卷积神经网络--手写数字识别

news2025/4/28 3:39:45

本文我们通过搭建卷积神经网络模型,实现手写数字识别。

pytorch中提供了手写数字的数据集 ,我们可以直接从pytorch中下载

MNIST中包含70000张手写数字图像:60000张用于训练,10000张用于测试

图像是灰度的,28x28像素

首先,下载数据集

import torch
from torchvision import datasets #封装与图像相关的模型,数据集
from torchvision.transforms import ToTensor # #数据转换,张量,将其他类型的数据转换为tensor张量

training_data=datasets.MNIST(
    root='data',#表示下载的手写数字到哪个路径
    train=True,#读取下载后数据中的训练集
    download=True,#如果之前已经下载过,就不用再下载
    transform=ToTensor(),#张量,图片不能直接传入神经网络模型
)

test_data=datasets.MNIST(
    root='data',
    train=False,
    download=True,
    transform=ToTensor(),
)

打包数据

from torch.utils.data import DataLoader 

train_dataloader=DataLoader(training_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)

判断当前设备是否支持GPU

device='cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'using {device} device')

构建卷积神经网络模型

from torch import nn #导入神经网络模块

class CNN(nn.Module):
    def __init__(self):#初始化类
        super(CNN,self).__init__()#初始化父类
        self.conv1=nn.Sequential(# 将多个层(如卷积、激活函数、池化等)按顺序打包,输入数据会​​依次通过这些层​​,无需手动编写每一层的传递逻辑。
            nn.Conv2d(#2D 卷积层,提取空间特征。
                in_channels=1,#输入通道数
                out_channels=16,#输出通道数
                kernel_size=3,#卷积核大小
                stride=1,#步长
                padding=1,#填充
            ),
            nn.ReLU(),#激活函数,引入非线性变换,使得神经网络能够学习复杂的非线性变换,增强表达能力
            nn.MaxPool2d(kernel_size=2)# 2x2最大池化(尺寸减半)
        )

        self.conv2=nn.Sequential(
            nn.Conv2d(16,32,3,1,1),
            nn.ReLU(),
            # nn.Conv2d(32,32,3,1,1),
            # nn.ReLU(),
            nn.MaxPool2d(2),
        )

        self.conv3=nn.Sequential(
            nn.Conv2d(32,64,3,1,1)
        )
        self.out=nn.Linear(64*7*7,10)

    def forward(self,x):#前向传播
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        x=x.view(x.size(0),-1)# 展平为向量(保留batch_size,合并其他维度)
        output=self.out(x)  # 全连接层输出
        return output

返回的output结果大致如图所示

 模型传入GPU

model=CNN().to(device)
print(model)

  损失函数,衡量的是​​模型预测的概率分布​​与​​真实的类别分布​​之间的差异。

loss_fn=nn.CrossEntropyLoss()

  优化器,用于在训练神经网络时更新模型参数,目的是​​在神经网络训练过程中,自动调整模型的参数(权重和偏置),以最小化损失函数​​。

optimizer=torch.optim.Adam(model.parameters(),lr=0.01)

 模型训练

def train(dataloader,model,loss_fn,optimizer):
    model.train()
    batch_size_num=1
    for X,y in dataloader:
        X,y=X.to(device),y.to(device)
        pred=model.forward(X)
        loss=loss_fn(pred,y)

        # Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络
        optimizer.zero_grad()               #梯度值清零
        loss.backward()                     #反向传播计算得到每个参数的梯度值
        optimizer.step()                    #根据梯度更新网络参数

        loss_value=loss.item()
        if batch_size_num%100==0:
            print(f'loss:{loss_value:>7f}[number:{batch_size_num}]')
        batch_size_num+=1

epochs=10

for i in range(epochs):
    print(f'第{i}次训练')
    train(train_dataloader, model, loss_fn, optimizer)

模型测试

