【深度学习】CrossEntropyLoss需要手动softmax吗?
- 问题:CrossEntropyLoss需要手动softmax吗?
- 答案:不需要
- 官方文档
- 代码解释
问题:CrossEntropyLoss需要手动softmax吗?
之前用 pytorch 实现自己的网络时,使用CrossEntropyLoss的时候将网路输出经 softmax激活层后再计算CrossEntropyLoss。
答案:不需要
调用了损失函数CrossEntropyLoss,最后一层是不需要再加softmax函数激活的。
官方文档
官方文档链接:pytorch-crossentropyloss,相当于在输入上应用LogSoftmax,然后NLLLoss
代码解释
import torch
import torch.nn as nn
import torch.nn.functional as F
criterion = nn.CrossEntropyLoss()
# 模拟网络输出(未经过softmax)
logits = torch.tensor([[0.2447, 3, 1]], requires_grad=True)
# 模拟目标标签
target = torch.tensor([0])
# 使用Softmax + CrossEntropyLoss计算损失
softmax_layer = nn.Softmax(dim=1)
softmax_output = softmax_layer(logits)
loss_softmax_cross_entropy = criterion(softmax_output, target)
print("softmax + CrossEntropyLoss:", loss_softmax_cross_entropy.item())
# 直接使用CrossEntropyLoss计算损失
loss_cross_entropy = criterion(logits, target)
print("CrossEntropyLoss:", loss_cross_entropy.item())
# 使用LogSoftmax + NLLLoss计算损失
softmax_output = torch.softmax(logits, dim=1)
log_softmax_output = torch.log(softmax_output)
log_softmax_nll = F.nll_loss(log_softmax_output, target)
print("LogSoftmax + NLLLoss:", log_softmax_nll.item())