PyTorch实战:实现MNIST手写数字识别

news2025/1/23 10:48:31

前言

PyTorch可以说是三大主流框架中最适合初学者学习的了,相较于其他主流框架,PyTorch的简单易用性使其成为初学者们的首选。这样我想要强调的一点是,框架可以类比为编程语言,仅为我们实现项目效果的工具,也就是我们造车使用的轮子,我们重点需要的是理解如何使用Torch去实现功能而不要过度在意轮子是要怎么做出来的,那样会牵扯我们太多学习时间。以后就出一系列专门细解深度学习框架的文章,但是那是较后期我们对深度学习的理论知识和实践操作都比较熟悉才好开始学习,现阶段我们最需要的是学会如何使用这些工具。

深度学习的内容不是那么好掌握的,包含大量的数学理论知识以及大量的计算公式原理需要推理。且如果不进行实际操作很难够理解我们写的代码究极在神经网络计算框架中代表什么作用。不过我会尽可能将知识简化,转换为我们比较熟悉的内容,我将尽力让大家了解并熟悉神经网络框架,保证能够理解通畅以及推演顺利的条件之下,尽量不使用过多的数学公式和专业理论知识。以一篇文章快速了解并实现该算法,以效率最高的方式熟练这些知识。


博主专注数据建模四年,参与过大大小小数十来次数学建模,理解各类模型原理以及每种模型的建模流程和各类题目分析方法。此专栏的目的就是为了让零基础快速使用各类数学模型、机器学习和深度学习以及代码,每一篇文章都包含实战项目以及可运行代码。博主紧跟各类数模比赛,每场数模竞赛博主都会将最新的思路和代码写进此专栏以及详细思路和完全代码。希望有需求的小伙伴不要错过笔者精心打造的专栏。

一文速学-数学建模常用模型


一、数据集加载

MNIST(Modified National Institute of Standards and Technology)是一个手写数字数据集,通常用于训练各种图像处理系统。

它包含了大量的手写数字图像,这些数字从0到9。每个图像都是一个灰度图像,大小为28x28像素,表示了一个手写数字。

MNIST数据集分成两部分:训练集和测试集。训练集通常包含60,000张图像,用于训练模型。测试集包含10,000张图像,用于评估模型的性能。

MNIST数据集是一个非常受欢迎的数据集,被用于测试和验证各种机器学习和深度学习模型,特别是在图像识别任务中。大家可以直接访问官网下载或者是在程序中使用torchvision下载数据集。

官网:THE MNIST DATABASE

一共4个文件,训练集、训练集标签、测试集、测试集标签:

文件名称大小内容
train-labels-idx1-ubyte.gz9,681 kb55000张训练集,5000张验证集
train-labels-idx1-ubyte.gz29 kb训练集图片对应的标签
t10k-images-idx3-ubyte.gz1,611 kb10000张测试集
t10k-labels-idx1-ubyte.gz5 kb测试集图片对应的标签

 程序加载MNIST数据集:

from torch.utils.data import DataLoader
import torchvision.datasets as dsets

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # 将图像转为灰度
    transforms.ToTensor(),  # 将图像转为张量
    transforms.Normalize((0.1307,), (0.3081,))
])

#MNIST dataset
train_dataset = dsets.MNIST(root = '/ml/pymnist',  #选择数据的根目录
                            train = True,  #选择训练集
                            transform = transform,  #不考虑使用任何数据预处理
                            download = True  #从网络上下载图片
                           )
test_dataset = dsets.MNIST(root = '/ml/pymnist',#选择数据的根目录
                           train = False,#选择测试集
                           transform = transform, #不考虑使用任何数据预处理
                           download = True #从网络上下载图片
                          )
#加载数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size = batch_size,
                                           shuffle = True #将数据打乱
                                          )
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                           batch_size = batch_size,
                                           shuffle = True
                                         )

图片展示:

import matplotlib.pyplot as plt
digit = train_dataset.train_data[0]
plt.imshow(digit,cmap=plt.cm.binary,interpolation='none')
plt.title("Labels: {}".format(train_dataset.train_labels[0]))
plt.show()

 之后需要切分数据集,分为训练集和测试集,MNIST数据集已经做好了直接使用就好了:

print("train_data:",train_dataset.train_data.size())
print("train_labels:",train_dataset.train_labels.size())
print("test_data:",test_dataset.test_data.size())
print("test_labels:",test_dataset.test_labels.size())

train_data: torch.Size([60000, 28, 28])
train_labels: torch.Size([60000])
test_data: torch.Size([10000, 28, 28])
test_labels: torch.Size([10000])

