12.梯度下降法的具体解析——举足轻重的模型优化算法

news2024/10/4 3:08:47

引言

梯度下降法(Gradient Descent)是一种广泛应用于机器学习领域的基本优化算法,它通过迭代地调整模型参数,最小化损失函数以求得到模型最优解。

通过阅读本篇博客,你可以:

1.知晓梯度下降法的具体流程

2.掌握不同梯度下降法的区别

一、梯度下降法的流程

梯度下降法的流程通常分为以下四个步骤。

1.初始化模型参数

初始化模型参数其实就是random随机一个初始的 \theta (一组 W_{0},...,W_{n})。这样我们就可以得到上图中的 Starting Point(开始点)

2.计算当下参数的梯度

计算模型参数的梯度,其实就是对于当前损失函数所在位置进行求偏导,公式:

gradient_{j} = \frac{\partial J(\theta)}{\partial \theta_{j}}

公式推导 J(\theta) 是损失函数,\theta_{j} 是样本中某个特征维度 x_{j} 对应的权值系数,也可以写成 W_{j} 。对于多元线性回归来说,损失函数 J(\theta) = \frac{1}{2}(h_{\theta}x - y)^{2} (推导过程在9.深入线性回归推导出MSE——不容小觑的线性回归算法-CSDN博客中),因为我们的MSE中 Xy 是已知的,\theta 是未知的,而 \theta 不是一个变量而是许多向量组成的矩阵,所以我们只能对含有一堆变量的函数MSE中的一个变量求导,即偏导,下面就是对 \theta_{j} 求偏导。

\frac{\partial J(\theta)}{\partial \theta_{j}} = \frac{\partial \frac{1}{2}(h_{\theta}x-y)^{2}}{\partial \theta_{j}}

由于链式求导法则,我们可以推出:

\Rightarrow \frac{\partial J(\theta)}{\partial \theta_{j}} = 2 \cdot \frac{1}{2}(h_{\theta}x-y) \cdot\frac{\partial (h_{\theta}x - y)}{\partial \theta_{j}}

在多元线性回归中,h_{\theta}x 就是 W^{T}X,也就是 \omega _{0}x_{0} + \omega_{1}x_{1}+...+\omega_{n}x_{n},我们通常把它写成\sum_{n}^{i = 0}\omega _{i}x_{i} ,所以继续推导公式:

\Rightarrow \frac{\partial J(\theta)}{\partial \theta_{j}} = (h_{\theta}x - y) \cdot \frac{\partial \sum_{n}^{i =0}(\theta_{i}x_{i}-y)}{\partial \theta_{j}}

由于我们是对 \theta_{j} 求偏导,那么和 \theta_{j} 无关的可以忽略不计,所以公式变为:

\Rightarrow \frac{\partial J(\theta)}{\partial \theta_{j}} = (h_{\theta}x - y) \cdot x_{j}

所以,我们可以得到结论:\theta_{j} 对应的梯度(gradient)与预测值 \hat{y} 和真实值 y 有关,同时还与每个特征维度 x_{j} 有关。如果我们分别对每个维度求偏导,即可得到所有维度对应的梯度值。

3.根据梯度和学习率更新参数

通过11.梯度下降法的思想——举足轻重的模型优化算法-CSDN博客的学习,我们已经知道了梯度下降法的公式:

W_{j}^{t+1} = W_{j}^{t} - \eta \cdot gradient_{j}

在获得了梯度之后,我们可以将公式表示为:

W_{j}^{t+1} = W_{j}^{t} - \eta \cdot (h_{\theta}x - y) \cdot x_{j}

通过这个公式我们就可以去更新参数逼近最优解。

4.判断是否收敛

在如何判断收敛问题上,我相信大多数的人都会认为直接判断梯度(gradient)是否为0。其实这样的方法是错误的,由于非凸损失函数的存在,gradient = 0 的情况可能是极大值!所以我们使用了另外一种方法,设置合理的阈值(Threshold)来界定函数是否收敛。即判断不等式:

Loss^{t} - Loss^{t+1} < Threshold

