基于Paddle的手写数字识别模型

news2025/1/15 6:50:20

  百度飞桨(paddlepaddle)是百度的开源深度学习平台,今天就利用paddle来编写入门级的手写数字模型.

一,准备数据

  1. 下载数据集,这里我们使用的是MNIST数据集

    # 下载原始的 MNIST 数据集并进行解压
    wget https://paddle-imagenet-models-name.bj.bcebos.com/data/mnist.tar
    tar -xf mnist.tar
    

    数据集的目录格式如下:

    mnist.
    ├── train
    │   └── imgs
    │       ├── 0
    │       ├── 1
    │       ├── 2
    │       ├── 3
    │       ├── 4
    │       ├── 5
    │       ├── 6
    │       ├── 7
    │       ├── 8
    │       └── 9
    └── val
       └── imgs
            ├── 0
            ├── 1
            ├── 2
            ├── 3
            ├── 4
            ├── 5
            ├── 6
            ├── 7
            ├── 8
            └── 9
    

    train和val目录下均有一个标签文件label.txt

  2. 定义数据加载模块,这里使用百度paddle提供paddle.io.Dataset来实现自定义的MyDataSet:
    class MyDataSet(Dataset):
    
        def __init__(self, data_dir, label_path, tansform=None):
            super(MyDataSet, self).__init__()
            self.data_list = []
    
            with open(label_path, encoding='utf-8') as f:
    
                for line in f.readlines():
                    image_path, label = line.split('\t')
                    image_path = os.path.join(data_dir, image_path)
                    self.data_list.append([image_path, label])
    
            self.tansform = tansform
    
        def __getitem__(self, index):
    
            image_path, label = self.data_list[index]
            image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
            image = image.astype('float32')
    
            # 应用图像变换
            if self.tansform is not None:
                self.tansform(image)
    
            label = int(label)
            return image, label
    
        def __len__(self):
            return len(self.data_list)

二,模型实现

这里我们参照LeNet进行实现,下面看一下LeNet的网络结构

 定义一个MyNet:

class MyNet(nn.Layer):
    def __init__(self, num_classes=10):
        super().__init__()
        self.num_classes = num_classes

        # 定义
        self.conv1 = nn.Conv2D(1, 6, 3, stride=1, padding=1)
        self.conv2 = nn.Conv2D(6, 16, 5, stride=1, padding=0)
        self.relu = nn.ReLU()

        self.features = nn.Sequential(
            self.conv1,
            self.relu,
            nn.MaxPool2D(2, 2),
            self.conv2,
            self.relu,
            nn.MaxPool2D(2, 2))

        if num_classes > 0:
            self.linear = nn.Sequential(
                nn.Linear(400, 120),
                nn.Linear(120, 84),
                nn.Linear(84, num_classes)
            )

    def forward(self, x):
        x = self.features(x)
        if self.num_classes > 0:
            x = paddle.flatten(x, 1)
            x = self.linear(x)
        return x

三,模型训练

# 定义一个网络
model = MyNet()
# 可视化模型组网结构和参数
params_info = paddle.summary(model, (1, 1, 28, 28))
print(params_info)

这里定义一个MyNet的网络,然后查看网络的结构,下面开始进行模型训练

