loss盘点: asl loss (Asymmetric Loss) 代码解析详细版

news2024/9/20 16:36:52

1. BCE公式部分

可以简单浏览下这篇博客的文章:
https://blog.csdn.net/qq_14845119/article/details/114121003

这是多分类 经典 B C E L o s s BCELoss BCELoss 公式
L = − y L + − ( 1 − y ) L − L = -y L_{+} - (1-y) L_{-} L=yL+(1y)L

其中, L + / − L_{+/-} L+/ 是正负例预测概率的log值,即:

L + = l o g ( y ^ ) L − = l o g ( 1 − y ^ ) y ^ = s i g m o i d ( l o g i t ) \begin{aligned} L_{+} &= log( \hat{y} )\\ L_{-} &= log( 1- \hat{y} )\\ \hat{y} &= sigmoid( logit ) \end{aligned} L+Ly^=log(y^)=log(1y^)=sigmoid(logit)

实际上由于 l a b e l label label 标签 y y y 值,是一个 0 / 1 0/1 0/1 矩阵,实际上充当了一个掩码 m a s k mask mask 的作用,挑选出 L + L_{+} L+ 中正例部分 和 L − L_{-} L 中负例部分

假设:

y = [ 0 0 1 0 ] y = \begin{bmatrix} 0 & 0 \\ 1 & 0 \end{bmatrix} y=[0100]

y ^ = [ 0.5 0.1 0.3 0.2 ]   L + = [ − 0.6931 − 2.3026 − 1.2040 − 1.6094 ]   L − = [ − 0.6931 − 0.1054 − 0.3567 − 0.2231 ] \hat{y} = \begin{bmatrix} 0.5 & 0.1 \\ 0.3 & 0.2 \end{bmatrix} \ L_{+} = \begin{bmatrix} -0.6931 & -2.3026 \\ -1.2040 & -1.6094 \end{bmatrix} \ L_{-} = \begin{bmatrix} -0.6931 & -0.1054 \\ -0.3567 & -0.2231 \end{bmatrix} y^=[0.50.30.10.2] L+=[0.69311.20402.30261.6094] L=[0.69310.35670.10540.2231]

所以, L L L 左下角为 L + L_{+} L+对应的值的相反数,左上角和右上角和右下角为 L − L_{-} L对应的值的相反数

L = [ 0.6931 0.1054 1.2040 0.2231 ] L = \begin{bmatrix} 0.6931 & 0.1054 \\ 1.2040 & 0.2231 \end{bmatrix} L=[0.69311.20400.10540.2231]

代码验证:

x = torch.tensor([0.5, 0.1, 0.3, 0.2]).reshape(2, 2).float()
y = torch.tensor([0, 0, 1, 0]).reshape(2, 2).float()
torch.nn.functional.binary_cross_entropy(x, y, reduction='none')
tensor([[0.6931, 0.1054],
        [1.2040, 0.2231]])

(不要小看这个 mask 代码的操作,一会儿写 asl 代码会用的上)

2. focal loss 公式部分

基本公式依旧是这个:
L = − y L + − ( 1 − y ) L − L = -y L_{+} - (1-y) L_{-} L=yL+(1y)L

L + L_{+} L+ L − L_{-} L 如下:
L + = ( 1 − p ) γ ∗ l o g ( p ) L − = p γ ∗ l o g ( 1 − p ) p = s i g m o i d ( l o g i t ) \begin{aligned} L_{+} &= (1-p)^{\gamma} * log(p) \\ L_{-} &= p^{\gamma} * log(1-p) \\ p &= sigmoid(logit) \end{aligned} L+Lp=(1p)γlog(p)=pγlog(1p)=sigmoid(logit)

3. asl 公式部分

asl loss 是 focal loss的改进版

