《深度学习》PyTorch 手写数字识别 案例解析及实现 <下>

news2024/9/23 1:36:02

目录

一、回顾神经网络框架

1、单层神经网络

2、多层神经网络

二、手写数字识别

1、续接上节课代码,如下所示

2、建立神经网络模型

输出结果:

3、设置训练集

4、设置测试集

5、创建损失函数、优化器

参数解析:

1)params

2)lr

3)loss

6、开始训练

运行结果:

三、总结

1、关键步骤

2、训练过程中

3、训练完成后


一、回顾神经网络框架

1、单层神经网络

        图示为单层神经完了的基本构造,首先,有信号想传入输入层,然后这些信号会在传播途中发生衰减或者增强,假如想传入的信号为x,那么经过衰减或增强后的值则为wx,这个w叫权重,此时便有了传入信号,有很多传入信号,那么神经元就会对这些信号进行处理,把这些信号的加权值求和,再将这个求和的值映射到非线性激活函数上来引入非线性特征,最后得到输出结果。

2、多层神经网络

        相比于上述的单层神经网络,单层神经网络就相当于上图中的绿色的第一列,即多个信号传入后只进行加权映射处理一次即输出结果,而多层神经网络即有多列神经元,传入信号经过第一列神经元处理后得到的值将其再次当做信号进行衰减或增强传入下一层神经元对其进行处理,至于多少层需要经过多次训练去决定,没有固定的值。

二、手写数字识别

1、续接上节课代码,如下所示

import torch
print(torch.__version__)

"""MNIST包含70,000张手写数字图像:60,000张用于训练,10,000张用于测试。
图像是灰度的,28x28像素的,并且居中的,以减少预处理和加快运行。"""

from torch import nn  # 导入神经网络模块
from torch.utils.data import DataLoader  # 数据包管理工具,打包数据,
from torchvision import datasets   # 封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor   # 数据转换,张量,将其他类型的数据转换为tensor张量,numpy arrgy,

"""下载训练数据集,图片+标签"""
training_data = datasets.MNIST(   # 跳转到函数的内部源代码,pycharm 按下ctrl +鼠标点击
    root='data',   # 表述下载的数据存放的根目录
    train=True,   # 表示下载的是训练数据集,如果要下载测试集,更改为False即可
    download=True,   # 表示如果根目录有该数据,则不再下载,如果没有则下载
    transform=ToTensor()   # 张量,图片是不能直接传入神经网络模型
    # 表示制定一个数据转换操作,将下载的图片转换为pytorch张量,因为pytorch只能处理张量tensor类型的数据
)

test_data = datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor()  # Tensor是在深度学习中提出并广泛应用的数据类型,它与深度学习框架(如 PyTorch、TensorFlo
)  # NumPy 数组只能在CPU上运行。Tensor可以在GPU上运行,这在深度学习应用中可以显著提高计算速度。

print(len(training_data))
print(len(test_data))

"""展示手写数字图片,把训练数据集中的前59000张图片展示一下"""

# # tensor -》numpy  矩阵类型的数据,矩阵是特殊的张量,张量可以包含任意维度的数据
# from matplotlib import pyplot as plt   # 导入绘图库
# figure = plt.figure()   # 设置一个空白画布
# for i in range(9):
#     img,label = training_data[i+59000]   # 提取第59000张图片开始,共9张,返回图片及其对应的标签值
#
#     figure.add_subplot(3,3,i+1)   # 在画布创建3行3列的小窗口,通过遍历的值i来确定每个画布展示的图片
#     plt.title(label)   # 设置每个窗口的标题,设置标签为上述返回的标签值
#     plt.axis('off')   # 取消画布中的坐标轴的图像
#     plt.imshow(img.squeeze(),cmap='gray')   # plt.imshow()将NumPy数组data中的数据显示为图像,并在图形窗口中,
#     a = img.squeeze()   # img.squeeze()从张量img中去掉维度为1的。如果该维度的大小不为1,则张量不会改变。
# plt.show()



