【人工智能概论】 构建神经网络——以用InceptionNet解决MNIST任务为例

news2024/11/14 23:33:03

【人工智能概论】 构建神经网络——以用InceptionNet解决MNIST任务为例

文章目录

  • 【人工智能概论】 构建神经网络——以用InceptionNet解决MNIST任务为例
  • 一. 整体思路
    • 1.1 两条原则
    • 1.2 四个步骤
  • 二. 举例——用InceptionNet解决MNIST任务
    • 2.1 模型简介
    • 2.2 MNIST任务
    • 2.3 完整的程序


一. 整体思路

  • 两条原则,四个步骤。

1.1 两条原则

  1. 从宏观到微观
  2. 把握数据形状

1.2 四个步骤

  1. 准备数据
  2. 构建模型
  3. 确定优化策略
  4. 完善训练与测试代码

二. 举例——用InceptionNet解决MNIST任务

2.1 模型简介

  • InceptionNet的设计思路是通过增加网络宽度来获得更好的模型性能。
  • 其核心在于基本单元Inception结构块,如下图:
    在这里插入图片描述
  • 通过纵向堆叠Inception块构建完整网络。

2.2 MNIST任务

  • MNIST是入门级的机器学习任务;
  • 它是一个手写数字识别的数据集。

2.3 完整的程序

# 调包
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.optim as optim


"""数据准备"""
batch_size = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])

train_dataset = datasets.MNIST(root='./mnist/',train=True,download=True,transform=transform)
train_loader = DataLoader(train_dataset,shuffle=True,batch_size=batch_size)
test_dataset = datasets.MNIST(root='./mnist/',train=False,download=True,transform=transform)
test_loader = DataLoader(test_dataset,shuffle=False,batch_size=batch_size)


"""构建模型"""
# 需要指定输入的通道数

class Inceptiona(nn.Module): 
    def __init__(self,in_channels):
        super(Inceptiona,self).__init__()
        
        self.branch1_1 = nn.Conv2d(in_channels , 16 , kernel_size= 1)
        
        self.branch5_5_1 =nn.Conv2d(in_channels, 16, kernel_size= 1)
        self.branch5_5_2 =nn.Conv2d(16,24,kernel_size=5,padding=2)
        
        self.branch3_3_1 = nn.Conv2d(in_channels, 16,kernel_size=1)
        self.branch3_3_2 = nn.Conv2d(16,24,kernel_size=3,padding=1)
        self.branch3_3_3 = nn.Conv2d(24,24,kernel_size=3,padding=1)
        
        self.branch_pooling = nn.Conv2d(in_channels,24,kernel_size=1)
        
    def forward(self,x):
        x1 = self.branch1_1(x)
        
        x2 = self.branch5_5_1(x)
        x2 = self.branch5_5_2(x2)
        
        x3 = self.branch3_3_1(x)
        x3 = self.branch3_3_2(x3)
        x3 = self.branch3_3_3(x3)
        
        x4 = F.avg_pool2d(x,kernel_size=3,stride = 1, padding=1)
        x4 = self.branch_pooling(x4)
        
        outputs = [x1,x2,x3,x4]
        return torch.cat(outputs,dim=1)      
    
# 构建完整的网络
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1 = nn.Conv2d(1,10,kernel_size=5)
        self.conv2 = nn.Conv2d(88,20,kernel_size=5)
        
        self.incep1 = Inceptiona(in_channels=10)
        self.incep2 = Inceptiona(in_channels=20)
        
        self.mp = nn.MaxPool2d(2)
        self.fc = nn.Linear(1408,10)
        
    def forward(self,x):
        batch_size = x.size(0)
        x = F.relu(self.mp(self.conv1(x)))
        x = self.incep1(x)
        x = F.relu(self.mp(self.conv2(x)))
        x = self.incep2(x)
        
        x = x.view(batch_size,-1)
        x = self.fc(x)
        return x


"""确定优化策略"""
model = Net()
device = torch.device('cuda:0'if torch.cuda.is_available() else 'cpu')
model.to(device) # 指定设备

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(),lr=0.01,momentum=0.5)



