【机器学习】loss损失讨论

news2024/12/24 16:51:59

大纲

  • 验证集loss上升,准确率也上升(即将overfitting?)
  • 训练集loss一定为要为0吗

Q1. 验证集loss上升,准确率也上升

随着置信度的增加,一小部分点的预测结果是错误的(log lik 给出了指数级的惩罚,在损失中占主导地位)。与此同时,大量其他点开始预测良好(argmax p=label),主导了预测的准确性。
在这里插入图片描述


Q2. 训练集loss一定为要为0吗

一般来说,我们是用训练集来训练模型,但希望的是验证机的损失越小越好,而正常来说训练集的损失降到一定值后,验证集的损失就会开始上升,因此没必要把训练集的损失降低到 0

既然如此,在已经达到了某个阈值之后,我们可不可以做点别的事情来提升模型性能呢?ICML2020 的论文《Do We Need Zero Training Loss After Achieving Zero Training Error?》回答了这个问题,不过实际上它并没有很好的描述 “为什么”,而只是提出了 “怎么做”

假设原来的损失函数是 L ( θ ) \mathcal {L}(\theta) L(θ),现在改为 L ~ ( θ ) \tilde {\mathcal {L}}(\theta) L~(θ)
L ~ ( θ ) = ∣ L ( θ ) − b ∣ + b (1) \tilde{\mathcal{L}}(\theta)=|\mathcal{L}(\theta)-b|+b\tag{1} L~(θ)=L(θ)b+b(1)

其中 b b b 是预先设定的阈值。当 L ( θ ) > b \mathcal {L}(\theta)>b L(θ)>b L ~ ( θ ) = L ( θ ) \tilde {\mathcal {L}}(\theta)=\mathcal {L}(\theta) L~(θ)=L(θ),这时就是执行普通的梯度下降;而 L ( θ ) < b \mathcal {L}(\theta)<b L(θ)<b L ~ ( θ ) = 2 b − L ( θ ) \tilde {\mathcal {L}}(\theta)=2b-\mathcal {L}(\theta) L~(θ)=2bL(θ),注意到损失函数变号了,所以这时候是梯度上升。因此,总的来说就是以 b b b 为阈值,低于阈值时反而希望损失函数变大。论文把这个改动称为 “Flooding”
这样做有什么效果呢?论文显示,在某些任务中,训练集的损失函数经过这样处理后,验证集的损失能出现 “二次下降(Double Descent)”,如下图
在这里插入图片描述

在这里插入图片描述

如何解释这个方法呢?可以想像,当损失函数达到 b b b 之后,训练流程大概就是在交替执行梯度下降和梯度上升。直观想的话,感觉一步上升一步下降,似乎刚好抵消了。事实真的如此吗?我们来算一下看看。假设先下降一步后上升一步,学习率为 ε \varepsilon ε,那么:
θ n = θ n − 1 − ε g ( θ n − 1 ) θ n + 1 = θ n + ε g ( θ n ) \begin{equation}\begin{aligned}&\theta_n = \theta_{n-1} - \varepsilon g(\theta_{n-1})\\ &\theta_{n+1} = \theta_n + \varepsilon g(\theta_n) \end{aligned}\tag{2}\end{equation} θn=θn1εg(θn1)θn+1=θn+εg(θn)(2)

其中 g ( θ ) = ∇ θ L ( θ ) g (\theta)=\nabla_{\theta}\mathcal {L}(\theta) g(θ)=θL(θ),现在我们有
θ n + 1 =   θ n − 1 − ε g ( θ n − 1 ) + ε g ( θ n − 1 − ε g ( θ n − 1 ) ) ≈   θ n − 1 − ε g ( θ n − 1 ) + ε ( g ( θ n − 1 ) − ε ∇ θ g ( θ n − 1 ) g ( θ n − 1 ) ) =   θ n − 1 − ε 2 2 ∇ θ ∥ g ( θ n − 1 ) ∥ 2 \begin{equation}\begin{aligned}\theta_{n+1} =&\, \theta_{n-1} - \varepsilon g(\theta_{n-1}) + \varepsilon g\big(\theta_{n-1} - \varepsilon g(\theta_{n-1})\big)\\ \approx&\,\theta_{n-1} - \varepsilon g(\theta_{n-1}) + \varepsilon \big(g(\theta_{n-1}) - \varepsilon \nabla_{\theta} g(\theta_{n-1}) g(\theta_{n-1})\big)\\ =&\,\theta_{n-1} - \frac{\varepsilon^2}{2}\nabla_{\theta}\Vert g(\theta_{n-1})\Vert^2 \end{aligned}\tag{3}\end{equation} θn+1==θn1εg(θn1)+εg(θn1εg(θn1))θn1εg(θn1)+ε(g(θn1)εθg(θn1)g(θn1))θn12ε2θg(θn1)2(3)