train_dataloader = DataLoader(training_data,batch_size=64)  # 调用上述定义的DataLoader打包库,将训练集的图片和标签,64张图片为一个包,
test_dataloader = DataLoader(test_data,batch_size=64)   # 将测试集的图片和标签,每64张打包成一份
for x,y in test_dataloader:
    # x是表示打包好的每一个数据包,其形状为[64,1,28,28],64表示批次大小,1表示通道数为1,即灰度图,28表示图像的宽高像素值
    # y表示每个图片标签
    print(f"shape of x[N,C,H,W]:{x.shape}")   # 打印图片形状
    print(f"shape of y:{y.shape}{y.dtype}")   # 打印标签的形状和数据类型
    break  # 跳出并终止循环,表示只遍历一个包的数据情况

"""判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""  # 返回cuda,mps,cpu, m1,m2集显CPU+GPU RTX3060
device = "cuda" if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")  # 字符串的格式化。CUDA驱动软件的功能:pytorch能够去执行cuda的命令,cuda通过GPU指令集
# 神经网络的模型也需要传入到GPU,1个batchsize的数据集也需要传入到GPU,才可以进行训练。

2、建立神经网络模型

class NeuralNetwork(nn.Module):  # 通过调用类的形式来使用神经网络,神经网络的模型
    def __init__(self):
        super().__init__()   # 继承的父类初始化
        self.flatten = nn.Flatten()   # 将输入的多维数据展开成一维数据,创建一个展开对象flatten
        self.hidden1 = nn.Linear(28*28,128)  # 创建一个全连接层,其上包含权重的值,第1个参数表示有多少个神经元传入进来,第2个参数表示有多少个数据传出去,这里表示上述多层神经网络的第一列神经元
        self.hidden2 = nn.Linear(128, 256)  # 再次创建一个全连接层,将第一层输出个数转变成这一层的输入信号数,然后在设定输出层个数为256
        self.hidden3 = nn.Linear(256,128)  # 继承上一层的输出信号个数,将其当做这一层的输入为256层,设定输出为128层
        self.hidden4 = nn.Linear(128,64)   # 输入128层,输出64层
        self.out = nn.Linear(64, 10)   # 设定输出层,输出必需和标签的类别相同,输入必须是上一层的神经元个数

    def forward(self,x):   # 设定前向传播函数,你得告诉它,数据的流向。是神经网络层连接起来,函数名称不能改。当你调用forward函数的时候,传入进来的图像数据
        x = self.flatten(x)  # 传入信号x图像,首先进行展开
        x = self.hidden1(x)  # 将其当做输入层
        x = torch.sigmoid(x)   # 使用sigmoid激活函数得到结果,torch使用的relu函数 relu,tanh
        x = self.hidden2(x)    # 再次对上一层的结果进行加权
        x = torch.sigmoid(x)
        x = self.hidden3(x)
        x = torch.sigmoid(x)
        x = self.hidden4(x)
        x = torch.sigmoid(x)
        x = self.out(x)  # 输出结果
        return x  # 将输出结果返回出去

model = NeuralNetwork().to(device)   # 把刚刚创建的模型传入到gpu
print(model)   # 打印模型的构造

        此处的神经网络层数可以自己设置,如果想多设置几层或者少设置基层,只需改变self.hidden的个数以及前向传播函数中激活函数的个数。

        输出结果:

3、设置训练集

def train(dataloader,model,loss_fn,optimizer):   # 导入参数,dataloader表示打包,数据加载器,model导入上述定义的神经网络模型,loss_fn表示损失值,optimizer表示优化器
    model.train()   # 模型设置为训练模式
    # 告诉模型,我要开始训练,模型中权重w进行随机化操作,已经更新w。在训练过程中,w会被修改的
    # #pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。
    # 一般用法是:在训练开始之前写上model.train(),在测试时写上model.eval()。
    batch_size_num = 1
    for x,y in dataloader:    # 遍历打包的图片的每一个包中的每一张图片及其对应的标签,其中batch为每一个数据的编号
        x,y = x.to(device),y.to(device)   # 把训练数据集和标签传入cpu或GPU
        pred = model.forward(x)    # 模型进行前向传播,输入图片信息后得到预测结果,forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值
        loss = loss_fn(pred,y)     # 调用交叉熵损失函数计算损失值loss,输入参数为预测结果和真实结果,
        # Backpropaqation 进来一个batch的数据,计算一次梯度,更新一次网络
        optimizer.zero_grad()    # 梯度值清零,在反向传播之前先清除之前的梯度
        loss.backward()     # 反向传播,计算得到每个参数的梯度值w
        optimizer.step()    # 根据梯度更新权重w参数

        loss_value = loss.item()   # 从tensor数据中提取数据出来,tensor获取损失值
        if batch_size_num % 200 == 0:  # 判断遍历包的个数是否整除于200,用于将训练到的包的个数打印出来,整除200目的是节省资源
            print(f"loss:{loss_value:>7f}   [number: {batch_size_num}]")  # 打印损失值及其对应的值,损失值最大宽度为7,右对齐
        batch_size_num += 1    # 每遍历一个包增加一次,以达到显示出来遍历的包的个数

4、设置测试集

def test(dataloader,model,loss_fn):  # 输入参数打包的图片、训练好的模型、以及损失值
    size = len(dataloader.dataset)   # 返回测试数据集的样本总数
    num_batches = len(dataloader)   # 返回当前dataloader配置下的批次数
    model.eval()    # 表示此为模型测试,w就不能再更新。
    test_loss,correct = 0, 0   # 设置总损失值初始化为0,正确预测结果初始化为0
    with torch.no_grad():    # 一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算
        for x,y in dataloader:   # 遍历测试集中的每个包的每个图片及其对应的标签
            x,y = x.to(device),y.to(device)   # 将其传入gpu
            pred = model.forward(x)   # 图片数据进行前向传播
            test_loss += loss_fn(pred,y).item()    # test_loss是会自动累加每一个批次的损失值
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # pred.argmax(1) == y用于判断预测结果最大值对用的标签是否与真实值相同,然后将判断结果的bool值转变为浮点数并求和
            a = (pred.argmax(1) == y)   # dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值对应的索引号
            b = (pred.argmax(1) == y).type(torch.float)
    test_loss /= num_batches    # 总损失值除以打包的批次数,返回测试的每一个包的损失值的均值,能来衡量模型测试的好坏。
    correct /= size   # 平均的正确率
    print(f"Test result: \n Accuracy:{(100 * correct)}%, Avg loss:{test_loss}")

5、创建损失函数、优化器

loss_fn = nn.CrossEntropyLoss()  # 创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果
optimizer = torch.optim.Adam(model.parameters(),lr=0.0045)  # 创建一个优化器,SGD为随机梯度下降算法,学习率或者叫步长为0.0045
参数解析:
        1)params

                表示要训练的参数,一般我们传入的都是model.parameters()

        2)lr

                learning_rate学习率,也就是步长

        3)loss

                表示模型训练后的输出结果与样本标签的差距。如果差距越小,就表示模型训练越好,越逼近于真实的型。

6、开始训练

epochs = 25  # 设置训练的轮数为25轮,因为模型中设置了权重值的更新,所以重复训练会更新模型的权值
for i in range(epochs):
    print(f"Epoch {i+1}\n--------------------")
    train(train_dataloader,model,loss_fn,optimizer)
print('Done!!')
test(test_dataloader,model,loss_fn)   # 导入测试集进行测试

        将上述所有代码连贯起来运行。

        运行结果:

        至此训练结束,用户可手动更改模型激活函数、学习率、模型训练轮数等等,以找到最优结果。

        祝你成功!!

三、总结

1、关键步骤

        使用PyTorch进行手写数字识别可以分为几个关键步骤。首先,需要准备手写数字数据集,通常使用MNIST数据集。然后,需要定义神经网络模型,可以使用PyTorch提供的各种层和激活函数来构建模型架构。接下来,需要定义损失函数,通常使用交叉熵损失函数来计算预测结果与真实标签之间的差异。

2、训练过程中

        在训练过程中,使用优化算法(如随机梯度下降)来更新模型的权重和偏置,使其逐渐接近最优解。训练过程中会使用训练集的数据进行反向传播,并根据损失函数的结果来调整模型参数。可以设置训练的轮数和批次大小,以更好地优化模型。

3、训练完成后

        在训练完成后,可以使用测试集或新的手写数字图像来评估模型的性能。通过将图像输入已训练好的模型,可以获得预测结果,并与真实标签进行比较。可以计算准确率等指标来评估模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2139091.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Cesium 计算3d凸包(ConvexHull)

Cesium 计算3d凸包(ConvexHull) Cesium 计算3d凸包(ConvexHull)

Unity实战案例全解析:PVZ 植物放置分析

前篇:Unity实战案例全解析:PVZ 植物卡片状态分析-CSDN博客 植物应该如何从卡牌状态转为实物? 其实就只需要考虑两个步骤加一个后续处理: 1.点击卡牌后就实例化 需要一个植物状态枚举,因为卡牌分为拿在手上和种植下…

Android 10.0 mtk平板camera2横屏预览旋转90度横屏保存圆形预览缩略图旋转90度功能实现

1.前言 在10.0的系统rom定制化开发中,在进行一些平板等默认横屏的设备开发的过程中,需要在进入camera2的 时候,默认预览图像也是需要横屏显示的,在上一篇已经实现了横屏预览功能,然后发现横屏预览后,点击录像和照片下保存的圆形预览缩略图 依然是竖屏的,所以说同样需要…

需求导向的正则表达式

目录 re.sub 需求:把 1. 2.这些序号转成(1) (2) 需求:反过来,把(1)->1. ,(2)》2. 。 需求:把出现的 1 2 3都转成下标 进阶1!只想让化学符…

Redis入门2

在java中操作Redis Redis的Java客户端 Redis 的 Java 客户端很多,常用的几种: Jedis Lettuce Spring Data Redis Spring Data Redis 是 Spring 的一部分,对 Redis 底层开发包进行了高度封装。 在 Spring 项目中,可以使用Spring Data R…

把项目部署到Linux系统上(如何在阿里云服务器上安装和配置SpringBoot+vue全栈开发环境)

项目部署上线 环境准备下载安装Linux系统和ssh连接工具背景知识安装虚拟机安装Linux系统选择installCentOS7按命令IP addr查看服务器IP地址,ens33网卡中会出现IP地址配置好后就可以查看了一个可远程连接Linux服务器的工具1.(基于finalshell工具&#xff…

小明震惊OpenAI 的新模型 01

在硅谷的中心,繁忙的咖啡馆和创业中心周围,年轻的软件工程师小明坐在他的办公桌前,面露困惑。科技界一直在盛传一项新的AI突破,但他持怀疑态度,不敢抱太大希望。他认为AI泡沫即将破灭,炒作列车即将出轨&…

网络原理 IP协议与以太网协议

博主主页: 码农派大星. 数据结构专栏:Java数据结构 数据库专栏:MySQL数据库 JavaEE专栏:JavaEE 关注博主带你了解更多数据结构知识 目录 1.网络层 IP协议 1.IP协议格式 2.地址管理 2.1 IP地址 2.2 解决IP地址不够用的问题 2.3NAT网络地址转换 2.4网段划分 3.路由选择…

北极星计划的回响:从Leap Motion到Midjourney的AI 3D硬件梦想

在科技的浩瀚星空中,总有一些梦想如同北极星般璀璨,指引着探索者前行。六年前,Leap Motion的CEO David以一篇充满激情的博客文章,向我们揭示了“北极星计划”——一个旨在打破数字与物理界限,创造流畅统一体验的增强现实平台。今天,随着Midjourney在AI文生图领域的全球爆…

2024.9.15周报

一、题目信息 题目:Physics-informed neural networks for solving flow problems modeled by the 2D Shallow Water Equations without labeled data 链接:物理信息神经网络用于解决由二维浅水方程建模的流动问题,无需标记数据- ScienceDi…

【Node.js】初识 RabbitMQ

概述 MQ 顾名思义,是消息队列。 RabbitMQ 是一个消息队列系统,用于实现异步通信。基于 AMQP。AMQP(高级消息队列协议) 实现了对于消息的排序,点对点通讯,和发布订阅,保持可靠性、保证安全性。 在 Node.js 的微服务架…

LAMP+WordPress

一、简介 LAMP: L:linux——操作系统,提供服务器运行的基础环境。A:apache(httpd)——网页服务器软件,负责处理HTTP请求和提供网页内容。M:mysql,mariadb——数据库管理…

PCL 窗口可视化两个点云

目录 一、概述 1.1原理 1.2实现步骤 1.3 应用场景 二、代码实现 2.1关键函数 2.2完整代码 三、实现效果 PCL点云算法汇总及实战案例汇总的目录地址链接: PCL点云算法与项目实战案例汇总(长期更新) 一、概述 本文将介绍如何使用PCL库…

8.4Prewitt算子边缘检测

基本原理 Prewitt算子是一种用于边缘检测的经典算子,它通过计算图像中像素值的(一阶导数)梯度来检测边缘。Prewitt算子通常包括两个3x3的卷积核,一个用于检测水平方向上的边缘,另一个用于检测垂直方向上的边缘。 示例…

【动漫资源管理系统】Java SpringBoot助力,搭建一个高清动漫在线观看网站

🍊作者:计算机毕设匠心工作室 🍊简介:毕业后就一直专业从事计算机软件程序开发,至今也有8年工作经验。擅长Java、Python、微信小程序、安卓、大数据、PHP、.NET|C#、Golang等。 擅长:按照需求定制化开发项目…

【插件】【干货】用EPPlus在Unity中读写Excel表

EPPlus是什么我就不说了,你都点进来了肯定知道 几个常用的api 1.index下标都是从1开始的 2.可以读取任意单元格上的任意内容,不需要给excel表写规则 但是如果你写了规则,就需要自己用额外的代码 --- 数据结构去实现 3.打开excel表 ExcelP…

[数据集][目标检测]智慧交通铁路异物入侵检测数据集VOC+YOLO格式802张7类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):802 标注数量(xml文件个数):802 标注数量(txt文件个数):802 标注类别…

萤石举办2024清洁机器人新品发布会 多维智能再造行业标杆

导言:作为智慧生活守护者,萤石今日发布了两款清洁机器人,AI扫拖机器人RS20 Pro Ultra 和AI洗地机器人RX30 Max ,标志着萤石在智能清洁领域的全新突破。RS20 Pro Ultra基于CutFree 2.0内切割滚刷专利,有效解决毛发缠绕难…

速通GPT:《Improving Language Understanding by Generative Pre-Training》全文解读

文章目录 速通GPT系列几个重要概念1、微调的具体做法2、任务感知输入变换3、判别式训练模型 Abstract概括分析和观点1. 自然语言理解中的数据问题2. 生成预训练和监督微调的结合3. 任务感知输入变换4. 模型的强大性能 Introduction概括分析和观点1. 自然语言理解的挑战在于对标…

探索Python的HTML处理神器:pyquery的魔力

文章目录 探索Python的HTML处理神器:pyquery的魔力背景:为何选择pyquery?pyquery是什么?安装pyquery五个简单的库函数使用方法1. $:选择元素2. .text():获取文本内容3. .html():获取HTML内容4. …