基于飞桨paddle的极简方案构建手写数字识别模型测试代码

news2024/11/13 10:11:18

基于飞桨paddle的极简方案构建手写数字识别模型测试代码
在这里插入图片描述
原始测试图片为255X252的图片
因为是极简方案采用的是线性回归模型,所以预测结果数字不一致
本次预测的数字是 [[3]]
测试结果:

PS E:\project\python> & D:/Python39/python.exe e:/project/python/MNIST.py
10.0.0
2.4.2
图像数据形状和对应数据为: (28, 28)
图像标签形状和对应数据为: (1,) [5]

打印第一个batch的第一个图像,对应标签数字为[5]
epoch_id: 0, batch_id: 0, loss is: [34.4626]
epoch_id: 0, batch_id: 1000, loss is: [7.599941]
epoch_id: 0, batch_id: 2000, loss is: [4.583123]
epoch_id: 0, batch_id: 3000, loss is: [2.8974648]
epoch_id: 1, batch_id: 0, loss is: [3.610869]
epoch_id: 1, batch_id: 1000, loss is: [5.6290216]
epoch_id: 1, batch_id: 2000, loss is: [1.9465038]
epoch_id: 1, batch_id: 3000, loss is: [2.1046467]
epoch_id: 7, batch_id: 2000, loss is: [4.63013]
epoch_id: 7, batch_id: 3000, loss is: [4.4638147]
epoch_id: 8, batch_id: 0, loss is: [3.0043283]
epoch_id: 8, batch_id: 1000, loss is: [1.633965]
epoch_id: 8, batch_id: 2000, loss is: [3.1906333]
epoch_id: 8, batch_id: 3000, loss is: [2.4461133]
epoch_id: 9, batch_id: 0, loss is: [3.9595613]
epoch_id: 9, batch_id: 1000, loss is: [1.3417265]
epoch_id: 9, batch_id: 2000, loss is: [2.3505783]
epoch_id: 9, batch_id: 3000, loss is: [2.0194921]
原始图像shape:  (252, 255)
采样后图片shape:  (28, 28)
result Tensor(shape=[1, 1], dtype=float32, place=Place(cpu), stop_gradient=False,
       [[3.94108272]])
本次预测的数字是 [[3]]
PS E:\project\python>

测试代码如下所示:

#加载飞桨和相关类库
import paddle
from paddle.nn import Linear
import paddle.nn.functional as F
import os
import numpy as np
import matplotlib.pyplot as plt
# 导入图像读取第三方库
from PIL import Image,ImageFilter
print(Image.__version__)    #10.0.0
#原来是在pillow的10.0.0版本中,ANTIALIAS方法被删除了,使用新的方法即可Image.LANCZOS
#或降级版本为9.5.0,安装pip install Pillow==9.5.0
print(paddle.__version__)   #2.4.2

#飞桨提供了多个封装好的数据集API,涵盖计算机视觉、自然语言处理、推荐系统等多个领域,
# 帮助读者快速完成深度学习任务。
# 如在手写数字识别任务中,
# 通过paddle.vision.datasets.MNIST可以直接获取处理好的MNIST训练集、测试集,
# 飞桨API支持如下常见的学术数据集:
'''
mnist
cifar
Conll05
imdb
imikolov
movielens
sentiment
uci_housing
wmt14
wmt16
'''

#数据处理
# 设置数据读取器,API自动读取MNIST数据训练集
train_dataset = paddle.vision.datasets.MNIST(mode='train')

train_data0 = np.array(train_dataset[0][0])
train_label_0 = np.array(train_dataset[0][1])

# 显示第一batch的第一个图像
'''
import matplotlib.pyplot as plt
plt.figure("Image") # 图像窗口名称
plt.figure(figsize=(2,2))
plt.imshow(train_data0, cmap=plt.cm.binary)
plt.axis('on') # 关掉坐标轴为 off
plt.title('image') # 图像题目
plt.show()
'''

print("图像数据形状和对应数据为:", train_data0.shape)                          #(28, 28)
print("图像标签形状和对应数据为:", train_label_0.shape, train_label_0)         #(1,) [5]
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(train_label_0))   # [5]

