【炼金术士】BatchSize对网络训练的影响

news2024/12/22 20:07:51

文章目录

  • 1 BatchSize对于网络训练的影响
  • 2 调整学习率可以提高大BatchSize的性能
  • 3 实际训练时的建议
    • 3.1 设置初始学习率的方法
    • 3.2 多卡训练时学习率的设置

参考资料:

  • 【深度学习】Batch Size对神经网络训练的影响
  • 【AI不惑境】学习率和batchsize如何影响模型的性能?
  • 论文:ON LARGE-BATCH TRAINING FOR DEEP LEARNING: GENERALIZATION GAP AND SHARP MINIMA
  • 如何找到最优学习率
  • 论文:Cyclical Learning Rates for Training Neural Networks

一句话总结:当BatchSize增大k倍时,学习率也要对应增大k倍,以达到同样的泛化效果。

1 BatchSize对于网络训练的影响

网络、超参数、训练数据一样的情况下:

  • BatchSize越大,网络训练速度越快;
  • 大BatchSize训练出来的模型,泛化性比小BatchSize训练出来的模型泛化性要差。

针对BatchSize越大,网络训练速度快这条很好理解,在相同的单位时间内,一次训练的数据越多,网络训练的速度也就越快。那么为什么在条件相同的情况下,大BatchSize训练出来的模型泛化性比小Batch训练出来的模型泛化性要差呢?

Keskar 等人对小批量和大批量之间的性能差距提出了一种解释:使用小批量的训练倾向于收敛到平坦的极小化,该极小化在极小化的小邻域内仅略有变化,而大批量则收敛到尖锐的极小化,这变化很大。平面minimizers 倾向于更好地泛化,因为它们对训练集和测试集之间的变化更加鲁棒 。
在这里插入图片描述
Source: Keskar et al.

此外,他们发现与大批量训练相比,小批量训练可以找到距离初始权重更远的最小值。他们解释说,小批量训练可能会为训练引入足够的噪声,以退出锐化minimizers 的损失池,而是找到可能更远的平坦minimizers

这个证明非常有意思,已知基于Batch参数更新公式如下:
θ = θ − α ( 1 B Σ i = 1 B ∇ θ J i ( θ ) ) \theta = \theta - \alpha(\frac{1}{B} \Sigma_{i=1}^{B}\nabla_{\theta}J_i(\theta)) θ=θα(B1Σi=1BθJi(θ))
其中B代表BatchSize的大小。假设a,b分别代表batch为1的两次迭代,a+b代表batch为2的一次迭代。

在这里插入图片描述
Source: 【深度学习】Batch Size对神经网络训练的影响

很明显, ∣ a ∣ + ∣ b ∣ ≥ ∣ a + b ∣ |a| + |b| \geq |a+b| a+ba+b的。只有当a和b更新的方向与(a+b)更新的方向一致,且 ∣ a ∣ + ∣ b ∣ = ∣ a + b ∣ |a| + |b| = |a+b| a+b=a+b时,两次小batch的迭代才等价为一次大batch的迭代。然而这种情况发生的概率非常低。所以可以得出结论:小Batch可以引入更加丰富的梯度信息,使网络有更好的泛化性。

2 调整学习率可以提高大BatchSize的性能

从上图可以看出,2轮小Batch的参数更新幅度是要比1次大Batch的更新幅度要大的,为了弥补这种差距,相应的提高大Batch的学习率,就可以使得大Batch与多轮小Batch迭代效果差不多,从而弥补大Batch训练出来的模型泛化性问题。

实验也证明了这一点:批量大小为 32、64、128 和 256。我们将对批量大小 32 使用 0.01 的基本学习率,并相应地缩放其他批量大小。
在这里插入图片描述
Source: 【深度学习】Batch Size对神经网络训练的影响

事实上,我们发现调整学习率确实消除了小批量和大批量之间的大部分性能差距。现在,批量大小 256 的验证损失为 0.352 而不是 0.395——更接近批量大小 32 的损失 0.345。

3 实际训练时的建议

在实际网络训练过程中,BatchSize的设置往往取决于模型的大小和GPU的显存,那么为了提高训练的效率,可以:

  • 最大限度设置BatchSize,使GPU的利用率最高;
  • 在单卡GPU上测试最佳初始学习率;

3.1 设置初始学习率的方法

初始的学习率肯定是有一个最优值的,过大则导致模型不收敛,过小则导致模型收敛特别慢或者无法学习,下图展示了不同大小的学习率下模型收敛情况的可能性。
在这里插入图片描述
Source: tensorflow loss分析