L + = ( 1 − p ) γ + ∗ l o g ( p ) L − = p m γ − ∗ l o g ( 1 − p m ) p = s i g m o i d ( l o g i t ) p m = m a x ( p − m , 0 ) \begin{aligned} L_{+} &= (1-p)^{\gamma_{+}} &*& log(p) \\ L_{-} &= p_m^{\gamma_{-}} &*& log(1-p_m) \\ p &= sigmoid(logit) \\ p_m &= max(p-m, 0) \end{aligned} L+Lppm=(1p)γ+=pmγ=sigmoid(logit)=max(pm,0)log(p)log(1pm)

由于 p m p_m pm 仅在 L − L_{-} L 中存在,而 p p p一般出现在 L + L_{+} L+中, ( 1 − p ) (1-p) (1p)一般出现在 L − L_{-} L中,所以将 p m p_m pm 做一些反向操作

先引入一个引理,显然成立,x和y都是函数(或者变量),二者中大的加上负号,就是二者相反数中小的

− m a x ( x , y ) = = m i n ( − x , − y ) -max(x, y) == min(-x, -y) max(x,y)==min(x,y)

所以:
p m = m a x ( p − m , 0 ) = − m i n ( m − p , 0 ) − p m = m i n ( m − p , 0 ) 1 − p m = m i n ( m − p , 0 ) + 1 1 − p m = m i n ( m − p + 1 , 1 ) 1 − p m = m i n ( m + 1 − p , 1 ) 1 − p m = n p . c l i p ( m + 1 − p , m a x = 1 ) \begin{aligned} p_m &= max(p-m, 0) \\ &= -min(m-p, 0) \\ -p_m &= min(m-p, 0) \\ 1-p_m &= min(m-p, 0) + 1 \\ 1-p_m &= min(m-p+ 1, 1) \\ 1-p_m &= min(m+ 1-p, 1) \\ 1-p_m &= np.clip(m+ 1-p, max=1) \\ \end{aligned} pmpm1pm1pm1pm1pm=max(pm,0)=min(mp,0)=min(mp,0)=min(mp,0)+1=min(mp+1,1)=min(m+1p,1)=np.clip(m+1p,max=1)

这一行咱等会要用到

4. asl 代码

看看 asl loss 的代码,torch代码来自:
https://github.com/Alibaba-MIIL/ASL/blob/main/src/loss_functions/losses.py

  • self.gamma_neg γ − \gamma_{-} γ
  • self.gamma_pos γ + \gamma_{+} γ+
  • self.eps 是用作 log 函数内部,防止溢出
