NeurIPS 22|四分钟内就能训练目标检测器!( AGVM)

news2025/1/16 5:49:23

文章目录

    • 引言
    • 方法介绍
    • 实验过程
    • 结果分析

引言

来自商汤的基模型团队和香港大学等机构的研究人员提出了一种大批量训练算法 AGVM,该研究已被NeurIPS 2022接收。

本文提出了一种大批量训练算法 AGVM (Adaptive Gradient Variance Modulator),不仅可以适配于目标检测任务,同时也可以适配各类分割任务。AGVM 可以把目标检测的训练批量大小扩大到 1536,帮助研究人员四分钟训练 Faster R-CNN,3.5 小时把 COCO 刷到 62.2 mAP,均打破了目标检测训练速度的世界纪录。
在这里插入图片描述

  • 论文地址:https://arxiv.org/pdf/2210.11078.pdf
  • 代码地址:https://github.com/Sense-X/AGVM

在当前的机器学习社区中,有三个普遍的趋势。首先,神经网络模型会越来越大。在 NLP 领域中最大规模的模型已经达到了上万亿级别。在视觉领域,最大规模的模型也达到了三百亿的量级。其次,训练的数据集也变得越来越大。比如,ImageNet 21k 和谷歌的 JFT 数据集都具有相当规模的数据集。另外,由于数据集变得越来越大,训练 SOTA 模型的开销越来越大。

因此,提升训练效率就变得愈发重要。而分布式训练因为其适应于数据并行、模型并行和流水线并行的加速训练方法的同时,也具备较高的 Deep Learning 通信效率而被广泛认为是一个有效的解决方案。

随着大模型时代的到来,目标检测器的训练速度越来越成为学术界和工业界的瓶颈,例如,在 COCO 的标准 setting 上把 mAP 训到 62 以上大概需要三天的时间,算上调试成本,这在业界几乎是不可接受的。那么,我们能不能把这个训练时间压到小时级别呢?事实上,在图片分类和自然语言处理任务上,先前的研究人员借助 32K 的批量大小(batch size),只需 14 分钟就可以完成 ImageNet 的训练,76 分钟完成 Bert 的训练。但是,在目标检测领域,还很欠缺这类研究,导致研究人员无法充分利用当前的算力,数据集和大模型。

大批量训练算法 AGVM 便是这个问题的最佳解决方案之一。为了支持如此大批量的训练,同时保持模型的训练精度,本研究提出了一套全新的训练算法,根据密集预测不同模块的梯度方差(gradient variance),动态调整每一个模块的学习率。作者在大量的密集预测网络和数据集上进行了实验,并且证实了该方法的合理性。

方法介绍

大批量训练是加速大型分布式系统中深度神经网络训练的关键。尤其是在如今的大模型时代,如果不采用大批量训练,一个网络的训练时间几乎是难以接受的。但是,大批量训练很难,因为它会产生泛化差距(generalization gap), 直接训练会导致其准确率降低。此前的大批量工作往往针对于图像分类以及一些自然语言处理的任务,但密集预测任务(包括检测分割等),同样在视觉中处于举足轻重的位置,此前的方法并不能在密集预测任务上有很好的表现,甚至结果比基准线更差,这导致我们难以快速训练一个目标检测器。

为了解决这个问题,研究人员进行了大量的实验。最后发现,相较于传统的分类网络,利用密集预测网络一个很重要的特征:密集预测网络往往是由多个组件组成的,以 Faster R-CNN 为例:它由四个部分组成,骨干网络 (Backbone),特征金字塔网络(FPN),区域生成网络(RPN) 和检测头网络(head),我们可以发现一个很有效的指标:密集预测网络不同组件的梯度方差,在训练批量很小时(例如 32),几乎是相同的,但当训练批量很大时(例如 512),它们呈现出很大的区别,如下图所示:
在这里插入图片描述
那么, 能不能直接把这些拉平呢? 这直接引出了 AGVM 算法。以随机梯度下降算法为例, 上角标 代表第 i 个网络模块(例如 FPN 等), 上角标 1 代表骨干网络, η t \eta_{t} ηt代表学习率, 针定骨干网络, 可以直接将不同网络组件的梯度g的方差 θ t ( i ) \theta^{(i)}_{t} θt(i):
在这里插入图片描述
梯度的方差 θ t ( i ) \theta^{(i)}_{t} θt(i)可以由以下式子估计:
在这里插入图片描述
方差的具体求解细节可以参考原文,本研究同样引入了滑动平均机制,防止网络训练发散。同时,研究证明了 AGVM 在非凸情况下的收敛性,讨论了动量以及衰减的处理方式,具体实现细节可以参考原文。

