Pytorch-08 实战:手写数字识别

news2024/11/20 6:21:29

手写数字识别项目在机器学习中经常被用作入门练习,因为它相对简单,但又涵盖了许多基本的概念。这个项目可以视为机器学习中的 “Hello World”,因为它涉及到数据收集、特征提取、模型选择、训练和评估等机器学习中的基本步骤,所以手写数字识别项目是一个很好的起点。

我们的要做的是,训练出一个人工神经网络,使它能够识别手写数字(如下图所示):

以下是一个简单的示例代码,展示如何使用PyTorch创建一个手写数字识别的模型,包括数据集加载、训练和测试过程。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 检查GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 加载MNIST数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
print(f"训练集第1张图像形状 = {train_dataset.__getitem__(0)[0].shape}")
print(f"训练集第1张图像标签 = {train_dataset.__getitem__(0)[1]}")
print(f"测试集第1张图像形状 = {test_dataset.__getitem__(0)[0].shape}")
print(f"测试集第1张图像标签 = {test_dataset.__getitem__(0)[1]}")

# 使用数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# 定义神经网络模型并将其移至GPU
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

model = Net().to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型,训练过程输出损失值
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)  # 将数据移至GPU
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')

# 测试模型,输出数字识别准确率
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)  # 将数据移至GPU
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print(f'Accuracy on the test set: {100 * correct / total}%')

程序运行后,输出如下:

训练集第1张图像形状 = torch.Size([1, 28, 28])
训练集第1张图像标签 = 5
测试集第1张图像形状 = torch.Size([1, 28, 28])
测试集第1张图像标签 = 7
Epoch [1/5], Loss: 0.3935443162918091
Epoch [2/5], Loss: 0.1757822483778
Epoch [3/5], Loss: 0.1337398886680603
Epoch [4/5], Loss: 0.03868262842297554
Epoch [5/5], Loss: 0.025882571935653687
Accuracy on the test set: 96.85%

进程已结束,退出代码为 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1691871.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

强化学习4:DQN 算法

看这篇文章之前,建议先了解一下:Q-Learning 算法。 1. 算法介绍 DQN 算法全称为 Deep Q-Network,即深度Q网络。它将 Q-Learning 与 Deep Learning 结合在了一起。 1.1 Q-Network Q-Learning 是使用 Q-table 才存储决策信息的,…

spring常用知识点

1、拦截器和过滤器区别 1. 原理不同: 拦截器是基于java的反射机制,而过滤器采用责任链模式是基于函数回调的。 2. 使用范围不同: 过滤器Filter的使用依赖于Tomcat等容器,导致它只能在web程序中使用 拦截器是一个Sping组件&am…

IO模型:同步阻塞、同步非阻塞、同步多路复用、异步非阻塞

目录 stream和channel对比 同步、异步、阻塞、非阻塞 线程读取数据的过程 同步阻塞IO 同步非阻塞IO 同步IO多路复用 异步IO 优缺点对比 stream和channel对比 stream不会自动缓冲数据,channel会利用系统提供的发送缓冲区、接收缓冲区。stream仅支持阻塞API&am…

【C++】哈希和unordered系列容器

目录 一、unordered系列关联式容器的引入 二、容器使用 2.1 unordered_map的文档说明 2.2 unordered_map的使用 2.3 unordered_set 三、底层结构 3.1 哈希概念 3.2 哈希表 3.3 哈希冲突 3.4 哈希函数 3.5 哈希冲突解决 3.5.1 闭散列 3.5.2 开散列 3.5.3 思考 四…

C++ RBTree

目录 概念 性质 节点的定义 树的结构 Insert 1. pparent->_left parent 1.1 uncle && uncle->_col RED 1.2 !(uncle && uncle->_col RED) 1.2.1 parent->_left cur 1.2.2 parent->_right cur 2. pparent->_right parent …

hive3从入门到精通(一)

Hive3入门至精通(基础、部署、理论、SQL、函数、运算以及性能优化)1-14章 第1章:数据仓库基础理论 1-1.数据仓库概念 数据仓库(英语:Data Warehouse,简称数仓、DW),是一个用于存储、分析、报告的数据系统。 数据仓库的目的是构…

第十六讲:数据在内存中的存储

第十六讲:数据在内存中的存储 1.整数在内存中的存储1.1存储方式1.2大小端字节序1.3大小端字节序排序规则1.4为什么要有大小端1.5练习1.5.1练习11.5.2练习21.5.3练习31.5.4练习41.5.5练习51.5.6练习61.5.7练习7 2.浮点数在内存中的存储2.1练习2.2浮点数的存储2.3浮点…

常见的几种数据库通过SQL对表信息进行查询

