paddle实现手写数字模型(一)

news2024/10/7 20:33:11
  1. 参考文档:paddle官网文档
  2. 环境:Python 3.12.2 ,pip 24.0 ,paddlepaddle 2.6.0
    python -m pip install paddlepaddle==2.6.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
  3. 调试代码如下:
    LeNet.py
import paddle
import paddle.nn.functional as F

class LeNet(paddle.nn.Layer):
    def __init__(self):
        super().__init__()
        self.conv1 = paddle.nn.Conv2D(in_channels=1,out_channels=6,kernel_size=5,stride=1,padding=2)
        self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2,  stride=2)
        self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
        self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
        self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
        self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.max_pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.max_pool2(x)
        x = paddle.flatten(x, start_axis=1,stop_axis=-1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.linear3(x)
        return x

train.py


import paddle
from paddle.vision.transforms import Compose,Normalize,ToTensor
import paddle.vision.transforms as T  

import numpy as np
import matplotlib.pyplot as plt
from paddle.metric import Accuracy

from LeNet import LeNet
from PIL import Image



print(paddle.__version__)
transform = Compose([Normalize(mean=[127.5],std=[127.5],data_format='CHW')])
print('下载和加载训练数据...')
train_dataset = paddle.vision.datasets.MNIST(mode='train',transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test',transform=transform)
print('load finished')

train_data0,train_label_0 = train_dataset[0][0],train_dataset[0][1]
train_data0 = train_data0.reshape([28,28])
plt.figure(figsize=(2,2))
plt.imshow(train_data0,cmap=plt.cm.binary)
#plt.show()
print('train_data0 label is: '+str(train_label_0))


model = paddle.Model(LeNet())   # 用Model封装模型
optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())

# 配置模型
print('配置模型...')
model.prepare(
    optim,
    paddle.nn.CrossEntropyLoss(),
    Accuracy()
    )
# 训练模型
print('训练模型...')
model.fit(train_dataset,
        epochs=2,
        batch_size=64,
        verbose=1
        )
# 保存模型  
model.save('./model/mnist_model')  # 默认保存模型结构和参数 

#预测模型
print('预测模型...')
model.evaluate(test_dataset, batch_size=64, verbose=1)


predicted.py


import paddle

import numpy as np

from LeNet import LeNet
from PIL import Image

# 读取一张本地的样例图片,转变成模型输入的格式
def load_image(img_path):
    # 从img_path中读取图像,并转为灰度图
    im = Image.open(img_path).convert('L')
    #plt.imshow(im,cmap='gray')
    # print(np.array(im))
    im = im.resize((28, 28), Image.Resampling.LANCZOS)
    im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
    # 图像归一化,保持和数据集的数据范围一致
    im = 1 - im / 255 
    return im

# 加载训练好的模型参数
model = LeNet()
model.load_dict(paddle.load('./model/mnist_model.pdparams'))

# 设置模型为评估模式
model.eval()

# 准备一个MNIST样例图像
example_image = load_image("d:/8.png")

# 转换为Tensor并进行推理
with paddle.no_grad():
    example_tensor = paddle.to_tensor(example_image)
    prediction = model(example_tensor)
    print(prediction)

# 获取预测类别
predicted_class = np.argmax(prediction.numpy(), axis=1)[0]
print(f"Predicted class: {predicted_class}")

说明:先通过执行train.py训练数据集,将模型保存在model文件夹中,
然后运行predicted.py加载训练出来的数据集,推理出d:/8.png图片的结果。
结果图片如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1577418.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

初学python记录:力扣1600. 王位继承顺序

题目: 一个王国里住着国王、他的孩子们、他的孙子们等等。每一个时间点,这个家庭里有人出生也有人死亡。 这个王国有一个明确规定的王位继承顺序,第一继承人总是国王自己。我们定义递归函数 Successor(x, curOrder) ,给定一个人…

数据结构——二叉树——二叉搜索树(Binary Search Tree, BST)

目录 一、98. 验证二叉搜索树 二、96. 不同的二叉搜索树 三、538. 把二叉搜索树转换为累加树 二叉搜索树:对于二叉搜索树中的每个结点,其左子结点的值小于该结点的值,而右子结点的值大于该结点的值 一、98. 验证二叉搜索树 给你一个二叉树的…

GAMES Webinar 317-渲染专题-图形学 vs. 视觉大模型|Talk+Panel形式

两条路线:传统渲染路线,生成路线 两种路线的目的都是最终生成图片或者视频等在现在生成大火的情况下,传统路线未来该如何发展呢,两种路线是否能够兼容呢 严令琪 这篇工作是吸取这两条路各自优势的一篇工作 RGB是一张图&#xff…

好用的AI智能工具:AI写作、AI绘画、AI翻译全都有

在科技不断进步的今天,人工智能(AI)已经成为我们日常生活中不可或缺的一部分。它不仅在各个领域都有应用,还为我们提供了许多方便快捷的工具。对此,小编今天推荐7款人工智能软件,AI写作、AI绘画、AI翻译全都…

Vue - 你知道Vue组件之间是如何进行数据传递的吗

难度级别:中级及以上 提问概率:85% 这道题还可以理解为Vue组件之间的数据是如何进行共享的,也可以理解为组件之间是如何通信的,很多人叫法不同,但都是说的同一个意思。我们知道,在Vue单页面应用项目中,所有的组件都是被嵌套在App.vue内…

2024/4/1—力扣—BiNode

代码实现: /*** Definition for a binary tree node.* struct TreeNode {* int val;* struct TreeNode *left;* struct TreeNode *right;* };*/void convertBiNode_pro(struct TreeNode *root, struct TreeNode **p) {if (root) {convertBiNode_pro(roo…

Git - 如何重置或更改 Git SSH 密钥的密码?

Git 使用 ssh 方式拉取代码时,报 ssh password login,提示输入密码,这时很容易误填为 Git 的登录密码,其实这时需要输入 SSH 证书的密码,下面直接提供更改以及重新导入证书的方式。 首先需要确认你的本地是否有 SSH 钥…

HIS系统是什么?一套前后端分离云HIS系统源码 接口技术RESTful API + WebSocket + WebService

HIS系统是什么?一套前后端分离云HIS系统源码 接口技术RESTful API WebSocket WebService 医院管理信息系统(全称为Hospital Information System)即HIS系统。 常规模版包括门诊管理、住院管理、药房管理、药库管理、院长查询、电子处方、物资管理、媒体管理等&…

与汇智知了堂共舞,HW行动开启你的网络安全新篇章!

**网安圈内一年一度的HW行动来啦! ** 招募对象 不限,有HW项目经验 或持有NISP二级、CISP证书优先 HW时间 以官方正式通知为准 工作地点:全国 薪资待遇 带薪HW (根据考核成绩500-4000元/天不等) 招募流程 1.填写报名…

中科数安 || 公司电脑文件资料防泄密系统

#公司电脑文件资料防泄密# 中科数安推出的公司电脑文件资料防泄密系统,是一款专为企业电脑终端设计的数据安全解决方案,旨在全方位保护公司电脑中存储、处理、传输的各类文件资料免遭非法窃取、泄露或滥用。 中科数安 || 文件数据资料防泄密软件 PC地址…

第二十五周代码(蓝桥杯查缺补漏)

2024/03/31 周日 填充 题目链接 【参考代码】 想用暴力&#xff0c;没过 //枚举&#xff0c;未出结果QAQ #include <bits/stdc.h> using namespace std; string s00 "00"; string s11 "11"; int ans 0; //m个问号&#xff0c;子串有2^m…

如何本地搭建Discuz论坛并实现无公网IP远程访问

文章目录 前言1.安装基础环境2.一键部署Discuz3.安装cpolar工具4.配置域名访问Discuz5.固定域名公网地址6.配置Discuz论坛 前言 Crossday Discuz! Board&#xff08;以下简称 Discuz!&#xff09;是一套通用的社区论坛软件系统&#xff0c;用户可以在不需要任何编程的基础上&a…

基于velero和minio实现k8s数据的备份

1.30部署minio rootk8s-harbor:/etc/kubeasz/clusters/k8s-cluster1# docker run \ -d --restartalways -p 9000:9000 -p 9090:9090 –name minio -v /data/minio/data:/data -e “MINIO_ROOT_USERadmin” -e “MINIO_ROOT_PASSWORD12345678” quay.io/minio/minio server…

Golang | Leetcode Golang题解之第9题回文数

题目&#xff1a; 题解&#xff1a; func isPalindrome(x int) bool {// 特殊情况&#xff1a;// 如上所述&#xff0c;当 x < 0 时&#xff0c;x 不是回文数。// 同样地&#xff0c;如果数字的最后一位是 0&#xff0c;为了使该数字为回文&#xff0c;// 则其第一位数字也…

2024Spring> HNU-计算机系统-实验2-datalab-导引

前言 datalab考验对于位运算以及浮点数存储的理解&#xff0c;如果真的肯花时间去搞懂&#xff0c;对计算机系统存储的理解真的能上一个台阶。与课程考试关联性上来说不是很大&#xff0c;但对于IEEE的浮点数表示一定要熟练掌握。 导引 ①实验工具包 要完成的是bits.c中的15个…

Java | Leetcode Java题解之第13题罗马数字转整数

题目&#xff1a; 题解&#xff1a; class Solution {Map<Character, Integer> symbolValues new HashMap<Character, Integer>() {{put(I, 1);put(V, 5);put(X, 10);put(L, 50);put(C, 100);put(D, 500);put(M, 1000);}};public int romanToInt(String s) {int …

Linux中shell脚本的学习第一天,编写脚本的规范,脚本注释、变量,特殊变量的使用等,包含面试题

4月7日没参加体侧的我自学shell的第一天 Shebang 计算机程序中&#xff0c;shebang指的是出现在文本文件的第一行前两个字符 #&#xff01; 1)以#!/bin/sh 开头的文件&#xff0c;程序在执行的时候会调用/bin/sh, 也就是bash解释器 2)以#!/usr/bin/python 开头的文件&#…

Qt通讯录管理系统

在git上面找的一个操作文件的qt通讯录管理系统,尝试将它复刻了一下. 成果展示 分两个txt文件存储,一个是手机联系人,一个是电话卡联系人,主要功能就是增删改查,主要使用的是OOP的编程思想. 实现过程 界面布局 考虑设计三个界面,主界面,添加联系人界面和修改联系人的界面.于是添…

深入理解nginx realip模块[上]

目录 1. 引言2. Real IP模块的使用2.1 启用Real IP模块2.2 配置Real IP模块2.2.1 配置指令2.2.2 举例 3. 变量的使用 深入理解nginx realip模块[上] 深入理解nginx realip模块[下] 1. 引言 nginx 的 Real IP 模块用于解决代理服务器转发请求到nginx上时可能出现的 IP 地址问题…

ES入门十一:正排索引和倒排索引

索引本质上就是一种加快检索数据的存储结构&#xff0c;就像书本的目录一下。 为了更好的理解正排索引和倒排索引&#xff0c;我们借由一个 **唐诗宋词比赛&#xff0c;**这个比赛一共有两个项目&#xff1a; 给定诗词名称&#xff0c;背诵整首给诗词中几个词语&#xff0c;让…