深度学习笔记之梯度下降、反向传播与内置优化器

news2024/9/23 12:14:14

文章目录

    • 1. 梯度下降法
    • 2. 反向传播算法
    • 3. PyTorch内置的优化器
      • 3.1 SGD优化器
      • 3.2 RMSprop优化器
      • 3.3 Adam优化器

1. 梯度下降法

笔者往期的机器学习笔记: 机器学习之梯度下降算法

梯度下降法是一种致力于找到函数极值点的算法。

所谓“训练”或“学习”就是改进模型参数,以便通过大量训练步骤将损失最小化的过程。

训练过程就是求解损失函数最小值的过程,在此过程中将梯度下降法应用于寻找损失函数的极值点便构成了依据输入数据的模型学习。

梯度下降法的思路很简单,就是沿着损失函数下降最快的方向改变模型参数,直到到达最低点。在此过程中,需要求解模型参数的梯度,梯度是一种数学运算,它与导数类似,是微积分中一个很重要的概念,在单变量的函数中,梯度其实就是函数的微分,代表函数在某个给定点的切线的斜率;在多变量函数中,梯度是一个向量,向量有方向,梯度的方向就指出了函数在给定点的上升最快的方向。因此梯度可应用于输入为一个向量、输出为一个标量的函数,损失函数就属于这种类型。

因为梯度的方向就是损失函数变化最快的方向,所以当参数沿着梯度相反的方向改变时,就能让函数值下降得最快。所以,我们的训练就是重复利用这个方法反复求取梯度、修改模型参数,直到最后达到损失函数的最小值。

梯度的输出是一个由若干偏导数构成的向量,它的每个分量对应于函数对输入向量的相应分量的偏导,在求偏导时,可将当前变量以外的所有变量视为常数,然后运用单变量求导法则。有一点需要注意,当提及损失函数的输入变量时,指的是模型的参数(权重和偏置),而非实际数据集的输入特征。一旦给定数据集和所要使用的特征类型,这些输入特征便固定下来,无法进行优化。

我们所计算的偏导数是损失函数相对于模型中的每个参数而言的,为了更简洁地解释损失优化过程,我们绘制了一个假设的损失函数曲线,如下图所示:

在这里插入图片描述

假设上述曲线对应于损失函数曲线,箭头所在的点代表模型参数的当前值,即现在所在的位置。我们需要沿着梯度的反方向移动,在上图中用箭头表示,因此,为了减小损失,需要沿着箭头向左移动。此外,箭头的长度概念化地表示了如果在对应的方向移动,损失能够下降多少。在训练过程中,我们沿着箭头的方向移动,再次计算梯度,并重复这个过程,直到梯度的模为0,将到达损失函数的极小值点。这正是训练的目标,这个过程的图形化表示可参考下图:

在这里插入图片描述

权重的优化过程可以用公式表示如下:
W e i g h t s i + 1 = W e i g h t s i − l r × ∇ W e i g h t s i Weightsi+1=Weightsi-lr×∇Weightsi Weightsi+1=Weightsilr×Weightsi
式中,lr表示学习速率(learning rate),用来对梯度进行缩放,新的权重等于当前权重减去权重的梯度乘以学习速率,学习速率并不是模型需要推断的值,它是一个超参数(hyperparameter)。

所谓超参数是指那些需要我们手工配置的参数,需要为它指定正确的值。如果学习速率太小,则找到损失函数极小值点时可能需要许多轮迭代,训练过程会非常慢;如果学习速率太大,则算法可能会“跳过”极小值点并且周期性地来回“跳跃”而且永远无法找到极小值点,这种现象被称为“超调”,如下图所示:在这里插入图片描述

因此,在设定学习速率的值时,我们既希望学习速率足够大,能够快速地进行梯度下降学习,又希望学习速率不能太大,甚至越过最低点,导致模型的损失函数在最低点附近来回跳跃。

学习速率是在训练过程中需要特别注意的一个超参数,在实际应用中,过大的学习速率很容易引起梯度的震荡,导致训练失败。当我们观察损失变化曲线时,如果发现损失没有下降趋势,反而在震荡,首先应该想到调整学习速率。

在实践中,可在模型训练的初期使用较大的学习速率,在训练接近极值点时使用较小的学习速率,从而逼近损失函数极小值点

大家可能会有疑问,为什么要使用学习速率对梯度进行缩放呢?如果不设置学习速率对梯度进行缩放,那么梯度的变化几乎决定于当前训练批次,而忽略了所有之前的训练样本,这显然是不合适的,我们不希望单个样本或者单个批次的样本主导模型学习,而是希望模型能从所有的样本中学习,这就需要通过学习速率对梯度进行缩放,模型在训练循环中,从每一个批次中都得到学习,同时模型参数的变化又不是由单个批次主导的,这样训练出的模型泛化能力才足够好。

