python手写数字识别(PaddlePaddle框架、MNIST数据集)

news2025/1/10 18:18:16

python手写数字识别(PaddlePaddle框架、MNIST数据集)

import paddle
import paddle.nn.functional as F
from paddle.vision.transforms import Compose, Normalize

transform = Compose([Normalize(mean=[127.5],
                               std=[127.5],
                               data_format='CHW')])
# 使用transform对数据集做归一化
print('download training data and load training data')
# 使用飞桨框架自带的 paddle.vision.datasets.MNIST 完成mnist数据集的加载
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transform)
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transform)
print('load finished')

# 用paddle.nn下的API,如Conv2D、MaxPool2D、Linear完成卷积神经网络的构建
class CNN(paddle.nn.Layer):
    def __init__(self):
        super().__init__()
        self.conv1 = paddle.nn.Conv2D(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
        self.max_pool1 = paddle.nn.MaxPool2D(kernel_size=2,  stride=2)
        self.conv2 = paddle.nn.Conv2D(in_channels=6, out_channels=16, kernel_size=5, stride=1)
        self.max_pool2 = paddle.nn.MaxPool2D(kernel_size=2, stride=2)
        self.linear1 = paddle.nn.Linear(in_features=16*5*5, out_features=120)
        self.linear2 = paddle.nn.Linear(in_features=120, out_features=84)
        self.linear3 = paddle.nn.Linear(in_features=84, out_features=10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.max_pool1(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.max_pool2(x)
        x = paddle.flatten(x, start_axis=1,stop_axis=-1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        x = F.relu(x)
        x = self.linear3(x)
        return x


# 开始对模型进行训练,先构建train_loader,加载训练数据,然后定义train函数,设置好损失函数后,按batch加载数据,完成模型的训练
train_loader = paddle.io.DataLoader(train_dataset, batch_size=128, shuffle=True)
# 加载训练集 batch_size 设为 128
def train(model):
    model.train()
    epochs = 3
    optim = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
    # 用Adam作为优化函数
    print("Training:")
    for epoch in range(epochs):
        for batch_id, data in enumerate(train_loader()):
            x_data = data[0]
            y_data = data[1]
            predicts = model(x_data)
            loss = F.cross_entropy(predicts, y_data)
            # 计算损失
            acc = paddle.metric.accuracy(predicts, y_data)
            loss.backward()

            if batch_id % 300 == 0:
                print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
            optim.step()
            optim.clear_grad()
model = CNN()
train(model)


# 训练完成后,需要验证模型的效果,此时,加载测试数据集,然后用训练好的模对测试集进行预测,计算损失与精度。
test_loader = paddle.io.DataLoader(test_dataset, places=paddle.CPUPlace(), batch_size=128)
# 加载测试数据集
def test(model):
    model.eval()
    print("Testing:")
    for batch_id, data in enumerate(test_loader()):
        x_data = data[0]
        y_data = data[1]
        predicts = model(x_data)
        # 获取预测结果
        loss = F.cross_entropy(predicts, y_data)
        acc = paddle.metric.accuracy(predicts, y_data)

        if batch_id % 50 == 0:
            print("batch_id: {}, loss is: {}, acc is: {}".format(batch_id, loss.numpy(), acc.numpy()))
test(model)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1678709.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件确认测试详细介绍

在软件开发流程中,确认测试是一个至关重要的环节,它确保软件产品满足预定的需求、性能和质量标准。本文将详细介绍软件确认测试的概念、目的、方法、执行步骤以及其在软件开发周期中的重要性。   一、软件确认测试的概念   软件确认测试,…

SuperBox设计出图的效率提升!新增内门自动开孔和垫高支架图纸输出功能

越来越多的配电箱项目要求带内门,内门不仅可以有效减少外界灰尘、异物进入配电箱内部,保障配电箱正常运行,还能够隔离操作人员意外触摸导电部件,减少触电事故的发生。但是配电箱在配置内门后,会给设计带来更多的要求&a…

[图解]SysML和EA建模住宅安全系统-04

1 00:00:01,200 --> 00:00:04,710 我们首先来看一下需求图的一些要点 2 00:00:05,810 --> 00:00:07,080 需求图用来干什么 3 00:00:07,210 --> 00:00:12,080 用来记录文本形式的一些需求 4 00:00:12,090 --> 00:00:13,480 和需求的素材 5 00:00:14,540 --> …

【GESP】2023年12月图形化三级 -- 小杨做题

小杨做题 【题目描述】 为了准备考试,小杨每天都要做题。第 1 天,小杨做了 a a a 道题,第 2 天,小杨做了 b b b

GPT-4o: 从最难的“大海捞针”基准看起

大模型技术论文不断,每个月总会新增上千篇。本专栏精选论文重点解读,主题还是围绕着行业实践和工程量产。若在阅读过程中有些知识点存在盲区,可以回到如何优雅的谈论大模型重新阅读。另外斯坦福2024人工智能报告解读为通识性读物。若对于如果…

高压无源探头能测整流桥电压吗?

高压无源探头是用于测量高电压电路中信号的一种工具,它不需要外部电源供电。然而,对于测量整流桥电压,需要考虑几个因素以确定是否可以使用高压无源探头。 首先,让我们了解一下整流桥的基本原理。整流桥是一种电路,用…

华为OD机试 - 反射计数 - 矩阵(Java 2024 C卷 200分)

华为OD机试 2024C卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(A卷B卷C卷)》。 刷的越多,抽中的概率越大,每一题都有详细的答题思路、详细的代码注释、样例测试…

人工智能领域向量化技术加速多模态大模型训练与应用

目录 前言1、TextIn文档解析技术1.1、文档解析技术1.2、目前存在的问题1.2.1、不规则的文档信息示例 1.3、合合信息的文档解析1.3.1、合合信息的TextIn文档解析技术架构1.3.2、版面分析关键技术 Layout-engine1.3.3、文档树提取关键技术 Catalog-engine1.3.4、双栏1.3.5、非对称…

【Java基础】集合(1) —— Collection

存储不同类型的对象: Object[] arrnew object[5];数组的长度是固定的, 添加或删除数据比较耗时 集合: Object[] toArray可以存储不同类型的对象随着存储的对象的增加,会自动的扩容集合提供了非常丰富的方法,便于操纵集合相当于容器,可以存储多…

运行npm install时报错“npm ERR! code 1”

目录 一、问题分析 二、解决问题 一、问题分析 有registry淘宝镜像地址过期的问题,改一下地址 npm淘宝镜像过期解决办法-CSDN博客主要问题是node-sass和sass-loader版本冲突 打开cmd,输入"node -v"查看node版本 我的版本是16,应…

电子企业实施数字工厂管理系统会遇到哪些挑战

随着信息技术的飞速发展,数字化转型已成为电子企业提升竞争力、实现可持续发展的关键途径。数字工厂管理系统作为数字化转型的核心部分,旨在通过集成各种信息技术,实现生产过程的自动化、智能化和高效化。然而,电子企业在实施数字…

低成本创业分享,一个不用自己囤货、进货、直播的项目|抖音小店

大家好,我是喷火龙 在抖音上面开店,不仅可以卖自己的商品,还可以卖别人的商品赚差价, 并且不需要你囤货、进货、直播、剪视频,也不需要有粉丝。 这个项目就是抖音小店无货源。 很多朋友对抖音小店无货源模式的玩法…

【自然语言处理】【大模型】DeepSeek-V2论文解析

论文地址:https://arxiv.org/pdf/2405.04434 相关博客 【自然语言处理】【大模型】DeepSeek-V2论文解析 【自然语言处理】【大模型】BitNet:用1-bit Transformer训练LLM 【自然语言处理】BitNet b1.58:1bit LLM时代 【自然语言处理】【长文本…

融资融券最低利率4.0!,融资融券利息计算公式,怎么开通?

融资融券的费率: 融资融券的费率主要包括融资利率和融券费率,这些费率的高低主要取决于证券公司的成本、政策倾向以及投资者的资金量大小。 融资利率方面,多数券商的优惠融资利率在5.5%到7.5%之间,与券商的成本和政策有关。一些…

【车载开发系列】AutoSar中的Port

【车载开发系列】AutoSar中的Port 一. Port概念 AutoSAR 接口定义了 SWC 之间、BSW 模块之间以及 SWC 和 BSW 模块之间交互的信息。AutoSAR 接口通过 SWC 和/或 BSW 模块端口(Port)的形式实现。通过这些端口,SWC 和 BSW 模块之间实现了数据…

Adobe Premiere Pro v24.3.0 解锁版 (领先的视频编辑软件)

Adobe系列软件安装目录 一、Adobe Photoshop PS 25.6.0 解锁版 (最流行的图像设计软件) 二、Adobe Media Encoder ME v24.3.0 解锁版 (视频和音频编码渲染工具) 三、Adobe Premiere Pro v24.3.0 解锁版 (领先的视频编辑软件) 四、Adobe After Effects AE v24.3.0 解锁版 (视…

vue3和vite

vue3 1、vue3使如何实现效率提升的 客户端渲染效率比vue2提升了1.3~2倍 SSR渲染效率比vue2提升了2~3倍 1.1、静态提升 解释&#xff1a; 1. 对于静态节点&#xff08;如&#xff1a;<h1>接着奏乐接着舞</h1>&#xff09;&#xff0c;vue3直接提出来了&#xff…

应用层之 HTTP 协议

HTTP 协议 HTTP (全称为 "超文本传输协议") 是一种应用非常广泛的 应用层协议。所谓 "超文本" 的含义, 就是传输的内容不仅仅是文本(比如 html, css 这个就是文本), 还可以是一些 其他的资源, 比如图片, 视频, 音频等二进制的数据。浏览器获取到网页&#…

了解 Robot Framework :接口自动化测试教程!

开源自动化测试利器&#xff1a;Robot Framework Robot Framework 是一个用于实现自动化测试和机器人流程自动化&#xff08;RPA&#xff09;的开放源代码框架。它由一个名为 Robot Framework Foundation 的组织得到推广&#xff0c;得到了多家领军企业在软件开发中的广泛应用。…

Shopline和Shopify哪个更好?Shopline和Shopify的区别

Shopline和Shopify哪个更好取决于用户面向的市场&#xff0c;面向亚洲市场就更适合有本地化支持的Shopline&#xff0c;而如果希望拓展全球业务&#xff0c;Shopify可能更好。 Shopline和Shopify都是知名的电子商务平台&#xff0c;可以很好的帮助商家搭建和管理在线商店&…