【论文阅读】Long-Tailed Recognition via Weight Balancing(CVPR2022)附MaxNorm的代码

news2024/11/15 23:26:02

目录

  • 论文
  • 使用方法
    • weight decay
    • MaxNorm
  • 如果使用原来的代码报错的可以看下面这个

论文

问题:真实世界中普遍存在长尾识别问题,朴素训练产生的模型在更高准确率方面偏向于普通类,导致稀有的类别准确率偏低。
key:解决LTR的关键是平衡各方面,包括数据分布、训练损失和学习中的梯度。
文章主要讨论了三种方法: L2normalization, weight decay, and MaxNorm
本文提出了一个两阶段训练的范式
a. 利用调节权重衰减的交叉熵损失学习特征。
b. 通过调节权重衰减和Max Norm使用类平衡损失学习分类器。
一些有用的看法

  1. 研究表明,与联合训练特征学习和分类器学习的模型相比,解耦特征学习和分类器学习导致了显著的改进。
  2. 根据基准测试结果,通过集成专家模型或采用主动数据增强技术的自监督预训练来实现最好精度。
  3. 研究发现,SGD动量导致LTR出现问题,阻碍了进一步改善。
  4. 最近,Kang等人令人信服地证明了阶段性训练对LTR很重要。
  5. 权重衰减有助于学习隐藏层的平衡权重。
  6. 重要的是,我们的探索发现,虽然在分类器上使用L2规范化约束进行训练比简单训练有所改进,但它的表现不如下面描述的其他两个正则化。
  7. 与严格将所有滤波器权重的范数值设置为1的L2归一化不同,MaxNorm放松了这一约束,允许权重在训练期间在范数球内移动。
  8. 权重衰减中,不同数据集的最优λ各不相同——较大的数据集需要较小的权重衰减,直观地说,因为在更多数据上学习有助于泛化,因此需要较少的正则化。
    单阶段使用不平衡损失训练效果不好的原因:虽然他们没有解释为什么具有类平衡损失的单阶段训练表现不佳,但直观地说,这是因为类平衡损失人为地放大了从罕见的类训练数据计算的梯度,这损害了特征表示学习,从而损害了最终的LTR性能。
    本文作者使用了weight decay和max norm两种方法结合,因为发现两个结合效果更好。让模型不同类之间权重相差不会很大的同时,还能让这些权重缓慢增加。
    下面这幅图就是解释了这些方法的特点。
    在这里插入图片描述
    第一个就是普通方法训练的,它常见的类别权重增长快。
    第二个是L2 normalization,它把所有类别的权重都限定在一个常数。
    第三个是权重衰减,它的所有类的权重小,而且权重在增长。
    第四个是MaxNorm,它限制最大的权重。
    第五个是权重衰减和MaxNorm,会导致范数中的权重较小且平衡。

使用方法

weight decay

先定义好权重衰减的值。

weight_decay = 0.1 #weight decay value

然后在优化器中调用。Adam还有其他的都有weight_decay。

optimizer = optim.SGD([{'params': active_layers, 'lr': base_lr}], lr=base_lr, momentum=0.9, weight_decay=weight_decay)

MaxNorm

就是这个论文中的regularizers.py中的代码。只要会使用就好。就是要是不是作者代码中的模型的话,model.encoder.fc还需要根据自己的代码修改。

#使用前先定义好初始化好
pgdFunc = MaxNorm_via_PGD(thresh=thresh)
pgdFunc.setPerLayerThresh(model) # set per-layer thresholds这个是计算模型每一层的权重的阈值,这篇论文中只计算最后线性层的权重,并对最后线性层的权重进行限制

当模型训练一个epoch结束后,对已经更新完毕的模型权重进行限制,如果超过阈值就进行更新,让权重在最大范数的约束下。

 if pgdFunc:# Projected Gradient Descent
     pgdFunc.PGD(model)#对权重进行限制
import torch
import torch.nn as nn
import math
# The classes below wrap core functions to impose weight regurlarization constraints in training or finetuning a network.

