通过PyTorch 手写数字识别 入门神经网络 详细讲解

news2025/1/22 16:52:37

通过PyTorch 手写数字识别 入门神经网络

数据集

在这里插入图片描述

MNIST数据集中有手写数字图片7万张,划分训练集6万张,划分测试集1万张。

每张图片都会有一张标签,也就是代表着图片的真实值(真实含义)。

概念 计算机是如何读取图片的呢?

把照片当作一个数列矩阵给计算机读取,将照片特征从右到左拼接成一列输入到网络层。

在这里插入图片描述

网络层

每一次的节点是由前一层计算得到的,a和b 分别代表系数(权重)和偏置项。

i表示前一层的节点序号,j表示当前节点的序号。

在机器学习和深度学习中,偏置项(Bias)是模型的一个重要组成部分。它是一个可学习的参数,通常用来调整模型的输出,使其能够更好地拟合训练数据。

在这里插入图片描述

最后图像的信息通过网络的传播一直传播到最后一层,k代表的是网络的层数。最后一层就是输出层,而最后有10个节点代表的分别是10个数字的可能性结果,每个节点对应一种可能。
在这里插入图片描述

归一化处理

因为每个节点代表的应该是概率,而每个节点的数值都应该是0<p<1,且每层的总和应该为1。

所以我们需要对节点做归一化处理

Softmax归一化

在多分类问题中,通常会使用softmax函数作为网络输出层的激活函数,softmax函数可以对输出值进行归一化操作,把所有输出值都转化为概率(0~1之间),所有概率值加起来等于1∶
在这里插入图片描述

例如:某个神经网络有3个输出值,为[1,5,3]。

在数学中有个数叫e(数学中一个常数,是一个无限不循环小数,且为超越数,其值约为2.718281828459045)

先计算出e1(e的1次方),e5,e3和它们的和的数值来,e1=2.718、e5=148.413,e3=20.086、e1+e5+e^3=171.217
1的概率:

在这里插入图片描述

3的概率:在这里插入图片描述
5的概率:
在这里插入图片描述

0.016+0.867+0.117=1

在这里插入图片描述

训练

现在我们的输出具有了概率这个概念,那么我们真正要使得我们的概率有意义,那么就需要进行"训练"!

我们一开始的概率分布是随机的,而我们这张图片代表是7,而理想状态下,这张图片是7的概率应该是百分百,而现实训练过程中与理想状态的差值便是损失loss。

所以为了减小损失,我们需要在训练过程中调整网络参数,也就是a和b,使得更接近与理想状态的预测判断概率。

调整网络参数的算法有很多,比如梯度下降算法,ADAM算法等等。从而神经网络问题就变成了一个最优化问题,在多次尝试下寻找到最优解。

在这里插入图片描述

而这仅仅是一张图片,假如我们对于上万张图片进行训练,从而调整得到合适的网络参数,便能使得我们的神经网络具备预测的能力,因为一次只输入一张图片,我们的效率会很低的,所以我们会分批次,几张图片一起输入到网络,这个批次的概念叫做batchSize。

在这里插入图片描述

激活函数

如果没有激活函数,观察我们的节点计算,我们会发现我们节点中的计算都是线性的,但我们生活中很多问题都不是线性的,输入和输出之间存在着非线性,因为一个线性函数无论怎么调整都调整不出非线性函数的效果(模拟出非线性行为),所以我们会在每一次计算中都套一层激活函数,从而达到非线性计算。

在这里插入图片描述

常见的激活函数如下:
在这里插入图片描述

在这个手写数字识别中,我们采用整流函数,因为当x小于0的时候都归0,x大于0的时候才会有数值,也就相当于激活的效果。

项目实现

安装库: pip install numpy torch torchvision matplotlib

pytorch GPU的安装方法更详细可以参考这篇文章:全网最详细的安装pytorch GPU方法,一次安装成功!!包括安装失败后的处理方法!-CSDN博客