近似那一步实际上是使用了泰勒展开,我们将 θ n − 1 \theta_{n-1} θn1 看作 x x x ε g ( θ n − 1 ) \varepsilon g (\theta_{n-1}) εg(θn1) 看作 Δ x \Delta x Δx,由于
g ( x − Δ x ) − g ( x ) − Δ x = ∇ x g ( x ) \frac{g(x - \Delta x) - g(x)}{-\Delta x} = \nabla_x g(x) Δxg(xΔx)g(x)=xg(x) 所以
g ( x − Δ x ) = g ( x ) − Δ x ∇ x g ( x ) g(x - \Delta x) = g(x) - \Delta x \nabla_x g(x) g(xΔx)=g(x)Δxxg(x)

最终的结果就是相当于学习率为 ε 2 2 \frac {\varepsilon^2}{2} 2ε2、损失函数为梯度惩罚 ∥ g ( θ ) ∥ 2 = ∥ ∇ θ L ( θ ) ∥ 2 \Vert g (\theta)\Vert^2 = \Vert \nabla_{\theta} \mathcal {L}(\theta)\Vert^2 g(θ)2=θL(θ)2 的梯度下降。更妙的是,改为 “先上升再下降”,其表达式依然是一样的(这不禁让我想起 “先涨价 10% 再降价 10%” 和 “先降价 10% 再涨价 10% 的故事”)。因此,平均而言,Flooding 对损失函数的改动,相当于在保证了损失函数足够小之后去最小化 ∥ ∇ x L ( θ ) ∥ 2 \Vert \nabla_x \mathcal {L}(\theta)\Vert^2 xL(θ)2,也就是推动参数往更平稳的区域走,这通常能提高泛化性(更好地抵抗扰动),因此一定程度上就能解释 Flooding 有作用的原因了

本质上来讲,这跟往参数里边加入随机扰动、对抗训练等也没什么差别,只不过这里是保证了损失足够小后再加扰动

想要使用 Flooding 非常简单,只需要在原有代码基础上增加一行即可

logits = model(x)
loss = criterion(logits, y)
loss = (loss - b).abs() + b # This is it!
optimizer.zero_grad()
loss.backward()
optimizer.step()

有心是用这个方法的读者可能会纠结于 b b b 的选择,原论文说 b b b 的选择是一个暴力迭代的过程,需要多次尝试

The flood level is chosen from b ∈ { 0 , 0.01 , 0.02 , . . . , 0.50 } b\in \{0, 0.01,0.02,...,0.50\} b{0,0.01,0.02,...,0.50}

不过笔者倒是有另外一个脑洞: b b b 无非就是决定什么时候开始交替训练罢了,那如果我们从一开始就用不同的学习率进行交替训练呢?也就是自始自终都执行
θ n = θ n − 1 − ε 1 g ( θ n − 1 ) θ n + 1 = θ n + ε 2 g ( θ n ) \begin{equation}\begin{aligned}&\theta_n = \theta_{n-1} - \varepsilon_1 g(\theta_{n-1})\\ &\theta_{n+1} = \theta_n + \varepsilon_2 g(\theta_n) \end{aligned}\tag{4}\end{equation} θn=θn1ε1g(θn1)θn+1=θn+ε2g(θn)(4)

