基于Pytorch的CNN手写数字识别

news2024/11/16 3:49:00

作为深度学习小白,我想把自己学习的过程记录下来,作为实践部分,我会写一个通用框架,并会不断完善这个框架,作为自己的入门学习。因此略过环境搭建和基础知识的步骤,直接从代码实战开始。

一.下载数据集并加载

在这里使用MINST开源数字识别数据集。

首先导入必要的库,设置训练的设备(gpu或cpu),设置训练的轮次(epoch),然后设置数据集train_data、test_data,并使用torchvision的datasets来读取,下载的MINSt数据集被保存在当前路径的dataset文件夹下,对于训练集和测试集分别设置train的参数,最后把它转成tensor张量。

接着对设置好的数据集进行读取,调用了torch.utils.data下的DataLoader,分别读取训练集和测试集,同时设置batch_size,即为每一次读取多少张图片,然后对训练集数据进行展平(通常测试集不需要)。

# 搭建CNN卷积神经网络对MNIST数据集实现数字识别

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
import cv2
import matplotlib.pyplot as plt
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epoch = 10


train_data = datasets.MNIST("./dataset", train=True,download=True,transform=transforms.ToTensor())
test_data = datasets.MNIST("./dataset", train=False, download=True,transform=transforms.ToTensor())

train_loader = DataLoader(train_data, batch_size=16, shuffle=True)
test_loader = DataLoader(test_data, batch_size=16, shuffle=False)

二.定义训练网络

其中super().__init__()允许我们调用父类(nn.Module)的方法,

对于卷积操作nn.Conv2d(输入通道数,输出通道数,卷积核尺寸,步长,padding大小)参数如此,因为输入为灰度图,则对于第一个卷积的输入通道数等于1,最后线性层会输出一个包含10个数据的变量,分别代表10个数字(类别)的概率。

然后,我们实例化model为网络的对象,定义损失函数为交叉熵损失函数,使用Adam优化器对参数(model.parameters())进行优化,初始化学习率为0.001,并调用学习率更新器。

class Dight(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 10, 5),  #输入:batch*1*28*28  输出:batch*10*24*24(28 -5 + 1)
            nn.ReLU(),  #保持shape不变  输出:batch*10*24*24(28 -5 + 1)
            nn.MaxPool2d(2),   #输入:batch*10*24*24(28 -5 + 1) 输出:batch*10*12*12
            nn.Conv2d(10, 20, 3),   #输入:batch*10*12*12  输出:batch*20*10*10(12 - 3 + 1)
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(20*10*10, 500),   #输入:batch2000   输出:batch 500
            nn.ReLU(),    #保持shape不变
            nn.Linear(500, 10)  #输入:batch 500  输出:batch 10
        )

    def forward(self, x):
        return self.model(x)
    
model = Dight()
model = model.to(device)
loss_fn = nn.CrossEntropyLoss()
loss_fn =  loss_fn.to(device)
optimizer = optim.Adam(model.parameters(), lr = 0.001)
scheduler = StepLR(optimizer, step_size = 5, gamma = 0.5)

三.开始训练

使用model.train()开始训练,使用for循环遍历数据集中的数据(imgs)和标签(targets),对梯度初始化,将数据传入model进行前向传播,并输出前向传播结果(outputs),根据outputs和给定的标签targets计算交叉熵损失loss,根据loss进行反向传播,根据反向传播更新模型参数。

同时,每1000步打印一下当前的步数和loss,用于观察训练进度和效果。

#定义训练方法
def train():
    #模型训练
    model.train()
    train_step = 0
    for batch_index, (imgs, targets) in enumerate(train_loader):
        #部署到device上
        imgs, targets = imgs.to(device), targets.to(device)
        #梯度初始化为0
        optimizer.zero_grad()
        #训练后的结果
        outputs = model(imgs)
        #计算损失
        loss = loss_fn(outputs, targets)   #交叉熵损失,适用于多分类任务,二分类适用于sigmoid
        #反向传播
        loss.backward()
        #参数更新
        optimizer.step()

        train_step += 1
        if train_step % 1000 == 0:
            print(f"train Epoch: {train_step} , Loss: {loss.item()}")

