从零编写一个神经网络完成手写数字的识别分类(pytorch实现)

news2024/12/25 1:39:25

1. 前言

很多人都有这样的困惑:

“我已经看过很多有关神经网络的书和视频了,但为什么感觉还是似懂非懂呢?”

那是因为,你从来都没有完整的、从头编写并训练过一个神经网络

学习AI相关的算法,尤其是深度学习方向;

真的不是学几个公式,了解几个名词概念就可以的。

因为深度学习,是一门实践课程!

举个例子:

激活函数、损失函数、前向传播和反向传播,这些概念,相信大家都听过。

几个相关的问题,大家看看能不能回答出来。

激活函数:

激活函数必须要有吗?一般要放在哪里?

是放在线性层计算后,还是放在线性层计算前?

或者有没有都可以、放在哪都可以?

损失函数:

损失函数是用来做什么的?

有哪些常用的损失函数?

分类问题用什么损失函数?回归问题用什么损失函数?

前向传播和反向传播:

什么是前向传播、什么又是反向传播?

为什么使用Pytorch,要定义前向传播forward函数?

梯度计算是在前向传播的过程中,还是在反向传播的过程中?

对于上述这些简单问题,如果你觉得很模糊;

好像能回答出来,又好像回答不出来;

那只能说明一个问题:

就是你从来没动手编写过神经网络。

本文,我会具体讲解一个神经网络的编程案例,并附上代码。

大家看完这个案例后,动手写一写,然后再想一想;

就会发现,前面的那些问题,都能迎刃而解。

2. 问题简述

我要使用手写数字识别,这个例子:

来说明到底如何设计、实现并训练一个标准的前馈神经网络。

具体来说,我们要设计并训练一个3层的神经网络

这个神经网络会以数字图像作为输入。经过神经网络的计算,就会识别出图像中的数字是几,从而实现数字图像的分类。

在这个过程中,重点讲解3个块内容:

1)神经网络的设计和实现

2)训练数据的准备和处理

3)模型的训练和测试流程。

3. 神经网络的设计和实现

首先需要观察数据的样子?

为了设计一个处理数字图像的神经网络,首先要弄清楚输入图像的大小和格式。

其中,分辨率就是是图像的高和宽

可以发现,我们要处理的图片是,28*28像素的灰色单通道图像。

这样的灰色图像,包括了28*28=784个数据点。

每次在处理数字图像时,输入给神经网络的,就是这784个数据点。

在将它输入给神经网络前,这个28*28的二维图片向量,会被展平为1*784大小的一维线性向量【因为我们使用的是线性模型,而非卷积模型

比如这张图,左侧代表了28*28个像素对应的图像;

右侧是一个展平后的一维向量,包括了x0到x783,一共784个像素点。

这样这个向量才能被神经网络的输入层所接收和处理。

3.1 输入层的设计

我们会使用一个3层神经网络,来处理图片对应的向量x:

如图

输入层需要接收784维的图片向量x。图中的红色箭头,就代表了数据的输入。

x中的每一个维度的数据,都有一个神经元来接收。

因此,输入层就要包含784个神经元。

3.2 隐藏层的设计

隐藏层是指除了输入层的后面层数,也有的是说包含权重的层数,只需要记住,隐藏层的个数等于神经网络层数-1即可。例如本文实现的是3层神经网络,那么隐藏层的个数就是2

隐藏层用于特征提取,它将输入的特征向量,处理为更高级的特征向量。

由于手写数字图像并不复杂,这里就将隐藏层的神经元个数,设置为256。

256就是个经验值,大家也可以设置为128、512,甚至999。

对于手写数字这个问题,并没有太大影响。

这样输入层与隐藏层之间,就会有一个784*256大小的线性层。

它可以将一个784维的输入向量,转换为256维的输出向量。

该输出向量会继续向前传播,到达输出层。

3.3 输出层的设计

由于最终要将数字图像,识别为0到9,10种可能的数字;

因此,输出层需要定义10个神经元,对应这10种数字。

256维的向量,再经过隐藏层和输出层之间的线性层计算后,就得到了10维的输出结果。

这个10维的向量,就是代表了10个数字的预测得分。

