探索线性回归中的梯度下降法

news2024/11/25 18:55:30

目录

  • 前言
  • 1 梯度下降的基本思想
  • 2 梯度下降的公式
  • 3 梯度下降的步骤
    • 3.1 初始化参数
    • 3.2 计算梯度
    • 3.3 更新参数
    • 3.4 迭代更新
  • 4 学习率的控制
    • 4.1 过大学习率的问题
    • 4.2 过小学习率的问题
    • 4.3 学习率的调整
  • 5 批量梯度下降方法
    • 5.1 批量梯度下降(Batch Gradient Descent)
    • 5.1 小批量梯度下降(Mini-batch Gradient Descent)
  • 结语

前言

线性回归是机器学习中常用的模型之一,而梯度下降法则是优化线性回归模型参数的重要手段之一。本文将深入探讨梯度下降法在线性回归中的应用,包括其基本思想、相关公式、步骤、学习率的控制以及批量梯度下降方法。通过详细阐述这些内容,希望读者能够更好地理解和运用梯度下降法来优化线性回归模型。

1 梯度下降的基本思想

在机器学习中,梯度下降法是一种常用的优化算法,其核心思想是通过迭代的方式逐步调整模型参数,以降低目标函数(损失函数)的值。在线性回归中,我们的目标是找到最优的权重 w 和偏置 b,使得损失函数$J(w,b) $取得最小值。

线性回归的目标函数通常以平方损失为例,即

J ( w , b ) = 1 2 m ∑ i = 1 m ( y i − ( w x i + b ) ) 2 J(w, b) = \frac{1}{2m} \sum_{i=1}^{m} (y_i - (wx_i + b))^2 J(w,b)=2m1i=1m(yi(wxi+b))2
其中,m是样本数量,( x i x_i xi, y i y_i yi)是训练集中的样本。这个公式描述了平方损失的均方差,表示模型预测值与实际值之间的差异,梯度下降的目标是最小化这个损失函数。

通过梯度下降法,我们希望找到使得目标函数最小化的 w 和 b。梯度下降的基本思想是计算目标函数对于参数的梯度(偏导数),然后沿着梯度的反方向调整参数,以减小目标函数的值。

2 梯度下降的公式

梯度下降法的核心在于通过对目标函数进行偏导数的计算,求解梯度,然后根据梯度的反方向来更新模型参数。在线性回归中,我们的目标是最小化损失函数$J(w,b) $。

在这里插入图片描述

权重的更新

w : = w − α ∂ J ( w , b ) ∂ w w := w - \alpha \frac{\partial J(w, b)}{\partial w} w:=wαwJ(w,b)

w = w − α 1 m ∑ i = 1 m ( y i − ( w x i + b ) ) x i w = w - \alpha \frac{1}{m} \sum_{i=1}^{m} (y_i - (wx_i + b))x_i w=wαm1i=1m(yi(wxi+b))xi

偏置的更新

b : = b − α ∂ J ( w , b ) ∂ b b := b - \alpha \frac{\partial J(w, b)}{\partial b} b:=bαbJ(w,b)

b = b − α 1 m ∑ i = 1 m ( y i − ( w x i + b ) )   b = b - \alpha \frac{1}{m} \sum_{i=1}^{m} (y_i - (wx_i + b)) \ b=bαm1i=1m(yi(wxi+b)) 

其中, α \alpha α 是学习率,它是一个正数,用于控制每次迭代的步长。学习率的选择对梯度下降的性能影响很大,过大的学习率可能导致震荡,而过小的学习率可能导致收敛速度过慢。

在更新公式中, ∂ J ( w , b ) ∂ w \frac{\partial J(w, b)}{\partial w} wJ(w,b)表示损失函数关于权重 w的偏导数,而 $\frac{\partial J(w, b)}{\partial b} $表示关于偏置 b的偏导数。这两个偏导数告诉我们在当前参数下,目标函数的变化率,梯度下降通过不断减小这些变化率来逼近最小值。

3 梯度下降的步骤

梯度下降是一种迭代优化算法,用于最小化目标函数。在线性回归中,梯度下降的步骤可以简要概括如下。

