Pytorch 中Label Smoothing CrossEntropyLoss实现

news2025/3/1 11:33:13

一. 前言

一般情况下我们都是直接调用Pytorch自带的交叉熵损失函数计算loss,但涉及到魔改以及优化时,我们需要自己动手实现loss function,在这个过程中如果能对交叉熵损失的代码实现有一定的了解会帮助我们写出更优美的代码。

其次是标签平滑这个trick通常简单有效,只需要改改损失函数既可带来性能上的提升,通常与交叉熵配合食用。

因此,本文基于这两个出发点,介绍基于Pytorch框架下的交叉熵损失实现以及标签平滑的实现。

二. CrossEntropyLoss

相信大家对于如何计算交叉熵已经非常熟悉,常规步骤是①计算softmax得到各类别置信度;②计算交叉熵损失。但其实从Pytorch的官方文档可以看出,还有更一步到位的方法,如下:
CE
这避免了softmax的计算。

三. 代码实现

class CELoss(nn.Module):
    ''' Cross Entropy Loss'''
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        ''' 
        Args:
            pred: prediction of model output    [N, M]
            target: ground truth of sampler [N]
        '''
        eps = 1e-12
      	# standard cross entropy loss
        loss = -1.*pred.gather(1, target.unsqueeze(-1)).reshape(-1,1) + torch.log(torch.exp(pred+eps).sum(dim=1)).reshape(-1,1)

        return loss.mean()

具体细节参考我前面的文章 Pytorch中CrossEntropyLoss()详解。

四. Label Smoothing

Label Smoothing也称之为标签平滑,其实是一种防止过拟合的正则化方法。传统的分类loss采用softmax loss,先对全连接层的输出计算softmax,视为各类别的置信度概率,再利用交叉熵计算损失。
Label Smooth
Label Smooth

在这个过程中尽可能使得各样本在正确类别上的输出概率为1,这要使得对应的z值为+∞,这拉大了其与其他类别间的距离

现在假设一个多分类任务标签是[1,0,0],如果它本身的label的出现了问题,这对模型的伤害是非常大的,因为在训练的过程中强行学习一个非本类的样本,并且让其概率非常高,这会影响对后验概率的估计。并且有时候类与类之间的并不是毫无关联,如果鼓励输出的概率间相差过大,这会导致一定程度上的过拟合

因此Label Smoothing的想法是让目标不再是one-hot标签,而是变为如下形式:
Label Smooth
其中ε为一个较小的常数,这使得softmax损失中的概率优目标不再为1和0,同时z值的最优解也不再是正无穷大,而是一个具体的数值。这在一定程度上避免了过拟合,也缓解了错误标签带来的影响。

五. Label Smoothing CrossEntropyLoss实现

基于上一节的交叉熵实现增加标签平滑功能,代码如下:

class CELoss(nn.Module):
    ''' Cross Entropy Loss with label smoothing '''
    def __init__(self, label_smooth=None, class_num=137):
        super().__init__()
        self.label_smooth = label_smooth
        self.class_num = class_num

    def forward(self, pred, target):
        ''' 
        Args:
            pred: prediction of model output    [N, M]
            target: ground truth of sampler [N]
        '''
        eps = 1e-12
        
        if self.label_smooth is not None:
            # cross entropy loss with label smoothing
            logprobs = F.log_softmax(pred, dim=1)	# softmax + log
            target = F.one_hot(target, self.class_num)	# 转换成one-hot
            
            # label smoothing
            # 实现 1
            # target = (1.0-self.label_smooth)*target + self.label_smooth/self.class_num 	
            # 实现 2
            # implement 2
            target = torch.clamp(target.float(), min=self.label_smooth/(self.class_num-1), max=1.0-self.label_smooth)
            loss = -1*torch.sum(target*logprobs, 1)
        
        else:
            # standard cross entropy loss
            loss = -1.*pred.gather(1, target.unsqueeze(-1)).reshape(-1,1) + torch.log(torch.exp(pred+eps).sum(dim=1)).reshape(-1,1)

        return loss.mean()

