六、神经网络完整训练流程(MNIST数据集为例)

news2024/9/20 20:24:34

一、下载数据集

MNIST数据集
将下载好的数据集解压放入同级项目路径下
在这里插入图片描述

二、导包

import torch
import torch.nn as nnn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets,transforms
%matplotlib inline

%matplotlib inline可将绘图内嵌到notebook中,可以省略掉plt.show()

三、加载数据集

设置一些参数信息、datasetsdataloader
MNIST数据集图像大小均为28*28像素,共10个类别

#输入图像大小为28*28,10个类别,全部图像训练循环3次,每次训练64张
input_size = 28
num_classes = 10
num_epochs = 3
batch_size = 64

train_dataset = datasets.MNIST(root="./data/",train=True,transform=transforms.ToTensor(),download=True)
test_dataset = datasets.MNIST(root="./data/",train=False,transform=transforms.ToTensor(),download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)

四、构架模型

模型总共分为2层和1个输出

1,layer1包括卷积层、激活函数、最大池化层

因为是数据集样本是单颜色通道,大小为28*28像素点,形状为:[1,28,28]

卷积层:
卷积核为大小为5*5,卷积核个数为16,滑动步长为1,加边圈数为2
在这里插入图片描述
由公式计算可得,通过卷积层输出的特征图形状为[16,28,28]
池化层:池化核大小为2*2,仅对特征图大小进行砍半,即[14,14]
最终通过layer1之后的特征图形状为[16,14,14]
在这里插入图片描述

2,layer2包括卷积层、激活函数、最大池化层

通过layer1之后得到的特征图形状为[16,14,14]

卷积层:
卷积核为大小为5*5,卷积核个数为32,滑动步长为1,加边圈数为2
同理通过计算得到特征图形状为[32,14,14]

池化层:池化核大小为2*2,仅对特征图大小进行砍半,即[7,7]
最终通过layer2之后的特征图形状为[32,7,7]

3,输出层为线性层全连接

通过layer2之后得到的特征图形状为[32,7,7]
首先将该特征图进行展开成一行,x.view(x.size(0),-1)
每张图像的特征元素全部占一行,有多张图像,故得到(batch_size, 32*7*7)

之后因为最终的是十分类任务,在对其通过矩阵[32*7*7,10]进行线性变换,最终将其转换为十个输出值,即[batch_size,10],每张图片为十个值,分别对应预测成为0-9这十个数字的概率,共batch_size张

class yy_model(nn.Module):
    def __init__(self):
        super(yy_model,self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels=1,out_channels=16,kernel_size=5,stride=1,padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(in_channels=16,out_channels=32,kernel_size=5,stride=1,padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.out = nn.Linear(in_features=7*7*32,out_features=10)
        
    def forward(self,x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = x.view(x.size(0),-1) #(batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output

五、准确率

传入模型预测结果和实际的真实结果,进行对比即可,找到模型预测结果中最大的值,也就是模型预测成为的数字

def accuracy(predictions, labels):
    pred = torch.max(predictions.data, 1)[1] 
    rights = pred.eq(labels.data.view_as(pred)).sum() 
    return rights, len(labels) 

六、模型训练

# 实例化
net = yy_model() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法

#开始训练循环
for epoch in range(num_epochs):
    #当前epoch的结果保存下来
    train_rights = [] 
    
    for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环
        net.train()                             
        output = net(data) 
        loss = criterion(output, target) 
        optimizer.zero_grad() 
        loss.backward() 
        optimizer.step() 
        right = accuracy(output, target) 
        train_rights.append(right) 

    
        if batch_idx % 100 == 0: 
            
            net.eval() 
            val_rights = [] 
            
            for (data, target) in test_loader:
                output = net(data) 
                right = accuracy(output, target) 
                val_rights.append(right)
                
            #准确率计算
            train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
            val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))

            print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(
                epoch, batch_idx * batch_size, len(train_loader.dataset),
                100. * batch_idx / len(train_loader), 
                loss.data, 
                100. * train_r[0].numpy() / train_r[1], 
                100. * val_r[0].numpy() / val_r[1]))

七、检测模型训练效果