在这里插入图片描述

3.1 初始化参数

在开始优化过程之前,需要初始化模型参数。通常可以将权重 w 和偏置 b 初始化为零或者随机的小值。这一步是为了给优化算法一个起始点。

3.2 计算梯度

计算目标函数$J(w,b) $关于参数 w 和 b 的偏导数,即梯度。梯度告诉我们目标函数在当前参数点上的变化率。对于线性回归,梯度的计算涉及对损失函数关于权重 w 和偏置 b 的偏导数。

3.3 更新参数

使用梯度和预先设定的学习率 α,通过梯度下降的更新规则来调整参数 w 和 b。更新规则如下:

w : = w − α ∂ J ( w , b ) ∂ w w := w - \alpha \frac{\partial J(w, b)}{\partial w} w:=wαwJ(w,b)

b : = b − α ∂ J ( w , b ) ∂ b b := b - \alpha \frac{\partial J(w, b)}{\partial b} b:=bαbJ(w,b)

这一步的目的是沿着梯度的反方向调整参数,以减小目标函数的值。

3.4 迭代更新

重复步骤 b 和 c,直至满足停止条件。停止条件可以是达到最大迭代次数或者梯度趋近于零。迭代的过程中,参数不断被调整,目标函数逐渐趋近最小值。

通过这些步骤,梯度下降能够有效地搜索参数空间,找到使得损失函数最小化的最优参数,从而优化线性回归模型。

4 学习率的控制

学习率是梯度下降中一个至关重要的参数,它决定了每次迭代中模型参数更新的步长。选择合适的学习率对于梯度下降的性能和收敛速度至关重要。

4.1 过大学习率的问题

如果学习率过大,可能导致梯度下降算法在参数空间中跳动或震荡,甚至无法收敛到最小值。这是因为过大的学习率使得每次迭代参数更新过大,导致优化过程失控。

4.2 过小学习率的问题

相反,如果学习率过小,模型参数更新的步长太小,梯度下降收敛速度会很慢,甚至可能陷入局部最小值而无法找到全局最小值。

4.3 学习率的调整

一种常用的学习率调整方法是进行实验,通过尝试不同的学习率来找到一个在特定问题上表现良好的值。另一种方法是使用自适应学习率的技术,如Adagrad、Adadelta、Adam等,它们可以根据梯度的历史信息来动态地调整学习率,以更灵活地适应优化过程。

在实践中,可以从一个较小的学习率开始,观察损失函数的下降情况。如果发现收敛速度过慢,可以逐渐增大学习率。然而,需要注意不要选择过大的学习率,以免影响优化的稳定性。

通过合理调整学习率,梯度下降算法能够更好地在参数空间中搜索,加速模型的收敛,并更有效地优化线性回归模型。

5 批量梯度下降方法

梯度下降的方法不仅仅限于单一形式,批量梯度下降是其中一种形式,它的特点是每次迭代都利用所有训练样本来计算梯度。这相对于随机梯度下降更为稳定,但在大数据集上计算梯度较为耗时。为了解决这一问题,引入了小批量梯度下降,作为一种折中的选择,它使用一小部分样本来估计梯度。

5.1 批量梯度下降(Batch Gradient Descent)

在批量梯度下降中,每次迭代都需要对整个训练集进行计算。其权重和偏置的更新公式如下:

w : = w − α 1 m ∑ i = 1 m ∂ J ( w , b ) ∂ w w := w - \alpha \frac{1}{m} \sum_{i=1}^{m} \frac{\partial J(w, b)}{\partial w} w:=wαm1i=1mwJ(w,b)

b : = b − α 1 m ∑ i = 1 m ∂ J ( w , b ) ∂ b b := b - \alpha \frac{1}{m} \sum_{i=1}^{m} \frac{\partial J(w, b)}{\partial b} b:=bαm1i=1mbJ(w,b)

其中,$ m$ 是训练样本的数量,$ \alpha$ 是学习率。

5.1 小批量梯度下降(Mini-batch Gradient Descent)