class MaxNorm_via_PGD():
    def __init__(self, thresh=1.0, LpNorm=1, tau=1):
        self.thresh = thresh
        self.LpNorm = LpNorm
        self.tau = tau
        self.perLayerThresh = []

    def setPerLayerThresh(self, model):#根据指定的模型设置每层的阈值
        #set pre-layer thresholds
        self.perLayerThresh = []

        for curLayer in [model.encoder.fc.weight, model.encoder.fc.bias]:#遍历模型的最后两层
            curparam = curLayer.data#获取当前层的数据
            if len(curparam.shape) <= 1:#如果层只有一个维度,是一个偏置或者是一个1D的向量,则设置这一层的阈值为无穷大,继续下一层
                self.perLayerThresh.append(float('inf'))
                continue
            curparam_vec = curparam.reshape((curparam.shape[0], -1))#如果不是,把权重张量展开
            neuronNorm_curparam = torch.linalg.norm(curparam_vec, ord=self.LpNorm, dim=1).detach().unsqueeze(-1)#沿着第一维计算P番薯,结果存储
            curLayerThresh = neuronNorm_curparam.min() + self.thresh*(neuronNorm_curparam.max() - neuronNorm_curparam.min())#计算每一层的阈值及神经元范数的最小值加上最大值和最小值之间的缩放差
            self.perLayerThresh.append(curLayerThresh)#每层阈值存储

    def PGD(self, model):#定义PGD函数,用于在模型的参数上执行投影梯度下降,试试最大范数约束
        if len(self.perLayerThresh) == 0:#如果每层的阈值是空,用setPerLayerThresh方法初始化
            self.setPerLayerThresh(model)
        for i, curLayer in enumerate([model.encoder.fc.weight, model.encoder.fc.bias]):#遍历模型的最后两层
            curparam = curLayer.data#获取当前层的数据张量值
            curparam_vec = curparam.reshape((curparam.shape[0], -1))#变成一维
            neuronNorm_curparam = (torch.linalg.norm(curparam_vec, ord=self.LpNorm, dim=1)**self.tau).detach().unsqueeze(-1)#在最后加一维
            #计算权重张量中每行神经元番薯的tau次方
            scalingVect = torch.ones_like(curparam)#创建一个形状与当前层数据相同的张量,用1初始化
            curLayerThresh = self.perLayerThresh[i]#获取阈值

            idx = neuronNorm_curparam > curLayerThresh#创建bool保存超过阈值的神经元
            idx = idx.squeeze()#
            tmp = curLayerThresh / (neuronNorm_curparam[idx].squeeze())**(self.tau)#根据每层的阈值和超过阈值的神经元番薯计算缩放因子
            for _ in range(len(scalingVect.shape)-1):#扩展缩放因子以匹配当前层数据的维度
                tmp = tmp.unsqueeze(-1)

            scalingVect[idx] = torch.mul(scalingVect[idx],tmp)
            curparam[idx] = scalingVect[idx] * curparam[idx]
            curparam[idx] = scalingVect[idx] * curparam[idx]#通过缩放值更新当前层的数据,以便对超过阈值的神经元进行缩放。完成权重更新


如果使用原来的代码报错的可以看下面这个

我的网络只有一层是线性层idx = idx.squeeze(),idx是(1,1)形状的,squeeze就没了,所以报错,如果有这个原因的可以改成idx = idx.squeeze(1)。maxnorm只改最后两层/一层权重所以,定义了一个列表存储线性层只取最后两层或者一层。

class MaxNorm_via_PGD():
    # learning a max-norm constrainted network via projected gradient descent (PGD)
    def __init__(self, thresh=1.0, LpNorm=2, tau=1):
        self.thresh = thresh
        self.LpNorm = LpNorm
        self.tau = tau
        self.perLayerThresh = []

    def setPerLayerThresh(self, model):
        # set per-layer thresholds
        self.perLayerThresh = []#存储每一层的阈值
        self.last_two_linear_layers = []#提取线性层
        for name, module in model.named_children():
            if isinstance(module, nn.Linear):
                self.last_two_linear_layers.append(module)

        for linear_layer in self.last_two_linear_layers[-min(2, len(self.last_two_linear_layers)):]:  # here we only apply MaxNorm over the last two layers
            curparam = linear_layer.weight.data
            if len(curparam.shape) <= 1:
                self.perLayerThresh.append(float('inf'))
                continue
            curparam_vec = curparam.reshape((curparam.shape[0], -1))
            neuronNorm_curparam = torch.linalg.norm(curparam_vec, ord=self.LpNorm, dim=1).detach().unsqueeze(-1)
            curLayerThresh = neuronNorm_curparam.min() + self.thresh * (
                        neuronNorm_curparam.max() - neuronNorm_curparam.min())
            self.perLayerThresh.append(curLayerThresh)

    def PGD(self, model):
        if len(self.perLayerThresh) == 0:
            self.setPerLayerThresh(model)
        for i, curLayer in enumerate([self.last_two_linear_layers[-min(2,
                                                             len(self.last_two_linear_layers))]]):  # here we only apply MaxNorm over the last two layers

            curparam = curLayer.weight.data

            curparam_vec = curparam.reshape((curparam.shape[0], -1))
            neuronNorm_curparam = (
                        torch.linalg.norm(curparam_vec, ord=self.LpNorm, dim=1) ** self.tau).detach().unsqueeze(-1)
            scalingVect = torch.ones_like(curparam)
            curLayerThresh = self.perLayerThresh[i]

            idx = neuronNorm_curparam > curLayerThresh
            idx = idx.squeeze(1)
            tmp = curLayerThresh / (neuronNorm_curparam[idx].squeeze()) ** (self.tau)
            for _ in range(len(scalingVect.shape) - 1):
                tmp = tmp.unsqueeze(-1)

            scalingVect[idx] = torch.mul(scalingVect[idx], tmp)
            curparam[idx] = scalingVect[idx] * curparam[idx]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1420601.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AutoMQ Kafka 云上十倍成本节约的奥秘(一): SPOT 实例

