我们真的需要把训练集的损失降到零吗?

news2024/11/26 21:36:34

在训练模型的时候,我们需要将损失函数一直训练到0吗?显然不用。一般来说,我们是用训练集来训练模型,但希望的是验证机的损失越小越好,而正常来说训练集的损失降到一定值后,验证集的损失就会开始上升,因此没必要把训练集的损失降低到0

既然如此,在已经达到了某个阈值之后,我们可不可以做点别的事情来提升模型性能呢?ICML2020的论文《Do We Need Zero Training Loss After Achieving Zero Training Error?》回答了这个问题,不过实际上它并没有很好的描述"为什么",而只是提出了"怎么做"

思路描述

论文提供的解决方案非常简单,假设原来的损失函数是 L ( θ ) \mathcal{L}(\theta) L(θ),现在改为 L ~ ( θ ) \tilde{\mathcal{L}}(\theta) L~(θ)
L ~ ( θ ) = ∣ L ( θ ) − b ∣ + b (1) \tilde{\mathcal{L}}(\theta)=|\mathcal{L}(\theta)-b|+b\tag{1} L~(θ)=L(θ)b+b(1)
其中 b b b是预先设定的阈值。当 L ( θ ) > b \mathcal{L}(\theta)>b L(θ)>b L ~ ( θ ) = L ( θ ) \tilde{\mathcal{L}}(\theta)=\mathcal{L}(\theta) L~(θ)=L(θ),这时就是执行普通的梯度下降;而 L ( θ ) < b \mathcal{L}(\theta)<b L(θ)<b L ~ ( θ ) = 2 b − L ( θ ) \tilde{\mathcal{L}}(\theta)=2b-\mathcal{L}(\theta) L~(θ)=2bL(θ),注意到损失函数变号了,所以这时候是梯度上升。因此,总的来说就是以 b b b为阈值,低于阈值时反而希望损失函数变大。论文把这个改动称为**“Flooding”**

这样做有什么效果呢?论文显示,在某些任务中,训练集的损失函数经过这样处理后,验证集的损失能出现"二次下降(Double Descent)",如下图


左图:不加Flooding的训练示意图;右图:加了Flooding的训练示意图

简单来说,就是最终的验证集效果可能更好一些,原论文的实验结果如下:


Flooding的实验结果:第一行W表示是否使用weight decay,第二行E表示是否使用early stop,第三行的F表示是否使用Flooding

个人分析

如何解释这个方法呢?可以想像,当损失函数达到 b b b之后,训练流程大概就是在交替执行梯度下降和梯度上升。直观想的话,感觉一步上升一步下降,似乎刚好抵消了。事实真的如此吗?我们来算一下看看。假设先下降一步后上升一步,学习率为 ε \varepsilon ε,那么:
θ n = θ n − 1 − ε g ( θ n − 1 ) θ n + 1 = θ n + ε g ( θ n ) (2) \begin{aligned}&\theta_n = \theta_{n-1} - \varepsilon g(\theta_{n-1})\\ &\theta_{n+1} = \theta_n + \varepsilon g(\theta_n) \end{aligned}\tag{2} θn=θn1εg(θn1)θn+1=θn+εg(θn)(2)
其中 g ( θ ) = ∇ θ L ( θ ) g(\theta)=\nabla_{\theta}\mathcal{L}(\theta) g(θ)=θL(θ),现在我们有
θ n + 1 =   θ n − 1 − ε g ( θ n − 1 ) + ε g ( θ n − 1 − ε g ( θ n − 1 ) ) ≈   θ n − 1 − ε g ( θ n − 1 ) + ε ( g ( θ n − 1 ) − ε ∇ θ g ( θ n − 1 ) g ( θ n − 1 ) ) =   θ n − 1 − ε 2 2 ∇ θ ∥ g ( θ n − 1 ) ∥ 2 (3) \begin{aligned}\theta_{n+1} =&\, \theta_{n-1} - \varepsilon g(\theta_{n-1}) + \varepsilon g\big(\theta_{n-1} - \varepsilon g(\theta_{n-1})\big)\\ \approx&\,\theta_{n-1} - \varepsilon g(\theta_{n-1}) + \varepsilon \big(g(\theta_{n-1}) - \varepsilon \nabla_{\theta} g(\theta_{n-1}) g(\theta_{n-1})\big)\\ =&\,\theta_{n-1} - \frac{\varepsilon^2}{2}\nabla_{\theta}\Vert g(\theta_{n-1})\Vert^2 \end{aligned}\tag{3} θn+1==θn1εg(θn1)+εg(θn1εg(θn1))θn1εg(θn1)+ε(g(θn1)εθg(θn1)g(θn1))θn12ε2θg(θn1)2(3)