实验过程

本研究首先在目标检测、实例分割、全景分割和语义分割的各种密集预测网络上进行了测试,通过下表可以看到,当用标准批量大小训练时,AGVM 相较传统方法没有明显优势,但当在超大批量下训练时,AGVM 相较传统方法拥有压倒性的优势,下图第二列从左至右分别表示目标检测,实例分割,全景分割和语义分割的表现,AGVM 超越了有史以来的所有方法:
在这里插入图片描述
下表详细对比了 AGVM 和传统方法,体现出了本研究方法的优势:
在这里插入图片描述
同时,为了说明 AGVM 的优越性,本研究进行了以下三个超大规模的实验。研究人员把 Faster R-CNN 的 batch size 放到了 1536,这样利用 768 张 A100 可以在 4.2 分钟内完成训练。其次,借助 UniNet-G,本研究可以在利用 480 张 A100 的情况下,3.5 个小时让模型在 COCO 上达到 62.2mAP(不包括骨干网络预训练的时间),极大的减小了训练时间:
在这里插入图片描述
甚至,在 RetinaNet 上,本研究把批量大小扩展到 10K。这在目标检测领域是从未见的批量大小,在如此大的批量下,每一个 epoch 只有十几个迭代次数,AGVM 在如此大的批量下,仍然能展现出很强的稳定性,性能如下图所示:
在这里插入图片描述

结果分析

本研究探究了一个很重要的问题:以 RetinaNet 为例,如下图第一列所示,探究为什么会出现梯度方差不匹配这一现象。

本研究认为,这一现象来自于:网络不同模块间的有效批量大小 (effective batch size) 是不同的。例如,RetinaNet 的头网络的输入是由特征金字塔的五层网络输出的,特征金字塔的 top-down 和 bottom-up pathways,以及像素维度的损失函数计算会导致头网络和骨干网络的等效批量大小不同,这一原理导致了梯度方差不匹配的现象。

为了验证这一假设,本研究依次给每一层特征使用单独的头网络,移去特征金字塔网络,随机忽略掉 75% 的用于计算损失函数的像素,最终,本研究发现骨干网络和头网络的梯度方差曲线重合了,本研究也对 Faster R-CNN 做了类似的实验,如下图第二列所示,更多的讨论请参见原文。
在这里插入图片描述
在这里插入图片描述
转载至微信公众号目标检测与机器学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/19060.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Shell脚本

文章目录Shell脚本学习1. Shell概念1.1Shell脚本的好处1.2 Shell脚本的入门1.2.1 Linux环境中默认Shell版本1.2.2 Shell脚本1.2.3 编写简单的hello,world 脚本1.2.4 Shell 脚本的多种执行方法1.2.4.1 第一种 bash 或 sh 加文件的路径1.2.4.2 第二种 文件的路径直接执行1.2.4.2.…

Day07--wxs的概念以及其基本的用法

一.概念 1.啥子是wxs呢? *****************************************************************************************************************************************************************************************************************************…

【附源码】Python计算机毕业设计特大城市地铁站卫生防疫系统

项目运行 环境配置: Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术: django python Vue 等等组成,B/S模式 pychram管理等等。 环境需要 1.运行环境:最好是python3.7.7,…

数字电路和模拟电路-10时序逻辑电路的分析和设计

前言:学习同步时序逻辑电路的分析、设计 一、同步时序逻辑电路的分析 1、时序逻辑电路的分析步骤 步骤一 逻辑图 同步or异步 计数器or状态机 一条总线同步,多条总线是异步 计数器无输入,状态机有输入 状态机还分摩尔型和米里型 步骤二 驱动…

力扣(LeetCode)17. 电话号码的字母组合(C++)

回溯 将 222——999 和字母对应起来,用字符串数组保存。 递归遍历 digitsdigitsdigits 每一个数字,每一个数字对应的字母,又可以递归遍历,和下一个数字的字母组成排列。当排列长度等于 digitsdigitsdigits 的长度,就…

详解MySQL非常重要的日志—bin log

前言 bin log想必大家多多少少都有听过,它是MySQL中一个非常重要的日志,所以各位架构师们,如果有不了解的,一定要好好学习了,因为它涉及到数据库层面的主从复制、高可用等设计。 bin log是什么? bin log…

【博客538】BGP优雅重启机制

bgp优雅重启机制 背景 以BGP为代表的路由协议,从设计之初,就关注路由表的正确性,因为这是确保整个网络系统正常工作的最基本要求。因此每个BGP路由器,总是会以最快的速度收敛到整个网络最新的状态上。当一个BGP peer的BGP连接断开…