def test(dataloader,model,loss_fn):

    size = len(dataloader.dataset)# 测试集总样本数
    num_batches = len(dataloader)# 测试集总批次数
    model.eval()#进入到模型的测试状态,所有的卷积核权重被设为只读模式
    test_loss, correct = 0, 0# 初始化累计损失和正确预测数
    #禁用梯度计算
    with torch.no_grad():#一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算所用内存消耗。
        for X,y in dataloader:
            X,y=X.to(device),y.to(device)
            pred=model.forward(X)
            test_loss+=loss_fn(pred,y).item()
            correct+=(pred.argmax(1)==y).type(torch.float).sum().item()
            a=(pred.argmax(1)==y)
            b=(pred.argmax(1)==y).type(torch.float)
    test_loss/=num_batches
    correct/=size

    print(f'Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}')

test(test_dataloader,model,loss_fn)

得到结果如图所示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2343440.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL Server 2019 安装与配置详细教程

一、写在最前的心里话 和 MySQL 对比,SQL Server 的安装和使用确实要处理很多细节: 需要选择配置项很多有“定义实例”的概念,同一机器可以运行多个数据库服务设置身份验证方式时,需要同时配置 Windows 和 SQL 登录要想 Spring …

MyBatisPlus文档

一、MyBatis框架回顾 使用springboot整合Mybatis,实现Mybatis框架的搭建 1、创建示例项目 (1)、创建工程 新建工程 创建空工程 创建模块 创建springboot模块 选择SpringBoot版本 (2)、引入依赖 <dependencies><dependency><groupId>org.springframework.…

Memcached 主主复制架构搭建与 Keepalived 高可用实现

实验目的 掌握基于 repcached 的 Memcached 主主复制配置 实现通过 Keepalived 的 VIP 高可用机制 验证数据双向同步及故障自动切换能力 实验环境 角色IP 地址主机名虚拟 IP (VIP)主节点10.1.1.78server-a10.1.1.80备节点10.1.1.79server-b10.1.1.80 操作系统: CentOS 7 软…

鸿蒙ArkUI之相对布局容器(RelativeContainer)实战之狼人杀布局,详细介绍相对布局容器的用法,附上代码,以及效果图

在鸿蒙应用开发中&#xff0c;若是遇到布局相对复杂的场景&#xff0c;往往需要嵌套许多层组件&#xff0c;去还原UI图的效果&#xff0c;若是能够掌握相对布局容器的使用&#xff0c;对于复杂的布局场景&#xff0c;可直接减少组件嵌套&#xff0c;且随心所欲完成复杂场景的布…

线程函数库

pthread_create函数 pthread_create 是 POSIX 线程库&#xff08;pthread&#xff09;中的一个函数&#xff0c;用于创建一个新的线程。 头文件 #include <pthread.h> 函数原型 int pthread_create(pthread_t *thread, const pthread_attr_t *attr,void *(*s…

[C]基础13.深入理解指针(5)

博客主页&#xff1a;向不悔本篇专栏&#xff1a;[C]您的支持&#xff0c;是我的创作动力。 文章目录 0、总结1、sizeof和strlen的对比1.1 sizeof1.2 strlen1.3 sizeof和strlen的对比 2、数组和指针笔试题解析2.1 一维数组2.2 字符数组2.2.1 代码12.2.2 代码22.2.3 代码32.2.4 …

OpenCV 图形API(60)颜色空间转换-----将图像从 YUV 色彩空间转换为 RGB 色彩空间函数YUV2RGB()

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 将图像从 YUV 色彩空间转换为 RGB。 该函数将输入图像从 YUV 色彩空间转换为 RGB。Y、U 和 V 通道值的常规范围是 0 到 255。 输出图像必须是 8…

hbuilderx云打包生成的ipa文件如何上架

使用hbuilderx打包&#xff0c;会遇到一个问题。开发的ios应用&#xff0c;需要上架到app store&#xff0c;因此&#xff0c;就需要APP store的签名证书&#xff0c;并且还需要一个像xcode那样的工具来上架app store。 我们这篇文章说明下&#xff0c;如何在windows电脑&…

Golang | 位运算

位运算比常规运算快&#xff0c;常用于搜索引擎的筛选功能。例如&#xff0c;数字除以二等价于向右移位&#xff0c;位移运算比除法快。

产品动态|千眼狼sCMOS科学相机捕获单分子荧光信号

