随机梯度下降SGD的理解和现象分析

news2024/12/23 5:44:56

提出问题:令人疑惑的损失值

在某次瞎炼丹的过程中,出现了如下令人疑惑的损失值变化图像:

嗯,看起来还挺工整,来看看前10轮打印的具体损失值变化:

| epoch 1 |  iter 5 / 10 | time 1[s] | loss 2.3137 | lr 0.0010
| epoch 1 |  iter 10 / 10 | time 1[s] | loss 2.2976 | lr 0.0010
| epoch 2 |  iter 5 / 10 | time 1[s] | loss 2.3135 | lr 0.0010
| epoch 2 |  iter 10 / 10 | time 1[s] | loss 2.2973 | lr 0.0010
| epoch 3 |  iter 5 / 10 | time 1[s] | loss 2.3132 | lr 0.0010
| epoch 3 |  iter 10 / 10 | time 1[s] | loss 2.2970 | lr 0.0010
| epoch 4 |  iter 5 / 10 | time 1[s] | loss 2.3129 | lr 0.0010
| epoch 4 |  iter 10 / 10 | time 1[s] | loss 2.2968 | lr 0.0010
| epoch 5 |  iter 5 / 10 | time 1[s] | loss 2.3127 | lr 0.0010
| epoch 5 |  iter 10 / 10 | time 1[s] | loss 2.2965 | lr 0.0010
| epoch 6 |  iter 5 / 10 | time 1[s] | loss 2.3124 | lr 0.0010
| epoch 6 |  iter 10 / 10 | time 1[s] | loss 2.2962 | lr 0.0010
| epoch 7 |  iter 5 / 10 | time 1[s] | loss 2.3122 | lr 0.0010
| epoch 7 |  iter 10 / 10 | time 1[s] | loss 2.2960 | lr 0.0010
| epoch 8 |  iter 5 / 10 | time 1[s] | loss 2.3119 | lr 0.0010
| epoch 8 |  iter 10 / 10 | time 1[s] | loss 2.2957 | lr 0.0010
| epoch 9 |  iter 5 / 10 | time 1[s] | loss 2.3116 | lr 0.0010
| epoch 9 |  iter 10 / 10 | time 1[s] | loss 2.2954 | lr 0.0010
| epoch 10 |  iter 5 / 10 | time 1[s] | loss 2.3114 | lr 0.0010
| epoch 10 |  iter 10 / 10 | time 1[s] | loss 2.2952 | lr 0.0010

可以明显看到两列出现递减的子序列:奇数序列和偶数序列。奇数序列的损失值为2.3137, 2.3135, 2.3132, 2.3129,...;奇数序列的损失值为2.2976, 2.2973, 2.2970, 2.2968,...。事出反常必有妖,那么究竟是什么样的东西导致如此的怪象?

在尝试找具体的原因之前,我们先把涉及的具体参数描述清楚。

模型就是一个很简单的序列模型,其网络结构如下:

layers = [MatMul(W1), Sigmoid(), MatMul(W2), Sigmoid(), MSE()]

网络结构就是两层重复结构,单层为一个矩阵乘法层MatMul加上一个激活函数Sigmoid,两层计算完后用均方误差MSE计算损失值,其中参数W1,W2的赋值如下:

rn = np.random.randn
W1 = (rn(10, 1000)).astype(np.float32)
W2 = (rn(1000, 10)).astype(np.float32)

数据和标签的赋值如下:

x = (rn(1000, 10)).astype(np.float32)
t = x**2

数据就是按照正态分布随机化初始1000个10维的向量,而标签就是原来的向量按元素乘方,而炼丹的目的就是观察模型如何学习二次函数的运算法则的。
相关训练的参数如下:

epochs = 100
batch_size = 100
eval_interval = 5
lr = 0.001

训练一共进行100轮,每一轮的每一批数据有100个,对于1000个数据,那么单个轮次可以分10个批次。每个批次都会计算当前批次100个数据的平均损失值,5个批次评估一次平均损失值,然后打印出来。也就是单个轮次可以看到2次打印出来的评估数据。

显然,第1次评估的平均损失值是用前一半的数据计算出来的,而第2次的则是后一半的数据进行运算。那么可以简单猜测:造成如此令人困惑的损失值变化图像,很可能原因就在数据分批上。

本质思考:推导数学公式解释

