# 手写数字识别:使用PyTorch构建MNIST分类器

news2025/4/21 2:01:34

手写数字识别:使用PyTorch构建MNIST分类器

在这篇文章中,我将引导你通过使用PyTorch框架构建一个简单的神经网络模型,用于识别MNIST数据集中的手写数字。MNIST数据集是一个经典的机器学习数据集,包含了60,000张训练图像和10,000张测试图像,每张图像都是28x28像素的灰度手写数字。
在这里插入图片描述

在这里插入图片描述

环境准备

首先,确保你的环境中安装了PyTorch和torchvision。可以通过以下命令安装:

pip install torch torchvision

数据加载与预处理

我们首先加载MNIST数据集,并将图像转换为PyTorch张量格式,以便模型可以处理。

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

'''下载训练数据集(包含训练图片+标签)'''
training_data = datasets.MNIST( #跳转到函数的内部源代码,pycharm 按下ctrl+鼠标点击 training_data:Dataset
    root="data",#表示下载的手写数字 到哪个路径。60000
    train=True, #读取下载后的数据 中的 训练集
    download=True,#如果你之前已经下载过了,就不用再下载
    transform=ToTensor(), #张量,图片是不能直接传入神经网络模型
)   #对于pytorch库能够识别的数据一般是tensor张量。


'''下载测试数据集(包含训练图片+标签)'''
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
print(len(training_data))

数据可视化

为了更好地理解数据,我们可以展示一些手写数字图像。

