CNN卷积网络实现MNIST数据集手写数字识别

news2024/12/22 18:59:21

步骤一:加载MNIST数据集

train_data = MNIST(root='./data',train=True,download=False,transform=transforms.ToTensor())
train_loader = DataLoader(train_data,shuffle=True,batch_size=64)
# 测试数据集
test_data = MNIST(root='./data',train=False,download=False,transform=transforms.ToTensor())
test_loader = DataLoader(test_data,shuffle=False,batch_size=64)

首先,通过MNIST类创建了train_data对象,指定了数据集的路径root='./data',并且将数据集标记为训练集train=Truedownload=False表示不自动从网络上下载数据集,而是使用已经下载好的数据集。我是之前自己已经下载过该数据集所以这里填的是False,如果之前没有下载的话就要填True。下面测试集也是一样。transforms.ToTensor()将数据转换为张量形式。

然后,通过DataLoader类创建了train_loader对象,指定了使用train_data作为数据源。shuffle=True表示在每个epoch开始时,将数据打乱顺序。batch_size=64表示每次抓取64个样本。

接下来,同样的步骤也被用来创建了测试集的数据加载器test_loader。不同的是,这里将数据集标记为测试集train=False,并且shuffle=False表示不需要打乱顺序。

加载完的数据集存在MNIST文件夹的raw文件夹下内容如下:

其中t10k-images-idx3-ubyte是测试集的图像,t10k-labels-idx3-ubyte是测试集的标签。train-images-idx3-ubyte是训练集的图像,train-labels-idx1-ubyte是训练集的标签。

存下来的这些数据集是二进制的形式,可以通过下面的代码(1.py)读取:

"""
Created on Sat Jul 27 15:26:39 2024

@author: wangyiyuan
"""
# 导入包
import struct
import numpy as np
from PIL import Image

class MnistParser:
   # 加载图像
   def load_image(self, file_path):

       # 读取二进制数据
       binary = open(file_path,'rb').read()

       # 读取头文件
       fmt_head = '>iiii'
       offset = 0

       # 读取头文件
       magic_number,images_number,rows_number,columns_number = struct.unpack_from(fmt_head,binary,offset)

       # 打印头文件信息
       print('图片数量:%d,图片行数:%d,图片列数:%d'%(images_number,rows_number,columns_number))

       # 处理数据
       image_size = rows_number * columns_number
       fmt_data = '>'+str(image_size)+'B'
       offset = offset + struct.calcsize(fmt_head)

       # 读取数据
       images = np.empty((images_number,rows_number,columns_number))
       for i in range(images_number):
           images[i] = np.array(struct.unpack_from(fmt_data, binary, offset)).reshape((rows_number, columns_number))
           offset = offset + struct.calcsize(fmt_data)
           # 每1万张打印一次信息
           if (i+1) % 10000 == 0:
               print('> 已读取:%d张图片'%(i+1))

       # 返回数据
       return images_number,rows_number,columns_number,images


   # 加载标签
   def load_labels(self, file_path):
       # 读取数据
       binary = open(file_path,'rb').read()

       # 读取头文件
       fmt_head = '>ii'
       offset = 0

       # 读取头文件
       magic_number,items_number = struct.unpack_from(fmt_head,binary,offset)

       # 打印头文件信息
       print('标签数:%d'%(items_number))

       # 处理数据
       fmt_data = '>B'
       offset = offset + struct.calcsize(fmt_head)

       # 读取数据
       labels = np.empty((items_number))
       for i in range(items_number):
           labels[i] = struct.unpack_from(fmt_data, binary, offset)[0]
           offset = offset + struct.calcsize(fmt_data)
           # 每1万张打印一次信息
           if (i+1)%10000 == 0:
               print('> 已读取:%d个标签'%(i+1))

       # 返回数据
       return items_number,labels


   # 图片可视化
   def visualaztion(self, images, labels, path):
       d = {0:0, 1:0, 2:0, 3:0, 4:0, 5:0, 6:0, 7:0, 8:0, 9:0}
       for i in range(images.__len__()):
            im = Image.fromarray(np.uint8(images[i]))
            im.save(path + "%d_%d.png"%(labels[i], d[labels[i]]))
            d[labels[i]] += 1
            # im.show()
            
            if (i+1)%10000 == 0:
               print('> 已保存:%d个图片'%(i+1))
               