除了学习速率,还有一些其他问题也会影响梯度下降算法的性能。例如,损失函数的局部极值点。我们再次回到之前的损失函数曲线示例,如果权值的初值靠近损失函数右侧的“谷底”,则该算法的工作过程如下图所示:

在这里插入图片描述

如果权重的起始点落在了上图中第一个箭头位置,随着训练,模型参数将沿着梯度下降的方向移动,直到到达最低点并终止迭代,因为模型认为已经找到了一个极值点,在这个极值点,梯度的模为0,但显然这不是全局最小极值点。梯度下降法无法区分迭代终止时到底是到达了全局最小点还是局部极小点,后者往往只在一个很小的邻域内为最优。

在深度学习发展的今天,局部极值问题已经不再是研究的重点,因为我们可通过更好的权值随机初始化来改善局部极值点问题。如果权值随机落在了左边,模型就可以找到全局最小值点。通过使用随机值初始化权重,可以增加从靠近全局最优点附近开始下降的机会。从实际训练来看,损失函数总能下降到全局最小点。

2. 反向传播算法

反向传播(back propagation,BP)算法是一种高效计算数据流图中梯度的算法,在多层神经网络的训练中具有举足轻重的地位。

多层神经网络的参数多,梯度计算比较复杂,反向传播算法可以用来解决深层参数训练问题。

神经网络训练是通过误差反向传播实现的,通过这种方法,根据前一次运行获得的错误率对神经网络的权值进行微调。

详细来说,计算分为两个过程,首先是前馈,也就是输入经过网络得到预测输出,并用预测输出与真实值计算损失;然后反向传播算法根据损失进行后向传递,计算网络中模型的每一个权值的梯度;最后根据梯度调整模型的权值。正确地采用这种方法可以降低错误率,提高模型的可靠性。

总结这个监督学习过程,就是试图找到一个将输入数据映射到正确输出的函数,从而较好地实现某个特定的功能(如分类就是将猫的图片映射到猫这个标签)。

具体到前馈神经网络,需要实现的目的很简单,就是希望损失函数达到最小值,因为只有这样,实际输出和预期输出的差值(损失)才最小。

那么,如何从众多网络参数(神经元之间的连接权值和偏置)中找到最佳的参数使得损失最小呢?这就应用到了梯度下降的方法找极值,其中最关键的是,如何计算梯度呢?那就用到了反向传播算法。

在多层神经网络中要计算每一层的梯度,就需要从输出层开始逐层计算,反向传播算法利用链式法则,避开了这种逐层计算的冗余。

神经网络中每一层的导数都是后一层的导数与前一层输出之积,这正是链式法则的奇妙之处,误差反向传播算法利用的正是这一特点避免了计算冗余。

前馈时,从输入开始逐一计算每个隐藏层的输出,直到输出层;然后开始计算导数,并从输出层经各隐藏层逐一反向传播。为了减少计算量,还需对所有已完成计算的元素进行复用,这便是反向传播算法名称的由来。

3. PyTorch内置的优化器

PyTorch框架内置了自动求导模块和优化器,自动求导模块可以根据损失函数对模型的参数进行求梯度运算,这个过程中使用了反向传播算法,而优化器会根据计算得到的梯度,利用一些策略去更新模型的参数,最终使得损失函数下降。

可以看出,优化器的功能就是管理和更新模型中可学习参数的值,并通过一次次训练使得模型的输出更接近真实标签。

下面重点讲解PyTorch中的优化器,在PyTorch中,torch.optim模块实现了各种优化算法,最常用的优化方法已经得到内置支持,可以直接调用。

3.1 SGD优化器

SGD(stochastic gradient descent,torch.optim.SGD)是最为简单的优化器,它所实现的就是前面介绍的梯度下降算法,当然在梯度下降过程中要使用学习速率缩放梯度。SGD的缺点在于收敛速度慢,可能在鞍点处震荡,并且,如何合理地选择学习速率是SGD的一大难点。针对SGD可能在鞍点处震荡这个缺点,可以为其引入动量Momentum加速SGD在正确方向的下降并抑制震荡,具体使用中可通过参数momentum设置。

3.2 RMSprop优化器

RMSprop优化器(torch.optim.RMSprop)是对AdaGrad算法的一种改进。在原始的优化算法中,目标函数自变量的每一个元素在相同时间步都使用同一个学习速率来自我迭代,但是统一的学习速率难以适应所有维度变化不同的问题,因此RMSprop根据自变量在每个维度的梯度值的大小来调整各个维度上的学习速率,并且增加了一个衰减系数来控制历史信息的获取多少。也就是说,设置全局学习速率之后,每次通过全局学习速率逐参数除以经过衰减系数控制的历史梯度平方和的平方根,使得每个参数的学习速率不同。

