Pytorch 手写数字识别 深度学习基础分享

news2025/1/10 19:46:43

本篇是一次内部分享,给项目开发的同事分享什么是深度学习。用最简单的手写数字识别做例子,讲解了大概的原理。

手写数字识别

展示首先数字识别项目的使用。项目实现过程:

  1. 训练出模型
  2. 准备html手写板
  3. flask 框架搭建简单后端

简单手写数字识别



深度学习必备知识介绍

机器学习的概念

通俗解释
机器学习的关键内涵之一在于利用计算机的运算能力从大量的数据中发现一个规律,用这个规律实现预测或判断的功能。

深度学习算法分类

以算法区分深度学习应用,算法类别可分成三大类:

  • 常用于图片数据进行分析处理的卷积神经网络
  • 文本分析或自然语言处理的递归神经网络
  • 常用于数据生成的对抗神经网络

卷积神经网络(CNN)主要应用可分为图像分类、目标检测、语义分割

图片保存的本质

图片在计算机中以数字矩阵的形式存储。
https://h.markbuild.com/doc/binary-viewer-cn.html

图片的保存:

模型训练的通用步骤

模型训练的思想:

  1. 准备数据集
  2. 构建神经网络模型(面向对象中定义的一个类)
  3. 选择损失函数和优化器
  4. 训练模型
    • 从模型训练得出数值
    • 通过损失函数得到预测值和实际值的差距
    • 通过优化器调整模型中的参数,让结果越来越准确
    • 循环以上步骤

损失函数:衡量训练结果和实际偏差的函数。数值越大代表差距越大
优化器:优化模型的算法,让损失函数减小的方法

Q&A


Pytorch 手写数字识别讲解

模型训练使用pytorch框架,同样可以实现的框架还由tensorflow、keras。

数据集获取

手写识别使用的是MNIST数据集,手写数字图片。MNIST数据集由像素是28 × 28 的0~9的手写数字图片组成,一共有7万张图片,其中6万张是训练集,1万张是测试集。每个图片是黑底白字的形式。

pytorch 中提供了torchvision 包,可以通过该包可以下载数据集

import torchvision 
import matplotlib.pyplot as plt

# 训练数据集
train_data = torchvision.datasets.MNIST(
    root="data",    # 表示把MINST保存在data文件夹下
    download=True,  # 表示需要从网络上下载。下载过一次后,下一次就不会再重复下载了
    train=True,     # 表示这是训练数据集
    transform=torchvision.transforms.ToTensor()
                    # 要把数据集中的数据转换为pytorch能够使用的Tensor类型
)

# 测试数据集
test_data = torchvision.datasets.MNIST(
    root="data",    # 表示把MINST保存在data文件夹下
    download=True,  # 表示需要从网络上下载。下载过一次后,下一次就不会再重复下载了
    train=False,    # 表示这是测试数据集
    transform=torchvision.transforms.ToTensor()
                    # 要把数据集中的数据转换为pytorch能够使用的Tensor类型
)

演示


模型定义

模型使用的是卷积神经网络模型。定义的神经网络模型如下:

import torch.nn as nn


# 定义卷积神经网络类
class RLS_CNN(nn.Module):
    def __init__(self):
        super(RLS_CNN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16,   # 输入、输出通道数,输出通道数可以理解为提取了几种特征
                      kernel_size=(3, 3),               # 卷积核尺寸
                      stride=(1, 1),                    # 卷积核每次移动多少个像素
                      padding=1),                       # 原图片边缘加几个空白像素
                                                        # 输入图片尺寸为 1×28×28
                                                        # 第一次卷积,尺寸为                 16×28×28
            nn.MaxPool2d(kernel_size=2),                # 第一次池化,尺寸为                 16×14×14
            nn.Conv2d(16, 32, 3, 1, 1),                 # 第二次卷积,尺寸为                 32×14×14
            nn.MaxPool2d(2),                            # 第二次池化,尺寸为                 32×7 ×7
            nn.Flatten(),                               # 将三维数组变成一维数组
            nn.Linear(32*7*7, 16),                      # 变成16个卷积核,每一个卷积核是1*1,最后输出16个数字
            nn.ReLU(),                                  # 激活函数 x<0 y=0  x>0 y=x,用在反向反向传导
            nn.Linear(16, 10)                           # 将16变成10,预测0-9之间概率值
        )

    def forward(self, x):
        return self.net(x)

卷积神经网络模型组成

