GiantPandaCV | 提升分类模型acc(一):BatchSizeLARS

news2025/1/12 12:17:53

本文来源公众号“GiantPandaCV”,仅用于学术分享,侵权删,干货满满。

原文链接:提升分类模型acc(一):BatchSize&LARS

在使用大的bs训练情况下,会对精度有一定程度的损失,本文探讨了训练的bs大小对精度的影响,同时探究Layer-wise Adaptive Rate Scaling(LARS)是否可以有效的提升精度。

论文链接:https://arxiv.org/abs/1708.03888

论文代码: https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
知乎专栏: https://zhuanlan.zhihu.com/p/406882110

1 引言

如何提升业务分类模型的性能,一直是个难题,毕竟没有99.999%的性能都会带来一定程度的风险,所以很多时候我们只能通过控制阈值来调整准召以达到想要的效果。本系列主要探究哪些模型trick和数据的方法可以大幅度让你的分类性能更上一层楼,不过要注意一点的是,tirck不一定是适用于不同的数据场景的,但是数据处理方法是普适的。本篇文章主要是对于大的bs下训练分类模型的情况,如果bs比较小的可以忽略,直接看最后的结论就好了(这个系列以后的文章讲述的方法是通用的,无论bs大小都可以用)。

2 实验配置

  • 模型:ResNet50

  • 数据:ImageNet1k

  • 环境:8xV100

3 BatchSize对精度的影响

我这里设计了4组对照实验,256, 1024, 2048和4096的batchsize,开了FP16也只能跑到了4096了。采用的是分布式训练,所以单张卡的bs就是bs = total_bs / ngpus_per_node。这里我没有使用跨卡bn,对于bs 64单卡来说理论上已经很大了,bn的作用是约束数据分布,64的bs已经可以表达一个分布的subset了,再大的bs还是同分布的,意义不大,跨卡bn的速度也更慢,所以大的bs基本可以忽略这个问题。但是对于检测的任务,跨卡bn还是有价值的,毕竟输入的分辨率大,单卡的bs比较小,一般4,8,16,这时候统计更大的bn会对模型收敛更好。

很明显可以看出来,当bs增加到4k的时候,acc下降了将近0.8%个点,1k的时候,下降了0.2%个点,所以,通常我们用大的bs训练的时候,是没办法达到最优的精度的。个人建议,使用1k的bs和0.4的学习率最优。

4 LARS(Layer-wise Adaptive Rate Scaling)

4.1. 理论分析

由于bs的增加,在同样的epoch的情况下,会使网络的weights更新迭代的次数变少,所以需要对LR随着bs的增加而线性增加,但是这样会导致上面我们看到的问题,过大的lr会导致最终的收敛不稳定,精度有所下降。

LARS代码如下:

class LARC(object):
    def __init__(self, optimizer, trust_coefficient=0.02, clip=True, eps=1e-8):
        self.optim = optimizer
        self.trust_coefficient = trust_coefficient
        self.eps = eps
        self.clip = clip

    def step(self):
        with torch.no_grad():
            weight_decays = []
            for group in self.optim.param_groups:
                # absorb weight decay control from optimizer
                weight_decay = group['weight_decay'] if 'weight_decay' in group else 0
                weight_decays.append(weight_decay)
                group['weight_decay'] = 0
                for p in group['params']:
                    if p.grad is None:
                        continue
                    param_norm = torch.norm(p.data)
                    grad_norm = torch.norm(p.grad.data)

                    if param_norm != 0 and grad_norm != 0:
                        # calculate adaptive lr + weight decay
                        adaptive_lr = self.trust_coefficient * (param_norm) / (
                                    grad_norm + param_norm * weight_decay + self.eps)

                        # clip learning rate for LARC
                        if self.clip:
                            # calculation of adaptive_lr so that when multiplied by lr it equals `min(adaptive_lr, lr)`
                            adaptive_lr = min(adaptive_lr / group['lr'], 1)

                        p.grad.data += weight_decay * p.data
                        p.grad.data *= adaptive_lr

        self.optim.step()
        # return weight decay control to optimizer
        for i, group in enumerate(self.optim.param_groups):
            group['weight_decay'] = weight_decays[i]

