深度学习-第P1周——实现mnist手写数字识别

news2024/11/24 2:48:07

深度学习-第P1周——实现mnist手写数字识别

  • 深度学习-第P1周——实现mnist手写数字识别
    • 一、前言
    • 二、我的环境
    • 三、前期工作
      • 1、导入依赖项并设置GPU
      • 2、导入数据集
      • 3、数据可视化
    • 四、构建简单的CNN网络
    • 五、训练模型
      • 1、设置超参数
      • 2、编写训练函数
      • 3、编写测试函数
      • 4、正式训练
    • 六、结果可视化
    • 七、总结

深度学习-第P1周——实现mnist手写数字识别

一、前言

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

二、我的环境

  • 电脑系统:Windows 10
  • 语言环境:Python 3.8.5
  • 编译器:colab在线编译
  • 深度学习环境:Pytorch

三、前期工作

1、导入依赖项并设置GPU

所需库函数的介绍:

  • numpy是Python科学计算的基本包。
  • matplotlib 是在Python中常用的绘制图形的库。
  • PyTorch是一个开源的深度学习框架,提供了各种张量操作并通过自动求导可以自动进行梯度计算,方便构建各种动态神经网络
  • torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型
  • torch.device代表将torch.Tensor分配到的设备的对象,其包含一个设备类型('cpu’或’cuda’设备类型)和可选的设备的序号。如果设备序号不存在,则为当前设备

导入依赖项:

1.#导入所需要的库
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

在这里插入图片描述

2、导入数据集

使用dataset下载MNIST数据集,并划分训练集和测试集

使用dataloader加载数据,并设置好基本的batch_size

torchvision.datasets.MNIST详解

torchvision.datasets.MNIST是pytorch自带的一个数据库,我们可以通过代码在线下载数据,这里使用的是torchvision.datasets中的MNIST数据集

函数原型:

torchvision.datasets.MNIST(root, train = True, transform = None, download = False)

参数说明:

  • root(string): 数据地址
  • train(string): True = 训练集, False = 测试集
  • transform(callable, optional): 这里的参数选择一个你想要的数据转化函数,直接完成数据转化
  • download(bool, optional):如果为True,从互联网上下载数据集,并把数据集放在root目录下
2.#加载数据集,并划分训练集和测试集
train_ds = torchvision.datasets.MNIST('data',
                    train = True,  
                    transform = torchvision.transforms.ToTensor(), #将数据类型转化为Tensor类型
                    download = True  )
test_ds = torchvision.datasets.MNIST('data',
                    train = False,
                    transform = torchvision.transforms.ToTensor(),
                    download = True)

torch.utils.data.DataLoader详解

  • torch.utils.data.DataLoader是pytorch自带的一个数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据量
  • 在训练模型时使用到此函数,用来把训练数据分成多个小组,此函数每次抛出一组数据。直至把所有的数据都抛出。就是做一个数据的初始化。

torch.utils.data.DataLoader(dataset, batch_size, shuffle)

参数说明:

  • dataset: 加载的数据集
  • batch_size:每批加载的样本大小,默认为1
  • shuffle:如果为True,则打乱数据的顺序
  • 如下第一个函数就是将训练集生成迭代数据,每次迭代的数据为32个,shuffle为洗牌操作,即打乱顺序。
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds, batch_size = batch_size, shuffle = True)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size = batch_size)

#取一个批次查看数据格式
#数据的shape为:[batch_size, channel, height, weight]
#其中batch_size是样本批次,channel, height, weight分别是图片的通道数,高度和宽度
#train_dl本质上是一个可迭代对象,可以使用iter()进行访问,采用iter(dataloader)返回的是一个迭代器,然后可以使用next()访问。
#也可以使用enumerate(dataloader)的形式访问。
imgs, labels = next(iter(train_dl)) #第一次使用next访问的是迭代器里的第一批数据,二次是第二批,以此类推

img = imgs[0] #访问第一批数据里的第一个数据
print(imgs.shape, img.shape)

()

3、数据可视化

3.#数据可视化
import numpy as np

#指定图片大小,图片大小为20宽,5高的绘图(单位为英寸)
plt.figure(figsize = (20, 5))
#enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
for i, imgs in enumerate(imgs[:20]): 
  #维度缩减,把维度为1的去掉 如:[32, 1, 28,28]->[32, 28, 28]
  npimg = np.squeeze(imgs.numpy())
  #将整个figure分成2行10列,绘制第i + 1个子图
  plt.subplot(2, 10, i + 1)
     # imshow()函数格式为:
	#matplotlib.pyplot.imshow(X, cmap=None)
	#X: 要绘制的图像或数组。
	#cmap: 颜色图谱(colormap), 默认绘制为RGB(A)颜色空间。
  plt.imshow(npimg, cmap = plt.cm.binary)
  plt.axis('off')