一、前言 我们查询数据库表的信息,一般都使用界面化的连接工具查看,很少使用SQL语句去查,而且不同的数据库SQL语句又各自有差异。但如果通过代码去获取数据库表的信息,这时就需要通过SQL语句去查了,这个在逆向代码生成…

【案例分享】医疗布草数字化管理系统:聚通宝赋能仟溪信息科技

内容概要 本文介绍了北京聚通宝科技有限公司与河南仟溪信息科技有限公司合作开发的医疗布草数字化管理系统。该系统利用物联网技术实现了医疗布草生产过程的实时监控和数据分析,解决了医疗布草洗涤厂面临的诸多挑战,包括人工记录、生产低效率和缺乏实时…

打造专业级网页排版:全方位解析专业字体家族font-family实践与全球知名字体库导览

CSS中的字体家族(font-family)属性用于指定文本所使用的字体系列。它允许开发者选择一种或多种字体作为备选,确保在浏览器中以最佳可用字体显示文本。本文将深度解析专业级网页排版中字体家族(font-family)设置的实践技…

掌握Python基本语法的终极指南【基本语法部分】

一、基本语法部分 1.简单数据类型 1.1字符串类型及操作 字符串访问: 1.索引访问 mystr"Hello world" #索引访问 print(mystr[0]) #H print(mystr[-1]) #d print(mystr[-7]) #o print(mystr[6]) #w 2.切片访问 [头下标:尾下标] &#x…

车灯合面合壳密封使用UV胶的优缺点是什么呢?汽车车灯的灯罩如果破损破裂破洞了要怎么修复?

车灯合面合壳密封使用UV胶的优缺点是什么呢? 车灯合壳密封使用UV胶的优缺点如下: 优点: 快速固化:UV胶通过紫外线照射可以在短时间内迅速固化,大大缩短了车灯制造的工艺流程时间,提高了生产效率。高度透明&#xff…

SVG批量转为pdf超有效的方式!

最近在整理工作,发现ppt里面画的图智能导出svg格式无法导出pdf格式,由于在线的网站会把我的图片搞乱而且不想下载visio(会把本地的word搞坏),因此琢磨出这种批量转换的方式。 1. 下载并安装Inkscape 下载链接&#xf…

基于Matlab完整版孤立词识别系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景与意义 孤立词识别是语音识别领域的一个重要分支,其目标是将输入的语音信号转换为计算机可…

成都爱尔眼科医院《中、欧国际近视手术大数据白皮书2.0》解读会圆满举行

2024年5月12日,爱尔眼科联合中国健康促进基金会健康传播与促进专项基金、新华社新媒体中心与中南大学爱尔眼科研究院、爱尔数字眼科研究所重磅发布《中、欧国际近视手术大数据白皮书2.0》。这是继2021、2022年在国内相继发布《国人近视手术白皮书》、《2022中、欧近…

C++笔试强训day32

目录 1.素数回文 2.活动安排 3.合唱团 1.素数回文 链接https://www.nowcoder.com/practice/d638855898fb4d22bc0ae9314fed956f?tpId290&tqId39945&ru/exam/oj 现将其转化为回文数(这里用字符串存储比较方便转化),然后判断是否为…

无线网卡有几种接口?怎么给电脑选择一款合适的无线网卡?

前言 这篇文章一共有两个问题: 无线网卡有几种接口 怎么给电脑选择一款合适的无线网卡 目测这一期的文章很长很长,但不水。想要给笔记本或台式机升级无线网卡的小伙伴看过来了! 最近有小伙伴问:华硕r555笔记本能不能升级无线…

MySql的环境配置与安装

MySQL 数据库 MySQL是一款关系型数据库 关系型数据库 ​ 基本单位是表,一个表中存储一类信息,表与表之间存在关联关系 sql语言(Structured Query Language) 数据库操作语言也属于一种编程语言,专门用作数据库操作分为三种语言 如下 sql安装使用流程 官网 href https://…

【设计模式】JAVA Design Patterns——Converter(转换器模式)

🔍目的 转换器模式的目的是提供相应类型之间双向转换的通用方法,允许进行干净的实现,而类型之间无需相互了解。此外,Converter模式引入了双向集合映射,从而将样板代码减少到最少 🔍解释 真实世界例子 在真实…

大众汽车集团CARIAD中国领导团队莅临知迪科技考察交流

5月23日,大众汽车集团旗下软件子公司CARIAD中国领导团队莅临知迪科技参观考察,知迪科技COO尹晓航先生率公司技术代表热情接待。 CARIAD中国一行来宾首先参观了知迪科技数采项目改制车。知迪科技软硬件工程师为考察团领导专家们讲解了知迪智驾数采系统&am…