【AI知识点】机器学习中的常用优化算法(梯度下降、SGD、Adam等)

news2024/10/10 0:33:42

更多AI知识点总结见我的专栏:【AI知识点】
AI论文精读、项目和一些个人思考见我另一专栏:【AI修炼之路】
有什么问题、批评和建议都非常欢迎交流,三人行必有我师焉😁


1. 什么是优化算法?

在机器学习中优化算法(Optimization Algorithm) 的任务是找到模型参数(如权重、偏置等),使得损失函数(例如均方误差、交叉熵等)最小化。损失函数度量的是模型预测值与真实标签之间的误差。优化算法通过不断调整模型的参数,使损失函数达到全局或局部最小值。

在神经网络中,优化算法需要通过反向传播(Backpropagation)计算每个参数对损失函数的导数(即梯度),并根据这些梯度更新模型的参数。


2. 基于梯度的优化算法

这些算法是深度学习中最常用的优化方法,通过计算梯度来找到损失函数最小的方向。

a. 梯度下降(Gradient Descent)

梯度下降是最基本的优化算法,核心思想是:朝着使损失函数减少的方向更新参数,直到达到最小值。

  • 更新规则
    θ = θ − α ⋅ ∇ θ J ( θ ) \theta = \theta - \alpha \cdot \nabla_\theta J(\theta) θ=θαθJ(θ)
    其中:
    • θ \theta θ 是模型的参数。
    • α \alpha α 是学习率,控制每次更新的步长。
    • ∇ θ J ( θ ) \nabla_\theta J(\theta) θJ(θ) 是损失函数 J ( θ ) J(\theta) J(θ) 对参数 θ \theta θ 的梯度。

b. 随机梯度下降(Stochastic Gradient Descent, SGD)

梯度下降的一个问题是,当数据集很大时,计算所有样本的梯度会很耗时。随机梯度下降(SGD) 是对梯度下降的改进,每次迭代只使用一个数据点来计算梯度,从而大大加快了参数更新。

c. 小批量梯度下降(Mini-batch Gradient Descent)

这是梯度下降和随机梯度下降的折中版本。它通过对一小部分数据(称为mini-batch)进行梯度计算和更新,这样既加快了计算速度,又保持了一定的稳定性。

d. 动量法(Momentum)

SGD 更新参数时每次依赖于当前梯度的方向,但有时可能会在方向上震荡。动量法通过加入“动量”项,积累过去几次梯度的方向,使得优化算法能够更快速地朝着最优解的方向前进。

  • 更新规则
    v t = β v t − 1 + α ∇ θ J ( θ ) v_t = \beta v_{t-1} + \alpha \nabla_\theta J(\theta) vt=βvt1+αθJ(θ)
    θ = θ − v t \theta = \theta - v_t θ=θvt
    其中, β \beta β 是动量项的系数。

e. RMSProp

RMSProp 是另一种改进的优化算法,它对每个参数都使用不同的学习率,通过对每个参数的梯度平方进行平滑加权平均,使得参数的更新步长更加合适。

  • 更新规则
    E [ ∇ θ 2 J ( θ ) ] t = β E [ ∇ θ 2 J ( θ ) ] t − 1 + ( 1 − β ) ∇ θ 2 J ( θ ) E[\nabla_\theta^2 J(\theta)]_t = \beta E[\nabla_\theta^2 J(\theta)]_{t-1} + (1 - \beta) \nabla_\theta^2 J(\theta) E[θ2J(θ)]t=βE[θ2J(θ)]t1+(1β)θ2J(θ)
    θ = θ − α E [ ∇ θ 2 J ( θ ) ] t + ϵ ∇ θ J ( θ ) \theta = \theta - \frac{\alpha}{\sqrt{E[\nabla_\theta^2 J(\theta)]_t + \epsilon}} \nabla_\theta J(\theta) θ=θE[θ2J(θ)]t+ϵ αθJ(θ)
    其中, ϵ \epsilon ϵ 是一个很小的数,用于避免除以零。

