使用PyTorch构建卷积神经网络(CNN)源码(详细步骤讲解+注释版) 01 手写数字识别

news2024/12/28 9:52:33

1 卷积神经网络(CNN)简介

在使用PyTorch构建GAN生成对抗网络一文中,我们使用GAN构建了一个可以生成人脸图像的模型。但尽管是较为简单的模型,仍占用了1G左右的GPU内存,因此需要探索更加节约资源的方式。

卷积神经网络(Convolutional Neural Network,简称CNN)是一种深度学习模型,主要应用于图像处理、语音识别等领域。它的主要思想是通过卷积操作对输入图像的特征进行提取,再通过多层网络对特征进行分类和判断。

CNN的网络结构通常由卷积层、池化层和全连接层组成。卷积层的作用是对输入图像的特征进行提取,池化层的作用是减少数据的维度,以提高计算效率;全连接层则用于对特征进行分类和判断。

CNN可以通过训练学习到输入图像的特征表示,从而可以在未知图像上进行分类、识别等任务。它已经成为计算机视觉领域的重要技术,在诸多应用中取得了良好的效果。
在这里插入图片描述

2 从普通BP到CNN的网路结构转变

以前面建立好的手写数字分类器为例,(使用PyTorch构建神经网络构建手写数字分类器)在模型结构定义中,需要对神经网络层做出相应的修改:

        self.model = nn.Sequential(
            # expand 1 to 10 filters
            nn.Conv2d(1, 10, kernel_size=5, stride=2),
            nn.LeakyReLU(0.02),
            nn.BatchNorm2d(10),
        
            # 10 filters to 10 filters
            nn.Conv2d(10, 10, kernel_size=3, stride=2),
            nn.LeakyReLU(0.02),
            nn.BatchNorm2d(10),
            
            View(250),
            nn.Linear(250, 10),
            nn.Sigmoid()
        )

更新后的神经网络架构如下:

  1. 第一个卷积层:把1个通道的输入图像扩展为10个通道,使用5x5的卷积核,步长为2。
  2. 第二个卷积层:10个通道的输入图像不变,使用3x3的卷积核,步长为2。
  3. 第一个全连接层:把250个节点的一维向量映射到10个节点。

其中用到的函数的含义:
4. Conv2d:对由一个或多个输入平面组成的输入信号进行二维卷积。第1个参数是输入参数,对于黑白图像,输入的通道数即为1。第2个参数是输出通道的数量。在上面的代码中,我们创建了10个卷积核,从而生成10个特征图。kernel_size函数代表了卷积核的大小,使用的是5×5的卷积核。stride是卷积核移动时的大小。该数值小于卷积核大小时,说明卷积核所覆盖的区域有重叠。
5. LeakyReLU:非线性激活函数,常用于生成对抗网络。
6. BatchNorm2d:批量归一化,用于提高网络的稳定性和收敛速度。
7. View:将多维张量展平为一维向量。(自定义函数,详见完整代码)
8. Sigmoid:S形函数,用于二分类问题的输出。

对于一个28*28像素的图片,第一步卷积之后将会生成一个12*12像素的图片(计算方式:共走了 28 − 5 2 \frac{28-5}{2} 2285 步)。第二步卷积之后将会生成一个5*5像素的图片(计算方式:共走了 12 − 3 2 \frac{12-3}{2} 2123 步)。

3 从普通BP到CNN的辅助修改

在网络结构中用到了View函数,在上面的参考博文中并未涉及这部分代码,因此把这给你功能进行补充。(与人脸识别篇代码中的View完全相同)

class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape,

    def forward(self, x):
        return x.view(*self.shape)

此外,修改后的CNN网络结构,其传入的图片应将其修改为4D数据。因此在模型训练时,将传入的数据进行变形。

start_time = time.perf_counter()  # 计时开始
C = Classifier()
epochs = 3
for i in range(epochs):
    print('training epoch', i+1, 'of', epochs)
    for label, image_data_tensor, target_tensor in mnist_dataset:
        C.train(image_data_tensor.view(1, 1, 28, 28), target_tensor)

