博文总结:交叉熵损失函数与标签平滑

news2024/11/15 18:00:40

文章目录

  • 基本概念
  • 交叉熵损失函数
  • Pytorch代码实现
  • 参考文献

李宏毅机器学习2023作业04Self-attention、李宏毅机器学习2023作业03CNN和李宏毅机器学习2023作业02Classification都是分类问题,都涉及到了交叉熵损失函数以及起正则作用的标签平滑技巧,本次博文把以上两点整理总结下。

基本概念

1、信息量:在信息论中,一个不太可能发生的事件居然发生了,我们收到的信息要多于一个非常可能发生的事情发生。因此,事件包含的信息量应与其发生的概率负相关。数学表达式定义为:假设 X X X是取值集合为 { x 1 , x 2 , . . . , x n } \{x_1,x_2,...,x_n\} {x1,x2,...,xn} 的离散型随机变量,定义事件 X = x i X=x_i X=xi的信息量为 I ( x i ) = − log ⁡ 2 P ( X = x i ) I(x_i)=-\log_2P(X=x_i) I(xi)=log2P(X=xi)
  这里采用 log ⁡ \log log函数的形式主要是为了体现信息量的三个性质:事件发生的概率越低,信息量越大;事件发生的概率越高,信息量越低;多个事件同时发生的概率是多个事件概率相乘,总信息量是多个事件信息量相加。

2、:通常用熵对整个事件的平均信息量进行描述,即上述信息量定义关于概率分布 P P P的期望:
H ( P ) = E X ∼ P [ − log ⁡ 2 P ( x ) ] = − ∑ i = 1 n P ( x i ) ⋅ log ⁡ 2 P ( x i ) \mathrm{H}(P)=\mathbb{E}_{X \sim P}[-\log _2 P(x)]=-\sum_{i=1}^n P(x_i) \cdot \log_2P(x_i) H(P)=EXP[log2P(x)]=i=1nP(xi)log2P(xi)总而言之,信息熵是用来衡量事物不确定性的,信息熵越大,事物越具不确定性。通常接近确定性的分布(输出几乎可以确定)具有较低的熵,那些接近均匀分布的概率分布具有较高的熵。

3、KL散度(相对熵):一般被用于计算两个分布之间的不同
D K L ( P ∣ ∣ Q ) = E X ∼ P [ log ⁡ 2 P ( x ) Q ( x ) ] = ∑ i = 1 n P ( x i ) ⋅ log ⁡ 2 P ( x i ) Q ( x i ) \mathrm{D}_{KL}(P||Q)=\mathbb{E}_{X \sim P}[\log _2 \frac{P(x)}{Q(x)}]=\sum_{i=1}^n P(x_i) \cdot \log_2\frac{P(x_i)}{Q(x_i)} DKL(P∣∣Q)=EXP[log2Q(x)P(x)]=i=1nP(xi)log2Q(xi)P(xi)将上式展开之后可以发现前者是,而后者通常定义为交叉熵,因此KL散度(相对熵)=交叉熵-熵

4、交叉熵
H ( P , Q ) = E X ∼ P [ − log ⁡ 2 Q ( x ) ] = − ∑ i = 1 n P ( x i ) ⋅ log ⁡ 2 Q ( x i ) \mathrm{H}(P,Q)=\mathbb{E}_{X \sim P}[-\log _2 Q(x)]=-\sum_{i=1}^n P(x_i) \cdot \log_2Q(x_i) H(P,Q)=EXP[log2Q(x)]=i=1nP(xi)log2Q(xi)

交叉熵损失函数

1、在深度学习中,我们总是希望模型学到的分布 P ( m o d e l ) P(model) P(model)和真实数据的分布 P ( r e a l ) P(real) P(real)越接近越好,最直接的损失函数就是利用KL散度使得两个分布的差异性最小。但我们没有真实数据的分布,那么只能退而求其次,希望模型学到的分布和训练数据的分布 P ( t r a i n i n g ) P(training) P(training)尽量相同
2、由于训练数据是给定的,因此KL散度中的熵就是恒定的,那么,最小化交叉熵就是最小化KL散度
3、在分类任务中,训练数据的标签通常才有用one-hot编码的形式,假设类别总数为3类,给定一个样本的标签为 [ 1 , 0 , 0 ] [1,0,0] [1,0,0]的形式,该样本对应的模型输出为 [ 0.8 , 0.1 , 0.1 ] [0.8,0.1,0.1] [0.8,0.1,0.1]的形式,直接代入交叉熵的公式 − ( 1 ⋅ log ⁡ 2 ( 0.8 ) + 0 ⋅ log ⁡ 2 ( 0.1 ) + 0 ⋅ log ⁡ 2 ( 0.1 ) ) = − log ⁡ 2 ( 0.8 ) -(1\cdot \log_2(0.8)+0\cdot \log_2(0.1)+0\cdot \log_2(0.1))=-\log_2(0.8) (1log2(0.8)+0log2(0.1)+0log2(0.1))=log2(0.8)计算结果中只有one-hot编码形式的标签中为1的对应项。在具体的Pytorch代码中,几行代码就可以实现交叉熵损失函数的定义、计算损失、计算梯度:

criterion = nn.CrossEntropyLoss()
......
output_tensor = model(input_tensor)
loss = criterion(output_tensor, target_tensor)
loss.backward()
...... 

Pytorch代码实现

1、在Pytorch1.9.0中,交叉熵损失函数的定义形式如下:
交叉熵损失函数
紧接着“This criterion combines LogSoftmax and NLLLoss in one single class.”的描述给出了两点信息:一是包含了LogSoftmax和NLLLoss两个函数,二是用于单类别问题(即一个样本只对应一个类别)
2、函数中第一个参数 w e i g h t weight weight是可手动定义的1D Tensor,假如分类问题类别总数为 C C C,参数 w e i g h t weight weight的长度就是 C C C,在训练集中各个类别占比不平衡时通过设置不同的权重会特别有用;第二个参数 s i z e _ a v e r a g e size\_average size_average和第四个参数 r e d u c e reduce reduce已经被替代为第五个参数 r e d u c t i o n reduction reduction,默认为 ′ m e a n ′ 'mean' mean,对一个 b a t c h batch batch范围内所有样本的交叉熵损失求平均,也可以取值 ′ s u m ′ 'sum' sum对一个 b a t c h batch batch范围内所有样本的交叉熵损失求总和以及取值 ′ n o n e ′ 'none' none保持交叉熵损失的尺寸,即与 T a r g e t Target Target的尺寸一致; i g n o r e _ i n d e x ignore\_index ignore_index表示该类别对应的样本对最终的交叉熵损失没有任何贡献。
3、函数包括2个输入: I n p u t Input Input T a r g e t Target Target和1个输出: O u t p u t Output Output,对于 I n p u t Input Input而言,它是来自深度学习模型未归一化的原始的各个类别的置信度,通常对应的尺寸是 b a t c h × C batch\times C batch×C或者 b a t c h × C × d 1 × d 2 × . . . × d K batch\times C\times d_1\times d_2\times ... \times d_K batch×C×d1×d2×...×dK,在2023作业02Classification、2023作业03CNN和2023作业04Self-attention中对应的是 b a t c h × C batch\times C batch×C,其中2023作业02Classification的BossBaseline方法是 b a t c h × C × S e q L e n g t h batch\times C\times SeqLength batch×C×SeqLength,当然如果不嫌麻烦的话,可以把 b a t c h × C × d 1 × d 2 × . . . × d K batch\times C\times d_1\times d_2\times ... \times d_K batch×C×d1×d2×...×dK通过维度转化,把后面的数据尺寸维度合并到 b a t c h batch batch维转换成二维的形式
4、对于 T a r g e t Target Target而言,通常对应的尺寸是 b a t c h batch batch或者 b a t c h × d 1 × d 2 × . . . × d K batch\times d_1\times d_2\times ... \times d_K batch×d1×d2×...×dK,取值为 [ 0 , C − 1 ] [0, C-1] [0,C1],最终交叉熵损失函数就是像上文提到的例子一样,当给定 b a t c h batch batch中的一个样本时,计算结果中只有one-hot编码形式的标签中为1的对应项,下式中 j j j是类别索引,取值范围为 [ 0 , C − 1 ] [0, C-1] [0,C1]
在这里插入图片描述
如果考虑参数 w e i g h t weight weight,那么会在上式中乘以对应标签类别的权重
在这里插入图片描述
通常在一个 b a t c h batch batch范围内对所有样本取平均,就是如下的形式:
在这里插入图片描述
可以看出,分母是一个 b a t c h batch batch范围内所有样本对应的真实标签类别的权重之和,意味着对每个样本的交叉熵损失进行了加权平均。
5、自Pytorch1.10开始,torch.nn.CrossEntropyLoss内置了标签平滑的参数 l a b e l _ s m o o t h i n g label\_smoothing label_smoothing带标签平滑的交叉熵损失函数有点没整明白(网上资料挺多的,但是有点杂),先做个标记吧! 如下的代码是可以直接用的,和官方的Pytorch代码结果是一致的:

import torch
import torch.nn as nn
import torch.nn.functional as F


def linear_combination(x, y, epsilon):
    return epsilon * x + (1 - epsilon) * y