f. Adam(Adaptive Moment Estimation)

Adam 是目前深度学习中最常用的优化算法之一,它结合了动量法RMSProp的优点。Adam 同时对一阶和二阶矩进行估计,能够自适应地调整每个参数的学习率。

  • 更新规则
    Adam 分别维护了两个动量变量:

    • 一阶动量(梯度的加权平均): m t m_t mt
    • 二阶动量(梯度平方的加权平均): v t v_t vt

    m t = β 1 m t − 1 + ( 1 − β 1 ) ∇ θ J ( θ ) m_t = \beta_1 m_{t-1} + (1 - \beta_1) \nabla_\theta J(\theta) mt=β1mt1+(1β1)θJ(θ)
    v t = β 2 v t − 1 + ( 1 − β 2 ) ∇ θ 2 J ( θ ) v_t = \beta_2 v_{t-1} + (1 - \beta_2) \nabla_\theta^2 J(\theta) vt=β2vt1+(1β2)θ2J(θ)
    然后对动量进行偏差校正:
    m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t} m^t=1β1tmt,v^t=1β2tvt
    最终更新参数:
    θ = θ − α m ^ t v ^ t + ϵ \theta = \theta - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} θ=θαv^t +ϵm^t


3. 梯度下降的图示

图片来源:https://mlpills.dev/machine-learning/gradient-descent/

这张图形象地说明了梯度下降的工作原理:从一个随机的初始参数开始,经过多次迭代更新,逐步逼近最低的损失值,最终找到最佳的模型参数。

  1. 坐标轴

    • 横轴(w):表示模型的参数(权重),它是通过优化调整的变量。
    • 纵轴(Cost):表示模型的成本或损失值,即模型预测与实际结果之间的误差。
  2. 随机初始值

    • 图中左侧标注的“Random initial value”表示算法开始时模型参数的随机初始值。优化过程从这个点开始。
  3. 学习步骤

    • 图中多个蓝色圆点表示算法在每次迭代中的参数值。每个点都对应一个特定的成本值。
    • “Learning step”表示每次迭代中,算法根据当前的梯度(导数)调整参数,以降低成本。每次调整的幅度和方向由学习率决定。
  4. 最小值

    • 图中黄色圆点标识了成本函数的最小值,表示在这个参数值下,模型的预测效果最好,损失最小。
  5. 下降路径

    • 蓝色圆点之间的连接线展示了模型在参数空间中逐步接近最小值的过程。这条路径表明,随着迭代的进行,模型参数不断调整,成本值逐渐降低。

4. 局部最优解和全局最优解

在复杂的损失函数中,可能会存在多个局部最优解。优化算法的目标是找到全局最优解,即损失函数的全局最小值。然而,梯度下降类算法可能会陷入局部最优解,因此一些改进的算法(如动量法、Adam)引入了额外的策略来帮助模型跳出局部最优解。

  • 局部最优解:损失函数的一个小范围内的最小值,但不是全局最小值。
  • 全局最优解:整个损失函数范围内的最小值。

图解:

图片来源:https://easyai.tech/en/ai-definition/gradient-descent/#google_vignette

这张图展示了梯度下降(Gradient Descent) 的概念。图中呈现的三维曲面代表了一个目标函数,通常是损失函数,反映了模型参数与损失之间的关系。黑色箭头表示梯度的方向。不同的点代表不同的参数组合,曲面的高低则表示损失值的大小。其中最低的凹点就是全局最优解,而不是最低点的其他凹点则代表各种局部最优解


5. 优化算法的比较

算法优点缺点
梯度下降简单易懂,适合小规模数据集。计算量大,尤其是大数据集时速度慢。
SGD快速更新参数,适合大规模数据集。收敛不稳定,路径波动大,需要调节学习率。
动量法减少梯度震荡,加快收敛。需要调整动量参数 β \beta β,对不同问题敏感。
RMSProp适应性学习率,避免步长过大或过小,适合深度网络。需要调整超参数 β \beta β,在某些任务上表现不稳定。
Adam自适应学习率,结合动量和 RMSProp 的优点,广泛用于深度学习。需要调整较多的超参数,对学习率敏感,可能导致局部最优解。

