CNN手写数字识别1——模型搭建与数据准备

news2025/3/12 19:16:33

模型搭建

我们这次使用LeNet模型,LeNet是一个经典的卷积神经网络(Convolutional Neural Network, CNN)架构,最初由Yann LeCun等人在1998年提出,用于手写数字识别任务

创建一个文件model.py。实现以下代码。

源码

# 导入PyTorch库
import torch
# 从PyTorch库中导入神经网络模块
from torch import nn
# 从torchsummary库中导入summary函数,用于打印模型的结构和参数数量
from torchsummary import summary

# 定义LeNet类,它继承自nn.Module,是一个神经网络模型
class LeNet(nn.Module):
    # 初始化函数,定义模型的层次结构
    def __init__(self):
        # 调用父类的初始化函数
        super().__init__()
        # 第一个卷积层,输入通道为1,输出通道为6,卷积核大小为5x5,padding为2
        self.c1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2)
        # Sigmoid激活函数
        self.sig = nn.Sigmoid()
        # 第一个平均池化层,池化窗口为2x2,步长为2
        self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)
        # 第二个卷积层,输入通道为6,输出通道为16,卷积核大小为5x5
        self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        # 第二个平均池化层,池化窗口为2x2,步长为2
        self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)
        # Flatten层,用于将多维的输入一维化,以便输入到全连接层
        self.flatten = nn.Flatten()
        # 第一个全连接层,输入特征数为400,输出特征数为120
        self.f5 = nn.Linear(400, 120)
        # 第二个全连接层,输入特征数为120,输出特征数为84
        self.f6 = nn.Linear(120, 84)
        # 第三个全连接层,输入特征数为84,输出特征数为10(通常对应分类任务中的类别数)
        self.f7 = nn.Linear(84, 10)

    # 前向传播函数,定义数据通过网络的方式
    def forward(self, x):
        x = self.sig(self.c1(x))  # 通过第一个卷积层和Sigmoid激活函数
        x = self.s2(x)            # 通过第一个平均池化层
        x = self.sig(self.c3(x))  # 通过第二个卷积层和Sigmoid激活函数
        x = self.s4(x)            # 通过第二个平均池化层
        x = self.flatten(x)       # 通过Flatten层
        x = self.sig(self.f5(x))  # 通过第一个全连接层和Sigmoid激活函数
        x = self.sig(self.f6(x))  # 通过第二个全连接层和Sigmoid激活函数
        x = self.sig(self.f7(x))  # 通过第三个全连接层和Sigmoid激活函数
        return x

# 主函数
if __name__ == "__main__":
    # 自动检测是否有可用的GPU,如果有则使用GPU,否则使用CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 实例化LeNet模型,并将其移动到指定的设备上(GPU或CPU)
    model = LeNet().to(device)
    # 使用torchsummary的summary函数打印模型的结构和参数数量,输入形状为(1, 28, 28)
    print(summary(model, (1, 28, 28)))

源码解析

神经网络构建

LeNet网络主要由卷积层、池化层、激活函数和全连接层组成。有2个卷积层,2个池化层,3个全连接层。

  • 卷积层‌(nn.Conv2d):用于提取图像中的特征。这里有两个卷积层,第一个卷积层有6个输出通道,第二个卷积层有16个输出通道,卷积核大小都是5x5。
  • 激活函数‌(nn.Sigmoid):用于引入非线性,使得网络能够学习更复杂的模式。这里使用了Sigmoid激活函数。
  • 池化层‌(nn.AvgPool2d):用于降低特征图的尺寸,减少计算量,同时保留重要特征。这里使用了平均池化层,池化窗口大小为2*2,步长为2。
  • Flatten层‌(nn.Flatten):用于将多维的特征图展平成一维向量,以便输入到全连接层。
  • 全连接层‌(nn.Linear):用于分类任务,将特征向量映射到类别空间。这里有三个全连接层,分别将特征维度从400降到120,再从120降到84,最后从84降到10(对应10个类别)。

参数计算

从代码中可以看到每一层神经网络都有自己的参数,这里面通道数,卷积核大小,步长和感受野,一定程度上可以当做超参数人为自由设定,其他的参数都需要事先根据输入数据进行计算。

首先,假设输入的图像数据大小都是28*28*1,即宽28个像素,高28个像素,由于是灰度图所以色彩通道只有1。

在第一个卷积层c1,卷积核大小是5*5,卷积核个数是6个,步长默认是1,填充是2,这些是我们人为设定的。可可以得出输出通道数=卷积核个数=6,经过这一层的输出数据通道数为6,尺寸通过公式计算:

O=\frac{IN+2P-F}{S}+1 

公式里面O是输出的宽/高,IN是输入的宽/高,P是填充,F是卷积核的宽/高或者感受野的宽/高,S是步长。

即输出图像的宽高是(28+2*2-5)/1+1=28。输出数据量是28*28*6。可以看到卷积层可以有效提升数据的通道数。

在第一个池化层s2,感受野是2*2,步长是2,填充默认是0,可以计算出输出图像的宽高是(28+2*0-2)/2+1=14。可以看到经过池化层之后数据的特征量明显减少,此时输出的数据量是14*14*6(池化不会改变通道数)。

在第二个卷积层c3,卷积核大小是5*5,步长默认是1,填充默认是0,有16个卷积核,也就是通道数增加到了16。根据公式可以计算出输出图像的宽高是(14+2*0-5)/1+1=10。输出数据量是10*10*16。

在第二个池化层s4,感受野是2*2,步长为2,那么输出数据宽高就是(10+2*0-2)/2+1=5。输出数据量是5*5*16。

到全连接层的第一层就比较关键了,因为这里的参数有一个“输入特征数”,也就是刚才算的5*5*16=400。如果前面的计算不对,到了这一步模型是会报错的,因此每一层的输出特征量都需要计算出来。

后面的全连接层就比较好写了,输入的特征量是上一层全连接层的神经元个数。

前向传播

前向传播定义了数据通过网络的方式。对于输入x,它首先通过第一个卷积层和Sigmoid激活函数,然后通过第一个平均池化层;接着通过第二个卷积层和Sigmoid激活函数,再通过第二个平均池化层;最后通过Flatten层将多维特征展平成一维向量,并依次通过三个全连接层和Sigmoid激活函数得到最终输出。

验证

在主函数中,我们首先检测是否有可用的GPU,并将模型移动到合适的计算设备上(GPU或CPU)。然后,我们使用torchsummary的summary函数打印模型的结构和参数数量,以便了解模型的复杂度和计算需求。

从图中可以看出,池化层是不包含参数的,整个模型的大部分参数都在全连接层(48120+10164+850 = 59134,将近6万个参数在伺候全连接层)。

数据准备

FashionMNIST是一个流行的数据集,包含了10种类别的70,000个灰度图像,通常用于计算机视觉和机器学习的教学与研究。我们这次通过远程下载的方式来获取数据。

另外创建一个plot.python。用来下载数据和预览数据。这部分代码可写可不写,模型训练的时候还会重新加载数据。

源码

# 导入必要的库和模块
from torchvision import transforms  # 用于图像预处理的变换
from torchvision.datasets import FashionMNIST  # 导入FashionMNIST数据集
import torch.utils.data as Data  # 用于数据加载的实用工具
import numpy as np  # 导入NumPy库,用于数值计算
import matplotlib.pyplot as plt

# 准备训练数据
train_data = FashionMNIST(
    root='./data',  # 数据集存储的根目录
    train=True,  # 指定为训练数据集
    transform=transforms.Compose([  # 图像预处理步骤
        transforms.Resize(size=224),  # 将图像大小调整为224x224
        transforms.ToTensor()  # 将图像转换为PyTorch张量
    ]),
    download=True  # 如果数据集不存在,则下载
)

# 创建数据加载器
train_loader = Data.DataLoader(
    dataset=train_data,  # 指定数据集
    batch_size=64,  # 每个批次的大小
    shuffle=True,  # 在每个epoch开始时打乱数据
    num_workers=0  # 使用0个工作线程(对于Windows系统,有时需要设置为0以避免多进程问题)
)


# 遍历数据加载器
for step, (b_x, b_y) in enumerate(train_loader):
    if step > 0:  # 只处理第一个批次的数据
        break
# 将PyTorch张量转换为NumPy数组
batch_x = b_x.squeeze().numpy()  # 移除批次维度(如果可能),并转换为NumPy数组
batch_y = b_y.numpy()  # 将标签转换为NumPy数组

# 获取数据集中的类别标签
class_label = train_data.classes  # 这是一个包含所有类别名称的列表
# 打印类别标签
print(class_label)  # 输出类别标签列表

# 设置图形的大小
plt.figure(figsize=(12, 5))