一种PEG衍生物Azide-PEG-Biotin|N3-PEG-Biotin|叠氮-PEG-生物素|956748-40-6

1、名称 英文:N3-PEG-Biotin,Azide-PEG-Biotin 中文:叠氮-聚乙二醇-生物素 2、CAS编号:956748-40-6 3、所属分类: Azide PEG Biotin PEG 4、分子量:可定制 5、质量控制:95% 6、储存&…

[附源码]java毕业设计食堂线上点餐系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

玩转MySQL:命令大全~忘记了SQL该怎么写就回来看看~

引言 相信大家在编写SQL时一定有一个困扰,就是明明记得数据库中有个命令/函数,可以实现自己需要的功能,但偏偏不记得哪个命令该怎么写了,这时只能靠盲目的去百度,以此来寻找自己需要的命令。 时间是最厉害的武器&…

Anaconda3安装部署(二) 百篇文章学PyQT

本文章是百篇文章学PyQT的第二篇,本文讲述如何安装Anaconda3工具,Anaconda3 在安装过程中会遇到很多问题,博主在本篇文章中将遇到和踩过的坑总结出来,可以供大家参考,希望大家安装顺利。包括 安装、遇到问题的解决方案…

实战十八:通过ItemKNN算法实现基于协同过滤的商品推荐 代码+数据

项目概述: 推荐系统任务描述:通过用户的历史行为(比如浏览记录、购买记录等等)准确的预测出用户未来的行为;好的推荐系统不仅如此,而且能够拓展用户的视野,帮助他们发现可能感兴趣的却不容易发现的item;同时将埋没在长尾中的好商品推荐给可能感兴趣的用户。ItemKNN推荐…

【专栏】基础篇05| Redis 该怎么保证数据不丢失(下)

前言 上一小节我们讲了AOF是什么以及它是如何保证Redis的Crash Safe的,这一节我们再来看一看Redis的RDB和AOF有何不同,两者是怎么样的关系 RDB的工作模式 RDB全称Redis Database,我们也常叫做Redis的内存快照,它与AOF最大的不同在…

基于java+ssm幼儿园教学网站管理系统vue-计算机毕业设计

项目介绍 要想做好幼升小的衔接工作,首先我们要明确小学生相对于幼儿园来说的不同之处。在幼儿园阶段,我们更多的是让小朋友做游戏,培养他们的学习兴趣等。而进入小学后,课程种类增加了,阅读信息不再是简单的图片&…

PHP房屋租售信息管理系统可以用wamp、phpstudy运行定制开发mysql数据库BS模式

一、源码特点 PHP房屋租售信息管理系统 是一套完善的web设计系统,对理解php编程开发语言有帮助,系统具有完整的源代码和数据库系统主要采用B/S模式开发,开发环境为PHP APACHE,数据库为mysql5.0,使用php语言开发 PHP房屋租售信…

kubernetes组件再认知

背景 之前学习k8s的各组件还是感觉不深入, 只停留在名字解释上面。总是不能深入理解,例如应用部署后kuber-proxy会在master 和node上添加什么样的iptables规则、部署一个应用的完整流程( 手画各组件功能并介绍10分钟以上 )、schedule具体是怎么调度的、limit reque…

计算机视觉|针孔成像,相机内外参及相机标定,矫正的重要性

计算机视觉读书笔记|相机内外参及相机标定,矫正的重要性 这篇博客将介绍针孔成像,透镜(弥补了针孔成像曝光不足成像速度慢的缺点,但引进了畸变,主要是径向畸变和切向畸变,径向畸变主要是离中心越远越弯曲&…

Tableau指标排行

2022年11月15日,深圳数据交易所举行揭牌暨数据交易成果发布活动。 文章目录前言一、整体数据排行二、数据排行TOP N三、根据需要也可以显示具体排名总结前言 分享Tableau指标排行制作过程中遇到的问题及其解决方式,供各位小伙伴参考。 一、整体数据排行…

谷粒学院(二) 讲师管理模块

一、讲师管理模块配置 1、在service下面service-edu模块中创建配置文件 2. resources目录下创建文件 application.properties # 服务端口 server.port8001 # 服务名 spring.application.nameservice-edu# 环境设置:dev、test、prod spring.profiles.activedev# my…

java基于ssm网上超市购物商城-计算机毕业设计

项目介绍 网上超市是商业贸易中的一条非常重要的道路,可以把其从传统的实体模式中解放中来,网上购物可以为消费者提供巨大的便利。通过网上超市这个平台,可以使用户足不出户就可以了解现今的流行趋势和丰富的商品信息,为用户提供…