还需要确定批次的尺寸,在神经网络训练中,batch_size 是指每次迭代训练时,模型同时处理的样本数量。它在训练过程中起到了几个重要作用:

  • 加速训练过程:通过同时处理多个样本,利用了现代计算机的并行计算能力,可以加速训练过程,尤其在使用GPU时。
  • 减少内存消耗:一次性加载整个训练集可能会占用大量内存,而将数据分批次加载可以降低内存消耗,使得在内存受限的环境中也能进行训练。
  • 提高模型泛化能力:在训练过程中,模型会根据每个batch的数据调整权重,而不是依赖于整个训练集。这样可以提高模型对不同样本的泛化能力。
  • 避免陷入局部极小值:随机选择的小批量样本可以帮助模型避免陷入局部极小值。
  • 增加噪声鲁棒性:在每次迭代中,模型只看到了一个小样本,这可以看作一种随机噪声,有助于提高模型的鲁棒性。
  • 方便在线学习:对于在线学习任务,可以动态地加载新的数据批次,而不需要重新训练整个模型。

总的来说,合理选择合适的batch_size可以使训练过程更加高效、稳定,并且能够提高模型的泛化能力。然而,过大的batch_size可能会导致内存溢出或训练速度变慢,过小的batch_size可能会导致模型收敛困难。因此,选择合适的batch_size需要在实践中进行调试和优化。

print("批次的尺寸:",train_loader.batch_size)
print("load_train_data:",train_loader.dataset.train_data.shape)
print("load_train_labels:",train_loader.dataset.train_labels.shape)

 

批次的尺寸: 100
load_train_data: torch.Size([60000, 28, 28])
load_train_labels: torch.Size([60000])

 从输出结果中,可以看到原始数据集和数据打乱按照批次读取的数据集的总行数是一样的,实际操作中train_loader以及test_loader将作为神经网络的输入数据源。

二、定义神经网络

在前面的文章中已经带着大家搭建过好几遍神经网络了,注意初始化网络和对应的输入层,隐藏层和输出层。

import torch.nn as nn
import torch

input_size = 784 #mnist的像素为28*28
hidden_size = 500
num_classes = 10#输出为10个类别分别对应于0~9

#创建神经网络模型
class Neural_net(nn.Module):
#初始化函数,接受自定义输入特征的维数,隐含层特征维数以及输出层特征维数
    def __init__(self,input_num,hidden_size,out_put):
        super(Neural_net,self).__init__()
        self.layer1 = nn.Linear(input_num,hidden_size) #从输入到隐藏层的线性处理
        self.layer2 = nn.Linear(hidden_size,out_put) #从隐藏层到输出层的线性处理
        
    def forward(self,x):
        x = self.layer1(x) #输入层到隐藏层的线性计算
        x = torch.relu(x) #隐藏层激活
        x = self.layer2(x) #输出层,注意,输出层直接接loss
        return x
        
net = Neural_net(input_size,hidden_size,num_classes)
print(net)
        
Neural_net(
  (layer1): Linear(in_features=784, out_features=500, bias=True)
  (layer2): Linear(in_features=500, out_features=10, bias=True)
)

 super(Neural_net, self).init() 是 Python 中用于调用父类的方法或属性的一种方式。在这里,Neural_net 是你定义的神经网络模型的类名,它继承了 nn.Module 类,而 nn.Module 是 PyTorch 中用于构建神经网络模型的基类。也就是说,你的神经网络模型会继承 nn.Module 的所有属性和方法,这样你可以在 Neural_net 类中使用 nn.Module 中定义的各种功能,比如添加神经网络层、指定损失函数等。

三、训练模型

只有要注意一下Variable,之前的文章中又提到过Variable。

Variable是PyTorch早期版本(0.4版本之前)中用于构建计算图的抽象,它包含了data、grad和grad_fn等属性,可以用于构建计算图,并在反向传播时自动计算梯度。但从PyTorch 0.4版本开始,Variable被官方废弃,而Tensor直接支持了自动求导功能,不再需要显式地创建Variable。

因此,Autograd是PyTorch实现自动求导的核心机制,而Variable是早期版本中用于构建计算图的一种抽象,现在已经被Tensor所取代。 Autograd会自动追踪Tensor上的操作,并在需要时计算梯度,从而实现反向传播。

#optimization
import numpy as np
from torchvision import transforms

learning_rate = 1e-3 #学习率
num_epoches = 5
criterion =nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(),lr = learning_rate) #随机梯度下降