注:上面两个VIEW并不相同,一个是我们自行定义用于分类器类使用的函数,一个是torch的自带功能。
除此之外代码均可保持不变,这部分的原始代码可在此找到到或文末留言申请。

4 模型评估

在这里插入图片描述

在训练初期,可以看到模型的损失呈现迅速下降。下面使用测试集对模型准确率进行评价:
在这里插入图片描述
使用一张图片来查看模型的生成。此处我们分别选择了一张数字0和数字6,可以发现与BP模型相比,CNN模型对结果变得更有信心了。
在这里插入图片描述在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/193346.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue 文件.env.production、.env.development简析

静下来 慢慢看 首先 我们需要搭建一个项目 依赖包会自动下载好 无需自己 npm i .env 无论什么环境都会加载 .env.production 生产环境加载 .env.development 测试开发环境加载 我们下面的例子分开来写 只用 .env.production .env.development 在项目根目录新建两个文件 分别…

Git和SVN

一:两者的区别 ——Git是分布式版本控制系统,SVN是集中式版本控制系统 ——集中式版本控制系统 早期出现的版本控制系统有:SVN、CVS等,它们是集中式版本控制系统,集中式版本控制系统有一个单一的集中管理的服务器&…

【微服务】分布式搜索引擎elasticsearch(2)

分布式搜索引擎elasticsearch(2)1.DSL查询文档1.1.DSL查询分类1.2.全文检索查询1.2.1.使用场景1.2.2.基本语法1.2.3.示例1.2.4.总结1.3.精准查询1.3.1.term查询1.3.2.range查询1.3.3.总结1.4.地理坐标查询1.4.1.矩形范围查询1.4.2.附近查询1.5.复合查询1…

Power BI饼图

饼图展现的是个体占总体的比例,利用扇面的角度来展示比例大小。 在PowerBI默认的可视化对象中当然也有这种标准饼图,通过点击拖拽轻松生成, 通过对标题和图例的格式稍加设置,一个简单而不失专业的饼图就可以用了, 饼…

学术加油站|HIST,面向海量数据的学习型多维直方图

编者按 本文系东北大学李俊虎所著,也是「 OceanBase 学术加油站」系列第九篇内容。 「李俊虎:东北大学计算机科学与工程学院在读硕士生,课题方向为数据库查询优化,致力于应用 AI 技术改进传统基数估计器,令数据库选择…

mysql逻辑架构和数据库缓冲池

逻辑架构 典型的CS架构,服务端程序使用的是mysqld 客户端进程向服务器进程发送一段文本(SQL语句),服务器进程处理后再向客户端进程发送文本(处理结果) # 应用连接层: # 连接处理,用户鉴权(username,host…

分享|2023年全球市场准入认证咨讯

作为全球应用安全科学专家,UL Solutions服务全球100多个国家和地区的客户,将产品安全、信息安全和可持续性挑战转化为客户的机遇。UL Solutions 提供测试、检验、认证(TIC),以及软件产品和咨询服务,以支持客…

Deep Learning for 3D Point Clouds: A Survey - 3D点云的深度学习:一项调查 (IEEE TPAMI 2020)

Deep Learning for 3D Point Clouds: A Survey - 3D点云的深度学习:一项调查(IEEE TPAMI 2020)摘要1. 引言2. 背景2.1 数据集2.2 评估指标3. 3D形状分类3.1 基于多视图的方法3.2 基于体积的方法3.3 基于点的方法3.3.1 逐点MLP方法3.3.2 基于卷…

AI助力多文档审查丨合同风险审查、招投标文件、合同和中标通知书一致性审查

当下,企业管理的数据和文档管理中充斥着大量有复用价值的数据、资料和内容性信息。每一家企业都有许多商业文档和法律文档需要使用和维护,其中包含了不同语言文字、手写体、数字、公式等。然而,目前企业的各种文档资料仍主要依靠人工手段进行…

