神经网络基础-手写数字识别

news2024/11/24 13:03:18

手写数字识别神经网络

基本原理

图像本质上被认为是一个矩阵,每个像素点都是一个对应的像素值,相当于在多维数据上进行相关的归类或者其他操作。

线性函数

线性函数的一个从输入到输出的映射,用于给目标一个每个类别对应的得分。

图像 ( 32 ∗ 32 ∗ 3 ) → f ( x , W ) Y 图像(32*32*3) \stackrel{f(x,W)}{\rightarrow} Y 图像(32323)f(x,W)Y

其中 x x x为3072维的一个向量,
W W W为parameters
Y Y Y为图像对应每个类别对应的得分

f ( x , W ) = W x ( + b ) f(x,W)=Wx(+b) f(x,W)=Wx(+b)

其中 f ( x , W ) f(x,W) f(x,W)是10*1维度
W W W是10*3072维度
x x x是3072*1维度
b b b是10*1维度

请添加图片描述

损失函数

得到了输入图像和分类目标直接对应的每类得分,我们如何去分析衡量分类的结果?我们可以使用损失函数去明确当前模型的效果是好是坏。
损失函数可以表示为:

损失函数 = 数据损失 + 正则化惩罚项

L = 1 N ∑ i = 1 N m a x ( 0 , f ( x i ; W ) j − f ( x i ; W ) y i + 1 ) + λ R ( W ) L=\frac{1}{N} \sum\limits\limits_{i=1}^{N}max(0,f(x_i;W)_j-f(x_i;W)_{y_i}+1)+\lambda R(W) L=N1i=1Nmax(0,f(xi;W)jf(xi;W)yi+1)+λR(W)
其中R(W)项为正则化惩罚项,用于减少模型复杂度,防止过拟合,其中其 λ \lambda λ参数越大惩罚力度越大,也就是我们约不希望他过拟合。
R ( W ) = ∑ k ∑ l W k , l 2 R(W)=\sum\limits_{k}\sum\limits_{l}W _{k,l}^{2} R(W)=klWk,l2

其中一个常用损失函数为
L i = ∑ j ≠ y i m a x ( 0 , s j − s y i + 1 ) L_{i}=\sum\limits_{j \neq y_i}max(0,s_j-s_{y_i}+1) Li=j=yimax(0,sjsyi+1)
在某次训练过程中,几个任务图像的线性函数输出结果如下所示:

请添加图片描述

我们分别计算其损失函数:

= m a x ( 0 , 5.1 − 3.2 + 1 ) + m a x ( 0 , − 1.7 − 3.2 + 1 ) = m a x ( 0 , 2.9 ) + m a x ( 0 , − 0.39 ) = 2.9 \begin{aligned} &=max(0,5.1-3.2+1)+max(0,-1.7-3.2+1) \\ &=max(0,2.9)+max(0,-0.39)\\ &=2.9 \end{aligned} =max(0,5.13.2+1)+max(0,1.73.2+1)=max(0,2.9)+max(0,0.39)=2.9

= m a x ( 0 , 1.3 − 4.9 + 1 ) + m a x ( 0 , 2.0 − 4.9 + 1 ) = m a x ( 0 , − 2.6 ) + m a x ( 0 , − 1.9 ) = 0 \begin{aligned} &=max(0,1.3-4.9+1)+max(0,2.0-4.9+1) \\ &=max(0,-2.6)+max(0,-1.9)\\ &=0 \end{aligned} =max(0,1.34.9+1)+max(0,2.04.9+1)=max(0,2.6)+max(0,1.9)=0

= m a x ( 0 , 2.2 − ( − 3.1 ) + 1 ) + m a x ( 0 , 2.5 − ( − 3.1 ) + 1 ) = m a x ( 0 , 5.3 ) + m a x ( 0 , 5.6 ) = 10.9 \begin{aligned} &=max(0,2.2-(-3.1)+1)+max(0,2.5-(-3.1)+1) \\ &=max(0,5.3)+max(0,5.6)\\ &=10.9 \end{aligned} =max(0,2.2(3.1)+1)+max(0,2.5(3.1)+1)=max(0,5.3)+max(0,5.6)=10.9
我们可以根据本轮损失函数的计算去判断当前分类效果的好坏。
所以损失值我们通过一下流程进行得到:

请添加图片描述

前向传播(梯度下降)

我们知道了当前的模型的性能效果,那么肯定要对模型进行更新来达到一个更佳的状态

暂略

整体架构

请添加图片描述

数据模块

我们可以使用现成的torch中帮忙封装的MNIST数据,通过datasets包可以直接进行下载,并且使用dataloader加载数据。

def data_pre():
    train_data=torchvision.datasets.MNIST(
        root='MNIST',
        train=True,
        transform=torchvision.transforms.ToTensor(),
        download=True
    )
    test_data=torchvision.datasets.MNIST(
        root='MNIST',
        train=False,
        transform=torchvision.transforms.ToTensor(),
        download=True
    )
    train_load=DataLoader(dataset=train_data,batch_size=100,shuffle=True)
    test_load=DataLoader(dataset=test_data,batch_size=100,shuffle=True)
    return train_data, test_data

torchvision.datasets.MNIST参数含义

  • root:

存放训练和测试数据的文件根目录

  • train:(数据类型bool)

如果为True则从training创建数据集,否则从test.pt创建数据集

  • download:(数据类型bool)

如果为ture则从网络上下载数据集并放在根目录下,如果数据已经存在不会进行重复下载

  • transform:(数据类型callable)

对数据内容进行转换处理的函数,具体见torchvision.transforms中的参数设置,此处为将PIL文件转换成tensor的数据格式

DataLoader参数含义

  • dataset:(数据类型 dataset)

PyTorch中的数据集类型。

  • batch_size:(数据类型 int)

每次输入数据的行数,默认为1。PyTorch训练模型时调用数据不是一个一个进行的,而是一批一批输入用于提升效率。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。

  • shuffle:(数据类型 bool)

洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。

神经网络框架

class neuralnet():
    def __init__(self, input_nodes, hidden_nodes, output_nodes, learning_rate):
        self.inodes = input_nodes   # 输入层节点设定
        self.hnodes = hidden_nodes  # 隐藏层节点设定
        self.onodes = output_nodes  # 输出层节点设定
        self.lr = learning_rate     # 学习率设定

        # 初始化w_ih
        # 输入层与隐藏层之间的连接参数
        self.wih = (np.random.normal(0.0, pow(self.hnodes, -0.5),\
                        (self.hnodes, self.inodes)))
        # 隐藏层与输出层之间的连接参数
        self.who = (np.random.normal(0.0, pow(self.onodes,-0.5),\
                        (self.onodes,self.hnodes)))
        # 激活函数,返回sigmoid函数
        self.activation_function = lambda x:spe.expit(x)
    
    
    def train(self, inputs_list, targets_list):
        # 输入进来的二维图像数据
        inputs = np.array(inputs_list, ndmin=2).T
        # 隐藏层计算
        hidden_inputs = np.dot(self.wih, inputs)   
        # 隐藏层的输出经过sigmoid函数处理         
        hidden_outputs = self.activation_function(hidden_inputs)
        # 输出层计算
        final_inputs = np.dot(self.who, hidden_outputs)
        # 输出经过sigmoid函数处理
        final_outputs = self.activation_function(final_inputs)

        # 取得对应的标签
        targets = np.array(targets_list, ndmin=2).T

        # 计算数据预测误差,将其用于向前反馈
        output_errors = targets - final_outputs
        # 根据公式计算得到反向传播参数
        hidden_errors = np.dot(self.who.T,output_errors)

        # 根据反馈参数去修改两个权重
        self.who += self.lr * np.dot((output_errors * final_outputs *(1.0 - final_outputs)),np.transpose(hidden_outputs))  
        self.wih += self.lr * np.dot((hidden_errors * hidden_outputs * (1.0-hidden_outputs)), np.transpose(inputs))

    def query(self, inputs_list):
        # 输入进来的二维图像数据
        inputs = np.array(inputs_list, ndmin=2).T
        # 隐藏层计算
        hidden_inputs = np.dot(self.wih, inputs)   
        # 隐藏层的输出经过sigmoid函数处理         
        hidden_outputs = self.activation_function(hidden_inputs)
        # 输出层计算
        final_inputs = np.dot(self.who, hidden_outputs)
        # 输出经过sigmoid函数处理
        final_outputs = self.activation_function(final_inputs)

        return final_outputs

