手写数字识别案例分析(torch,深度学习入门)

news2025/1/4 12:39:04

在人工智能和机器学习的广阔领域中,手写数字识别是一个经典的入门级问题,它不仅能够帮助我们理解深度学习的基本原理,还能作为实践编程和模型训练的良好起点。本文将带您踏上手写数字识别的深度学习之旅,从数据集介绍、模型构建到训练与评估,一步步深入探索。

一、引言

手写数字识别(Handwritten Digit Recognition)是指通过计算机程序自动识别手写数字的过程。最著名的手写数字数据集之一是MNIST(Modified National Institute of Standards and Technology database),它包含了大量的手写数字图片,每张图片都被标记了对应的数字(0-9)。这个数据集成为了初学者学习深度学习,尤其是卷积神经网络(CNN)的首选。

二、MNIST数据集简介

MNIST数据集由60,000个训练样本和10,000个测试样本组成,每个样本都是一张28x28像素的灰度图像,代表了一个手写数字。这些图像已经被归一化并居中在图像中心,使得数字不会受到位置变化的影响。

 PyTorch 和 torchvision 库来下载并准备 MNIST 数据集,包括训练集和测试集

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

'''下载训练数据集(图片+标签)'''
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)
  1. 打印设备信息:您的代码已经很好地检查了CUDA和MPS(针对Apple M系列芯片)的可用性,并设置了相应的设备。但是,在打印设备信息时,有一个小错误在字符串格式化上。您需要确保在字符串中正确地包含变量名。

  2. 打印数据形状:您已经正确地设置了DataLoader并打印了测试数据集中的一个批次的数据和标签的形状。这是一个很好的实践,可以帮助您了解数据的维度。

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)  # 通常训练时会打乱数据  
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=False)  # 测试时不需要打乱数据  
  
# 打印测试数据集的一个批次的数据和标签的形状  
for x, y in test_dataloader:  
    print(f"Shape of x [N,C,H,W]: {x.shape}")  # 注意这里的x是图像,但MNIST是灰度图,所以C=1  
    print(f"Shape of y: {y.shape}, {y.dtype}")  # y是标签,通常是一维的,且为long类型  
    break  
  
# 判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU  
device = "cuda" if torch.cuda.is_available() else ('mps' if torch.backends.mps.is_available() else "cpu")  
print(f"Using {device} device")  # 确保在字符串中正确地包含了变量名  
  

三、训练模型选择

一、创建一个具有多个隐藏层的神经网络,这些层都使用了nn.Linear来定义全连接层,并使用torch.sigmoid作为激活函数。

import torch  
import torch.nn as nn  
  
class NeuralNetwork(nn.Module):  
    def __init__(self):  
        super().__init__()  
        self.flatten = nn.Flatten()  
        self.hidden1 = nn.Linear(28 * 28, 256)  
        self.relu1 = nn.ReLU()  
        self.hidden2 = nn.Linear(256, 128)  
        self.relu2 = nn.ReLU()  
        self.hidden3 = nn.Linear(128, 64)  
        self.relu3 = nn.ReLU()  
        self.hidden4 = nn.Linear(64, 32)  
        self.relu4 = nn.ReLU()  
        self.out = nn.Linear(32, 10)  # 输出层对应于10个类别的得分  
  
       def forward(self, x):
        x = self.flatten(x)
        x = self.hidden1(x)
        x = torch.sigmoid(x)
        x = self.hidden2(x)
        x = torch.sigmoid(x)
        x = self.hidden3(x)
        x = torch.sigmoid(x)
        x = self.hidden4(x)
        x = torch.sigmoid(x)
        x = self.out(x)
        return x 
  
  
model = NeuralNetwork().to(device)  
print(model)  
  

二、定义了一个具有三个卷积层的CNN,每个卷积层后面都跟着ReLU激活函数,前两个卷积层后面还跟着最大池化层。最后,通过一个全连接层将卷积层的输出转换为10个类别的得分。

import torch  
import torch.nn as nn  
  
class CNN(nn.Module):  
    def __init__(self):  
        super(CNN, self).__init__()  
        self.conv1 = nn.Sequential(  
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),  
            nn.ReLU(),  
            nn.MaxPool2d(kernel_size=2),  
        )  
        self.conv2 = nn.Sequential(  
            nn.Conv2d(16, 32, 5, 1, 2),  
            nn.ReLU(),  
            nn.Conv2d(32, 32, 5, 1, 2),  
            nn.ReLU(),  
            nn.MaxPool2d(2),  
        )  
        self.conv3 = nn.Sequential(  
            nn.Conv2d(32, 64, 5, 1, 2),  
            nn.ReLU(),  
        )  
        self.out = nn.Linear(64 * 7 * 7, 10)  # 确保这里的输入特征数与卷积层输出后的特征数相匹配  
  
    def forward(self, x):  
        x = self.conv1(x)  
        x = self.conv2(x)  
        x = self.conv3(x)  # 输出应为(batch_size, 64, 7, 7)  
        x = x.view(x.size(0), -1)  # 展平操作,输出为(batch_size, 64*7*7)  
        output = self.out(x)  
        return output  
  