我们先把模型抽象为数学上的函数 F F F,其具体形式如下:
L o s s = F ( x , t , w ) Loss = F(x,t,w) Loss=F(x,t,w)
其中, x x x为数据, t t t为标签, w w w为权重, L o s s Loss Loss为损失值。
考虑到数据分批,对数据分成 m m m批的情况,实际上存在 m m m个子函数,如下:
L 1 = F 1 ( x 1 , t 1 , w ) L 2 = F 2 ( x 2 , t 2 , w ) L 3 = F 3 ( x 3 , t 3 , w ) . . . L m = F m ( x m , t m , w ) \begin{matrix} L_{1} = F_{1} (x_{1},t_{1},w)\\L_{2} = F_{2} (x_{2},t_{2},w) \\L_{3} = F_{3} (x_{3},t_{3},w) \\... \\L_{m} = F_{m} (x_{m},t_{m},w) \end{matrix} L1=F1(x1,t1,w)L2=F2(x2,t2,w)L3=F3(x3,t3,w)...Lm=Fm(xm,tm,w)
如果将 w ( i , j ) w_{(i,j)} w(i,j)表示为第 i 轮 i轮 i j j j批的权重值,那么很显然对第 i i i轮的训练批次来说,存在如下关系:
w i , 0 = w i − 1 , m w i , 1 = w i , 0 + k ∂ F 1 ∂ w ∣ w = w i , 0 w i , 2 = w i , 1 + k ∂ F 2 ∂ w ∣ w = w i , 1 w i , 3 = w i , 2 + k ∂ F 3 ∂ w ∣ w = w i , 2 . . . w i , m = w i , m − 1 + k ∂ F m ∂ w ∣ w = w i , m − 1 \begin{matrix} w_{i,0}=w_{i-1,m}\\w_{i,1} = w_{i,0}+k\frac{\partial F_{1}}{\partial w}|_{w=w_{i,0}} \\w_{i,2} = w_{i,1}+k\frac{\partial F_{2}}{\partial w}|_{w=w_{i,1}} \\w_{i,3} = w_{i,2}+k\frac{\partial F_{3}}{\partial w}|_{w=w_{i,2}} \\... \\w_{i,m} = w_{i,m-1}+k\frac{\partial F_{m}}{\partial w}|_{w=w_{i,m-1}} \end{matrix} wi,0=wi1,mwi,1=wi,0+kwF1w=wi,0wi,2=wi,1+kwF2w=wi,1wi,3=wi,2+kwF3w=wi,2...wi,m=wi,m1+kwFmw=wi,m1
其中 k k k为学习率的相反数,且一般情况下取值都较小(如取 k = − 0.001 k=-0.001 k=0.001)。考虑到 k k k取值较小,所以有如下近似公式:
w i , 0 = w i − 1 , m w i , 1 = w i , 0 + k ∂ F 1 ∂ w ∣ w = w i , 0 w i , 2 ≈ w i , 1 + k ∂ F 2 ∂ w ∣ w = w i , 0 w i , 3 ≈ w i , 2 + k ∂ F 3 ∂ w ∣ w = w i , 0 . . . w i , m ≈ w i , m − 1 + k ∂ F m ∂ w ∣ w = w i , 0 \begin{matrix} w_{i,0}=w_{i-1,m}\\w_{i,1} = w_{i,0}+k\frac{\partial F_{1}}{\partial w}|_{w=w_{i,0}} \\w_{i,2} \approx w_{i,1}+k\frac{\partial F_{2}}{\partial w}|_{w=w_{i,0}} \\w_{i,3} \approx w_{i,2}+k\frac{\partial F_{3}}{\partial w}|_{w=w_{i,0}} \\... \\w_{i,m} \approx w_{i,m-1}+k\frac{\partial F_{m}}{\partial w}|_{w=w_{i,0}} \end{matrix} wi,0=wi1,mwi,1=wi,0+kwF1w=wi,0wi,2wi,1+kwF2w=wi,0wi,3wi,2+kwF3w=wi,0...wi,mwi,m1+kwFmw=wi,0
从而进一步得到如下具体的近似公式:
w i , j ≈ w i − 1 , j + ∑ t = 1 m k ∂ F t ∂ w ∣ w = w i − 1 , j w_{i,j} \approx w_{i-1,j}+\sum_{t=1}^{m} k\frac{\partial F_{t}}{\partial w}|_{w=w_{i-1,j}} wi,jwi1,j+t=1mkwFtw=wi1,j
为了直观得到结论,采用如下表示:
v t = k ∂ F t ∂ w ∣ w = w i − 1 , j v_{t} = k\frac{\partial F_{t}}{\partial w}|_{w=w_{i-1,j}} vt=kwFtw=wi1,j
那么之前的表达式就可以简写为:
w i , j ≈ w i − 1 , j + ∑ t = 1 m v t w_{i,j} \approx w_{i-1,j}+\sum_{t=1}^{m} v_{t} wi,jwi1,j+t=1mvt
对于 w i , j w_{i,j} wi,j来说, v j v_{j} vj才是其让损失值下降最快的方向,其他的向量代表其他批的数据,往往得到的方向与该方向比较随机,最后得到的和可能趋于0或者其他损失值下降不太快的方向。