不要忘了还得有softmax层!

为了继续得到10个数字的预测概率,我们还要将输出层的输出,输入到softmax层。

softmax层会将10维的向量,转换为10个概率值,p0到p9。

每个概率值,都对应一个数字,也就是输入图片,是某一个数字的可能性。

另外,p0到p9这10个概率值,相加到一起的总和是1。

这是由softmax函数的性质决定的。

以上就是神经网络的设计思路。

3.4 代码实现

对于初学者,我知道很难直接按照这个设计思路,将代码编写出来。

大家最开始的时候,可以先模仿着写,进行练习;

慢慢的自己就会写出完整的模型了。

下面我会基于刚刚的思路,实现Pytorch代码。

如果想进一步理解代码,最好的方式还是将代码编写出来后,然后再将代码跑起来。

首先,定义神经网络Network。

在init函数中:

定义两个线性层layer1和layer2。

layer1和layer2分别是输入层和隐藏层、隐藏层和输出层之间的线性层。

它们的大小分别是784*256和256*10。

也就是右侧图中,红色标记的layer1和layer2。

在前向传播,forward函数中:

函数的输入为图像x。

这个x就是1个或者多个,28*28像素数字图像。

在函数中,需要先将输入的图像x,使用view函数,将x展平。

也就是将n*28*28的数据,展平成n*784的数据。

然后将x输入至layer1;

接着使用relu激活;

最后输入至layer2计算结果,再返回。

另外,需要注意的是:

我没有在forward中直接定义softmax层,

这是因为后面会使用CrossEntropyLoss损失函数。

在这个损失函数中,会实现softmax的计算。

4. 训练数据的准备和处理

如果想要理解一个模型,我们要先理解给它输入的数据。

理解了数据定义和读取,再去看模型,会事半功倍。

4.1 训练数据哪里来?

手写数字识别的训练数据,可以直接使用MNIST数据集。

这个数据集可以从torchvision.datasets中获取。

这里会将数据分别保存到train和test两个目录中,其中:

 

1) train有60000个数据

2)test有10000个数据

它们分别用来模型的训练和测试。

在train和test,这两个目录中,都包括了10个子目录:

子目录的名字就对应了图像中的数字。例如,在名为3的文件夹中,就保存了数字3的图像。

其中图像的名称是随机的字符串签名。

4.2 如何处理和读取这些数据?

完成数据的准备后,实现数据的读取功能,我会基于这一部分的代码进行讲解。

初学者在学习这一部分时,只要知道大致的数据处理流程就可以了。

数据的处理包括三块内容。

第1步,图像数据预处理:

需要实现图像的预处理pipeline,transform。

它包括了将图像转为灰度图和转张量两个功能。

这一步可以简单的理解为,将数组数据处理为训练时所用的张量数据。

第2步,构建数据集对象:

数据集对象的作用,就是用来整体操作训练数据,可以更方便的访问这些数据。

具体来说,使用ImageFolder函数,读取数据文件夹,构建数据集dataset。这个函数会将保存数据的文件夹的名字,作为数据的标签,组织数据。

例如,对于名字为“3”的文件夹,就会将“3”就会作为文件夹中的图像数据的标签。

标签和图像配对,用于后续的训练,ImageFolder使用起来非常方便。

这里我们分别读取训练数据文件夹train和测试数据文件夹test;

这样会得到train_dataset和test_dataset,两个数据集对象。

如果我们此时运行程序,会打印出它们的长度;

会看到,train_dataset是60000,test_dataset是10000。

这就代表了在训练集有60000个数据,测试集中有10000个数据。

第3步,小批量加载数据:

小批量加载数据直接和模型的训练有关。

小批量的数据读取,是训练各类深度学习模型的前提!

以下是创建小批量读取器dataloader的样例代码:

我们会使用train_loader,实现小批量的数据读取。

这里设置小批量的大小,batch_size=64。

也就是每个批次,包括64个数据,一次计算64个数据的梯度!

这时如果运行程序,会打印train_loader的长度,然后看到结果是938。

 具体来说,60000个训练数据,如果每个小批量,读入64个样本;