近似那一步实际上是使用了泰勒展开,我们将 θ n − 1 \theta_{n-1} θn1看作 x x x ε g ( θ n − 1 ) \varepsilon g(\theta_{n-1}) εg(θn1)看作 Δ x \Delta x Δx,由于
g ( x − Δ x ) − g ( x ) − Δ x = ∇ x g ( x ) \frac{g(x - \Delta x) - g(x)}{-\Delta x} = \nabla_x g(x) Δxg(xΔx)g(x)=xg(x)
所以
g ( x − Δ x ) = g ( x ) − Δ x ∇ x g ( x ) g(x - \Delta x) = g(x) - \Delta x \nabla_x g(x) g(xΔx)=g(x)Δxxg(x)

最终的结果就是相当于学习率为 ε 2 2 \frac{\varepsilon^2}{2} 2ε2、损失函数为梯度惩罚 ∥ g ( θ ) ∥ 2 = ∥ ∇ θ L ( θ ) ∥ 2 \Vert g(\theta)\Vert^2 = \Vert \nabla_{\theta} \mathcal{L}(\theta)\Vert^2 g(θ)2=θL(θ)2的梯度下降。更妙的是,改为"先上升再下降",其表达式依然是一样的(这不禁让我想起"先涨价10%再降价10%“和"先降价10%再涨价10%的故事”)。因此,平均而言,Flooding对损失函数的改动,相当于在保证了损失函数足够小之后去最小化 ∥ ∇ x L ( θ ) ∥ 2 \Vert \nabla_x \mathcal{L}(\theta)\Vert^2 xL(θ)2,也就是推动参数往更平稳的区域走,这通常能提高泛化性(更好地抵抗扰动),因此一定程度上就能解释Flooding有作用的原因了

本质上来讲,这跟往参数里边加入随机扰动、对抗训练等也没什么差别,只不过这里是保证了损失足够小后再加扰动

继续脑洞

想要使用Flooding非常简单,只需要在原有代码基础上增加一行即可

logits = model(x)
loss = criterion(logits, y)
loss = (loss - b).abs() + b # This is it!
optimizer.zero_grad()
loss.backward()
optimizer.step()

有心是用这个方法的读者可能会纠结于 b b b的选择,原论文说 b b b的选择是一个暴力迭代的过程,需要多次尝试

The flood level is chosen from b ∈ { 0 , 0.01 , 0.02 , . . . , 0.50 } b\in \{0, 0.01,0.02,...,0.50\} b{0,0.01,0.02,...,0.50}