因此,要想让第 j j j批的数据对应的损失值稳定下降,还得靠一轮一轮的循环才行,靠同一轮的其他批次是不太合理的(只有一部分情况才能如此)

合理外推:实验数据验证想法

如果看懂了前面的数学推导,那么很自然就能想到:对于批次 m m m较大的情况下,损失函数图像会呈现整体趋势下降的条带,如下图:

其中训练参数改动如下:

x = (rn(2000, 10)).astype(np.float32)
t = x**2
epochs = 200

你说啥?数学推导没看懂?那也没关系,其实到最后只是为了说明一个事情:你把训练数据分成很多个批次去炼丹,对于具体的某个批次的损失值下降,主要是依赖该批次的下一轮迭代,而不是同一轮的其他批次。

如果你感觉条带形状的损失值碍眼,感觉损失值起起伏伏的,很多计算资源都浪费了,那么用一招就能“瞒天过海”:把损失值的评估计算改为一整轮的平均损失,比如有 m m m批数据,那么统计损失值时使用这 m m m个批次的损失值总平均值即可,效果绝对立竿见影:

其中训练参数改动如下:

x = (rn(2000, 10)).astype(np.float32)
t = x**2
epochs = 200
batch_size = 100
eval_interval = 20

这参数里面,一共有2000个数据,100个数据为1批,共20批数据,然后20批数据评估一次整体平均损失值,训练200轮。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1657664.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Hadoop3:HDFS的Shell操作(常用命令汇总)

一、简介 什么是HDFS的Shell操作? 很简单,就是在Linux的终端,通过命令来操作HDFS。 如果,你们学习过git、docker、k8s,应该会发现,这些命令的特点和shell命令非常相似 二、常用命令 1、准备工作相关命令…

【VTKExamples::Rendering】第一期 TestAmbientSpheres(环境照明系数)

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享VTK样例TestAmbientShperes,介绍环境照明系数对Actor颜色的影响,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动…

Qt模型视图代理之QListView使用的简单介绍

往期回顾: Qt绘图与图形视图之Graphics View坐标系的简单介绍_graphics view 坐标系-CSDN博客 Qt模型视图代理之MVD(模型-视图-代理)概念的简单介绍-CSDN博客 Qt模型视图代理之QTableView应用的简单介绍-CSDN博客 Qt模型视图代理之QListView使用的简单介绍 一、最…

视频监控平台:交通运输标准JTT808设备SDK接入源代码函数分享