首次运行会安装MNIST数据集。

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt


class Net(torch.nn.Module):

    def __init__(self):
        super().__init__()
        # 输入为28*28的像素图片,中间三层都放了64个节点
        self.fc1 = torch.nn.Linear(28 * 28, 64)
        self.fc2 = torch.nn.Linear(64, 64)
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)

    def forward(self, x):
        # fc1全连接线性计算,再套上激活函数relu
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        # log_softmax softmax归一化再套上log让计算更稳定
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x


def get_data_loader(is_train):
    # 进行数据转换,tensor就是一个多维数组(又叫张量) 定义数据转换类型
    to_tensor = transforms.Compose([transforms.ToTensor()])
    # 下载MNIST数据集,""代表当前目录,is_train用于指定是导入训练集还是测试集
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
    # batch_size 表示一个批次包含15张图片 ,shuffle表示数据是否是随机打乱的
    return DataLoader(data_set, batch_size=15, shuffle=True)

# 用于评估神经网络的正确率
def evaluate(test_data, net):
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        # 在测试集中按批次取出数据
        for (x, y) in test_data:
            # 计算神经网络的预测值 x代表图片, y代表真实结果(标签)
            outputs = net.forward(x.view(-1, 28 * 28))
            # 再与真实结果进行比较进行累加记录
            for i, output in enumerate(outputs):
                # argmax是找到一个数列中最大值的序号,也就是预测结果
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1
    return n_correct / n_total


def main():
    # 导入训练集和测试集
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    # 初始化神经网络
    net = Net()
    # 打印初始网络的正确率 一般是0.1,因为10种结果,猜对的概率是十分之一
    print("initial accuracy:", evaluate(test_data, net))

    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    # epoch是训练轮次
    for epoch in range(2):
        # 这部分基本是通用写法
        for (x, y) in train_data:
            # 初始化
            net.zero_grad()
            # 正向传播
            output = net.forward(x.view(-1, 28 * 28))
            # 计算差值 nll_loss 是一个对数损失函数 是为了匹配前面的log_softmax的对数运算
            loss = torch.nn.functional.nll_loss(output, y)
            # 反向误差传播
            loss.backward()
            # 优化网络参数
            optimizer.step()
        # 每个epoch结束后打印一次正确率
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))

    # 随机抽取三张图片验证模型性能
    for (n, (x, _)) in enumerate(test_data):
        if n > 3:
            break
        predict = torch.argmax(net.forward(x[0].view(-1, 28 * 28)))
        plt.figure(n)
        plt.imshow(x[0].view(28, 28))
        plt.title("prediction: " + str(int(predict)))
    plt.show()


if __name__ == "__main__":
    main()

正向传播传播算出当前的概率或值,反向传播将计算得到的损失告诉网络层,从而进行优化调整。

epoch代表的是训练轮次,比如六万张图片,训练了两次六万张图片,那么就是训练两个epoch。

最终结果:在这里插入图片描述

训练一次的时候正确率就从0.09达到了0.95,第二次就到了0.96,提升就相对少了。

概念和语法问题

上下文管理器

with torch.no_grad(): 是一个上下文管理器,它确保在其控制下的代码块内不会执行梯度计算。在 PyTorch 中,当我们构建计算图时,默认情况下会对每个操作进行跟踪,以便能够计算梯度。这对于训练模型是必要的,因为我们需要通过反向传播来更新权重。然而,在模型的评估阶段或者当我们只需要前向传递来得到输出而不需要更新模型参数时,保持梯度计算是不必要的,甚至会消耗额外的内存和计算资源。

使用 torch.no_grad() 的好处包括:

  • 节省内存:不需要存储中间变量的梯度信息。
  • 提高性能:省去了梯度计算的时间。

当你看到 with torch.no_grad():,这意味着在这段代码执行期间,所有涉及到自动梯度计算的操作都将被忽略,即创建的 Tensor 不会被加入到计算图中。这对于评估模型、生成模型输出以及任何不需要梯度计算的任务都是有用的。

