1.4.1机器学习——梯度下降+α学习率大小判定

news2025/1/19 7:11:45

1.4.1梯度下降

4.1、梯度下降的概念

※【总结一句话】:系统通过自动的调节参数w和b的值,得到最小的损失函数值J。

如下:是梯度下降的概念图。

  • 我们有一个损失函数 J(w,b),包含两个参数w和b(你可以想象成J(w,b) = w*x + b) ,我们想要找到最合适的w和b,尝试最小化损失函数 J(w,b)的值

  • ”梯度下降“:梯度下降(gradient descent)在机器学习中应用十分的广泛,不论是在线性回归还是Logistic回归中,它的主要目的是通过迭代找到目标函数的最小值,或者收敛到最小值

  • 它还适用于有多个参数的(如图:w1~ wn,b) :梯度下降的任务是调节w1 ~ wn 和 参数 b 的值,最小化损失函数的值

  • 梯度下降法是用来计算损失函数最小值的。它的思路很简单,想象在山顶放了一个球,一松手它就会顺着山坡最陡峭的地方滚落到谷底:

=

4.2、梯度下降概述

在这里插入图片描述

  • J(w,b) = wx + b : 我们开始的时候w和b可以设置为任意值。(这里我们设置 w=0 , b = 0)
  • “梯度下降”要做的就是:通过迭代不断的调整 w和b 的值,去尽量的降低损失函数J(w,b)
  • 直到我们的损失函数J(w,b) 达到/接近 “谷底”/最小值
  • 【※注意】 :损失函数可能不仅仅是一个如右图的抛物线,也可能是“高尔夫球场图”,这样的话minimum最小值的个数就不仅仅是一个了

在这里插入图片描述
“高尔夫球场图”

  • 图中XYZ轴分别代表:W,b , 损失函数 J(w,b)的值

  • 他不是一个 “平方误差损失函数”(线性回归使用 平方误差损失函数)图,因为 “平方误差损失函数”常常以抛物线 / 吊床 的形状结束

  • 这个小人站在哪的起始点取决于你选择的参数w 和 b的起始值

  • 两个谷底最低点都是局部最优解

1.4.2 理解梯度下降+梯度下降偏导的意义+α学习率

  • 图解定义:

在这里插入图片描述

  • 对梯度下降公式算法的解读:
    • 其中参数 w和b 是一个同步进行 修改/微调 的。(同步进行,就是计算w和b的更新的两个公式是同步计算 的)
    • 重复以上两个公式,直到达到结果收敛为止
    • 在收敛重复这个公式的过程中,在计算tmp_w 和tmp_b中 的w是上一次迭代的w的值,只有tmp_w 和tmp_b都计算完成才会进行更新
  • 你可能存在疑问:为什么不要用中间值 tmp_w 和wmp_b,直接就是 w = w - α…;b= b - α…不行吗???
    • 答:你提问的很好。但是在这个梯度下降w,b迭代执行过程中,w和b同时参与了w,b的计算。
    • 简单的说如下图,以W为例,加入在第i次迭代的情况下:
      • 在第i次迭代的情况下执行完第一步w = w -0.2 后(假设这时候 J(w,b) = 0.2),
      • 执行第二步 b = b - J(w,b) 的时候,你想放 执行w = w -0.2 更新之前的值 or 更新之后的值???
      • 每一次迭代,在这第i次中,w 和 b对应的是第i -1次的值,在第i次迭代结束时才一起更新
      • 如果不借用中间值 tmp_w 和 tmp_b,那么w 和 b 的最终值 就会受到影响
      • 在这里插入图片描述

