pytorch入门级项目--基于卷积神经网络的数字识别

news2025/2/25 16:41:49

文章目录

  • 前言
  • 1.数据集的介绍
  • 2.数据集的准备
  • 3.数据集的加载
  • 4.自定义网络模型
    • 4.1卷积操作
    • 4.2池化操作
    • 4.3模型搭建
  • 5.模型训练
    • 5.1选择损失函数和优化器
    • 5.2训练
  • 6.模型的保存
  • 7.模型的验证
  • 结语

前言

本篇博客主要针对pytorch入门级的教程,实现了一个基于卷积神经网络(CNN)架构的数字识别,带你了解由数据集到模型验证的全过程。

1.数据集的介绍

MNIST数据集是一个广泛用于机器学习和计算机视觉领域的手写数字图像数据集。它是深度学习入门的经典数据集,常用于图像识别任务,特别是手写数字识别。该数据集分为训练集和测试集:

  • 训练集:包含60,000张手写数字图像,用于模型的训练。
  • 测试集:包含10,000张手写数字图像,用于模型的评估。

2.数据集的准备

Torchvision在torchvision. datasets模块中提供了许多内置数据集,其中便包含了MNIST数据集,因此可以直接通过torchvision. datasets直接下载。

  • root代表存储路径,此处采用的是相对路径,保存在当前文件所在文件夹的data文件夹下
  • train用来区分训练集和测试集
  • download用来表示是否下载数据集

同时我们可以查看其中一条数据,看看该数据及具体形式

import torchvision
trainset=torchvision.datasets.MNIST(root='./data',train=True,download=True)
testset=torchvision.datasets.MNIST(root='./data',train=False,download=True)
trainset[0]

在这里插入图片描述
这里我的数据集已经下载好了,所以很快执行好了
(<PIL.Image.Image image mode=L size=28x28>, 5)通过输出我们可以发现,第一条数据是一个元组,第一个元素表示data(灰度图,大小28*28),第二个元素表示label,指的是该图片的类别,下面我们可以查看每个类别的含义
在这里插入图片描述
因此我们可以看到第一张图片表示的就是手写数字5,也可以通过trainset[0][0].show()进行显示
在这里插入图片描述
在进行后续操作前,需要将图片格式由PIL调整为tensor类型。

import torchvision
trainset=torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=torchvision.transforms.ToTensor())
testset=torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=torchvision.transforms.ToTensor())
trainset[0][0].shape

在这里插入图片描述

至此,数据集准备阶段就算完成了。

3.数据集的加载

在训练深度学习模型时,通常不会一次性将整个数据集输入到模型中,而是将数据集分成多个小批量(mini-batches)进行训练。

import torch

trainloader=torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testloader=torch.utils.data.DataLoader(testset,batch_size=64,shuffle=False)
type(trainloader),len(trainloader),len(trainset)
  • batch_size:表示批量大小
  • shuffle:表示是否打乱,一般训练集打乱,测试集打乱不打乱都可以
  • num_workers:表示多个进程,windows系统可能报错
    这里提供解决windows系统中多进程报错的可能解决方案:
    加上if __name__ == '__main__':即可,如果你是用的是jupyter,不用此操作

在这里插入图片描述
此时,我们取出一条数据查看类型:
在这里插入图片描述
此时我们可以看到,一条数据为torch.Size([64, 1, 28, 28]),64表示批量大小,1表示该图像单通道,即为灰度图,28*28表示图像大小
我们可以通过tensorboard,查看当前批量的图片

from torch.utils.tensorboard import SummaryWriter

writer=SummaryWriter("./logs")
steps=0
for data in trainloader:
    # print(data[0].shape)
    # break
    images,labels=data
    writer.add_images('mnist_images',images,steps)
    steps+=1
writer.close()

在这里插入图片描述
此处在环境中需要装tensorboard

conda install tensorboard

或者

pip install tensorboard

然后在终端执行

tensorboard --logdir=logs

即可

4.自定义网络模型

这里我们随便选了一张网络结构图,也可以自行设计
在这里插入图片描述
来源:图片来源