# 保存为图片格式
def change_and_save():
    mnist =  MnistParser()

    trainImageFile = './train-images-idx3-ubyte'
    _, _, _, images = mnist.load_image(trainImageFile)
    trainLabelFile = './train-labels-idx1-ubyte'
    _, labels = mnist.load_labels(trainLabelFile)
    mnist.visualaztion(images, labels, "./images/train/")

    testImageFile = './train-images-idx3-ubyte'
    _, _, _, images = mnist.load_image(testImageFile)
    testLabelFile = './train-labels-idx1-ubyte'
    _, labels = mnist.load_labels(testLabelFile)
    mnist.visualaztion(images, labels, "./images/test/")


# 测试
if __name__ == '__main__':
    change_and_save()


将这个1.py文件和下载好的数据集放在同一个文件夹下:

新建一个文件夹images,在文件夹images里面新建两个文件夹分别叫test和train。

运行完可以发现train和test里的内容如下:

步骤二:建立模型

class Model(nn.Module):
    def __init__(self):
        super(Model,self).__init__()
        self.linear1 = nn.Linear(784,256)
        self.linear2 = nn.Linear(256,64)
        self.linear3 = nn.Linear(64,10) # 10个手写数字对应的10个输出

    def forward(self,x):
        x = x.view(-1,784) # 变形
        x = torch.relu(self.linear1(x))
        x = torch.relu(self.linear2(x))
        # x = torch.relu(self.linear3(x))
        return x

这里是建立了一个神经网络模型类(Model)。这个模型有三个线性层(linear1、linear2、linear3)。输入维度为784(因为每一张图片的大小是28*28=784),输出维度为256、64、10(因为有十个类)。forward函数定义了模型的前向传播过程,其中x.view(-1, 784)将输入张量x变形为(batch_size, 784)的大小。然后经过三个线性层和relu激活函数进行运算,最后返回输出结果x。

步骤三:训练模型

model = Model()
criterion = nn.CrossEntropyLoss() # 交叉熵损失,相当于Softmax+Log+NllLoss
optimizer = torch.optim.SGD(model.parameters(),0.8) # 第一个参数是初始化参数值,第二个参数是学习率

# 模型训练
# def train():
for index,data in enumerate(train_loader):
        input,target = data # input为输入数据,target为标签
        optimizer.zero_grad() # 梯度清零
        y_predict = model(input) # 模型预测
        loss = criterion(y_predict,target) # 计算损失
        loss.backward() # 反向传播
        optimizer.step() # 更新参数
        if index % 100 == 0: # 每一百次保存一次模型,打印损失
            torch.save(model.state_dict(),"./model/model.pkl") # 保存模型
            torch.save(optimizer.state_dict(),"./model/optimizer.pkl")
            print("损失值为:%.2f" % loss.item())

首先创建了一个模型对象model,一个损失函数对象criterion和一个优化器对象optimizer。然后使用一个for循环遍历训练数据集train_loader,每次取出一个batch的数据。接着将优化器的梯度清零,然后使用模型前向传播得到预测结果y_predict,计算损失值loss,然后进行反向传播和参数更新。每训练100个batch,保存模型和优化器的参数,并打印当前的损失值。

步骤四:保存模型参数

if os.path.exists('./model/model.pkl'):
    model.load_state_dict(torch.load("./model/model.pkl")) # 加载保存模型的参数

在当前文件夹下新建一个名叫model的文件夹。保存步骤三中训练完模型的参数。

步骤五:检验模型


    correct = 0 # 正确预测的个数
    total = 0 # 总数
    with torch.no_grad(): # 测试不用计算梯度
        for data in test_loader:
            input,target = data
            output=model(input) # output输出10个预测取值,其中最大的即为预测的数
            probability,predict=torch.max(output.data,dim=1) # 返回一个元组,第一个为最大概率值,第二个为最大值的下标
            total += target.size(0) # target是形状为(batch_size,1)的矩阵,使用size(0)取出该批的大小
            correct += (predict == target).sum().item() # predict和target均为(batch_size,1)的矩阵,sum()求出相等的个数
        print("准确率为:%.2f" % (correct / total))


参数说明:

  • correct:记录正确预测的个数
  • total:记录总样本数
  • test_loader:测试集的数据加载器
  • input:输入数据
  • target:目标标签
  • output:模型的输出结果
  • probability:最大概率值
  • predict:最大值的下标

过程:

  • 使用torch.no_grad()包装测试过程,表示不需要计算梯度
  • 遍历测试集中的每个数据,获取输入数据和目标标签
  • 将输入数据输入模型,得到模型的输出结果
  • 使用torch.max()函数返回预测结果中的最大概率值和最大值的下标
  • 更新总数和正确预测的个数
  • 最后计算并输出准确率。

步骤六:检测自己的手写数据