不过笔者倒是有另外一个脑洞: b b b无非就是决定什么时候开始交替训练罢了,那如果我们从一开始就用不同的学习率进行交替训练呢?也就是自始自终都执行
θ n = θ n − 1 − ε 1 g ( θ n − 1 ) θ n + 1 = θ n + ε 2 g ( θ n ) (4) \begin{aligned}&\theta_n = \theta_{n-1} - \varepsilon_1 g(\theta_{n-1})\\ &\theta_{n+1} = \theta_n + \varepsilon_2 g(\theta_n) \end{aligned}\tag{4} θn=θn1ε1g(θn1)θn+1=θn+ε2g(θn)(4)
其中 ε 1 > ε 2 \varepsilon_1 > \varepsilon_2 ε1>ε2,这样我们就把 b b b去掉了(引入了 ε 1 , ε 2 \varepsilon_1, \varepsilon_2 ε1,ε2的选择,天下没有免费的午餐)。重复上述近似展开,我们就得到
θ n + 1 =   θ n − 1 − ε 1 g ( θ n − 1 ) + ε 2 g ( θ n − 1 − ε 1 g ( θ n − 1 ) ) ≈   θ n − 1 − ε 1 g ( θ n − 1 ) + ε 2 ( g ( θ n − 1 ) − ε 1 ∇ θ g ( θ n − 1 ) g ( θ n − 1 ) ) =   θ n − 1 − ( ε 1 − ε 2 ) g ( θ n − 1 ) − ε 1 ε 2 2 ∇ θ ∥ g ( θ n − 1 ) ∥ 2 =   θ n − 1 − ( ε 1 − ε 2 ) ∇ θ [ L ( θ n − 1 ) + ε 1 ε 2 2 ( ε 1 − ε 2 ) ∥ ∇ θ L ( θ n − 1 ) ∥ 2 ] (5) \begin{aligned} \theta_{n+1} =& \, \theta_{n-1} - \varepsilon_1g(\theta_{n-1})+\varepsilon_2g(\theta_{n-1} - \varepsilon_1g(\theta_{n-1}))\\ \approx&\, \theta_{n-1} - \varepsilon_1g(\theta_{n-1}) + \varepsilon_2(g(\theta_{n-1}) - \varepsilon_1\nabla_\theta g(\theta_{n-1})g(\theta_{n-1}))\\ =&\, \theta_{n-1} - (\varepsilon_1 - \varepsilon_2) g(\theta_{n-1}) - \frac{\varepsilon_1\varepsilon_2}{2}\nabla_{\theta}\Vert g(\theta_{n-1})\Vert^2\\ =&\,\theta_{n-1} - (\varepsilon_1 - \varepsilon_2)\nabla_{\theta}\left[\mathcal{L}(\theta_{n-1}) + \frac{\varepsilon_1\varepsilon_2}{2(\varepsilon_1 - \varepsilon_2)}\Vert \nabla_{\theta}\mathcal{L}(\theta_{n-1})\Vert^2\right] \end{aligned}\tag{5} θn+1===θn1ε1g(θn1)+ε2g(θn1ε1g(θn1))θn1ε1g(θn1)+ε2(g(θn1)ε1θg(θn1)g(θn1))θn1(ε1ε2)g(θn1)2ε1ε2θg(θn1)2θn1(ε1ε2)θ[L(θn1)+2(ε1ε2)ε1ε2θL(θn1)2](5)
这就相当于自始自终都在用学习率 ε 1 − ε 2 \varepsilon_1-\varepsilon_2 ε1ε2来优化损失函数 L ( θ ) + ε 1 ε 2 2 ( ε 1 − ε 2 ) ∥ ∇ θ L ( θ ) ∥ 2 \mathcal{L}(\theta) + \frac{\varepsilon_1\varepsilon_2}{2(\varepsilon_1 - \varepsilon_2)}\Vert\nabla_{\theta}\mathcal{L}(\theta)\Vert^2 L(θ)+2(ε1ε2)ε1ε2θL(θ)2了,也就是说一开始就把梯度惩罚给加了进去,这样能提升模型的泛化性能吗?《Backstitch: Counteracting Finite-sample Bias via Negative Steps》里边指出这种做法在语音识别上是有效的,请读者自行测试甄别

效果检验

我随便在网上找了个竞赛,然后利用别人提供的以BERT为baseline的代码,对Flooding的效果进行了测试,下图分别是没有做Flooding和参数 b = 0.7 b=0.7 b=0.7的Flooding损失值变化图,值得一提的是,没有做Flooding的验证集最低损失值为0.814198,而做了Flooding的验证集最低损失值为0.809810

根据知乎文章一行代码发一篇ICML?底下用户Curry评论所言:“通常来说 b b b值需要设置成比’Validation Error开始上升’的值更小,1/2处甚至更小,结果更优”,所以我仔细观察了下没有加Flooding模型损失值变化图,大概在loss为0.75到1.0左右的时候开始出现过拟合现象,因此我又分别设置了 b = 0.4 b=0.4 b=0.4 b = 0.5 b=0.5 b=0.5,做了两次Flooding实验,结果如下图