卷积神经网络通常由3个部分构成:卷积层,池化层,全连接层。各部分的功能:

  • 卷积层:负责提取图像中的特征,可以输出一张图片的很多种特征。
  • 池化层:用来缩小尺寸,大幅降低参数量级,降低计算量
  • 全连接层:合并特征并输出结果

美颜相机的原理就是提取图片的特征,如下图片第二张模糊轮廓,第三张是突出轮廓。

卷积

卷积的功能:提取图片的多种特征信息
卷积的原理:用一个卷积核和图片的矩阵相乘,得到一个新的矩阵。新矩阵就是一个新的特征。
卷积核
卷积核也是一个矩阵,通常是33的矩阵,或者是55的矩阵。卷积运算的过程如下:

图像边缘提取
使用如下的卷积核就可以提取图像的边缘轮廓特征

调参
卷积核矩阵由3*3一共9个参数组成,这些参数都是模型自动生成的,所谓的调参,其中一部分就是指调整卷积核矩阵的参数,让其提取的特征能够使预测更加准确

池化

池化的功能:池化就是缩小矩阵的尺寸,从而减少后续操作的参数数量。通常会在相邻的卷积层之间加入一个池化层。
池化的原理:池化的运算过程:将一个44的矩阵最大池化成22的矩阵,就是取4*4矩阵中对应区域中最大的一个数值。

池化通常有两种:

  • 最大池化(max pooling):选图像区域的最大值作为该区域池化后的值。
  • 平均池化(average pooling):计算图像区域的平均值作为该区域池化后的值。

全连接

全连接功能: 全连接的作用是组合特征分类
在前面两个步骤中从一张图片提取多种特征,并将特征矩阵进行了压缩。当数据到达全连接层时得到是一张图片的多种特征。
某一个特征并不能说整个图片是什么,否则就是盲人摸象。那么全连接层就是将多种特征组合起来形成一个完整的特征,并根据特征计算出图片是某一个类型的概率。
全连接层最终输出就是概率。比如手写数字识别,最终全连接层输出就是某一个手写数字在0~9上的概率。

tensor([[ 0.949,  3.032,  0.771, -2.173, -0.038, -0.236,  0.013,  0.614, -1.125, -2.6991]])

全连接的原理
全连接层实现的是特征组合,原理和卷积类似,也就是用一个卷积核对矩阵做运算,最后得到一个一维的数组,也就是0-9的概率。

调参:全连接的实现也需要卷积核的参与,所以卷积核矩阵也是参数的一部分,调参就包括该部分的参数。

手写数字识别的模型定义

手写数字识别的卷积神经网络,下面分析卷积+池化+全连接的过程:

Q&A


选择损失函数和优化器

损失函数功能:衡量训练结果和实际偏差的函数。数值越大代表差距越大
优化器功能:让模型不断优化,让损失函数减小的方法

手写数字识别中使用的损失函数和优化器如下:

# 交叉熵损失函数,选择一种方法计算误差值
loss_func = torch.nn.CrossEntropyLoss()

# 优化器,随机梯度下降算法
optimizer = torch.optim.SGD(model.parameters(), lr=0.2)

损失函数

手写识别中选择了交叉熵损失函数,pytorch一共有19中损失函数可以使用,比较好理解的是平方差损失函数

优化器

手写识别中选了随机梯度下降算法,用来实现反向传播参数的修改。pytorch中一共有11中优化器可以使用。

模型训练

模型训练的流程:

  1. 定义训练的次数
  2. 遍历训练集,调用模型类传入图片,得到概率结果
  3. 通过损失函数计算损失值
  4. 通过优化器调整参数
  5. 训练完成保存模型
# 定义训练次数
cnt_epochs = 5 # 训练5个循环

# 循环训练
for cnt in range(cnt_epochs):
    # 把训练集中的数据训练一遍
    for imgs, labels in train_dataloader:
        outputs = model(imgs)  # 输出0~9预测的结果概率
        loss = loss_func(outputs, labels) # 和输入做一个比较,得到一个误差
        optimizer.zero_grad()   # 初始化梯度,清空梯度。注意清空优化器的梯度,防止累计
        loss.backward()  # 方向传播计算
        optimizer.step() # 累加1,执行一次

# 保存训练的结果(包括模型和参数)
torch.save(model, "my_cnn.nn")

需要注意的点:

  • 训练的规律
  • my_cnn.nn 模型保存的内容

Q&A