if __name__ == '__main__':
    # 自定义测试
    image = Image.open('C:/Users/wangyiyuan/Desktop/20201116160729670.jpg') # 读取自定义手写图片
    image = image.resize((28,28)) # 裁剪尺寸为28*28
    image = image.convert('L') # 转换为灰度图像
    transform = transforms.ToTensor()
    image = transform(image)
    image = image.resize(1,1,28,28)
    output = model(image)
    probability,predict=torch.max(output.data,dim=1)
    print("此手写图片值为:%d,其最大概率为:%.2f" % (predict[0],probability))
    plt.title('此手写图片值为:{}'.format((int(predict))),fontname="SimHei")
    plt.imshow(image.squeeze())
    plt.show()

这里的C:/Users/wangyiyuan/Desktop/20201116160729670.jpg是我自己从网上找的的手写图片。这段代码意思如下:

  1. 打开并读取一张手写图片,图片的路径为'C:/Users/wangyiyuan/Desktop/20201116160729670.jpg'。
  2. 调整图片尺寸为28x28。
  3. 将图片转换为灰度图像,以便后续处理。
  4. 使用transforms.ToTensor()将图片转换为PyTorch张量。
  5. 调整图片尺寸为(1, 1, 28, 28)以适应模型的输入要求。
  6. 将处理后的图片输入模型,获取预测输出。
  7. 通过torch.max函数获得输出中的最大值及其索引,即预测的数字和其概率。
  8. 打印预测的数字和概率。
  9. 在图像上显示预测结果和手写图片。
  10. 展示图像。

步骤七:结果展示

我的原图是:

测试得到的结果为:


损失值为:4.16
损失值为:0.93
损失值为:0.31
损失值为:0.19
损失值为:0.24
损失值为:0.15
损失值为:0.13
损失值为:0.11
损失值为:0.18
损失值为:0.02
此手写图片值为:2,其最大概率为:6.57

----------------------码字不易,请多多关注博主!-----------------------------------------------主程序可以关注博主后,私信秒发-------------------

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1973187.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Project #0 - C++ Primer

知识点 1.pragma once C和C中的一个非标准但广泛支持的预处理指令,用于使当前源文件在单次编译中只被包含一次。 #pragma once class F {}; // 不管被导入多少次,只处理他一次2.explicit C中的一个关键字,它用来修饰只有一个参数的类构造函…

遇到突发事故,您是否能够应对自如?

近期发生的全球性大规模系统技术故障为我们敲响了警钟——仅仅依赖一朵公共云服务存在其固有的脆弱性。全球多地视窗系统因一款安全软件更新而宕机,出现“蓝屏”, 航空、医疗、传媒、金融、零售、物流等多个行业均受影响。这一事件凸显了对强大、多元化云…

forwardRef和useImperativeHandle到底能做啥