首先在单卡GPU上,先将BatchSize拉满,然后通过以下的方式寻找最佳初始学习率。

Leslie N. Smith 在2015年的一篇论文“Cyclical Learning Rates for Training Neural Networks”中的3.3节描述了一个非常棒的方法来找初始学习率,方法非常简单。首先我们设置一个非常小的初始学习率,比如1e-5,然后在每个batch之后都更新网络,同时增加学习率,统计每个batch计算出的loss。最后我们可以描绘出学习的变化曲线和loss的变化曲线,从中就能够发现最好的学习率。
关键点:

  • 初始学习率要足够小,并随着迭代次数增加而逐渐增大
  • 选择使loss下降最快的学习率作为最佳初始学习率
    在这里插入图片描述
    Source: 如何找到最优学习率
    从上图可以看出最佳学习初始学习率为0.1。

3.2 多卡训练时学习率的设置

上一小节讲述了如何在单卡GPU上寻找最佳初始学习率,如果有多卡GPU同时训练时,就需要遵循线性缩放原则。有多少块GPU同时训练,就意味着BatchSize增大多少倍,那么学习率也要对应增大多少倍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1630588.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Windows使用SSH登录本机Linux虚拟机

SSH(Secure Shell),一种网络协议,可以在安全外壳下实现数据传输通信,所以主要用于计算机间加密登录,可以简单理解为远程控制。除了计算机间直接互联,在git中也可以看到,常见的协议有…

购买 DDoS 高防 IP 防护哪家好?

DDoS 高防 IP 哪里买会比较好?在这场攻与守的游戏里,DDoS 高防 IP 是一种针对 DDoS 攻击的防护措施,通过将网站或应用的 IP 地址映射到高防 IP 上,实现对流量的清洗和过滤,从而有效抵御 DDoS 攻击。在选择 DDoS 高防 IP 服务提供…

Ubuntu下部署gerrit+报错分析(超详细)

Ubuntu下部署gerrit代码平台 之前安装过几次 最后都在Apache代理这里失败了,如下图,总是gerrit.config与Apache2.config配置有问题,后面换了使用ngnix代理,简单多了 安装Mysql、gerrit、jdk、git 这一步也是非必须得&#xff0…

无监督学习的评价指标

轮廓系数(Silhouette Coefficient) 轮廓系数用于判断聚类结果的紧密度和分离度。轮廓系数综合了样本与其所属簇内的相似度以及最近的其他簇间的不相似度。 其计算方法如下: 1、计算簇中的每个样本i 1.计算a(i) &#x…

实时采集麦克风并播放(springboot+webscoekt+webrtc)

项目技术 springbootwebscoektwebrtc 项目介绍 项目通过前端webrtc采集麦克风声音,通过websocket发送后台,然后处理成g711-alaw字节数据发生给广播UDP并播放。 后台处理项目使用线程池(5个线程)接受webrtc数据并处理g711-alaw字节数组放到Map容器中&…

opencv基础篇 ——(九)图像几何变换

图像几何变换是通过对图像的几何结构进行变换来改变图像的形状、大小、方向或者透视关系。常见的图像几何变换包括缩放、旋转、平移、仿射变换和透视变换等。下面对这些几何变换进行简要介绍: 矩阵的转置(transpose ): 对于图像来…

吴恩达2022机器学习专项课程(一) 7.1 逻辑回归的成本函数第三周课后实验:Lab4逻辑回归的损失函数

问题预览/关键词 上节课回顾逻辑回归模型使用线性回归模型的平方误差成本函数单个训练样本的损失损失函数,成本函数,代价函数的区别线性回归损失函数和逻辑回归损失函数的区别逻辑回归模型的成本函数是什么?逻辑回归模型的损失函数实验逻辑回…

STL——List常用接口模拟实现及其使用

认识list list的介绍 list是可以在常数范围内在任意位置进行插入和删除的序列式容器,并且该容器可以前后双向迭代。 list的底层是双向链表结构,双向链表中每个元素存储在互不相关的独立节点中,在节点中通过指针指向其前一个元素和后一个元素…

linux tcpdump的交叉编译以及使用

一、源码下载 官网:点击跳转 二、编译 1、解压 tar -xf libpcap-1.10.4.tar.xz tar -xf tcpdump-4.99.4.tar.xz 2、配置及编译 //libpcap: ./configure --hostarm-linux --targetarm-linux CCarm-linux-gcc --with-pcaplinux --prefix$PWD/build//t…