四.测试方法

我们会使用测试集对网络进行验证,通过model.eval()对模型进行验证,因为验证时不会计算梯度也不算反向传播,所以与训练不同的是需要使用语句with torch.no_grad(),同样的对测试集进行遍历(这里也可以仿照训练时的写法),之后,同样的计算outputs和loss,还会对test_loss和accuracy进行累计,观察网络在测试集的效果

#定义测试方法
def test():
    #模型验证
    model.eval()
    #正确率
    accuracy = 0.0
    #测试损失
    test_loss = 0.0


    with torch.no_grad():  #不会计算梯度也不会反向传播
        for imgs, targets in test_loader:
            #部署到device上
            imgs, targets = imgs.to(device), targets.to(device)
            #测试数据
            outputs = model(imgs)
            #计算测试损失
            loss = loss_fn(outputs, targets)
            test_loss += loss.item()


            #累计正确的值
            accuracy += (outputs.argmax(1) == targets).sum().item()
        
        test_loss /= len(test_loader)
        accuracy /= len(test_data)
        print(f"整体测试集上的损失: {test_loss},准确率 : {accuracy}")

 五.模型保存

调用

torch.save(model, "my_CNN.pth")

print("模型已保存")

即可

整合上面代码

if __name__ == "__main__":
    #调用方法
    for epoch in range(1, epoch + 1):
        print(f"-------------------第{epoch}轮训练开始------------------")
        train()
        # 调整学习率
        scheduler.step()

        test()

    torch.save(model, "my_CNN.pth")
    print("模型已保存")

六.结果测试

创建另一个py文件,输入任意一张数字图片,对图片的数字进行预测(多分类)。

打开image,并将它resize为28*28,如这里使用的3.jpg为

 用torch.load()加载模型

from PIL import Image
import torchvision
import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential

img_path = "/home/lm/数字识别/picture/3.jpg"
image = Image.open(img_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = torchvision.transforms.Compose([torchvision.transforms.Resize((28, 28)),
                                            torchvision.transforms.ToTensor()])

image = transform(image)

class Dight(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 10, 5),  #输入:batch*1*28*28  输出:batch*10*24*24(28 -5 + 1)
            nn.ReLU(),  #保持shape不变  输出:batch*10*24*24(28 -5 + 1)
            nn.MaxPool2d(2),   #输入:batch*10*24*24(28 -5 + 1) 输出:batch*10*12*12
            nn.Conv2d(10, 20, 3),   #输入:batch*10*12*12  输出:batch*20*10*10(12 - 3 + 1)
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(20*10*10, 500),   #输入:batch2000   输出:batch 500
            nn.ReLU(),    #保持shape不变
            nn.Linear(500, 10)  #输入:batch 500  输出:batch 10
        )

    def forward(self, x):
        return self.model(x)

model = torch.load("/home/lm/数字识别/my_CNN.pth")


image = torch.reshape(image, (1,1,28,28)).to(device)
model.eval()
with torch.no_grad():
    output = model(image)
print(output)

print(output.argmax(1))

最终输出为

tensor([[-14.0138,  -4.8722,  -7.2821, -11.5329,   6.1589,  -8.7089,  -7.8535,
          -6.8521,  -5.4265,  -7.6144]], device='cuda:0')
tensor([4], device='cuda:0')

可以看出模型可以正确预测出图片类别

七.数据集转换

问题

在上一步加载图片时,我们使用了MINST数据集的图片,但是我们下载的MINST数据集的格式是这样的

 数据集介绍

MNIST数据集来自美国国家标准与技术研究所, National Institute of Standards and Technology (NIST)。训练集(training set)由来自250个不同人手写的
数字构成,其中50%是高中学生,50%来自人口普查局(the Census Bureau)的工作人员。测试集(test set)也是同样比例的手写数字数据,但保证了测试集和训练集
的作者集不相交。

  MNIST数据集一共有7万张图片,其中6万张是训练集,1万张是测试集。每张图片是28 × 28 28\times 2828×28的0 − 9 0-90−9的手写数字图片组成。每个图片是黑底
