CNN实现手写数字识别(Pytorch)

news2024/9/24 11:20:53

CNN结构

CNN(卷积神经网络)主要包括卷积层、池化层和全连接层。输入数据经过多个卷积层和池化层提取图片信息后,最后经过若干个全连接层获得最终的输出。
在这里插入图片描述CNN的实现主要包括以下步骤:

  1. 数据加载与预处理
  2. 模型搭建
  3. 定义损失函数、优化器
  4. 模型训练
  5. 模型测试

以下基于Pytorch框架搭建一个CNN神经网络实现手写数字识别。

CNN实现

此处使用MNIST数据集,包含60000个训练样本和10000个测试样本。分为图片和标签,每张图片是一个 28 × 28 28 \times 28 28×28 的像素矩阵,标签是0~9一共10种数字。每个样本的格式为[data, label]。

1. 导入相关库

import numpy as np
import torch 
from torch import nn
from torchvision import datasets, transforms,utils
from PIL import Image
import matplotlib.pyplot as plt

2. 数据加载与预处理

# 定义超参数
batch_size = 128 # 每个批次(batch)的样本数

# 对输入的数据进行标准化处理
# transforms.ToTensor() 将图像数据转换为 PyTorch 中的张量(tensor)格式,并将像素值缩放到 0-1 的范围内。
# 这是因为神经网络需要的输入数据必须是张量格式,并且需要进行归一化处理,以提高模型的训练效果。
# transforms.Normalize(mean=[0.5],std=[0.5]) 将图像像素值进行标准化处理,使其均值为 0,标准差为 1。
# 输入数据进行标准化处理可以提高模型的鲁棒性和稳定性,减少模型训练过程中的梯度爆炸和消失问题。
transform = transforms.Compose([transforms.ToTensor(),
                               transforms.Normalize(mean=[0.5],std=[0.5])])

# 加载MNIST数据集
train_dataset = torchvision.datasets.MNIST(root='./data', 
                                           train=True, 
                                           transform=transform, 
                                           download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', 
                                          train=False, 
                                          transform=transform, 
                                          download=True)
                                          
# 创建数据加载器(用于将数据分次放进模型进行训练)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True, # 装载过程中随机乱序
                                           num_workers=2) # 表示2个子进程加载数据
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False,
                                          num_workers=2) 

加载完数据后,可以得到60000个训练样本和10000个测试样本

print(len(train_dataset))
print(len(test_dataset))

在这里插入图片描述

以及469个训练批次和79测试批次

# batch=128
# train_loader=60000/128 = 469 个batch
# test_loader=10000/128=79 个batch
print(len(train_loader))
print(len(test_loader))

在这里插入图片描述

打印前5个手写数字样本看看

for i in range(0,5):
    oneimg,label = train_dataset[i]
    grid = utils.make_grid(oneimg)
    grid = grid.numpy().transpose(1,2,0) 
    std = [0.5]
    mean = [0.5]
    grid = grid * std + mean
    # 可视化图像
    plt.subplot(1, 5, i+1)
    plt.imshow(grid)
    plt.axis('off')

plt.show()

在这里插入图片描述
这里用了 make_grid() 函数将多张图像拼接成一张网格图像,并调整了网格图像的形状,使得它可以直接作为 imshow() 函数的输入。这种方式可以在一张图中同时显示多张图像,比单独显示每张图像更加方便,常用于可视化深度学习中的卷积神经网络(CNN)中的特征图、卷积核等信息。
在 PyTorch 中,默认的图像张量格式是 (channel, height, width),即通道维度在第一个维度。 torchvision.transforms.ToTensor() 函数会将 PIL 图像对象转换为 PyTorch 张量,并将通道维度放在第一个维度。因此,当我们使用 ToTensor() 函数加载图像数据时,得到的 PyTorch 张量的格式就是 (channel, height, width)。代码中的 oneimg.numpy().transpose(1,2,0) 就是将 PyTorch 张量 oneimg 转换为 NumPy 数组,然后通过 transpose 函数将图像数组中的通道维度从第一个维度(channel-first)调整为最后一个维度(channel-last),即将 (channel, height, width) 调整为 (height, width, channel),以便于 Matplotlib 库正确处理通道信息。