示例代码:

with torch.no_grad():
    # 在这里创建的所有 Tensor 和执行的所有操作都不会被记录在计算图中
    predictions = model(inputs)

在你的代码中,torch.no_grad() 被用在 evaluate 函数中,以确保在评估模型的准确率时不会进行不必要的梯度计算,从而提高效率并节约内存。这是因为评估阶段我们关心的是模型的性能而非更新模型参数。

全连接层(Fully Connected Layer,简称 FC 层)是神经网络中最基本的组件之一,也是最直观的一种层。在一个全连接层中,前一层的所有神经元(节点)都与后一层的所有神经元相连,也就是说,每一层的每一个神经元都会接受前一层所有神经元的输出作为输入。

全连接层的工作原理

在一个全连接层中,每个神经元的输出是由前一层所有神经元的输出经过加权求和后加上偏置项(bias),然后通过激活函数计算得出的。数学上,可以用以下公式来表示:

z=W⋅x+bz=Wx+b

其中:

  • zz 是神经元的加权输入(未经过激活函数之前的值)。
  • WW 是权重矩阵。
  • xx 是输入向量。
  • bb 是偏置项。

接着,zz 会通过一个激活函数(如 ReLU、Sigmoid、tanh 或者其他激活函数)来产生非线性映射:

h=f(z)h=f(z)

这里的 h 是最终的输出,f 是激活函数。

应用场景

全连接层通常用于处理一维的数据,例如从卷积层提取的特征向量或者是展平后的图像数据。在图像识别、自然语言处理等领域,全连接层常用于提取特征之后的分类任务。例如,在你的代码中,全连接层用于将输入图像(展平后的28x28像素,共784个元素)映射到一个较低维度的空间,最后输出类别概率分布。

优化器

optimizer = torch.optim.Adam(net.parameters(), lr=0.001) 这一行代码的作用是在 PyTorch 中创建一个优化器实例,用于更新神经网络的参数。让我们分解一下这条语句的含义:

代码解析

  1. torch.optim.Adam:这是 PyTorch 提供的一种优化算法——Adam(Adaptive Moment Estimation)算法的实现。Adam 是一种自适应学习率优化方法,它结合了动量(Momentum)和 RMSProp 的优点,能够在训练过程中动态调整每个参数的学习率。
  2. net.parameters():这是神经网络模型 net 的所有可学习参数的迭代器。这些参数通常是模型中的权重和偏置项,它们是训练过程中需要更新的对象。
  3. lr=0.001:这是学习率(Learning Rate)的设定值。学习率决定了参数更新的步长大小。较高的学习率会使参数更新更快,但也可能导致训练过程不稳定;较低的学习率则会使训练过程更稳定,但可能需要更多的时间来收敛。

作用

这句话的主要作用是创建一个 Adam 优化器,并指定要优化的参数集合以及学习率为 0.001。这个优化器将在训练过程中使用,具体来说:

  • 初始化优化器:创建一个 Adam 优化器实例,准备好对模型的参数进行优化。
  • 参数绑定:将模型的所有可学习参数传递给优化器,以便在训练过程中更新这些参数。
  • 设置学习率:确定了优化过程中每次参数更新的步长大小。

参考:

10分钟入门神经网络 PyTorch 手写数字识别_哔哩哔哩_bilibili

pytorch tutorial: PyTorch 手写数字识别 教程代码 (gitee.com)

秒懂Softmax归一化_矩阵列向量softmax归一化计算-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2212850.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

鸿蒙开发案例:记忆翻牌

【游戏简介】 记忆翻牌游戏是一种经典的益智游戏&#xff0c;玩家需要翻开隐藏的卡片&#xff0c;找出所有成对的图案。每翻开一对卡片&#xff0c;如果图案相同&#xff0c;则这对卡片会永久显示出来&#xff0c;否则会在一段时间后自动翻回背面。游戏的目标是在尽可能短的时…