#飞桨将维度是28×28的手写数字图像转成向量形式存储,
# 因此使用飞桨数据加载器读取到的手写数字图像是长度为784(28×28)的向量。

#模型设计
#模型的输入为784维(28×28)数据,输出为1维数据,

# 定义mnist数据识别网络结构,同房价预测网络
#===========================================
class MNIST(paddle.nn.Layer):
    def __init__(self):
        super(MNIST, self).__init__()
        
        # 定义一层全连接层,输出维度是1
        self.fc = paddle.nn.Linear(in_features=784, out_features=1)
        
    # 定义网络结构的前向计算过程
    def forward(self, inputs):
        outputs = self.fc(inputs)
        return outputs
#===========================================

#训练配置
# 声明网络结构
model = MNIST()
def train(model):
    # 启动训练模式
    model.train()
    # 加载训练集 batch_size 设为 16
    train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), 
                                        batch_size=16, 
                                        shuffle=True)
    # 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001
    opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())
#===========================================
# 图像归一化函数,将数据范围为[0, 255]的图像归一化到[0, 1]
def norm_img(img):
    # 验证传入数据格式是否正确,img的shape为[batch_size, 28, 28]
    assert len(img.shape) == 3
    batch_size, img_h, img_w = img.shape[0], img.shape[1], img.shape[2]
    # 归一化图像数据
    img = img / 255
    # 将图像形式reshape为[batch_size, 784]
    img = paddle.reshape(img, [batch_size, img_h*img_w])
    
    return img  
#===========================================   
import paddle
# 确保从paddle.vision.datasets.MNIST中加载的图像数据是np.ndarray类型
paddle.vision.set_image_backend('cv2')