# 遍历batch_y中的每一个元素,即每一个样本的标签
for ii in np.arange(len(batch_y)):
    # 创建子图,4行16列,第ii+1个子图
    # 这里假设一个批次有64个样本,因此用4x16的布局来显示它们
    plt.subplot(4, 16, ii + 1)

    # 显示图像
    # batch_x[ii, :, :]表示第ii个样本的图像数据
    # cmap=plt.cm.gray指定使用灰度色彩映射
    plt.imshow(batch_x[ii, :, :], cmap=plt.cm.gray)

    # 设置标题为对应的类别标签
    # class_label[batch_y[ii]]根据标签索引获取类别名称
    # size=10设置标题字体大小
    plt.title(class_label[batch_y[ii]], size=10)

    # 关闭坐标轴显示
    plt.axis("off")

# 调整子图之间的间距
# wspace=0.05设置子图之间的宽度间距
plt.subplots_adjust(wspace=0.05)

# 显示图形
plt.show()

源码解析

下载和加载数据

首先准备训练数据。FashionMNIST数据集将被下载到指定的根目录,并进行图像预处理

为了高效地加载数据,我们使用PyTorch的DataLoader来创建数据加载器。

dataloader的参数解释如下:

  • dataset:指定要加载的数据集。
  • batch_size:每个批次加载的样本数。
  • shuffle:是否在每个epoch开始时打乱数据。
  • num_workers:加载数据时使用的工作线程数。在Windows系统上,有时需要设置为0以避免多进程问题。

展示数据

我们遍历数据加载器,但只处理第一个批次的数据(为了简化示例)。使用squeeze()方法移除批次维度(如果可能),并将PyTorch张量转换为NumPy数组,以便使用matplotlib进行可视化。随后使用matplotlib的subplot()方法创建子图,并在每个子图中显示一个图像样本。我们使用灰度色彩映射(cmap=plt.cm.gray)来显示图像。最后,使用plt.show()方法显示图形。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2299711.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度学习04 数据增强、调整学习率

目录 数据增强 常用的数据增强方法 调整学习率 学习率 调整学习率 ​调整学习率的方法 有序调整 等间隔调整 多间隔调整 指数衰减 余弦退火 ​自适应调整 自定义调整 数据增强 数据增强是通过对训练数据进行各种变换(如旋转、翻转、裁剪等)&am…

PH热榜 | 2025-02-16

1. Cal.com Routing 标语:根据客户线索,系统会智能地自动安排约会。 介绍:告别繁琐的排期!Cal.com 推出了新的路由功能,能更智能地分配预约,让你的日程安排更顺畅。这项功能运用智能逻辑和深入的数据分析…

数据库基本概念及基本使用

数据库基本概念 什么是数据库: 数据库特点: 常见的数据库软件: 不同的公司进行不同的实践,生成了不同的产品。 比如买汽车,汽车只是一个概念,你要买哪个牌子哪个型号的汽车,才是真正的汽车的一…

gozero实现数据库MySQL单例模式连接

在 GoZero 框架中实现数据库的单例连接可以通过以下步骤来完成。GoZero 使用 gorm 作为默认的数据库操作框架,接下来我会展示一个简单的单例模式实现。 ### 1. 定义数据库连接的单例结构 首先,你需要定义一个数据库连接的结构体,并在初始化…

CSS flex布局 列表单个元素点击 本行下插入详情独占一行

技术栈:Vue2 javaScript 简介 在实际开发过程中有遇到一个场景:一个list,每行个数固定,点击单个元素后,在当前行与下一行之间插入一行元素详情,便于更直观的查看到对应的数据详情。 这种情形&#xff0c…

无人机航迹规划: 梦境优化算法(Dream Optimization Algorithm,DOA)求解无人机路径规划MATLAB

一、梦境优化算法 梦境优化算法(Dream Optimization Algorithm,DOA)是一种新型的元启发式算法,其灵感来源于人类的梦境行为。该算法结合了基础记忆策略、遗忘和补充策略以及梦境共享策略,通过模拟人类梦境中的部分记忆…

权限五张表

重点:权限五张表的设计 核心概念: 在权限管理系统中,经典的设计通常涉及五张表,分别是用户表、角色表、权限表、用户角色表和角色权限表。这五张表的设计可以有效地管理用户的权限,确保系统的安全性和灵活性。 用户&…

Docker-数据卷

1.数据卷 容器是隔离环境,容器内程序的文件、配置、运行时产生的容器都在容器内部,我们要读写容器内的文件非常不方便。大家思考几个问题: 如果要升级MySQL版本,需要销毁旧容器,那么数据岂不是跟着被销毁了&#xff1…