def reduce_loss(loss, reduction='mean'):
    return loss.mean() if reduction == 'mean' else loss.sum() if reduction == 'sum' else loss


class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, epsilon: float = 0.1, reduction='mean'):
        super().__init__()
        self.epsilon = epsilon
        self.reduction = reduction

    def forward(self, preds, target):
        # 如果数据样本除了batch,class_num外还有其他维度,需要先转换成batch*class_num的二维形式
        # 这样就和官方版本的代码完全一致了
        if len(preds.size()) >= 3:
            batch, class_num = preds.size()[0: 2]
            preds = preds.transpose(0, 1).reshape(class_num, -1).transpose(0, 1)
            target = target.reshape(-1)
        n = preds.size()[-1]
        log_preds = F.log_softmax(preds, dim=-1)
        loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
        nll = F.nll_loss(log_preds, target, reduction=self.reduction)
        return linear_combination(loss / n, nll, self.epsilon)


criterion1 = LabelSmoothingCrossEntropy()
output = torch.randn(3, 5, 10, requires_grad=True)
target = torch.empty(3, 10, dtype=torch.long).random_(5)
loss1 = criterion1(output, target)
criterion2 = nn.CrossEntropyLoss(label_smoothing=0.1)
loss2 = criterion2(output, target)
print(loss1)
print(loss2)

参考文献

1.PyTorch中的Loss Fucntion
2.百度百科
3.交叉熵损失函数(Cross Entropy Loss)
4.为什么交叉熵(cross-entropy)可以用于计算代价?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1164054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

项目实战:给首页上库存名称添加超链接然后带fid跳转到edit页面

1、提取公共方法common.js function $(key){if(key){if(key.startsWith("#")){key key.substring(1)return document.getElementById(key)}else{let nodeList document.getElementsByName(key)return Array.from(nodeList)}} } 2、 给库存名称添加超链接 2.1、inde…

Qt Creator创建新项目警告问题

这里可以看见如果你是一些高版本会出现各种警告,但是可以编译通过,这是ClangCodeModel模块导致 解决办法 help -> About Plugins..->C ->ClangCodeModel 帮助 -> 关于插件 -> c ->ClangCodeModel取消勾选 然后重启Qt即可

【Java初阶练习题】-- 循环+递归练习题

循环练习题02 打印X图形计算1/1-1/21/3-1/41/5 …… 1/99 - 1/100 的值输出一个整数的每一位如:123的每一位是3,2,1模拟登录使用方法求最大值求斐波那契数列的第n项。(迭代实现)求和的重载求最大值方法的重载递归求N阶乘递归求 1 2 3 ...…

Redis的安装及基本使用

⭐⭐ Redis专栏:Redis专栏 ⭐⭐ 个人主页:个人主页 目录 一.Redis的简介 ⭐ 拓展:NO-SQL数据库与SQL数据库 二.Redis的安装 2.1linux版安装 下载Redis Desktop 2.2 Windows安装 三.redis的基本使用 3.1 String 字符串类…

如何从嘉立创下单一个PCB打板(免费)

文章目录 设计PCB下单制作PCB领取优惠券 设计PCB 由于我刚接触PCB设计,并不会自己设计,因此直接选择了一个开源硬件平台中的一个项目进行下载,下载链接如下: ESP32管灯熊猫 - 嘉立创EDA开源硬件平台 (oshwhub.com) 根据其中的视…

jeecg-uniapp 转成小程序的过程 以及报错 uniapp点击事件

uniapp 点击事件 tap: 单击事件 confirm: 回车事件 blur:失去焦点事件 touchstart: 触摸开始事件 touchmove: 触摸移动事件。 touchend: 触摸结束事件。 longpress: 长按事件。 input: 输入框内容变化事件。 change: 表单元素值变化事件。 submit: 表单提交事件。 scroll: 滚动…

程序员有哪些规避风险的合法兼职渠道?

近期,承德程序员事件冲上热搜,这对许多程序员的心灵是多么大的伤害啊! 人人自危,大家开始顾虑自己接私活、找兼职的方式和前景了。毕竟,谁也不想”辛辛苦苦几十年,一把回到解放前“。那有什么办法既可以接私…

【自动控制原理】数学模型:系统框图及其化简、控制系统传递函数

文章目录 第2章 数学模型2.1 控制系统的运动微分方程2.2 拉氏变换和反变换2.3 传递函数2.4 系统框图2.4.1 系统框图2.4.2 系统框图的简化2.4.3 梅森公式2.4.4 例题答案解析——梅森公式 2.5 控制系统传递函数2.5.1 闭环系统的开环传递函数2.5.2 参考输入R(s)作用下的闭环传递函…

ONNX的结构与转换