4.1卷积操作

这里简单介绍一下卷积的运算方式
在这里插入图片描述
在这里插入图片描述

4.2池化操作

池化层是子采样的一种具体实现方式,这里我们介绍最大池化
在这里插入图片描述

4.3模型搭建

  1. 由图可知,第一层卷积,输入通道数为1(原始图像为灰度图),输出通道数为6,因为图像大小未发生改变,所以padding=2
    在这里插入图片描述
  2. 第二层池化,本文使用的是最大池化
  3. 第三层卷积,输入通道数为6(上一层卷积的输出),输出通道数为16(由图),因为图像大小未发生改变,所以padding=2
  4. 第四层池化,使用最大池化
  5. 全连接层,图示定义了两层
  6. 定义前向传播函数
  7. 输出网络结构
class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        # 第一层卷积层
        self.conv1=torch.nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)
        # 第一层池化层
        self.pool1=torch.nn.MaxPool2d(kernel_size=2,stride=2)
        # 第二层卷积层
        self.conv2=torch.nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5,padding=2)
        # 第二层池化层
        self.pool2=torch.nn.MaxPool2d(kernel_size=2,stride=2)
        # 全连接层
        self.fc1=torch.nn.Linear(in_features=16*7*7,out_features=84)
        self.fc2=torch.nn.Linear(in_features=84,out_features=10)

    def forward(self,x):
        x=self.pool1(torch.nn.functional.relu(self.conv1(x)))
        x=self.pool2(torch.nn.functional.relu(self.conv2(x)))
        x=x.view(-1,16*7*7)
        x=torch.nn.functional.relu(self.fc1(x))
        x=self.fc2(x)
        return x
net=Net()
net

在这里插入图片描述

5.模型训练

5.1选择损失函数和优化器

对于分类问题,一般选用交叉熵损失CrossEntropyLoss(),优化器此处我们选用随机梯度下降SGD

5.2训练

我们先对训练集进行一轮训练

running_loss=0
for i,data in enumerate(trainloader,1):
    inputs,labels=data
    outputs=net(inputs)
    loss=criterion(outputs,labels)
    running_loss+=loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i%100==0:
        print('epoch:{} loss:{}'.format(i,running_loss/100))
        running_loss=0


在这里插入图片描述
我们发现loss持续下降,我们训练十轮

running_loss=0
for epoch in range(10):
    for i,data in enumerate(trainloader,1):
        inputs,labels=data
        outputs=net(inputs)
        loss=criterion(outputs,labels)
        running_loss+=loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i%100==0:
            print('[%d,%5d] loss:%.3f'%(epoch+1,i,running_loss/100))
            running_loss=0.0
print('Finished Training')


在这里插入图片描述

6.模型的保存

在这里插入图片描述

7.模型的验证

此处我们选择test数据集一个批次验证
在这里插入图片描述
计算整体上的正确率

