IB 公式解析

news2024/11/28 12:35:16

公式

3.2. Influence Function

影响函数允许我们在移除样本时估计模型参数的变化,而无需实际移除数据并重新训练模型。

 3.3 影响平衡加权因子

 3.4 影响平衡损失

 3.5 类内重加权

m代表一个批次(batch)的大小,这意味着公式对一个批次中的所有样本进行计算,然后去平均值。

 代码

criterion_ib = IBLoss(weight=per_cls_weights, alpha=1000).cuda()
def ib_loss(input_values, ib):
    """Computes the focal loss"""
    loss = input_values * ib
    return loss.mean()
class IBLoss(nn.Module):
    def __init__(self, weight=None, alpha=10000.):
        super(IBLoss, self).__init__()
        assert alpha > 0
        self.alpha = alpha
        self.epsilon = 0.001
        self.weight = weight

    def forward(self, input, target, features):
        grads = torch.sum(torch.abs(F.softmax(input, dim=1) - F.one_hot(target, num_classes)),1) # N * 1
        ib = grads * features.reshape(-1)
        ib = self.alpha / (ib + self.epsilon)
        return ib_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), ib)

1.计算梯度 grads

grads = torch.sum(torch.abs(F.softmax(input, dim=1) - F.one_hot(target, num_classes)), 1) # N * 1
  • 计算 softmax 概率分布F.softmax(input, dim=1) 将模型的输出转换为概率分布。
  • 计算 one-hot 编码F.one_hot(target, num_classes) 将目标标签转换为 one-hot 编码。
  • 计算绝对差值:通过计算 softmax 输出与 one-hot 编码之间的绝对差值,得到每个样本的梯度,表示样本对模型的损失贡献。

 2. 计算影响平衡因子(IB Factor)

ib = grads * features.reshape(-1)
ib = self.alpha / (ib + self.epsilon)

影响平衡因子(IB Factor)确实与梯度成反比。梯度越大,IB因子越小,分配给该样本的权重越小;梯度越小,IB因子越大,分配给该样本的权重越大。这一机制确保了模型在处理不平衡数据时,能够更有效地减小对多数类样本的过拟合,提升对少数类样本的泛化能力。

 3. 计算最终损失

return ib_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), ib)

将论文中的公式与代码对应起来

论文中的公式:

对应代码

首先,我们来看影响平衡损失 IBLoss 的代码实现:

class IBLoss(nn.Module):
    def __init__(self, weight=None, alpha=10000.):
        super(IBLoss, self).__init__()
        assert alpha > 0
        self.alpha = alpha
        self.epsilon = 0.001
        self.weight = weight

    def forward(self, input, target, features):
        grads = torch.sum(torch.abs(F.softmax(input, dim=1) - F.one_hot(target, num_classes)), 1) # N * 1
        ib = grads * features.reshape(-1)
        ib = self.alpha / (ib + self.epsilon)
        return ib_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), ib)

对应关系

  1. 批次大小 m

    在代码中,批次大小由 train_loadertest_loader 的批次大小参数决定。

  2. 数据集 𝐷𝑚

    代码中的 train_loadertest_loader 提供了批次数据。

  3. 类别权重因子 𝜆𝑘

    在代码中,通过 per_cls_weights 来实现:

per_cls_weights = 1.0 / np.array(cls_num_list)
per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()

4.基本损失函数 𝐿(𝑦,𝑓(𝑥,𝑤))

代码中使用 torch.nn.CrossEntropyLoss 计算交叉熵损失:

base_loss = F.cross_entropy(input, target, reduction='none', weight=self.weight)

5.模型输出 𝑓(𝑥,𝑤)f(x,w)

在代码中,模型的输出为 input

output, features = model(images)

6.模型输出与真实标签的 L1 范数 ∥𝑓(𝑥,𝑤)−𝑦∥1

在代码中,通过以下方式计算:

grads = torch.sum(torch.abs(F.softmax(input, dim=1) - F.one_hot(target, num_classes)), 1) # N * 1

7.隐藏层特征向量 ℎ 和其 L1 范数 ∥ℎ∥1**

在代码中,通过以下方式计算隐藏层特征向量的 L1 范数:

features = torch.sum(torch.abs(feats), 1).reshape(-1, 1)

8.最终影响平衡因子 IB

在代码中,通过以下方式计算:

ib = grads * features.reshape(-1)
ib = self.alpha / (ib + self.epsilon)

 9.最终影响平衡损失 𝐿IB(𝑤)

通过自定义的 ib_loss 函数计算:

return ib_loss(F.cross_entropy(input, target, reduction='none', weight=self.weight), ib)

 为什么类别权重因子要这样实现

per_cls_weights = 1.0 / np.array(cls_num_list)
per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()

类别权重因子的实现旨在通过加权样本来处理类别不平衡问题。以下是详细解释为什么要这样实现 per_cls_weights 以及每一步的作用:

代码实现

per_cls_weights = 1.0 / np.array(cls_num_list)
per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()

每一步的解释