在这里插入图片描述

四、构建简单的CNN网络

对于一般的CNN网络来说,都是由特征提取网络和分类网络构成,其中特征提取网络用于提取图片的特征,分类网络用于将图片进行分类。

  • nn.Conv2d为卷积层,用于提取图片的特征,传入参数为输入channel,输出channel,池化核大小
  • nn.MaxPool2d为池化层,进行下采样,用更高层的抽象表示图像特征,传入参数为池化核大小
  • nn.ReLU为激活函数,使模型可以拟合非线性数据
  • nn.Linear为全连接层,可以起到特征提取器的作用,最后一层的全连接层也可以认为是输出层,传入参数为输入特征数和输出特征数(输入特征数由特征提取网络计算得到,如果不会计算可以直接运行网络,报错中会提示输入特征数的大小,下方网络中第一个全连接层的输入特征数为1600)
  • nn.Sequential可以按构造顺序连接网络,在初始化阶段就设定好网络结构,不需要在前向传播中重新写一遍

下方代码各函数参数详解:
nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding) 二维卷积

  • in_channels: 输入张量的channels数
  • out_channels:期望输出张量的channels数
  • kernel_size:卷积核的大小
  • stride:步长,即卷积核在图像窗口上每次平移的间隔
  • padding:图像的填充数

nn.MaxPool2d(kernel_size)

  • kernel_size: 最大池化的窗口大小

nn.Linear(in_features,out_features,bias = Ture)

  • in_features:输入的神经元个数
  • out_features:输出的神经元个数
  • bias:是否包含偏置
#二、构建简单的CNN网络
# 创建并设置卷积神经网络
# 卷积层:通过卷积操作对输入图像进行降维和特征抽取
# 池化层:是一种非线性形式的下采样。主要用于特征降维,压缩数据和参数的数量,减小过拟合,同时提高模型的鲁棒性。
# 全连接层:在经过几个卷积和池化层之后,神经网络中的高级推理通过全连接层来完成。
import torch.nn.functional as F
num_classes = 10
class Model(nn.Module):
  def __init__(self):
    super().__init__()
    #特征提取网络
    self.conv1 = nn.Conv2d(1, 32, kernel_size = 3) # 第一层卷积,卷积核大小为3*3
    self.pool1 = nn.MaxPool2d(2)            #设置池化层,池化核大小为2*2  
    self.conv2 = nn.Conv2d(32, 64, kernel_size = 3) #第二层卷积,卷积核大小为3*3
    self.pool2 = nn.MaxPool2d(2)

    #分类网络
    self.fc1 = nn.Linear(1600, 64)
    self.fc2 = nn.Linear(64, num_classes)

  #前向传播
  def forward(self, x):
    x = self.pool1(F.relu(self.conv1(x)))
    x = self.pool2(F.relu(self.conv2(x)))

    x = torch.flatten(x, start_dim = 1) #Flatten层,连接卷积层与全连接层

    x = F.relu(self.fc1(x))
    x = self.fc2(x)

    return x

加载并打印模型:

#加载并打印模型
from torchinfo import summary

model = Model().to(device)
summary(model)

在这里插入图片描述

五、训练模型

1、设置超参数

  • torch.optim是一个实现了各种优化算法的库。大部分常用的方法得到支持,并且接口具备足够的通用性,使得未来能够集成更加复杂的方法。
  • 为了使用torch.optim,你需要构建一个optimizer对象。这个对象能够保持当前参数状态并基于计算得到的梯度进行参数更新。
  • 其中的SGD是optim中的一个算法(优化器):随机梯度下降算法
  • 动手学深度学习-多层感知机中:updater = torch.optim.SGD(params, lr=lr)。其中的updater就是一个optimizer对象。“”"
loss_fn = nn.CrossEntropyLoss() #创建损失函数
learning_rate = 1e-2
opt = torch.optim.SGD(model.parameters(), lr = learning_rate) #待优化参数的iterable(w和b的迭代), 学习率

2、编写训练函数

  • .item():求出张量具体位置的元素值的高精度值
def train(dataloader, model, loss_fn, optimizer):
  size = len(dataloader.dataset) #训练集的大小, 一共60000张照片
  num_batches = len(dataloader) #批次数目,1875(60000/32)

  train_loss, train_acc = 0, 0 #初始化训练损失和正确率

  for X, y in dataloader: # 获取图片及其标签
    X, y = X.to(device), y.to(device)

    # 计算预测误差
    pred = model(X) #网络输出
    loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算两者差值即为损失

    # 反向传播
    optimizer.zero_grad() # grad属性归0
    loss.backward() # 反向传播
    optimizer.step() # 每一步自动更新

    #记录acc和 loss
    train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
    train_loss += loss.item()

  train_acc /= size
  train_loss /= num_batches

  return train_acc, train_loss