目录 一、JT/T 808标准简介 (一)概述 (二)协议特点 1、通信方式 2、鉴权机制 3、消息分类 (三)协议主要内容 1、位置信息 2、报警信息 3、车辆控制 4、数据转发 二、代码和解释 (一…

电商核心技术揭秘53:社群营销的策略与实施

相关系列文章 电商技术揭秘相关系列文章合集(1) 电商技术揭秘相关系列文章合集(2) 电商技术揭秘相关系列文章合集(3) 电商技术揭秘四十一:电商平台的营销系统浅析 电商技术揭秘四十二&#…

手机传输助手有哪些?如何快速互传文件?

手机已经成为我们生活和工作中不可或缺的一部分,而手机传输助手,作为一种帮助我们在不同设备之间快速、方便地共享文件的工具,其重要性不言而喻。无论是在工作中需要将文件从电脑传输到手机,还是在生活中想要与朋友分享美好的瞬间…

【智能算法】人工原生动物优化算法(APO)原理及实现

目录 1.背景2.算法原理2.1算法思想2.2算法过程 3.结果展示4.参考文献5.获取代码 1.背景 2024年,X Wang受到自然界原生动物启发,提出了人工原生动物优化算法( Artificial Protozoa Optimizer, APO)。 2.算法原理 2.1算法思想 AP…

一文玩转Vue3参数传递——全栈开发之路--前端篇(8)

全栈开发一条龙——前端篇 第一篇:框架确定、ide设置与项目创建 第二篇:介绍项目文件意义、组件结构与导入以及setup的引入。 第三篇:setup语法,设置响应式数据。 第四篇:数据绑定、计算属性和watch监视 第五篇 : 组件…

导弹追踪效果实现_unity基础开发教程

Unity开发中导弹追踪的原理与实现 前言原理逻辑实现导弹逻辑目标赋值 应用效果结语 前言 ⭕在之前的一个项目的开发中,需要加入一个导弹追踪的游戏功能,且还要实现不规则发射路径,但是这种功能是第一次做,经过查阅资料和询问做过的…

java io包

InputStream InputStream 是 Java I/O 中所有输入流的抽象基类,它定义了读取字节流的基本方法。InputStream 类提供了许多子类,用于从不同的数据源读取数据,如文件、网络连接、内存等。 InputStream 提供了以下常用的方法: int…

Magic Studio Eraser API使用教程

AI橡皮擦 - 使用网址 Magic Studio的AI橡皮擦功能非常好用,能去除图片中的杂物。但是网页版只支持低分辨率下载,想要原图就得开会员,价格不菲。 不过官网其实提供了API接入方式,并且有100次的免费试用机会 API接入网站 在这里可…

PyQt6--Python桌面开发(3.运行QTDesigner生成的ui文件程序)

运行QTDesigner生成的ui文件程序 用QTDesigner设计一个简单的UI 保存ui文件,放到项目里面去 通过pyqt6包里面的uic来加载ui文件 import sysfrom PyQt6.QtWidgets import QApplication from PyQt6 import uicif __name__ __main__:appQApplication(sys.argv)uiui…

C++对象引用作为函数参数

使用对象引用作为函数参数最常见,它不但有指针作为参数的优点,而且比指针作为参数更简单、更方便。 引用方式进行参数传递,形参对象就是实参对象的“别名”,对形参的操作其实就是对实参的操作。 例如:用对象引用进行参数传…

基于Springboot的校园悬赏任务平台(有报告)。Javaee项目,springboot项目。

演示视频: 基于Springboot的校园悬赏任务平台(有报告)。Javaee项目,springboot项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构…

1010: 折半查找的实现

解法&#xff1a; #include<iostream> #include<vector> using namespace std; void solve() {int n;cin >> n;vector<int> vec(n);for (int& x : vec) cin >> x;int x;cin >> x;int l 0, r n-1, cnt 0;while (l < r) {cnt;int…

Python语言在地球科学交叉领域中的实践技术融合应用

Python是功能强大、免费、开源&#xff0c;实现面向对象的编程语言&#xff0c;Python能够运行在Linux、Windows、Macintosh、AIX操作系统上及不同平台&#xff08;x86和arm&#xff09;&#xff0c;Python简洁的语法和对动态输入的支持&#xff0c;再加上解释性语言的本质&…

代码随想录算法训练营第四十二天| 01背包问题(二维、一维)、416.分割等和子集

系列文章目录 目录 系列文章目录动态规划&#xff1a;01背包理论基础①二维数组②一维数组&#xff08;滚动数组&#xff09; 416. 分割等和子集①回溯法&#xff08;超时&#xff09;②动态规划&#xff08;01背包&#xff09;未剪枝版剪枝版 动态规划&#xff1a;01背包理论基…

【MySQL基本查询(上)】

文章目录 一、多行插入 指定列插入数据更新表中某个数据的信息&#xff08;on duplicate&#xff09;了解affected报告信息 二、检索功能1.select 查询1.1全列查询1.2指定列查询1.3where条件筛选子句案例 2.结果排序案例 3.筛选分页结果offset实现分页 一、多行插入 指定列插…

QT:小项目:登录界面 (下一章连接数据库)

一、效果图 登录后&#xff1a; 二、项目工程结构 三、登录界面UI设计 四主界面 四、源码设计 login.h #ifndef LOGIN_H #define LOGIN_H#include <QDialog>namespace Ui { class login; }class login : public QDialog {Q_OBJECTpublic:explicit login(QWidge…

es使用遇到的bug总结

本来版本7.4.0不行&#xff0c;最后换了个版本7.15.1就可以了&#xff0c;但又出现以下问题了&#xff1a; Beanpublic ElasticsearchClient elasticsearchClient() { // RestClient client RestClient.builder(new HttpHost("localhost", 9200,"http&q…