值得一提的是, b = 0.4 b=0.4 b=0.4 b = 0.5 b=0.5 b=0.5时,验证集上的损失值最低仅为0.809958和0.796819,而且很明显验证集损失的整体上升趋势更加缓慢。接下来我做了一个实验,主要是验证"继续脑洞"部分以不同的学习率一开始就交替着做梯度下降和梯度上升的效果,其中,梯度下降的学习率我设为 1 e − 5 1e-5 1e5,梯度上升的学习率为 1 e − 6 1e-6 1e6,结果如下图,验证集的损失最低仅有0.783370

References

  • 我们真的需要把训练集的损失降低到零吗?
  • 一行代码发一篇ICML?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/51885.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

手摸手教你 docker+jenkins+gitlab 部署你的前端项目

学习了一周的CICD,踩了很多坑,都是泪,特此记录一下整个过程,本次项目产出效果是,git push的时候自动化直接部署到服务器上,以下是整个大致流程: 本地代码push到gitlab gitlab通过webhook通知到jenkins jenkins拉取gitlab仓库代码,并执行shell脚本 shell脚本执行docker命令,打…

python数据分析——NumPy基础

目录 一、创建数组的方法 二、array的属性 三、创建特殊的数组 四、数组的变换 4.1、数组重塑 4.2、数组合并 4.3、数组分割 4.4、数组转置和轴对换 五、数组的索引和切片 5.1、一维数组的索引 5.2、多维数组的索引 5.3、多维数组的访问 六、数组的运算 6.1、数组…

Apollo 应用与源码分析:guardian 紧急处置

目录 概念 代码 分析 概念 Guardian模块的主要作用是监控自动驾驶系统状态&#xff0c;当出现模块为失败状态的时候&#xff0c;会主动切断控制命令输出&#xff0c;并且刹车。 有点像是保险丝&#xff0c;有一个fallback机制。 guardian模块的触发条件主要有2个。 上报…

虚拟机的快照与克隆

简单回顾以下快照 快照的拍摄&#xff1a; 记录虚拟机当前的状态 拍摄快照时&#xff0c;系统一定要处于关机状态 转到&#xff1a; 回到某一个历史快照节点 克隆 复制某一个历史快的的节点 克隆的方式 链接克隆&#xff1a; 当前节点文件家只存储差异性数据 相同数据放在原…

RabbitMQ之延迟队列

延迟消息是指的消息发送出去后并不想立即就被消费&#xff0c;而是需要等&#xff08;指定的&#xff09;一段时间后才触发消费。 例如下面的业务场景&#xff1a;在支付宝上面买电影票&#xff0c;锁定了一个座位后系统默认会帮你保留15分钟时间&#xff0c;如果15分钟后还没付…

zcu106 lwip搭建以太网配置寄存器

文章目录实验一1.配置网口GEM32.导出xsa文件&#xff0c;在vitis中创建工程&#xff0c;选择freertos10_xilinx的操作系统来使用3.配置lwip211&#xff0c;选择SOCKET API的模式4.创建工程 选择FreeRTOS Iwip TCP Perf Server模板5.代码分析main.cfreertos_tcp_perf_server.cfr…

基于yolov5n的轻量级MSTAR遥感影像目标检测系统设计开发实战

做过很多目标检测类的项目了&#xff0c;最近看到一个很早之前用过的数据集MSTAR&#xff0c;之前老师给的任务是基于这个数据集来搭建图像识别模型&#xff0c;殊不知他也是可以用来做目标检测的&#xff0c;今天正好有点时间就想着基于这个数据集来做一下目标检测实践。 首先…

利用车载摄像头了解道路语义的鸟瞰图

以下内容来自从零开始机器人SLAM知识星球 每日更新内容 点击领取学习资料 → 机器人SLAM学习资料大礼包 #论文##开源代码# Understanding Bird’s-Eye View of Road Semantics using an Onboard Camera 论文地址&#xff1a;https://arxiv.org/abs/2012.03040 作者单位&#…

自助建站工具

每用一次自助建站工具&#xff0c;就有一个程序员失业。 作为企业老板的你&#xff0c;要为公司的获客&#xff0c;企业推广发愁&#xff0c;但是预算有限&#xff0c;招人也很困难&#xff0c;不仅要面试程序员&#xff0c;后续还要检验这个程序员的功力&#xff0c;实在是太…