参考

(29条消息) [ PyTorch ] torch.utils.data.DataLoader 中文使用手册_江南蜡笔小新的博客-CSDN博客

(29条消息) 「学习笔记」torchvision.datasets.MNIST 参数解读/中文使用手册_江南蜡笔小新的博客-CSDN博客_torchvision.datasets.mnist

Python scipy.special.expit用法及代码示例 - 纯净天空 (vimsky.com)

002-深度学习数学基础(神经网络、梯度下降、损失函数) - 小小猿笔记 - 博客园 (cnblogs.com)

TensorBoard的最全使用教程:看这篇就够了 - 腾讯云开发者社区-腾讯云 (tencent.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/465034.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode刷题(9)二叉树(3)

各位朋友们,提前祝大家五一劳动节快乐啊!!!今天我为大家分享的是关于leetcode刷题二叉树相关的第三篇我文章,让我们一起来看看吧。 文章目录 1.二叉树的层序遍历题目要求做题思路代码实现 2.从前序与中序遍历序列构造二…

Authing 正式发布应用集成网关 - Authing Gateway

2023 年 2月, Authing 推出了身份领域的 PaaS化应用集成网关 - Authing Gateway 。 Authing Gateway 提供将原有应用快速集成到 Authing 身份云产品的能力,在扩充身份认证方式的同时,提高资源的安全性和数据的隐私可靠性。 01.Authing Gatew…

如何查看声卡、pcm设备以及tinyplay、tinymix、tinycap的使用

命令列表 功能命令查看当前录音进程状态dumpsys media.audio_flinger查看当前音频策略状态dumpsys media.audio_policy查看pcm节点信息cat /proc/asound/pcm查看声卡信息cat /proc/asound/cards查看声卡物理设备节点ls /dev/snd/驱动层录音命令tinycap xx.wav -D 0 -d 1 -c 2 …

【Java EE】-博客系统(前端页面)

作者:学Java的冬瓜 博客主页:☀冬瓜的主页🌙 专栏:【JavaEE】 分享: 且视他人如盏盏鬼火,大胆地去走你的道路。——史铁生《病隙碎笔》 主要内容:博客系统 登陆页面,列表页面,详情页…

OpenAI推企业版ChatGPT,英伟达造AI安全卫士

GPT现在已经进入了淘金时代。虽然全球涌现出成千上万的大模型或ChatGPT变种,但一直能挣钱的人往往是卖铲子的人。 这不,围绕暴风眼中的大模型,已经有不少企业,开始研究起了大模型的“铲子”产品,而且开源和付费两不误…

【C++】——string的功能介绍及使用

前言: 在上期,我们简单的介绍了关于 模板和STL ,今天我就带领大家学习一下关于 【string】类。本期,我们主要讲解的是关于 【string】的基本介绍以及【string】类的常用接口说明。有了以上的基本认识之后,在下期&…

全球SPD市场迎来黄金时代,中国领跑全球增长

近日,专注于前沿领域的国际咨询机构ICV发布了全球单光子探测器市场研究报告,报告分析了单光子探测器(SPD)市场,包括产品定位、下游应用、主要供应商、市场情况和未来趋势等各个方面,以进行分析和预测。 研究…

微服务 - kong安装,API网关设计(原理篇)

概述 微服务实践的第二个关键组件,微服务API网关设计,API网关是对微服务做统一的鉴权、限流、黑白名单、负载均衡等功能实现,这篇我们先来介绍Api网关的意义和安装kong/konga需要的组件。 网关的作用和意义 网关可以使得服务本身更专注自己的领域&…

Linux Ansible管理变量、管理事实、管理机密

目录 Ansible变量 变量定义范围 变量类型 定义变量并引用 事实变量与魔法变量 事实变量 魔法变量 Ansible加密 ansible-vault参数 ansible-vault举例 Ansible变量 Ansible支持利用变量来存储值,并且可以在Ansible项目的所有文件中重复使用这些值 变量可能…

浏览器缓存原理

使用 HTTP 缓存的好处:通过复用缓存资源,减少了客户端等待服务器响应的时间和网络流量,同时也能缓解服务器端的压力。可以显著的提升网站的应用性能。 HTTP 缓存策略分为两种:强制缓存、协商缓存。 强制缓存 浏览器缓存没有过期…

[pgrx开发postgresql数据库扩展]5.自定义函数与SQL组合应用

老规矩的声明: 并不是所有场景都需要(或者适合)用rust来写的,绝大部分操作数据库的功能和计算,用SQL就已经足够了! 本系列中,所有的案例,仅用于说明pgrx的能力,而并非是…

BPMN2.0 任务-用户任务

“用户任务(user task)”用于对需要人工执行的任务进行建模。当流程执行到达用户任务时,会为指派至该任务的用户或组的任务列表创建一个新任务。 用户任务用左上角有一个小用户图标的标准任务(圆角矩形)表示。 用户任务在XML中如下定义。其中id是必须属性,name是可选属性…

提高网络安全性:探索ADAudit Plus的全功能IT安全审计解决方案

网络安全一直是组织和企业需要关注的重要问题之一,因为随着企业数字化的加速和技术的不断发展,网络攻击的威胁也变得越来越严峻。因此,组织和企业需要采取措施保护其信息资产和网络安全。 ADAudit Plus是一种全功能的IT安全审计解决方案&…

2023年商票研究报告

第一章 行业概况 1.1 定义 商票是指出票人依托商业汇票系统,以数据电文形式制作的,委托付款人在指定日期无条件支付确定的金额给收款人或者持票人的票据。按承兑人的不同,商业汇票分为银行承兑汇票和商业承兑汇票(即商票&#x…

flex布局 高度没有自动撑到max-height

在做一个项目时,用到了竖向flex布局,我写了max-height: 820px, 但是到小屏幕时,只能撑到773px,解决方法是height: max-content. 但是不知道为什么只能撑到773px便撑不动了。 https://zhuanlan.zhihu.com/p/130460207 这个文档说的…

workerman开发者必须知道的几个问题

1、windows环境限制 windows系统下workerman单个进程仅支持200个连接。 windows系统下无法使用count参数设置多进程。 windows系统下无法使用status、stop、reload、restart等命令。 windows系统下无法守护进程,cmd窗口关掉后服务即停止。 windows系统下无法在一个…

目标检测之损失函数

损失函数的作用为度量神经网络预测信息与期望信息(标签)的距离,预测信息越接近期望信息,损失函数值越小。 在目标检测领域,常见的损失分为分类损失和回归损失。 L1损失 L1 Loss也称为平均绝对值误差(MAE&…

[HNCTF 2022 WEEK4]ezheap

Index 前言Checksec & IDA 前言 手把手教学,覆盖一切途中会遇到的问题。 [HNCTF 2022 WEEK4]ezheap Checksec & IDA 保护全开,但是四肢健全(四项功能 增删改查),因此是ezheap。 主要来观察函数add和show。 d…

注意力机制:基于Yolov5/Yolov7的Triplet注意力模块,即插即用,效果优于cbam、se,涨点明显

论文:https://arxiv.org/pdf/2010.03045.pdf 本文提出了可以有效解决跨维度交互的triplet attention。相较于以往的注意力方法,主要有两个优点: 1.可以忽略的计算开销 2.强调了多维交互而不降低维度的重要性,因此消除了通道和权…

信号完整性分析基础知识之传输线和反射(三):仿真和测试反射波形

使用上面反射系数的定义,可以计算来自任意阻抗的反射信号。当终端阻抗为阻性元件时,阻抗恒定,反射电压容易计算。当终端具有更复杂的阻抗行为(例如电容性或电感性终端,或两者的某种组合)时,如果…