Pytorch入门实战 P1-实现手写数字识别

news2024/10/24 23:16:50

目录

一、前期准备(环境+数据)

1、首先查看我们电脑的配置;

2、使用datasets导入MNIST数据集

3、使用dataloader加载数据集

4、数据可视化

二、构建简单的CNN网络

三、训练模型

1、设置超参数

2、编写训练函数

3、编写测试函数

4、正式训练

四、结果可视化


  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

一、前期准备(环境+数据)

编辑器:Pycharm

环境语言:python、pytorch

1、首先查看我们电脑的配置;

即:看看我们电脑是CPU版本还是GPU的。

import torch
torch.cuda.is_available()

# 返回 False,则是CPU版本;反之是GPU版本

查看自己的电脑配置后,一般写代码的时候,只需要看是CPU或GPU,然后根据不同的版本运行代码。一般我们会选择使用判断语句这样写:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

不懂得torch.device() 的,我把官网附在下面了。 

2、使用datasets导入MNIST数据集

本文中,我们主要用来实现手写数字的识别,因此,我们得先得到手写数字数据集,即MNIST数据集。

【MNIST数据集背景介绍】:

        手写数据集,如MNIST,是一个经典的机器学习数据集,主要用于手写数字识别。

        这个数据集包含了来自250个不同的人手写数字图片,其中50%是学生,50%来自人口普查局的工作人员。训练集一共包含了60,000张图像和标签,而测试集一共包含了10,000张图像和标签。测试集中前5000个来自最初NIST项目的训练集,后5000个来自最初MNIST项目的测试集。前5000个比后5000个要规整,这是因为前5000个数据来自于美国人口普查局的员工,他们的书写相对更标准,而后5000个来自于大学生,书写风格可能更多样。

        该数据集的收集目的是希望通过算法,实现对手写数字的识别。在手写数字识别分类中,每个样本都是一个28x28像素的灰度图像,表示手写数字0到9。

总的来说,手写数据集为机器学习领域的研究者提供了一个标准化的、大规模的、有挑战性的数据集,有助于推动手写数字识别等相关技术的发展。


【导入MNIST数据集】:

首先,使用datasets下载MNIST数据,并划分好训练集和测试集。

import torchvision 

# 训练集数据

train_ds = torchvision.datasets.MNIST('data',
                                       train=True,
                                       transform=torchvision.transforms.ToTensor(),
                                        download=True)
# 测试集数据
test_ds = torchvisio.datasets.MNIST('data',
                               train=False,
                               transform=torchvision.transforms.ToTensor(),
                               download=True)

我们先来看下原型:

torchvision.datastes.MNIST( root,train=True, transform=None, download=False)

其中:

        root是要把下载的数据集存入的文件夹的名字。

        train:True表示是训练集;False表示是测试集;

        transform: 这里的参数,选择一个你想要的数据转换函数,直接完成数据转化。

        download:True 从互联网上下载数据集,并把数据集放在root目录下。