白字的形式,黑底用0表示,白字用0-1之间的浮点数表示,越接近1,颜色越白。每个元素表示图片对应的数字出现的概率,显然,该向量标签表示的是数字5。

  MNIST数据集下载地址是http://yann.lecun.com/exdb/mnist/,它包含了4 44个部分:

    (1)训练数据集:train-images-idx3-ubyte.gz (9.45 MB,包含60,000个样本)。
    (2)训练数据集标签:train-labels-idx1-ubyte.gz(28.2 KB,包含60,000个标签)。
    (3)测试数据集:t10k-images-idx3-ubyte.gz(1.57 MB ,包含10,000个样本)。
    (4)测试数据集标签:t10k-labels-idx1-ubyte.gz(4.43 KB,包含10,000个样本的标签)。

数据集转换

编写一个脚本把原二进制格式的数据转换成jpg格式,这里先转换100张

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import cv2
import numpy as np

with open("./dataset/MNIST/raw/train-images-idx3-ubyte", "rb") as f:
    file = f.read()


for i in range(1,100):
    image1 = [int(str(item).encode('ascii'), 16) for item in file[16+784*(i-1) : 16+784*i]]
    print(image1)

    image1_np = np.array(image1, dtype = np.uint8).reshape(28, 28, 1)
    cv2.imwrite(f"./picture/{i}.jpg", image1_np)

最后,可在picture文件夹下找到转换完成的jpg数据,再用它进行结果测试即可

八.总结

本文介绍了一个通用简单的pytorch框架,还有很多不足和缺点,后续会在本系列继续完善框架

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1110819.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【遮天】最新预告,叶凡一怒报仇,导演再删减人物,还暴露一个严重问题

Hello,小伙伴们,我是小郑继续为大家深度解析遮天国漫资讯。 《遮天》动漫第30集预告已出,叶凡被挟持进入荒古禁地!这一集看下来,导演又删减人物了,还暴露一个问题。 在预告中,叶凡已经被姬家和姜家的人带往…

【C++ 学习 ㉙】- 详解 C++11 的 constexpr 和 decltype 关键字

目录 一、constexpr 关键字 1.1 - constexpr 修饰普通变量 1.2 - constexpr 修饰函数 1.3 - constexpr 修饰类的构造函数 1.4 - constexpr 和 const 的区别 二、decltype 关键字 2.1 - 推导规则 2.2 - 实际应用 一、constexpr 关键字 constexpr 是 C11 新引入的关键字…

Spring Boot学习笔记(1)

Spring Boot学习笔记(1) 1.环境1.win2.mac3. IDEA 2.知识点1.Record类2.Switch开关表达式3. var和sealed4.springboot5.启用lombok 学习资料: 官网, 手册, 视频。 1.环境 1.win 1.下载vscode 2.安装jdk&#xff0…

求助C语言大佬:C语言的main函数参数问题