for epoch in range(num_epoches):
    print('current epoch = %d' % epoch)
    for i ,(images,labels) in enumerate(train_loader,0):
        images=images.view(-1,28*28)
        outputs = net(images) #将数据集传入网络做前向计算
        labels = torch.tensor(labels, dtype=torch.long)
        
        loss = criterion(outputs, labels) #计算loss
        
        optimizer.zero_grad() #在做反向传播之前先清楚下网络状态
        loss.backward() #Loss反向传播
        optimizer.step() #更新参数
        
        if i % 100 == 0:
            print('current loss = %.5f' % loss.item())
            
print('finished training')

 

current epoch = 1
current loss = 0.27720
current loss = 0.23612
current loss = 0.39341
current loss = 0.24683
current loss = 0.18913
current loss = 0.31647
current loss = 0.28518
current loss = 0.18053
current loss = 0.34957
current loss = 0.31319
current epoch = 2
current loss = 0.15138
current loss = 0.30887
current loss = 0.24257
current loss = 0.46326
current loss = 0.30790
current loss = 0.17516
current loss = 0.32319
current loss = 0.32325
current loss = 0.32066
current loss = 0.24271

 四、准确度测试

各层的权重通过随机梯度下降法更新Loss之后,针对测试集数字分类的准确率:

#prediction
total = 0
correct =0 
acc_list_test = []
for images,labels in test_loader:
    images=images.view(-1,28*28)
    outputs = net(images) #将数据集传入网络做前向计算
    
    _,predicts = torch.max(outputs.data,1)
    total += labels.size(0)
    correct += (predicts == labels).sum()
    acc_list_test.append(100 * correct / total)
    
print('Accuracy = %.2f'%(100 * correct / total))
plt.plot(acc_list_test)
plt.xlabel('Epoch')
plt.ylabel('Accuracy On TestSet')
plt.show()

 

 

点关注,防走丢,如有纰漏之处,请留言指教,非常感谢

以上就是本期全部内容。我是fanstuck ,有问题大家随时留言讨论 ,我们下期见。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1020899.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AI绘图提示词Stable Diffusion Prompt 笔记

基础 提示词分为正向提示词(positive prompt)和反向提示词(negative prompt),用来告诉AI哪些需要,哪些不需要词缀的权重默认值都是1,从左到右依次减弱,权重会影响画面生成结果。AI …

登录认证方式汇总,例如ThreadLocal+拦截器+Redis、JWT

登录方式汇总 先讲讲传统的登录方式 1.Cookie方案 用cookie作为媒介存放用户凭证。 用户登录系统之后,会返回一个加密的cookie,当用户访问子应用的时候会带上这个cookie,授权以解密cookie并进行校验,校验通过后即可登录当前用…

基于Java的高校竞赛管理系统设计与实现(亮点:发起比赛、报名、审核、评委打分、获奖排名,可随意更换主题如蓝桥杯、ACM、王者荣耀、吃鸡等竞赛)

高校竞赛管理系统 一、前言二、我的优势2.1 自己的网站2.2 自己的小程序(小蔡coding)2.3 有保障的售后2.4 福利 三、开发环境与技术3.1 MySQL数据库3.2 Vue前端技术3.3 Spring Boot框架3.4 微信小程序 四、功能设计4.1 主要功能描述4.2 系统角色 五、系统…

linux常用命令(1):zip/unzip命令(压缩文件/解压缩文件)

文章目录 前言一、linux安装zip文件二、zip语法2.1、常用选项实例2.1.1、-r (压缩文件夹,解决80%的场景)2.1.2、-q2.1.3、-d(从压缩文件中删除指定文件)2.1.4、-u(更新文件)2.1.5、-f2.1.6、-m2…

leetcode1797. 设计一个验证系统(java)

设计一个验证系统 题目描述哈希表题目描述 题目描述 难度 - 中等 leetcode1797. 设计一个验证系统 你需要设计一个包含验证码的验证系统。每一次验证中,用户会收到一个新的验证码,这个验证码在 currentTime 时刻之后 timeToLive 秒过期。如果验证码被更新…

XGBoost算法讲解和公式推导

本文节选自《机器学习入门基础(微课版)》 10.5 XGBoost 算法 XGBoost 是 2014 年 2 月由华盛顿大学的博士生陈天奇发明的基于梯度提升算法(GBDT)的机器学习算法,其算法不但具有优良的学习效果,而且训练速度高效,在数据…

Linux下的系统编程——线程同步(十三)

前言: 在多线程编程中,如果多个线程同时访问和修改共享资源,可能会产生竞争条件和数据不一致的问题。同步机制用于协调线程之间的访问和操作,确保数据的正确性和一致性。为了避免多个线程同时访问和操作共享资源导致的问题&#…

小程序键盘没有【小数点】输入

<input v-model"formData.number" :auto-height"true" placeholder"请输入" confirm-type"done" type"digit" maxlength"11" input"inputNumber" />number&#xff1a;数字键盘&#xff08;没有小…