【Redis】Redis面试题详解与使用案例(金三银四面试专栏启动)

📫作者简介:小明java问道之路,2022博客之星全国TOP3,专注于后端、中间件、计算机底层、架构设计演进与稳定性建工设优化。 文章内容兼具广度深度、大厂技术方案,对待技术喜欢推理加验证,就职于知名金融公司…

ipv6实验

r1的配置为 [r1]int g0/0/0 [r1-GigabitEthernet0/0/0]ip address 12.1.1.1 24 [r1]int lo0 [r1-LoopBack0]ip address 1.1.1.1 32 [r1]ip route-static 0.0.0.0 0 12.1.1.2 r2的配置为 [r2]int g0/0/0 [r2-GigabitEthernet0/0/0]ip address 12.1.1.2 24 [r2]int lo0 […

小红书破局品牌增长:4大阶段+8个种草建议

品牌如何从激烈的竞争中突围,成为快速增长的“黑马”?本文就和大家一起聊聊围绕产品面对不同阶段的人群“种草”策略,希望能够帮助品牌更好地与用户沟通并提升营销效率,实现品效合一。 1、种草1.0 —— 立住产品,抢占赛…

Minecraft 1.19.2 Fabric模组开发 10.建筑生成

我们本次尝试在Fabric 1.19.2中生成一个自定义的建筑。 效果展示效果展示效果展示 由于版本更新缘故,1.19的建筑生成将不涉及任何Java包的代码编写,只需要在数据包中对建筑生成进行自定义。 1.首先我们要使用游戏中的结构方块制作一个建筑,结构方块使用…

企业内部沟通即时通讯软件要怎么选?

随着企业信息化的快速发展,为了工作效率的提高,即时通讯工具已经成为了众多企业办公时的标配,同时,各大企业对即时通讯功能的要求也越来越高。但是,现在市场上即时通信软件众多,各种功能和服务都是参差不齐…

python爬虫基本库的使用

博主简介:博主是一个大二学生,主攻人工智能领域研究。感谢缘分让我们在CSDN相遇,博主致力于在这里分享关于人工智能,C,python,爬虫等方面的知识分享。如果有需要的小伙伴,可以关注博主&#xff…

Java面试基础篇

目录 一、集合 1.集合与集合之间的区别 2.集合子类之间的区别(数据结构) 二、线程 三、面向对象 继承 多态 四、异常 五、IO流 六、序列化与反序列化 今天给大家分享 Java基础篇的面试题,小编给大家稍微整理了一下,希望即…

RHCE(web服务器)

文章目录一、www简介(一)网址及HTTP简介(二)HTTP协议请求的工作流程二、www服务器的类型(一)仅提供用户浏览的单向静态网页(二)提供用户互动接口的动态网站三、www服务器的基本配置四…

苹果证书p12和描述文件的创建教程

在hbuilderx或apicloud这些uniapp框架工具打包苹果APP的时候,需要p12证书和证书profile文件来编译,目前网上很少使用windows电脑生成p12证书的教程,官方的教程都是需要使用苹果电脑来创建的。 这里,我们这篇文章来教会大家如何使…

【虚拟仿真】Unity3D中实现鼠标悬浮UI上显示文字

推荐阅读 CSDN主页GitHub开源地址Unity3D插件分享简书地址我的个人博客 大家好,我是佛系工程师☆恬静的小魔龙☆,不定时更新Unity开发技巧,觉得有用记得一键三连哦。 一、前言 本篇文章实现一个鼠标悬浮在UI上显示文字的功能,实…

LMK04828时钟芯片配置历程——SPI接口

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 LMK04828时钟芯片配置历程——SPI接口总结最近有一个开发板需要去调试,开发板上包含了AD9371和LMK04828时钟芯片,而我的任务是需要将他们都配置起来。…