模型验证

  • 模型在测试集上的准确率
  • 一批模型准确率展示

总结

  1. 数据集非常重要。html手写识别中遇到的问题,以及如何解决。颜色,大小
  2. 数学知识。训练过程中遇到的数据知识:矩阵乘法
  3. 为什么需要GPU?如何使用GPU?
  4. 模型训练的过程。卷积 + 池化 + 全连接 + 损失函数 + 优化器
  5. 目标检查的训练过程和手写识别有何不同?
    图像分类:LeNet、AlexNet、VGG、GoogLeNet
    目标检测:RCNN、Fast RCNN、Faster RCNN、YOLO、YOLOv2、SSD

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2257037.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WPS EXCEL 使用 WPS宏编辑器 写32位十六进制数据转换为浮点小数的公式。

新建EXCLE文件 另存为xlsm格式的文件 先打开WPS的开发工具中的宏编辑器 宏编辑器编译环境 在工作区添加函数并编译&#xff0c;如果有错误会有弹窗提示&#xff0c;如果没有错误则不会弹 函数名字 ”HEXTOFLOAT“ 可以自己修改。 function HEXTOFLOAT(hex) { // 将十六…

沃丰科技智能客服在跨境电商独立站中的核心角色

随着全球化进程的加速和互联网技术的不断发展&#xff0c;跨境电商行业蓬勃兴起&#xff0c;为消费者提供了更广阔、更便捷的购物选择。在这样一个竞争激烈的市场环境中&#xff0c;优质的客户服务成为了企业脱颖而出的关键。沃丰科技智能客服凭借其先进的技术和人性化的设计理…

langgraph实现无观测推理 (Reasoning without Observation)

图例 1. 图状态 在 LangGraph 中&#xff0c;每个节点都会更新一个共享的图状态。当任何节点被调用时&#xff0c;状态就是该节点的输入。 下面&#xff0c;我们将定义一个状态字典&#xff0c;用以包含任务、计划、步骤和其他变量。 from typing import List from typing…

2024企业数据资产入表合规指引——解读

更多数据资产资讯关注公众&#xff1a;数字化转型home 本报告旨在为企业数据资产入表提供合规保障。随着数字经济的发展&#xff0c;数据资产已成为重要战略资源和新生产要素。财政部发布的《企业数据资源相关会计处理暂行规定》明确&#xff0c;自2024年1月1日起&#xff0c;数…

19,[极客大挑战 2019]PHP1

这个好玩 看到备份网站字眼&#xff0c;用dirsearch扫描 在kali里打开 找出一个www.zip文件 访问一下 解压后是这个页面 class.php <?php include flag.php; error_reporting(0); class Name{ private $username nonono; private $password yesyes; public …

计算机键盘简史 | 键盘按键功能和指法

注&#xff1a;本篇为 “计算机键盘简史 | 键盘按键功能和指法” 相关文章合辑。 英文部分机翻未校。 The Evolution of Keyboards: From Typewriters to Tech Marvels 键盘的演变&#xff1a;从打字机到技术奇迹 Introduction 介绍 The keyboard has journeyed from a humb…

《Clustering Propagation for Universal Medical Image Segmentation》CVPR2024

摘要 这篇论文介绍了S2VNet&#xff0c;这是一个用于医学图像分割的通用框架&#xff0c;它通过切片到体积的传播&#xff08;Slice-to-Volume propagation&#xff09;来统一自动&#xff08;AMIS&#xff09;和交互式&#xff08;IMIS&#xff09;医学图像分割任务。S2VNet利…

HarmonyOS(65) ArkUI FrameNode详解

Node 1、Node简介2、FrameNode2.1、创建和删除节点2.2、对FrameNode的增删改2.3、 FramNode的查询功能3、demo源码4、总结5、参考资料1、Node简介 在HarmonyOS(63) ArkUI 自定义占位组件NodeContainer介绍了自定义节点复用的原理(阅读本本篇博文之前,建议先读读这个),在No…

Elasticsearch使用(2):docker安装es、基础操作、mapping映射

1 安装es 1.1 拉取镜像 docker pull swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/library/elasticsearch:7.17.3 1.2 运行容器 运行elasticsearch容器&#xff0c;挂载的目录给更高的权限&#xff0c;否则可能会因为目录权限问题导致启动失败&#xff1a; docker r…

java实现SpringBoot项目分页查询和消费的方法