最近在敲代码的过程中,突发奇想,产生了一个疑问: 为什么main函数可以任由我们定义:可以接收一个参数、两个参数、三个参数都接接收,或者可以不接收?这是如何实现的 int main(){retrun 0; } int main (int…

移动app安全检测报告有什么作用?

移动app安全测试是一项至关重要的任务,它能够帮助确保移动应用程序在使用过程中不会受到各种安全威胁的侵害。在如今移动应用程序日益普及的时代,移动app安全测试尤为重要。移动app安全检测报告是基于专业的安全测试团队进行的全面分析后生成的&#xff…

博客积分上一万了

博客积分上一万了 继续努力,勇往直前。

JOSEF约瑟 JD3-40/23 JD3-70/23漏电继电器 AC220V\0.05-0.5A

JD3系列漏电继电器(以下简称继电器)适用于交流电压至1140V,频率为50Hz,该继电器与分励脱扣器或失压脱扣器的断路器、交流接触器、磁力启动器等组成漏电保护装置,作漏电和触电保护之用,可配备蜂鸣器、信号等…

短视频是“风口”还是“疯口”?

熟悉我的粉丝都知道,最近去追了下短视频的风口,折腾了几个视频出来。且不说视频效果如何,单单是制作视频的过程,就差点没要了童话的老命。看似短短的几分钟,真的应了那句话:台上一分钟,台下十年…

Ubuntu系统忘记Root用户密码-无法登录系统-更改Root密码-Ubuntu系统维护

一、背景 很多时候,我们总会设计复杂的密码,但是大多数时候,我们反而会先忘记我们的密码,导致密码不仅仅阻挡其他用户进入系统,同时也阻碍我们进入系统。 本文将介绍在忘记密码的情况下,如何进入系统并更改…

macOS Sonoma 桌面小工具活学活用!

macOS Sonoma 虽然不算是很大型的改版,但当中触目的新功能是「桌面小工具」(Widget)。如果我们的萤幕够大,将能够放更多不同的Widget,令用户无须开App 就能显示资讯,实在相当方便。 所有iPhone Widget 也能…

基于Springboot服装商品管理系统免费分享

基于Springboot服装商品管理系统 作者: 公众号(擎云毕业设计指南) 更多毕设项目请关注公众号,获取更多项目资源。如需部署请联系作者 注:禁止使用作者开源项目进行二次售卖,发现必究!!! 运行环境&…

controller调用service层报错Invalid bound statement (not found)

报错信息: "Invalid bound statement (not found): com.gelei.system.service.TbUserFollowService.getMyUserFanList" 这个问题就很神奇,请看下图,我测试的时候就是这么个情况; 综上所述,解决方法如下&…

pragma once与ifndef的区别

概要 代码编译过程中,为了防止同一份代码被重复引用,通常有两种实现方式 方式一 #pragma once 方式二 #ifndef _TEST_H_ #define _TEST_H_ #endif // !TEST_H 通常情况下,使用上述两种方式中的任意一种都是可以的。最近工作中,代…

阿里云ECS服务器的搭建学习

云服务器ECS: 云服务器(Elastic Compute Service,简称ECS)是阿里云提供的性能卓越、稳定可靠、弹性扩展的IaaS(Infrastructure as a Service)级别云计算服务。云服务器ECS免去了您采购IT硬件的前期准备&a…

直线模组有哪些配件组成的?

直线模组又称线性模组或线性滑台,是自动化设备中重要的传动元件,主要由以下几部分组成: 1、直线导轨:直线导轨又称线性滑轨,是用于直线往复运动场合的重要零部件,它具有比直线轴承更高的额定负载&#xff0…

吉利高端品牌领克汽车携手体验家,重塑智能创新的汽车服务体验

浙江吉利控股集团(以下简称“吉利集团”)始建于1986年,1997年进入汽车行业,一直专注实业,专注技术创新和人才培养,坚定不移地推动企业转型升级和可持续发展。现资产总值超5100亿元,员工总数超过…

【内网击穿工具 】NATAPP

内网穿透又叫内网映射,功能是把内网IP映射到公网,使公网也能轻松访问所搭建的服务。 内网与外网 外网指的是一个组织或网络中可公开访问的网络,即对外开放的网络。外网可以通过公共互联网进行访问 内网是相对于外网而言的,指的…

十四、Django框架使用

目录 一、框架简介二、MVT模型简介三、Python的虚拟环境3.1 安装virtualenv 虚拟环境3.2 创建和使用虚拟环境四、Django项目的搭建4.1 安装Django包4.2 创建Django项目4.3 创建Django项目的应用4.4 使用pycharm打开Django项目4.5 注册Django项目的应用4.6 启动Django项目五、OR…

Guava-RateLimiter详解

简介: 常用的限流算法有漏桶算法和令牌桶算法,guava的RateLimiter使用的是令牌桶算法,也就是以固定的频率向桶中放入令牌,例如一秒钟10枚令牌,实际业务在每次响应请求之前都从桶中获取令牌,只有取到令牌的请…

聚观早报 | 荣耀Play8T上市;阿芙“超级品牌日”上线

【聚观365】10月19日消息 荣耀Play8T上市 阿芙“超级品牌日”上线 特斯拉家庭充电服务包更新 TikTok Shop印尼站关停 高通与谷歌合作开发RISC-V芯片 荣耀Play8T上市 3月28日,荣耀推出了荣耀Play 7T系列手机,其最大的卖点就是搭载了6000mAh大电池&a…