pytorch小记(十五):pytorch中 交叉熵损失详解:为什么logits比targets多一个维度?
- PyTorch交叉熵损失详解:为什么logits比targets多一个维度?
- 一、前言:新手常见困惑
- 二、核心概念:从考试得分到概率分布
- 1. logits:原始得分矩阵
- 2. targets:正确答案索引
- 三、维度差异的本质原因
- 1. 分类任务的数学需求
- 2. 维度对照表
- 3. 错误用法解析
- 四、手把手计算交叉熵损失
- 1. 输入数据
- 2. 计算步骤
- 步骤1:Softmax归一化
- 步骤2:提取正确类别的概率
- 步骤3:计算交叉熵
- 五、设计哲学深度解析
- 1. 为何不直接使用概率?
- 2. 多任务场景对照表
- 六、常见问题解答
- Q1:二分类能否用形状[N]的logits?
- Q2:如何处理多标签分类?
- Q3:为什么我的loss计算很慢?
- 七、总结
PyTorch交叉熵损失详解:为什么logits比targets多一个维度?
关键词:PyTorch交叉熵损失、logits维度、分类任务原理、深度学习基础
一、前言:新手常见困惑
许多初学PyTorch的朋友在使用交叉熵损失函数时,都会对logits
和targets
的维度关系感到困惑。典型的报错场景如下:
# 正确用法
logits = torch.tensor([[1.2, -0.5], [0.3, 2.1]]) # 形状 [2, 2]
targets = torch.tensor([0, 1]) # 形状 [2]
# 错误用法(触发维度错误)
logits_error = torch.tensor([0.5, 1.2]) # 形状 [2]
targets_error = torch.tensor([0, 1]) # 形状 [2]
loss = F.cross_entropy(logits_error, targets_error) # 报错!
本文将用生活实例+手把手计算的方式,带你彻底理解交叉熵损失的维度设计逻辑。
二、核心概念:从考试得分到概率分布
1. logits:原始得分矩阵
想象你正在参加一场有2道选择题的考试,每道题有A、B两个选项。模型对每个选项给出原始得分:
logits = torch.tensor([
[-1.0, 1.0], # 第1题:A得-1分,B得1分
[-0.5, 1.5], # 第2题:A得-0.5分,B得1.5分
[-0.5, 1.5] # 第3题(新增):同上
])
- 形状[3, 2]:3个样本(题目),每个样本2个类别(选项)
- 物理意义:未经归一化的"信心分数",数值越大表示模型越倾向该选项
2. targets:正确答案索引
targets = torch.tensor([0, 1, 1])
# 含义:第1题正确答案是A(索引0),第2、3题是B(索引1)
- 形状[3]:3个样本各对应一个正确答案位置
三、维度差异的本质原因
1. 分类任务的数学需求
- 模型需要为每个可能的类别提供判断依据
- 即使正确答案只有一个,也必须比较所有选项的"证据强度"
2. 维度对照表
张量 | 形状 | 物理意义 |
---|---|---|
logits | [N, C] | N个样本,每个样本C个类别的得分 |
targets | [N] | N个样本的正确类别索引(n在0~c-1之间) |
3. 错误用法解析
若logits
与targets
同维度:
logits_error = torch.tensor([0.2, 0.7, 0.5]) # 形状[3]
targets = torch.tensor([0, 1, 1]) # 形状[3]
此时模型无法判断:
- 每个数值对应哪个类别?
- 如何进行多类别比较?
四、手把手计算交叉熵损失
以具体例子演示计算全过程:
1. 输入数据
logits = torch.tensor([
[-1.0, 1.0],
[-0.5, 1.5],
[-0.5, 1.5]
]) # 形状[3,2]
targets = torch.tensor([0, 1, 1]) # 形状[3]
2. 计算步骤
步骤1:Softmax归一化
将原始得分转换为概率分布(每行和为1):
第1个样本([-1.0, 1.0]):
exp(-1.0) = 0.3679
exp(1.0) = 2.7183
总合 = 0.3679 + 2.7183 = 3.0862
概率 = [0.3679/3.0863 ≈ 0.1192, 2.7183/3.0863 ≈ 0.8808]
第2个样本([-0.5, 1.5]):
exp(-0.5) ≈ 0.6065
exp(1.5) ≈ 4.4817
总合 = 0.6065 + 4.4817 ≈ 5.0882
概率 = [0.6065/5.0882 ≈ 0.1192, 4.4817/5.0882 ≈ 0.8808]
步骤2:提取正确类别的概率
根据targets
索引:
样本1:取索引0 → 0.1192
样本2:取索引1 → 0.8808
样本3:取索引1 → 0.8808
步骤3:计算交叉熵
公式:loss = -平均(ln(正确概率))
loss = -(ln(0.1192) + ln(0.8808) + ln(0.8808)) / 3
= -[(-2.127) + (-0.127) + (-0.127)] / 3
≈ 0.7937
验证PyTorch计算结果:
print(loss.item()) # 输出 0.7937
五、设计哲学深度解析
1. 为何不直接使用概率?
- 数值稳定性:直接处理指数运算易导致溢出
- 梯度优化:logits的线性特性更利于反向传播
2. 多任务场景对照表
任务类型 | logits形状 | targets形状 | 损失函数 |
---|---|---|---|
二分类(2个选项) | [N,2] | [N] | CrossEntropyLoss |
多标签分类 | [N,C] | [N,C] | BCEWithLogitsLoss |
回归任务 | [N] | [N] | MSELoss |
六、常见问题解答
Q1:二分类能否用形状[N]的logits?
可以,但需配合sigmoid
:
# 二分类特例
logits = torch.tensor([0.8, -0.3]) # 形状[2]
prob = torch.sigmoid(logits) # 转换为概率
loss = F.binary_cross_entropy(prob, targets)
Q2:如何处理多标签分类?
当每个样本可能有多个正确标签时:
logits = torch.tensor([[1.2, -0.5], [0.3, 2.1]]) # 形状[2,2]
targets = torch.tensor([[1, 0], [0, 1]]) # 形状[2,2] (one-hot)
loss = F.binary_cross_entropy_with_logits(logits, targets)
Q3:为什么我的loss计算很慢?
- 检查是否误用了for循环逐个样本计算
- 正确的向量化计算可加速百倍以上
七、总结
理解logits与targets的维度差异,关键在于把握分类任务的本质需求:
- logits提供全类别的判断依据 → 需要二维结构
- targets只需指出正确位置 → 一维索引足矣
掌握这一设计哲学后,你就能:
✅ 正确构建分类模型的输出层
✅ 快速调试维度相关的错误
✅ 深入理解损失函数的工作原理
练习建议:在Jupyter Notebook中复现本文的计算示例,尝试修改logits值观察loss变化。
相关阅读:
- PyTorch官方文档:CrossEntropyLoss
如有疑问欢迎留言讨论!