手写数字识别基本思路

news2024/10/2 2:22:30

问题

什么是MNIST?如何使用Pytorch实现手写数字识别?如何进行手写数字对模型进行检验?

方法

mnist数据集

MNIST数据集是美国国家标准与技术研究院收集整理的大型手写数字数据集,包含了60,000个样本的训练集以及10,000个样本的测试集。

使用Pytorch实现手写数字识别

1.进行数据预处理对于MNIST数据集,可以通过torchvision中的datasets进行下载。

root (string):表示数据集的根目录,其中根目录存在MNIST/processed/training.pt和MNIST/processed/test.pt的子目录。

train (bool, optional):如果为True,则从training.pt创建数据集,否则从test.pt创建数据集。

download (bool, optional):如果为True,则从internet下载数据集并将其放入根目录。如果数据集已下载,则不会再次下载。

transform (callable, optional):接收PIL图片并返回转换后版本图片的转换函数。

bat = 128
transform = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize(0.1307, 0.3081)  # (均值,方差)
])  # Compoes 两个操作合为一个
train_ds = datasets.MNIST(root='data', download=False, train=True,
                         transform=transform)
train_ds, val_ds = torch.utils.data.random_split(train_ds, [50000, 10000])
test_ds = datasets.MNIST(root='data', download=True, train=False,
                        transform=transform)
train_loader = DataLoader(dataset=train_ds, batch_size=bat, shuffle=True)
val_loader = DataLoader(dataset=val_ds, batch_size=bat)
test_loader = DataLoader(dataset=test_ds, batch_size=bat)

2.构建模型

class MyNet(nn.Module):
   def __init__(self) -> None:
       super().__init__()

       self.flatten = nn.Flatten()  # 将28*28的图像拉伸为784维向量
       # 第一个全连接层Full Connection(FC)
       self.fc1 = nn.Linear(in_features=784,
                            out_features=256)
       self.fc2 = nn.Linear(in_features=256,
                            out_features=128)
       self.fc3 = nn.Linear(in_features=128,
                            out_features=10)

   def forward(self, x):
       x = self.flatten(x)
       x = torch.relu(self.fc1(x))
       x = torch.relu(self.fc2(x))
       out = torch.relu(self.fc3(x))
       return out

构建一个三层的神经网络MNIST数据集中的图片都是28×28大小的,而且是灰度图。而全连接神经网络的输入要是一个行向量,所以我们要把28×28的矩阵转换成28×28=764的行向量,作为神经网络的输入

3.优化器的选择,参数设置

使用优化器和损失函数。优化器选择SGD,SGD随机梯度下降,lr学习率取值0.2最优,momentum用于加速SGD在某一方向上的搜索以及抑制震荡的发生。

optimizer=torch.optim.SGD(net.parameters(),lr=0.2)#lr学习率,momentum用于加速SGD在某一方向上的搜索以及抑制震荡的发生
#损失函数
#衡量y与y_hat之间的差异
loss_fn=nn.CrossEntropyLoss()

4.对模型进行训练测试,网络的输入,输入尺寸B*C*H*W B是batch,一个batch一个batch交给网络处理,x=torch.rand(size=(128,1,28,28)),基于loss信息利用优化器从后向前更新网络全部参数。

def train(dataloader, net, loss_fn, optimizer, epoch):
   size = len(dataloader.dataset)
   corrent = 0
   epoch_loss = 0.0
   batch_num = len(dataloader)
   net.train()

   # 一个batch一个batch的训练网络
   for batch_idx, (X, y) in enumerate(dataloader):
       pred = net(X)

       # 衡量y与y_hat之间的loss
       # y:128, pred:128x10 CrossEntropyloss
       loss = loss_fn(pred, y)

       # 基于loss信息利用优化器从后向前更新网络全部参数 <---
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
       epoch_loss += loss.item()
       corrent += (pred.argmax(1) == y).type(torch.float).sum().item()
       if batch_idx % 100 == 0:
           # f-string
           print(f'[{batch_idx + 1:>5d}/{batch_num + 1:>5d}],loss:{loss.item()}')
   avg_loss = epoch_loss / batch_num
   avg_accuracy = corrent / size
   # loss_list.append(avg_loss)
   return avg_accuracy, avg_loss
def test(dataloader, net, loss_fn):
   size = len(dataloader.dataset)
   batch_num = len(dataloader)
   corrent = 0
   losses = 0
   net.eval()
   with torch.no_grad():
       for X, y in test_loader:
           pred = net(X)
           correct = (pred.argmax(1) == y).type(torch.int).sum().item()
           # print(y.size(0))
           # print(correct)
           corrent += correct
   accuracy = corrent / size
   avg_loss = losses / batch_num
   return accuracy, avg_loss