如果前一次的损失函数 Loss^{t} 减去这次迭代后的损失函数 Loss^{t+1} 小于我们设定的阈值Threshold ,那我们认为函数收敛,当前的参数就是我们寻求的最优解。反之,我们重复第二步与第三步,一直达到最优解为止。其实我们是在判断 Loss 的下降收益是否更合理,随着迭代次数的增多,Loss 减小的幅度不再变化就可以认为停止在最低点。

二、梯度下降法的分类

我们根据梯度下降法流程中求取梯度的步骤样本数量的不同,将梯度下降法分为三个基本的类别。它们每次学习(更新模型参数)使用的样本个数,每次更新使用不同的样本会导致每次学习的准确性和学习时间不同

1.全量梯度下降(Batch Gradient Descent)

全量梯度下降(Batch Gradient Descent)通过使用整个数据集在每次迭代中计算损失函数的梯度,以此更新模型参数(也称批量梯度下降)。由于我们使用整个数据集的样本,所以全量梯度下降的公式为:

W_{j}^{t+1} = W_{j}^{t} - \eta \cdot \sum_{m}^{i = 1}(h_{\theta}x_{i} - y_{i}) \cdot x_{j}

在全量梯度下降中,对于 \theta 的更新,所有的样本都有贡献,也就是参与调整 \theta 。所以从理论上来说一次更新的幅度是比较大的。

全量梯度下降法的优点在于收敛稳定,每次更新都朝着全局最优的方向移动。并且能够净化噪声,由于使用整个数据集计算梯度,随机噪声对更新的影响较小,使得损失函数的路径相对平滑。

缺点也是相当明显,当数据集非常大时,全量梯度下降法每个迭代计算数据集的梯度是非常耗时且占用内存的。所以不适合处理实时数据,比如在线学习和实时更新数据场景。

上图表示的梯度下降法中两个维度参数的关系,我们可以将圆圈看成一个碗的俯视图,碗底就是我们要找的最优解。我们不难发现,全量梯度下降法每次迭代都直接向碗底行进,目标明确。

2.随机梯度下降(Stochastic Gradient Descent)

随机梯度下降(Stochastic Gradient Descent)通过使用数据集中的一个随机样本在每次迭代中计算损失函数的梯度,以此更新模型参数。由于使用随机的一个样本,所以随机梯度下降的公式就是:

W_{j}^{t+1} = W_{j}^{t} - \eta \cdot (h_{\theta}x - y) \cdot x_{j}

随机梯度下降的优点在于计算速度快,由于每次迭代只对一个样本计算梯度,因此更新速度快,适合大规模数据集。它还拥有更强的泛化能力,由于引入了随机性,SGD能更好地跳出局部最优,避免过拟合(过拟合相关内容会在专栏后续文章中更新)。并且能够处理实时数据,可以在线学习,所以适用于动态更新的场景。

同样地,由于每次更新只基于一个样本,SGD的收敛并不稳定,梯度波动较大,会导致损失函数的收敛路径不平稳。并且由于随机性的存在,SGD通常需要更多的迭代次数才能收敛到最优解,即收敛速度变慢

从上图我们可以看出,相比较全量梯度下降,SGD需要迭代更多的次数才能找到最优解。

3.小批量梯度下降(Mini-batch Gradient Descent)

小批量梯度下降(Mini-batch Gradient Descent)通过使用数据集的一部分样本在每次迭代中计算损失函数的梯度,以此更新模型参数。由于使用了数据集的部分样本,所以小批量梯度下降的公式为:

W_{j}^{t+1} = W_{j}^{t} - \eta \cdot \sum_{batchsize}^{i = 1}(h_{\theta}x_{i} - y_{i}) \cdot x_{j}

小批量梯度下降综合了全量梯度下降与随机梯度下降,在更新速度与更新次数中取得一个平衡。其每次更新从数据集中随机选择 batchsize 个样本进行学习。相对于随机梯度下降法,小批量梯度下降法降低了收敛的波动性(降低了参数更新的方差),使得更新更加稳定。相对于全量梯度下降法,其提高了每次学习的速度。

小批量梯度下降的优点在于平衡了计算效率和收敛稳定性。并且不用担心内存瓶颈而使用向量化计算,还能利用GPU的并行计算能力提高计算速度。在每个小批量中,我们可以设置不同的学习率,提高模型的训练表现。

小批量梯度下降的缺点则在于样本的大小会影响训练效果,所以我们要人为地选择合适的样本大小。

