(三)Pytorch快速搭建卷积神经网络模型实现手写数字识别(代码+详细注解)

news2025/1/10 23:37:56

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

  • 前言
    • Q1:卷积网络和传统网络的区别
    • Q2:卷积神经网络的架构
    • Q3:卷积神经网络中的参数共享,也是比传统网络的优势所在
    • 4、 具体的实现代码+网络搭建


前言

深度学习pytorch系列第三篇啦,之前更了FC,NN,这篇是卷积神经网络(cNN)模型实现手写数字识别,依然是重在理解哈,具体的理解内容我都以注释的形式放在了代码中,我就直接放代码了,因为我把一些知识点和理解的东西用注释的形式写了


首先是关于卷积神经网络的一些点

Q1:卷积网络和传统网络的区别

传统网络只适合结构化数据,不适合图像数据,由于图像数据的数据量大(表现为像素点多),传统网络需要使用的参数量太大

Q2:卷积神经网络的架构

卷积神经网络包括:输入层,卷积层,池化层,全连接层
重点介绍卷积层!!
卷积就是针对每个区域去计算特征。可以这样做的原因是:图片是有像素点构成的,针对每个像素点进行处理,需要的参数量过于庞大,并且相邻的像素点之间是存在联系的
特征图的个数与卷积核的个数一致。每个卷积核通过对输入特征图进行卷积操作,生成一个输出特征图。因此,卷积核的个数决定了输出的特征图的个数。
使用不同的卷积核学习同一个位置,可以得到不同的特征图,从而使特征多样化
卷积核的大小一般使用3*3
卷积核的大小规格一般是固定的,卷积核的数量理论上是越多越好
卷积层涉及的参数有:滑动窗口步长,卷积核尺寸,边缘填充,卷积核个数
卷积结果计算公式:长:h2=(h1-Fh+2p)/s +1 宽:w2=(w1-Fw+2p)/s +1
其中:w1,h1表示输入的宽度,长度;w2和h2表示输出特征图的宽度、长度,F表示卷积核的长和宽,s表示滑动窗口的补偿,p表示边界填充
经过卷积操作后,特征图的长和宽也可以保持不变
池化层的作用就是筛选好的特征,pool是只筛选位置的,channel是全部使用的
池化也称为下采样,(一次只能下采样原来的一半,不能直接224-16)
卷积神经网络由多个block组成,重点就在于怎么设计这个block的组成
关于卷积神经网络的层数,带权重参数的就算是一层,6个conn+1个fc,就可以说是7层网络结构

Q3:卷积神经网络中的参数共享,也是比传统网络的优势所在

同一个卷积核在各个位置上的参数都是一致的
权重参数的个数与输入数据的大小无关

4、 具体的实现代码+网络搭建

# 读取数据
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets,transforms
# transforms  进行预处理,比如进行tensor转换
import matplotlib.pyplot as plt
import numpy as np
#全连接:batch*28*28,全连接各个像素点之间无关
# cnn:batch*1*28*28  ,多了一个参数channel,卷积会综合考虑一个窗口之间的关系,因此各个像素点并不是独立的,卷积网络更适合处理图像数据
# 定义超参数
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片
# 训练集
train_dataset = datasets.MNIST(root='./data',
                            train=True,
                            transform=transforms.ToTensor(),
                            download=True)
# 测试集
test_dataset = datasets.MNIST(root='./data',
                           train=False,
                           transform=transforms.ToTensor())

# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)
# 卷积网络模块构建
# 一般卷积层,relu层,池化层可以写成一个套餐
# 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务
# 定义一个网络
class CNN(nn.Module):
    def __init__(self):
        #         构造函数
        # 卷积网络一般是组合进行的:conv pool relu可以当一个组合
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(  # 输入大小 (1, 28, 28)
            nn.Conv2d(  # 2d卷积做任务
                in_channels=1,  # 灰度图
                out_channels=16,  # 要得到几多少个特征图,就是卷积核的个数,相当于有16个卷积核
                kernel_size=5,  # 卷积核大小 5*5的
                stride=1,  # 步长
                padding=2,  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1,一般是这么希望的
                #                                             如果不能整除pytorch采用向下取整
            ),  # 输出的特征图为 (16, 28, 28)
            nn.ReLU(),  # relu层
            nn.MaxPool2d(kernel_size=2),  # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14),一般是pooling后是之前的一半
        )
        self.conv2 = nn.Sequential(  # 下一个套餐的输入 (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),  # 输出 (32, 14, 14)
            nn.ReLU(),  # relu层
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 输出 (32, 7, 7)
        )

        self.conv3 = nn.Sequential(  # 下一个套餐的输入 (32, 7, 7)
            nn.Conv2d(32, 64, 5, 1, 2),  # 输出 (64, 7, 7)
            nn.ReLU(),  # 输出 (64, 7, 7)
        )
        # 只有pool的时候才会筛选特征

        self.out = nn.Linear(64 * 7 * 7, 10)  # 全连接层得到的结果,最后的任务是10分类任务,进行一个wx+b的操作去做分类

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)  # flatten操作,结果为:(batch_size, 32 * 7 * 7),和reshape操作一样
        # reshape操作:总的大小是不变的,提供一个维度后,后边的维度自动计算
        # 比如当前的x:64*7*7,x.size:64,也就是要从三维转成两维,总的大小不变,就变为64*49这样,-1可以简单的看成一个占位符号
        # 变换维度,开始是64*7*7,转成batchsize*特征个数,比如64*49
        output = self.out(x)
        return output
