交叉熵损失函数
交叉熵损失函数(Cross-Entropy Loss),也称为对数损失(Log Loss),是机器学习和深度学习中常用的损失函数之一,尤其在分类问题中。它衡量的是模型预测的概率分布与真实标签的概率分布之间的差异。
为什么要写一期交叉熵损失函数,是因为在基于pytorch本地部署微调bert模型(yelp文本分类数据集)-CSDN博客这个项目的coding过程中,错误的使用了pytorch的CrossEntropyLoss这个函数,在一周多的时间内自我怀疑Bert的微调是否能够成功,最后发现是损失函数使用错误的问题
计算公式
代码实现
import torch
import torch.nn as nn
# 假设模型输出的logits
logits = torch.tensor([[2.0, 1.0, 0.5]]) # 形状为 [1, 3]
# 真实标签,类别为2
labels = torch.tensor([2]) # 形状为 [1]
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 计算损失
loss = criterion(logits, labels)
print(loss.item()) # 输出损失值
千万注意,函数中的两个参数是不同的维度,label会在函数内部自行进行one-hot编码,而logits(一般认为是模型的输出值)是一维的向量,每一个批次输出分类数量的向量表示,所以对于logits是不需要进行处理的!