correct=0
total=0
with torch.no_grad():
    for data in testloader:
        images,labels=data
        outputs=net2(images)
        _,predicted=torch.max(outputs.data,1)
        total+=labels.size(0)
        correct+=(predicted==labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

在这里插入图片描述

结语

本篇博客通过训练卷积神经网络CNN模型,实现了对数字的识别,希望对你有所帮助

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2305934.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

yolov12部署(保姆级教程)

yolov12部署 戳链接访问原论文论文地址 戳链接访问原代码代码地址 直接把源代码以ZIP的形式下载到本地&#xff0c;然后解压用IDE打开就可以了&#xff08;这一步比较简单不过多介绍&#xff09; 在IDE中打开可以看见一个README.md文件&#xff0c;这里有我们将yolov12部署本…

五、Three.js顶点UV坐标、纹理贴图

一部分来自1. 创建纹理贴图 | Three.js中文网 &#xff0c;一部分是自己的总结。 一、创建纹理贴图 注意&#xff1a;把一张图片贴在模型上就是纹理贴图 1、纹理加载器TextureLoader 注意&#xff1a;将图片加载到加载器中 通过纹理贴图加载器TextureLoader的load()方法加…

汽车零部件工厂如何通过ESD监控系统闸机提升产品质量

在汽车零部件工厂的生产过程中&#xff0c;静电带来的危害不容小觑。从精密的电子元件到复杂的机械部件&#xff0c;静电都可能成为影响产品质量的 “隐形杀手”。而 ESD 监控系统闸机的出现&#xff0c;为汽车零部件工厂解决静电问题、提升产品质量提供了关键的技术支持。 一、…

Pi币与XBIT:在去中心化交易所的崛起中重塑加密市场

在加密货币市场迅猛发展的背景下&#xff0c;Pi币和XBIT正在成为投资者关注的焦点。Pi币作为一项创新的数字货币&#xff0c;通过独特的挖矿机制和广泛的用户基础&#xff0c;迅速聚集了大量追随者&#xff0c;展示了强大的市场潜力。同时&#xff0c;币应XBIT去中心化交易所的…

【Python量化金融实战】-第2章:金融市场数据获取与处理:2.1 数据源概览:Tushare、AkShare、Baostock、通联数据(DataAPI)

本章将详细介绍四大主流金融数据源&#xff08;Tushare、AkShare、Baostock、通联数据&#xff08;DataAPI&#xff09;&#xff09;&#xff0c;分析其特点与适用场景&#xff0c;并通过实战案例展示数据获取与处理的全流程。 &#x1f449; 点击关注不迷路 &#x1f449; 点击…

首次使用WordPress建站的经验分享(一)

之前用过几种内容管理系统(CMS),如:dedeCMS、phpCMS、aspCMS,主要是为了前端独立建站,达到预期的效果,还是需要一定的代码基础的,至少要有HTML、Css、Jquery基础。 据说WordPress 是全球最流行的内容管理系统CMS,从现在开始记录一下使用WordPress 独立建站的步骤 选购…

Mysql 主从集群同步延迟问题怎么解决

目录 前言&#xff1a; 复制过程分为几个步骤&#xff1a; 一、同步延迟的危害 二、同步延迟的常见原因 1. 主库写入压力过大 2. 网络传输瓶颈 3. 从库硬件性能不足 4. 配置参数不合理 5. 特殊操作影响 三、深度诊断方法 1. 查看同步状态 2. 性能分析工具 四、十大解…

Unity Shader 学习13:屏幕后处理 - 使用高斯模糊的Bloom辉光效果

目录 一、基本的后处理流程 - 以将画面转化为灰度图为例 1. C#调用shader 2. Shader实现效果 二、Bloom辉光效果 1. 主要变量 2. Shader效果 &#xff08;1&#xff09;提取较亮区域 - pass1 &#xff08;2&#xff09;高斯模糊 - pass2&3 &#xff08;3&#xff…

三、《重学设计模式》-单例模式

单例模式 单例模式分为四大类&#xff0c;饿汉式、懒汉式、静态内部类、枚举 饿汉式 优点&#xff1a;类装载时进行实例化&#xff0c;避免同步问题 缺点&#xff1a;造成内存浪费 实现一 1.构造器私有化 2.内部创建对象实例 3.提供静态方法 public class Type1 {public s…

SpringBoot3整合Swagger3时出现Type javax.servlet.http.HttpServletRequest not present错误

目录 错误详情 错误原因 解决方法 引入依赖 修改配置信息 创建文件 访问 错误详情 错误原因 SpringBoot3和Swagger3版本不匹配 解决方法 使用springdoc替代springfox&#xff0c;具体步骤如下&#xff1a; 引入依赖 在pom.xml文件中添加如下依赖&#xff1a; <…

项目实战--网页五子棋(匹配模块)(4)

上期我们完成了游戏大厅的前端部分内容&#xff0c;今天我们实现后端部分内容 1. 维护在线用户 在用户登录成功后&#xff0c;我们可以维护好用户的websocket会话&#xff0c;把用户表示为在线状态&#xff0c;方便获取到用户的websocket会话 package org.ting.j20250110_g…

【Java毕业设计】商城购物系统(附源码+数据库脚本)

本系统是基于JavaEEServletJSPMysql实现的商城购物系统。包括用户登录、用户注册、商品分类、添加购物车、订单支付等基本功能&#xff0c;具体页面及功能如下&#xff1a; 感谢阅读&#xff01; 如需获取完整项目源码及更多项目信息&#xff0c;可添加V&#xff1a;

POST请求提交数据的三种方式及通过Postman实现

1、什么是POST请求&#xff1f; POST请求是HTPP协议中一种常用的请求方法&#xff0c;它的使用场景是向客户端向服务器提交数据&#xff0c;比如登录、注册、添加等场景。另一种常用的请求方法是GET&#xff0c;它的使用场景是向服务器获取数据。 2、POST请求提交数据的常见编…

Spring Boot 整合 Spring MVC /(整合Web)笔记

1. Spring Boot 整合 Web 功能 Spring Boot 通过自动配置简化了 Spring MVC 的集成。只需在 pom.xml 中添加 spring-boot-starter-web 依赖&#xff0c;Spring Boot 就会自动配置 Spring MVC 的相关组件。 <dependency><groupId>org.springframework.boot</gr…

[特殊字符]清华大学:DeepSeek从入门到精通.pdf(清华领航,驾驭DeepSeek,开启AI新境界)

不愧是清华大学出品的deepseek手册&#xff0c;简直是新手 福音&#xff0c;非常实用&#xff01; 这份《DeepSeek&#xff1a;从入门到精通》手册从基础到高 阶&#xff0c;手把手教你玩转DeepSeek&#xff0c;特别适合刚入门的小白&#xff0c;拿来就能用&#xff01; 1.Deep…

深度学习技术全景图:从基础架构到工业落地的超级进化指南

&#x1f50d; 目录导航 基础架构革命训练优化秘技未来战场前瞻 &#x1f9e9; 一、基础架构革命 1.1 前馈神经网络&#xff08;FNN&#xff09; ▍核心结构 import torch.nn as nnclass FNN(nn.Module):def __init__(self):super().__init__()self.fc1 nn.Linear(784, 25…

PyTorch-基础(CUDA、Dataset、transforms、卷积神经网络、VGG16)

PyTorch-基础 环境准备 CUDA Toolkit安装&#xff08;核显跳过此步骤&#xff09; CUDA Toolkit是NVIDIA的开发工具&#xff0c;里面提供了各种工具、如编译器、调试器和库 首先通过NVIDIA控制面板查看本机显卡驱动对应的CUDA版本&#xff0c;如何去下载对应版本的Toolkit工…

IO/网络IO基础全览

目录 IO基础CPU与外设1. 程序控制IO&#xff08;轮询&#xff09;2. 中断中断相关知识中断分类中断处理过程中断隐指令 3. DMA&#xff08;Direct Memory Access&#xff09; 缓冲区用户空间和内核空间IO操作的拷贝概念传统IO操作的4次拷贝减少一个CPU拷贝的mmap内存映射文件(m…

【DeepSeek-R1背后的技术】系列十一:RAG原理介绍和本地部署(DeepSeekR1+RAGFlow构建个人知识库)

【DeepSeek-R1背后的技术】系列博文&#xff1a; 第1篇&#xff1a;混合专家模型&#xff08;MoE&#xff09; 第2篇&#xff1a;大模型知识蒸馏&#xff08;Knowledge Distillation&#xff09; 第3篇&#xff1a;强化学习&#xff08;Reinforcement Learning, RL&#xff09;…

鸿蒙开发深入浅出04(首页数据渲染、搜索、Stack样式堆叠、Grid布局、shadow阴影)

鸿蒙开发深入浅出04&#xff08;首页数据渲染、搜索、Stack样式堆叠、Grid布局、shadow阴影&#xff09; 1、效果展示2、ets/pages/Home.ets3、ets/views/Home/SearchBar.ets4、ets/views/Home/NavList.ets5、ets/views/Home/TileList.ets6、ets/views/Home/PlanList.ets7、后端…