使用简单MLP实现0-9数字识别,数据集为MNIST

news2025/1/6 19:59:42

简介

  • 需求:
    1. 基于pytorch实现简单MLP,完成数字识别,
    2. 采用MNIST手写数字作为数据集,MNIST:有6万张训练图片,1万张测试图片
    3. 训练结束后,随机取3张测试图片,展示模型的预测结果和真实图片
  • 模型结构:
    1. 一层线性层作为输入层,转换输入
    2. 中间三层线性层
    3. 一层softmax作为输出层,输出结果概率

模型结果

  • 输出结果
Accuracy: 0.0958
epoch:  0 Accuracy:  0.9406
epoch:  1 Accuracy:  0.9611
epoch:  2 Accuracy:  0.9631
epoch:  3 Accuracy:  0.9687
epoch:  4 Accuracy:  0.9703
epoch:  5 Accuracy:  0.9691
epoch:  6 Accuracy:  0.9737
epoch:  7 Accuracy:  0.974
epoch:  8 Accuracy:  0.9717
epoch:  9 Accuracy:  0.9731
Prediction:  tensor(7)
Prediction:  tensor(5)
Prediction:  tensor(4)
  • 真实的测试图片
    在这里插入图片描述

代码及注释

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt

# 定义模型结构
class Net(torch.nn.Module):
    # 初始化模型
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Linear(28*28, 64) # 输入维度为28*28,输出维度为64,第一个全连接层
        self.conv2 = torch.nn.Linear(64, 64) # 输入维度为64,输出维度为64,第二个全连接层
        self.conv3 = torch.nn.Linear(64, 64) # 输入维度为64,输出维度为64,第三个全连接层
        self.conv4 = torch.nn.Linear(64, 10) # 输入维度为64,输出维度为10,第四个全连接层,最后输出10个类别

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        # 由于我们的模型是一个多分类问题,所以我们需要使用softmax来归一化输出最后的概率,加上log是为了防止数值过小,导致数值溢出
        # 由于输出的形状为(batch_size,P),dim=1表示对第一个维度进行softmax操作,即对概率P进行softmax操作
        x = torch.nn.functional.log_softmax(self.conv4(x), dim=1)
        return x
# 定义加载数据函数
def get_data_loader(is_train):
    dataset = MNIST(root='./data', train=is_train, download=True, transform=transforms.ToTensor())
    # shuffle打乱原数据集,并按64个为一批次,返回到data_loader中
    # 数据加载器会按照指定的批次大小,从数据集中逐批加载数据。
    # 所以在每个迭代中,数据加载器会提供一个批次的数据,但是数据加载器会动态地在每次迭代中加载下一个批次的数据,直到遍历完整个数据集。
    data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
    # 返回数据加载器/迭代器
    return data_loader
# 定义测试函数
def evaluate(test_dataloader, net):
    correct = 0
    total = 0
    # 由于在测试部分,所以取消梯度更新
    with torch.no_grad():
        # 由于我们是在每个data_loader中,对模型进行参数更新,而模型是每个迭代进行一次更新,所以我们将data_loader叫做迭代器
        # 从迭代器,中获取一个批次的数据,返回给x,y
        for x,y in test_dataloader:
            # 由于我们的模型结构的第一层为linear(28*28,64),而数据集中的样本形状为(batch_size,1,28,28),1表示通道数,灰度图像只有一个通道
            # 所以我们需要将输入x的形状转换为(batch_size,28*28),-1表示自动计算,可以使用reshape或者view来进行形状转换
            # 随后将x输入到模型中,得到输出outputs
            outputs = net(x.reshape(-1, 28*28))
            # enumerate(outputs) 是一个 Python 内置函数,它用于将一个可迭代对象(如列表、元组、字符串等)包装成一个枚举对象,同时返回一个索引和对应的值。
            # 例如,outputs = [0.9, 0.8, 0.7],那么 enumerate(outputs) 的结果是 [(0, 0.9), (1, 0.8), (2, 0.7)]
            for i,output in enumerate(outputs):
                # torch.argmax(output) 返回output中最大值的索引,由于我们的输出维度为10,所以刚好为预测的类别(0-9)
                # 如果相同,那么预测正确,correct+1,同时不论正确与否,预测总数total+1
                if torch.argmax(output) == y[i]:
                    correct += 1
                total += 1
    return correct/total