ONNX的结构与转换 1. 背景2. ONNX结构分析与修改工具2.1. ONNX结构分析2.2. ONNX的兼容性问题2.3. 修改ONNX模型 3. 各大深度学习框架如何转换到ONNX?3.1. MXNet转换ONNX3.2. TensorFlow模型转ONNX3.3. PyTorch模型转ONNX3.4. PaddlePaddle模型转ONNX3.4.1. 简介3.4…

zabbix6.4监控centos

1、关闭防火墙 setenforce 0 #关闭SELinux sed -i "s/SELINUX=enforcing/SELINUX=disabled/g" /etc/selinux/config #设置永久关闭SELinux systemctl stop firewalld.service #关闭防火墙 systemctl disable firewalld.service …

nodejs express vue 点餐外卖系统源码

开发环境及工具: nodejs,vscode(webstorm),大于mysql5.5 技术说明: nodejs express vue elementui 功能介绍: 用户端: 登录注册 首页显示搜索菜品,轮播图&#xf…

DNS 域名解析协议

作用 将域名转化位IP地址 域名 用’ . ’ 隔开的字符串,如:www.badu.com,就是为了赋予IP特殊含义。 一级域名 .com :公用 .cn:中国 .gov:政府 .us:美国 .org:组织 .net:网站 对应一级…

MATLAB和S7-1200PLC OPC通信(激活S7-1200PLC OPC UA服务器)

MATLAB和SMART PLC OPC通信请参考下面文章博客: MATLAB和西门子SMART PLC OPC通信-CSDN博客文章浏览阅读123次。西门子S7-200SMART PLC OPC软件的下载和使用,请查看下面文章Smart 200PLC PC Access SMART OPC通信_基于pc access smart的opc通信_RXXW_Dor的博客-CSDN博客OPC是…

【算法挨揍日记】day18——746. 使用最小花费爬楼梯、91. 解码方法

746. 使用最小花费爬楼梯 746. 使用最小花费爬楼梯 题目描述: 给你一个整数数组 cost ,其中 cost[i] 是从楼梯第 i 个台阶向上爬需要支付的费用。一旦你支付此费用,即可选择向上爬一个或者两个台阶。 你可以选择从下标为 0 或下标为 1 的台阶开始爬…

进程优先级(nice值,top指令),独立性,竞争性,进程切换(时间片),抢占与出让,并发并行概念

目录 优先级 引入 为什么会存在优先级 特点 优先级值 nice值 更改nice值 top指令 独立性 竞争性 进程切换 引入 时间片 上下文切换 调度器 抢占与出让 强占 出让 并发和并行 并发 并行 优先级 引入 与权限不同的是,权限是能不能做的问题,优先级是什…

有谁知道怎么下载微信视频号视频吗?

抖音视频下载、某站视频下载都很常见,那你知道怎么下载V信视频号视频吗/今天给大家分享两种简单方便的办法,继续往下看吧!一、犀牛视频下载机器人犀牛视频下载器可以直接解析并下载视频号短视频。您只需转发视频到机器人即可下载。此方法也是…

Redis安装-常用命令及操作

目录 一.Redis简介 二.redis安装 1.1安装Linux版本 1.2安装 windows版本 三.redis的常用命令 Redis哈希(Hash) 一.Redis简介 Redis是一个开源(BSD许可),内存存储的数据结构服务器,可用作数据库,高速缓存和消息队…

LLMs之RAG之IncarnaMind:IncarnaMind的简介(提高RAG召回率的两个优化技巧=滑块遍历反向查找+独立查询)、安装、使用方法之详细攻略

LLMs之RAG之IncarnaMind:IncarnaMind的简介(提高RAG召回率的两个优化技巧滑块遍历反向查找独立查询)、安装、使用方法之详细攻略 导读:在IncarnaMind项目中,提出了几个优化技巧,是非常值得我们在优化RAG系统的时候,进行…

分享66个工作总结PPT,总有一款适合您

分享66个工作总结PPT,总有一款适合您 66个工作总结PPT下载链接:https://pan.baidu.com/s/1g8AWl42-tLdFYXEHZUYyGQ?pwd8888 提取码:8888 Python采集代码下载链接:采集代码.zip - 蓝奏云 立冬PPTPPT模板 西藏信仰PPT模板 古镇丽…

金麟国际用工-全新蓝领跨境就业服务平台

金麟国际用工-全新蓝领跨境就业服务平台 金麟国际用工平台是一个引领时代的蓝领跨境就业服务平台,专为蓝领求职者和雇主提供一个全面、便捷、高效的就业对接环境。这个平台通过其强大的数字化系统,包括客户管理系统、岗位信息系统和智能营销工具等&…