线上个官网例子 App.js import { useRef } from react; import MyInput from ./MyInput.js;export default function Form() {const ref useRef(null);function handleClick() {ref.current.focus();// This wont work because the DOM node isnt exposed:// ref.current.sty…

2024年必备技能:智联招聘岗位信息采集技巧全解析

随着大数据时代的发展,精准定位职业机会成为程序员求职的关键。本文将深入解析如何利用Python高效采集智联招聘上的岗位信息,助你在2024年的职场竞争中脱颖而出。通过实战代码示例,揭示网络爬虫背后的秘密,让你轻松掌握这一必备技…

【算法】双指针-OJ题详解1

双指针-OJ题 移动零(点击跳转)原理讲解代码实现 复写零(点击跳转)原理讲解代码实现 快乐数(点击跳转)原理讲解代码实现 盛最多水的容器(点击跳转)原理讲解代码实现 有效三角形的个数…

模式植物构建orgDb数据库 | 以org.Slycompersicum.eg.db为例

原文链接:模式植物构建orgDb数据库 | 以org.Slycompersicum.eg.db为例 本期教程 一步构建模式植物OrgDb数据库 source("../Set_OrgDb_Database.R")# 使用函数 Set_OrgDb_Database(emapper_file "out.emapper_tomato.csv", ## 输入的eggnog结果文件json_…

使用 MinIO、Langchain 和 Ray Data 构建分布式嵌入式子系统

嵌入子系统是实现检索增强生成所需的四个子系统之一。它将您的自定义语料库转换为可以搜索语义含义的向量数据库。其他子系统是用于创建自定义语料库的数据管道,用于查询向量数据库以向用户查询添加更多上下文的检索器,最后是托管大型语言模型 &#xff…

Stream 33

package Array.collection;import java.util.*; import java.util.stream.Stream;public class stream1 {public static void main(String[] args) {//、如何茯取List集合的Stream流?List<String> names new ArrayList<>();Collections. addAll(names,"方法…

超声波眼镜清洗机哪个品牌好?四款高性能超声波清洗机测评剖析

对于追求高生活质量的用户来说&#xff0c;眼镜的清洁绝对不能马虎。如果不定期清洁眼镜&#xff0c;时间久了&#xff0c;镜片的缝隙中会积累大量的灰尘和细菌&#xff0c;眼镜靠近眼部&#xff0c;对眼部健康有很大影响。在这种情况下&#xff0c;超声波清洗机显得尤为重要。…

现象:程序没问题,compile成功,在load时,提示prg.dll没找到

现象:程序没问题&#xff0c;compile成功&#xff0c;在load时&#xff0c;提示prg.dll没找到 解决方法&#xff1a;使用新的ATE电脑主机&#xff0c;导致的问题&#xff0c;又换回原来的电脑主机&#xff0c;问题解决。

数据结构与算法--【链表1】力扣练习 || 链表 / 移除链表元素

声明&#xff1a;本文参考代码随想录。 一、链表定义 1、概念 将线性表L(a0,a1,……,an-1)中各元素分布在存储器的不同存储块&#xff0c;称为结点&#xff0c;通过地址或指针建立元素之间的联系。 每一个结点由两部分组成&#xff1a;数据域和指针域。结点的data域存放数据…

第三届Apache Flink 极客挑战赛暨AAIG CUP比赛攻略_大浪813团队

关联比赛: 第三届 Apache Flink 极客挑战赛暨AAIG CUP——电商推荐“抱大腿”攻击识别 第三届Apache Flink 极客挑战赛暨AAIG CUP比赛攻略_大浪813团队 第三届Apache Flink 极客挑战赛暨AAIG CUP 自2021年8月17日上线以来已有 4537 个参赛队伍报名。11月09号&#xff0c;大赛…

Android入门之路 - WebView加载数据的几种方式

之前客户端加载H5时遇到了一些问题&#xff0c;我为了方便解决问题&#xff0c;所以将对应场景复刻到了Demo中&#xff0c;从之前的网络加载模拟为了本地加载Html的方式&#xff0c;但是没想到无意被一个基础知识点卡了一些时间&#xff0c;翻看往昔笔记发现未曾记录这种基础场…

C语言初阶(10)

1.野指针 野指针就是指向未知空间的指针&#xff0c;有以下几种情况 &#xff08;1)指针未初始化 int main() {int a0;int*b;return 0; } 上面指针就是没有初始化&#xff0c;形成一种指向一个随机空间的地址的指针&#xff0c;我们可以修改成 int main() {int a0;int*bNU…

甘肃雀舌面:舌尖上的独特韵味

雀舌面&#xff0c;顾名思义&#xff0c;其面条形状如同雀舌般小巧精致。这一独特的形态并非偶然所得&#xff0c;而是源于精湛的手工技艺。制作雀舌面&#xff0c;对面粉的选择和面团的揉制有着极高的要求。经验丰富的师傅会精心挑选优质面粉&#xff0c;加入适量的水&#xf…

嵌入式学习---DAY17:共用体与位运算

链表剩余的一些内容 一、共用体 union 共用体名 名称首字母大写 { 成员表列&#xff1b; }&#xff1b; union Demo {int i;short s;char c; }; int main(void) {union Demo d;d.i 10;d.s 100;d.c 200;printf("%d\n", sizeof(d)); /…

一起学习LeetCode热题100道(24/100)

24.回文链表(学习) 给你一个单链表的头节点 head &#xff0c;请你判断该链表是否为 回文链表 。如果是&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,2,1] 输出&#xff1a;true 示例 2&#xff1a; …

鸿蒙Scroll布局,横向与纵向

注意&#xff0c;当横向scroll时&#xff0c;直接子元素的宽&#xff0c;不能100%&#xff0c; 当纵向scroll时&#xff0c;直接子元素的高&#xff0c;不能100%​​​​​​​ 1、纵向代码&#xff1a; 方法1&#xff1a;用数值计算&#xff0c;来设置中间的高度&#xff1a; …

Django函数视图和类视图

函数视图 1.全局环境的urls.py引入映入应用的urls&#xff0c;避免后期开发路由过多而导致杂乱 from django.contrib import admin from django.urls import path, includeurlpatterns [path(account/, include(account.urls)),#使用include函数引入&#xff0c;表示account…

搜狗爬虫(www.sogou.com)IP及UA,真实采集数据

一、数据来源&#xff1a; 1、这批搜狗爬虫&#xff08;www.sogou.com&#xff09;IP来源于尚贤达猎头网站采集数据&#xff1b; ​ 2、数据采集时间段&#xff1a;2023年10月-2024年7月&#xff1b; 3、判断标准&#xff1a;主要根据用户代理是否包含“www.sogou.com”和IP核实…