下载完成后的目录是这样的:(如果已经下载一次了,后续就可以把download改为False,不然每次运行都会下载

3、使用dataloader加载数据集

使用dataloader加载数据集,并设置好基本的batch_size。

import torch 

batch_size = 32   # 每批加载样本的大小

# 加载训练集数据
train_dl = torch.utils.data.DataLoader(train_ds,
                                       batch_size=batch_size,
                                       shuffle=True)   # 每个epoch重新排列数据

# 加载测试集数据
test_dl = torch.utils.data.DataLoader(test_ds,
                                       batch_size=batch_size)

我们可以取一个批次查看下数据的格式:

imgs,labels = next(iter(train_dl))

print(imgs.shape)   # 得到结果是   torch.Size([32,1,28,28])

其中:我们得到的数据的shape位:[batch_size , channel, height, weight]

                batch_size :是我们自己设置的(上面的代码中有设置过)

                channel: 通道数  (黑白图像一般的通道数为1;RGB格式图像的通道数为:3)

                height: 图片的高度

                weight: 图片的宽度

train_dl 就是我们上面的使用dataloader加载的训练集的数据。

iter(train_dl) 将数据加载器转换为一个迭代器(iterator),使得我们可以使用Python的next()函数来逐个访问数据加载器中的元素。

next()  函数用于获取迭代器中的下一个元素。这里,它被用来获取train_dl中的下一批量数据。


4、数据可视化

数据可视化,就是使用代码展示下,我们上面获取的数据(获取20个数字的图片)。

plt.figure(figsize=(20,5))
for i,imgs in enumerate(imgs[:20]):
    # 维度缩减
    npimg = np.squeeze(imgs.numpy())
    plt.subplot(2,10,i+1)  # 指定划分的行数、列数及子图的索引。
    plt.imshow(npimg,cmap=plt.cn.binary)  # 展示图片,以cmap给的色彩展示
    plt.axis('off')  # 关闭坐标轴
plt.show()  # 展示图片

 运行结果展示:


至此,我们的前期准备工作准备结束。我们即将进入第二部分!!!

二、构建简单的CNN网络

我们现在简单看下上面这个图,上图里面是一个简单的CNN网络图。

依次包括:输入层、卷积层1、池化层1、卷积层2、池化层2、全连接层1、全连接层2、全连接层3、输出层。 

对于一般的CNN网络来说,都是由【特征网络】和【分类网络】构成的。

①nn.Conv2d  为卷积层,用于提取图片的特征,传入参数为:input_channel、out_channel、kernal_size;

②nn.MaxPool2d 为池化层,进行下采样,更高层的抽象表示图像特征,传入参数为:kernal_size

③nn.ReLU 为激活函数,使得模型可以拟合非线性数据。

④nn.Squential 可以按构造顺序连接网络,在初始化阶段就设定好网络结构,不需要在前向传播中重新写一遍。


下面的代码,我们以这个图为例,两层卷积、两层池化、全连接层。

num_classes = 10   # 图片的类别数


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # 特征提取网络
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)   # 输入图像的通道数、输出图像的通道数、卷积核大小  (RGB图像的输入通道数为3)
        self.pool1 = nn.MaxPool2d(2)                    # 设置池化层,池化核大小为2*2
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)   # 第二层卷积,卷积核大小为3*3
        self.pool2 = nn.MaxPool2d(2)

        # 分类网络
        self.fc1 = nn.Linear(1600,64)
        self.fc2 = nn.Linear(64,num_classes)

    # 前向传播
    def forward(self,x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))

        x = torch.flatten(x,start_dim=1)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x
# 打印并加载模型
model = Model().to(device)
print(model)
print("查看模型信息:")
summary(model)

​​​​​​​

这个步骤很重要,这块就是我们一般修改模型的地方。如果是写论文的话,这里是很重要的。因为是刚开始学习,就先能大概了解就行,后续我还会继续学习的。也会多多更新这里的。

三、训练模型

我们现在已经构建好了CNN的网络模型,那么就开始设置一些参数训练模型吧。

1、设置超参数

loss_fn = nn.CrossEntropyLoss()  # 创建损失函数
learn_rate = 1e-1  # 学习率
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)

2、编写训练函数

'''
    1、optimizer.zero_grad() 函数会遍历模型的所有参数,通过内置方法截断反向传播的梯度流,再将每个参数的梯度值设为0,即上次的梯度记录会被清空。
    2、loss.backward() Pytorch的反向传播(即:tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计算其对应的梯度。
    3、optimizer.step()  step() 函数的作用是执行一次优化步骤,通过梯度下降法来更新参数的值。因为梯度下降是基于梯度的,所以在执行optimizer.step() 函数前应先指向那个loss.backward()函数来计算梯度。
'''

# 训练循环
print('准备进入----训练集里面')


def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共60000张图片
    num_batches = len(dataloader)  # 批次数目,1875(60000/32)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率

    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)  # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值,即为损失。

        # 反向传播   (以下三个基本上是固定的)
        optimizer.zero_grad()  # grade属性归零
        loss.backward()  # 反向传播
        optimizer.step()  # 每一步自动更新

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(
            torch.float).sum().item()  # 表示计算预测正确的样本数量,并将其作为一个标量值返回。这通常用于评估分类模型的准确率或计算分类问题的正确预测数量。
        '''
            pred.argmax(1)返回数组pred在第一个轴(即行)上最大值所在的索引。这通常用于多分类问题中,其中pred是一个包含预测概率的二维数组,每行表示一个样本的预测概率分布。
            pred.argmax(1) == y是一个布尔值,其中等号是否成立代表对应样本的预测是否正确。(True表示正确,False表示错误)

            .type(torch.float)是将布尔数组的数据类型转换为浮点数类型,即将True转换为1.0;将False转换为0.0
            .sum() 是对数组中的元素进行求和,计算出预测正确的样本数量。
            .item() 将求和结果转换为标量值,以便在Python中使用或打印。
        '''
        train_loss += loss.item()
    train_acc /= size
    train_loss /= num_batches

    return train_acc, train_loss