小批量梯度下降是一种折中方案,每次迭代时仅利用一小部分样本来估计梯度。这样可以在保持一定稳定性的同时,减少计算开销。更新公式如下:

w : = w − α 1 b a t c h _ s i z e ∑ i = 1 b a t c h _ s i z e ∂ J ( w , b ) ∂ w w := w - \alpha \frac{1}{batch\_size} \sum_{i=1}^{batch\_size} \frac{\partial J(w, b)}{\partial w} w:=wαbatch_size1i=1batch_sizewJ(w,b)

b : = b − α 1 b a t c h _ s i z e ∑ i = 1 b a t c h _ s i z e ∂ J ( w , b ) ∂ b b := b - \alpha \frac{1}{batch\_size} \sum_{i=1}^{batch\_size} \frac{\partial J(w, b)}{\partial b} b:=bαbatch_size1i=1batch_sizebJ(w,b)

其中,$ batch_size $ 是每次迭代使用的样本数量。

选择何种梯度下降方法取决于数据集的规模和计算资源的可用性。批量梯度下降适用于较小的数据集,而小批量梯度下降则可以在大规模数据集上更高效地进行计算。随机梯度下降则是一种更为轻量级的方法,适用于在线学习或数据流式处理。

通过灵活选择梯度下降的形式,我们能够更好地平衡计算效率和模型稳定性,从而优化线性回归模型。

结语

通过本文对梯度下降在线性回归中的深入探讨,我们理解了其基本思想、公式、步骤、学习率的控制以及批量梯度下降方法。在实际应用中,灵活运用梯度下降算法,调整参数和学习率,将有助于优化线性回归模型,提高其性能和泛化能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1457085.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代码随想录算法训练营DAY20 | 二叉树(7) (续)

一、LeetCode 236 二叉树的最近公共祖先 题目链接:236.二叉树的最近公共祖先https://leetcode.cn/problems/lowest-common-ancestor-of-a-binary-tree/description/ 思路:利用后序遍历是天然回溯过程、方便实现自底向上查找的原理,递归寻找公…

基于SpringBoot的高校竞赛管理系统

基于SpringBoot的高校竞赛管理系统的设计与实现~ 开发语言:Java数据库:MySQL技术:SpringBootMyBatis工具:IDEA/Ecilpse、Navicat、Maven 系统展示 主页 个人中心 管理员界面 老师界面 摘要 高校竞赛管理系统是为了有效管理学校…

书生开源大模型-第2讲-笔记

1.环境准备 1.1环境 先克隆我们的环境 bash /root/share/install_conda_env_internlm_base.sh internlm-demo1.2 模型参数 下载或者复制下来,开发机中已经有一份参数了 mkdir -p /root/model/Shanghai_AI_Laboratory cp -r /root/share/temp/model_repos/inter…

分库分表浅析

简介 对于任何系统而言,都会设计到数据库随着时间增长而累积越来越多的数据,系统也因为越来越多的需求变迁导致原有的设计不再满足现状,为了解决这些问题,分库分表就会走进视野,带着几个问题走入分库分表。 什么是分…

嵌入式学习第十八天(目录IO)

目录IO: 1. mkdir int mkdir(const char *pathname, mode_t mode); 功能:创建目录文件 参数: pathname:文件路径 mode:文件的权限 rwx rwx rwx 111 111 111 0 7 7 7 r:目录中是否能够查看文件 w:目…

瑞_23种设计模式_代理模式

文章目录 1 代理模式(Proxy Pattern)1.1 介绍1.2 概述1.3 代理模式的结构 2 静态代理2.1 介绍2.2 案例——静态代理2.3 代码实现 3 JDK动态代理★★★3.1 介绍3.2 代码实现3.3 解析代理类3.3.1 思考3.3.2 使用 Arthas 解析代理类3.3.3 结论 3.4 动态代理…

ARM体系在linux中的中断抢占

上一篇说到系统调用等异常通过向量el1_sync做处理,中断通过向量el1_irq做处理,然后gic的工作都是为中断处理服务,在rtos中,我们一般都会有中断嵌套和优先级反转的概念,但是在linux中,中断是否会被其他中断抢…