''展示手写字图片,把训练数据集中的前59000张图片展示一下'''

from matplotlib import pyplot as plt
figure = plt.figure()
for i in range(9):
    img, label = training_data[i+59000] #提取第59000张图片

    figure.add_subplot(3, 3, i+1) #图像窗口中创建多个小窗口,小窗口用于显示图片
    plt.title(label)
    plt.axis("off") # plt.show(I)#是示矢量,
    plt.imshow(img.squeeze(), cmap="gray")
    a = img.squeeze()
plt.show()

创建DataLoader

为了高效地加载数据,我们使用DataLoader来批量加载数据。

# '"创建数据DataLoader(数据加载器)开'
#  'batch_size:将数据集分成多份,每一份为batch_size个数据'
#  '优点:可以减少内存的使用,提高训练速度。

train_dataloader = DataLoader(training_data, batch_size=64) #64张图片为一个包,train_dataloader:<torch
test_dataloader = DataLoader(test_data, batch_size=64)

模型定义

接下来,我们定义一个简单的神经网络模型,包含两个隐藏层和一个输出层。

'''定义神经网络类的继承这种方式'''
class NeuralNetwork(nn.Module):  #通过调用类的形式来使用神经网络,神经网络的模型,nn.module
    def __init__(self): #python基础关于类,self类自已本身
        super().__init__() #继承的父类初始化
        self.flatten = nn.Flatten() #展开,创建一个展开对象flatten
        self.hidden1 = nn.Linear(28*28, 128 ) #第1个参数:有多少个神经元传入进来,第2个参数:有多少个数据传出
        self.hidden2 = nn.Linear(128, 256)
        self.out = nn.Linear(256, 10)
    def forward(self, x):
        x = self.flatten(x) #图像进行展开
        x = self.hidden1(x)
        x = torch.relu(x) #激活函数,torch使用的relu函数 relu,tanh
        x = self.hidden2(x)
        x = torch.relu(x)
        x = self.out(x)
        return x

model = NeuralNetwork().to(device) #把刚刚创建的模型传入到Gpu
print(model)

训练与测试

我们定义训练和测试函数,使用交叉熵损失函数和随机梯度下降优化器。

def train(dataloader, model, loss_fn, optimizer):
    model.train() #告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w。在训练过程中,w会被修改的
# #pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。
 # 一般用法是:在训练开始之前写上model.trian(),在测试时写上 model.eval()

    batch_size_num = 1
    for X, y in dataloader: #其中batch为每一个数据的编号
        X, y = X.to(device), y.to(device) #把训练数据集和标签传入cpu或GPU
        pred = model.forward(X) #.forward可以被省略,父类中已经对次功能进行了设置。自动初始化w
        loss= loss_fn(pred, y) #通过交叉熵损失函数计算损失值loss
        # Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络
        optimizer.zero_grad() #梯度值清零
        loss.backward() #反向传播计算得到每个参数的梯度值w
        optimizer.step() #根据梯度更新网络w参数

        loss_value = loss.item() #从tensor数据中提取数据出来,tensor获取损失值
        if batch_size_num % 100 ==0:
            print(f"loss: {loss_value:>7f} [number:{batch_size_num}]")
        batch_size_num += 1


def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset) #10000
    num_batches = len(dataloader) #打包的数量
    model.eval() #测试,w就不能再更新。
    test_loss, correct = 0, 0
    with torch.no_grad(): #一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model.forward(X)
            test_loss += loss_fn(pred, y).item()  #test_loss是会自动累加每一个批次的损失值
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)   #dim=1表示每一行中的最大值对应的索引号,dim=0表示每一列中的最大值
            b = (pred.argmax(1) == y).type(torch.float)
    test_loss /= num_batches #能来衡量模型测试的好坏。
    correct /= size #平均的正确率

    print(f"Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}")

训练模型

最后,我们训练模型并测试其性能。

loss_fn = nn.CrossEntropyLoss() #创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果

optimizer = torch.optim.SGD(model.parameters(), lr=0.01) #创建一个优化器,SGD为随机梯度下降算法
# #params:要训练的参数,一般我们传入的都是model.parameters()# #lr:learning_rate学习率,也就是步长

#loss表示模型训练后的输出结果与,样本标签的差距。如果差距越小,就表示模型训练越好,越逼近干真实的模型。

# train(train_dataloader, model, loss_fn, optimizer)
# test(test_dataloader, model, loss_fn)

epochs = 30
for t in range(epochs):
    print(f"Epoch {t+1}\n")
    train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

运行结果

在这里插入图片描述

结论

通过这篇文章,我们成功构建了一个简单的神经网络模型来识别MNIST数据集中的手写数字。这个模型展示了如何使用PyTorch进行数据处理、模型定义、训练和测试。希望这能帮助你开始自己的深度学习项目!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2339127.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode:DFS综合练习

简单 1863. 找出所有子集的异或总和再求和 一个数组的 异或总和 定义为数组中所有元素按位 XOR 的结果&#xff1b;如果数组为 空 &#xff0c;则异或总和为 0 。 例如&#xff0c;数组 [2,5,6] 的 异或总和 为 2 XOR 5 XOR 6 1 。 给你一个数组 nums &#xff0c;请你求出 n…

Perf学习

重要的能解决的问题是这些&#xff1a; perf_events is an event-oriented observability tool, which can help you solve advanced performance and troubleshooting functions. Questions that can be answered include: Why is the kernel on-CPU so much? What code-pa…

齐次坐标变换+Unity矩阵变换

矩阵变换 变换&#xff08;transform)&#xff1a;指的是我们把一些数据&#xff0c;如点&#xff0c;方向向量甚至是颜色&#xff0c;通过某种方式&#xff08;矩阵运算&#xff09;&#xff0c;进行转换的过程。 变换类型 线性变换&#xff1a;保留矢量加和标量乘的计算 f(x)…

Pandas取代Excel?

有人在知乎上提问&#xff1a;为什么大公司不用pandas取代excel&#xff1f; 而且列出了几个理由&#xff1a;Pandas功能比Excel强大&#xff0c;运行速度更快&#xff0c;Excel除了简单和可视化界面外&#xff0c;没有其他更多的优势。 有个可怕的现实是&#xff0c;对比Exce…

启动vite项目报Unexpected “\x88“ in JSON

启动vite项目报Unexpected “\x88” in JSON 通常是文件被防火墙加密需要寻找运维解决 重启重装npm install

HTTP测试智能化升级:动态变量管理实战与效能跃迁

在Web应用、API接口测试等领域&#xff0c;测试场景的动态性和复杂性对测试数据的灵活管理提出了极高要求。传统的静态测试数据难以满足多用户并发、参数化请求及响应内容验证等需求。例如&#xff0c;在电商系统性能测试中&#xff0c;若无法动态生成用户ID、订单号或实时提取…

关于一对多关系(即E-R图中1:n)中的界面展示优化和数据库设计

前言 一对多&#xff0c;是常见的数据库关系。在界面设计时&#xff0c;有时为了方便&#xff0c;就展示成逗号分割的字符串。例如&#xff1a;学生和爱好的界面。 存储 如果是简单存储&#xff0c;建立数据库&#xff1a;爱好&#xff0c;课程&#xff0c;存在一张表中。 但…

JVM笔记【一】java和Tomcat类加载机制

JVM笔记一java和Tomcat类加载机制 java和Tomcat类加载机制 Java类加载 * loadClass加载步骤类加载机制类加载器初始化过程双亲委派机制全盘负责委托机制类关系图自定义类加载器打破双亲委派机制 Tomcat类加载器 * 为了解决以上问题&#xff0c;tomcat是如何实现类加载机制的…

React 组件类型详解:类组件 vs. 函数组件

React 是一个用于构建用户界面的 JavaScript 库&#xff0c;其核心思想是组件化开发。React 组件可以分为类组件&#xff08;Class Components&#xff09;和函数组件&#xff08;Function Components&#xff09;&#xff0c;它们在设计理念、使用方式和适用场景上有所不同。随…

GPT-SoVITS 使用指南

一、简介 TTS&#xff08;Text-to-Speech&#xff0c;文本转语音&#xff09;&#xff1a;是一种将文字转换为自然语音的技术&#xff0c;通过算法生成人类可听的语音输出&#xff0c;广泛应用于语音助手、无障碍服务、导航系统等场景。类似的还有SVC&#xff08;歌声转换&…

美信监控易:数据采集与整合的卓越之选

在当今复杂多变的运维环境中&#xff0c;一款具备强大数据采集与整合能力的运维管理软件对于企业的稳定运行和高效决策至关重要。美信监控易正是这样一款在数据采集与整合方面展现出显著优势的软件&#xff0c;以下是它的一些关键技术优势&#xff0c;值得每一个运维团队深入了…

End-to-End从混沌到秩序:基于LLM的Pipeline将非结构化数据转化为知识图谱

摘要:本文介绍了一种将非结构化数据转换为知识图谱的端到端方法。通过使用大型语言模型(LLM)和一系列数据处理技术,我们能够从原始文本中自动提取结构化的知识。这一过程包括文本分块、LLM 提示设计、三元组提取、归一化与去重,最终利用 NetworkX 和 ipycytoscape 构建并可…

MySql 三大日志(redolog、undolog、binlog)详解

![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/aa730ab3f84049638f6c9a785e6e51e9.png 1. redo log&#xff1a;“你他妈别丢数据啊&#xff01;” 干啥的&#xff1f; 这货是InnoDB的“紧急备忘录”。比如你改了一条数据&#xff0c;MySQL怕自己突然断电嗝屁了&am…

HTTP:九.WEB机器人

概念 Web机器人是能够在无需人类干预的情况下自动进行一系列Web事务处理的软件程序。人们根据这些机器人探查web站点的方式,形象的给它们取了一个饱含特色的名字,比如“爬虫”、“蜘蛛”、“蠕虫”以及“机器人”等!爬虫概述 网络爬虫(英语:web crawler),也叫网络蜘蛛(…

2025妈妈杯数学建模C题完整分析论文(共36页)(含模型建立、可运行代码、数据)

2025 年第十五届 MathorCup 数学建模C题完整分析论文 目录 摘 要 一、问题分析 二、问题重述 三、模型假设 四、 模型建立与求解 4.1问题1 4.1.1问题1思路分析 4.1.2问题1模型建立 4.1.3问题1代码&#xff08;仅供参考&#xff09; 4.1.4问题1求解结果&#xff08;仅…

数据结构排序算法全解析:从基础原理到实战应用

在计算机科学领域&#xff0c;排序算法是数据处理的核心技术之一。无论是小规模数据的简单整理&#xff0c;还是大规模数据的高效处理&#xff0c;选择合适的排序算法直接影响着程序的性能。本文将深入解析常见排序算法的核心思想、实现细节、特性对比及适用场景&#xff0c;帮…

UMG:ListView

1.创建WBP_ListView,添加Border和ListView。 2.创建Object,命名为Item(数据载体&#xff0c;可以是其他类型)。新增变量name。 3.创建User Widget&#xff0c;命名为Entry(循环使用的UI载体).添加Border和Text。 4.设置Entry继承UserObjectListEntry接口。 5.Entry中对象生成时…

每天学一个 Linux 命令(18):mv

​​可访问网站查看&#xff0c;视觉品味拉满&#xff1a; http://www.616vip.cn/18/index.html 每天学一个 Linux 命令&#xff08;18&#xff09;&#xff1a;mv 命令功能 mv&#xff08;全称&#xff1a;move&#xff09;用于移动文件/目录或重命名文件/目录&#xff0c;是…

ubuntu24.04上使用qemu和buildroot模拟vexpress-ca9开发板构建嵌入式arm linux环境

1 准备工作 1.1 安装qemu 在ubuntu系统中使用以下命令安装qemu。 sudo apt install qemu-system-arm 安装完毕后&#xff0c;在终端输入: qemu- 后按TAB键&#xff0c;弹出下列命令证明安装成功。 1.2 安装arm交叉编译工具链 sudo apt install gcc-arm-linux-gnueabihf 安装之…

IntelliSense 已完成初始化,但在尝试加载文档时出错

系列文章目录 文章目录 系列文章目录前言一、原因二、使用步骤 前言 IntelliSense 已完成初始化&#xff0c;但在尝试加载文档时出错 File path: E:\QtExercise\DigitalPlatform\DigitalPlatform\main\propertyWin.ui Frame GUID:96fe523d-6182-49f5-8992-3bea5f7e6ff6 Frame …