if __name__ == '__main__':
    # 创建训练和测试数据加载器
    train_dataloader = get_data_loader(is_train=True)
    test_dataloader = get_data_loader(is_train=False)
    # 初始化神经网络
    net = Net()
    # 输出最开始时模型的准确率
    print("Accuracy:", evaluate(test_dataloader, net))
    # 定义优化器
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    # 开始训练
    for epoch in range(10):
        for x,y in train_dataloader:
            # 重置梯度
            optimizer.zero_grad()
            # 正向传播
            outputs = net(x.reshape(-1, 28*28))
            # 计算损失
            # 定义损失函数
            # 由于我们的模型输出为log_softmax,所以我们使用负对数似然损失函数
            # 同时,由于nll_loss不支持无参数传入,所以我们在此处定义损失函数,并同时传入outputs和y
            loss = torch.nn.functional.nll_loss(outputs, y)
            # 反向传播
            loss.backward()
            # 更新参数
            optimizer.step()
        # 每个epoch完成后,都输出模型的准确率
        print('epoch: ',epoch,"Accuracy: ", evaluate(test_dataloader, net))
    # 训练完成后,随机选取三张图像进行预测
    for images, labels in test_dataloader:
        for i in range(3):
            output = net(images[i].reshape(-1, 28*28))
            print("Prediction: ", torch.argmax(output))
            plt.imshow(images[i].view(28, 28), cmap='gray')
            plt.show()
        # print("Image Shape:", images.shape)  # 打印图像的形状
        # # print("Label:", labels.item())  # 打印标签
        # # print("Tensor Data:", images)  # 打印图像的张量数据
        # # plt.imshow(images.view(28, 28), cmap='gray')  # 显示图像
        # # plt.show()
        # outputs = net(images.reshape(-1, 28 * 28))
        # print(outputs.shape)
        # print(outputs)
        break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1570867.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

策略模式图

策略模式 小小的图解 主要的三个角色 Strategy—抽象策略角色ConcreateStrategy—具体策略角色Context—上下文角色 封装了对具体策略的调用可以使用set的依赖注入也可以使用构造方法 核心是上下文角色 只要调用上下文角色就行,实现解耦 策略 工厂 将上下文角…

【java探索之旅】逻辑控制掌握 顺序结构 分支语句

🎥 屿小夏 : 个人主页 🔥个人专栏 : Java编程秘籍 🌄 莫道桑榆晚,为霞尚满天! 文章目录 📑前言一、逻辑控制的概念二、顺序结构三、分支结构3.1 if语句3.2 if习题巩固3.3 细节注意项…

TR3 - Transformer算法详解

目录 文本输入处理词向量位置向量 编码器 EncoderSelf-Attention多头注意力机制残差连接 解码器 Decoder线性层与Softmax损失函数总结与心得体会 这周来看一下Transformer是怎么将文本转换成向量,然后又输入到模型处理并得到最终的输出的。 文本输入处理 词向量 …

递归算法解读

递归(Recursion)是计算机科学中的一个重要概念,它指的是一个函数(或过程)在其定义中直接或间接地调用自身。递归函数通过把问题分解为更小的相似子问题来解决原问题,这些更小的子问题也使用相同的解决方案&…

2024-04-05 问AI: 介绍一下深度学习中的Leaky ReLU函数

文心一言 Leaky ReLU(Leaky Rectified Linear Unit)函数是深度学习领域中的一种激活函数,主要用于解决ReLU(Rectified Linear Unit)函数在负值区域为零的问题。ReLU函数在输入为正时保持原样,而在输入为负…

Navicat工具使用

Navicat的本质: 在创立连接时提前拥有了数据库用户名和密码 双击数据库时,相当于建立了一个链接关系 点击运行时,远程执行命令,就像在xshell上操作Linux服务器一样,将图像化操作转换成SQL语句去后台执行 一、打开Navi…

Python学习: 错误和异常

Python 语法错误 解析错误(Parsing Error)通常指的是程序无法正确地解析(识别、分析)所给定的代码,通常是由于代码中存在语法错误或者其他无法理解的结构导致的。这可能是由于缺少括号、缩进错误、未关闭的引号或其他括号等问题造成的。 语法错误(Syntax Error)是指程序…

CSS设置网页颜色

目录 前言: 1.颜色名字: 2.十六进制码: 3.RGB: 4.RGBA: 5.HSL: 1.hue: 2.saturation: 3.lightness: 6.HSLA: 前言: 我们在电脑显示器&…