三种 Gradient Descent 的形象图示:

在这里插入图片描述
图片来源:https://www.nomidl.com/machine-learning/what-is-gradient-descent-batch-gradient-descent-stochastic-gradient-descent-mini-batch-gradient-descent/


6. 如何选择合适的优化算法

选择优化算法时,需要根据具体的任务需求、数据特点和模型架构来选择合适的算法。以下是一些常见的选择依据:

a. 数据规模

  • 小规模数据集:可以使用标准的梯度下降,因为计算量不大。
  • 大规模数据集:通常使用随机梯度下降(SGD)小批量梯度下降(Mini-batch Gradient Descent)。这些算法对大规模数据更有效。

b. 模型复杂性

  • 浅层模型:如逻辑回归、线性回归等浅层模型,使用SGD动量法 可以取得良好的效果。
  • 深层神经网络:深度学习通常使用AdamRMSProp,它们能够自动调整学习率,适应深度网络中的复杂性和高维性。

c. 收敛速度

  • 如果需要快速收敛,并且可以承受一定的波动性,可以使用SGD动量法
  • 如果需要更加平稳的收敛过程,建议使用AdamRMSProp,这些算法通过自适应调整学习率来保证收敛的平稳性。

7. 超参数调整的重要性

所有的优化算法都有一些关键的超参数,如学习率( α \alpha α)、动量系数( β \beta β)、RMSProp 和 Adam 中的动量参数和二阶动量参数等。这些超参数的选择对于模型性能的影响非常大。

a. 学习率

  • 学习率决定了每次参数更新的步长。学习率太大,可能会导致跳过最优解;学习率太小,模型收敛速度太慢,甚至可能陷入局部最优解。

b. 动量系数

  • 动量系数(通常记为 β \beta β)用于在动量法和 Adam 中,它决定了过去梯度的影响。动量系数过大会导致优化过程“过冲”,而动量系数过小则无法有效加速收敛。

c. 自适应学习率

  • 像 Adam 和 RMSProp 这样的优化算法会根据每个参数的梯度历史自动调整学习率。虽然这些算法自适应学习率,但仍然需要仔细调整初始学习率和其他超参数,才能获得良好的性能。

8. 一些常见的优化技巧

在深度学习和机器学习中,优化算法的性能很大程度上取决于使用者是否有效地结合了各种优化技巧。以下是一些常见的优化技巧:

a. 学习率衰减

  • 在训练的早期使用较大的学习率来加快收敛速度,然后随着训练进行逐渐减小学习率,帮助模型在最优解附近进行更细致的搜索。这可以避免模型在靠近最优解时仍然使用较大的步长导致震荡。

b. 提前停止(Early Stopping)

  • 提前停止是一种防止过拟合的技巧,它会监控模型在验证集上的表现,当验证集上的损失不再降低时,就提前停止训练。这避免了模型过度拟合训练数据,并可以加快训练过程。

c. 批归一化(Batch Normalization)

  • 批归一化在每一层对输入数据进行归一化,使得神经网络中各层的输入数据分布更加稳定,能够加速训练并提高模型的收敛速度。

d. 梯度裁剪(Gradient Clipping)

  • 当模型中梯度过大时,可能会导致梯度爆炸问题,尤其是在深层神经网络或循环神经网络(RNN)中。梯度裁剪将梯度限制在一个固定的范围内,从而防止梯度过大导致不稳定的训练。

9. 总结

优化算法 是机器学习和深度学习的核心工具,它通过调整模型参数,使损失函数最小化,从而提高模型的性能。不同的优化算法适用于不同的数据规模、模型复杂度和任务类型,常见的算法包括梯度下降、动量法、Adam等。选择合适的优化算法和调整超参数是成功训练机器学习模型的关键。结合学习率衰减、提前停止、批归一化等优化技巧,模型的训练效率和效果可以显著提高。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2200524.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