实现1采用了 (1.0-self.label_smooth)*target +self.label_smooth/self.class_num 实现,与原始公式不太一样
后续在了解到pytorch的clamp接口后,发现能够利用其能正确实现原公式,见实现2

六. 试验验证

① 交叉熵损失正确率,与标准的交叉熵比较:

	loss1 = nn.CrossEntropyLoss()
    loss2 = CELoss(label_smooth=None, class_num=3)

    x = torch.tensor([[1, 8, 1], [1, 1, 8]], dtype=torch.float)
    y = torch.tensor([1, 2])

    print(loss1(x, y), loss2(x, y))
	# tensor(0.0018) tensor(0.0018)

② 标签平滑结果展示:

	loss1 = nn.CrossEntropyLoss()
    loss2 = CELoss(label_smooth=0.05, class_num=3)

    x = torch.tensor([[1, 8, 1], [1, 1, 8]], dtype=torch.float)
    y = torch.tensor([1, 2])

    print(loss1(x, y), loss2(x, y))
	# tensor(0.0018) tensor(0.2352)

另一组结果:

	x = torch.tensor([[0.1, 8, 0.1], [0.1, 0.1, 8]], dtype=torch.float)
    y = torch.tensor([1, 2])

    print(loss1(x, y), loss2(x, y))
    # tensor(0.0007) tensor(0.2641)

分析:拉大模型输出数值间的差距后,原始的交叉熵会变小,而增加了标签平滑的反而变大。这也反映了标签平滑后,并不是概率越接近于1越好,而是接近某个小于1的值,这使得模型的输出不再是越高(+∞)越好。

七. 参考链接

Pytorch:交叉熵损失(CrossEntropyLoss)以及标签平滑(LabelSmoothing)的实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/38874.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Day13--自定义组件-封装自定义属性和click事件

提出问题: 当前我们search搜索框的背景颜色和圆角边框都是写死的,使用者没有办法修改器背景颜色和圆角尺寸。那么为了让这个组件更加通用性强一些。 ***********************************************************************************************…

用DIV+CSS技术设计的个人电影网站(web前端网页制作课作业)

HTML实例网页代码, 本实例适合于初学HTML的同学。该实例里面有设置了css的样式设置,有div的样式格局,这个实例比较全面,有助于同学的学习,本文将介绍如何通过从头开始设计个人网站并将其转换为代码的过程来实践设计。 文章目录一、网页介绍一…

框架体系——Spring

Spring IOC IOC控制反转 IOC 控制反转,全称Inverse of Control,是一种设计理念由代理人来创建和管理对象,消费者通过代理人来获取对象Ioc的目的是降低对象之间的耦合通过加入Ioc容器将对象统一管理,将对象关联变为弱耦合。 DI…

MyBatis中有哪些注解呢?

转自: MyBatis中有哪些注解呢? 为了简化 XML 的配置,MyBatis 提供了注解。我们可以通过 MyBatis 的 jar 包查看注解,如下图所示。 以上注解主要分为三大类,即 SQL 语句映射、结果集映射和关系映射 下面分别进行讲解 一、SQL 语句…

Allegro中如何进行尺寸标注

摘要本文介绍了如何在Allegro中进行尺寸标注,包含各种标注样式的区别、如何设置参数、如何显示单位、如何导出带尺寸的PDF与DXF等信息。 一. 为什么要尺寸标注PCB尺寸标注的作用: 方便设计人员明确板子的大小,以及安装位置的各种细节&#xf…

react学习笔记3--数据双向绑定,组件通信