计算每个类别的逆频率

per_cls_weights = 1.0 / np.array(cls_num_list)
  • cls_num_list 是每个类别的样本数量列表。例如,如果有三个类别,且每个类别的样本数量为 [100, 200, 50],则 cls_num_list = [100, 200, 50]
  • 通过取倒数 1.0 / np.array(cls_num_list),我们得到了每个类别的逆频率。例如,结果将是 [0.01, 0.005, 0.02]
  • 逆频率反映了类别数量的稀少程度,样本数量少的类别(少数类)将得到更高的权重。

归一化权重

per_cls_weights = per_cls_weights / np.sum(per_cls_weights) * len(cls_num_list)
  • 首先,计算权重的总和 np.sum(per_cls_weights)。根据前面的例子,总和为 0.01 + 0.005 + 0.02 = 0.035
  • 然后,将每个类别的权重除以总和,使得所有权重的和为 1。这是标准化步骤,使得权重变为 [0.01/0.035, 0.005/0.035, 0.02/0.035],即 [0.2857, 0.1429, 0.5714]
  • 接下来,将这些标准化权重乘以类别的数量 len(cls_num_list)。在这个例子中,类别数量是 3。因此,最终的权重变为 [0.2857*3, 0.1429*3, 0.5714*3],即 [0.8571, 0.4286, 1.7143]

这一步的作用是确保每个类别的权重和类别数量成正比,同时保持权重的总和为类别数量。

转换为 PyTorch 张量

per_cls_weights = torch.FloatTensor(per_cls_weights).cuda()
  • 将 NumPy 数组转换为 PyTorch 张量,以便在 PyTorch 中使用这些权重。
  • 将权重张量移动到 GPU(如果可用),以加速计算。

归一化步骤的原因

归一化权重的目的是确保类别权重的相对比例合理,并且所有权重的总和与类别数量一致。这有助于避免某些类别被赋予过高或过低的权重,从而确保训练过程中的稳定性和效果。

处理类别不平衡的原因

类别不平衡问题是指在数据集中,不同类别的样本数量差异很大。在这种情况下,传统的损失函数往往会被多数类主导,导致模型在少数类上的性能较差。通过加权样本,特别是对少数类样本赋予更高的权重,可以平衡各类样本对损失的贡献,从而改善模型在少数类上的表现。

总结

  • 逆频率权重:通过取样本数量的倒数,使得样本数量少的类别得到更高的权重。
  • 归一化:将权重标准化,并确保权重的总和与类别数量一致,保持权重比例的合理性。
  • 转换为张量:将权重转换为 PyTorch 张量,以便在训练过程中使用。

这种权重计算方法确保了在处理类别不平衡问题时,少数类样本对损失函数的贡献增加,从而提高模型对少数类的识别能力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1668435.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【qt】最快的开发界面效率——混合编程

混合编程 一.准备工作1.创建项目2.添加项目资源 二.ui界面设计1.menuBar菜单栏2.action ▲3.toolBar工具栏4.中心组件 三.代码界面设计1.toolBar添加组件2.statusBar状态栏添加组件 四.完成界面的功能1.对action配置信号槽2.对action转到信号槽3.代码添加的组件手动关联槽函数 …

鸿蒙内核源码分析(Shell编辑篇) | 两个任务,三个阶段

系列篇从内核视角用一句话概括shell的底层实现为:两个任务,三个阶段。其本质是独立进程,因而划到进程管理模块。每次创建shell进程都会再创建两个任务。 客户端任务(ShellEntry): 负责接受来自终端(控制台)敲入的一个个字符&…

第五步->手撕spring源码之资源加载器解析到注册

本步骤目标 在完成 Spring 的框架雏形后,现在我们可以通过单元测试进行手动操作 Bean 对象的定义、注册和属性填充,以及最终获取对象调用方法。但这里会有一个问题,就是如果实际使用这个 Spring 框架,是不太可能让用户通过手动方式…

PD-L1表达与免疫逃逸和免疫响应

免疫检查点信号转导和癌症免疫治疗(文献)-CSDN博客https://blog.csdn.net/hx2024/article/details/137470621?ops_request_misc%257B%2522request%255Fid%2522%253A%2522171551954416800184136566%2522%252C%2522scm%2522%253A%252220140713.130102334.…

webrtc windows 编译,以及peerconnection_client

webrtc windows环境编译,主要参考webrtc官方文档,自备梯子 depot tools 安装 Install depot_tools 因为我用的是windows,这里下载bundle 的安装包,然后直接解压,最后设置到环境变量PATH。 执行gn等命令不报错&…

A计算机上的程序与B计算机上部署的vmware上的虚拟机的程序通讯 如何配置?

环境: 在A计算机上运行着Debian11.3 Linux操作系统;在B计算机上运行着Windows10操作系统,并且安装了VMware软件,然后在VMware上创建了虚拟机C并安装了CentOS 6操作系统 需求: 现在A计算机上的程序需要同虚拟机C上的软…

【递归、回溯和剪枝】全排列 子集