total_epoch = 5
    batch_size = 16

    # transform = F.normalize(mean=[127.5], std=[127.5], data_format=['CHW'])
    transform = Normalize(mean=[127.5], std=[127.5], data_format=['CHW'])

    # 训练集
    data_dir_train = './mnist/train'
    label_path_train = './mnist/train/label.txt'
    

    # 加载数据
    train_dataset = MyDataSet(data_dir_train, label_path_train, transform)
    val_dataset = MyDataSet(data_dir_val, label_path_val, transform)
    print(f'训练图片张数:{len(train_dataset)} 测试集图张数:{len(val_dataset)}')
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    optim = paddle.optimizer.Adam(parameters=model.parameters())
    # 设置损失函数
    loss_fn = paddle.nn.CrossEntropyLoss()
    for epoch in range(total_epoch):
        for batch_id, data in enumerate(train_loader):

            x_data = data[0]  # 训练数据
            y_data = data[1]  # 训练数据标签
            # print(y_data)
            # print(y_data.shape)
            # 增加维度
            x_data = paddle.unsqueeze(x_data, axis=1)
            predicts = model(x_data)  # 预测结果
            # print(f'predicts:{predicts} predicts.shape={predicts.shape}')
            y_data = paddle.unsqueeze(y_data, axis=1)
            # 计算损失 等价于 prepare 中loss的设置
            loss = loss_fn(predicts, y_data)

            # 计算准确率 等价于 prepare 中metrics的设置
            acc = paddle.metric.accuracy(predicts, y_data)

            # 下面的反向传播、打印训练信息、更新参数、梯度清零都被封装到 Model.fit() 中
            # 反向传播
            loss.backward()
            # print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id + 1, loss.numpy(),
            #                                                                 acc.numpy()))

            if (batch_id + 1) % 100 == 0:
                print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id + 1, loss.numpy(),
                                                                                acc.numpy()))
                write_to_log("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id + 1, loss.numpy(),
                                                                                acc.numpy()))
            # 更新参数
            optim.step()
            # 梯度清零
            optim.clear_grad()
        paddle.save(model.state_dict(), f'./mynet/mynet.ep{epoch}.pdparams')

注意输入的数据的维度要与网络结构保持一致 

四,模型推理

下面来看一下模型的效果:

输出结果:

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/24336.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

12.数组的初始化和引用

数组的初始化 定义数组的时候,顺便给数组的元素赋予初值,即开辟空间的同时并且给数组元素赋值 一维数组的初始化 a. 全部初始化 int a[5] {2,4,7,8,5}; 代表的意思:a[0] 2 , a[1] 4 , a[2] 7 , a[3] 8, a[4] 5; b. 部分初始化 int …

Clever Internet Suite for Delphi, C++Builder

为Internet应用程序添加即时SSL/TLS安全性,并实现许多有用的Internet相关功能。 聪明的互联网套件允许您添加下载、上传和提交互联网资源;发送和接收MIME消息;HTTP、FTP、SMTP、POP3、IMAP和NNTP客户端/服务器解决方案;带有数字证书的SSL/TLS通道支持您的VCL应用程序…

电脑分辨率怎么调?电脑分辨率怎么调合适

​无论是笔记本电脑的用户,还是说台式电脑的用户,在使用电脑的时候,如果电脑分辨率调整的不对,很容易造成显示与观感方面的模糊。电脑分辨率怎么调?电脑分辨率怎么调最佳?本篇文章,小编就来教教…

ASEMI肖特基二极管1N5822参数,1N5822特征,1N5822应用

编辑-Z ASEMI肖特基二极管1N5822参数: 型号:1N5822 最大重复峰值反向电压(VRRM):40V 最大RMS电桥输入电压(VRMS):28V 最大直流阻断电压(VDC)&#xff1a…

三、简单了解kafka设计原理

系列文章目录 文章目录系列文章目录一、Kafka核心总控制器Controller二、kafka高性能简单理解一、Kafka核心总控制器Controller 在Kafka集群中会有一个或者多个broker,其中有一个broker会被选举为控制器(Kafka Controller),它负责…

[ZJCTF 2019]Login--动态调试--详细版

前言 主要是因为太菜了,看了别人的exp,还是懵懵懂懂的,都是静态分析,不明白为会在改密码的时候会导致最后的getshell。今天给它动态分析整一个,看看到底哪里出错了。 基本原理 网上有很多介绍的,在这里说…

Linux学习——01 gcc编译器

一、程序构建过程 高级语言的代码无法被计算机执行,需要将高级语言代码编译成汇编语言,然后再将汇编语言翻译成机器指令,最后通过链接生成最后的可执行文件,此时该文件才可以被计算机执行。总共有四步: 1.1 预编译&a…

[02] BLEMotion-Kit 基于QMI8658传感器使用加速度计进行倾斜检测

文章目录1. 先修知识2. 原理(单轴为例)2.1 单轴倾斜2.2 双轴倾斜2.3 三轴倾斜1. 先修知识 2. 原理(单轴为例) 首先我们要知道的是:当目标轴(本例中为X轴)与地球表面平行时,传感器处于 0g 场。顺时针或逆时针旋转90 将…