5.保存最优的模型

net.load_state_dict(torch.load('model_best.pth'))
   test(test_loader,net,loss_fn)

6.读入自己的写入数字,进行识别

model = MyNet()
model.load_state_dict(torch.load('model_best.pth'))
img = Image.open("7.png").convert("L")  # 转为灰度图像
img = transform(img)
# img = np.array(img)
# print(img)
result = model(img)
_, predict = torch.max(result.data, dim=1)
print(result)
print("the result is:",predict.item())

4d1fb1c7127ae2e3d99e583625960373.png

结语

minist是一个28*28的图像,所以输入就是28*28=784的维度,输出为10,0-9十个数字。手写数字识别首先需要初始化全局变量,构建数据集。然后构建模型,构建迭代器与损失函数,进行训练测试。最后可以将训练的模型进行保存,通过读取自己写的数字进行识别验证,完成一个简单的深度学习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/484455.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RIP笔记

目录 RIP路由信息协议——UDP520端口(RIPNG521端口) RIP使用的算法——贝尔曼福特算法 RIP的版本 RIP的数据包 RIP的工作过程 RIP的计时器 周期更新计时器——默认30s 失效计时器——默认180s 垃圾回收计时器——默认120s RIP的环路问题 解决方法&#xff1a; RIP的…

12种接口优化的通用方案

一、背景 针对老项目&#xff0c;去年做了许多降本增效的事情&#xff0c;其中发现最多的就是接口耗时过长的问题&#xff0c;就集中搞了一次接口性能优化。本文将给小伙伴们分享一下接口优化的通用方案。 二、接口优化方案总结 1.批处理 批量思想&#xff1a;批量操作数据库…

Item冷启优化

Item冷启动的目标&#xff1a; 1.精准推荐。 2.激励发布。 3.挖掘高潜。 Item冷启动优化措施&#xff1a; 1.优化全链路&#xff08;召回和排序&#xff09; 2.流量调控&#xff08;新老物品的流量分配&#xff09; 评价指标&#xff1a; 作者侧&#xff1a; 发布渗透率&a…

【基于Ubuntu18.04+Melodic的realsense D435安装】

【基于Ubuntu18.04Melodic的realsense D435安装】 1. RealSense SDK安装1.1 克隆SDK1. 2 安装相关依赖1.3 安装权限脚本1. 4 进行编译与安装1.5 测试安装是否成功 2. D435i 安装ROS接口2.1 方法一realsense—ros源码2.2 方法二安装相机库 3. 总结 1. RealSense SDK安装 系统硬…

C++:分治算法之选择问题的选择第k小元素问题

目录 3.2.6 选择问题 分析过程&#xff1a; 解法一&#xff1a; 算法代码&#xff1a; 【单组数据】 【多组数据】 运行结果&#xff1a; 解法二 代码&#xff1a; 运行结果&#xff1a; 解法三&#xff1a; 3.2.6 选择问题 ¢ 对于给定的 n 个元素的数组 a[0 …

DAY 53 Haproxy负载均衡集群

常见的Web集群调度器 目前常见的Web集群调度器分为软件和硬件&#xff1a; 软件通常使用开源的LVS、Haproxy、 Nginx LVS性能最好&#xff0c;但是搭建相对复杂&#xff1b;Nginx 的upstream模块支持群集功能&#xff0c;但是对群集节点健康检查功能不强&#xff0c;高并发性能…

第一章 Linux是什么

Linux是一套操作系统&#xff0c;如同下图所示&#xff0c;Linux就是核心与系统调用接口那两层。至于应用程序不算Linux。 1.1 Linux当前应用的角色 由于Linux kernel实在是非常的小巧精致&#xff0c;可以在很多强调省电以及较低硬件资源的环境下面执行&#xff1b; 此外&…

【Elasticsearch】NLP简单应用

文章目录 NLP简介ES中的自然语言处理(NLP)NLP演示将opennlp插件放在ESplugins路径中下载NER模型配置opennlp重启ES、验证 NLP简介 NLP代表自然语言处理&#xff0c;是计算机科学和人工智能领域的一个分支。它涉及使用计算机来处理、分析和生成自然语言&#xff0c;例如英语、中…

企业对网络安全的重视度开始降低

近日&#xff0c;英国科学技术部发布了《2023年企业网络安全合规调查报告》&#xff08; Cyber Security Breaches Survey &#xff09;&#xff0c;对英国所有企业和社会性组织目前的网络威胁态势和合规建设进行研究&#xff0c;同时也就如何提升新一代网络应用的合规性给出专…