3、编写测试函数

测试函数、训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器。
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)  # 批次数目313 (10000/32=321.5 向上取整)
    test_loss, test_acc = 0, 0

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc /= size
    test_loss /= num_batches

    return test_acc, test_loss

4、正式训练

epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):  # epoch 索引值
    model.train()  # 启用Batch Normalization和Dropout
    '''
        如果模型中有BN(Batch Normalization)和Dropout ,需要在训练时添加model.train() 。 model.train() 是保证BN层能够用到每一批数据的均值和方差。
        对于Dropout ,model.train() 是随机取一部分网络连接来训练更新参数。
    '''
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)

    model.eval()  # 不启用Batch Normalization 和Dropout
    '''
        如果模型中有BN(Batch Normalization)和Dropout ,需要在测试时添加model.eval() . model.eval() 是保证BN层能够用全部训练数据的均值和方差,
        即:测试过程中要保证BN层的均值和方差不变。对于Dropout, model.eval() 是利用到了所有网络连接,即:不进行随机舍弃神经元。

        训练完train样本后,生成的模型model要用来测试样本。在model(test)之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。
        这是model中还有BN层和Dropout所带来的性质。
    '''
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    template = 'Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f}'
    print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss))
print('Done')

四、结果可视化

结果可视化,主要使用的是import matplotlib.pyplot as plt 的绘图。

warnings.filterwarnings('ignore')  # 忽略警告信息
# plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rc('font', family='PingFang HK')
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Train Loss')
plt.plot(epochs_range, test_loss, label="Test Loss")
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

上述我们已经先开始训练模型了,正式训练模型,我们会得到很多数据,因此,我们要用一些可视化工具来清晰的展示,我们得到的数据,看看训练数据和测试数据会有哪些差异呢。

由于我本地电脑跑起来,风扇呼呼响。这里我用的是谷歌提供的免费的 工具跑的代码,训练数据共花费3分钟左右。


至此,我们使用Pytorch完成了手写数字识别,也算是一个简单的基础入门实战啦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1497765.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

蚂蚁感冒 刷题笔记