其中 ε 1 > ε 2 \varepsilon_1 > \varepsilon_2 ε1>ε2,这样我们就把 b b b 去掉了(引入了 ε 1 , ε 2 \varepsilon_1, \varepsilon_2 ε1,ε2 的选择,天下没有免费的午餐)。重复上述近似展开,我们就得到
θ n + 1 =   θ n − 1 − ε 1 g ( θ n − 1 ) + ε 2 g ( θ n − 1 − ε 1 g ( θ n − 1 ) ) ≈   θ n − 1 − ε 1 g ( θ n − 1 ) + ε 2 ( g ( θ n − 1 ) − ε 1 ∇ θ g ( θ n − 1 ) g ( θ n − 1 ) ) =   θ n − 1 − ( ε 1 − ε 2 ) g ( θ n − 1 ) − ε 1 ε 2 2 ∇ θ ∥ g ( θ n − 1 ) ∥ 2 =   θ n − 1 − ( ε 1 − ε 2 ) ∇ θ [ L ( θ n − 1 ) + ε 1 ε 2 2 ( ε 1 − ε 2 ) ∥ ∇ θ L ( θ n − 1 ) ∥ 2 ] \begin{equation}\begin{aligned} \theta_{n+1} =& \, \theta_{n-1} - \varepsilon_1g(\theta_{n-1})+\varepsilon_2g(\theta_{n-1} - \varepsilon_1g(\theta_{n-1}))\\ \approx&\, \theta_{n-1} - \varepsilon_1g(\theta_{n-1}) + \varepsilon_2(g(\theta_{n-1}) - \varepsilon_1\nabla_\theta g(\theta_{n-1})g(\theta_{n-1}))\\ =&\, \theta_{n-1} - (\varepsilon_1 - \varepsilon_2) g(\theta_{n-1}) - \frac{\varepsilon_1\varepsilon_2}{2}\nabla_{\theta}\Vert g(\theta_{n-1})\Vert^2\\ =&\,\theta_{n-1} - (\varepsilon_1 - \varepsilon_2)\nabla_{\theta}\left[\mathcal{L}(\theta_{n-1}) + \frac{\varepsilon_1\varepsilon_2}{2(\varepsilon_1 - \varepsilon_2)}\Vert \nabla_{\theta}\mathcal{L}(\theta_{n-1})\Vert^2\right] \end{aligned}\tag{5}\end{equation} θn+1===θn1ε1g(θn1)+ε2g(θn1ε1g(θn1))θn1ε1g(θn1)+ε2(g(θn1)ε1θg(θn1)g(θn1))θn1(ε1ε2)g(θn1)2ε1ε2θg(θn1)2θn1(ε1ε2)θ[L(θn1)+2(ε1ε2)ε1ε2θL(θn1)2](5)

这就相当于自始自终都在用学习率 ε 1 − ε 2 \varepsilon_1-\varepsilon_2 ε1ε2 来优化损失函数 L ( θ ) + ε 1 ε 2 2 ( ε 1 − ε 2 ) ∥ ∇ θ L ( θ ) ∥ 2 \mathcal {L}(\theta) + \frac {\varepsilon_1\varepsilon_2}{2 (\varepsilon_1 - \varepsilon_2)}\Vert\nabla_{\theta}\mathcal {L}(\theta)\Vert^2 L(θ)+2(ε1ε2)ε1ε2θL(θ)2 了,也就是说一开始就把梯度惩罚给加了进去,这样能提升模型的泛化性能吗?《Backstitch: Counteracting Finite-sample Bias via Negative Steps》里边指出这种做法在语音识别上是有效的,请读者自行测试甄别

效果检验

我随便在网上找了个竞赛,然后利用别人提供的以 BERT 为 baseline 的代码,对 Flooding 的效果进行了测试,下图分别是没有做 Flooding 和参数 b = 0.7 b=0.7 b=0.7 的 Flooding 损失值变化图,值得一提的是,没有做 Flooding 的验证集最低损失值为 0.814198,而做了 Flooding 的验证集最低损失值为 0.809810
在这里插入图片描述

根据知乎文章一行代码发一篇 ICML?底下用户 Curry 评论所言:“通常来说 b b b 值需要设置成比 'Validation Error 开始上升 ’ 的值更小,1/2 处甚至更小,结果更优”,所以我仔细观察了下没有加 Flooding 模型损失值变化图,大概在 loss 为 0.75 到 1.0 左右的时候开始出现过拟合现象,因此我又分别设置了 b = 0.4 b=0.4 b=0.4 b = 0.5 b=0.5 b=0.5,做了两次 Flooding 实验,结果如下图
在这里插入图片描述