这里有一个超参数,trust_coefficient,也就是公式里面所提到的, 这个参数对精度的影响比较大,实验部分我们会给出结论。

4.2. 实验结论

可以很明显发现,使用了LARS,设置turst_confidence为1e-3的情况下,有着明显的掉点,设置为2e-2的时候,在1k和4k的情况下,有着明显的提升,但是2k的情况下有所下降。

LARS一定程度上可以提升精度,但是强依赖超参,还是需要细致的调参训练。

5 结论

  • 8卡进行分布式训练,使用1k的bs可以很好的平衡acc&speed。

  • LARS一定程度上可以提升精度,但是需要调参,做业务可以不用考虑,刷点的话要好好训练。

6 结束语

本文是提升分类模型acc系列的第一篇,后续会讲解一些通用的trick和数据处理的方法,敬请关注。

THE END !

文章结束,感谢阅读。您的点赞,收藏,评论是我继续更新的动力。大家有推荐的公众号可以评论区留言,共同学习,一起进步。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1803216.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SuntoryProgrammingContest2024(AtCoder Beginner Contest 357)(A~F)(最爱线段树的一集)

A - Sanitize Hands 题意: 模拟 // Problem: A - Sanitize Hands // Contest: AtCoder - SuntoryProgrammingContest2024(AtCoder Beginner Contest 357) // URL: https://atcoder.jp/contests/abc357/tasks/abc357_a // Memory Limit: 1024…

【python报错】TypeError: ‘dict_values‘ Object IsNot Subscriptable

【Python报错】TypeError: ‘dict_values’ object is not subscriptable 在Python中,字典(dict)提供了几种不同的视图对象,包括dict_keys、dict_values和dict_items。这些视图对象允许你以只读方式遍历字典的键、值或键值对。如果…

30-unittest生成测试报告(HTMLTestRunner插件)

批量执行完测试用例后,为了更好的展示测试报告,最好是生成HTML格式的。本文使用第三方HTMLTestRunner插件生成测试报告。 一、导入HTMLTestRunner模块 这个模块下载不能通过pip安装,只能下载后手动导入,下载地址是:ht…

Elasticsearch之写入原理以及调优

1、ES 的写入过程 1.1 ES支持四种对文档的数据写操作 create:如果在PUT数据的时候当前数据已经存在,则数据会被覆盖,如果在PUT的时候加上操作类型create,此时如果数据已存在则会返回失败,因为已经强制指定了操作类型…

vue ts 导入 @/assets/ 红色显示的问题解决