model = CNN().to(device)  
print(model)
  • in_channels=1:这指定了输入图像的通道数。

  • out_channels=16:这指定了卷积操作后输出的通道数,也就是卷积核(或称为滤波器)的数量。

  • kernel_size=5:这定义了卷积核的大小。

  • stride=1:这指定了卷积核在输入数据上滑动的步长。

  • padding=2:这定义了要在输入数据周围添加的零填充(zero-padding)的数量。

四、处理数据集和测试集

训练集处理:

def train(dataloader, model, loss_fn, optimizer):  
    model.train()  # 将模型设置为训练模式  
    batch_size_num = 1  # 这不是标准的用法,但在这里用作计数已处理批次的数量  
    for x, y in dataloader:  # 遍历数据加载器中的每个批次  
        x, y = x.to(device), y.to(device)  # 将数据和标签移动到指定的设备(如GPU)  
        pred = model(x)  # 通过模型进行前向传播  
        loss = loss_fn(pred, y)  # 计算预测和真实标签之间的损失  
        optimizer.zero_grad()  # 清除之前的梯度  
        loss.backward()  # 反向传播,计算当前梯度  
        optimizer.step()  # 更新模型的权重  
        loss_value = loss.item()
        if batch_size_num % 200 == 0:
            print(f"{loss_value:>7f}[number:{batch_size_num}]")#打印结果
        
        batch_size_num += 1  # 增加已处理批次的数量

测试集处理:

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            pred = model(x)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            a = (pred.argmax(1) == y)
            b = (pred.argmax(1) == y).type(torch.float)
    test_loss /= num_batches
    correct /= size
    print(f'Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}')

模型训练:

loss_fn = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

epochs = 10
for t in range(epochs):
    print(f"-----------------------------------------------\nepcho{t+1}")
    train(train_dataloader, model, loss_fn, optimizer)
print("Done!")
test(test_dataloader, model, loss_fn)

train(train_dataloader,model,loss_fn,optimizer)
test(test_dataloader,model, loss_fn)

结果:

神经网络:

cnn:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2158824.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

全栈开发(四):使用springBoot3+mybatis-plus+mysql开发restful的增删改查接口

1.创建user文件夹 作为增删改查的根包 路径 src/main/java/com.example.demo/user 2.文件夹里文件作用介绍 1.User(实体类) package com.example.demo.user; import com.baomidou.mybatisplus.annotation.TableId; import com.baomidou.mybatisplus.annotation.IdType; impo…

利用影刀实现批量发布文章的RPA流程(附视频演示)

前言 大家好,我是小智。在这篇文章中,我将分享一个实战案例,展示如何利用影刀实现批量发布文章的RPA流程。这里主要介绍其中一个简单步骤,其它步骤将通过视频演示。有使用方面的疑问可以留言。 影刀是一款强大的自动化工具&#x…

Matlab|考虑柔性负荷的综合能源系统低碳经济优化调度

目录 1 主要内容 2 部分代码 3 程序结果 4 下载链接 1 主要内容 程序主要实现的是考虑柔性负荷的综合能源系统低碳经济优化调度,模型参考《考虑柔性负荷的综合能源系统低碳经济优化调度》,求解方法采用的是混合整数规划算法,通过matlabc…

医学数据分析实训 项目四 回归分析--预测帕金森病病情的严重程度