感觉模型训练的效果还行,实际标签为4,预测的结果也是4,嘿嘿

x,y = train_dataset[9]#第9个数据x为图片,对应的结果为4
x.shape # torch.Size([1, 28, 28])
y # 4
x = x.view(-1,1,28,28) # 因为投喂网络需要格式为[B,C,W,H],需要变成相应的格式
x.shape # torch.Size([1, 1, 28, 28])
y_hat = net(x) # 模型预测
y_hat
"""
tensor([[ -7.2561,  -3.0549,  -3.3932,  -6.8128,  10.3861, -11.8726,  -7.2241,
          -0.6564,  -2.5825, -10.9693]], grad_fn=<AddmmBackward0>)
"""
pred_maxvalue, pred_maxindex = torch.max(y_hat,dim=1) # 得到最大的值和索引下标
pred_maxvalue # tensor([10.3861], grad_fn=<MaxBackward0>)
pred_maxindex # tensor([4])
pred_maxindex.item() # 4 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/709050.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华芯微特SWM34-IO速度优化之模拟SPI写速度提速

本文以在SWM34S&#xff08;M33内核,150Mhz&#xff0c;编译器Keil MDK 5.36&#xff09;上优化为例&#xff0c;说明优化方法和需要注意的地方&#xff0c;其他MCU可以参考。 在编写模拟SPI通信驱动LCD的例子的时候&#xff0c;会用到一个发送字节的核心函数&#xff0c;其基本…

【JavaSE】初步认识

目录 【1】Java语言概述 【1.1】Java是什么 【1.2】Java语言重要性 【1.3】Java语言发展简史 【1.4】Java语言特性 【1.5】 Java开发环境安装 【2】初识Java的main方法 【2.1】main方法示例 【2.2】运行Java程序 【3】注释 【3.1】基本规则 【3.2】注释规范 【4】…

ESP32-S2启动异常分析

客户反馈最近一批50块基于ESP32-S2的LoRaWAN gateway&#xff0c;有5块偶尔网络灯能亮&#xff0c;经常不能亮。 反复分析&#xff0c;定位&#xff0c;一个共同现象是用示波器看&#xff0c;串口输出一串信息后再没输出了。因为用了 ESP-ROM:esp32s2-rc4-20191025 Build:Oct …

企业构建知识库方案

AI模型理解误区&#xff1a;百万成本微调垂直行业达模型VS低成本建立企业专属知识库或ai助理_哔哩哔哩_bilibili

vscode关闭调试工具栏

问题描述 项目启动的时候老是蹦出这玩意 很碍眼 解决方案&#xff1a; 设置里搜索 选项改为hidden即可

前端Vue自定义注册界面模版 手机号邮箱账号输入框 验证码输入框 包含手机号邮箱账号验证

前端Vue自定义注册界面模版 手机号邮箱账号输入框 验证码输入框 包含手机号邮箱账号验证 &#xff0c; 下载完整代码请访问uni-app插件市场地址&#xff1a;https://ext.dcloud.net.cn/plugin?id13306 效果图如下&#xff1a;

CSS 备忘录2-动画、渐变、颜色、选择器等

1、背景 background属性是八个属性的简写形式&#xff1a; background-image 指定一个文件或生成的颜色渐变作为背景图片background-position 设置图片的初始位置background-size 指定背景图片的渲染尺寸background-repeat 是否平铺图片ba…

缺少msvcp140.dll丢失该如何解决?

msvcp140.dll是什么东西?相信很多人都遇到过msvcp140.dll这个文件吧?那么为什么一丢失msvcp140.dll电脑软件就会打不开?如果缺失了这个东西会怎么样?小编今天就来给大家详细的说说&#xff0c;其实这些都是一些比较常见的电脑知识&#xff0c;我们是需要去了解一下的。 msv…

Python 利用@property装饰器和property()方法将一个方法变成属性调用

目录 方法一&#xff1a;使用property装饰器 方法二&#xff1a;使用property()创建类属性 在创建实例属性时&#xff0c;如果直接把实例属性暴露出去&#xff0c;虽然写起来简单&#xff0c;但是存在一些风险&#xff0c;比如实例属性可以在外部被修改。 为了限制外部操作&a…

