卷积神经网络实现MNIST手写数字识别 - P1

news2025/1/11 12:59:51
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第P1周:实现mnist手写数字识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
      • 引用需要的包
      • 设置GPU
    • 数据准备
      • 下载数据集
      • 数据集预览
      • 数据集准备
    • 模型设计
    • 模型训练
      • 超参数设置
      • helper函数
      • 正式训练
    • 结果呈现
  • 总结与心得体会


环境

  • 系统:Linux
  • 语言: Python 3.8.10
  • 深度学习框架:PyTorch 2.0.0+cu118

步骤

环境设置

引用需要的包

Python写程序都需要做的事

import torch # 有些API直接在模块下
import torch.nn as nn # 大部分和模型相关的API
import torch.optim as optim # 优化器相关API
# 一些可以直接调用的函数封装(和nn下的很多方法是一样的效果不同的形式)
import torch.nn.functional as F 

from torch.utils.data import DataLoader # 数据集做分批,随机排序
from torchvision import datasets, transforms # 预置数据集下载,数据增强

import matplotlib.pyplot as plt # 图表库
import numpy as np # 用来操作numpy数组,图像展示用

from torchinfo import summary # 打开模型结构

设置GPU

首先用一个全局的对象设置一下当前的设备,是使用CPU还是CPU

# 有显卡就用显卡,没有就用CPU
device = torch.device('cuda'if torch.cuda.is_available() else 'cpu')

数据准备

下载数据集

调用torchvision包预置的API可以一键下载MNIST数据集

train_dataset = datasets.MNIST(
	root='data',  # 数据存放位置
	train=True, # 加载训练集还是验证集
	download=True,  # 本地没有是否从远程下载
	transform=transforms.ToTensor()) # 载入后将图像转换成pytorch的tensor对象
test_dataset = datasets.MNIST(
	root='data',  
	train=False,  # False说明是验证集
	download=True,
	transform=transforms.ToTensor())

数据集预览

先看看数据集中图像的样子,比如是单通道还是三通道,长宽是多少,然后就可以设置缩放以及模型的一些参数

image, label = train_dataset[0]
image.shape

图片信息
结果表明数据集中的图片应该是单通道的高28宽28的图像

打印里面20个图看看是什么样的

plt.figure(figsize=(20, 4)) # 设置一个plt图表画板的宽和高,单位是英寸。。
for i in range(20):
	image, label = train_dataset[i]
	plt.subplot(2, 10, i+1) # 以2行10列的形式展示图片
	# 先把tensor转为了numpy数组,然后把(1, 28, 28)第0维用squeeze去掉
	# cmap=plt.cm.binary说明是一个单通道的灰度图
	plt.imshow(np.squeeze(image.numpy()), cmap=plt.cm.binary)
	plt.title(label) # 打印一下对应的标签
	plt.axis('off') # 不显示坐标轴

图像预览

数据集准备

设置一下数据的批次大小

batch_size = 32
# 训练集上将数据的顺序打乱一下
train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
test_loader= DataLoader(test_dataset, batch_size=batch_size)

模型设计

采用一个类似于LeNet的小型卷积网络

class Model(nn.Module):
	def __init__(self, num_classes):
		super().__init__()
		# 定义两个卷积层,核都是3x3的,通道数递增
		self.conv1 = nn.Conv2d(1, 16, kernel_size=3)
		self.conv2 = nn.Conv2d(16, 32, kernel_size=3)
		# 池化层没有参数需要学习,可以复用一个
		self.maxpool = nn.MaxPool2d(2)

		# 全连接层的输入维度要结果计算,可以在forward的时候算一下
		self.fc1 = nn.Linear(5*5*32128)
		# 最后一层的输出得是分类的数量
		self.fc2 = nn.Linear(128, num_classes)
	
	def forward(self, x):
		# 28x28 -> conv1 -> 26x26 -> maxpool -> 13x13
		x = self.maxpool(F.relu(self.conv1(x)))
		# 13x13 -> conv2 -> 11x11 -> maxpool -> 5x5
		x = self.maxpool(F.relu(self.conv2(x)))

		# 这里要进全连接层了,需要把数据压平,保留第0维,从第1维开始压
		x = torch.Flatten(start_dim=1)
		x = F.relu(self.fc1(1))
		# 最后一层就不加激活函数了
		x = self.fc2()
# 将模型创建后,设备设置为上面定义的设备对象
model = Model(10).to(device)
# 一定要加input_size,不然打印的就不是实际执行的样子,而是按self中定义的顺序,复用的组件也展示不出来
summary(model, input_size(1, 1, 28, 28))

模型结构

模型训练

接下来就到了训练模型的环节了

超参数设置

需要设置的超参数有训练的轮次epoch和学习率learning_rate

# 轮次
epochs = 10
# 学习率
larning_rate = 0.001
# 创建优化器,将模型参数进去,并设置学习率
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 分类问题,无脑使用交叉熵损失
loss_fn = nn.CrossEntropyLoss()

helper函数