简介 why&#xff1a; 最近在项目中&#xff0c;有一个sql需要查询100多万的数据&#xff0c;且需要在代码中遍历处理。面临两个问题 一次性查询出太多数据&#xff0c;速度较慢当前服务器内存支持以上操作&#xff0c;但是随着数据量的增多&#xff0c;以后可能会出现内存溢出…

专为高性能汽车设计的Armv9架构的Neoverse V3AE CPU基础知识与软件编码特性解析

一、ARMv9以及V3AE处理器架构 Armv9架构的Arm Neoverse V系列处理器是专为高性能计算设计的产品线&#xff0c;其中V3AE&#xff08;Advanced Efficiency&#xff09;特别强调了性能与效率之间的平衡。以下是关于Armv9架构下Neoverse V3AE处理器结构和指令集的一些详细解读&am…

Python数据清洗之重复数据处理

大家好&#xff0c;在数据处理和分析的过程中&#xff0c;重复数据是一个常见的问题。重复的数据不仅会影响数据的准确性&#xff0c;还可能导致模型训练中的偏差。因此&#xff0c;检测并清理重复数据是数据清洗中的重要步骤。Python 的 Pandas 提供了强大的功能来检测、标记和…

【实战教程】使用YOLO和EasyOCR实现视频车牌检测与识别【附源码】

《------往期经典推荐------》 一、AI应用软件开发实战专栏【链接】 项目名称项目名称1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】3.【手势识别系统开发】4.【人脸面部活体检测系统开发】5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】7.【…

【项目实战】基于python+爬虫的电影数据分析及可视化系统

注意&#xff1a;该项目只展示部分功能&#xff0c;如需了解&#xff0c;文末咨询即可。 本文目录 1.开发环境2 系统设计 2.1 设计背景2.2 设计内容 3 系统页面展示 3.1 用户页面3.2 后台页面3.3 功能展示视频 4 更多推荐5 部分功能代码 5.1 爬虫代码5.2 电影信息代码 1.开发环…

SDXL的优化工作

本文详细介绍SDXL在SD系列的基础上做了什么优化&#xff0c;包括模型架构优化和训练过程数据的相关优化策略。 目录 Stable Diffusion XL核心基础内容 SDXL整体架构初识 Base模型 Refiner模型 Base——VAE Base——U-Net Base——Text Encoder Refiner GPT补充【TODO】 SDXL官方…

计算机网络 —— HTTPS 协议

前一篇文章&#xff1a;计算机网络 —— HTTP 协议&#xff08;详解&#xff09;-CSDN博客 目录 前言 一、HTTPS 协议简介 二、HTTPS 工作过程 1.对称加密 2.非对称加密 3.中间人攻击 4.引入证书 三、HTTPS 常见问题 1.中间人能否篡改证书&#xff1f; 2.中间人能否调…

YonBuilder移动开发——调用手机系统的浏览器打开网页

概述 在YonBuilder移动开发中&#xff0c;可以通过使用引擎提供的 api.openWin 或者 api.openFrame 函数方法通过内置的浏览器引擎在App内部打开相关的远程H5网站的网页。但是在实际项目开发中&#xff0c;可能会有一种需求&#xff0c;调用手机操作系统提供的系统浏览器去打开…

美畅物联丨视频接入网关如何配置 HTTPS 证书

在安防领域&#xff0c;视频接入网关&#xff08;Video Access Gateway&#xff0c;VAG&#xff09;是视频监控系统的重要组成部分&#xff0c;其职责是把视频数据从前端设备传输至后端服务器。配置HTTPS证书后&#xff0c;可对视频流进行加密传输&#xff0c;避免数据在网络传…

Redis原理—2.单机数据库的实现

大纲 1.Redis数据库的结构 2.读写Redis数据库键值时的处理 3.Redis数据库的构成 4.Redis过期键的删除策略 5.Redis的RDB持久化 6.Redis的AOF持久化 7.Redis的AOF重写机制 8.Redis持久化是影响其性能的高发地 9.Redis基于子进程实现持久化的使用建议 10.Redis持久化的…

Android平台GB28181设备接入模块动态文字图片水印技术探究

技术背景 前几年&#xff0c;我们发布的了Android平台GB28181设备接入模块&#xff0c;实现了不具备国标音视频能力的 Android终端&#xff0c;通过平台注册接入到现有的GB/T28181—2016或GB/T28181—2022服务。 Android终端除支持常规的音视频数据接入外&#xff0c;还可以支…