PyTorch手写体数字识别实例

news2024/11/15 14:06:48

MNIST数据集的准备

“HelloWorld”是所有编程语言入门的基础程序,在开始编程学习时,我们打印的第一句话通常就是这个“HelloWorld”。本书也不例外,在深度学习编程中也有其特有的“HelloWorld”,一般就是采用MNIST完成一项特定的深度学习项目。

MNIST是一个手写数字图像数据库,如图2-21所示,它有60 000个训练样本集和10 000个测试样本集。读者可直接使用本书源码库提供的MNIST数据集,它位于配套源码的dataset文件夹中,如图2-22所示。

然后使用NumPy数据库进行数据的读取,代码如下:

import numpy as np

x_train = np.load("./dataset/mnist/x_train.npy")

y_train_label = np.load("./dataset/mnist/y_train_label.npy")

或者读者也可以在网上搜索MNIST的下载地址,下载MNIST文件中包含的数据集train-images-idx3-ubyte.gz(训练图片集)、train-labels-idx1-ubyte.gz(训练标签集)、t10k-images-idx3-ubyte.gz(测试图片集)和t10k-labels-idx1-ubyte.gz(测试标签集),如图2-23所示。

图2-23  MNIST文件中包含的数据集

将下载的4个文件进行解压缩。解压缩后,会发现这些文件并不是标准的图像格式,而是二进制文件,将文件保存到源码可以访问到的目录下。

基于PyTorch的手写体识别

下面我们开始基于PyTorch的手写体识别。通过2.3.4小节的介绍可知,我们还需要定义的一个内容就是深度学习的优化器部分,在这里采用Adam优化器,这部分代码如下:

model = NeuralNetwork()

optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数

完整的手写体识别首先需要定义模型,然后将模型参数传入优化器中。其中lr是对学习率的设定,根据设定的学习率进行模型计算。完整的手写体识别模型代码如下:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编号
import torch
import numpy as np
from tqdm import tqdm

batch_size = 320            	#设定每次训练的批次数
epochs = 1024              	#设定训练次数

#device = "cpu"    	#PyTorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"   	#这里默认使用GPU,如果出现运行问题,可以将其改成CPU模式


#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = torch.nn.Flatten()
        self.linear_relu_stack = torch.nn.Sequential(
            torch.nn.Linear(28*28,312),
            torch.nn.ReLU(),
            torch.nn.Linear(312, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 10)
        )
    def forward(self, input):
        x = self.flatten(input)
        logits = self.linear_relu_stack(x)

        return logits

model = NeuralNetwork()
model = model.to(device)          	#将计算模型传入GPU硬件等待计算
#model = torch.compile(model)     	#PyTorch 2.0的特性,加速计算速度,选择性使用
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数

#载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")

train_num = len(x_train)//batch_size