经验上,RMSprop被证明是有效且实用的深度学习网络优化算法,特别是针对序列问题的训练,使用RMSprop会有不错的效果。

3.3 Adam优化器

Adam优化器(torch.optim.Adam)可以认为是RMSprop和Momentum的结合。Adam不仅对二阶动量使用指数移动平均,还对一阶动量也使用指数移动平均计算。这就相当于在RMSProp基础上对小批量随机梯度也做了指数加权移动平均。

Adam主要包含以下几个显著的优点。

  • 实现简单,计算高效,对内存需求少。

  • 参数的更新不受梯度的伸缩变换影响。

  • 超参数具有很好的解释性,且通常无须调整或仅需很少的微调。

  • 更新的步长能够被限制在大致的范围内(初始学习速率)。

  • 能自动调整学习速率。

  • 适合应用于大规模的数据及参数的场景。

  • 适用于不稳定目标函数。

  • 适用于梯度稀疏或梯度存在很大噪声的问题。

    工程上,Adam是目前最常用的优化器,在很多情况下,作为默认优化器都可以获得不错的效果。因此在以后的学习和训练中,可以优先选择使用Adam优化器。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/505662.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023年跨境电商必读:海外网红营销的4大雷区及应对策略

随着跨境电商的快速发展和海外网红的普及,海外网红营销已成为越来越多跨境电商企业推广品牌的必备手段。然而,在进行海外网红营销时,企业需要注意一系列雷区,以确保营销的有效性和可持续性。本文Nox聚星将和大家探讨2023年跨境电商…

AI自动写文章_免费在线原创文章生成器

自动写文章生成器 自动写文章生成器是一种利用人工智能和自然语言处理技术,帮助用户快速生成文章的工具。该软件可以根据用户的需求和选择,自动生成符合要求的文章,无需手动编写和修改。 自动写文章生成器的主要功能包括以下几个方面&#…

(三)打造华丽的即时通信系统主界面,让你的聊天体验更有质感

文章目录 一、引言1、即时通信系统的基本概念和应用场景2、Qt框架在实现即时通信系统中的应用 二、主界面设计2.1 界面设计的基本要求2.2 主界面的设计 三、通信功能实现3.1 通信协议的选择3.1.1 TCP协议和UDP协议的优缺点比较3.1.2选择何种协议进行即时通信系统的实现 3.2 通信…

单片机的电子秤方案设计

电子秤是一种利用电子技术实现重量计量的设备,广泛应用于商业、工业、医疗、科学研究等领域。电子秤是一种高精度的计重装置,不仅精度高,而且使用方便、稳定可靠。下面,我们从结构设计、工作原理、功能参数、产品种类四个方面来介…

MyBatis基础知识点总结

MyBatis了解 MyBatis 是什么? MyBatis 是支持定制化 SQL、存储过程以及高级映射的优秀的持久层框架 MyBatis 避免了几乎所有的 JDBC 代码和手动设置参数以及获取结果集 MyBatis 可以使用简单的XML或注解用于配置和原始映射,将接口和Java的 POJO&#x…

Canvas 2D详解

在我书的第六章中有一个关于MNIST手写数字的例子&#xff0c;当数据集加载完成之后&#xff0c;用户可以在<canvas/>上输入手写数字&#xff0c;点击「预测」按钮之后&#xff0c;浏览器会弹出经模型预测之后的结果&#xff1b;在我书的第九章和第十章中&#xff0c;分别…

2023年宜昌市中等职业学校技能大赛 “网络搭建与应用”竞赛题-2

2023年宜昌市中等职业学校技能大赛 “网络搭建与应用”竞赛题 一、竞赛内容分布 “网络搭建及应用”竞赛共分二个部分&#xff0c;其中&#xff1a; 第一部分&#xff1a;企业网络搭建部署项目&#xff0c;占总分的比例为50%&#xff1b; 第二部分&#xff1a;企业网络服…

第十四届蓝桥杯大赛软件赛省赛(Java 大学A组)

蓝桥杯 2023年省赛真题 Java 大学A组 试题 A: 特殊日期  试题 B: 与或异或 把填空挂上跟大伙对对答案&#xff0c;先把C/C B组的做了。 试题 A: 特殊日期 本题总分&#xff1a;5 分 【问题描述】 记一个日期为 y y \small yy yy 年 m m \small mm mm 月 d d \small dd dd 日…

链表的初步认识

什么是链表&#xff1f;链表是一种物理存储结构上非连续存储结构&#xff0c;数据元素的逻辑顺序是通过链表中的引用链接次序实现的 。 就如现实中的火车或铁链一般&#xff0c;环环相扣。当我们到达一个节点时&#xff0c;就可以通过这个节点找到下一个节点。链表与顺序表一样…