LabVIEW提高开发效率技巧----跨平台开发

在如今的多平台环境下&#xff0c;开发者常常面临不同操作系统的需求&#xff0c;如Windows、Linux和RT&#xff08;实时&#xff09;系统等。而LabVIEW作为一种强大的开发工具&#xff0c;提供了支持跨平台开发的能力&#xff0c;但要使其无缝迁移&#xff0c;开发者需要掌握一…

干货分享 | 同星多设备间的时间戳同步机制TSync功能与使用

随着汽车网络测试的通道数量不断增加&#xff0c;时常需要多个同星设备同时连接在同一台电脑的同一个TSMaster应用程序&#xff0c;并进行多设备同时执行CAN报文收发和记录等功能&#xff0c;必然有多设备之间的时间戳同步以及设备与电脑上操作系统的时间同步的要求。 为了满足…

5G 双卡双通演进

█ 双卡技术的演进历程 前面我有提到&#xff0c;世界上第一台双卡手机&#xff0c;诞生于 2004 年。 之所以会有双卡手机的出现&#xff0c;和当时特殊的历史背景有关。那一时期&#xff0c;中国大陆市场只有两家移动通信运营商&#xff0c;分别是中国移动和中国联通。中国移…

轻松入门:Maven核心功能详解

White graces&#xff1a;个人主页 &#x1f649;专栏推荐:Java入门知识&#x1f649; ⛳️点赞 ☀️收藏⭐️关注&#x1f4ac;卑微小博主&#x1f64f; ⛳️点赞 ☀️收藏⭐️关注&#x1f4ac;卑微小博主&#x1f64f; 目录 Maven Maven核心功能 1. 项目构建 2. 依赖管…

超材料光子晶体和禁带分析实例_CST电磁仿真教程

光子晶体是由周期性排列的不同折射率的介质制造的光学结构&#xff0c;可被视为广义超材料metamaterial的一种。本期我们演示设计一个基于光频能带(PBG,photonics band gap) 的二维光子晶体波导&#xff0c;能带分析方法也可适用于微波波段&#xff08;EBG,electromagetic band…

QT事件与网络通信

闹钟 头文件 #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow> #include <QTimer> #include <QTextToSpeech> // 添加此行以引入QTextToSpeech类QT_BEGIN_NAMESPACE namespace Ui { class MainWindow; } QT_END_NAMESPACEclass MainWin…

通信接入技术

一、xDSL 1、xDSL&#xff1a;利用电话线中的高频信息传输数据&#xff0c;高频信号损耗大&#xff0c;容易受噪声干扰。【速率越高&#xff0c;传输距离越近】 1.1 ADSL虚拟拨号&#xff1a;采用专门的协议PPPover Ethernet&#xff0c;拨号后直接由验证服务器进行检验&#…

免费版视频压缩软件:让视频处理更便捷

现在不少人已经习惯通过视频来记录生活、传播信息和进行娱乐的重要方式。但是由于设备大家现在录制的文件都会比较大&#xff0c;这时候就比较需要一些缩小视频的工具了。今天我们一起来探讨视频压缩软件免费版来为我们带来的生动世界。 1.Foxit视频压缩大师 链接直达&#x…

【论文笔记】Adversarial Diffusion Distillation

Abstract 本文提出了一种新的训练方法&#xff0c;在保持较高图像质量的前提下&#xff0c;仅用1~4步就能有效地对大规模传统图像扩散模型进行采样&#xff0c;使用分数蒸馏(score distillation)&#xff0c;来利用大规模现成的图像扩散模型作为教师信号&#xff0c;并结合对抗…

CVE-2022-26965靶机渗透

​ 开启环境 ​ ​ 进入环境 ​ ​ 使用弱口令admin登录 ​ ​ 利用cms主题构造木马 ​ 需要将主题中的info.php文件修改&#xff0c;再打包成zip再上传&#xff0c;通过网络搜索找到Github中的Pluck CMS&#xff0c;进入后随便下载任一主题 https://github.com/sear…