class AsymmetricLossOptimized(nn.Module):
    ''' Notice - optimized version, minimizes memory allocation and gpu uploading,
    favors inplace operations'''

    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
        super(AsymmetricLossOptimized, self).__init__()

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
        self.eps = eps

        # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations
        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None

    def forward(self, x, y):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """

        self.targets = y
        self.anti_targets = 1 - y

        # 分别计算正负例的概率
        self.xs_pos = torch.sigmoid(x)
        self.xs_neg = 1.0 - self.xs_pos

        # 非对称裁剪
        if self.clip is not None and self.clip > 0:
            self.xs_neg.add_(self.clip).clamp_(max=1)  # 给 self.xs_neg 加上 clip 值

        # 先进行基本交叉熵计算
        self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
        self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(False)
                
            # 以下 4 行相当于做了个并行操作
            self.xs_pos = self.xs_pos * self.targets
            self.xs_neg = self.xs_neg * self.anti_targets
            self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
                                          self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)
            
            if self.disable_torch_grad_focal_loss:
                torch.set_grad_enabled(True)
            self.loss *= self.asymmetric_w

        return -self.loss.sum()

来咱单独看一下代码:

# 非对称裁剪
if self.clip is not None and self.clip > 0:
	self.xs_neg.add_(self.clip).clamp_(max=1)  # 给 self.xs_neg 加上 clip 值

这两行用于计算:
1 − p m = n p . c l i p ( m + 1 − p , m a x = 1 ) \begin{aligned} 1-p_m &= np.clip(m+ 1-p, max=1) \end{aligned} 1pm=np.clip(m+1p,max=1)

# 先进行基本交叉熵计算
self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps))
self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps)))

这两行用于计算红框部分:
在这里插入图片描述
注意 self.targetsself.anti_targets 都相当于掩码 mask 的作用,此处的 self.loss 矩阵的shape是和 self.targets 一样的 shape,不理解可以回忆一下 BCE公式部分 的计算

而前面的 幂 相当于权重,就是代码中的 self.asymmetric_w,也就是此处的:
在这里插入图片描述

self.asymmetric_w 是这样计算的,这部分很妙!

self.xs_pos = self.xs_pos * self.targets
self.xs_neg = self.xs_neg * self.anti_targets
self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg,
                              self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets)

插一句 torch.pow 该函数会将两个shape相同的张量的对应位置做幂运算,看这个例子

>>> x = torch.tensor([1, 2, 3, 4])
>>> y = torch.tensor([2, 2, 3, 1])
>>> torch.pow(x, y)
tensor([ 1,  4, 27,  4])

计算 self.asymmetric_w 时,只需将pow的 x x x 参数对应位置写成 ( 1 − p ) (1-p) (1p) 或者 p m p_m pm,将pow的 y y y 参数对应位置写成 γ − \gamma_{-} γ 或者 γ + \gamma_{+} γ+ 即可,先看简单的, y y y 参数这里计算:

self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets

也是通过 self.targets 的 mask 操作来进行的,而 x x x 参数这样计算:

1 - self.xs_pos - self.xs_neg

当计算 L + L_{+} L+ 时,self.xs_neg==0 x x x 参数对应位置就是 1 - self.xs_pos (1-p)
当计算 L − L_{-} L 时,self.xs_pos==0 x x x 参数对应位置就是 1 - self.xs_neg ( 1 − ( 1 − p m ) ) = p m (1-(1-p_m))=p_m (1(1pm))=pm

通过一个 torch.pow 巧妙的计算了 self.asymmetric_w NICE!

之后二者对应位置相乘即可

self.loss *= self.asymmetric_w

5. asl 代码 Paddle 实现

class AsymmetricLossOptimizedWithLogit(nn.Layer):
    ''' Notice - optimized version, minimizes memory allocation and gpu uploading,
    favors inplace operations'''

    def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-5, disable_paddle_grad_focal_loss=False):
        super(AsymmetricLossOptimizedWithLogit, self).__init__()

        self.gamma_neg = gamma_neg
        self.gamma_pos = gamma_pos
        self.clip = clip
        self.disable_paddle_grad_focal_loss = disable_paddle_grad_focal_loss
        self.eps = eps

        self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None

    def forward(self, x, y, weights=None):
        """"
        Parameters
        ----------
        x: input logits
        y: targets (multi-label binarized vector)
        """

        self.targets = y
        self.anti_targets = 1 - y

        # Calculating Probabilities
        self.xs_pos = F.sigmoid(x)
        self.xs_neg = 1.0 - self.xs_pos

        # Asymmetric Clipping
        if self.clip is not None and self.clip > 0:
            # self.xs_neg.add_(self.clip).clip_(max=1)
            self.xs_neg = (self.xs_neg + self.clip).clip_(max=1)

        # Basic CE calculation
        self.loss = self.targets * paddle.log(self.xs_pos.clip(min=self.eps))
        self.loss.add_(self.anti_targets * paddle.log(self.xs_neg.clip(min=self.eps)))

        # Asymmetric Focusing
        if self.gamma_neg > 0 or self.gamma_pos > 0:
            if self.disable_paddle_grad_focal_loss:
                paddle.set_grad_enabled(False)
            self.xs_pos = self.xs_pos * self.targets
            self.xs_neg = self.xs_neg * self.anti_targets
            self.asymmetric_w = paddle.pow(1 - self.xs_pos - self.xs_neg,
                                            (self.gamma_pos * self.targets + \
                                            self.gamma_neg * self.anti_targets).astype("float32"))
            if self.disable_paddle_grad_focal_loss:
                paddle.set_grad_enabled(True)
            self.loss *= self.asymmetric_w
       
        if weights is not None:
            self.loss *= weights
            
        _loss = -self.loss.sum()

        return _loss
    
    
if __name__ == "__main__":
    
    np.random.seed(11070109)
    x = np.random.randn(3, 3)
    x = paddle.to_tensor(x).cast("float32")
    y = (x > 0.5).cast("float32")
    
    loss = AsymmetricLossOptimizedWithLogit()
    out = loss(x, y)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/142078.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Docker保姆级学习教程

文章目录1、什么是Docker1.1、容器技术1.2、容器与虚拟机比较1.3、Docker特点1、更高效的利用系统资源2、更快速的启动时间3、一致的运行环境4、持续支付和部署5、更轻松的迁移6、更轻松的维护和扩展2、Docker组件学习2.1、Docker客户端和服务器2.2、Docker镜像2.3、Registry&a…

奇怪的知识——Windows下怎么修改进程的名称?

📢欢迎点赞 :👍 收藏 ⭐留言 📝 如有错误敬请指正,赐人玫瑰,手留余香!📢本文作者:由webmote 原创📢作者格言:无尽的折腾后,终于又回到…

element-plus的form表单form-item的prop怎么写才能正确校验,实现逻辑是怎么样的?

不管是element-plus还是上一个版本的element-ui,都是一个使用很广泛的基于csshtmljs的ui组件库,它的form表单自带强大的校验功能,form-item的prop怎么写才正确,实现逻辑是怎么样的?element-plus的form表单的model、for…

聚观早报 | 苹果市值跌破2万亿美元大关;卢伟冰晋升小米集团总裁

今日要闻:苹果市值跌破2万亿美元大关;卢伟冰晋升小米集团总裁;京东方拿下iPhone 15订单;英伟达与富士康达成合作;哪吒汽车旗下车型价格调整苹果市值跌破2万亿美元大关 1 月 4 日消息,据国外媒体报道&#x…

C51单片机连接wifi模块,发送AT指令

一、AT指令AT 指令集是从终端设备( Terminal Equipment , TE) 或 数据终端设备 ( Data TerminalEquipment , DTE) 向终端适配器 (Terminal Adapter , TA) 或 数据电路终端设备 (Data CircuitTerminal Equipment &#…

CDGA|企业数字化转型进展得越快就越好吗?

数据治理并不是一件简单的事情。即使是行业知名公司,在高度重视和确保投入的情况下,完成全公司“数据底座”/“数据中台”的所耗时间也往往以年计。并且,还需要注意到,数据规范只是数字化转型的一个维度而已: 在国标《…

SQL INSERT INTO 语句

INSERT INTO 语句用于向表中插入新记录。 SQL INSERT INTO 语句 INSERT INTO 语句用于向表中插入新记录。 SQL INSERT INTO 语法 INSERT INTO 语句可以有两种编写形式。 第一种形式无需指定要插入数据的列名,只需提供被插入的值即可: INSERT INTO t…

Python爬虫常用哪些库?

经常游弋在互联网爬虫行业的程序员来说,如何快速的实现程序自动化,高效化都是自身技术的一种沉淀的结果,那么使用Python爬虫都会需要那些数据库支持?下文就是有关于我经常使用的库的一些见解。 请求库: 1、urllib&a…

matlab复杂函数多元函数拟合

简介 本文介绍了基于matlab实现的复杂函数以及多元函数的拟合。在工程和研究中偶尔会遇到要用一个非常复杂的数学公式来拟合实验测量数据,对这些复杂的数学公式拟合时,采用常见的拟合方法往往会失败,或者得不到足够精确的结果。本文以笔者多…

AVL树:高度平衡的二叉搜索树

AVL树 AVL树和BST树的联系   答:BST树(二叉排序树)当节点的数据key有序时是一棵单支树,查找时效率直接降低到O(N)而不是树高,为了使树尽量两边均匀,设计出了AVL树,AVL树的左右高度差不超过1。…

sql语句练习题1

1、选择部门30中的所有员工; 要注意到查的是所有员工 代码如下: mysql> select * from emp where deptno 30;2、列出所有办事员(CLERK)的姓名,编号和部门编号; 注意的是要查的是姓名,编号和部门编号 范围限定的是…

并发编程的原子性 != 事务ACID的原子性

△Hollis, 一个对Coding有着独特追求的人△这是Hollis的第 412 篇原创分享作者 l Hollis来源 l Hollis(ID:hollischuang)关于原子性,很多人在多个地方都听说过,大家也都背的很熟悉。在事务的ACID中,有原子性…

儒家思想和道家思想的三个主要差异

孔子、孟子、老子、庄子,这四位古代思想家被称为“中国四哲”,他们分别代表了儒家和道家思想。这两大思想流派,是数千年来中国人智慧的结晶和文化的瑰宝。01先秦儒家思想的发展,经过了三个阶段,第一阶段是孔子&#xf…

CHAPTER 7 Ansible playbook(四)

ansible-playbook7.1 roles(角色)7.1.1 Ansible Roles 介绍7.1.2 Roles结构7.1.3 存储和查找角色7.1.4 制作一个Role7.1.5 使用角色7.1.5.1 经典方法7.1.5.2 import_role7.1.6 如何使用Galaxy7.1 roles(角色) 7.1.1 Ansible Role…

windows docker安装prometheus和grafana

文章目录docker安装prometheusdocker安装grafanawindows安装windows_exporterprometheus配置新增windows_exporter的job,配置grafana导入windows模板即可出现酷炫大屏出现酷炫画面完成docker安装prometheus 拉取镜像,在D盘下创建prometheus.yml配置文件,映射到docker里面d:/se…

【pandas】教程:8-如何组合多个表格的数据

Pandas 组合多个表格的数据 本节使用的数据为 data/air_quality_no2_long.csv,链接为 pandas案例和教程所使用的数据-机器学习文档类资源-CSDN文库 导入数据 NO2NO_2NO2​ import pandas as pd air_quality_no2 pd.read_csv("data/air_quality_no2_long.cs…

二、python编程进阶02:模块和包

目录 1. python中的模块是什么 2. 导入模块: 学习import语句 2.1 import语句介绍 2.2 import导入模块的语法 2.3 导入自己的模块 2.4 导入数字开头或者带空格的模块 3. 编写自定义模块 3.1 给自定义模块编写测试代码 3.2 给自定义模块模块编写说明文档 4. 模块的搜索…

1215. 小朋友排队(树状数组应用 -- 逆序对个数)

题目如下: 思路 or 题解 我们可以得出交换的次数 > 逆序对个数 kkk 我们可以发现 所有 位置 左边大于它的个数 右边小于它的个数和 kik_iki​ 等于 k∗2k*2k∗2 我们可以简单证明出(感觉出):答案就是 ∑1n(1ki)∗ki2\sum^n_1 \frac{(1 k_i) * k_i}…

JavaScript 错误

文章目录JavaScript 错误 - throw、try 和 catchJavaScript 错误JavaScript 抛出(throw)错误JavaScript try 和 catchThrow 语句实例实例JavaScript 错误 - throw、try 和 catch try 语句测试代码块的错误。 catch 语句处理错误。 throw 语句创建自定义错…

string的使用介绍

目录 标准库中的string类 string类(了解) 编码介绍 string类的常用接口说明 Member functions 测试一:创建对象 测试二:遍历字符串 Iterators 测试三:反向迭代器(Iterators) Capacity 测试四:容器相关(Capacity) 测试…