CentOS升级python3版本

介绍 本文将详细介绍在CentOS7.9系统的服务器将自带的python3.6.8版本升级到3.8.0版本的过程。 在升级前CentOS7.9中已经同时存在两个python版本分别是2.7.5和3.6.8。 查看CentOS版本命令&#xff1a; cat /etc/centos-release这是我升级后的python版本&#xff08;python3升…

Minio设置文件永久访问和下载

1. docker pull minio/mc 2. docker run -it --entrypoint/bin/sh minio/mc 3. mc config host add <ALIAS> <YOUR-S3-ENDPOINT> <YOUR-ACCESS-KEY> <YOUR-SECRET-KEY> [--api API-SIGNATURE] mc ls minio ALIAS: 别名就是给你的云存储服务起了一个…

2021-02-01

oracle设置定期修改密码 --通过如下sql查询用户密码有效期配置 SELECT username,PROFILE FROM dba_users; --上述sql查询结果一般为default --使用如下sql可以查询到default的默认值 select * from dba_profiles where profile DEFAULT and resource_name PASSWORD_LI…

HTML+CSS大作业【传统文化艺术耍牙15页】学生个人网页设计作品

&#x1f389;精彩专栏推荐 &#x1f4ad;文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业&#xff1a; 【&#x1f4da;毕设项目精品实战案例 (10…

MES系统中生产计划模块的重要作用

MES系统中的“生产调度”支持从“工单管理”中收到的作业队列&#xff0c;根据生产目标&#xff08;时间和数量&#xff09;&#xff0c;必须考虑到人员、设备、材料的可用性等限制和生产过程中的各种中断&#xff0c;生成一个作业时间表&#xff0c;即生产作业计划。MES系统“…

iOS 接入firebase简易步骤

接入准备 去firebase官网注册应用并下载配置文件GoogleService-Info.plist 接入步骤 1.通过cocopods导入以下两个依赖 pod Firebase/Analytics pod Firebase/Core 2.导入成功后将配置文件GoogleService-Info.plist拖入项目中 3.代码支持 引入#import <Firebase/Firebas…

【云原生系列】第五讲:Knative Eventing 下

目录 序言 1.Parallel介绍 1.1 Parallel Spec ​编辑 2.Sequence 2.1.Sequence Spec 2.2适用场景 2.3 Broker/Trigger 2.4 代码示例 3.投票 序言 三言两语&#xff0c;不如细心探索。 今天整理了一下Eventing 相关知识点 ParallelSequence希望此文&#xff0c;能帮…

C#正则表达式总结

推荐一个专门用于编写正则表达式的网站&#xff1a; regex101: build, test, and debug regex 参考文档&#xff1a; https://zh.wikipedia.org/wiki/%E6%AD%A3%E5%88%99%E8%A1%A8%E8%BE%BE%E5%BC%8F 特殊字符的意义&#xff1a; ^ : 表示字符串的开头 例子&#xff1a; …

局域网的网络硬件主要包括有什么

一、网络服务器 是计算机局域网的核心部件。 网络操作系统是在网络服务器上运行的&#xff0c;网络服务器的效率直接影响整个网络的效率。 因此&#xff0c;一般要用高档计算机或专用服务器计算机作为网络服务器。 二、网络工作站 网络工作站是通过网络接口卡连接到网络上的…

银行软件测试:基于互联网金融平台的测试框架设计与分析

目前互联网金融火的一塌糊涂&#xff0c;基于互联网金融平台的自动化测试的项目也是如火如荼的进行。笔者手头上负责一个p2p项目的测试框架开发&#xff0c;因此如何设计一套有效的测试框架也成为工作所需和互相交流测试经验的必须。进入》软件测试社群学习交流 这个网站的后台…

科研绘图配色方案

科研绘图配色方案 在撰写论文的时候&#xff0c;美观&#xff0c;大气&#xff0c;上档次的图表能够很好地给自己的论文加分。但是在绘制图表的时候往往会面临色彩搭配的问题&#xff0c;选择合适的色彩搭配能够有效地展示自己的方法&#xff0c;但是色彩搭配选择不当的话往往…