Ai项目十四:基于 LeNet5 的手写数字识别及训练

news2025/1/13 6:07:36

若该文为原创文章,转载请注明原文出处。

一、介绍

pytorch复现lenet5模型,并检测自己手写的数字图片。

利用torch框架搭建模型相对比较简单,但是也会遇到很多问题,网上资料很多,搭建模型的方法大同小异,在我尝试了自己搭建搭建出来模型,无论是训练还是检测都会遇到很多的问题,像这种自己遇到的问题,请教别人也没有用。原本使用的是github上的一份代码来复现,环境搭建完成后,才发现要有GPU,而我搭建是使用CPU,失败告终,为了复现,租用了AutoDL平台,在次搭建,这里记录GPU下的操作,CPU版本需要修改源码,自行修改,我的目的是在要训练自己的模型并在RK3568上部署,所以先训练并测试好。为后续部署作基础。

二、环境

三、搭建

1、创建虚拟环境

 conda create -n LeNet5_env python==3.8

2、安装pytorch

Previous PyTorch Versions | PyTorch

根据官方PyTorch,安装pytorch,使用的是CPU版本,其他版本自行安装,安装命令:

​​​​​​​
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
 -i https://pypi.tuna.tsinghua.edu.cn/simple

还需要安装一些其他的库

pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple

3、数据集下载

http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

直接把上面地址复制到网页上,就只可以下载

下载后保存到data/MNIST/raw目录下

四、训练代码

训练模型有四个文件分别为:LeNet5.py;myDatast.py;readMnist.py;train.py

文件LeNet5.py是网络层模型

train.py

import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.utils.data import DataLoader
from readMnist import *
from myDatast import Mnist
from LeNet5 import LeNet5

train_images = load_train_images()
train_labels = load_train_labels()

trainData = Mnist(train_images, train_labels)
train_data = DataLoader(dataset=trainData, batch_size=1, shuffle=True)

lenet5 = LeNet5()
lenet5.cuda()

lossFun = nn.CrossEntropyLoss()

optimizer = torch.optim.Adam(params=lenet5.parameters(), lr=1e-4)

Epochs = 100
L = len(train_data)

for epoch in range(Epochs):
    for i, (img, id) in enumerate(train_data):

        img = img.float()
        id = id.float()

        img = img.cuda()
        id = id.cuda()

        img = Variable(img, requires_grad=True)
        id = Variable(id, requires_grad=True)

        Output = lenet5.forward(img)
        loss = lossFun(Output, id.long())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        iter = epoch * L + i + 1
        if iter % 100 == 0:
            print('epoch:{},iter:{},loss:{:.6f}'.format(epoch + 1, iter, loss))

    torch.save(lenet5.state_dict(), 'lenet5.pth')

 LeNet5.py

import torch.nn as nn


class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(in_features=16 * 4 * 4, out_features=120),
            nn.Sigmoid()
        )

        self.fc2 = nn.Sequential(
            nn.Linear(in_features=120, out_features=84),
            nn.Sigmoid()
        )

        self.fc3 = nn.Linear(in_features=84, out_features=10)

    def forward(self, img):
        img = self.conv1.forward(img)
        img = self.conv2.forward(img)

        img = img.view(img.size()[0], -1)

        img = self.fc1.forward(img)
        img = self.fc2.forward(img)
        img = self.fc3.forward(img)

        return img

 readMnist.py

from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np


class Mnist(Dataset):
    def __init__(self, dataset, label):
        self.dataset = dataset
        self.label = label
        self.len = len(self.label)
        self.transforms = transforms.Compose([transforms.ToTensor() , transforms.Normalize(mean=[0.5], std=[0.5])])

    def __len__(self):
        return self.len

    def __getitem__(self, item):
        img = self.dataset[item]
        img_id = self.label[item]

        img = np.transpose(img,(1,2,0))
        img = self.transforms(img)

        return img, img_id

readMnist.py

import numpy as np
import struct
import matplotlib.pyplot as plt
import cv2

fpath = 'G:/enpei_Project_Code/21_LeNet5/LeNet5-master/myLeNet5/data/MNIST/raw/'

# 训练集文件
train_images_idx3_ubyte_file = fpath + 'train-images-idx3-ubyte'
# 训练集标签文件
train_labels_idx1_ubyte_file = fpath + 'train-labels-idx1-ubyte'

# 测试集文件
test_images_idx3_ubyte_file = fpath + 't10k-images-idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = fpath + 't10k-labels-idx1-ubyte'