#开始计算
for epoch in range(20):
    train_loss = 0
    for i in range(train_num):
        start = i * batch_size
        end = (i + 1) * batch_size

        train_batch = torch.tensor(x_train[start:end]).to(device)
        label_batch = torch.tensor(y_train_label[start:end]).to(device)

        pred = model(train_batch)
        loss = loss_fu(pred,label_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()  #记录每个批次的损失值

    #计算并打印损失值
    train_loss /= train_num
    accuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_size
    print("train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

模型的训练结果如图3-5所示。

图3-5  训练结果

可以看到,随着模型循环次数的增加,模型的损失值在降低,而准确率在提高,具体请读者自行验证测试。

《PyTorch深度学习与计算机视觉实践(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书 (jd.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1942539.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java二十三种设计模式-代理模式模式(8/23)

代理模式:为对象访问提供灵活的控制 引言 代理模式(Proxy Pattern)是一种结构型设计模式,它为其他对象提供一个代替或占位符,以控制对它的访问。 基础知识,java设计模式总体来说设计模式分为三大类&#…

Ant Design Vue中日期选择器快捷选择 presets 用法

ant写文档的纯懒狗 返回的是一个day.js对象 范围选择时可接受一个数组 具体参考 操作 Day.js 话不多说 直接上代码 <a-range-pickerv-model:value"formData.datePick"valueFormat"YYYY-MM-DD HH:mm:ss"showTime:presets"presets"change&quo…

一、C#概述

本文是网页版《C# 12.0 本质论》第一章解读。欲完整跟踪本系列文章&#xff0c;请关注并订阅我的Essential C# 12.0解读专栏。 前言 第一章的内容非常简单&#xff0c;毕竟仅仅是Introducing C#。不过正如《0.前言》所述&#xff0c;《C# 12.0本质论》本身就不是一本零基础的…

【Redis】主从复制分析-基础

1 主从节点运行数据的存储 在主从复制中, 对于主节点, 从节点就是自身的一个客户端, 所以和普通的客户端一样, 会被组织为一个 client 的结构体。 typedef struct client {// 省略 } client;同时无论是从节点, 还是主节点, 在运行中的数据都存放在一个 redisServer 的结构体中…

S71200 - 笔记

1 S71200 0 ProfiNet - 2 PLC编程 01.如何零基础快速上手S7-1200_哔哩哔哩_bilibili 西门子S7-1200PLC编程设计学习视频&#xff0c;从入门开始讲解_哔哩哔哩_bilibili

Facebook在内容创作中的新策略与机会

随着社交媒体的不断发展&#xff0c;内容创作已经成为了平台吸引和留住用户的核心竞争力。Facebook作为全球最大的社交平台之一&#xff0c;不断调整和优化其内容创作策略&#xff0c;以适应用户需求的变化和技术的进步。本文将深入探讨Facebook在内容创作中的新策略与机会&…

【深度学习】yolov8-det目标检测训练,拼接图的分割复原

项目背景 https://blog.csdn.net/x1131230123/article/details/140606459 似乎这个任务是简单的&#xff0c;利用目标检测是否可以完成得好呢? 生成数据集 利用这个代码产生数据集&#xff1a; 为了将标签转换为YOLOv5格式&#xff0c;需要将左上角和右下角的坐标转换为Y…

websocket实现进度条

websocket实现进度条 做一个简易的websocket实现进度条的练习&#xff0c;效果如下&#xff1a; 前端vue3 <template><el-progress type"circle" :percentage"this.progressValue" :status"this.perstatus" /><el-button cli…

【Python的wxauto】快速入门案例:简单操作微信发送消息

使用wxauto库发送消息是一个相对简单的过程。以下是一个详细的文字教程&#xff0c;以及相应的Python代码示例&#xff0c;指导您如何使用wxauto库发送消息。 文字教程&#xff1a;使用wxauto库发送消息 效果展示 步骤1&#xff1a;环境准备 确保您的计算机上安装了Python…

人工智能增强的心电图推导的身体质量指数作为未来心脏代谢疾病预测指标| 文献-基于人工智能(AI base)医学影像研究与疾病诊断

Title 题目 Artificial intelligence-enhancedelectrocardiography derived body massindex as a predictor of futurecardiometabolic disease 人工智能增强的心电图推导的身体质量指数作为未来心脏代谢疾病预测指标 01 文献速递介绍 心电图&#xff08;ECG&#xff09;可…

深度学习模型Transformer结构

Transformer结构是一种基于自注意力&#xff08;Self-Attention&#xff09;机制的深度学习模型&#xff0c;最初由Vaswani等人在2017年的论文《Attention Is All You Need》中提出&#xff0c;用于解决自然语言处理&#xff08;NLP&#xff09;领域的任务&#xff0c;如机器翻…

五年Java手,竟被一个用MemFire Cloud的前端给秒了

小李是个有五年经验的Java开发工程师&#xff0c;在公司里也算得上是技术大拿。可有一天&#xff0c;他却在一次项目竞赛中被一个刚入行不久的前端新手给秒了。这让他大感意外&#xff0c;不禁自问&#xff1a;“难道我的Java生涯要完了么&#xff1f;” 事情的真相是&#xf…

私密文件的绿色通道,使用极空间Docker部署视频文件加密工具『Alist-encrypt』

私密文件的绿色通道&#xff0c;使用极空间Docker部署视频文件加密工具『Alist-encrypt』 哈喽小伙伴们好&#xff0c;我是Stark-C~ 关于Alist我就不用过多介绍了&#xff0c;作为多网盘存储挂载工具&#xff0c;它不仅支持文件列表全能展示&#xff0c;还可以链接分享与下载…

# Redis 入门到精通(八)-- 服务器配置-redis.conf配置与高级数据类型

Redis 入门到精通&#xff08;八&#xff09;-- 服务器配置-redis.conf配置与高级数据类型 一、redis 服务器配置–redis.conf 配置 1、服务器端设定 1&#xff09;设置服务器以守护进程的方式运行&#xff1a; daemonize yes|no 2&#xff09;绑定主机地址&#xff1a; bin…

【unity小技巧】新输入系统InputSystem重新绑定控制按键(最全最完美解决方案)

文章目录 前言安装InputSystem,并导入重新绑定控制按键例子输入控制拿例子的重绑定按钮预制体绑定对应按钮升级文本新增全屏覆盖的提示文本配置绑定绑定当前启用的输入键禁用一些按钮的绑定和退出按键绑定状态重复绑定按钮问题重置绑定重复按钮修改按钮绑定名字添加两个变量勾选…

【常见开源库的二次开发】基于openssl的加密与解密——MD5算法源码解析(五)

一、MD5算法分析 &#xff1a; 1.1 关于MD5 “消息摘要”是指MD5&#xff08;Message Digest Algorithm 5&#xff09;算法。MD5是一种广泛使用的密码散列函数&#xff0c;它可以生成一个128位&#xff08;16字节&#xff09;的散列值。 RFC 1321: MD5由Ronald Rivest在1992…

Windows 磁盘分区样式有几种?如何查看电脑分区样式?

在使用 Windows 操作系统的过程中&#xff0c;磁盘分区是一个重要的概念。磁盘分区的方式直接影响到数据存储和系统运行的效率。磁盘分区的时候也有不同的样式&#xff0c;你知道分区类型有哪些吗&#xff1f;不同的分区样式决定了硬盘的分区方式、可支持的最大存储容量以及兼容…

某企业网络及服务器规划与设计

目录 1. 项目需求与设计... 5 1.1 项目需求... 5 1.2 组建企业网络内部网的流程... 5 1) 构思阶段... 5 2) 方案设计阶段... 6 3) 工程实施阶段... 6 4) 测试验收... 6 5) 管理维护... 7 1.3 技术可行性分析... 7 1.4 网络组网规则... 8 1.5 网络拓扑... 8 2. 项目所…

气膜体育馆内运动舒服吗—轻空间

气膜体育馆作为一种新型的体育设施&#xff0c;以其灵活的结构和高效的功能受到越来越多体育爱好者的青睐。很多人可能会担心在这种环境中运动是否会感到不适。轻空间将从气膜体育馆的结构特点、环境控制和用户体验三个方面&#xff0c;详细分析在气膜体育馆内运动的舒适度。 气…

如何用JavaScript实现视频观看时间追踪

在网页开发中&#xff0c;跟踪用户与多媒体内容&#xff08;如视频&#xff09;的互动是一项常见需求。无论是教育平台、数据分析&#xff0c;还是用户参与度统计&#xff0c;监控用户如何观看视频内容都能提供宝贵的见解。这篇文章将探索如何使用JavaScript实现视频播放时长的…