值得一提的是, b = 0.4 b=0.4 b=0.4 b = 0.5 b=0.5 b=0.5 时,验证集上的损失值最低仅为 0.809958 和 0.796819,而且很明显验证集损失的整体上升趋势更加缓慢。接下来我做了一个实验,主要是验证 “继续脑洞” 部分以不同的学习率一开始就交替着做梯度下降和梯度上升的效果,其中,梯度下降的学习率我设为 1 e − 5 1e-5 1e5,梯度上升的学习率为 1 e − 6 1e-6 1e6,结果如下图,验证集的损失最低仅有 0.783370在这里插入图片描述

References

我们真的需要把训练集的损失降低到零吗?
LossUpAccUp -Github
https://wmathor.com/index.php/archives/1551/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1148847.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VSCode编写Unity代码自动补全配置

1.下载并安装.NET 7.0&#xff08;C#插件需要&#xff09;和.NET Framework 4.7.1&#xff08;Unity需要&#xff09; .NET 7.0下载链接&#xff1a;https://dotnet.microsoft.com/en-us/download .NET Framework 4.7.1下载链接&#xff1a;https://dotnet.microsoft.com/en-…

Python 日期和时间处理教程:datetime 模块的使用

Python 中的日期不是独立的数据类型&#xff0c;但我们可以导入一个名为 datetime 的模块来使用日期作为日期对象。 示例&#xff1a;导入 datetime 模块并显示当前日期&#xff1a; import datetimex datetime.datetime.now() print(x)日期输出 当我们执行上面示例中的代码…

Java利用Scanner类,从键盘接受一个字符串输入,该字符串包含小写字母,大写字母和数字。分别输出该字符串所包含的大写字母、小写字母和数字的个数。

题目要求&#xff1a;利用Scanner类&#xff0c;从键盘接受一个字符串输入&#xff0c;该字符串包含小写字母&#xff0c;大写字母和数字。分别输出该字符串所包含的大写字母、小写字母和数字的个数。 import java.util.Scanner;public class Demo1 {public static void main(…

算法篇 : 并查集

介绍 英文名&#xff1a;union find set 作用&#xff1a;合并集合&#xff0c;查询集合 合并&#xff1a;将有直接关系的顶点放在一个集合里面 查找&#xff1a;查询某个顶点所属的集合 集合的标志&#xff1a;用祖先点的标号作为每个集合的标识 案例 如果说将下图的集合2合并…

H5游戏源码分享-接苹果游戏拼手速

H5游戏源码分享-接苹果游戏拼手速 看看在20秒内能接多少个苹果 <html> <head><title>我是你的小苹果</title><meta charset"utf-8"/><meta name"viewport" content"initial-scale1, user-scalableno, minimum-scale…

【DevChat】智能编程助手 - 使用评测

写在前面&#xff1a;博主是一只经过实战开发历练后投身培训事业的“小山猪”&#xff0c;昵称取自动画片《狮子王》中的“彭彭”&#xff0c;总是以乐观、积极的心态对待周边的事物。本人的技术路线从Java全栈工程师一路奔向大数据开发、数据挖掘领域&#xff0c;如今终有小成…

【unity小技巧】unity排序问题的探究

文章目录 前言一、排序图层二、sorting Group的使用三、树木排序设计方法一 代码控制方法二 拆分图片方法三 透视排序1. 普通物品排序2. TileMap瓦片排序设计 完结 前言 unity的排序问题其实之前分享的项目多多少少都有提到一点&#xff0c;但是没有单独拿出来说&#xff0c;所…

常用第三方库

Moment GTC(Greenwish Mean Time)&#xff1a;格林威治时间&#xff0c;太阳时&#xff0c;精确到毫秒UTC(Universal Time Coodinated)&#xff1a;世界协调时间&#xff0c;原子种计时&#xff0c;精确到纳秒 GTC和UTC都是以0时区作为标准时间戳&#xff1a;以UTC的1970-1-1 …

天气数据可视化平台-计算机毕业设计vue

天气变幻无常&#xff0c;影响着我们生活的方方面面&#xff0c;应用天气预报信息可以及时了解天气的趋势&#xff0c;给人们的工作、生活等带来便利&#xff0c;也可以为我们为未来的事情做安排和打算&#xff0c;所以一个精准的、易读 通过利用 程序对气象网站大量的气象信息…

H5游戏源码分享-命悬一线