从下图中我们就能看到随机梯度下降与小批量梯度下降的区别。

总结

本篇博客讲解了梯度下降法的流程和大致的分类。希望可以对大家起到作用,谢谢。


关注我,内容持续更新(后续内容在作者专栏《从零基础到AI算法工程师》)!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2187270.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

力扣 中等 129.求根节点到叶子结点数字之和

文章目录 题目介绍解法 题目介绍 解法 法一&#xff1a;有返回值、 class Solution {public int sumNumbers(TreeNode root) {return dfs(root, 0);}public int dfs(TreeNode root, int x) {if (root null) {return 0;}x x * 10 root.val;if (root.left root.right) { //…

LC刷题专题:dfs、哈希表合集

自己刷题缺少分类思想&#xff0c;总是这里刷一道那里刷一道&#xff0c;以后建立几个专辑&#xff0c;然后自己新刷的同类型的题目都会即使更新上。 文章目录 690. 员工的重要性 690. 员工的重要性 2024-10-03 题目描述&#xff1a; 我第一次写并没有考虑到dfs&#xff0c;…

基于Arduino的L298N电机驱动模块使用

一.简介&#xff1a; L298N作为电机驱动芯片&#xff0c;具有驱动能力强&#xff0c;发热量低&#xff0c;抗干扰能力强的特点,一个模块可同时驱动两个直流电机工作&#xff0c;能够控制电机进行正转、反转、PWM调速。 说明&#xff1a; 1&#xff09;12V输入端口接入供电电压…

esp32开发环境搭建和烧录测试

文章目录 前言一、硬件环境1、esp32开发板。2、两个micro usb 数据线&#xff0c;一路用于供电&#xff0c;另一路用于烧录和调试3、喇叭&#xff0c; 淘宝上买的 4 欧姆 3 W扬声器 二、软件环境配置1、开发软件2、ESP-IDF简介下载 3、vscode安装配置1、安装vscode2、安装IDF插…

论文提纲怎么写?分享5款AI论文写作软件

在学术研究和写作过程中&#xff0c;撰写高质量的论文是一项挑战性的任务。幸运的是&#xff0c;随着人工智能技术的发展&#xff0c;AI论文写作工具逐渐成为帮助学者和学生提高写作效率的重要工具。这些工具不仅能够提高写作效率&#xff0c;还能帮助简化复杂的写作流程&#…

C++(string类的实现)

1. 迭代器、返回capacity、返回size、判空、c_str、重载[]和clear的实现 string类的迭代器的功能就类似于一个指针&#xff0c;所以我们可以直接使用一个指针来实现迭代器&#xff0c;但如下图可见迭代器有两个&#xff0c;一个是指向的内容可以被修改&#xff0c;另一个则是指…

【JNI】hello world

JNI&#xff0c;作为java和C/C的中间层&#xff0c;为在Java中调用C/C代码提供了便利。作为初学者&#xff0c;这里简单记录学习的过程。 本文所有的操作都在kali linux上进行&#xff0c;jdk环境以及gcc&#xff0c;g编译器需自行提前安装好 操作系统&#xff1a; jdk&#…

行为型模式-命令-迭代-观察者-策略

命令模式 是什么 将一个请求封装成为一个对象, 从而可以使用不同的请求对客户进行参数化,对请求排队或记录请求日志,以及可以撤销的操作 实例 请求封装成为对象 //用来声明执行操作的接口 public abstract class Command { protected Receiver receiver; public Comma…

【网络原理】Udp报文结构,保姆级详解,建议收藏

&#x1f490;个人主页&#xff1a;初晴~ &#x1f4da;相关专栏&#xff1a;计算机网络那些事 一、UDP报文格式 ​ ​ 可以看出UDP报文主要由报头和正文两部分构成&#xff0c;报头存储了此次报文的一些重要信息&#xff0c;而正文才是真正需要传输的内容。本篇文章就主要…

【Kubernetes】常见面试题汇总(五十二)

目录 116. K8S 集群服务暴露失败&#xff1f; 117.外网无法访问 K8S 集群提供的服务&#xff1f; 特别说明&#xff1a; 题目 1-68 属于【Kubernetes】的常规概念题&#xff0c;即 “ 汇总&#xff08;一&#xff09;~&#xff08;二十二&#xff09;” 。 题目 69-…