解锁编程的力量:SPL的学习之旅

SPL 一、前言二、集算器应用场景三、下载四、集算器的基本使用 一、前言 一种面向结构化数据的程序计算语言 集算器又称&#xff1a;SPL&#xff08;Structured Process Language&#xff09; 敏捷计算是集算器的主要特征 二、集算器应用场景 数据准备&#xff08;跑批&…

闭着眼学机器学习——决策树分类

引言&#xff1a; 在正文开始之前&#xff0c;首先给大家介绍一个不错的人工智能学习教程&#xff1a;https://www.captainbed.cn/bbs。其中包含了机器学习、深度学习、强化学习等系列教程&#xff0c;感兴趣的读者可以自行查阅。 1. 算法介绍 决策树是一种常用的机器学习算法…

Linux SSH无密码使用私钥远程登录连接详细配置流程

文章目录 前言1. Linux 生成SSH秘钥对2. 修改SSH服务配置文件3. 客户端秘钥文件设置4. 本地SSH私钥连接测试5. Linux安装Cpolar工具6. 配置SSHTCP公网地址7. 远程SSH私钥连接测试8. 固定SSH公网地址9. 固定SSH地址测试 前言 本文将详细介绍如何将Linux SSH服务与cpolar相结合&…

modbus tcp wireshark抓包

Modbus TCP报文详解与wireshark抓包分析_mbap-CSDN博客 关于wireshark无法分析出modbusTCP报文的事情_wireshark 协议一列怎么没有modbus tcp-CSDN博客 使用Wireshark过滤Modbus功能码 - 技象科技 连接建立以后才能显示Modbus TCP报文 modbus.func_code 未建立连接时&…

论文阅读MOE-DAMEX

目录 Abstract 1. Introduction 3. 传统的MOE 4. 方法 题目&#xff1a;DAMEX: Dataset-aware Mixture-of-Experts for visual understanding of mixture-of-datasets数据集感知的专家混合模型&#xff0c;用于混合数据集的视觉理解 Abstract 通用普通的detector的构建提…

使用HTML、CSS和JavaScript创建图像缩放功能

使用HTML、CSS和JavaScript创建图像缩放功能 在这篇博客文章中&#xff0c;我们将介绍如何使用HTML、CSS和JavaScript创建一个简单的图像缩放功能。这个功能可以增强用户体验&#xff0c;让访问者在点击图像时查看更大的版本。 效果 步骤1&#xff1a;设置HTML结构 首先&…

python异常检测 - 随机离群选择Stochastic Outlier Selection (SOS)

python异常检测 - Stochastic Outlier Selection (SOS) 前言 随机离群选择SOS算法全称stochastic outlier selection algorithm. 该算法的作者是jeroen janssens. SOS算法是一种无监督的异常检测算法. 随机离群选择SOS算法原理 随机离群选择SOS算法的输入: 特征矩阵(featu…

架构设计笔记-14-云原生架构设计理论与实践

知识要点 云原生&#xff08;Cloud Native&#xff09;架构原则&#xff1a; 服务化原则&#xff1a;通过微服务架构&#xff0c;小服务&#xff08;MiniService&#xff09;架构把不同生命周期的模块分离出来&#xff0c;分别进行业务迭代&#xff0c;避免迭代频繁模块被慢速…

10 分钟使用豆包 MarsCode 帮我搭建一套后台管理系统

以下是「 豆包MarsCode 体验官」优秀文章&#xff0c;作者把梦想揉碎。 十分钟使用豆包 MarsCode 搭建后台管理项目 在这个快节奏的时代&#xff0c;开发者们总是希望能够快速、高效地完成项目的搭建与开发工作。无论是初创企业还是大型公司&#xff0c;后台管理系统都是必不可…