近年来&#xff0c;无论是海外还是国内&#xff0c;虽然受疫情影响&#xff0c;公有云的市场规模增速有所放缓&#xff0c;但是云的市场总规模仍然是持续增长的。公有云作为一个各个国家重点布局的战略方向和其本身万亿级市场的定位[1]&#xff0c;我们学习用好云是非常有必要的…

彻底解决 MAC Android Studio gradle async 时出现 “connect timed out“ 问题

最近在编译一个比较老的项目&#xff0c;git clone 之后使用 async 之后出现一下现象&#xff1a; 首先确定是我网络本身是没有问题的&#xff0c;尝试几次重新 async 之后还是出现问题&#xff0c;网上找了一些方法解决了本问题&#xff0c;以此来记录一下问题是如何解决的。 …

网络地址相关函数一网打尽

这块的函数又多又乱&#xff0c;今天写篇日志&#xff0c;以后慢慢补充 1. 网络地址介绍 1.1 ipv4 1.1.1 点、分十进制的ipv4 你对这个地址熟悉吗&#xff1f; 192.168.10.100&#xff0c;这可以当做一个字符串。被十进制数字、 “ . ”分开。IP地址的知识就不再多讲…

关于MyBatis和JVM的最常见的十道面试题

ORM项目中类属性名和数据库字段名不一致会导致什么问题&#xff1f;它的解决方案有哪些&#xff1f; 在ORM项目中&#xff0c;如果类的属性名称和数据库字段名不一致会场导致插入、修改时设置的这个不一致字段为null&#xff0c;查询的时候即使数据库有数据&#xff0c;但是查…

Jenkins如何从GIT下拉项目并启动Tomcat

一、先添加服务器 二、添加视图 点击控制台输出&#xff0c;滑到最下面&#xff0c;出现这个就说明构建成功了&#xff0c;如果没有出现&#xff0c;说明构建有问题&#xff0c;需要解决好问题才能启动哦~

Python 九九乘法表的7种实现方式

Python 九九乘法表的7种实现方式 九九乘法表是初学者学习编程的必要练手题目之一&#xff0c;因此各种语言都有对应的实现方式&#xff0c;而 Python 也不例外。在 Python 中&#xff0c;我们可以使用多种方式来生成一个简单的九九乘法表。 实现方式一&#xff1a;双重循环 f…

使用 Node.js 和 Cheerio 爬取网站图片

写一个关于图片爬取的小案例 爬取效果 使用插件如下&#xff1a; {"dependencies": {"axios": "^1.6.0","cheerio": "^1.0.0-rc.12","request": "^2.88.2"} }新建一个config.js配置文件 // 爬取图片…

Android T 远程动画显示流程(更新中)

序 本地动画和远程动画区别是什么? 本地动画&#xff1a;自给自足。对自身SurfaceControl矢量动画进行控制。 远程动画&#xff1a;拿来吧你&#xff01;一个app A对另一个app B通过binder跨进程通信&#xff0c;控制app B的SurfaceControl矢量动画。 无论是本地动画还是远程…

F5负载均衡有何技术优势?为你详细解读