Windows 环境上安装 NASM 和 YASM 教程

NASM 和 YASM NASM NASM&#xff08;Netwide Assembler&#xff09;是一个开源的、可移植的汇编器&#xff0c;它支持多种平台和操作系统。它可以用来编写16位、32位以及64位的代码&#xff0c;并且支持多种输出格式&#xff0c;包括ELF、COFF、OMF、a.out、Mach-O等。NASM使用…

复习HTML(进阶)

前言 上一篇的最后我介绍了在表单中&#xff0c;上传文件需要使用到 method属性 和enctype属性。本篇博客主要是详细的介绍这些知识 <form action"http://localhost:8080/test" method"post" enctype"multipart/form-data"> method属性…

SQL Inject-基于报错的信息获取

常用的用来报错的函数 updatexml() : 函数是MYSQL对XML文档数据进行查询和修改的XPATH函数。 extractvalue(): 函数也是MYSQL对XML文档数据进行查询的XPATH函数。 floor(): MYSQL中用来取整的函数。 思路&#xff1a; 在MySQL中使用一些指定的函数来制造报错&am…

YOLOv8改进 - 注意力篇 - 引入SEAttention注意力机制

一、本文介绍 作为入门性篇章&#xff0c;这里介绍了SEAttention注意力在YOLOv8中的使用。包含SEAttention原理分析&#xff0c;SEAttention的代码、SEAttention的使用方法、以及添加以后的yaml文件及运行记录。 二、SEAttention原理分析 SEAttention官方论文地址&#xff1…

深度学习——线性神经网络(一、线性回归)

目录 一、线性回归1.1 线性回归的基本元素1.1.1 术语介绍1.1.2 线性模型1.1.3 损失函数1.1.4 解析解1.1.5 随机梯度下降1.1.6 模型预测 1.2 正态分布与平方损失 因为线性神经网络篇幅比较长&#xff0c;就拆成几篇博客分开发布。目录序号保持连贯性。 一、线性回归 回归&#x…

基于单片机智能百叶窗卷帘门自动门系统

** 文章目录 前言概要功能设计设计思路 软件设计效果图 程序文章目录 前言 &#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN特邀作者、博客专家、CSDN新星计划导师&#xff0c;一名热衷于单片机技术探索与分享的博主、专注于 精通51/STM32/MSP430/AVR等单片机设计 主要对…

二叉树深度学习——将二叉搜索树转化为排序的双向链表

1.题目解析 题目来源&#xff1a;LCR 155.将二叉搜索树转化为排序的双向链表 测试用例 2.算法原理 首先题目要求原地进行修改并且要求左指针代表前驱指针&#xff0c;右指针代表后继指针&#xff0c;所以思路就是 1.使用前序遍历创建两个指针cur、prev代表当前节点与前一个节点…

STM32三种启动模式:【详细讲解】

STM32在上电后&#xff0c;从那里启动是由BOOT0和BOOT1引脚的电平决定的&#xff0c;如下表&#xff1a; BOOT模式选引脚启动模式BOOT0BOOT1X0主Flash启动01系统存储器启动11内置SRAM启动 BOOT 引脚的值在重置后 SYSCLK 的第四个上升沿时被锁定。在重置后,由用户决定是如何设…

硬件开发笔记(三十):TPS54331电源设计(三):设计好的原理图转设计PCB布板,12V输入电路布局设计

若该文为原创文章&#xff0c;转载请注明原文出处 本文章博客地址&#xff1a;https://hpzwl.blog.csdn.net/article/details/142694484 长沙红胖子Qt&#xff08;长沙创微智科&#xff09;博文大全&#xff1a;开发技术集合&#xff08;包含Qt实用技术、树莓派、三维、OpenCV…

挖矿病毒记录 WinRing0x64.sys

之前下载过福晰pdf编辑器&#xff0c;使用正常。 某天发现机器启动后&#xff0c;过个几分钟(具体为5min)会自动运行几个 cmd 脚本(一闪而过)&#xff0c;但是打开任务管理器没有发现异常程序&#xff08;后面发现病毒程序伪装成System系统程序&#xff0c;见下图&#xff09;…