基于多层感知机(MLP)实现MNIST手写体识别

news2025/2/28 14:57:52

实现步骤

  1. 下载数据集
  2. 处理好数据集
  3. 确定好模型(初始化模型参数等等)
  4. 确定优化函数(损失函数也称为目标函数)和优化方法(一般选用随机梯度下降 SDG )
  5. 进行模型的训练
  6. 进行模型的评估
import torch
import torchvision
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# 1. 下载数据集
mnist_train = torchvision.datasets.MNIST(root='../data', train=True, transform=transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST(root='../data', train=False, transform=transforms.ToTensor(), download=True)

# 2. 创建批量数据迭代器
train_iter = DataLoader(mnist_train, batch_size=256, shuffle=True)
test_iter = DataLoader(mnist_test, batch_size=256)

# 3. 可视化检查数据
var = next(iter(train_iter))
plt.title(str(var[1][0]))  # 显示标签
plt.imshow(var[0][0].squeeze().numpy(), cmap='gray')  # 显示图片
plt.show()

# 4. 定义模型:多层感知机
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, 512),
    nn.ReLU(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 10) # 注意这里是不需要加 Softmax 了的,因为后面定义了,nn.CrossEntropyLoss()这个会自动帮我们进行 Softmax 以及进行损失计算。其实就是目标函数
)

# 初始化模型参数
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

# 5. 定义损失函数和优化器
loss_fn = nn.CrossEntropyLoss()  # CrossEntropyLoss已经包含了softmax,所以不需要LogSoftmax
optimizer = optim.SGD(net.parameters(), lr=0.2)

# 6. 训练模型
epoch_num = 20
for epoch in range(epoch_num):
    net.train()  # 设置为训练模式
    total_loss = 0
    for X, y in train_iter:
        optimizer.zero_grad()  # 清除梯度
        y_hat = net(X)  # 前向传播
        loss = loss_fn(y_hat, y)  # 计算损失
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        total_loss += loss.item() * X.shape[0]  # 累积损失

    avg_loss = total_loss / len(mnist_train)  # 计算平均损失
    print(f'Epoch {epoch + 1}/{epoch_num}, Loss: {avg_loss:.4f}')

# 7. 评估模型
def evaluate_model(net, test_iter):
    net.eval()  # 设置为评估模式
    correct, total = 0, 0
    with torch.no_grad():  # 在评估时不需要计算梯度
        for X, y in test_iter:
            y_hat = net(X)
            _, predicted = torch.max(y_hat, 1)  # 获取预测的标签
            correct += (predicted == y).sum().item()  # 统计正确的个数
            total += y.size(0)  # 统计总数

    accuracy = correct / total
    print(f'Accuracy on test set: {accuracy * 100:.2f}%')

# 评估模型的表现
evaluate_model(net, test_iter)


代码实践的结果:

  1. 自己不会去计算损失
  2. 在模型进行训练的时候不知道如何把数据放进去:原来只需要创建好了 DataLoader 以后,通过枚举
    就可以拿到数据了。
  3. 最后进行模型评估的时候也是用 AI 进行完成了。所以多少还是差点意思。
  4. 后面的代码多去实践实践,并且思考吧!!!

关于代码中交叉熵计算的理解

理解损失函数(loss_fn)是如何计算的,对于训练神经网络来说是非常重要的。具体到你提到的这行代码:

loss = loss_fn(y_hat, y)  # 计算损失

损失函数的定义:

在你的代码中,损失函数是:

loss_fn = nn.CrossEntropyLoss()

nn.CrossEntropyLoss() 是一种常用于多分类问题的损失函数,它实际上包含了两个步骤:

  1. Softmax:将模型的输出转换为概率分布。
  2. 交叉熵损失:计算真实标签与预测概率分布之间的差距。

为什么要用交叉熵呢?因为交叉熵可以来衡量预测差距,这个我们只需要这个知识点,并且知道上面的公式就好了。

我们逐步分析这两个步骤。

1. Softmax(概率转换)

假设模型的输出 y_hat 是一个向量,其中每个元素代表对应类别的“分数”(或者说是原始的 logits)。例如,假设有 3 个类别,模型的输出可能是:

y_hat = [2.0, 1.0, -1.0]  # 这三个数字是 logits,不是概率

通过 Softmax 函数,我们将这些 logits 转换成概率:

# 计算 softmax
softmax = torch.nn.functional.softmax(y_hat, dim=-1)

softmax 的输出会是一个概率分布,每个数值的范围在 [0, 1] 之间,且所有数值加起来为 1。例如,经过 Softmax 后可能得到:

softmax = [0.7, 0.2, 0.1]  # 类别 0 的概率是 0.7,类别 1 的概率是 0.2,类别 2 的概率是 0.1

2. 交叉熵损失(Cross Entropy Loss)

交叉熵是衡量两个概率分布之间差异的一个标准方法。在分类任务中,我们希望预测的类别概率与真实标签分布尽可能接近。

对于一个单一的样本,交叉熵损失的计算公式为:

L = − ∑ i = 1 C y i log ⁡ ( p i ) L = - \sum_{i=1}^{C} y_i \log(p_i) L=i=1Cyilog(pi)

  • ( C ) 是类别数。
  • ( y_i ) 是真实标签(在 one-hot 编码下,真实类别的标签为 1,其他类别为 0)。
  • ( p_i ) 是模型预测的概率。

对于多分类任务来说,交叉熵损失会选择对应真实标签的类别概率 ( p_{\text{true}} ) 来计算损失。例如,如果真实标签是类别 0,那么我们只关注模型在类别 0 上的预测概率。

假设真实标签 y 是类别 0,对应的 one-hot 编码是 [1, 0, 0],而模型的预测是:

softmax = [0.7, 0.2, 0.1]

那么交叉熵损失为:

L = − ( 1 ⋅ log ⁡ ( 0.7 ) + 0 ⋅ log ⁡ ( 0.2 ) + 0 ⋅ log ⁡ ( 0.1 ) ) = − log ⁡ ( 0.7 ) ≈ 0.3567 L = - (1 \cdot \log(0.7) + 0 \cdot \log(0.2) + 0 \cdot \log(0.1)) = - \log(0.7) \approx 0.3567 L=(1log(0.7)+0log(0.2)+0log(0.1))=log(0.7)0.3567

nn.CrossEntropyLoss() 如何工作

在 PyTorch 中,nn.CrossEntropyLoss 会自动处理上述两个步骤:

  1. y_hat(logits)转换为概率。
  2. 使用真实标签 y 计算交叉熵损失。
输入和输出:
  • y_hat: 这是模型的原始输出(logits),形状为 (batch_size, num_classes)。每一行是一个样本的 logits。
  • y: 这是标签,通常是一个包含类别索引的向量,形状为 (batch_size,)。每个元素是该样本的真实类别索引。

例如:

假设我们有以下数据:

  • 模型的输出(logits)为:

    y_hat = torch.tensor([[2.0, 1.0, -1.0],  # 第一个样本
                          [0.5, 1.5, 0.3]]) # 第二个样本
    
  • 真实标签 y 为:

    y = torch.tensor([0, 1])  # 第一个样本的标签是类别 0,第二个样本的标签是类别 1
    

使用 nn.CrossEntropyLoss() 计算损失:

loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(y_hat, y)

CrossEntropyLoss 会首先对 y_hat 进行 softmax 转换,然后计算每个样本的交叉熵损失。你可以通过打印出来的 loss 来查看模型的表现。

总结:

  • y_hat 是模型的原始输出(logits),表示每个类别的“分数”。
  • nn.CrossEntropyLoss 会自动处理 softmax 和交叉熵损失的计算。
  • 损失函数的目的是衡量模型的输出与真实标签之间的差异,差异越小,损失值越小,说明模型的预测越准确。