编写两个函数用来封装模型训练和模型验证的过程

  1. 模型训练
def train(train_loader, model, loss_fn, optimizer):
	size = len(train_loader.dataset) # 训练总数据量
	num_batches = len(train_loader) # 批次数量
	train_loss, train_acc = 0, 0 # 记录并返回本次训练过程的状态数据
	for x, y in train_loader:
		x, y = x.to(device), y.to(device) # 将数据加载到和模型相同的设备中,不然取不到值

		preds = model(x) # 这样模型会自动调用forward并进行一些参数的跟踪操作等
		loss = loss_fn(preds, y) # 计算当前批次的损失

		optimizer.zero_grad() # 清空之前训练时产生的梯度
		loss.backward() # 在损失函数上对参数执行反向传播计算梯度
		optimizer.step() # 执行参数更新操作

		# 累加当前数据
		train_loss += loss.item()
		# 计算正确数需要使用argmax求概率最大的一个分类然后和ground truth比较
		train_acc += (preds.argmax(1) == y).type(torch.float).sum().item()
	train_loss /= num_batches # 因为一个批次只计算一次损失,求平均值
	train_acc /= size # 正确率是在总数上计算的
	
	return train_loss, train_loss # 返回数据
  1. 模型验证
# 基本上就是train函数的简化
def test(test_loader, model, loss_fn):
	size = len(test_loader.dataset)
	num_batches = len(test_loader)

	test_loss, test_acc = 0, 0
	for x, y in test_loader:
		x, y = x.to(device), y.to(device)
	
		preds = model(x)
		loss = loss_fn(preds, y)

		test_loss += loss.item()
		test_acc += (preds.argmax(1) == y).type(torch.float).sum().item()

	test_loss /= num_batches
	test_acc /= size

	return test_loss, test_acc

正式训练

开始正式训练,其实也可以封装成一个helper

# 记录训练过程的数据
train_loss, train_acc = [],[]
test_loss, test_acc = [],[]

for epoch in range(epochs):
	model.train() # 切换模型为训练模式
	epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)
	
	model.eval() # 切换模型为评估模式
	epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)

	# 记录本轮次数据
	train_loss.append(epoch_train_loss)
	train_acc.append(epoch_train_acc)
	test_loss.append(epoch_test_loss)
	test_acc.append(epoch_test_acc)

	# 打印本轮次的数据信息
	print(f"Epoch:{epoch+1}, Train loss: {epoch_train_loss:.3f}, Train accuracy: {epoch_train_loss*100:.1f}, Validation loss: {epoch_test_loss:.3f}, Validation accuracy: {epoch_test_acc*100:.1f}")

训练过程

结果呈现

上面打印的结果不够直观我们可以用折线图打印一下