# 定义准确率
def accuracy(predictions, labels):
    pred = torch.max(predictions.data, 1)[1] # 最大值是多少,最大值的索引,只要索引就可以
    rights = pred.eq(labels.data.view_as(pred)).sum()
    return rights, len(labels)
# 训练网络模型
# 实例化
net = CNN()
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器,学习率是0.001
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器,普通的随机梯度下降算法
# 开始训练循环
for epoch in range(num_epochs):
    # 当前epoch的结果保存下来
    train_rights = []

    for batch_idx, (data, target) in enumerate(train_loader):  # 针对容器中的每一个批进行循环
        net.train()
        output = net(data)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        right = accuracy(output, target)
        train_rights.append(right)
        # 每一个batch都进行训练,每一百个batch进行一次评估
        if batch_idx % 100 == 0:

            net.eval()
            val_rights = []

            for (data, target) in test_loader:
                output = net(data)
                right = accuracy(output, target)
                val_rights.append(right)

            # 准确率计算
            train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
            val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))

            print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(
                epoch, batch_idx * batch_size, len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                loss.data,
                       100. * train_r[0].numpy() / train_r[1],
                       100. * val_r[0].numpy() / val_r[1]))

实现结果
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1265175.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

美创科技受邀亮相第二届全球数字贸易博览会

11月23日-27日,由浙江省人民政府、商务部共同主办的第二届全球数字贸易博览会(以下简称“数贸会”)圆满落幕。围绕“国家级、国际性、数贸味”的目标定位,以“数字贸易 商通全球”为主题,数贸会重点展示数字贸易全产业…

哈希函数:保护数据完整性的关键

🤍 前端开发工程师(主业)、技术博主(副业)、已过CET6 🍨 阿珊和她的猫_CSDN个人主页 🕠 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 🍚 蓝桥云课签约作者、已在蓝桥云…

佳易王个体诊所管理系统电子处方软件,个体诊所人员服务软件,卫生室配方模板电子病历系统教程

佳易王个体诊所管理系统电子处方软件,个体诊所人员服务软件,卫生室配方模板电子病历系统教程 软件试用版下载可以点击最下方官网卡片 软件功能: 1、配方模板:可以自由添加配方分类,预先设置药品配方,可以…

字符串逆序问题

写一个函数,可以将任意输入的字符串逆序(要可以满足多组输入) 这个题有三个点 1.要读入键盘输入的字符串,所以要用到字符串输入函数 2.可以进行多组输入 3.把输入的n组字符串都逆序 #define _CRT_SECURE_NO_WARNINGS 1 #incl…

[栈迁移+ret滑梯]gyctf_2020_borrowstack

题目来源buuctf——gyctf_2020_borrowstack 参考链接https://www.shawroot.cc/2097.html 题目信息ubuntu16、64位 第一个read仅溢出一个机器字长,需要栈迁移 解题步骤栈偏移到全局变量bank中,ret2libcgadget 关键步骤 ret滑梯 第二个payload需要添…

Android flutter项目 启动优化实战(一)使用benchmark分析项目

背景描述 启动时间是用户对应用的第一印象,较慢的加载会对用户的留存和互动造成负面影响 在刚上线的B端项目中: 1.提高启动速度能提高整体流程的效率 2.提高首次运行速度能提高应用推广的初体验效果 问题描述 项目刚上线没多久、目前存在冷启动过程存在…

《融合SCADA系统数据的天然气管道泄漏多源感知技术研究》误报数据识别模型开发

数据处理不作表述。因为我用的是处理后的数据,数据点这。 文章目录 工作内容1CC040VFD电流VFD转速压缩机转速反馈进出口差压 紧急截断阀开到位进出电动阀开到位发球筒电筒阀开到位收球筒电动阀开到位电动阀2005开到位越站阀开到位 工作内容2工作内容3 工作内容1 任…