使用`nn.CrossEntropyLoss 会自动进行独热编码

在计算交叉熵损失时,nn.CrossEntropyLoss 会自动处理标签,并且不需要你手动将标签转换为独热编码(one-hot encoding)。

具体来说:

  • y_hat:是模型的原始输出(logits),形状为 (batch_size, num_classes),每一行是一个样本的预测结果,包含每个类别的分数(logits)。
  • y:是标签,形状为 (batch_size,),每个元素是该样本的真实类别的 索引,而不是独热编码。

nn.CrossEntropyLoss 会自动使用标签 y 中的类别索引(如类别 0, 1, 2)来计算损失,它会根据该类别索引选择对应的模型输出进行计算,而不需要你事先将标签转换为独热编码。

举个例子:

假设我们有一个批次的两个样本,模型的输出 y_hat 和真实标签 y 如下:

模型的输出 y_hat(logits):
y_hat = torch.tensor([[2.0, 1.0, -1.0],  # 第一个样本的 logits
                      [0.5, 1.5, 0.3]]) # 第二个样本的 logits
真实标签 y(类别索引):
y = torch.tensor([0, 1])  # 第一个样本的标签是类别 0,第二个样本的标签是类别 1

在这个例子中,y_hat 的形状是 (2, 3),表示有两个样本,每个样本有三个类别的 logits。

  • 对于第一个样本,它的真实标签是类别 0y[0] = 0
  • 对于第二个样本,它的真实标签是类别 1y[1] = 1

当使用 nn.CrossEntropyLoss 时,它会根据真实标签中的类别索引来选择对应的类别概率(通过 softmax 后的结果),然后计算交叉熵损失。PyTorch 会自动完成:

  1. Softmax 转换:将 y_hat 的 logits 转换为概率分布。
  2. 损失计算:根据真实标签的类别索引计算交叉熵损失。

损失计算过程:

假设 y_hat 的 softmax 输出(经过 softmax 转换后)为:

softmax = torch.nn.functional.softmax(y_hat, dim=1)

# softmax 结果:
# [[0.7, 0.2, 0.1],  # 第一个样本的概率
#  [0.3, 0.6, 0.1]]  # 第二个样本的概率

然后计算交叉熵损失时:

  • 对于第一个样本(真实标签是 0),损失将仅计算 softmax[0][0](即类别 0 的概率),即 -log(0.7)
  • 对于第二个样本(真实标签是 1),损失将仅计算 softmax[1][1](即类别 1 的概率),即 -log(0.6)

因此,nn.CrossEntropyLoss 会自动从 y 中提取类别索引,并计算相应的损失。

总结:

  • 不需要手动进行独热编码,因为 nn.CrossEntropyLoss 会自动处理。
  • 标签 y 只需要是一个包含类别索引的张量(每个元素是一个整数,表示类别)。
  • nn.CrossEntropyLoss 会根据标签中的类别索引,从 y_hat 中提取对应的 logits,然后计算交叉熵损失。

通过这种方式,PyTorch 简化了损失计算的过程,让你不需要显式地进行独热编码的转换。

为什么在实验的时候我去掉了参数初始化对结果并没有很大影响?

在深度学习模型中,尤其是在多层感知机(MLP)中,初始化模型参数对于训练的收敛速度和结果有重要影响。虽然初始化通常有助于提升训练过程的稳定性和准确性,但在某些情况下,即使没有显式初始化,训练仍然可能得到较好的结果。这有以下几个原因:

1. PyTorch 默认的权重初始化

如果你没有手动初始化模型参数,PyTorch 会使用默认的初始化方式。对于 nn.Linear 层,PyTorch 默认使用的是 均匀分布初始化U(-sqrt(k), sqrt(k))),其中 k 是每一层的输入特征数(即该层的输入维度)。这种初始化方式通常可以在大部分场景下有效工作,尤其是在简单的任务或小型模型中。

2. 任务本身比较简单

MNIST 数据集是一个相对简单的任务,具有以下特点:

  • 样本相对简单(28x28 的灰度图像)。
  • 类别数量较少(10 个类别)。
  • 数据集规模较小(60,000 个训练样本)。

由于这些原因,即使没有特别优化初始化方式,模型仍然能在训练过程中较好地拟合数据,因此准确率可能不会受到显著影响。

3. 优化器的鲁棒性

现代优化器(如 SGD、Adam 等)通常具有较强的鲁棒性,能够在一定范围内有效地调整模型的参数,避免了初始化差异带来的过度影响。即使没有进行显式初始化,优化器也能够逐步调整模型的参数,从而避免梯度消失或梯度爆炸等问题,保证训练的顺利进行。

4. 训练过程中参数的调整

在模型训练初期,即使初始化不完美,随着训练的进行,网络的权重会在反向传播过程中逐步调整到合适的值。因此,即使开始时的参数较为随机,优化过程仍然能够找到有效的解决方案。这就是深度学习的一个特性:即使参数初始不理想,优化过程通常能通过梯度更新找到合适的解。

5. 初始化不影响最终收敛结果

对于一些简单的任务,模型可能在多个初始化条件下都能够达到一个相对接近的局部最优解。在这种情况下,即使没有手动初始化权重,模型也能收敛到较好的解。

总结:

  • 默认初始化(PyTorch 内部的初始化方式)通常已经能在很多简单的任务中有效工作,特别是像 MNIST 这样简单的图像分类任务。
  • 优化器的鲁棒性帮助模型调整参数,避免了初始化不完美时对结果产生显著影响。
  • 对于 MNIST 这种简单任务,初始化参数的不同可能不会导致显著差异,尤其是在训练的过程中,优化器能够找到较好的解。

然而,在一些更复杂的任务中,初始化的方式会直接影响模型的训练效率和性能。在这些任务中,精心设计的初始化(例如 Xavier、He 初始化等)能够帮助模型更快地收敛并避免训练过程中遇到的问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2307442.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【机器学习】Logistic回归#1基于Scikit-Learn的简单Logistic回归

主要参考学习资料: 《机器学习算法的数学解析与Python实现》莫凡 著 前置知识:线性代数-Python 目录 问题背景数学模型类别表示Logistic函数假设函数损失函数训练步骤 代码实现特点 问题背景 分类问题是一类预测非连续(离散)值的…

8.Dashboard的导入导出

分享自己的Dashboard 1. 在Dashboard settings中选择 JSON Model 2. 导入 后续请参考第三篇导入光放Dashboard,相近

next.js-学习2

next.js-学习2 1. https://nextjs.org/learn/dashboard-app/getting-started2. 模拟的数据3. 添加样式4. 字体,图片5. 创建布局和页面页面导航 1. https://nextjs.org/learn/dashboard-app/getting-started /app: Contains all the routes, components, and logic …

视频推拉流EasyDSS直播点播平台授权激活码无效,报错400的原因是什么?

在当今数字化浪潮中,视频推拉流 EasyDSS 视频直播点播平台宛如一颗璀璨的明珠,汇聚了视频直播、点播、转码、精细管理、录像、高效检索以及时移回看等一系列强大功能于一身,全方位构建起音视频服务生态。它既能助力音视频采集,精准…

【论文详解】Transformer 论文《Attention Is All You Need》能够并行计算的原因

文章目录 前言一、传统 RNN/CNN 存在的串行计算问题二、Transformer 如何实现并行计算?三、Transformer 的 Encoder 和 Decoder 如何并行四、结论 前言 亲爱的家人们,创作很不容易,若对您有帮助的话,请点赞收藏加关注哦&#xff…

Framework层JNI侧Binder

目录 一,Binder JNI在整个系统的位置 1.1 小结 二,代码分析 2.1 BBinder创建 2.2 Bpinder是在查找服务时候创建的 2.3 JNI实现 2.4 JNI层android_os_BinderProxy_transact 2.5 BPProxy实现 2)调用IPCThreadState发送数据到Binder驱动…

Excel大文件拆分

import pandas as pddef split_excel_file(input_file, output_prefix, num_parts10):# 读取Excel文件df pd.read_excel(input_file)# 计算每部分的行数total_rows len(df)rows_per_part total_rows // num_partsremaining_rows total_rows % num_partsstart_row 0for i i…

OpenCV计算摄影学(7)HDR成像之多帧图像对齐的类cv::AlignMTB

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 该算法将图像转换为‌中值阈值位图‌(Median Threshold Bitmap,MTB): 1.位图生成‌:…

Axure PR 9 中继器 03 翻页控制

大家好,我是大明同学。 接着上期的内容,这期内容,我们来了解一下Axure中继器图表翻页控制。 预览地址:https://pvie5g.axshare.com 翻页控制 1.打开上期RP 文件,在元件库中拖入一个矩形,宽值根据业务实际…

IO流(师从韩顺平)

文章目录 文件什么是文件文件流 常用的文件操作创建文件对象相关构造器和方法应用案例 获取文件的相关信息应用案例 目录的操作和文件删除应用案例 IO 流原理及流的分类Java IO 流原理IO流的分类 IO 流体系图-常用的类IO 流体系图(重要!!&…

Spring Boot集成Jetty、Tomcat或Undertow及支持HTTP/2协议

目录 一、常用Web服务器 1、Tomcat 2、Jetty 3、Undertow 二、什么是HTTP/2协议 1、定义 2、特性 3、优点 4、与HTTP/1.1的区别 三、集成Web服务器并开启HTTP/2协议 1、生成证书 2、新建springboot项目 3、集成Web服务器 3.1 集成Tomcat 3.2 集成Jetty 3.3 集成…

《Python实战进阶》专栏 No 5:GraphQL vs RESTful API 对比与实现

《Python实战进阶》专栏包括68集,每一集聚焦一个中高级技术知识点,涵盖Python在Web开发、数据处理、自动化、机器学习、并发编程等领域的应用,系统梳理Python开发者的知识集。本集的主题为: No4 : GraphQL vs RESTful API 对比与实…

MYSQL 5.7数据库,关于1067报错 invalid default value for,解决方法!

???作者: 米罗学长 ???个人简介:混迹java圈十余年,精通Java、小程序、数据库等。 ???各类成品java毕设 。javaweb,ssm,springboot,mysql等项目,源码丰富,欢迎咨询。 ???…

【Linux基础】Linux下的C编程指南

目录 一、前言 二、Vim的使用 2.1 普通模式 2.2 插入模式 2.3 命令行模式 2.4 可视模式 三、GCC编译器 3.1 预处理阶段 3.2 编译阶段 3.3 汇编阶段 3.4 链接阶段 3.5 静态库和动态库 四、Gdb调试器 五、总结 一、前言 在Linux环境下使用C语言进行编程是一项基础且…

浅谈HTTP及HTTPS协议

1.什么是HTTP? HTTP全称是超文本传输协议,是一种基于TCP协议的应用非常广泛的应用层协议。 1.1常见应用场景 一.浏览器与服务器之间的交互。 二.手机和服务器之间通信。 三。多个服务器之间的通信。 2.HTTP请求详解 2.1请求报文格式 我们首先看一下…

Pytest自定义测试用例执行顺序

文章目录 1.前言2.pytest默认执行顺序3.pytest自定义执行顺序 1.前言 在pytest中,我们可能需要自定义测试用例的执行顺序,例如登陆前需要先注册,这个时候就需要先执行注册的测试用例再执行登录的测试用例。 本文主要讲解pytest的默认执行顺序…

人大金仓KCA | 用户与角色

人大金仓KCA | 用户与角色 一、知识预备1. 用户和角色 二、具体实施1. 用户管理-命令行1.1 创建和修改用户1.2 修改用户密码1.3 修改用户的并发连接数1.4 修改用户的密码有效期 2.用户管理-EasyKStudio2.1 创建和修改用户2.2 修改用户密码2.3 修改用户的并发连接数2.4 修改用户…

【Azure 架构师学习笔记】- Azure Databricks (12) -- Medallion Architecture简介

本文属于【Azure 架构师学习笔记】系列。 本文属于【Azure Databricks】系列。 接上文 【Azure 架构师学习笔记】- Azure Databricks (11) – UC搭建 前言 使用ADB 或者数据湖,基本上绕不开一个架构“Medallion”, 它使得数据管理更为简单有效。ADB 通过…

智能证件照处理器(深度学习)

功能说明:支持常见证件照尺寸(一寸、二寸、护照等) 智能背景去除(使用深度学习模型)自定义背景颜色选择自动调整尺寸并保持比例实时预览处理效果注意:整合rembg进行抠图,使用Pillow处理图像缩放和背景替换,定义常见证件照尺寸,并提供用户交互选项。首次运行时会自动下…

C++-第十三章:红黑树

目录 第一节:红黑树的特征 第二节:实现思路 2-1.插入 2-1-1.unc为红 2-1-2.cur为par的左子树,且par为gra的左子树(cur在最左边) 2-1-2-1.unc不存在 2-1-2-2.unc为黑 2-1-3.cur为par的右子树,且par为gra的右子树(cur在最右侧) 2-…