文章目录 项目四:回归分析实践目的实践平台实践内容 预测帕金森病病情的严重程度作业(一)数据读入及理解(二)数据准备(三)模型建立(四)模型预测(五&#xff0…

如何使用cmd命令查看本机电脑的主机名?

1、按键盘win R 键,输入cmd,然后按一下【回车】 2、输入ping -a localhost , 然后按下【回车】 3、如下Ping 后面的DESKTOP-ALB9JF7即是本机电脑的【主机名】

浮动静态路由

浮动静态路由 首先我们知道静态路由的默认优先级是60&#xff0c;然后手动添加一条静态路由优先级为80的路由作为备份路由。当主路由失效的备份路由就会启动。 一、拓扑图 二、基本配置 1.R1: <Huawei>system-view [Huawei]sysname R1 [R1]interface GigabitEthernet…

linux的ssh命令使用介绍

目录 一、SSH的基本概念 二、SSH的工作原理 1、建立连接 2、密钥交换 3、认证 4、加密通信 三、SSH的主要功能 1、远程登录 2、文件传输 3、端口转发 四、SSH的安全性 五、SSH的应用场景 六、SSH的实现软件 一、SSH的基本概念 SSH主要用于登录远程服务器和执行命令、传输文…

使用Conda配置python环境到Pycharm------Window小白版

使用Conda配置python环境到Pycharm 一、Conda安装和环境配置1.1 安装Conda软件1.2 判断是否安装成功1.3 创建Conda虚拟环境 二、 pycharm的安装2.1 Pycharm使用手册2.2 安装pycharm 三、 pycharm导入Conda环境 一、Conda安装和环境配置 anaconda官网 1.1 安装Conda软件 运行…

TryHackMe 第4天 | Pre Security (三)

该学习路径讲解了网络安全入门的必备技术知识&#xff0c;比如计算机网络、网络协议、Linux命令、Windows设置等内容。过去两篇已经对计算机网络和网络协议进行了简单介绍&#xff0c;本篇博客将记录 Linux命令 部分。 Linux 系统的优点就是其轻量级&#xff0c;有些 Linux 系…

通过spring-boot创建web项目

依赖的软件 maven 1. 官网下载zip 文件&#xff0c;比如apache-maven-3.9.9-bin.zip 2. 解压到某个盘符&#xff0c;必须保证父亲目录的名字包含英文&#xff0c;数字&#xff0c;破折号&#xff08;-&#xff09; 3. 设置环境变量M2_HOME, 并将%M2_HOME%\bin添加到windown…

Linux:Bash中的文件描述符详解

相关阅读 Linuxhttps://blog.csdn.net/weixin_45791458/category_12234591.html?spm1001.2014.3001.5482 Linux中的所有进程&#xff0c;都拥有自己的文件描述符(File Descriptor, FD)&#xff0c;它是操作系统在管理进程和文件时的一种抽象概念。每个文件描述符由一个非负整…

2012-2019全球地表平均夜光年度数据

数据详情 2012-2019全球地表平均夜光年度数据 数据属性 数据名称&#xff1a;全球地表平均夜光年度数据 数据时间&#xff1a;2012-2019 空间位置&#xff1a;全球 数据格式&#xff1a;tif 空间分辨率&#xff1a;1500米 时间分辨率&#xff1a;年 坐标系&#xff1a;…

【自学笔记】支持向量机(3)——软间隔

引入 上一回解决了SVM在曲线边界的上的使用&#xff0c;使得非线性数据集也能得到正确的分类。然而&#xff0c;对于一个大数据集来说&#xff0c;极有可能大体呈线性分类趋势&#xff0c;但是边界处混杂&#xff0c;若仍采用原来的方式&#xff0c;会得到极其复杂的超平面边界…

高效高质量SCI论文撰写及投稿

第一章、论文写作准备即为最关键 1、科技论文写作前期的重要性及其分类 2、AI工具如何助力学术论文 3、研究主题确定及提高创新性 兴趣与背景&#xff1a;选择一个您感兴趣且有背景知识的研究领域。 创新性&#xff1a;选题和研究设计阶段如何提高学术创新性的方法。 研究缺…

FreeMarker 禁止自动转义标签-noautoesc

&#x1f496;简介 FreeMarker 是一个用 Java 语言编写的模板引擎&#xff0c;它被设计用来生成文本输出&#xff08;HTML 网页、电子邮件、配置文件等&#xff09;。在 FreeMarker 中&#xff0c;默认情况下&#xff0c;当你在模板中输出变量时&#xff0c;如果这些变量包含 …

应用密码学第一次作业(9.23)

一、Please briefly describe the objectives of information and network security,such as confidentiality, integrity, availability , authenticity , and accountability The objectives of information and network security include: Confidentiality: Protecting se…

在线思维导图怎么制作?只需要台这些组合分析法!

思维导图经历了漫长的进化&#xff0c;现已成为信息组织、记忆和头脑风暴的重要工具。其制作方式主要有手绘和软件两种&#xff0c;随着互联网的发展&#xff0c;软件制作因其便捷性和易于保存逐渐占据主导。如今&#xff0c;在线工具使得用户能够免费创建思维导图。本文将以即…

828华为云征文 | 云服务器Flexus X实例,Docker集成搭建Redis集群

828华为云征文 | 云服务器Flexus X实例&#xff0c;Docker集成搭建Redis集群 Redis 集群是一种分布式的 Redis 解决方案&#xff0c;能够在多个节点之间分片存储数据&#xff0c;实现水平扩展和高可用性。与传统的主从架构不同&#xff0c;Redis 集群支持数据自动分片、主节点故…

基于SpringBoot+Vue+MySQL的教学资料管理系统

系统展示 管理员后台界面 教师后台界面 系统背景 在当今信息化高速发展的时代&#xff0c;教育机构面临着日益增长的教学资料管理需求。为了提升教学管理的效率&#xff0c;优化资源的配置与利用&#xff0c;开发一套高效、便捷的教学资料管理系统显得尤为重要。基于SpringBoot…

通信工程学习:什么是MANO管理编排

MANO&#xff1a;管理编排 MANO&#xff1a;Management and Network Orchestration&#xff08;管理和网络编排&#xff09;在网络功能虚拟化&#xff08;NFV&#xff09;架构中扮演着至关重要的角色。MANO是一个由多个功能实体组合而成的层次&#xff0c;这些功能实体负责管理…