/* 解题思路 首先根据题意可知 1.蚂蚁速度均为1 即同向蚂蚁永远不可能追上 我们需要求最后感冒蚂蚁的数量 因为蚂蚁碰头将会掉头 效果和俩蚂蚁互相穿过继续走是一样的 所以我们将俩蚂蚁碰头视作穿过 2. 如果俩蚂蚁相向而行 则俩蚂蚁必定碰头 首先 我们获得第一个感冒蚂蚁的…

Requests教程-15-文件上传与下载

领取资料,咨询答疑,请➕wei: June__Go 上一小节,我们学习了requests的HTTPS请求方法,本小节我们讲解一下在requests文件上传与下载。 文件上传 使用requests库上传文件时,需要使用files参数,并将文件打…

IDEA 配置文件乱码,项目编码设置

见下图 其中第一二项控制全局以及工程的编码格式,下方的则是 properties 配置文件的格式,统一调整为 UTF-8 后不再乱码

【Python学习篇】Python基础入门学习——你好Python(一)

个人名片: 🦁作者简介:学生 🐯个人主页:妄北y 🐧个人QQ:2061314755 🐻个人邮箱:2061314755qq.com 🦉个人WeChat:Vir2021GKBS 🐼本文由…

Android制作.9图回忆

背景 多年前,做app开发遇到IM需求,那会用到.9图做聊天气泡背景,现在总结下使用png图片制作.9图。方法有很多,这里主要介绍Android studio制作.9图。当然使用ps、draw9patch都行。 第一步、打开Android studio,切换到dr…

stm32学习笔记:I2C通信协议原理和软件I2C读写MPU6050

概述 第一块:介绍协议规则,然后用软件模拟的形式来实现协议。 第二块:介绍STM32的iic外设,然后用硬件来实现协议。 程序一现象:通过软件I2C通信,对MPU6050芯片内部的寄存器进行读写,写入到配…

Linux安装代理

Linux安装代理 1.下载安装包2.进行解压3.点击运行4.进行配置5.设置系统网络 1.下载安装包 2.进行解压 3.点击运行 4.进行配置 导入链接 5.设置系统网络 测试运行是否成功

迭代器失效问题(C++)

迭代器失效就是迭代器指向的位置已经不是原来的含义了,或者是指向的位置是非法的。以下是失效的几种情况: 删除元素: 此处发生了迭代器的失效,因为erase返回的是下一个元素的位置的迭代器,所以在删除1这个元素的时候&…

鸿蒙Harmony应用开发—ArkTS声明式开发(通用属性:浮层)

设置组件的遮罩文本。 说明: 从API Version 7开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。 overlay overlay(value: string | CustomBuilder, options?: { align?: Alignment; offset?: { x?: number; y?: number } })…

Spring揭秘:BeanDefinitionRegistry应用场景及实现原理!

内容概要 BeanDefinitionRegistry接口提供了灵活且强大的Bean定义管理能力,通过该接口,开发者可以动态地注册、检索和移除Bean定义,使得Spring容器在应对复杂应用场景时更加游刃有余,增强了Spring容器的可扩展性和动态性&#xf…

GB 2312字符集:中文编码的基石

title: GB 2312字符集:中文编码的基石 date: 2024/3/7 19:26:00 updated: 2024/3/7 19:26:00 tags: GB2312编码中文字符集双字节编码区位码规则兼容性问题存储空间优化文档处理应用 一、GB 2312字符集的背景 GB 2312字符集是中国国家标准委员会于1980年发布的一种…

【Python】6. 基础语法(4) -- 列表+元组+字典篇

列表和元组 列表是什么, 元组是什么 编程中, 经常需要使用变量, 来保存/表示数据. 如果代码中需要表示的数据个数比较少, 我们直接创建多个变量即可. num1 10 num2 20 num3 30 ......但是有的时候, 代码中需要表示的数据特别多, 甚至也不知道要表示多少个数据. 这个时候,…

线上企业展厅:突破时空限制,展示企业实力的新平台

引言: 在数字化时代,企业宣传和展示已不再受限于传统的实体展厅。线上企业展厅作为一种创新的展示方式,不仅能够突破时空限制,还能充分利用多媒体技术,为企业带来更为丰富、立体的展示效果。 一、线上企业展厅的优势 …

YOLOv9中train.py与train_dual.py的异同!

专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,主力高效涨点!!! 首先,train.py(左)与train_dual.py(右)中的损失函数是不一样的,这也解释了为什么使用train.py除了填入…

浅谈数据中心末端配电母线槽技术的实现及产品监控选型

安科瑞电气股份有限公司 上海嘉定 201801 【摘要】末端配电母线槽是一种新型的数据中心配电解决方案。本文针对额定电流、额定冲击耐受电压、额定短时耐受电流三个*点技术参数展开探讨,分析了母线槽依据的国家标准,指出了*点技术参数的选择依据&#xf…

【STM32】HAL库 CubeMX 教程 --- 高级定时器 TIM1 定时

实验目标: 通过CUbeMXHAL,配置TIM1,1s中断一次,闪烁LED。 一、常用型号的TIM时钟频率 1. STM32F103系列: 所有 TIM 的时钟频率都是72MHz;F103C8不带基本定时器,F103RC及以上才带基本定时器。…

聊一聊ThreadLocal的原理?

1.ThreadLocal创建方式 ThreadLocal<String> threadlocal1 new ThreadLocal(); ThreadLocal<String> threadlocal2 new ThreadLocal(); ThreadLocal<String> threadlocal3 new ThreadLocal(); 2.首先介绍一下&#xff0c;ThreadLocal的原理&#xff1a; 如…

Git你必须知道的知识

一&#xff1a;使用Git的原因 我们在写版本的时候&#xff0c;可能会谢谢改改&#xff0c;可能要回到之前的文件&#xff0c;修改之前的文件&#xff0c;因此总是要保持很多个文件&#xff0c;且书写文件名也很麻烦。git可以有一个仓库&#xff0c;版本库&#xff0c;可以保存这…

五、循环神经网络语言模型(RNN)

1 循环神经网络基础知识 循环核&#xff08;Recurrent Cell&#xff09;定义&#xff1a; 指在时刻 t 时的神经网络单元&#xff0c;用来处理当前时刻的输入和上一时刻的隐藏状态&#xff0c;并生成当前时刻的输出和下一时刻的隐藏状态。记忆体&#xff08;Memory&#xff09;定…

vue面试--9, 1 ObjectProperty与vue3Proxy区别。2 MVVM的理解 3 双向绑定原理?

1 ObjectProperty与vue3Proxy区别 2 MVVM的理解 3 双向绑定原理&#xff1f;