IT : 是工作還是嗜好? Delphi 30周年快乐!

又到2月14日了, 自从30多年前收到台湾宝蓝(Borland)公司一大包的3.5 磁盘片, 上面用黑色油性笔写着Delphi Beta开始, Delphi便和我的工作生涯有了密不可分的关系. 一年后Delphi大获成功, 自此对于使用Delphi的使用者来说2月14日也成了一个特殊的日子! 我清楚记得Delphi Beta使用…

DeepPose

目录 摘要 Abstract DeepPose 算法框架 损失函数 创新点 局限性 训练过程 代码 总结 摘要 DeepPose是首个将CNN应用于姿态估计任务的模型。该模型在传统姿态估计方法的基础上,通过端到端的方式直接从图像中回归出人体关键点的二维坐标,避免了…

[HarmonyOS]鸿蒙(添加服务卡片)推荐商品 修改卡片UI(内容)

什么是服务卡片 ? 鸿蒙系统中的服务卡片(Service Card)就是一种轻量级的应用展示形式,它可以让用户在不打开完整应用的情况下,快速访问应用内的特定功能或信息。以下是服务卡片的几个关键点: 轻量级&#…

DeepSeek R1 本地部署和知识库搭建

一、本地部署 DeepSeek-R1,是幻方量化旗下AI公司深度求索(DeepSeek)研发的推理模型 。DeepSeek-R1采用强化学习进行后训练,旨在提升推理能力,尤其擅长数学、代码和自然语言推理等复杂任务 。 使用DeepSeek R1, 可以大大…

领域驱动设计叕创新,平安保险申请DDD专利

DDD领域驱动设计批评文集 做强化自测题获得“软件方法建模师”称号 《软件方法》各章合集 见下图: 这个名字拼得妙:领域驱动设计模式。 是领域驱动设计?还是设计模式?还是领域驱动设计设计模式?和下面这个知乎文章的…

团体程序设计天梯赛-练习集——L1-041 寻找250

前言 10分的题,主要的想法就一个,按这个想法可以出几个写法 L1-041 寻找250 对方不想和你说话,并向你扔了一串数…… 而你必须从这一串数字中找到“250”这个高大上的感人数字。 输入格式: 输入在一行中给出不知道多少个绝对值…

C#控制台大小Console.SetWindowSize函数失效解决

在使用C#修改控制台大小相关API会失效. 由于VS将控制台由命令提示符变成了终端,因此在设置大小时会出现问题 测试代码: Console.SetWindowSize(100, 50);

spring boot 对接aws 的S3 服务,实现上传和查询

1.aws S3介绍 AWS S3(Amazon Simple Storage Service)是亚马逊提供的一种对象存储服务,旨在提供可扩展、高可用性和安全的数据存储解决方案。以下是AWS S3的一些主要特点和功能: 1.1. 对象存储 对象存储模型:S3使用…

25/2/16 <算法笔记> DirectPose

DirectPose 是一种直接从图像中预测物体的 6DoF(位姿:6 Degrees of Freedom)姿态 的方法,包括平移和平面旋转。它在目标检测、机器人视觉、增强现实(AR)和自动驾驶等领域中具有广泛应用。相比于传统的位姿估…

数据结构-8.Java. 七大排序算法(下篇)

本篇博客给大家带来的是排序的知识点, 由于时间有限, 分两天来写, 下篇主要实现最后一种排序算法: 归并排序。同时把中篇剩下的快排非递归实现补上. 文章专栏: Java-数据结构 若有问题 评论区见 欢迎大家点赞 评论 收藏 分享 如果你不知道分享给谁,那就分享给薯条. 你们的支持是…

DeepSeek私有化部署+JAVA通过API调用离线大模型问答

在当今快速发展的数字化时代,企业对于高效、灵活的技术解决方案需求日益增长。DeepSeek作为一款领先的智能搜索与分析平台,凭借其强大的数据处理能力和精准的搜索结果,已经成为众多企业提升运营效率的得力助手。为了更好地满足企业对数据安全…

【吾爱出品】针对红警之类老游戏适用WIN10和11的补丁cnc-ddraw7.1汉化版

针对红警之类老游戏适用WIN10和11的补丁cnc-ddraw7.1汉化版 链接:https://pan.xunlei.com/s/VOJ8PZd4avMubnDzHQAeZDxWA1?pwdnjwm# 直接复制到游戏安装目录,保持与游戏主程序同目录下。