02-管理员登录与维护 尚筹网

一、管理员登陆 需要做的&#xff1a; 对存入数据库的密码进行MD5加密在登录界面登录失败时的处理抽取后台页面的公共部分检查登录状态&#xff0c;防止未登录时访问受保护资源的情况 具体操作如下&#xff1a; 1&#xff09;、MD5加密 ​ 使用到的CrowdConstant类中的一些…

人的全面发展评价指标体系—基于相关-主成分分析构建

本文先从经济、社会、生活质量和人口素质四个方面海选了众多人的全面发展评价指标&#xff0c;然后根据可观测性原则剔除无法获得的指标进行了初步筛选&#xff0c;再利用相关性分析删除相关系数大的指标&#xff0c;以及通过主成分分析删除因子负载小的指标&#xff0c;完成了…

CCD视觉检测设备如何选择光源

CCD视觉检测设备的机器视觉系统对光源的要求很高&#xff0c;光源是决定图像质量的一个重要因素。那么&#xff0c;我们就来看看CCD图像加网设备和机器视觉系统光源的选择点——CCD图像加网设备。 CCD视觉检测设备机器视觉系统光源选择要点&#xff1a; 1. 对比度&#xff1a;…

最新VUE面试题

前言 本文以前端面试官的角度出发&#xff0c;对 Vue 框架中一些重要的特性、框架的原理以问题的形式进行整理汇总&#xff0c;意在帮助作者及读者自测下 Vue 掌握的程度。 本文章节结构以从易到难进行组织&#xff0c;建议读者按章节顺序进行阅读&#xff0c;当然大佬级别的…

P1915 [NOI2010] 成长快乐

此题为世纪难题 题目提供者 洛谷 难度 NOI/NOI/CTSC 输入输出样例 输入 #1 5 1 6 0 0 1 5 2 2 0 0 输出 #1 1 5 5 2 2 1 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~此题非常难&#xff0c;小白就不用想着独自完成了 题解&#xff1a; #…

如何在 Windows 11 启用 Hyper-V

准备在本机玩一下k8s&#xff0c;需要先启用 Hyper-V&#xff0c;谁知道这一打开&#xff0c;没有 Hyper-V选项&#xff1a; 1、查看功能截图&#xff1a; 2、以下文件保存记事本&#xff0c;然后重命名为*.bat pushd "%~dp0" dir /b %SystemRoot%\servicing\Packa…

常用的MySQL 优化方法

数据库优化一方面是找出系统的瓶颈&#xff0c;提高MySQL数据库的整体性能&#xff0c;而另一方面需要合理的结构设计和参数调整&#xff0c;以提高用户的相应速度&#xff0c;同时还要尽可能的节约系统资源&#xff0c;以便让系统提供更大的负荷。   本文我们来谈谈项目中常用…

maven中的 type ,scope的作用

dependency为什么会有type为pom&#xff0c;默认的值是什么&#xff1f; dependency中type默认为jar即引入一个特定的jar包。那么为什么还会有type为pom呢?当我们需要引入很多jar包的时候会导致pom.xml过大&#xff0c;我们可以想到的一种解决方…

Linux指令-2

文章目录 一、 m a n man man [选项] 命令1、功能&#xff1a;2、常用选项&#xff1a;3、运用实例 二、 c p cp cp [选项] 源文件/目录 目标文件/目录1、功能&#xff1a;2、常用选项&#xff1a;3、运用实例 三、 m v mv mv [选项] 源文件/目录 目标文件/目录1、功能…

PySide6/PyQT多线程之 编程入门指南:基础概念和最佳实践

前言 本篇文章介绍 PySide6/PyQT多线程编程的基本概念&#xff0c;用到的知识点&#xff0c;以及PySide6/PyQT多线程的基本使用。 看多线程介绍&#xff0c;就看 知识点&#x1f4d6;&#x1f4d6; &#xff1b; 看多线程代码&#xff0c;就看 实现 。 知识点&#x1f4d6;&…

《手腕光电容积图智能手表对房颤检测的录制长度和其他心律失常的影响》阅读笔记

目录 一、论文摘要 二、论文十问 三、论文亮点与不足之处 四、与其他研究的比较 五、实际应用与影响 六、个人思考与启示 参考文献 一、论文摘要 本研究旨在评估手腕光电容积图&#xff08;PPG&#xff09;的定量分析是否能检测到房颤&#xff08;AF&#xff09;。使用心…