【NLP练习】中文文本分类-Pytorch实现

中文文本分类-Pytorch实现 🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 一、准备工作 1. 任务说明 本次使用Pytorch实现中文文本分类。主要代码与文本分类代码基本一致,不同的是本次任务使用…

[中级]软考_软件设计_计算机组成与体系结构_07_存储系统

存储系统 层次划存储概念图局促性原理分类存储器位置存取方式按内容存储按地址存储 工作方式拓展 往年真题 高速缓存(cache)概念案例解析:求取平均时间 Cache与主存的地址映射映像往年真题 主存编制计算编址大小的求取编址与计算存储单元编址内容总容量求取例题解析…

c# wpf template itemtemplate+dataGrid

1.概要 2.代码 <Window x:Class"WpfApp2.Window8"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d"http://schemas.microsoft.com/expression/blend…

[C#]OpenCvSharp使用帧差法或者三帧差法检测移动物体

关于C版本帧差法可以参考博客 [C]OpenCV基于帧差法的运动检测-CSDN博客https://blog.csdn.net/FL1768317420/article/details/137397811?spm1001.2014.3001.5501 我们将参考C版本转成opencvsharp版本。 帧差法&#xff0c;也叫做帧间差分法&#xff0c;这里引用百度百科上的…

【力扣每日一题】1026. 节点与其祖先之间的最大差值

LC 1026. 节点与其祖先之间的最大差值 题目描述 给定二叉树的根节点 root&#xff0c;找出存在于 不同 节点 A 和 B 之间的最大值 V&#xff0c;其中 V |A.val - B.val|&#xff0c;且 A 是 B 的祖先。 &#xff08;如果 A 的任何子节点之一为 B&#xff0c;或者 A 的任何子…

https证书申请方式

网站HTTPS证书&#xff0c;也称为SSL证书或TLS证书&#xff0c;是一种数字证书&#xff0c;用于在用户浏览器与网站服务器之间建立安全的加密连接。当网站安装了HTTPS证书后&#xff0c;用户访问该网站时&#xff0c;浏览器地址栏会显示为"https://"开头&#xff0c;…

CSS层叠样式表学习(文本属性)

&#xff08;大家好&#xff0c;今天我们将继续来学习CSS文本属性的相关知识&#xff0c;大家可以在评论区进行互动答疑哦~加油&#xff01;&#x1f495;&#xff09; 目录 四、CSS文本属性 4.1 文本颜色 4.2 对齐文本 4.3 装饰文本 4.4 文本缩进 4.5 行间距 4.6 文本…

简单的购物商城

SSM整合后的一个及其简单的商城&#xff0c;首页数据是模拟的&#xff0c;主要测试购物车模块 启动 创建数据库&#xff1a;shopping导入建表脚本&#xff1a;shopping.sql修改db.properties部署和启动项目&#xff08;项目的path为项目名&#xff09;访问 http://localhost:…

Python语言在地球科学领域中的应用

Python是功能强大、免费、开源&#xff0c;实现面向对象的编程语言&#xff0c;Python能够运行在Linux、Windows、Macintosh、AIX操作系统上及不同平台&#xff08;x86和arm&#xff09;&#xff0c;Python简洁的语法和对动态输入的支持&#xff0c;再加上解释性语言的本质&…

猫头虎技术分享 || 断网了,还能ping127.0.0.1吗?

博主猫头虎的技术世界 &#x1f31f; 欢迎来到猫头虎的博客 — 探索技术的无限可能&#xff01; 专栏链接&#xff1a; &#x1f517; 精选专栏&#xff1a; 《面试题大全》 — 面试准备的宝典&#xff01;《IDEA开发秘籍》 — 提升你的IDEA技能&#xff01;《100天精通鸿蒙》 …

Shell GPT:直接安装使用的chatgpt应用软件

ShellGPT是一款基于预训练生成式Transformer模型&#xff08;如GPT系列&#xff09;构建的智能Shell工具。它将先进的自然语言处理能力集成到Shell环境中&#xff0c;使用户能够使用接近日常对话的语言来操作和控制操作系统。 官网&#xff1a;GitHub - akl7777777/ShellGPT: *…

liteIDE自定义主题推荐

代码编辑器配色 \liteidex38.3-win64-qt5.15.2\liteide\share\liteide\liteeditor\color <?xml version"1.0" encoding"UTF-8"?> <style-scheme version"1.0" name"Sublime Text 2"><style name"Text" f…