那么60000个数据会被分成938组。

我们可以计算938*64=60032,不足60000;

这就说明最后一组,会不够64个数据。

小批量的遍历数据,是训练的关键前提 

我们可以通过循环遍历train_loader来获取每个小批量数据。

这里的每一次循环,都会取出64个图像数据,作为一个小批量batch。

此时如果,打印前3个batch观察:

可以看到数据的尺寸data.shape是64*1*28*28:

它表示了每组数据包括64个图像;

每个图像有1个灰色通道;

图像的尺寸是28*28。

接着打印图像的标签label:

可以看到64个图片对应的数字。

其中保存的数值是0到9,对应了10个数字。

5. 模型训练

实际上,对于训练一个深度学习模型,训练后再测试这个深度学习模型;

这两个过程,都是定式。

也就是,无论你训练的模型简单还是复杂,是前馈神经网络还是Transformer,都是哪几个步骤。

当然,对于一些特殊的神经网络,可能会做一些专门的训练优化。

但本质还是那几个步骤,大家在看下面的讲解时,重点是了解这些步骤;

对于每句代码的具体含义,如果真相搞懂,最好还是将代码写出来,然后进行运行和调试。

相同的数据读入步骤

关于模型的训练,前半部分是图像数据的读入。

包括:

1)图像的预处理transform

2)读入并构造数据集train_dataset

3)使用train_loader进行小批量的数据读入。

这一块和刚刚讲的是一样的。

创建核心对象(变量)

在使用Pytorch训练模型时,需要创建三个核心对象(变量)。

大家要记住,无论训练哪种深度学习模型;

下面说的这三个对象,都要创建!

第1个是:

模型本身model,它就是我们设计的神经网络。

第2个是:

优化器optimizer,它用来优化模型中的参数。

初学的时候,直接使用Adam优化器就可以了。

第3个是:

损失函数criterion,对于分类问题,就直接使用CrossEntropyLoss,交叉熵损失误差;

进入模型的循环迭代

模型的循环迭代,同样是定式!

大家记住,迭代深度学习模型,就是两层循环。

这两层循环,分别是:

表示训练轮数的外层循环;

表示梯度下降的内层循环!

具体来说:

外层循环,代表了整个训练数据集的遍历次数。

整个训练集要循环多少轮,是10次、20次或者100次都是可能的。

这里根据经验,设置为10次。

内层循环使用train_loader,进行小批量的数据读取。

内层循环,每循环一次,就会进行一次梯度下降算法。

梯度下降算法

内层循环所包含的梯度下降算法,包括了5个步骤。

这5个步骤,又是使用pytorch框架训练模型的定式。

初学的时候,可以先记住。

具体来说:

1)计算神经网络的前向传播结果output。

2)计算output和标签label之间的损失loss。

3)使用backward计算梯度。

4)使用optimizer.step更新参数。

5)最后将梯度清零。

另外,我们每迭代100个小批量,就打印一次模型的损失,观察训练的过程。

运行程序,就会观察到,模型的损失loss,不断变小。

最后使用torch.save保存模型,模型名字为mnist.pth。

这个“mnist.pth”就是我们最后得到的神经网络模型;

将来再进行数字图片的预测时,就要用它来识别图像。

6. 模型测试

完成模型训练后,需要对模型进行测试。

测试的流程与训练差不多,我们要测试出模型的效果。

测试的过程,也相当于模型的“使用过程”了。

前面是类似的数据读入和模型定义:

首先需要读取测试数据集test_dataset。

然后定义神经网络模型,并加载刚刚训练好的模型文件mnist.pth。

然后是遍历测试数据集,进行预测,统计正确率:

定义变量right,保存正确识别的数量。

遍历test_dataset,将其中的数据x输入到模型model中,计算结果output。

然后从output中,使用argmax,选择概率最大标签的作为预测结果,保存到predict。

接着对比预测值predict和真实标签y。

这里将识别错误的样本打印了出来。

可以看到错误case的预测值predict、真实值y和文件路径。

最终计算出的测试效果为0.978。

也就是10000个数据,有9779个数据识别正确