【Python 训练营】N_12 打印菱形图案

题目 打印菱形图案 分析 先把图形分成两部分来看待,前四行一个规律,后三行一个规律,利用双重for循环,第一层控制行,第二层控制列。 答案 # 方法一 for i in range(4):block **(2*i1)print({:^7}.format(block))…

web:NewsCenter

题目 打开页面显示如下 页面有个输入框,猜测是sql注入,即search为注入参数点,先尝试一下 返回空白显示错误 正常显示如下 是因为单引号与服务端代码中的’形成闭合,输入的字符串hello包裹,服务端代码后面多出来一个‘导…

MYSQL 8.X Linux-Generic 通用版本安装

下载对应版本MySQL :: Download MySQL Community Server (Archived Versions) 这里我选择的是Linux - Generic (glibc 2.12) (x86, 64-bit), TAR 解压到服务器 只需要里面的mysql-8.0.24-linux-glibc2.12-x86_64.tar.xz 在目录下创建需要的文件夹 这里我改名为mysql-8.0.24…

分享一个适用于 Vue3 的好的组件库,PrimeVue组件。

一、PrimeVue介绍 PrimeVue 是一个基于 Vue.js 的 UI 组件库,专注于提供丰富、灵活、现代的 UI 组件,以帮助开发者构建功能强大的 Web 应用程序。PrimeVue 提供了一系列的组件,涵盖了从基本的表单元素到高级的数据表格和图表等各种组件。 二、…

RPC之grpc重试策略

1、grpc重试策略 RPC 调用失败可以分为三种情况: 1、RPC 请求还没有离开客户端; 2、RPC 请求到达服务器,但是服务器的应用逻辑还没有处理该请求; 3、服务器应用逻辑开始处理请求,并且处理失败; 最后一种…

函数声明与函数表达式

函数声明 一个标准的函数声明&#xff0c;由关键字function 、函数名、形参和代码块组成。 有名字的函数又叫具名函数。 举个例子&#xff1a; function quack(num) { for (var i 0; i < num; i) {console.log("Quack!")} } quack(3)函数表达式 函数没有名称…

4.7 构建onnx结构模型-Transpose

前言 构建onnx方式通常有两种&#xff1a; 1、通过代码转换成onnx结构&#xff0c;比如pytorch —> onnx 2、通过onnx 自定义结点&#xff0c;图&#xff0c;生成onnx结构 本文主要是简单学习和使用两种不同onnx结构&#xff0c; 下面以transpose 结点进行分析 方式 方…

音视频学习(十九)——rtsp收流(tcp方式)

前言 本文主要介绍以tcp方式实现rtsp拉流。 流程图 流程说明: 客户端发起tcp请求&#xff0c;如向真实相机设备请求&#xff0c;端口一般默认554&#xff1b;tcp连接成功&#xff0c;客户端与服务端开始rtsp信令交互&#xff1b;客户端收到play命令响应后&#xff0c;开启线…

UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown

一个奇怪的BUG 这个代码会报下面的错&#xff1a; 但是把模型导入部分注释掉之后就没有这个错误了&#xff08;第六行&#xff09; 解决办法&#xff1a;在模型加载后面加入一行代码 matplotlib.use( TkAgg’)&#xff0c;这个bug的问题就是模型加载改变了matplotlib使用的终端…

Leetcode算法之哈希表

目录 1.两数之和2.判定是否互为字符重排3.存在重复元素I4.存在重复元素II5.字母异位词分组 1.两数之和 两数之和 class Solution { public:vector<int> twoSum(vector<int>& nums, int target) {unordered_map<int,int> hash;for(int i0;i<nums.si…

力扣141-环形链表

文章目录 力扣141-环形链表示例代码实现要点剖析 力扣141-环形链表 给你一个链表的头节点 head &#xff0c;判断链表中是否有环。 如果链表中有某个节点&#xff0c;可以通过连续跟踪 next 指针再次到达&#xff0c;则链表中存在环。 为了表示给定链表中的环&#xff0c;评测…

【C++】类型转换 ⑤ ( 常量和非常量之间的类型转换 - 常量类型转换 const_cast | const 左数右指原则 | 代码示例 )

文章目录 一、const 关键字简介1、const 修饰普通数据2、const 修饰指针 ( 左数右指原则 | 指针常量 | 常量指针 ) 二、常量和非常量 之间的类型转换 - 常量类型转换 const_cast1、常量类型转换 const_cast2、常量不能直接修改3、修改常量值的方法4、特别注意 - 确保指针指向的…