springboot集成openfeign,集成Histric

一、Feign简介 Feign是一个声明式的伪Http客户端&#xff0c;它使得写Http客户端变得更简单。使用Feign&#xff0c;只需要创建一个接口并注解。它具有可插拔的注解特性&#xff0c;可使用Feign 注解和JAX-RS注解。Feign支持可插拔的编码器和解码器。Feign默认集成了Ribbon&…

Echarts入门(SpringBoot + Vue)

一、Echarts简介 代码已上传至码云:echarts_boot: echarts使用demo ECharts是一个使用 JavaScript 实现的"数据可视化"库, 它可以流畅的运行在 PC 和移动设备上 什么是数据可视化? 也就是可以将数据通过图表的形式展示出来&#xff0c; Echarts官网:Apache ECh…

9-如何获取N维数组元素?【视频版】

目录 问题视频解答 问题 视频解答 点击观看&#xff1a; 9-如何获取N维数组元素&#xff1f;

基于 Opencv python实现批量图片去黑边—裁剪—压缩软件

简介 批量处理图片文件&#xff0c;批量提取GIF图片中的每一帧&#xff0c;具有范围裁剪、自动去除黑/白边、调整大小、压缩体积等功能。 先看一些软件的界面&#xff0c;是基于Tkinter写的GUI 裁剪等功能基于Opencv 下载 我添加了处理GIF的github&#xff1a; 原作者的gith…

基于Ant DesignPro Vue + SpringBoot 前后端分离 - 后端微服化 + 接口网关 + Nacos

基于Ant DesignPro Vue SpringBoot 前后端分离 - 后端微服化 接口网关 Nacos 通过Ant DesignPro Vue SpringBoot 搭建的后台管理系统后&#xff0c;实现了前后端分离&#xff0c;并实现了登录认证&#xff0c;认证成功后返回该用户相应权限范围内可见的菜单。 后端采用Spri…

一、枚举类型——新特性(将 switch 作为表达式)

switch 一直以来都只是一个语句&#xff0c;并不会生成结果。 JDK 14 使得 switch 还可以作为 一个表达式来使用&#xff0c;因此它可以得到一个值&#xff1a; SwitchExpression.java public class SwitchExpression {static int colon(String s) {var result switch (s) {ca…

基于单片机的智能点滴速度输液液体检测

功能介绍 以51单片机作为主控系统&#xff1b;显示液位&#xff0c;滴数&#xff0c;温度等信息&#xff1b;通过水位传感器检测当前药瓶是否有水&#xff1b;通过滴速传传感器利用单片机定时器计算当前滴速&#xff1b;通过DS18B20温度传感器采集当前药液温度&#xff0c;继电…

【前端】JS语法——数据类型转换

一、字符串&#xff08;里面必须数字&#xff09;转换为数字类型&#xff08;number&#xff09; 1、强制转换&#xff1a;(parseInt()、parseFloat()、Number())&#xff1b; 2、隐式转换&#xff08;number[-/*%]string&#xff09;&#xff1b; <script>let s &qu…

红米K60刷入MIUI.EU安装面具magisk与root教程

文章目录 前言1.解锁BootLoader2.刷入Recovery3.刷入EU的ROM包4.刷入magisk面具后话 前言 教程大概就是四步&#xff0c;解锁&#xff0c;刷入rec&#xff0c;刷入系统&#xff0c;刷入面具&#xff0c;跟着教程走即可。这次是刷机方式&#xff1a;卡刷&#xff08;Recovery&a…

SELECT * 会导致查询效率低的原因

SELECT * 会导致查询效率低的原因 前言一、适合SELECT * 的使用场景二、SELECT * 会导致查询效率低的原因2.1、数据库引擎的查询流程2.2、SELECT * 的实际执行过程2.3、使用 SELECT * 查询语句带来的不良影响 三、优化查询效率的方法四、总结 前言 因为 SELECT * 查询语句会查…

PCL可视化 3D点云PCD文件

工具安装 sudo apt install pcl-tools 启动命令&#xff1a; pcl_viewer 000000.pcd