H5游戏源码分享-命悬一线 在合适的时机跳下绳子&#xff0c;能安全站到木桩上&#xff0c;就通过。 游戏源码 <!DOCTYPE html> <html> <head><meta http-equiv"Content-Type" content"text/html; charsetutf-8" /><meta name&…

Mybatis延迟加载(缓存)

延迟加载 分步查询的优点&#xff1a;可以实现延迟加载&#xff0c;但是必须在核心配置文件中设置全局配置信息&#xff1a;lazyLoadingEnabled&#xff1a;延迟加载的全局开关。当开启时&#xff0c;所有关联对象都会延迟加载 aggressiveLazyLoading&#xff1a;当开启时&…

Openssl数据安全传输平台017:Linux客户端代码的编译与调试-Bug记录

文章目录 1 在windows上先预编译2 Centos上进入项目文件夹进行编译2.0 最终的编译指令2.1 找不到lprotobuf&#xff0c;找不到protobuf的google文件夹2.1.1 编译指令及提示2.1.2 问题分析2.1.3 解决办法 2.2 json类中方法unreference2.2.1 编译指令及提示2.2.2 问题分析 *** 最…

hadoop权威指南第四版

第一部分 HaDOOP基础知识 1.1 面临的问题 存储越来越大&#xff0c;读写跟不上。 并行读多个磁盘。 问题1 磁盘损坏 – 备份数据HDFS 问题2 读取多个磁盘用于分析&#xff0c;数据容易出错 --MR 编程模型 1.2 衍生品 1 在线访问的组件是hbase 。一种使用hdfs底层存储的模型。…

Spring中Bean的完整生命周期!(Bean实例化的流程,Spring后处理器,循环依赖解释及解决方法)附案例演示

Bean实例化的基本流程 加载xml配置文件&#xff0c;解析获取配置中的每个的信息&#xff0c;封装成一个个的BeanDefinition对象将BeanDefinition存储在一个名为beanDefinitionMap的Map<String,BeanDefinition>中ApplicationContext底层遍历beanDefinitionMap&#xff0c…

寻找倒数第K个节点

这篇文章也是凑数的 ... 寻找倒数第K个节点 描述 : 找出单向链表中倒数第 k 个节点。返回该节点的值。 题目 : LeetCode 返回倒数第K个节点 : 面试题 02.02. 返回倒数第 k 个节点 说明 : 给定的 k 保证是有效的。 分析 : 我们给出个例子 : 首先&#xff0c;我们创建两个…

Java-API简析_java.io.FilterOutputStream类(基于 Latest JDK)(浅析源码)

【版权声明】未经博主同意&#xff0c;谢绝转载&#xff01;&#xff08;请尊重原创&#xff0c;博主保留追究权&#xff09; https://blog.csdn.net/m0_69908381/article/details/134106510 出自【进步*于辰的博客】 因为我发现目前&#xff0c;我对Java-API的学习意识比较薄弱…

ubuntu安装tinycudann

tinycudann几乎是nerf开发环境必备的条件之一。在测试sdfstudio和nerfstudio都需要安装tinycudann。按照官方的操作步骤&#xff0c;最简单的安装步骤为&#xff1a; pip install githttps://github.com/NVlabs/tiny-cuda-nn/#subdirectorybindings/torch常见会出现如下两个问…

requests之post请求实例-百度翻译

视频版教程&#xff1a;一天掌握python爬虫【基础篇】 涵盖 requests、beautifulsoup、selenium 打开百度翻译网址&#xff0c;我们输入需要翻译的英文&#xff0c;谷歌 F12 打开开发者工具&#xff0c;network可以看到网络请求&#xff0c;我们需要找到请求的API&#xff0c;…

Git 提交时提示 GPG 签名错误

本来应该一切都是正常的&#xff0c;但今天提交的时候提示 GPG 签名错误。 错误的信息就是 GPG 签名失败。 gpg: skipped "942395299055675C": No secret key gpg: signing failed: No secret key error: gpg failed to sign the data fatal: failed to write commi…

PicGo+雨云ROS搭建自己的图床,可配合Typora使用

本文将手把手带你使用 PicGo雨云对象存储ROS(Rain Object Storage) 搭建自己专属的免费图床&#xff0c;并且可以配合Typora使用。 雨云对象存储服务介绍和使用教程&#xff1a;https://blog.zeruns.tech/archives/733.html 目前雨云对象存储是公测阶段&#xff0c;暂时是免费…