3、编写测试函数

def test(dataloader, model, loss_fn):
size = len(dataloader.dataset) #测试集的大小,一共10000张照片
num_batches = len(dataloader)  #批次数目,313(10000/32=312.5,向上取整)
test_loss, test_acc =0, 0

# 当不进行训练时,停止梯度更新,节省计算内存消耗
with torch.no_grad():
  for imgs, target in dataloader:
    imgs, target = imgs.to(device), target.to(device)

    #计算Loss
    target_pred = model(imgs)
    loss = loss_fn(target_pred, target)
    test_loss += loss.item()
    test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

test_acc /= size
test_loss /= num_batches

return test_acc, test_loss

4、正式训练

epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
model.train()
epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)

model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)

template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}')
print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')

在这里插入图片描述

六、结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") # 忽视警告
plt.rcParams['font.sans-serif'] = ['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示正负号
plt.rcParams['figure.dpi'] = 100 # 分辨率

epochs_range = range(epochs)

plt.figure(figsize = (12, 3))

plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label = 'Training Accuracy')
plt.plot(epochs_range, test_acc, label = 'Test Accuracy')
plt.legend(loc = 'lower right')
plt.title('Training and Validation Aaccuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label = 'Training Loss')
plt.plot(epochs_range, test_loss, label = 'Test Loss')
plt.legend(loc = 'lower right')
plt.title('Training and Validation Loss')

plt.show()

在这里插入图片描述

七、总结

该项目花了整整一天的时间,自己能理解的部分为目录一到四和七,从构建CNN网络开始就略有乏力,训练模型就完全看不懂了,自己也是尽最大程度的把各函数及其参数搜索并记忆。从结果来看,还是自身基本知识掌握的不到位,连CNN网络的概念都不清楚,下一步的计划是把吴恩达深度学习的视频和课后作业刷完,至少要明白卷积,池化,全连接层的定义和输出的结果以及这些层在背后的意义。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/66601.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ADSP-21489的图形化编程详解(7:延时、增益、分频、反馈、响度)

延时 21489 可以做延时,音频高手会运用此项算法来增强音效,我们做个最简单的,让大家知道怎么用它,至于怎么样嵌入到自己的系统里实现更好的效果,则需要各位调音师专业的耳朵来判断,调音无上限!…

MySQL之索引及其背后的数据结构

✨博客主页: 荣 ✨系列专栏: MySQL ✨一句短话: 难在坚持,贵在坚持,成在坚持! 文章目录一. 索引的介绍1. 什么是索引2. 索引的使用二. 索引背后的数据结构1. 考虑使用哈希表2. 二叉搜索树3. N叉搜索树(B树, B树)4. 注意事项一. 索引的介绍 1. 什么是索引 索引 (Index) 是帮助…

[激光原理与应用-39]:《光电检测技术-6》- 光干涉的原理与基础

目录 第1章 概述 1.1 什么是光干涉 1.2 产生干涉的必要条件 1.3 非相干光 - 自发辐射无法产生干涉 1.4 相干光 - 受激辐射 1.5 时间相干性 1.6 空间相干性 它山之石 第1章 概述 1.1 什么是光干涉 它是指因两束光波相遇而引起光的强度重新分布的现象。 指两列或两列以上…

Verilog入门学习笔记:Verilog基础语法梳理

无论是学IC设计还是FPGA开发,Verilog都是最基本、最重要的必备技能。但任何一门编程语言的掌握都需要长期学习。并不是简简单单的随便读几本书,随便动动脑筋那么简单。Verilog是一门基于硬件的独特语言,由于它最终所实现的数字电路&#xff0…

基于AVDTP信令分析蓝牙音频启动流程

前言 公司项目edifier那边需要在原来音频SBC,AAC基础上增加LHDC5.0编码,在打通lhdc协议栈之前,学习记录一番AVDTP音频服务流程。 一、AVDTP音频流基础知识 分析音频流程首先应具备的最简单基础概念知识:AVDTP信令signal,流端点se…

【JVM】垃圾回收机制详解(GC)

目录一.GC的作用区域二.关于对象是否可回收1.可达性分析算法和引用计数算法2.四种引用类型三.垃圾收集算法1.标记-清除算法2.复制算法3.标记-整理算法4.分代收集算法四.轻GC(Minor GC)和重GC(Full GC)一.GC的作用区域 可以看jvm详解之后,再来理解这篇文章更好 堆和…

[附源码]计算机毕业设计农村人居环境治理监管系统Springboot程序

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

ASP.NET Core 3.1系列(18)——EFCore中执行原生SQL语句

1、前言 前一篇博客介绍了EFCore中常见的一些查询操作,使用Linq或Lambda结合实体类的操作相当方便。但在某些特殊情况下,我们仍旧需要使用原生SQL来获取数据。好在EFCore中提供了完整的方法支持原生SQL,下面开始介绍。 2、构建测试数据库 …

Radare2 框架介绍及使用

Radare2 框架介绍及使用 欢迎入群交流 radare2 这是整个框架的核心工具,它具有debugger和Hexeditor的核心功能,使您能够像打开普通的文件一样,打开许多输入/输出源,包括磁盘、网络连接、内核驱动和处于调试中的进程等。 它实现了…

旧版本金庸群侠传3D新Unity重置修复版入门-lua”脚本“

金庸3DUnity重置入门系列文章 金庸3dUnity重置入门 - lua 语法 金庸3dUnity重置入门 - UniTask插件 金庸3dUnity重置入门 - Cinemachine 动画 金庸3dUnity重置入门 - 大世界实现方案 金庸3dUnity重置入门 - 素材极限压缩 (部分可能放到付费博客) 2022年底~20…

Apifox和Eolink两个测试工具谁最实用?

目前行业内有 postman、jmeter 为代表开源 Api 工具派系,我想对大家对这两个词并不陌生。虽然它们能解决基本的接口测试,但是无法解决接口链路上的所有问题,一个工具难以支持整个过程。在国内,我们可以看到有国产 API 管理工具&am…

Spring Cloud 微服务讲义

Spring Cloud 微服务讲义第一部分 微服务架构第 1 节 互联网应用架构演进第 2 节 微服务架构体现的思想及优缺点第 3 节 微服务架构中的核心概念第二部分 Spring Cloud 综述第 1 节 Spring Cloud 是什么第 2 节 Spring Cloud 解决什么问题第 3 节 Spring Cloud 架构3.1 Spring …

CCES软件做开发,如果仿真器连不进目标板怎么解决?(Failed to connect to processor)

ADI的DSP调试,我在Visual DSP软件下写过一个详细的帖子,来说明仿真器如果连不进目标板,可能存在的几种问题以及解决办法,现在在CCES软件下遇到了同样的问题,所以准备再写一个帖子说明一下。 我们都知道ADI的DSP&#…

智慧工地管理平台系统厂家哪家强|喜讯科技

喜讯科技针对施工现场涉及面广,多种元素交叉,状况较为复杂,如人员出入、机械运行、物料运输等工程项目管理在一定程度上存在着决策层看不清、管理层管不住、执行层做不好的问题。 围绕施工现场管理,构建全方位的智能监控防范体系弥…

Redis——Linux下安装以及命令操作

一、概述 redis是什么? Redis(Remote Dictionary Server ),即远程字典服务 是一个开源的使用ANSI C语言编写、支持网络、可基于内存亦可持久化的日志型、Key-Value数据库,并提供多种语言的API。 是一款高性能的NOSQL系列的非关系型…

每日一题:冒泡排序

每日一题:冒泡排序每日一题:冒泡排序第一种写法:第二种写法:每日一题:冒泡排序 冒泡排序是八大排序中较为简单的一种,具体详细可见:冒泡排序_百度百科 (baidu.com) 我们重点来看冒泡排序的步骤: 冒泡排序…

程序员如何写游戏搞钱?

ConcernedApe,一个叫做Eric Barone的程序员研发了一款叫做星露谷的小游戏,以乡村经营生活为核心,打造了一个虚拟的小世界,在这个小世界,你可以种植农作物,经营农场并挖矿钓鱼。 其中钓鱼的玩法是十分新颖的…

Git常见问题

1.拉取的项目很大,如1G以上,此时报错early EOF 具体报错如下: Cloning into csp-doc... remote: Counting objects: 6061, done. remote: Compressing objects: 100% (4777/4777), done. error: RPC failed; curl 18 transfer closed with …

Spring - FactoryBean扩展实战_MyBatis-Spring 启动过程源码解读

文章目录PrePreMyBatis-Spring 组件扩展点org.mybatis.spring.SqlSessionFactoryBeanInitializingBean扩展接口 afterPropertiesSetFactoryBean 扩展接口 getObjectApplicationListener扩展接口 onApplicationEvent扩展点org.mybatis.spring.mapper.MapperFactoryBeanSqlSessio…

【Linux基本命令归纳整理】

Linux 是一套免费使用和自由传播的类 Unix 操作系统,是一个基于 POSIX 和 UNIX 的多用户、多任务、支持多线程和多 CPU 的操作系统。严格来讲,Linux 这个词本身只表示 Linux 内核,但实际上人们已经习惯了用 Linux 来形容整个基于 Linux 内核&…