# 声明网络结构
model = MNIST()
#===========================================
def run(model):
    # 启动训练模式
    model.train()
    # 加载训练集 batch_size 设为 16
    train_loader = paddle.io.DataLoader(paddle.vision.datasets.MNIST(mode='train'), 
                                        batch_size=16, 
                                        shuffle=True)
    # 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001
    opt = paddle.optimizer.SGD(learning_rate=0.001, parameters=model.parameters())
    EPOCH_NUM = 10
    for epoch in range(EPOCH_NUM):
        for batch_id, data in enumerate(train_loader()):
            images = norm_img(data[0]).astype('float32')
            labels = data[1].astype('float32')
            
            #前向计算的过程
            predicts = model(images)
            
            # 计算损失
            loss = F.square_error_cost(predicts, labels)
            avg_loss = paddle.mean(loss)
            
            #每训练了1000批次的数据,打印下当前Loss的情况
            if batch_id % 1000 == 0:
                print("epoch_id: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, avg_loss.numpy()))
            
            #后向传播,更新参数的过程
            avg_loss.backward()
            opt.step()
            opt.clear_grad()
#===========================================
#调用训练            
run(model)
paddle.save(model.state_dict(), './mnist.pdparams')  

#模型测试

#===========================================
def showImage(im):
  #img_path = 'example_0.jpg'
  # 读取原始图像并显示
  #im = Image.open('example_0.jpg')
  plt.imshow(im)
  plt.show()
  # 将原始图像转为灰度图
  im = im.convert('L')
  print('原始图像shape: ', np.array(im).shape)
  # 使用Image.ANTIALIAS方式采样原始图片
  im = im.resize((28, 28), Image.LANCZOS)
  plt.imshow(im)
  plt.show()
  print("采样后图片shape: ", np.array(im).shape)
#===========================================
im = Image.open('example_0.jpg')
showImage(im)

# 读取一张本地的样例图片,转变成模型输入的格式
#=========================================== 
def load_image(img_path):
    # 从img_path中读取图像,并转为灰度图
    im = Image.open(img_path).convert('L')
    # print(np.array(im))
    im = im.resize((28, 28), Image.LANCZOS)
    im = np.array(im).reshape(1, -1).astype(np.float32)
    # 图像归一化,保持和数据集的数据范围一致
    im = 1 - im / 255
    return im
#=========================================== 
# 定义预测过程
def test():
  model = MNIST()
  params_file_path = 'mnist.pdparams'
  img_path = 'example_0.jpg'
  # 加载模型参数
  param_dict = paddle.load(params_file_path)
  model.load_dict(param_dict)
  # 灌入数据
  model.eval()
  tensor_img = load_image(img_path)  
  result = model(paddle.to_tensor(tensor_img))
  print('result',result)
  #  预测输出取整,即为预测的数字,打印结果
  print("本次预测的数字是", result.numpy().astype('int32'))
#=========================================== 
test(); 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/802760.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第五章 数组

定义 数组是一组相同类型元素的集合,但我们需要创建多个相同类型的变量时,只需要创建一个类型的数组,就相当于同时创建很多相同类型的变量。 一维数组 数组如何创建 从定义来入手看一下数组的创建: type_t arr_name[const_n];…

《向量数据库指南》——FAISS和Chroma:两种流行的向量数据库的比较

目录 FAISS Chroma 比较 向量数据库是一种可以存储和检索高维向量数据的数据库,高维向量数据是一种可以表示任何类型数据的A.I原生方式,比如文本、图像、音频等。向量数据库可以用于实现各种基于相似度搜索和聚类的A.I应用,比如语义搜索、推荐系统、图像识别等。在本文中…

Spring Boot——Spring Boot自动配置原理

系列文章目录 Spring Boot启动原理 Spring Boot自动配置原理 系列文章目录前言一、Spring Boot自动配置原理剖析二、自动配置生效三、总结: 前言 一直在使用Spring Boot特别好奇的是为什么Spring Boot比Spring在项目构建和开发过程中要方便很多,无需编…

二叉树的层序遍历(两种方法:迭代+递归)

题目: 给你二叉树的根节点 root ,返回其节点值的 层序遍历 。 (即逐层地,从左到右访问所有节点)。 输入:root [3,9,20,null,null,15,7] 输出:[[3],[9,20],[15,7]] 解题思路:迭代法…

【设计模式——学习笔记】23种设计模式——组合模式Composite(原理讲解+应用场景介绍+案例介绍+Java代码实现)

案例引入 学校院系展示 编写程序展示一个学校院系结构: 需求是这样,要在一个页面中展示出学校的院系组成,一个学校有多个学院,一个学院有多个系 【传统方式】 将学院看做是学校的子类,系是学院的子类,小的组织继承大…

位1的个数,编写一个函数,输入是一个无符号整数(以二进制串的形式),返回其二进制表达式中数字位数为 ‘1‘ 的个数(也被称为汉明重量)。

题记: 编写一个函数,输入是一个无符号整数(以二进制串的形式),返回其二进制表达式中数字位数为 ‘1’ 的个数(也被称为汉明重量)。 提示: 请注意,在某些语言&#xff…

MySQL使用xtrabackup备份和恢复教程

1、xtrabackup说明 xtrabackup是percona开源的mysql物理备份工具。 xtrabackup 8.0支持mysql 8.0版本的备份和恢复。 xtrabackup 2.4支持mysql 5.7及以下版本的备份和恢复。 这里我以xtrabackup 8.0为例讲解备份和恢复的具体操作方法。 xtrabackup 2.4版本的使用上和8.0版本相…

PX4从放弃到精通(二十九):传感器冗余机制

文章目录 前言一、parametersUpdate二、imuPoll三、 put四、 confidence五、 get_best 前言 PX4 1.13.2 一个人可以走的更快,一群人才能走的更远,可加文章底部微信名片 代码的位置如下 PX4冗余机制主要通过传感读数错误计数和传感器的优先级进行选优 …

解决[Vue Router warn]: No match found for location with path “/day“问题

首先是升级vue-router4.0后会警告[Vue Router warn]: No match found for location with path "/day" 找了许久解决方案如下: 一、404页面不需要再异步路由后边添加,直接放到静态路由里即可 二、要注意不能写name,否则会刷新默认…

Parameter ‘roleList‘ not found.

Parameter roleList not found. Available parameters are [arg1, arg0, param1, param2] 多半是Mapper层传入多个参数的时候,没有加Param注解,导致BindException错误

ORA-01187 ORA-01110

ORA-01187: cannot read from file because it failed verification tests ORA-01110: data file 201: ‘/u01/app/oracle/oradata/CNDB/temp01.dbf’ 查询临时文件是存在的 重建临时数据文件 删除临时文件: alter database tempfile /u01/app/oracle/oradata…

56. 合并区间 排序

Problem: 56. 合并区间 文章目录 思路Code 思路 对数组排序,按照左端点从小到大排序。初始化Merged,将第一个区间放入。遍历intervals ,如果当前区间的左端点比merged最后一个区间的右端点大,不重合,直接将该区间加入最后&#xf…

《零基础入门学习Python》第070讲:GUI的终极选择:Tkinter7

上节课我们介绍了Text组件的Indexs 索引和 Marks 标记,它们主要是用于定位,Marks 可以看做是特殊的 Indexs,但是它们又不是完全相同的,比如在默认情况下,你在Marks指定的位置中插入数据,Marks 的位置会自动…

指针的基础应用(数组的颠倒和排序,二维数组的表示)

1.数组的颠倒&#xff1a;若有10个数字&#xff0c;那么数组的颠倒即 a[0]与a[9]交换,a[1]与a[8]交换&#xff0c;a[2]与a[7]交换&#xff0c;......a[4]与a[5]交换&#xff0c;所以到a[4]就颠倒完毕&#xff0c;即 (n-1)/2 若不用指针代码如下 #include<stdio.h>voi…

交互式AI技术与模型部署:使用Gradio完成一项简单的交互式界面

下面的这段代码使用Gradio库创建了一个简单的交互式界面。用户可以输入名称、选择是早上还是晚上、拖动滑动条来选择温度&#xff0c;然后点击"Launch"按钮&#xff0c;界面会显示相应的问候语和摄氏度温度。例如&#xff0c;如果用户输入"John"&#xff0…

iperf3 编译安装及网讯WX1860千兆网口测试

iperf3 编译安装及网讯1860千兆网口测试 编译安装 安装包下载地址:https://github.com/esnet/iperf/archive/refs/tags/3.8.tar.gz 将安装包iperf-3.8.tar.gz拷贝测试系统盘桌面,使用如下命令进行编译安装: tar zxvf iperf-3.8.tar.gz cd iperf-3.8 ./configure make s…

LeetCode-222-完全二叉树的节点个数

一&#xff1a;题目描述&#xff1a; 给你一棵 完全二叉树 的根节点 root &#xff0c;求出该树的节点个数。 完全二叉树 的定义如下&#xff1a;在完全二叉树中&#xff0c;除了最底层节点可能没填满外&#xff0c;其余每层节点数都达到最大值&#xff0c;并且最下面一层的节…

点餐系统测试报告

文章目录 一、项目介绍项目简介功能介绍 二、测试计划1 功能测试功能测试用例发现的 BUG 和 解决方法注册功能上传图片功能 2 自动化测试3 性能测试 一、项目介绍 项目简介 该项目是一个门店点餐系统&#xff0c;采用前后端分离的方式实现&#xff0c;后端框架是SSM&#xff…

R-并行计算

本文介绍在计算机多核上通过parallel包进行并行计算。 并行计算运算步骤&#xff1a; 加载并行计算包&#xff0c;如library(parallel)。创建几个“workers”,通常一个workers一个核&#xff08;core&#xff09;&#xff1b;这些workers什么都不知道&#xff0c;它们的全局环…

第一次后端复习整理(JVM、Redis、反射)

1. JVM 文章仅为自身笔记 详情查看一篇文章掌握整个JVM&#xff0c;JVM超详细解析&#xff01;&#xff01;&#xff01; 1.1 什么是JVM jvm是Java虚拟机 1.2 Java文件的编译过程 程序员编写代码形成.java文件经过javac编译成.class文件再通过JVM的类加载器进入运行时数据…