"""完善训练与测试代码"""
def train(epoch):
    running_loss = 0.0
    for batch_index, data in enumerate(train_loader,0):
        inputs, target = data
        # 把数据和模型送到同一个设备上
        inputs, target = inputs.to(device), target.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs,target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() 
        # 用loss.item不会构建计算图,得到的不是张量,而是标量
        if batch_index % 300 == 299:
            # 每三百组计算一次平均损失
            print('[%d,%5d] loss: %.3f' %(epoch+1,batch_index+1,running_loss/300))
            # 给出的是平均每一轮的损失
            running_loss = 0.0
            
def test():
    correct = 0
    total = 0
    with torch.no_grad(): 
        # 测试的环节不用求梯度
        for data in test_loader:
            images , labels = data
            images , labels = images.to(device),  labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data,dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print('accuracy on test set: %d,%%'%(100*correct/total))
    return 100*correct/total  # 将测试的准确率返回

# 执行训练
if __name__=='__main__':
    score_best = 0
    for epoch in range(10):
        train(epoch)
        score = test()
        if score > score_best:
            score_best = score
            torch.save(model.state_dict(), "model.pth")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/440057.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【ChatGPT 】ChatGPT Sidebar 实战:自定义 ChatGPT 搜索页面回复模板(示例开发和文员专用模板)

目录 一、前言 二、ChatGPT Sidebar 通用配置 (1)通用配置入口 (2)设置 ① 如何访问 ChatGPT ② 语言 ③ 主题 三、ChatGPT Sidebar 搜索页面 (1)搜索页面入口 (2)设置 …

Node 05-Node.js模块化

Node.js 模块化 介绍 什么是模块化与模块 ? 将一个复杂的程序文件依据一定规则(规范)拆分成多个文件的过程称之为 模块化 其中拆分出的 每个文件就是一个模块,模块的内部数据是私有的,不过模块可以暴露内部数据以便其他模块使用…

【问题解决】Git报错:failed to push some refs to xxxxx

Git报错:failed to push some refs to xxxxx To https://xxxxxxxxxxxx.git ! [rejected] master -> master (fetch first) error: failed to push some refs to ‘https://xxxxxxxx.git’ hint: Updates were rejected because the remote contains work that yo…

AI+明厨亮灶智能算法 yolo

AI明厨亮灶智能算法通过pythonyolo网络模型分析算法,AI明厨亮灶模型算法可接对后厨实现如口罩识别、厨师服穿戴、夜间老鼠监测、厨师帽识别、厨师玩手机打电话识别、抽烟识别等实时分析监测。Python是一种由Guido van Rossum开发的通用编程语言,它很快就…

领导力专题︱聊聊领导力的主要问题