2. 模型搭建

我们将使用Pytorch构建一个如下图所示的CNN,包含两个卷积层,和全连接层,并使用Relu作为激活函数。
在这里插入图片描述
接下来看以下不同层的参数。

卷积层: Connv2d

  • in_channels ——输入数据的通道数目
  • out_channels ——卷积产生的通道数目
  • kernel_size ——卷积核的尺寸
  • stride——步长
  • padding——输入数据的边缘填充0的层数

池化层: MaxPool2d

  • kernel_siez ——池化核大小
  • stride——步长
  • padding——输入数据的边缘填充0的层数

全连接层: Linear

  • in_features:输入特征数
  • out_features:输出特征数

代码实现如下:

class CNN(nn.Module):
    # 定义网络结构
    def __init__(self):
        super(CNN, self).__init__()
        # 图片是灰度图片,只有一个通道
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, 
                               kernel_size=5, stride=1, padding=2)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, 
                               kernel_size=5, stride=1, padding=2)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(in_features=7*7*32, out_features=256)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(in_features=256, out_features=10)
	
    # 定义前向传播过程的计算函数
    def forward(self, x):
        # 第一层卷积、激活函数和池化
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        # 第二层卷积、激活函数和池化
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        # 将数据平展成一维
        x = x.view(-1, 7*7*32)
        # 第一层全连接层
        x = self.fc1(x)
        x = self.relu3(x)
        # 第二层全连接层
        x = self.fc2(x)
        return x

定义损失函数和优化函数

import torch.optim as optim

learning_rate = 0.001 # 学习率

# 定义损失函数,计算模型的输出与目标标签之间的交叉熵损失
criterion = nn.CrossEntropyLoss()
# 训练过程通常采用反向传播来更新模型参数,这里使用的是SDG(随机梯度下降)优化器
# momentum 表示动量因子,可以加速优化过程并提高模型的泛化性能。
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
#也可以选择Adam优化方法
# optimizer = torch.optim.Adam(net.parameters(),lr=1e-2)

3. 模型训练

model = CNN() # 实例化CNN模型
num_epochs = 10 # 定义迭代次数

# 如果可用的话使用 GPU 进行训练,否则使用 CPU 进行训练。
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 将神经网络模型 net 移动到指定的设备上。
model = model.to(device)
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (images,labels) in enumerate(train_loader):
        images=images.to(device)
        labels=labels.to(device)
        optimizer.zero_grad() # 清空上一个batch的梯度信息
        # 将输入数据 inputs 喂入神经网络模型 net 中进行前向计算,得到模型的输出结果 outputs。
        outputs=model(images) 
        # 使用交叉熵损失函数 criterion 计算模型输出 outputs 与标签数据 labels 之间的损失值 loss。
        loss=criterion(outputs,labels)
        # 使用反向传播算法计算模型参数的梯度信息,并使用优化器 optimizer 对模型参数进行更新。
        loss.backward()
         # 更新梯度
        optimizer.step()
        # 输出训练结果
        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, loss.item()))

print('Finished Training')

在这里插入图片描述

保存模型

# 模型保存
PATH = './mnist_net.pth'
torch.save(model.state_dict(), PATH)

4. 模型测试

# 测试CNN模型
with torch.no_grad(): # 进行评测的时候网络不更新梯度
    correct = 0
    total = 0
    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0
        correct += (predicted == labels).sum().item()
    print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

在这里插入图片描述
这里训练的模型准确率达到了98%,非常高,如果还想继续提高模型准确率,可以调整迭代次数、学习率等参数或者修改CNN网络结构实现。

可视化检验一个批次测试数据的准确性

# 将 test_loader 转换为一个可迭代对象 dataiter
dataiter = iter(test_loader)
# 使用 next(dataiter) 获取 test_loader 中的下一个 batch 的图像数据和标签数据
images, labels = next(dataiter)