一、表单处理 1、受控组件-input元素 通过设置input元素的value值(或复选框的checked值)实现Getter,通过监听onChange事件实现Setter,从而实现数据双向绑定。 class element extends React.Component {state {txt:""…

vulhub靶场搭建与使用

vulhub靶场搭建与使用1.前言2.配置yum源2.1备份原来的源文件2.2 配置阿里源2.3重置yum源2.4更新yum源3.安装docket3.1安装docket3.2启动docket3.3设置国内镜像源3.4重启docket4.安装docker-compose4.1安装dockers-compose4.2提升权限5.安装vulhub5.1安装git5.2下载vulhub5.3下载…

自知识蒸馏(知识蒸馏二)

自知识蒸馏(知识蒸馏二)自知识蒸馏(知识蒸馏二)Born-Again Neural Networks(ICML2018)方法为什么有效实验结果Training Deep Neural Networks in Generations: A More Tolerant Teacher Educates Better St…

MyBatis工作原理

MyBatis工作流程: 具体介绍: (1) MyBatis 读取核心配置文件mybatis-config.xml mybatis-config.xml核心配置文件主要配置了MyBatis的运行环境等信息。 (2)加载映射文件Mapper.xml Mapexm文件即SQL映射文件,该文件配置了操作数据库的SOL语句&a…

Python+Appium移动端自动化测试框架实现

一、Appium 概述 1、Appium 简介 Appium是一个开源的自动化测试框架,可以用来测试基于iOS、Android和Firefox OS 平台的原生与混合的应用。 该框架使用Selenium WebDriver,在执行测试时用于和Selenium Server 通信的是JSON Wire Protocol。在Selenium 2中,Appium将取代 i…

【CNN】经典网络LeNet——最早发布的卷积神经网络之一

前言 LeNet是Yann LeCun于1988年提出的用于数字识别的网络结构,可以说LeNet是深度CNN网络的基石,AlexNet、VGG、GoogLeNet、ResNet等都是在VGG基础上加入各类激活函数或加深网络演变而来的,所以理解LeNet对于现在主流CNN深度学习架构的理解有…

制作一个简单HTML电影网页设计(HTML+CSS)

HTML实例网页代码, 本实例适合于初学HTML的同学。该实例里面有设置了css的样式设置,有div的样式格局,这个实例比较全面,有助于同学的学习,本文将介绍如何通过从头开始设计个人网站并将其转换为代码的过程来实践设计。 文章目录一、网页介绍一…

基于蚁群算法的多配送中心的车辆调度问题的研究(Matlab代码实现)

👨‍🎓个人主页:研学社的博客 💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜…

【图像处理】小波编码图像中伪影和纹理的检测附Matlab代码和报告

✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。 🍎个人主页:Matlab科研工作室 🍊个人信条:格物致知。 更多Matlab仿真内容点击👇 智能优化算法 …

如果各位同学还对时间复杂度有疑问?看这一篇就可以啦!

🎇🎇🎇作者: 小鱼不会骑车 🎆🎆🎆专栏: 《java练级之旅》 🎓🎓🎓个人简介: 一名专科大一在读的小比特,努力学习编程是我…

chrome浏览器一键切换搜索引擎,一键切换谷歌和百度搜索

chrome浏览器一键切换搜索引擎,一键切换谷歌和百度搜索 背景 有么有办法在谷歌和百度之间(或其他引擎或非引擎,如Youtube、B站、Bing等)之间切换。我们当然是不想重新输入keyword,甚至点击浏览器插件的图标后再选择引…

Scala010--Scala中的常用集合函数及操作Ⅰ

之前我们已经知道了Scala中的数据结果有哪些,并且能够使用for循环取到该数据中的元素,现在我们再进一步的去了解更加方便及常用的函数操作,使得我们能够对集合更好的利用。 目录 一,foreach函数 1,遍历一维数组 1&…

Pytorch中CrossEntropyLoss()详解

一、损失函数 nn.CrossEntropyLoss() 交叉熵损失函数 nn.CrossEntropyLoss() ,结合了 nn.LogSoftmax() 和 nn.NLLLoss() 两个函数。 它在做分类(具体几类)训练的时候是非常有用的。 二. 什么是交叉熵 交叉熵主要是用来判定实际的输出与期望…

HTML CSS个人网页设计与实现——人物介绍丁真(学生个人网站作业设计)

🎉精彩专栏推荐👇🏻👇🏻👇🏻 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业…

SpringBoot SpringBoot 原理篇 1 自动配置 1.8 bean 的加载方式【六】

SpringBoot 【黑马程序员SpringBoot2全套视频教程,springboot零基础到项目实战(spring boot2完整版)】 SpringBoot 原理篇 文章目录SpringBootSpringBoot 原理篇1 自动配置1.8 bean 的加载方式【六】1.8.1 ImportSelector1 自动配置 1.8 b…