本文内容结构 一、领导力的主要问题:领导者与下属 1、让人敬佩的领导者的能力与技巧 2、下属的期望 (1)热情 (2)重视 (3)欣赏 (4)归属感 3、下属(追随…

微结构MRI参数估计的神经网络:在白质扩散-弛豫模型中的应用

导读 通过使用生物物理模型来解释弛豫-扩散MRI大脑数据,可以研究白质微观结构的具体特征。尽管更复杂的模型有可能揭示组织的更多细节,但也会导致参数估计耗时较长,由于简并拟合地形中普遍存在局部最小值,这些参数估计可能会收敛…

软件测试工程师需要达到什么水平才能顺利拿到 20k 无压力?

最近有粉丝朋友问:软件测试员需要达到什么水平才能顺利拿到 20k 无压力? 这里写一篇文章来详细说说: 目录 扎实的软件测试基础知识:具备自动化测试经验和技能:熟练掌握编程语言:具备性能测试、安全测试、全…

前端Vue.js项目开发,不重启项目,快速切换后台地址---使用nginx负载简单快速实现更换后台代理地址

前端Vue.js项目开发,不重启项目,快速切换后台地址—使用nginx负载简单快速实现更换后台代理地址 本文实现了在vue项目不重启的情况下,快速实现更换联调后台服务器的方法, 能够大大节省vue项目重启时间 chen 2023-04-20 文档源码地址,最新版本会在这里修改…

互交式3d地球仪工具:Earth 3D - World Atlas Mac

Earth 3D - World Atlas for Mac是一款3d地球仪。这个交互式 3D 地球仪以世界奇观、政治和物理地图以及天气为特色。发现许多关于我们星球的有趣事实和有用信息!原始的彩色图形、用户友好的界面和准确的信息——这就是 Earth 3D - World Atlas 的全部意义所在&#…

leetcode Two Sum-Java 和Python 的写法

我想这题是正要开始写LeetCode 的人,大部分的人的第一题吧,这题是个基本题算在easy 的题型,看到题目直接就会想到使用双回圈的写法,不过双回圈时间复杂度只有达到 O(N^2) 不那么理想,如果比较资深的工程师会用HashMap …

wsl的图像化实现,在wsl中启动浏览器

最近在学习wsl,原本我看以前的教程说wsl和vmware的区别有一点就是,wsl只能使用命令行,而vmware可以实现图像化,结果我在 microsoft 官方发现现在的wsl 2已经实现了 GUI 界面,所以就来记录一下吧。 wsl 的 GUI 实现 首…

Vue3.2 + TypeScript + Pinia + Vite4 + Element-Plus + 微前端(qiankun) 后台管理系统模板(已开源)

最终效果 一、前言 Wocwin-Admin,是基于 Vue3.2、TypeScript、Vite、Pinia、Element-Plus、Qiankun(微前端) 开源的一套后台管理模板;同时集成了微前端 qiankun也可以当做一个子应用。项目中组件页面使用了Element-plus 二次封装 t-ui-plus 组件&#xf…

C/C++每日一练(20230420)

目录 1. 存在重复元素 II 🌟 2. 外观数列 🌟🌟 3. 最优路线 🌟🌟🌟 🌟 每日一练刷题专栏 🌟 Golang每日一练 专栏 Python每日一练 专栏 C/C每日一练 专栏 Java每日一练 专…

搭建sentry来监控Django项目

sentry搭建 我的环境: centos7,已安装docker和docker compose 下载最新zip包到 /usr/local/ https://github.com/getsentry/self-hosted/tagshttps://github.com/getsentry/self-hosted/tags解压 unzip self-hosted-23.4.0.zip 安装期间会提示是否…

Ceph入门到精通-Cephadm安装Ceph(v17.2.5 Quincy)全网最全版本

Deploy Ceph(v17.2.5 Quincy) cluster to use Cephadm - DevOps - dbaselife Install cephadm Cephadm creates a new Ceph cluster by “bootstrapping” on a single host, expanding the cluster to encompass any additional hosts, and then depl…

【洛谷 P1003】[NOIP2011 提高组] 铺地毯 题解(数组+贪心算法)

[NOIP2011 提高组] 铺地毯 题目描述 为了准备一个独特的颁奖典礼,组织者在会场的一片矩形区域(可看做是平面直角坐标系的第一象限)铺上一些矩形地毯。一共有 n n n 张地毯,编号从 1 1 1 到 n n n。现在将这些地毯按照编号从小…

阿里云mysql8小版本升级造成磁盘不断增长,undolog持续增长不释放

现象: 1.用户升级之后,实例上磁盘空间以每分钟1g的速度不断增长, 2.高频dml表的空间不断变大但表数据其实不大,binlog大量产生 3.通过select * from innodb_tablespaces where name like %undo%发现undo 空间上涨较快&#xff0…

常见的九种大数据分析模型

常见的9种大数据分析模型分别为: 事件分析、 属性分析、 渠道分析、 Session分析、 留存分析、 归因分析、 漏斗分析、 路径分析、 分布分析 1、【事件分析】 事件分析,是指用户在 APP、网站等应用上发生的行为,即何人,何时&…

Python OpenCV 蓝图:1~5

原文:OpenCV with Python Blueprints 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自【ApacheCN 计算机视觉 译文集】,采用译后编辑(MTPE)流程来尽可能提升效率。 当别人说你没有底线的时候,你最好真…

【AI】NVIDIA CUDA-X AI名词解释

0、NVIDIA CUDA-X AI NVIDIA CUDA-X AI是一套完整的深度学习软件 官网:https://developer.nvidia.com/deep-learning-software https://github.com/NVIDIA:NVIDIA产品、演示、示例、入门教程 1、深度学习训练 Deep Learning Training 1.1、DALI 数据加载库 (DALI)是一…