vue ts 导入 /assets/ 红色显示的问题解决 一、问题描述 在使用的时候这样导入会出现如上的错误。 在使用的时候,导入的类型也没有对应的代码提示,说明导入有问题。 二、解决 在 tsconfig.json 中添加如下内容: {"compilerOptions&…

【机器学习】因TensorFlow所适配的numpy版本不适配,用anaconda降低numpy的版本

目录 0 TensorFlow最高支持的numpy版本 1 激活你的环境(如果你正在使用特定的环境) 2 查找可用的NumPy版本 3 安装特定版本的NumPy 4. 验证安装 5.(可选)如果你更改了base环境 0 TensorFlow最高支持的numpy版本 要使用 …

pytorch 笔记:pytorch 优化内容(更新中)

1 Tensor创建类 1.1 直接创建Tensor,而不是从Python或Numpy中转换 不要使用原生Python或NumPy创建数据,然后将其转换为torch.Tensor直接用torch.Tensor创建或者直接:torch.empty(), torch.zeros(), torch.full(), torch.ones(), torch.…

关于python中的关键字参数

在python语言中存在两种传参方式: 第一种是按照先后顺序来传参,这种传参风格,称为“位置参数”这是各个编程语言中最普遍的方式。 关键字传参~按照形参的名字来进行传参! 如上图所示,在函数中使用关键字传参的最大作…

MySQL-备份(三)

备份作用:保证数据的安全和完整。 一 备份类别 类别物理备份 xtrabackup逻辑备份mysqldump对象数据库物理文件数据库对象(如用户、表、存储过程等)可移植性差,不能恢复到不同版本mysql对象级备份,可移植性强占用空间占…

为什么说组合优于继承?

在编程中,继承和组合是用于在面向对象语言中设计和构建类和对象的两种基本技术。 继承,它允许一个类(称为派生类或子类)从另一个类(称为基类或超类)继承属性和行为。换句话说,子类“是”超类的…

独立游戏之路 -- 获取OAID提升广告收益

Unity 之 获取手机:OAID、IMEI、ClientId、GUID 前言一、Oaid 介绍1.1 Oaid 说明1.2 移动安全联盟(MSA) 二、站在巨人的肩膀上2.1 本文实现参考2.2 本文实现效果2.3 本文相关插件 三、Unity 中获取Oaid3.1 查看实现源码3.2 工程配置3.3 代码实现3.4 场景搭建 四、总…

编译和运行qemu-uboot-arm64单板的Armbian系统

这篇文章ARM虚拟机安装OMV-CSDN博客遗留一个启动qemu-uboot-arm64单板Armbian镜像的问题,使用官方下载的镜像,会报错: fatal: no kernel available .... Failed to load /vmlinuz ...... qemu-system-aarch64 -smp 8 -m 8G -machine virt …

搭建多平台比价系统需要了解的电商API接口?

搭建一个多平台比价系统涉及多个步骤,以下是一个大致的指南: 1. 确定需求和目标 平台选择:确定你想要比较价格的平台,例如电商网站、在线旅行社等。数据类型:明确你需要收集哪些数据,如产品价格、产品名称…

苹果手机微信如何直接打印文件

在快节奏的工作和生活中,打印文件的需求无处不在。但你是否曾经遇到过这样的困扰:打印店价格高昂,让你望而却步?今天,我要给大家介绍一款神奇的微信小程序——琢贝云打印,让你的苹果手机微信直接变身移动打…

二叉树的实现(初阶数据结构)

1.二叉树的概念及结构 1.1 概念 一棵二叉树是结点的一个有限集合,该集合: 1.或者为空 2.由一个根结点加上两棵别称为左子树和右子树的二叉树组成 从上图可以看出: 1.二叉树不存在度大于2的结点 2.二叉树的子树有左右之分,次序不能…

后端启动项目端口冲突问题解决

后端启动项目端口冲突 原因: Vindows Hyper-V虚拟化平台占用了端口。 解决方案一: 查看被占用的端口范围,然后选择一个没被占用的端口启动项目。netsh interface ipv4 show excludedportrange protocoltcp 解决方案二: 禁用H…

调试环境搭建(Redis 6.X 版本)

今儿,我们来搭建一个 Redis 调试环境,目标是: 启动 Redis Server ,成功断点调试 Server 的启动过程。使用 redis-cli 启动一个 Client 连接上 Server,并使用 get key 指令,发起一次 key 的读取。 视频可见…

鸿蒙状态管理-@Builder自定义构建函数

Builder 将重复使用的UI元素抽象成一个方法 在build方法里调用 使其成为 自定义构建函数 Entry Component struct BuilderCase {build() {Column(){Row(){Text("西游记").fontSize(20)}.justifyContent(FlexAlign.Center).backgroundColor("#f3f4f5").hei…

技术革命的十年:计算机、互联网、大数据、云计算与AI

近10年来,计算机、互联网、大数据、云计算和人工智能等技术领域发展迅速,带来了巨大的变革和创新。以下是各个领域的发展历史、现状、问题瓶颈、未来趋势以及可能的奇点。 计算机技术: 发展历史: 过去:过去十年间&am…

SpringBoot+Vue学生作业管理系统【附:资料➕文档】

前言:我是源码分享交流Coding,专注JavaVue领域,专业提供程序设计开发、源码分享、 技术指导讲解、各类项目免费分享,定制和毕业设计服务! 免费获取方式--->>文章末尾处! 项目介绍047: 【…