Python】深度学习基础知识——随机梯度下降详解和示例

news2025/3/12 1:42:35

本文通过原理和示例对随机梯度下降进行了详解,并和梯度下降进行了对比分析,简单易懂。

  • 随机梯度下降
    • 原理
    • 示例
  • 动态学习率
    • 动态学习率
    • 示例
  • 总结

随机梯度下降

原理

在这里插入图片描述

示例

import torch
import torch.nn as nn
import matplotlib.pyplot as plt


def train_2d(trainer, steps=20, f_grad=None):  #@save
    """用定制的训练机优化2D目标函数"""
    # s1和s2是稍后将使用的内部状态变量
    x1, x2, s1, s2 = -5, -2, 0, 0
    results = [(x1, x2)]
    for i in range(steps):
        if f_grad:
            x1, x2, s1, s2 = trainer(x1, x2, s1, s2, f_grad)
        else:
            x1, x2, s1, s2 = trainer(x1, x2, s1, s2)
        results.append((x1, x2))
    	print(f'epoch {i + 1}, x1: {float(x1):f}, x2: {float(x2):f}')
    return results

def show_trace_2d(f, results):  #@save
    """显示优化过程中2D变量的轨迹"""
    plt.figure(figsize=(6, 3))
    plt.plot(*zip(*results), '-o', color='#ff7f0e')
    x1, x2 = torch.meshgrid(torch.arange(-5.5, 1.0, 0.1),
                          torch.arange(-3.0, 1.0, 0.1), indexing='ij')
    plt.contour(x1, x2, f(x1, x2), colors='#1f77b4')
    plt.xlabel('x1')
    plt.ylabel('x2')

def f(x1, x2):  # 目标函数
    return x1 ** 2 + 2 * x2 ** 2

def f_grad(x1, x2):  # 目标函数的梯度
    return 2 * x1, 4 * x2

def sgd(x1, x2, s1, s2, f_grad):
    g1, g2 = f_grad(x1, x2)
    # 模拟有噪声的梯度
    g1 += torch.normal(0.0, 1, (1,)).item()
    g2 += torch.normal(0.0, 1, (1,)).item()
    eta_t = eta * lr()
    return (x1 - eta_t * g1, x2 - eta_t * g2, 0, 0)

def constant_lr():
    return 1

eta = 0.1
lr = constant_lr  # 常数学习速度
show_trace_2d(f, train_2d(sgd, steps=50, f_grad=f_grad))

输出:

epoch 1, x1: -4.021846, x2: -1.076427
epoch 2, x1: -3.268095, x2: -0.605968
epoch 3, x1: -2.635950, x2: -0.365234
epoch 4, x1: -1.919869, x2: -0.238383
epoch 5, x1: -1.398639, x2: -0.098951
epoch 6, x1: -1.121853, x2: -0.104999
epoch 7, x1: -0.872707, x2: -0.064939
epoch 8, x1: -0.867427, x2: -0.210833
epoch 9, x1: -0.693494, x2: -0.037735
epoch 10, x1: -0.439256, x2: -0.120039
epoch 11, x1: -0.271984, x2: -0.087188
epoch 12, x1: -0.114489, x2: 0.089921
epoch 13, x1: -0.030704, x2: 0.034423
epoch 14, x1: 0.025138, x2: 0.039062
epoch 15, x1: 0.079952, x2: 0.154875
epoch 16, x1: -0.031280, x2: 0.187696
epoch 17, x1: 0.015347, x2: 0.018523
epoch 18, x1: -0.117373, x2: 0.125311
epoch 19, x1: -0.115787, x2: 0.256176
epoch 20, x1: 0.061688, x2: -0.018338

随机梯度下降迭代过程
在这里插入图片描述

梯度下降迭代过程
在这里插入图片描述

从上面的两个图对比可以看出,随机梯度下降中变量的轨迹比梯度下降中的轨迹更混乱一些。这是由于梯度的随机性质。也就是说,即使我们接近最小值,我们仍然受到通过的瞬间梯度所注入的不确定性的影响。即使经过20次迭代,质量仍然不那么好。更糟糕的是,经过额外的步骤,它不会得到改善。这给我们留下了唯一的选择:改变学习率。但是,如果我们选择的学习率太小,我们一开始就不会取得任何有意义的进展。另一方面,如果我们选择的学习率太大,我们将无法获得一个好的解决方案,如上所示。解决这些相互冲突的目标的唯一方法是在优化过程中动态降低学习率。

动态学习率

动态学习率