springboot+java大学生西部计划志愿者岗位补助管理系统

本课题要求实现一套大学生西部计划管理系,系统主要包括系统个人中心、志愿者管理、岗位信息管理、补助信息管理、交流论坛、系统管理等功能模块。 为完善志愿者、岗位信息,应当建立健全志愿者的补助和管理机制,建立有效的激励机制&#xff0c…

Android Studio无法连接设备,一直显示Loading Devices...

不知道什么时候做了啥,从某个时间点之后,电脑就特别容易断开adb,有时候重启电脑都不管用。 一直显示"Loading Devices...",拔插设备,重启Android Studio都没用,甚至重启电脑有时候也不行。 反正…

全部售罄!1,000 多个Sports Land NFT 在 24 小时内被抢空!

现在还来得及,抓紧时间!👀 在不到24小时的时间里,来自《Sports Land:足球爱好者》作品集(2022 年 11 月 16 日发布)的1000 多个可穿戴 NFT 已被售出! 祝贺 Hermit Crab Game Studio …

bootstrap学习(一)

(1)bootstrap第一个程序 (2)bootstrap排版 (1)bootstrap第一个程序 创建boot文件夹方置bootstrap所需要的文件目录,拷贝过来 创建base目录,创建html页面: 引入css&#…

python复杂网络分析库NetworkX

文章目录1.Networkx简介2.图的类型(Graphs)3.图的创建(Graph Creation)4.图的属性(Graph Reporting)5.图算法(Algorithms)6.图的绘制(Drawing)7.数据结构8.图…

A股api交易接口文档怎么使用?

A股api交易接口是在股票量化交易中常用到的一种量化工具,对于它的用法,小编针对性的以文档的例子说明: 交易接口API 功能概述: 名称 功能 基本函数 Init API 初始化 Deinit API 反初始化 Logon 登录交易账户 Logoff 登…

年底了,接个大活儿,做一个回顾公司五年发展的总结ppt,要求做成H5网页

公司想做个五年总结 这不快年底了么,公司高层打算把这五年的发展历程做一次回顾巡礼,一方面宣扬一下公司文化,另一方面歌颂一下公司这五年来取得的辉煌成就,单纯的做个海报,写个公众号文章,或整个传统ppt在…

最强大脑记忆曲线(11)—— 30天结束第一轮复习后的操作

对于30天以后,结束第一轮(6次)复习以后,我们要做点什么操作呢? 对第一轮复习效果的评判可以是客观的,也可以是主观的。所谓客观的,是按“复习的正确率”来评判,大于某个值&#xff0…

内部类_Java

作者:爱塔居的博客_CSDN博客-JavaSE领域博主 专栏:JavaSE 文章目录 目录 文章目录 一、内部类的概念 二、内部类的分类 1.静态内部类(被static修饰) 2.非静态内部类 3.局部内部类 4.匿名内部类 一、内部类的概念 当一个事物…

【JVM】jvm的双亲委派机制

双亲委派机制一、JVM体系结构二、双亲委派机制的含义三、双亲委派机制的源代码四、双亲委派机制的意义五、示例代码一、JVM体系结构 我们先在这里放一张 JVM 的体系架构图,方便我们有个总体认知。 在了解JVM的双亲委派机制之前,你不得不需要知道的几个…

【Mapbox GL JS 入门】Hello world

目录Mapbox GL JS 简介安装Access tokenHello worldMapbox GL JS 简介 官网:https://www.mapbox.com/ git:https://github.com/mapbox/mapbox-gl-js/ 是一个客户端JavaScript库,为了web开发人员可以在web浏览器中动态绘制地图,在…

pico3pro使用unity播放360全景视频及事件交互

1.准备好全景视频,看起来是这样子的。 2.新建一个Materal 注意选择Shader如上图,Render Queue选择AlphaTest,因为我们要在视频前面放置按钮,UI的渲染值为3000,所以可以避免UI不显示的问题,这样UI会一直显示…