# print images
test_img = utils.make_grid(images)
test_img = test_img.numpy().transpose(1,2,0)
std = [0.5]
mean =  [0.5]
test_img = test_img*std+0.5
plt.imshow(test_img)
plt.show()
plt.savefig('./mnist_net.png')
print('GroundTruth: ', ' '.join('%d' % labels[j] for j in range(128)))

在这里插入图片描述

参考来源:
使用Pytorch框架的CNN网络实现手写数字(MNIST)识别
PyTorch初探MNIST数据集

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/488370.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SDN — OpenvSwitch 软硬件融合加速方案

目录 文章目录 目录OVS-DPDKOvS-DPDK v.s. SR-IOV东西向流量南北向流量 / 跨服务器东西流量 OVS Hardware OffloadOVS-DPDK Hardware OffloadDPDK Hardware offloadOvS-DPDK Hardware offloadOvS-DPDK Hardware offload with vDPA OVS-DPDK 上图中的深色模块就是引入 DPDK 的相…

Sublime软件及html相关软件安装

Sublime软件及html相关软件安装 下载Sublime编译器并安装下载链接: [https://www.sublimetext.com/3](https://www.sublimetext.com/3)安装emmet自动补全插件 验证 下载Sublime编译器并安装 下载链接: https://www.sublimetext.com/3 安装emmet自动补全插件 第一步&#xff1…

css分享 | 常用按钮效果记录(关注追加)

今日分享几个css样式,在日常业务中,我们会追求更友好的交互体验,所以记录一些业务中常用的按钮样式,下次遇到可以拿来即用。 目录 1.按钮水波纹点击效果 2.流光波光闪烁效果 3.按钮点击立体效果 4.按钮悬停出现箭头效果 1.按钮…

玩客云刷armbian证书错误server certificate verification failed

文章目录 前言大概操作:1、换http源(感觉https应该也行)2、修改armbian.list3、证书认证4、更新软件源、索引5、安装证书、更新证书6、禁用ssl7、手动添加网站证书(好像失败了)8、安装debian软件包公钥(好像…

uni push2.0使用

uni push2.0配置 需要开通uniCloud服务(推荐阿里云) 生成证书:安卓(https://ask.dcloud.net.cn/article/68),ios(https://docs.getui.com/getui/mobile/ios/apns/) 进入开发者中心…

C++ 1.基础语法

1.using namespace std; 建议a:项目中尽量不要用上述语句。b:日常练习中使用。c:项目中指定名空间访问展开常用。 这个语句表示标准库的东西都放到std,为了解决自己定义的名字和库名发生冲突。如果定义和库名冲突的名字&#xf…

基于 Docker 的 MySQL GTID 主从复制与测试

目录 一、规划1.1 基础环境1.2 应用架构1.3 路径规划 二、部署2.1 服务部署2.2 主从配置2.2.1 主从同步配置2.2.2 主主同步配置 2.3 主从验证2.3.1 主从同步验证2.3.2 主主同步验证 2.4 客户端连接2.4.1 控制台2.4.2 图形化 三、压测3.1 安装 sysbench3.2 sysbench 压测3.2.1 读…

玩具蛇+正则问题(JAVA解法)

玩具蛇:用户登录 题目描述 本题为填空题,只需要算出结果后,在代码中使用输出语句将所填结果输出即可。 小蓝有一条玩具蛇,一共有 16 节,上面标着数字 1 至 16。每一节都是一个正方形的形状。相邻的两节可以成直线或…

5 个冷门且实用的 Kubectl 使用技巧

kubectl 是 K8s 官方附带的命令行工具,可以方便的操作 K8s 集群。这篇文章主要介绍一些 kubectl 的别样用法,希望读者有一定基础的 K8s 使用经验。 有一篇文章也介绍了一些技巧,写博客的时候正好搜到了,正好也分享出来吧。 Ready…

【Linux】 OpenSSH_7.4p1 升级到 OpenSSH_8.7p1(亲测无问题,建议收藏❤)

🍁博主简介 🏅云计算领域优质创作者   🏅华为云开发者社区专家博主   🏅阿里云开发者社区专家博主 💊交流社区:运维交流社区 欢迎大家的加入! 文章目录 文章声明前述安装一些必要的命令&…

细讲shell中的循环语句--for、while、until

目录 一:何为循环 1.循环概述 2.使用循环的好处 二:for循环语句 1.for语句的用法 ​2. 语法结构 (1)一般格式 (2)类C语言格式 (3)死循环 3.事例 ​4.常用转义符 ​5.制作九九乘法表 …

Winform从入门到精通(39)——ToolStrip(史上最全)更新中

1、Name获取控件对象 2、AllowDrop 3、AllowItemReorder 4、AllowMerge 5、Anchor 设置ToolStrip如何锚定父控件 6、AutoSize 设置ToolStrip的尺寸大小是否根据Font属性的变化而变化 7、BackColor 设置ToolStrip的背景色 8、BackgroundImage 设置背景图像 9、Back…

精炼计算机网络——序章(二)

文章目录 前言1.4 计算机网络在我国的发展1.5 计算机网络的类别1.5.1计算机网络的定义1.5.2 几种不同类别的计算机网络 1.6 计算机网络的性能1.6.1 计算机网络的性能指标1.6.2 计算机网络的非性能特征 1.7 计算机网络体系结构1.7.1 计算机网络体系结构的形成1.7.2 协议与划分层…

ChatExcel?

大家好,我是章北海mlpy 最近在浅学LangChain,在大模型时代,感觉这玩意很有前途。 LangChain是一个开源的应用开发框架,目前支持Python和TypeScript两种编程语言。 它赋予LLM两大核心能力:数据感知,将语言模…

5月1日起正式实施!图解《关键信息基础设施安全保护要求》

2023年5月1日,GB/T 39204-2022《信息安全技术 关键信息基础设施安全保护要求》将正式实施。作为我国第一项关键信息基础设施安全保护的国家标准,对于我国关键信息基础设施安全保护有着极为重要的指导意义。 《信息安全技术 关键信息基础设施安全保护要求…

Swift 技术 监听电话中断,音乐(用于恢复播放音乐)(源码)

一直觉得自己写的不是技术,而是情怀,一个个的教程是自己这一路走来的痕迹。靠专业技能的成功是最具可复制性的,希望我的这条路能让你们少走弯路,希望我能帮你们抹去知识的蒙尘,希望我能帮你们理清知识的脉络&#xff0…

【高烧39°考研上岸】23上交819考研经验分享

笔者来自通信考研小马哥23上交819全程班学员 一,基本情况介绍和考研经历 大家好,首先介绍一下我的基本情况。我本科毕业于东南大学,报考的是上海交通大学电子系电子与通信工程专业(专业学位)。我二战上岸&#xff0c…

Selenium:HTML测试报告

自动化测试过程中,获得用例的执行结果后,需要有具象化、简洁明了的测试结果,比如:用例执行时间、失败用例数、失败的原因等,这时候,就需要用到测试报告。 HTML测试报告是python语言自带的单元测试框架&…

等保各项费用支出明细

等保收费主要依据文件: 等保工作的定级指南文件_luozhonghua2000的博客-CSDN博客 Q7:做等级保护要多少钱? 答:开展等级保护工作主要包含:规划费用、建设或整改费用、运维费用、测评费用等,具体费用因各单位现状、保护对象承载业务功能、重要程度、所在地区等差异较大。 …

Input事件在应用中的传递(一)

Input事件在应用中的传递(一) hongxi.zhu 2023-4-25 前面我们已经梳理了input事件在native层的传递,这一篇我们接着探索input事件在应用中的传递与处理,我们将按键事件和触摸事件分开梳理,这一篇就只涉及按键事件。 一、事件的接收 从前面的…