【服务器数据恢复】EXT4文件系统下KVM虚拟机数据恢复案例

服务器数据恢复环境&#xff1a; 服务器采用的Linux操作系统EXT4文件系统&#xff1b; 服务器中有3台KVM虚拟机&#xff1a;一台运行Mysql数据库&#xff0c;一台存放数据库备份&#xff0c;一台存放程序代码文件&#xff1b; 每台虚拟机包含一个qcow2格式的磁盘文件和一个raw格…

联合发版+主题演讲|GBASE南大通用亮相鲲鹏开发者峰会2023

5月6-7日&#xff0c;以“创未来 享非凡”为主题的鲲鹏开发者峰会2023在东莞松山湖举办&#xff0c;旨在打造生态伙伴和开发者学习、共享、交流的平台&#xff0c;帮助开发者深入了解鲲鹏、昇腾全栈技术&#xff0c;加速行业技术、产品和解决方案的创新。行业技术领袖、产业技术…

Apache FtpServer在Windows上使用以及SpringBoot中集成apache ftpserver实现Ftp 服务端搭建

场景 Apache Ftpserver Apache FtpServer是100&#xff05;纯Java FTP服务器。它被设计为基于当前可用的开放协议的完整且 可移植的FTP服务器引擎解决方案。FtpServer可以作为Windows服务或Unix / Linux守护程序独立运行&#xff0c; 也可以嵌入Java应用程序中。我们还提供…

【图】邻接表

目录 无向图的邻接表 链表&#xff08;存相邻顶点下标&#xff09;的类 数组里放的顶点 邻接表&#xff08;操作&#xff09; 构造和析构&#xff08;创建销毁邻接表&#xff09; 插入顶点 插入边 获取下标 插v1、v2之间的边 删除顶点 删除边 输出&#xff1a; 其他…

多种采购方式下,数智化招采系统解决方案(实例)

广发证券成立于1991年&#xff0c;是国内首批综合类证券公司&#xff0c;先后于2010年和2015年在深圳证券交易所及香港联合交易所主板上市。 多年来&#xff0c;广发证券在竞争激烈、复杂多变的行业环境中努力开拓、锐意进取&#xff0c;以卓越的经营业绩、持续完善的全面风险…

Node.js对ES6 及更高版本的支持

目录 1、简介 2、默认情况下什么特性随着 Node.js 一起发布&#xff1f; 3、有哪些特性在开发中&#xff1f; 4、移除这个标记&#xff08;--harmony&#xff09;吗 5、Node.js 对应 V8 引擎 1、简介 Node.js 是针对 V8 引擎构建的。通过与此引擎的最新版本保持同步&…

PMP课堂模拟题目及解析(第5期)

41. 项目的混凝土供应商通知项目经理&#xff0c;材料将比预定时间晚三个星期交付。项目经理更新了进度计划并通知项目团队。在这种情况下&#xff0c;哪种合同类型承担的风 险最小&#xff1f; A. 总价加激励费用合同。 B. 总价加经济价格调整合同。 C. 工料合同。 D. 固…

matlab学习笔记

一、序言 1. 图像的输入输出和显示 fimread("test.png"); frgb2gray(f);%rgb图像转化为灰度图像 imshow(f); imwrite(f,"result.jpg","quality",50);%50代表jpg形式压缩质量0-1002. matlab支持的四种图像类别 灰度级图像(Gray-scale images) …

类和对象【C++】【中篇】

目录 一、类的6个默认成员函数 1、构造函数 2、析构函数 3、拷贝构造函数 4、赋值重载函数 二、赋值运算符重载 一、类的6个默认成员函数 注意&#xff1a;默认成员函数不能在类外面定义成全局函数。因为类里没有的话会自动生成&#xff0c;就会产生冲突。 1、构造函数…

k8s采用ansible安装

一、准备工作 测试服务器 服务器配置进程功能备注192.168.0.189CPU:4核 内存32Gansibleansible一键安装k8s192.168.0.141CPU&#xff1a;12核 内存&#xff1a;10Gdocker&#xff0c;kube-apiserver&#xff0c;etcd&#xff0c;kube-scheduler&#xff0c;kube-controller-m…

产品经理 - 原型图设计软件

原型图设计软件哪个好用&#xff1f;6款好用软件推荐&#xff01; - 知乎 摩客, 墨刀 2014 即时设计是一款支持在线协作的专业级 UI 设计工具&#xff0c;用户数已突破230万&#xff1b; 2021年 5,000万(美元) 国外 axure 老牌 如果有进一步模拟的必要&#xff0c;再换用Ad…