单分子荧光成像技术&#xff0c;作为生物分子动态研究的关键工具&#xff0c;对捕捉微弱信号要求严苛。传统EMCCD相机因成本高昂&#xff0c;动态范围有限&#xff0c;满阱容量低等问题&#xff0c;制约单分子研究成果产出效率。 千眼狼精准把握科研需求与趋势&#xff0c;自研…

Hot100方法及易错点总结2

本文旨在记录做hot100时遇到的问题及易错点 五、234.回文链表141.环形链表 六、142. 环形链表II21.合并两个有序链表2.两数相加19.删除链表的倒数第n个节点 七、24.两两交换链表中的节点25.K个一组翻转链表(坑点很多&#xff0c;必须多做几遍)138.随机链表的复制148.排序链表 N…

网络:手写HTTP

目录 一、HTTP是应用层协议 二、HTTP服务器 三、HTTP服务 认识请求中的uri HTTP支持默认首页 响应 功能完善 套接字复用 一、HTTP是应用层协议 HTTP下层是TCP协议&#xff0c;站在TCP的角度看&#xff0c;要提供的服务是HTTP服务。 这是在原来实现网络版计算器时&am…

【计算机视觉】CV实战项目 - 基于YOLOv5的人脸检测与关键点定位系统深度解析

基于YOLOv5的人脸检测与关键点定位系统深度解析 1. 技术背景与项目意义传统方案的局限性YOLOv5多任务方案的优势 2. 核心算法原理网络架构改进关键点回归分支损失函数设计 3. 实战指南&#xff1a;从环境搭建到模型应用环境配置数据准备数据格式要求数据目录结构 模型训练配置文…

【python】如何将python程序封装为cpython的库

python程序在发布时&#xff0c;往往会打包为cpython的库&#xff0c;并且根据应用服务器的不同架构&#xff08;x86/aarch64&#xff09;&#xff0c;以及python的不同版本&#xff0c;封装的输出类型也是非常多。本文介绍不同架构指定python下的代码打包方式&#xff1a; 首…

计算机组成原理 课后练习

例一&#xff1a; 例二&#xff1a; 1. 原码一位乘 基本原理 原码是一种直接表示数值符号和大小的方式&#xff1a;最高位为符号位&#xff08;0表示正&#xff0c;1表示负&#xff09;&#xff0c;其余位表示数值的绝对值。原码一位乘的核心思想是逐位相乘&#xff0c;并通…

SVN仓库突然没有权限访问

如果svn仓库突然出现无法访问的情况&#xff0c;提示没有权限&#xff0c;所有账号都是如此&#xff0c;新创建的账号也不行。 并且会突然提示要输入账号密码。 出现这个情况时&#xff0c;大概率库里面的文件有http或者https的字样&#xff0c;因为单独给该文件添加权限导致…

【Qt】文件

&#x1f308; 个人主页&#xff1a;Zfox_ &#x1f525; 系列专栏&#xff1a;Qt 目录 一&#xff1a;&#x1f525; Qt 文件概述 二&#xff1a;&#x1f525; 输入输出设备类 三&#xff1a;&#x1f525; 文件读写类 四&#xff1a;&#x1f525; 文件和目录信息类 五&…

【AI】[特殊字符]生产规模的向量数据库 Pinecone 使用指南

一、Pinecone 的介绍 Pinecone是一个完全托管的向量数据库服务&#xff0c;专为大规模机器学习应用设计。它允许开发者轻松存储、搜索和管理高维向量数据&#xff0c;为推荐系统、语义搜索、异常检测等应用提供强大的基础设施支持。 1.1 Pinecone的核心特性 1. 高性能向量搜…

dstream

DStream转换DStream 上的操作与 RDD 的类似&#xff0c;分为 Transformations&#xff08;转换&#xff09;和 Output Operations&#xff08;输出&#xff09;两种&#xff0c;此外转换操作中还有一些比较特殊的原语&#xff0c;如&#xff1a;updateStateByKey()、transform(…

HFSS5(李明洋)——设置激励(波端口激励)

Magnetic是适用于铁磁氧导体的&#xff0c;只有前三种激励类型可以用于计算S参数 1波端口激励 也可以设置在模型内部&#xff0c;如果是设置在模型内部必须加一段理想导体&#xff0c;用于指定端口方向 1.1——模式 number 输入N&#xff1a;计算1-N的模式都计算 1.2——模式…