plt.figure(figsize=(16, 4))
series = range(epochs)
plt.subplot(1, 2, 1) # 一排两个图表
plt.plot(series, train_loss, label='train loss')
plt.plot(series, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')
plt.subplot(1, 2, 2)
plt.plot(series, train_acc, label='train accuracy')
plt.plot(series, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

训练结果


总结与心得体会

通过整个过程可以发现,手写数字的识别还是非常简单的,训练的效率比较快,结果也不错。非常适合拿来练手,学习一些基本概念、深度学习框架和分类任务实践过程等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/839329.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

架构训练营学习笔记:5-3接口高可用

序 架构决定系统质量上限,代码决定系统质量下限,本节课串一下常见应对措施的框架,细节不太多,侧重对于技术本质有深入了解。 接口高可用整体框架 雪崩效应:请求量超过系统处理能力后导致系统性能螺旋快速下降 链式…

mssqlmysql数据库忽略大小写

一、mssql -- 创建数据时指定排序集Latin1_General_CI_AS CREATE DATABASE [数据库名] COLLATE Latin1_General_CI_AS 查询效果: 二、mysql

课程作业-基于Python实现的迷宫搜索游戏附源码

简单介绍一下 该项目不过是一个平平无奇的小作业,基于python3.8开发,目前提供两种迷宫生成算法与三种迷宫求解算法,希望对大家的学习有所帮助。 项目如果有后续的跟进将会声明,目前就这样吧~ 效果图如下所示: 环境…

AP2400 LED汽车摩灯照明电源驱动 过EMC DC-DC降压恒流IC

产品特点 宽输入电压范围:5V~100V 可设定电流范围:10mA~6000mA 固定工作频率:150KHZ 内置抖频电路,降低对其他设备的 EMI干扰 平均电流模式采样,恒流精度更高 0-100%占空比控制&#xff0…

基于Pyqt5+serial的串口电池监测工具

本章,其他的没有,废话没有,介绍一下新开源了一个公司的测试工具,写了差不多三周吧。先来看看界面: 这是一个串口调试界面,使用Pyqt5serial完成。升级功能暂未移入,占一个坑位。 基于serial二次开…

javaAPI(一):String

String的特性 String底层源码 1、String声明为final,不可被继承 2、String实现了Serializable接口:表示字符支持序列化 实现了Comparable接口:表示String可以比较大小 3、String内部定义了final char[] value用于存储字符串 4、通过字面量的…

20天突破英语四级高频词汇——第①天

20天突破英语四级高频词汇~第一天加油(ง •_•)ง💪 🐳博主:命运之光 🌈专栏:英语四级高频词汇速记 🌌博主的其他文章:点击进入博主的主页 目录 20天突破英语四级…

CSS基础介绍笔记1

官方文档 CSS指的是层叠样式(Cascading Style Sheets)地址:CSS 教程离线文档:放大放小:ctrl鼠标滚动为什么需要css:简化修改HTML元素的样式;将html页面的内容与样式分离提高web开发的工作效率&…

windows开机运行jar

windows开机自启动jar包: 一、保存bat批处理文件 echo off %1 mshta vbscript:CreateObject("WScript.Shell").Run("%~s0 ::",0,FALSE)(window.close)&&exit java -jar E:\projects\ruoyi-admin.jar > E:\server.log 2>&1 &…

自监督去噪:Neighbor2Neighbor原理分析与总结

文章目录 1. 方法原理1.1 先前方法总结1.2 Noise2Noise回顾1.3 从Noise2Noise到Neighbor2Neighbor1.4 框架结构2. 实验结果3. 总结 文章链接:https://arxiv.org/abs/2101.02824 参考博客:https://arxiv.org/abs/2101.02824 1. 方法原理 1.1 先前方法总…

C数据结构与算法——哈希表/散列表创建过程中的冲突与聚集(哈希查找) 应用

实验任务 (1) 掌握散列算法(散列函数、散列存储、散列查找)的实现; (2) 掌握常用的冲突解决方法。 实验内容 (1) 选散列函数 H(key) key % p,取散列表长 m 为 10000,p 取小于 m 的最大素数; (2) 测试 α…

javaWeb项目--二级评论完整思路

先来看前端需要什么吧: 通过博客id,首先需要显示所有一级评论,包括评论者的头像,昵称,评论时间,评论内容 然后要显示每个一级评论下面的二级评论,包括,评论者的头像,昵称…

python:基于Kalman滤波器的移动物体位置估计

CSDN@_养乐多_ Kalman滤波器是一种经典的估计方法,广泛应用于估计系统状态的问题。本篇博客将介绍Kalman滤波器的基本原理,并通过一个简单的Python代码示例,演示如何使用Kalman滤波器来估计移动物体的位置。 通过运行代码,我们将得到一个包含两个子图的图像,分别展示了估…

数学知识(三)

一、容斥原理 #include<iostream> #include<algorithm>using namespace std;const int N 20;typedef long long LL; int n,m; int p[N];int main() {cin>>n>>m;for(int i 0;i < m;i ) cin>>p[i];int res 0;//从1枚举到2^m(位运算)for(int …

SpringBoot+vue 大文件分片下载

学习链接 SpringBootvue文件上传&下载&预览&大文件分片上传&文件上传进度 VueSpringBoot实现文件的分片下载 video标签学习 & xgplayer视频播放器分段播放mp4&#xff08;Range请求交互过程可以参考这个里面的截图&#xff09; 代码 FileController …

HTML|计算机网络相关

1.三次握手 第一次握手&#xff1a;客户端首先向服务端发送请求。 第二次握手&#xff1a;服务端在接收到客户端发送的请求之后&#xff0c;需要告诉客户端已收到请求。 第三次握手&#xff1a;客户端在接收到服务端发送的请求和确认信息之后&#xff0c;同样需要告诉服务端已…

python并发编程(多线程、多进程、多协程)

文章截图来源来源B站&#xff1a;蚂蚁学python 引入并发&#xff0c;就是为了提升程序运行速度 1、基础介绍 1-1 CPU密集型计算、IO密集型计算 1-2 多进程、多线程、多协程对比 2、全局解释器锁GIL 2-1 python速度慢的两大原因 2-2 GIL是什么 2-3 为什么有GIL这个东西 2-4 怎样…

Vue [Day3]

Vue生命周期 生命周期四个阶段 生命周期函数&#xff08;钩子函数&#xff09; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale…

企业服务器数据库中了devos勒索病毒怎么办如何解决预防勒索病毒攻击

随着科学技术的不断发展&#xff0c;计算机可以帮助我们完成很多重要的工作&#xff0c;但是随之而来的网络威胁也不断提升。近期&#xff0c;我们收到很多企业的求助&#xff0c;企业的服务器数据库遭到了devos勒索病毒攻击&#xff0c;导致系统内部的许多重要数据被加密无法正…

1310. 数三角形

题目链接&#xff1a;https://www.acwing.com/problem/content/1312/ 首先不考虑三点共线的情况一共有 种&#xff0c;现在来计算三点共线的情况 1.三点在一条直线上 2.三点在一条竖线上 3.三点在一条斜线上&#xff0c;正反斜线对称&#xff0c;仅需考虑一边的情况 如果…