对象与JSON字符串互转

1、JSON字符串转化成JSON对象 JSONObject jsonobject JSON.parseObject(str); 或者 JSONObject jsonobject JSONObject.parseObject(str); 功能上是一样的,都是将JSON字符串(str)转换成JSON对象 jsonobject 。注意str一定得是以键值对存在…

STM32之HAL开发——电容按键

电容按键原理 电容器 (简称为电容) 就是可以容纳电荷的器件,两个金属块中间隔一层绝缘体就可以构成一个最简单的电容。如图 32_1 (俯视图),有两个金属片,之间有一个绝缘介质,这样就构成了一个电容。这样一个电容在电路板上非常容…

二维数组求最大值(C语言)

一、N-S流程图&#xff1b; 二、运行结果&#xff1b; 三、源代码&#xff1b; # define _CRT_SECURE_NO_WARNINGS # include <stdio.h>int main() {//初始化变量值&#xff1b;int i, j, max 0, row 0, colum 0;int arr[3][4] { {1, 2, 3}, {4, 5, 16}, {7, 8, 9} …

线上办理离婚快速离婚,无需双方见面异地可办

现在离婚有两种方式 一种是协议离婚&#xff0c;双方都同意的情况下&#xff0c;可以去民政局协议离婚&#xff0c;有30天冷静期&#xff0c;冷静期过后需要双方再次去民政局办理离婚手续。 另一种是诉讼离婚&#xff0c;一方不同意离婚&#xff0c;可以选择诉讼离婚。可以全…

Vue 3 路由机制详解与实践

一、路由的理解 路由是指导用户界面导航的一种机制。它通过映射 URL 到应用程序的不同视图组件来实现页面间的切换和导航。 二、路由基本切换效果 路由基本切换效果指的是当用户在应用程序中进行页面导航时&#xff0c;通过路由可以实现页面的切换&#xff0c;从而展示不同的…

ICMP详解

3 ICMP ICMP&#xff08;Internet Control Message Protocol&#xff0c;因特网控制报文协议&#xff09;是一个差错报告机制&#xff0c;是TCP/IP协议簇中的一个重要子协议&#xff0c;通常被IP层或更高层协议&#xff08;TCP或UDP&#xff09;使用&#xff0c;属于网络层协议…

CTF(web方向)--md5的“===”和“==”的绕过

一、PHP弱类型说明 1.简介 php是一种弱类型语言&#xff0c;对数据的类型要求并不严格&#xff0c;可以让数据类型互相转换。 在php中有两种比较符号: 一种是 &#xff0c;另外一种是 &#xff0c;都是用来比较两个数值是否相等的操作符&#xff0c;但他们也是有区别的: &a…

Linux 小技巧1

目录 一. 统计文件的总行数二. 获取从第二行开始的内容三. 合并两个文件为一个文件四. 统计指定列唯一值的数量五. 列出文件的绝对路径六. 获取除了空白行和注释之外的部分 一. 统计文件的总行数 ⏹非压缩文件 统计当前文件夹下csv文件的行数 wc -l ./*.csv统计指定文件夹下…

想要应聘前端工程师——学习路线指南

前端工程师学习路线 按照前端岗位需求,以优先学习工作更需要,面试更常考的内容为原则,由浅入深,层层铺垫,与时俱进,可以较容易地总结出前端学习路线图: HTML / CSS / JavaScript 基础学习 《Web 入门》 MDN 权威入门指南,HTML / CSS / JavaScript 快速上手 《CSS 世界…

面试中算法(链表)

链表相关的题 有一个单向链表&#xff0c;链表中有可能出现“环”&#xff0c;如图所示&#xff0c;如何用程序来判断该链表是否为有环链表呢? 对于这道题&#xff0c;有一个很巧妙的方法&#xff0c;这个方法利用了两个指针。 首先创建两个指针pi和p2(在Python里就是两个对象…

【问题分析】TaskDisplayArea被隐藏导致的黑屏以及无焦点窗口问题【Android 14】

1 问题描述 用户操作出的偶现的黑屏以及无焦点窗口问题。 直接原因是&#xff0c;TaskDisplayArea被添加了eLayerHidden标志位&#xff0c;导致所有App的窗口不可见&#xff0c;从而出现黑屏和无焦点窗口问题&#xff0c;相关log为&#xff1a; 这个log是MTK添加的&#xff0…