当今数字化时代&#xff0c;网络应用的性能对于企业的成功至关重要。负载均衡建立在现有网络结构之上&#xff0c;提供了有效的方法扩展网络设备和服务器的带宽、增加吞吐量、加强网络数据处理能力、提高网络的灵活性和可用性。F5负载均衡技术则成为了许多企业实现高可用性和高…

原生table样式

HTML <div><table style"width: 100%;"><thead><tr><th style"width:25%;">董事会</th><th style"width:25%;">监事会</th><th style"width:25%;">股东</th><th sty…

物理信息神经网络PINN2024最新改良方案汇总(含复现代码)

传统的数值方法在处理复杂问题时可能需要大量的计算资源和时间&#xff0c;而改良后的PINN可以通过更有效的算法减少计算成本&#xff0c;使得求解过程更加高效。 在写论文时&#xff0c;我们也可以通过改进PINN减少数据需求、加速模型收敛、提高预测准确性、增强可解释性&…

linux -- 内存管理 -- SLAB分配器

SLAB分配器&#xff08;slab allocator&#xff09; SLAB分配器用于小内存空间管理&#xff0c;基本思想是&#xff1a;先利用页面分配器分配出单个或多个连续的物理页面&#xff0c;然后再此基础上将整块页面分割为多个相等的小内存单元&#xff0c;来满足小内存空间分配的需…

kerberos+kafka(2.13)认证(单节点ubuntu)

一&#xff1a;搭建kerberos。 1. 运行安装命令 apt-get install krb5-admin-server krb5-kdc krb5-user krb5-config2. 检查服务是否启动。 systemctl status krb5-admin-server systemctl status krb5-kdcsystemctl start krb5-admin-server systemctl startkrb5-kdc3. 修…

网络安全知识和华为防火墙

网络安全 网络空间安全 ---Cyberspace 2003年美国提出的网络空间概念 ---一个由信息基础设施组成的互相依赖的网络。 我国官方文件定义&#xff1a;网络空间为继海、陆、空、天以外的第五大人类互动领域。 通信保密阶段 --- 计算机安全阶段 --- 信息系统安全 --- 网络空间安…

校园教学气象站是什么

TH-XQ3在当今社会&#xff0c;气象科学的重要性日益凸显。它不仅关系到农业、交通、航空等多个领域的安全&#xff0c;更对人类的生活产生深远影响。因此&#xff0c;许多学校纷纷开设气象学相关课程&#xff0c;帮助学生了解气象知识&#xff0c;培养他们的科学素养。而在这其…

【数据结构:顺序表】

文章目录 线性表顺序表1.1 顺序表结构的定义1.2 初始化顺序表1.3 检查顺序表空间1.4 打印1.5 尾插1.6 头插1.7 尾删1.8 头删1.9 查找1.10 指定位置插入1.11 删除指定位置数据1.12 销毁顺序表 数据结构(Data Structure)是计算机存储、组织数据的方式&#xff0c;指相互之间存在一…

termux 玩法(一)

termux基础 termux基础玩法推荐国光写的手册&#xff1a;Termux 高级终端安装使用配置教程 | 国光 (sqlsec.com) termux安装 个人使用F-Droid安装的termux&#xff1a;Termux | F-Droid - Free and Open Source Android App Repository 基础知识 这些基础知识简单了解一下…

HDFS Federation前世今生

一 背景 熟悉大数据的人应该都知道&#xff0c;HDFS 是一个分布式文件系统&#xff0c;它是基于谷歌的GFS实现的开源系统&#xff0c;设计目的就是提供一个高度容错性和高吞吐量的海量数据存储解决方案。在经典的HDFS架构中有2个NameNode和多个DataNode&#xff0c;如下 从上面…

【C/C++ 02】希尔排序

希尔排序虽然是直接插入排序的升级版本&#xff0c;和插入排序有着相同的特性&#xff0c;即原始数组有序度越高则算法的时间复杂度越低&#xff08;预排序机制&#xff09;&#xff0c;但是是不稳定排序算法。 为了降低算法的时间复杂度&#xff0c;所以我们需要在排序之前尽…

3D效果图加树进去太卡,渲染太慢怎么办?

周末的时候&#xff0c;有个朋友私信来问&#xff1a;3dmax模型加树进去打开时特别的卡&#xff0c;是怎么回事。 不知道有没有朋友遇上这么个情况。 3dmax加树建议就用代理&#xff0c;这样相比于直接加而言&#xff0c;会流畅许多。 在3D效果图中&#xff0c;“树代理”是…