用与时间相关的学习率n(t)取代n增加了控制优化算法收敛的复杂性。特别是,我们需要弄清n的衰减速度。如果太快,我们将过早停止优化。如果减少的太慢,我们会在优化上浪费太多时间。以下是随着时间推移调整n时使用的一些基本策略(稍后我们将讨论更高级的策略):
在这里插入图片描述
在第一个分段常数(piecewise constant)场景中,我们会降低学习率,例如,每当优化进度停顿时。这是训练深度网络的常见策略。或者,我们可以通过指数衰减(exponential decay)来更积极地减低它。

示例

import math
def exponential_lr():
    # 在函数外部定义,而在内部更新的全局变量
    global t
    t += 1
    return math.exp(-0.1 * t)

t = 1
lr = exponential_lr
show_trace_2d(f, train_2d(sgd, steps=1000, f_grad=f_grad))
epoch 1, x1: -4.210943, x2: -1.249196
epoch 2, x1: -3.593916, x2: -0.855586
epoch 3, x1: -3.100004, x2: -0.743650
epoch 4, x1: -2.810694, x2: -0.598492
epoch 5, x1: -2.505691, x2: -0.404280
''''
''''
''''
epoch 993, x1: -0.832819, x2: 0.002843
epoch 994, x1: -0.832819, x2: 0.002843
epoch 995, x1: -0.832819, x2: 0.002843
epoch 996, x1: -0.832819, x2: 0.002843
epoch 997, x1: -0.832819, x2: 0.002843
epoch 998, x1: -0.832819, x2: 0.002843
epoch 999, x1: -0.832819, x2: 0.002843
epoch 1000, x1: -0.832819, x2: 0.002843

在这里插入图片描述
与上图的进行对比可知,效果已经有了较大提升,但距离真实解还有一定差距。

这也是现在为什么模型训练时,要添加学习率衰减或者余弦等策略,让学习率动态变化,需要快的地方学习率大一些,需要慢的地方,学习率小一些。

总结

  1. 随机梯度下降相较于梯度下降,计算量会小很多。
  2. 合理的调整学习率,将有助于收敛和避免陷入局部最优解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1495464.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue.js大师: 构建动态Web应用的全面指南

VUE ECMAScript介绍什么是ECMAScriptECMAScript 和 JavaScript 的关系ECMAScript 6 简介 ES6新特性let基本使用const不定参数箭头函数对象简写模块化导出导入a.jsb.jsmain.js Vue简介MVVM 模式的实现者——双向数据绑定模式 Vue环境搭建在页面引入vue的js文件即可。创建div元素…

简单两步,从补税到退税

大家好,我是拭心。 最近到了一年一度的个人所得税年度申报时期,有人可以退好几千,而有的人则需要补上万元,人类的悲喜这一刻并不相通。 我申报的时候,提示我需要补税一万多,心有不甘但差一点就认了&#xf…

瑞芯微第二代8nm高性能AIOT平台 RK3576 详细介绍

RK3576处理器 RK3576瑞芯微第二代8nm高性能AIOT平台,它集成了独立的6TOPS(Tera Operations Per Second,每秒万亿次操作)NPU(神经网络处理单元),用于处理人工智能相关的任务。此外,R…

07-prometheus的自定义监控-pushgateway工具组件

一、概述 pushgateway用于自定义监控节点、节点中服务的工具,用户可以通过自定义的命令获取数据,并将数据推送给pushgateway中; prometheus服务,从pushgateway中获取监控数据; 二、部署pushgateway 我们可以“随便”找…

分库分表浅析原理

数据库存放数据大了,查询等操作就会存在瓶颈,怎么办? 1. 如果是单张表数据大了,可以在原有库上新建几张表table_0、table1、table2、.....table_n 写程序对数据进行分表: --这里提供一种一种分表策略,这里只需维护分…

VR全景数字工厂,制造业企业线上营销新助手

VR全景技术逐渐渗透到各行各业,其中,很多实体工厂的线上营销宣传也借助720云VR全景技术也迎来了新的变革。 一、VR全景技术的独特魅力 VR全景技术是一种基于虚拟现实技术的全新视觉呈现方式,能够为用户带来身临其境的沉浸式体验。通过VR全景…

Clion开发STM32之printf的缓冲区验证方式

前言 clion开发stm32时,涉及到printf函数的重写,一般我们重写__io_putchar函数经发现printf函数在没有加换行符时,数据不会立刻通过串口发送数据,而是等到1024字节之后发送数据 测试 主程序 现象 debug调试 不加换行的情况 总…

985硕的4家大厂实习与校招经历专题分享(part1)

先简单介绍一下我的个人经历: 985硕士24届毕业生,实验室方向:CV深度学习 就业:工程-java后端 关注大模型相关技术发展 校招offer: 阿里巴巴 字节跳动 等10 研究生期间独立发了一篇二区SCI 实习经历:字节 阿里 京东 B站 (只看大厂…

超简单Windows-kafka安装配置

参考大佬文章: Kafka(Windows)安装配置启动(常见错误扫雷)教程_kafka在windows上的安装、运行-CSDN博客Kafka(Windows)安装配置启动(常见错误扫雷)教程_kafka在windows上…

力扣爆刷第87天之hot100五连刷21-25

力扣爆刷第87天之hot100五连刷21-25 文章目录 力扣爆刷第87天之hot100五连刷21-25一、240. 搜索二维矩阵 II二、160. 相交链表三、206. 反转链表四、234. 回文链表五、141. 环形链表 一、240. 搜索二维矩阵 II 题目链接:https://leetcode.cn/problems/search-a-2d-…

解决cs不能生成Linux木马的问题

要解决的问题:众所周知,msf上面的shell或者是其他的shell想反弹给cs默认情况下是只支持windows的,因为cs的监听模块默认没有linux的,但是有些主机就是用linux搭建的,这可怎么办呢。就要用到一个插件CrossC2。 下载插件…

Qt 二维数组的访问与应用

配色方案有多种类型,可以根据不同的需求和应用场景来选择合适的配色方法。在柱状图、饼状图中都会用到不同的配色,本文将配色方案使用二维数组进行存储,对常用的配色进行了整理: 效果图 示例代码 void MainWindow::InitUI() {QS…

Rust入门:Rust如何调用C静态库的函数

关于Rust调用C,因为接口比较复杂,貌似Rust不打算支持。而对于C函数,则相对支持较好。 如果要研究C/Rust相互关系的话,可以参考: https://docs.rs/cxx/latest/cxx/ Rust ❤️ C 这里只对调用C静态库做一个最简短的介…

云原生架构设计:开放应用模型(OAM)的重要性与实践

在当今云计算时代,云原生架构已经成为许多企业追求的理想状态,在云原生架构设计中,开放应用模型是至关重要的一部分。本文灸哥将和你一起探讨开放应用模型的概念、意义以及实践方法,以帮助大家更好地理解和应用云原生架构中的开放…

lvs集群介绍

目录 一、LVS集群基本介绍 1、什么是集群 2、集群的类型 2.1 负载均衡群集(Load Balance Cluster) 2.2 高可用群集(High Availiablity Cluster) 2.3 高性能运算群集(High Performance Computing Cluster) 3、负载均衡集群的结构 ​编辑 4、LVS集群类型中的…

上位机图像处理和嵌入式模块部署(qmacvisual三个特色)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 了解了qmacvisual的配置之后,正常来说,我们需要了解下不同插件的功能是什么。不过我们不用着急,可以继续学习下…

K倍区间 刷题笔记

法一 前缀和暴力搜索 &#xff08;数据大会超时&#xff09; #include<iostream> #include<cstring> #include<algorithm> #include<cstdio> using namespace std; const int N100010; int a[N],s[N]; int n,k; int main(){ cin>>n>>…

面试题 -- UI控件

文章目录 一、CAAnimation的层级结构二、 UITableView优化三、离屏渲染四、Autolayout本质原理五、生成二维码的步骤 一、CAAnimation的层级结构 二、 UITableView优化 Cell复用机制Cell高度预先计算缓存Cell高度圆角切割 三、离屏渲染 指的是GPU在当前屏幕缓冲区以外新开辟…

Unity性能优化篇(九) 模型优化之LOD技术概述以及操作方法

LOD模型优化技术概述: 1.LOD技术可以根据摄像头远近来显示不同精度的模型(例如吃鸡游戏 随着跳伞高度 来显示下面树木以及建筑的模型精度) LOD模型优化技术操作方法: 可使用Unity自带的LOD Group组件&#xff0c;并根据项目的情况来调整该组件的属性。Untiy资源商店也有一些其…

Zabbix(三)

监控Nginx服务 nginx配置 增加location{} [rootwenzi ~]#vim /etc/nginx/sites-enabled/defaultserver_name _; #_是通配符。服务器将响应任何域名的请求 ...location /status { stub_status;} ...访问 http://IP/status 即可 zabbix配置 Nginx by HTTP&#xff1a;无…