决策树随机森林-笔记

决策树 1. 什么是决策树? 决策树是一种基于树结构的监督学习算法,适用于分类和回归任务。 根据数据集构建一棵树(二叉树或多叉树)。 先选哪个属性作为向下分裂的依据(越接近根节点越关键)?…

【动态规划-最长递增子序列(LIS)】【hard】力扣1671. 得到山形数组的最少删除次数

我们定义 arr 是 山形数组 当且仅当它满足&#xff1a; arr.length > 3 存在某个下标 i &#xff08;从 0 开始&#xff09; 满足 0 < i < arr.length - 1 且&#xff1a; arr[0] < arr[1] < … < arr[i - 1] < arr[i] arr[i] > arr[i 1] > … &g…

【hot100-java】二叉搜索树中第 K 小的元素

二叉树 二叉搜索树的中序遍历是递增序列。 /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}* TreeNode(int val) { this.val val; }* TreeNode(int val, TreeNode lef…

【C++】面向对象之继承

不要否定过去&#xff0c;也不要用过去牵扯未来。不是因为有希望才去努力&#xff0c;而是努力了&#xff0c;才能看到希望。&#x1f493;&#x1f493;&#x1f493; 目录 ✨说在前面 &#x1f34b;知识点一&#xff1a;继承的概念及定义 •&#x1f330;1.继承的概念 •&…

ECCV24高分论文:MVSplat稀疏视图下的高效的前馈三维重建模型

目录 一、概述 二、相关工作 1、稀疏视角场景重建 2、前馈NeRF 3、前馈3DGS 4、多视角立体视觉 三、MVSplat 1、多视角Transformer 一、概述 本文提出了一个MVSplat高效的前馈三维重建模型&#xff0c;可以从稀疏的多视图图像中预测3D高斯分布&#xff0c;并且相较于p…

三角形面积 python

题目&#xff1a; 计算三角形面积 代码&#xff1a; a int(input("请输入三角形的第一个边长&#xff1a;")) b int(input("请输入三角形的第二个边长&#xff1a;")) c int(input("请输入三角形的第三个边长&#xff1a;")) s (abc) / 2 #…

我谈均值平滑模板——给均值平滑模板上升理论高度

均值平滑&#xff08;Mean Smoothing&#xff09;&#xff0c;也称为盒状滤波&#xff08;Box Filter&#xff09;&#xff0c;通过计算一个像素及其周围像素的平均值来替换该像素的原始值&#xff0c;从而达到平滑图像的效果。 均值平滑通常使用一个模板&#xff08;或称为卷…

ISCC认证是什么?ISCC认证的申请流程有哪些注意事项?

ISCC认证&#xff0c;即国际可持续发展与碳认证&#xff08;International Sustainability & Carbon Certification&#xff09;&#xff0c;是一个全球通用的可持续发展认证体系。以下是对ISCC认证的详细介绍&#xff1a; 一、起源与背景 ISCC认证体系起源于德国&#x…

如何使用pymysql和psycopg2执行SQL语句

在Python中&#xff0c;pymysql和psycopg2是两个非常流行的库&#xff0c;用于与MySQL和PostgreSQL数据库进行交互。本文将详细介绍如何使用这两个库来执行SQL查询、插入、更新和删除操作。 1. 准备工作 首先&#xff0c;确保已经安装了pymysql和psycopg2库。如果尚未安装&a…

Linux驱动---光电开关、火焰传感器、人体红外传感器

文章目录 一、电路连接二、设备树三、驱动代码 一、电路连接 人体红外 – PF12 检测到人体时会产生一个上升沿 光电开关 – PE15 有遮挡物时会产生一个上升沿 火焰传感器 – PF5 有火焰时会产生一个上升沿 二、设备树 /{ //人体红外PF12human{ compatible "zyx,huma…

电池大师 2.3.9 | 专业电池管理,延长寿命优化性能

Battery Guru 显示电池使用情况信息&#xff0c;测量电池容量&#xff08;mAh&#xff09;&#xff0c;并通过有用技巧帮助用户改变充电习惯&#xff0c;延长电池寿命。支持显示电池健康状况&#xff0c;优化电池性能。 大小&#xff1a;9.6M 百度网盘&#xff1a;https://pan…

数据库软题7-数据库设计

一、概念结构设计 题1-ER图的属性分类 题2-局部ER图的冲突分类 1.命名冲突 命名冲突有同名异义&#xff0c;异名同义2.结构冲突 结构冲突分为&#xff1a;统一实体不同属性&#xff0c;同一对象在不同关系里可能为属性或者实体 教师其实就是职工&#xff0c;他们有不同的属性…

基于Arduino的超声波测距模块HC-SR04

一. HC-SR04超声波模块简介 HC-SR04超声波模块是一种常用的测距模块&#xff0c;通过不断检测超声波发射后遇到障碍物所反射的回波&#xff0c;从而测出发射和接收回波的时间差&#xff0c;并据此求出距离。它主要由两个‌压电陶瓷超声传感器和一个外围信号处理电路构成&#…

重生之我在代码随想录刷算法第十九天 | 第77题. 组合、216.组合总和III、 17.电话号码的字母组合

参考文献链接&#xff1a;代码随想录 本人代码是Java版本的&#xff0c;如有别的版本需要请上代码随想录网站查看。 第77题. 组合 力扣题目链接 解题思路 这道题目乍一看可以用暴力解法解决&#xff0c;但如果k的数量增加那就需要套特别多的循环&#xff0c;所以这种组合类…

植物大战僵尸修改器-MFC

创建项目 创建mfc应用 基于对话框 打开资源视图下的 IDD_MFCAPPLICTION2_DIALOG 限制对话框大小 将属性中Border的值改为对话框外框 删除对话框中原有的控件 属性-外观-Caption 设置对话框标题 工具箱中拖放一个按钮 修改按钮名称 将按钮ID改为IDC_COURSE 在MFCApplication2…

django(二):定义第一个函数及url介绍

1.定义index函数 """ django里的第一个函数必须是request,不写会报错 """def index(request):return HttpResponse("Hello, world. Youre at the index of djangoProject.")注意&#xff01; ①.index函数里的形参必须为request ②.r…

STM32输入捕获模式详解(上篇):原理、测频法与测周法

1. 前言 在嵌入式系统的开发过程中&#xff0c;常常需要对外部信号进行精确的时间测量&#xff0c;如测量脉冲信号的周期、频率以及占空比等。STM32系列微控制器提供了丰富的定时器资源&#xff0c;其中的输入捕获&#xff08;Input Capture, IC&#xff09;模式能实现对信号的…

【测试】BUG篇——BUG

bug的概念 定义&#xff1a;⼀个计算机bug指在计算机程序中存在的⼀个错误(error)、缺陷(flaw)、疏忽(mistake)或者故障(fault)&#xff0c;这些bug使程序⽆法正确的运⾏。Bug产⽣于程序的源代码或者程序设计阶段的疏忽或者错误。 准确的来说&#xff1a; 当且仅当规格说明&am…

网站集群批量管理-Ansible(ad-hoc)

1. 概述 1. 自动化运维: 批量管理,批量分发,批量执行,维护 2. 无客户端,基于ssh进行管理与维护 2. 环境准备 环境主机ansible10.0.0.7(管理节点)nfs01 10.0.0.31(被管理节点)backup10.0.0.41(被管理节点) 2.1 创建密钥认证 安装sshpass yum install -y sshpass #!/bin/bash ##…

SpringBoot整合MyBatis记录

整体目录结构 创建数据库 创建一个MySQL的表&#xff0c;表名是student。 create table student (id int auto_increment comment 唯一标识idprimary key,name varchar(30) not null comment 姓名,age int not null comment 年龄 ) 插入一条数据记录到数据库当中去…