Minitab Express for Mac(数据分析软件)附破解补丁 v1.5.0 支持M1

Minitab Express是一款专为Mac用户设计的数据分析和统计软件。它提供了一套全面的工具和功能&#xff0c;用于分析数据、执行统计计算和生成可视化。 下载&#xff1a;Minitab Express for Mac(数据分析软件)附破解补丁 以下是 Minitab Express for Mac 的一些主要功能&#x…

随机森林案例分析

阅读随机森林模型前&#xff0c;建议首先阅读决策树模型手册&#xff08;点击后跳到决策树模型的帮助手册页面&#xff09;&#xff0c;因为随机森林模型实质上是多个决策树模型的综合&#xff0c;决策树模型只构建一棵分类树&#xff0c;但是随机森林模型构建非常多棵决策树&a…

数字虚拟人制作简明指南

如何在线创建虚拟人&#xff1f; 虚拟人&#xff0c;也称为数字化身、虚拟助理或虚拟代理&#xff0c;是一种可以通过各种在线平台与用户进行逼真交互的人工智能人。 在线创建虚拟人变得越来越流行&#xff0c;因为它为个人和企业带来了许多好处。 推荐&#xff1a;用 NSDT编辑…

阿里云无影云电脑和传统PC有什么区别?

阿里云无影云电脑和传统电脑PC有什么区别&#xff1f;区别大了&#xff0c;无影云电脑是云端的桌面服务&#xff0c;传统PC是本地的硬件计算机&#xff0c;无影云电脑的数据是保存在云端&#xff0c;本地传统PC的数据是保存在本地客户端&#xff0c;阿里云百科分享阿里云无影云…

低代码与低代码平台

随着数字化转型和软件需求的不断增长&#xff0c;传统的手写代码开发方式已经无法满足迅速推出应用程序的需求。为了加快软件开发的速度并降低技术门槛&#xff0c;低代码开发模式应运而生。本文将介绍低代码的概念&#xff0c;探讨什么是低代码、什么是低代码平台&#xff1f;…

RFID设备在自动化堆场中的管理应用

随着信息技术的高速发展&#xff0c;带动了港口生产和管理技术的长足进步&#xff0c;港口堆场内的自动化场桥的智能化水平成为码头提高生产率一个重要标签。各地海关着力于现代化科技手段&#xff0c;努力构筑新型的便捷通关模式&#xff0c;在进出口环节做好管理和服务。 全…

已解决 Kotlin Error: Type mismatch: inferred type is String but Int was expected

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页: &#x1f405;&#x1f43e;猫头虎的博客&#x1f390;《面试题大全专栏》 &#x1f995; 文章图文并茂&#x1f996…

idea之maven的安装与配置

我们到maven的官网里下载maven&#xff0c;地址&#xff1a;https://maven.apache.org/download.cgi下载完成后解压即可配置环境变量 此电脑–>右键–>属性–>高级系统设置–>环境变量–>系统变量&#xff08;S&#xff09;–>新建一个系统变量 变量名&…

【开发记录01】开发环境副本/页的导入&带用户权限管理系统

在蒋老师的指导下大概了解了: 1.开发环境的数据导入/导出 共享组件的同步 因为应用程序277是应用程序100的子程序&#xff0c;所以共享组件必须和100保持一致。 但是会出现一个小问题&#xff1a; 在APEX开发过程中同时打开两个不同的应用程序&#xff0c;但是编辑过程中经…

Java 华为真题-猴子爬山

需求&#xff1a; 一天一只顽猴想去从山脚爬到山顶&#xff0c;途中经过一个有个N个台阶的阶梯&#xff0c;但是这猴子有一个习惯&#xff1a;每一次只能跳1步或跳3步&#xff0c;试问猴子通过这个阶梯有多少种不同的跳跃方式&#xff1f; 输入描述 输入只有一个整数N&#xff…

Python - 小玩意 - 键盘记录器

pip install keyboardimport keyboard import timedef get_time():date_time time.strftime("%Y-%m-%d %H:%S", time.localtime())return date_timedef abc(x):if x.event_type down:print(f"{get_time()}你按下了{x.name}")with open(./键盘记录器.txt,…

CG Magic分享同一场景里下,VR渲染器和CR渲染器哪个好?

渲染操作时&#xff0c;VR渲染器和CR渲染器的对比成为常见问题了。这个问题很多人都会问。 今天CG Magic小编通过一个真实的项目&#xff0c;就是同一场景下来比较一下VR渲染器和CR渲染器的区别。 以下图为例是用来测试的场景当年的最终图。采用了当年的一个伊丽莎白大街152号的…