def decode_idx3_ubyte(idx3_ubyte_file):
    """
    解析idx3文件的通用函数
    :param idx3_ubyte_file: idx3文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = open(idx3_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽
    offset = 0
    fmt_header = '>iiii'  # 因为数据结构中前4行的数据类型都是32位整型,所以采用i格式,但我们需要读取前4行数据,所以需要4个i。我们后面会看到标签集中,只使用2个ii。
    magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)
    print('魔数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))

    # 解析数据集
    image_size = num_rows * num_cols
    offset += struct.calcsize(fmt_header)  # 获得数据在缓存中的指针位置,从前面介绍的数据结构可以看出,读取了前4行之后,指针位置(即偏移位置offset)指向0016。
    print(offset)
    fmt_image = '>' + str(
        image_size) + 'B'  # 图像数据像素值的类型为unsigned char型,对应的format格式为B。这里还有加上图像大小784,是为了读取784个B格式数据,如果没有则只会读取一个值(即一副图像中的一个像素值)
    print(fmt_image, offset, struct.calcsize(fmt_image))
    images = np.empty((num_images, 1, num_rows, num_cols))
    # plt.figure()
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print('已解析 %d' % (i + 1) + '张')
            print(offset)
        images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((1, num_rows, num_cols))
        # print(images[i])
        offset += struct.calcsize(fmt_image)
    #        plt.imshow(images[i],'gray')
    #        plt.pause(0.00001)
    #        plt.show()
    # plt.show()

    return images


def decode_idx1_ubyte(idx1_ubyte_file):
    """
    解析idx1文件的通用函数
    :param idx1_ubyte_file: idx1文件路径
    :return: 数据集
    """
    # 读取二进制数据
    bin_data = open(idx1_ubyte_file, 'rb').read()

    # 解析文件头信息,依次为魔数和标签数
    offset = 0
    fmt_header = '>ii'
    magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)
    print('魔数:%d, 图片数量: %d张' % (magic_number, num_images))

    # 解析数据集
    offset += struct.calcsize(fmt_header)
    fmt_image = '>B'
    labels = np.empty(num_images)
    for i in range(num_images):
        if (i + 1) % 10000 == 0:
            print('已解析 %d' % (i + 1) + '张')
        labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]
        offset += struct.calcsize(fmt_image)
    return labels


def load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):
    """
    TRAINING SET IMAGE FILE (train-images-idx3-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000803(2051) magic number
    0004     32 bit integer  60000            number of images
    0008     32 bit integer  28               number of rows
    0012     32 bit integer  28               number of columns
    0016     unsigned byte   ??               pixel
    0017     unsigned byte   ??               pixel
    ........
    xxxx     unsigned byte   ??               pixel
    Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).

    :param idx_ubyte_file: idx文件路径
    :return: n*row*col维np.array对象,n为图片数量
    """
    return decode_idx3_ubyte(idx_ubyte_file)


def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):
    """
    TRAINING SET LABEL FILE (train-labels-idx1-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000801(2049) magic number (MSB first)
    0004     32 bit integer  60000            number of items
    0008     unsigned byte   ??               label
    0009     unsigned byte   ??               label
    ........
    xxxx     unsigned byte   ??               label
    The labels values are 0 to 9.

    :param idx_ubyte_file: idx文件路径
    :return: n*1维np.array对象,n为图片数量
    """
    return decode_idx1_ubyte(idx_ubyte_file)


def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):
    """
    TEST SET IMAGE FILE (t10k-images-idx3-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000803(2051) magic number
    0004     32 bit integer  10000            number of images
    0008     32 bit integer  28               number of rows
    0012     32 bit integer  28               number of columns
    0016     unsigned byte   ??               pixel
    0017     unsigned byte   ??               pixel
    ........
    xxxx     unsigned byte   ??               pixel
    Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).

    :param idx_ubyte_file: idx文件路径
    :return: n*row*col维np.array对象,n为图片数量
    """
    return decode_idx3_ubyte(idx_ubyte_file)


def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):
    """
    TEST SET LABEL FILE (t10k-labels-idx1-ubyte):
    [offset] [type]          [value]          [description]
    0000     32 bit integer  0x00000801(2049) magic number (MSB first)
    0004     32 bit integer  10000            number of items
    0008     unsigned byte   ??               label
    0009     unsigned byte   ??               label
    ........
    xxxx     unsigned byte   ??               label
    The labels values are 0 to 9.

    :param idx_ubyte_file: idx文件路径
    :return: n*1维np.array对象,n为图片数量
    """
    return decode_idx1_ubyte(idx_ubyte_file)


if __name__ == '__main__':

    train_images = load_train_images()
    train_labels = load_train_labels()
    test_images = load_test_images()
    test_labels = load_test_labels()

    pass

    # 查看前十个数据及其标签以读取是否正确
    for i in range(10):
        print(train_labels[i])

        img = train_images[i]
        img = np.transpose(img, (1, 2, 0))

        cv2.namedWindow('img')
        cv2.imshow('img', img)
        cv2.waitKey(100)

    print('done')


上面代码需要注意的是数据集的路径,需要修改成对应的路径。

运行python train.py

训练大概5小时

五、测试

from LeNet5 import LeNet5
import torch
from readMnist import *
from myDatast import Mnist
from torch.utils.data import DataLoader
import numpy as np
import cv2

test_images = load_test_images()
test_labels = load_test_labels()

testData = Mnist(test_images, test_labels)
test_data = DataLoader(dataset=testData, batch_size=1, shuffle=True)

lenet5 = LeNet5()
lenet5.load_state_dict(torch.load('lenet5.pth'))
lenet5.eval()

showimg = True
js = 0
for i, (img, id) in enumerate(test_data):

    img = img.float()
    outid = lenet5(img)

    oid = torch.argmax(outid)
    if oid == id:
        js = js + 1

    if showimg == True:
        img = img.numpy()
        img = np.squeeze(img)

        id = id.numpy()
        id = np.squeeze(id)
        id = np.int32(id)

        oid = oid.numpy()
        oid = np.squeeze(oid)

        maxv = np.max(img)
        minv = np.min(img)

        img = (img - minv) / (maxv - minv)

        cv2.namedWindow("img", 0)
        cv2.imshow("img", img)

        title = "img, predicted value:{},truth value:{}".format(oid, id)
        cv2.setWindowTitle("img",title)

        cv2.waitKey(1)

print('准确率:{:.6f}'.format(js / (i + 1)))

测试结果准确率达到0.986基本达到要求 

如有侵权,或需要完整代码,请及时联系博主。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1058117.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

匿名上位机V7波形显示教程-简单能用

匿名上位机V7波形显示教程-简单能用 匿名上位机V7下位机数据格式根据匿名上位机V7的手册说明文档,编写对应的指令在主函数中初始化ANDmessage驱动连接匿名上位机V7 匿名上位机V7下位机数据格式 DATA区域内容: 举例说明DATA区域格式,例如上文&…

Altium Designer 批量添加元器件后缀

Altium Designer 批量添加元器件后缀 方法一方法二可能出现的问题要注意 方法一 您可以使用 Altium Designer 中的“批量修改元器件名称”功能来批量添加元器件后缀。具体步骤如下: 1.为了方便显示 操作流程,我这里复制了几个原理图的文件,粘…

【漏洞复现】用友GPR-U8 slbmbygr SQL注入漏洞

文章目录 一、漏洞描述二、网络空间搜索引擎搜索三、漏洞利用 一、漏洞描述 用友GRP-U8是面向政府及行政事业单位的财政管理应用。北京用友政务软件有限公司GRP-U8 SQL注入漏洞。 ![在这里插入图片描述](https://img-blog.csdnimg.cn/fe260ff4d6d14abeb0e576e4bbf3c385.png 二…

计算机组成原理期末复习

第一章 上机前的准备:建立数学模型、确定计算方法和编制解题程序n位操作码有 2 n 2^n 2n种不同操作主储存器(主存/内存)包括存储体M、各种逻辑部件及控制电路。储存体有多个储存单元,储存单元有多个储存元件,每个存储…

SDL2绘制ffmpeg解析的mp4文件

文章目录 1.FFMPEG利用命令行将mp4转yuv4202.ffmpeg将mp4解析为yuv数据2.1 核心api: 3.SDL2进行yuv绘制到屏幕3.1 核心api 4.完整代码5.效果展示 本项目采用生产者消费者模型,生产者线程:使用ffmpeg将mp4格式数据解析为yuv的帧,消费者线程&am…

latex表格内容换行

问题描述: 在用latex表格中编写公式时,可能出现公式太长,表格中后面的内容不能在文档中呈现,如下图1,故要进行行内内容的换行,使内容呈现完全而传统的\换行后,换行内容会顶格,如图2。 解决方…

PE文件之导入表

1. 导入表 2. 显示导入表信息的例子 ; 作用: 将RVA地址转成FOA即文件偏移 ; 参数: _pFileHdr 指向读到内存中文件的基址指针 ; _dwRVA 目标RVA地址 ; 返回: 目标RVA转成文件偏移的值 RVA2FOA PROC USES esi edi edx, _pFileHdr:PTR BYTE, _dwRVA:DWORDmov esi, _pFil…

饲料微生物检验 采样.

声明 本文是学习GB-T 42959-2023 饲料微生物检验 采样. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本文件规定了以微生物检验为目的的采样原则、采样人员、设备和材料、采样方案、采样步骤和采样 报告。 本文件适用于以微生物检验为目的…

Can‘t pickle <class ‘__main__.Test‘>: it‘s not the same object as __main__.Test

目录 原因1 类名重复了 案例1 变量名和类名重复 原因1 类名重复了 检查项目代码&#xff0c;是不是其他地方有同名类。 案例1 变量名和类名重复 转自&#xff1a;python3报错Cant pickle <class __main__.Test>: its not the same object as __main__.Test解决 - 知乎…

接口日志,统一记录(AOP+自定义注解)

需求 指定接口&#xff0c;记录请求的日志。 接口日志的核心内容包括&#xff1a;请求方法&#xff0c;接口路径&#xff0c;请求参数等。 方案 采用的方案是&#xff1a;AOP 自定义注解 说明&#xff1a; 在需要记录日志的接口上&#xff0c;加上自定义注解ApiLog&…

样品运输与贮存

声明 本文是学习GB-T 42959-2023 饲料微生物检验 采样. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本文件规定了以微生物检验为目的的采样原则、采样人员、设备和材料、采样方案、采样步骤和采样 报告。 本文件适用于以微生物检验为目的…

如何限制文件只能通过USB打印机打印,限制打印次数和时限并且无法在打印前查看或编辑内容

在今天这个高度信息化的时代&#xff0c;文档打印已经成为日常工作中不可或缺的一部分。然而&#xff0c;这也带来了诸多安全风险&#xff0c;如文档被篡改、知识产权被侵犯以及信息泄露等。为了解决这些问题&#xff0c;只印应运而生。作为一款独特的软件工具&#xff0c;只印…

《视觉 SLAM 十四讲》V2 第 4 讲 李群与李代数 【什么样的相机位姿 最符合 当前观测数据】

P71 文章目录 4.1 李群与李代数基础4.1.3 李代数的定义4.1.4 李代数 so(3)4.1.5 李代数 se(3) 4.2 指数与对数映射4.2.1 SO(3)上的指数映射罗德里格斯公式推导 4.2.2 SE(3) 上的指数映射SO(3),SE(3),so(3),se(3)的对应关系 4.3 李代数求导与扰动模型4.3.2 SO(3)上的李代数求导…

S-Clustr(影子集群)僵尸网络@Мартин.

公告 项目地址:https://github.com/MartinxMax/S-Clustr/tree/V1.0.0 1.成功扩展3类嵌入式设备,组建庞大的"僵尸网络" |——C51[开发中] |——Arduino |——合宙AIR780e[开发中] 2.攻击者端与服务端之间通讯过程全程加密,防溯源分析 3.Generate一键自动生成Arduino…

将数组和减半的最少操作【贪心2】

题目&#xff1a;将数组和减半的最少操作 贪心思路&#xff1a;每次挑选最大的数来减半。 解法&#xff1a;贪心大根堆 class Solution { public:int halveArray(vector<int>& nums) {priority_queue<double> heap;double sum 0.0;for(int& x : nums){hea…

Linux性能优化--性能工具:系统内存

3.0.概述 本章概述了系统级的Linux内存性能工具。本章将讨论这些工具可以测量的内存统计信息&#xff0c;以及如何使用各种工具收集这些统计结果。阅读本章后&#xff0c;你将能够&#xff1a; 理解系统级性能的基本指标&#xff0c;包括内存的使用情况。明白哪些工具可以检索…

vtk 动画入门 1 代码

实现效果如图&#xff1a; #include <vtkAutoInit.h> //VTK_MODULE_INIT(vtkRenderingOpenGL2); //VTK_MODULE_INIT(vtkInteractionStyle); VTK_MODULE_INIT(vtkRenderingOpenGL2); VTK_MODULE_INIT(vtkInteractionStyle); //VTK_MODULE_INIT(vtkRenderingFreeType); #in…

lv8 嵌入式开发-网络编程开发 01什么是互联网

目录 1 计算机网络的定义与分类 1.1 按照网络的作用范围进行分类 1.2 按照网络的使用者进行分类 2 网络的网络 2.1 名词解释 2.2 边缘与核心 3 互联网基础结构发展的三个阶段 3.1 第一阶段&#xff1a;1969 – 1990 3.2 第二阶段&#xff1a;1985 – 1993 3.3 第三阶…

Python大数据之PySpark(四)SparkBaseCore

文章目录 SparkBase&Core环境搭建-Spark on YARN扩展阅读-Spark关键概念[了解]PySpark角色分析[了解]PySpark架构后记 SparkBase&Core 学习目标掌握SparkOnYarn搭建掌握RDD的基础创建及相关算子操作了解PySpark的架构及角色 环境搭建-Spark on YARN Yarn 资源调度框…

UE中制作棋盘格材质效果

在UE中通过这个小技巧制作棋盘格材质效果&#xff0c;可以快速预览UV拉伸情况&#xff0c;方便调试导入的模型。 1.操作步骤 1.1 首先新建材质&#xff0c;Shading Model&#xff08;着色模式&#xff09;设置为Unlit&#xff08;无光照&#xff09;&#xff1a; 1.2 我们…