以上就是从零设计并训练神经网络的过程。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1924534.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【原创】springboot+mysql小区疫情防控网站设计与实现

个人主页:程序猿小小杨 个人简介:从事开发多年,Java、Php、Python、前端开发均有涉猎 博客内容:Java项目实战、项目演示、技术分享 文末有作者名片,希望和大家一起共同进步,你只管努力,剩下的交…

搭建hadoop+spark完全分布式集群环境

目录 一、集群规划 二、更改主机名 三、建立主机名和ip的映射 四、关闭防火墙(master,slave1,slave2) 五、配置ssh免密码登录 六、安装JDK 七、hadoop之hdfs安装与配置 1)解压Hadoop 2)修改hadoop-env.sh 3)修改 core-site.xml 4)修改hdfs-site.xml 5) 修改s…

2、ASPX、.NAT(环境/框架)安全

ASPX、.NAT&#xff08;环境/框架&#xff09;安全 源自小迪安全b站公开课 1、搭建组合&#xff1a; WindowsIISaspxsqlserver .NAT基于windows C开发的框架/环境 对抗Java xx.dll <> xx.jar 关键源码封装在dll文件内。 2、.NAT配置调试-信息泄露 功能点&#xf…

Docker 容器内的php 安装redis扩展

1、https://pecl.php.net/package/redis 下载redis扩展 2、解压redis扩展包&#xff0c;然后通过命令拷贝到php容器 docker cp ~/nginx/redis-4.3.0/* myphp-fpm:/usr/src/php/ext/redis/ myphp-fpm是你的php容器 &#xff5e;/nginx/redis**** 是redi扩展包路径 3、进入php容…

jenkins系列-09.jpom构建java docker harbor

本地先启动jpom server agent: /Users/jelex/Documents/work/jpom-2.10.40/server-2.10.40-release/bin jelexjelexxudeMacBook-Pro bin % sh Server.sh start/Users/jelex/Documents/work/jpom-2.10.40/agent-2.10.40-release/bin jelexjelexxudeMacBook-Pro bin % ./Agent.…

单元测试实施最佳方案(背景、实施、覆盖率统计)

1. 什么是单元测试&#xff1f; 对于很多开发人员来说&#xff0c;单元测试一定不陌生 单元测试是白盒测试的一种形式&#xff0c;它的目标是测试软件的最小单元——函数、方法或类。单元测试的主要目的是验证代码的正确性&#xff0c;以确保每个单元按照预期执行。单元测试通…

mysql-联合查询

一.联合查询的概念 .对于unio查询,就是把多次查询的结果合并起来,形成一个新的查询果集。 SELECT 字段列表 FROM 表A... UNION[ALL] SELECT 字段列表 FROM 表B...&#xff0c; 二.将薪资低于5000的员工,和年龄大于50岁的员工全部查询出来 select * from emp where salary&…

Vulnhub靶场 | DC系列 - DC2

目录 环境搭建渗透测试 环境搭建 靶机镜像下载地址&#xff1a;https://vulnhub.com/entry/dc-2,311/需要将靶机和 kali 攻击机放在同一个局域网里&#xff1b;本实验kali 的 IP 地址&#xff1a;192.168.10.146。 渗透测试 使用 nmap 扫描 192.168.10.0/24 网段存活主机 …

window下安装go环境

一、go官网下载安装包 官网地址如下&#xff1a;https://golang.google.cn/dl/ 选择对应系统的安装包&#xff0c;这里是window系统&#xff0c;可以选择zip包&#xff0c;下载完解压就可以使用 二、配置环境变量 这里的截图配置以win11为例 我的文件解压目录是 D:\Software…

【进阶篇-Day9:JAVA中单列集合Collection、List、ArrayList、LinkedList的介绍】

目录 1、集合的介绍1.1 概念1.2 集合的分类 2、单列集合&#xff1a;Collection2.1 Collection的使用2.2 集合的通用遍历方式2.2.1 迭代器遍历&#xff1a;&#xff08;1&#xff09;例子&#xff1a;&#xff08;2&#xff09;迭代器遍历的原理&#xff1a;&#xff08;3&…

Halcon机器视觉15种缺陷检测案例_2不均匀表面刮伤检测