0.回溯算法介绍 什么是回溯算法 回溯算法是⼀种经典的递归算法,通常⽤于解决组合问题、排列问题和搜索问题等。 回溯算法的基本思想:从⼀个初始状态开始,按照⼀定的规则向前搜索,当搜索到某个状态⽆法前进时,回退到前…

docker容器实现https访问

前言: 【云原生】docker容器实现https访问_docker ssl访问-CSDN博客 一术语介绍 ①key 私钥 明文--自己生成(genrsa ) ②csr 公钥 由私钥生成 ③crt 证书 公钥 签名(自签名或者由CA签名) ④证书&#xf…

Eclipse下载安装教程(包含JDK安装)【保姆级教学】【2024.4已更新】

目录 文章最后附下载链接 第一步:下载Eclipse,并安装 第二步:下载JDK,并安装 第三步:Java运行环境配置 安装Eclipse必须同时安装JDK !!! 文章最后附下载链接 第一步&#xf…

Go编程语言的调试器Delve | Goland远程连接Linux开发调试(go远程开发)

文章目录 Go编程语言的调试器一、什么是Delve二、delve 安装安装报错cgo: C compiler "gcc" not found: exec: "gcc": executable file not found in $PATH解决 三、delve命令行使用delve 常见的调试模式常用调试方法todo调试程序代码与动态库加载程序运行…

Unity编辑器如何多开同一个项目?

在联网游戏的开发过程中,多开客户端进行联调是再常见不过的需求。但是Unity并不支持编辑器多开同一个项目,每次都得项目打个包(耗时2分钟以上),然后编辑器开一个进程,exe 再开一个,真的有够XX的。o(╥﹏╥)o没错&#…

Windows下安装 Emscripten 详细过程

背景 最近研究AV1编码标准的aom编码器,编译的过程中发现需要依赖EMSDK,看解释EMSDK就是Emscripten 的相应SDK,所以此博客记录下EMSDK的安装过程;因为之前完全没接触过Emscripten 。 Emscripten Emscripten 是一个用于将 C 和 …

JAVA基础--IO

IO 什么是IO 任何事物提到分类都必须有一个分类的标准,例如人,按照肤色可分为:黄的,白的,黑的;按照性别可分为:男,女,人妖。IO流的常见分类标准是按照*流动方向*和*操作…

k8s 数据流向 与 核心概念详细介绍

目录 一 k8s 数据流向 1,超级详细版 2,核心主键及含义 3,K8S 创建Pod 流程 4,用户访问流程 二 Kubernetes 核心概念 1,Pod 1.1 Pod 是什么 1.2 pod 与容器的关系 1.3 pod中容器 的通信 2, …

Servlet讲解

Servlet生命周期 我们只需要继承Servlet接口,查看方法即可看出Servlet的生命周期 import java.io.IOException;import javax.servlet.Servlet; import javax.servlet.ServletConfig; import javax.servlet.ServletException; import javax.servlet.ServletRequest…

哈希(构造哈希函数)

哈希 哈希也可以叫散列 画一个哈希表 哈希冲突越多&#xff0c;哈希表效率越低。 闭散列开放定址法: 1.线性探测&#xff0c;依次往后去找下一个空位置。 2.二次探测&#xff0c;按2次方往后找空位置。 #pragma once #include<vector> #include<iostream> #i…

Oracle11g账户频繁被锁定的3种解决办法

方法1&#xff1a;创建触发器 方法1&#xff1a;数据库中创建触发器(只记录失败)&#xff0c;但是需要开发同意或者开发自己创建。找到密码输入错误的服务器&#xff0c;进行数据源配置的更改。 该方法适用于要求找到密码错误用户所在服务器的场景下。 CREATE OR REPLACE TR…

Llama 3 + Groq 是 AI 天堂

我们将为生成式人工智能新闻搜索创建一个后端。我们将使用通过 Groq 的 LPU 提供服务的 Meta 的 Llama-3 8B 模型。 关于Groq 如果您还没有听说过 Groq&#xff0c;那么让我为您介绍一下。 Groq 正在为大型语言模型 (LLM) 中文本生成的推理速度设定新标准。 Groq 提供 LPU&…

kettle经验篇:MongoDB-delete插件问题

目录 项目场景 问题分析 解决方案 MongoDB Delete插件使用总结 项目场景 项目使用的ODS层数据库是MongoDB&#xff1b;在数据中心从DB层向ODS层同步数据过程中&#xff0c;发现有张ODS表在同步过程中&#xff0c;数据突然发生锐减&#xff0c;甚至于该ODS表数据清0。 同步…

dragonbones 5.6.3不能导出的解决办法

问题描述 使用dragonbones 5.6.3导出资源时无反应。 解决方法 第一步安装node.js&#xff0c;我这里使用的是V18.16.0第二步进入到DragonBonesPro\egretlauncher\server\win目录&#xff0c;然后把里面的node.exe替换为刚刚下载的node文件夹即可&#xff0c;如下图&#xff…