RTC时钟

目录 一、STM32F407内部RTC硬件框图,主要由五大部分组成: 二、硬件相关引脚 三、具体代码设置步骤 四、了解其它知识点 一、STM32F407内部RTC硬件框图,主要由五大部分组成: ① 时钟源 (1)LSE:一般我们选择 LSE&am…

网络编程_TCP通信综合练习:

1 //client:: public class Client {public static void main(String[] args) throws IOException {//多次发送数据//创建socket对象,填写服务器的ip以及端口Socket snew Socket("127.0.0.1",10000);//获取输出流OutputStream op s.getOutput…

python统计分析——一元线性回归分析

参考资料:用python动手学统计学 1、导入库 # 导入库 # 用于数值计算的库 import numpy as np import pandas as pd import scipy as sp from scipy import stats # 用于绘图的库 import matplotlib.pyplot as plt import seaborn as sns sns.set() # 用于估计统计…

【高效开发工具系列】PyCharm使用

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

vue3项目配置按需自动导入API组件unplugin-auto-import

场景应用:避免写一大堆的import,比如关于Vue和Vue Router的 1、安装unplugin-auto-import npm i -D unplugin-auto-import 2、配置vite.config import AutoImport from unplugin-auto-import/vite//按需自动加载API插件 AutoImport({ imports: ["…

Unity中的Lerp插值的使用

Unity中的Lerp插值使用 前言Lerp是什么如何使用Lerp 前言 平时在做项目中插值的使用避免不了,之前一直在插值中使用存在误区,在这里浅浅记录一下。之前看的博客或者教程还多都存在一个“永远到达不了,只能无限接近”的一个概念。可能是之前脑…

ThreadLocal “你”真的了解吗?

今天想梳理一个常见的面试题。在开始之前,让我们一起来回顾一下昨天的那篇文章——《Spring 事务原理总结七》。这篇文章比较啰嗦,层次也不太清晰,所以以后有机会我一定要重新整理一番。这篇文章主要想表达这样一个观点:Spring的嵌…

对于软件测试的理解

前言 “尽早的介入测试,遇到问题的解决成本就越低” 随着软件测试技术的发展,测试工作由原来单一的寻找缺陷逐渐发展成为预防缺陷,探索测试,破坏程序的过程,测试活动贯穿于整个软件生命周期中,故称为全程…

【SpringBoot】项目启动增加自定义Banner

SpringBoot项目启动增加自定义Banner 前言 最近有个老哥推荐我给博客启动的时候加上自定义Banner,开始我还不太明白他说的是那部分,后面给我发了这样一个,瞬间就懂了~ // _ooOoo_ …

Python(九十三)函数的参数总结

❤️ 专栏简介:本专栏记录了我个人从零开始学习Python编程的过程。在这个专栏中,我将分享我在学习Python的过程中的学习笔记、学习路线以及各个知识点。 ☀️ 专栏适用人群 :本专栏适用于希望学习Python编程的初学者和有一定编程基础的人。无…

不要0!我们需要1!

解法一&#xff1a; 十进制转二进制同时数1的个数 #include<iostream> #define endl \n using namespace std; void solve(int x) {int cnt 0;while (x) {if (x % 2 1) cnt;x / 2;}cout << cnt << endl; } int main() {int n;cin >> n;solve(n);re…

2024-2-19 LC200. 岛屿数量

其实还是用并查集将 独立的岛屿视为独立的子集。 count其实是集合的个数&#xff0c;同一个块岛屿被压缩成了一个集合&#xff0c;而每个表示海洋的格子依然被看作独立的集合&#xff0c;在所有的格子都走完一遍后&#xff0c;count 被压缩的岛屿 所有表示海洋的独立格子的数…

2024.2.19

使用fread和fwrite完成两个文件的拷贝 #include<stdio.h> #include<stdlib.h> #include<string.h> int main(int argc, const char *argv[]) {FILE *fpNULL;if((fpfopen("./tset.txt","w"))NULL){perror("open error");retur…