2&#xff1a; 不均匀表面刮伤检测 思路 1、获取图像 2、分割图像 3、处理区域 4、获取大&#xff0c;小缺陷 效果 原图 代码 *02 不均匀表面刮伤检测 dev_update_off () dev_close_window ()*****************第一步 获取图像******************* read_image (Image, 2.不…

集成excel工具:自定义导入回调监听器、自定义类型转换器、web中的读

文章目录 I 封装导入导出1.1 定义工具类1.2 自定义读回调监听器: 回调业务层处理导入数据1.3 定义文件导入上下文1.4 定义回调协议II 自定义转换器2.1 自定义枚举转换器2.2 日期转换器2.3 时间、日期、月份之间的互转2.4 LongConverterIII web中的读3.1 使用默认回调监听器3.2…

NAT地址转换+多出口智能选路,附加实验内容

本章主要讲&#xff1a;基于目标IP、双向地址的转换 注意&#xff1a;基于目标NAT进行转换 ---基于目标IP进行地址转换一般是应用在服务器端口映射&#xff1b; NAT的基础知识 1、服务器映射 服务器映射是基于目标端口进行转换&#xff0c;同时端口号也可以进行修改&…

AI算法14-套索回归算法Lasso Regression | LR

套索回归算法概述 套索回归算法简介 在统计学和机器学习中&#xff0c;套索回归是一种同时进行特征选择和正则化&#xff08;数学&#xff09;的回归分析方法&#xff0c;旨在增强统计模型的预测准确性和可解释性&#xff0c; 正则化是一种回归的形式&#xff0c;它将系数估…

接口基础知识2:http通信的组成

课程大纲 一、http协议 HTTP&#xff08;Hypertext Transfer Protocol&#xff0c;超文本传输协议&#xff09;是互联网中被使用最广的一种网络协议&#xff0c;用于客户端与服务器之间的通信。 HTTP协议定义了一系列的请求方法&#xff0c;例如 GET、POST、PUT、DELETE 等&…

一篇学通Axios

Axios 是一个基于 Promise 的 HTTP 客户端&#xff0c;用于浏览器和 node.js 环境。它提供了一种简单易用的方式来发送 HTTP 请求&#xff0c;并支持诸如请求和响应拦截、转换数据、取消请求以及自动转换 JSON 数据等功能。 Axios 名字的由来 Axios 的名字来源于希腊神话中的…

在Linux系统实现瑞芯微RK3588部署rknntoolkit2进行模型转换

一、首先要先安装一个虚拟的环境 安装Miniconda包 Miniconda的官网链接:Minidonda官网 下载好放在要操作的linux系统,我用的是远程服务器的linux系统,我放在whl这个文件夹里面,这个文件夹是我自己创建的 运行安装 安装的操作都是yes就可以了 检查是否安装成功,输入下面…

秋招突击——7/13——多线程编程(基础知识回顾+编程练习 )

文章目录 引言基础知识Synchronized关键字使用方式用于同步方法针对同步块的方法静态方法使用原理解析 Volatile使用方式实现原理 final关键字 编程练习&#xff08;synchronized就能实现&#xff09;双线程轮流打印1-100个人实现参考实现 三线程顺序打出1-100个人实现参考实现…

笔记 4 :linux 0.11 中继续分析 0 号进程创建一号进程的 fork () 函数

&#xff08;27&#xff09;本条目开始&#xff0c; 开始分析 copy_process () 函数&#xff0c;其又会调用别的函数&#xff0c;故先分析别的函数。 get_free_page &#xff08;&#xff09; &#xff1b; 先 介绍汇编指令 scasb &#xff1a; 以及 指令 sstosd &#xff1a;…

[USACO24OPEN] Smaller Averages G (单调性优化dp)

来源 题目 Bessie 有两个长度为 N的数组&#xff08;1≤N≤500&#xff09;。第一个数组的第 i 个元素为 ai​&#xff08;1≤ai​≤10^6&#xff09;&#xff0c;第二个数组的第 i个元素为bi​&#xff08;1≤bi​≤10^6&#xff09;。 Bessie 希望将两个数组均划分为若干非空…