为了方便理解α(学习率),我们这里把b设置为0,只考虑有w的一维情况

  • 其中α:称为学习率(粉色框内容),通常这个值在 0~1 之间,是控制w和b的调参步长

    • 合适:在这里插入图片描述

    过程解释:

    红框中:是求w的偏导,这里求完偏导知道是>0

    α永远是 大于0 小于1 的数

    然后w = w - α(正数) ====> w变小

    所以就实现了微调

    在这里插入图片描述

    当W的偏导< 0 时:
    红框中:是求w的偏导,这里求完偏导知道是<0

    α永远是 大于0 小于1 的数

    然后w = w - α(负数) ====> w变大

    所以就实现了微调.

    在这里插入图片描述

    • 如果α非常大:那么这个是一个非常激进的梯度下降的过程,步长会非常大,反而会越过谷底,不断上升:+

    • 动图地址
      图片: ![Alt](https://img-blog.csdnimg.cn/img_convert/b8a2549083eae7aa3a33cc80c6cee5dd.webp?x-oss-process=image/format,png)

    吴恩达老师的解释:

    ​ 如果α过大/过大的步长:

    ​ 会导致超过 , 并且从来不会达到最小值。

    ​ 不能直线收敛,甚至是发散

在这里插入图片描述

  • 如果α非常小:比如 α = 0.0001 就过于小了,迭代 20 次后离谷底还很远,实际上 100 次后都无法到达谷底:

    梯度下降会起作用,但是需要很长的时间

  • 动图地址
    在这里插入图片描述

吴恩达的解释:

​ 如果α 步长太小,会实现收敛,但是这个收敛的过程会很慢很慢

  • **总结:**不同的步长α ,随着迭代次数的增加,会导致被优化函数f(x) 的值有不同的变化:

img

关于α选择以及判定的 详细内容看: 2.2.3机器学习—— 判定梯度下降是否收敛 + α学习率的选择

1.4.3 用于线性回归的梯度下降

在这里插入图片描述

  • 公式推到来源:

    在这里插入图片描述

  • 那么 线性回归梯度下降就是:

重复对w 和 b 执行更新直到收敛

在这里插入图片描述

  • 当使用线性回归的平方误差损失函数时,全局只有一个最低点

在这里插入图片描述

  • 当使用非平方误差函数时,就是非线性回归梯度下降的时候就会出现 >=1 的局部最优解

在这里插入图片描述

1.4.4 线性回归的梯度下降的应用

  • 等高线最中心那个圈是 损失函数值最小的点,Cost值越小,说明线性回归的拟合越好,直到我们达到全局最小值
  • 比如当 Size in feet 是1250 ,对应的回归预测值是 250 K

在这里插入图片描述

  • “批量梯度下降”:每一次的梯度下降使用的是全部的训练的数据:
  • 当计算 w 和 b 的偏导时,我们从 1 ~ m 所有的数据都计算上,然后相加求平均

在这里插入图片描述

※1.4.5 线性回归的梯度下降函数的代码实现

1、求损失函数

求损失函数代码如下:

def compute_cost(x, y, w, b): 
    """
    Computes the cost function for linear regression.
    
    Args:
      x (ndarray (m,)): Data, m examples 
      y (ndarray (m,)): target values
      w,b (scalar)    : model parameters  
    
    Returns
        total_cost (float): The cost of using w,b as the parameters for linear regression
               to fit the data points in x and y
    """
    # number of training examples
    m = x.shape[0] 
    
    cost_sum = 0 
    for i in range(m): 
        f_wb = w * x[i] + b   
        cost = (f_wb - y[i]) ** 2  
        cost_sum = cost_sum + cost  
    total_cost = (1 / (2 * m)) * cost_sum  

    return total_cost

2、求偏导 / 梯度

※求偏导代码如下:

def compute_gradient(x, y, w, b): 
    """
    Computes the gradient for linear regression 
    Args:
      x (ndarray (m,)): Data, m examples 
      y (ndarray (m,)): target values
      w,b (scalar)    : model parameters  
    Returns
      dj_dw (scalar): The gradient of the cost w.r.t. the parameters w
      dj_db (scalar): The gradient of the cost w.r.t. the parameter b     
     """
    
    # Number of training examples
    m = x.shape[0]    
    dj_dw = 0
    dj_db = 0
    
    for i in range(m):  
        f_wb = w * x[i] + b 
        dj_dw_i = (f_wb - y[i]) * x[i] 
        dj_db_i = f_wb - y[i] 
        dj_db += dj_db_i
        dj_dw += dj_dw_i 
    dj_dw = dj_dw / m 
    dj_db = dj_db / m 
        
    return dj_dw, dj_db

3、梯度下降函数

def gradient_descent(x, y, w_in, b_in, alpha, num_iters, cost_function, gradient_function): 
    """
    Performs gradient descent to fit w,b. Updates w,b by taking 
    num_iters gradient steps with learning rate alpha
    
    Args:
      x (ndarray (m,))  : Data, m examples 
      y (ndarray (m,))  : target values
      w_in,b_in (scalar): initial values of model parameters  
      alpha (float):     Learning rate
      num_iters (int):   number of iterations to run gradient descent
      cost_function:     function to call to produce cost
      gradient_function: function to call to produce gradient
      
    Returns:
      w (scalar): Updated value of parameter after running gradient descent
      b (scalar): Updated value of parameter after running gradient descent
      J_history (List): History of cost values
      p_history (list): History of parameters [w,b] 
      """
    
    w = copy.deepcopy(w_in) # avoid modifying global w_in
    # An array to store cost J and w's at each iteration primarily for graphing later
    J_history = []
    p_history = []
    b = b_in
    w = w_in
    
    for i in range(num_iters):
        # Calculate the gradient and update the parameters using gradient_function
        dj_dw, dj_db = gradient_function(x, y, w , b)     

        # Update Parameters using equation (3) above
        b = b - alpha * dj_db                            
        w = w - alpha * dj_dw                            

        # Save cost J at each iteration
        if i<100000:      # prevent resource exhaustion 
            J_history.append( cost_function(x, y, w , b))
            p_history.append([w,b])
        # Print cost every at intervals 10 times or as many iterations if < 10
        if i% math.ceil(num_iters/10) == 0:
            print(f"Iteration {i:4}: Cost {J_history[-1]:0.2e} ",
                  f"dj_dw: {dj_dw: 0.3e}, dj_db: {dj_db: 0.3e}  ",
                  f"w: {w: 0.3e}, b:{b: 0.5e}")
 
    return w, b, J_history, p_history #return w and J,w history for graphing
※※关于α选择以及判定的 详细内容看: 2.2.3机器学习—— 判定梯度下降是否收敛 + α学习率的选择

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1370359.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows环境下使用HTML5播放RTSP流

本文搭建环境选用Windows平台 一、使用笔记本摄像头推送RTSP视频流 参考文章:笔记本摄像头模拟监控推送RTSP流-CSDN博客 二、搭建Vue项目 参考项目&#xff1a; HTML5 播放 rtsp视频流 2.1 下载压缩包 2.2 解压文件&#xff0c;使用VSCode打开项目 2.3 启动项目 注&#…

vue day06

1、路由模块封装 2、声明式导航 实现导航高亮效果 直接通过这两个类名对相应标签设置样式 点击a链接进入my页面时&#xff0c;a链接 我的音乐高亮&#xff0c;同时my下的a、b页面中的 我的音乐也有router-link-active类&#xff0c;但没有精确匹配的类&#xff08;只有my页…

Talk | EMNLP 2023 最佳长论文:以标签为锚-从信息流动的视角分析上下文学习

本期为TechBeat人工智能社区第561期线上Talk。 北京时间1月4日(周四)20:00&#xff0c;北京大学博士生—王乐安的Talk已准时在TechBeat人工智能社区开播&#xff01; 他与大家分享的主题是: “以标签为锚-从信息流动的视角分析上下文学习”&#xff0c;介绍了他的团队在上下文学…

Python 安卓开发:Kivy、BeeWare、Flet、Flutter

kivy&#xff1a;https://github.com/kivy python-for-android &#xff1a;https://python-for-android.readthedocs.io/en/latest/ BeeWare&#xff1a;https://docs.beeware.org/en/latest/ Flet&#xff1a;https://github.com/flet-dev/flet 把 PySide6 移植到安卓上去&a…

1880_安装QEMU_for_ARC

Grey 全部学习内容汇总&#xff1a; https://github.com/GreyZhang/g_ARC 主标题 想学习一点ARC相关的知识&#xff0c;但是手里没有开发板。看了下&#xff0c;使用QEMU似乎是一个很好的选择&#xff0c;正好也有这么一个分支。在此&#xff0c;记录一下环境搭建的过程。 …

Java项目:01 springboot智能养生平台设计与实现

项目介绍 Java项目 智能养生平台 使用技术&#xff1a;springmybatisspringmvchtmlJavaScriptcsslayuijQuery 运行环境 jdk8mysqlIntelliJ IDEAmaven 主要分两个端&#xff0c;用户端和管理员端 网站功能&#xff1a;实现论坛帖子管理&#xff0c;论坛帖子分类管理&#xff0c…

计算机基础面试题 |20.精选计算机基础面试题

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

YOLOv5改进 | 2023主干篇 | EfficientViT替换Backbone(高效的视觉变换网络)

一、本文介绍 本文给大家带来的改进机制是EfficientViT(高效的视觉变换网络),EfficientViT的核心是一种轻量级的多尺度线性注意力模块,能够在只使用硬件高效操作的情况下实现全局感受野和多尺度学习。本文带来是2023年的最新版本的EfficientViT网络结构,论文题目是Effici…

4.MapReduce 序列化

目录 概述序列化序列化反序例化java自带的两种Serializable非Serializable hadoop序例化实践 分片/InputFormat & InputSplit日志 结束 概述 序列化是分布式计算中很重要的一环境&#xff0c;好的序列化方式&#xff0c;可以大大减少分布式计算中&#xff0c;网络传输的数…

推荐VSCODE插件:为`package.json`添加注释信息

众所周知&#xff0c;JSON文件是不支持注释的&#xff0c;除了JSON5/JSONC之外&#xff0c;我们在开发项目特别是前端项目时&#xff0c;大量会用到JSON文件&#xff0c;特别是在编写package.json中的scripts时&#xff0c;由于缺少注释,当有大量的命令脚本时&#xff0c;就有了…

Spring Cloud 介绍

文章目录 微服务技术栈Spring Cloud 介绍京东、阿里的微服务架构SpringBoot 和 SpringCloud 版本选择Springboot版本选择Springcloud版本选择Springcloud和Springboot之间的依赖关系如何看Spring Cloud 组件的升级替换 微服务技术栈 [toc] Spring Cloud 介绍 Spring Cloud是…

Element+vue3.0 tabel合并单元格span-method

Elementvue3.0 tabel合并单元格 span-method :span-method"objectSpanMethod"详解&#xff1a; 在 objectSpanMethod 方法中&#xff0c;rowspan 和 colspan 的值通常用来定义单元格的行跨度和列跨度。 一般来说&#xff0c;rowspan 和 colspan 的值应该是大于等于…

一氧化碳中毒悲剧频发:探究道合顺电化学传感器促进家庭取暖安全

1月6日&#xff0c;陕西省榆林市发生了一起疑似因使用煤炭炉取暖中毒事件。通报称&#xff0c;经公安部门现场调查&#xff0c;并结合医院救治情况&#xff0c;初步判断5人属一氧化碳中毒&#xff0c;其中4人抢救无效死亡&#xff0c;令人痛心。 一般来说&#xff0c;这种在日…

【System Verilog and UVM实力进阶1】SVA语法

毛主席说过&#xff1a;人不犯我我不犯人&#xff0c;人若犯我我必犯人。 目录 1 SVA介绍 1.1 什么是断言 1.2 为什么用System Verilog 断言&#xff08;SVA&#xff09; 1.3 System Verilog的调度 1.4 SVA术语 1.4.1 并发断言 1.4.2 即时断言 1.5 建立SVA块 1.6 一个简…

抖音矩阵云混剪系统源码 短视频矩阵营销系统V2.2.1(免授权版)

抖音矩阵云混剪系统源码 短视频矩阵营销系统V2.2.1&#xff08;免授权版&#xff09; 中网智达矩阵营销系统多平台多账号一站式管理&#xff0c;一键发布作品。智能标题&#xff0c;关键词优化&#xff0c;排名查询&#xff0c;混剪生成原创视频&#xff0c;账号分组&#xff…

基于Jackson自定义json数据的对象转换器

1、问题说明 后端数据表定义的id主键是Long类型&#xff0c;一共有20多位。 前端在接收到后端返回的json数据时&#xff0c;Long类型会默认当做数值类型进行处理。但前端处理20多位的数值会造成精度丢失&#xff0c;于是导致前端查询数据出现问题。 测试前端Long类型的代码 …

单机多卡训练报错NCCL版本有问题

torch.distributedtorch.distributed…DistBackendErrorDistBackendError: : NCCL error in: …/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1275, internal error, NCCL version 2.14.3 这个不知道什么原因&#xff0c;然后解决方法是 增加环境变量NCCL_SOCKET_IFNAM…

FreeRTOS——软件定时器

一、什么是定时器 简单可以理解为闹钟&#xff0c;到达指定一段时间后&#xff0c;就会响铃。 STM32 芯片自带硬件定时器&#xff0c;精度较高&#xff0c;达到定时时间后会触发中断&#xff0c;也可以生成 PWM 、输入捕获、输出 比较&#xff0c;等等&#xff0c;功能强大&a…

HarmonyOS应用开发学习笔记 应用上下文Context 获取文件夹路径

1、 HarmoryOS Ability页面的生命周期 2、 Component自定义组件 3、HarmonyOS 应用开发学习笔记 ets组件生命周期 4、HarmonyOS 应用开发学习笔记 ets组件样式定义 Styles装饰器&#xff1a;定义组件重用样式 Extend装饰器&#xff1a;定义扩展组件样式 5、HarmonyOS 应用开发…

用html和css实现一个加载页面【究极简单】

要创建一个简单的加载页面&#xff0c;你可以使用 HTML 和 CSS 来设计。以下是一个基本的加载页面示例&#xff1a; HTML 文件 (index.html): <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"…