深入理解二分类和多分类CrossEntropy Loss和Focal Loss
二分类交叉熵
在二分的情况下,模型最后需要预测的结果只有两种情况,对于每个类别我们的预测得到的概率为
p
p
p和
1
−
p
1-p
1−p,此时表达式为( 的
log
\log
log底数是
e
e
e):
L
=
1
N
∑
i
L
i
=
1
N
∑
i
−
[
y
i
⋅
log
(
p
i
)
+
(
1
−
y
i
)
⋅
log
(
1
−
p
i
)
]
L=\frac{1}{N} \sum_{i} L_i =\frac{1}{N} \sum_{i} -[y_i \cdot \log (p_i) +(1-y_i) \cdot \log (1-p_i)]
L=N1i∑Li=N1i∑−[yi⋅log(pi)+(1−yi)⋅log(1−pi)]
其中:
- y i y_i yi —— 表示样本 i i i的label,正类为1 ,负类为0
- p i p_i pi—— 表示样本 i i i预测为正类的概率
由于二分类交叉熵很容易理解,在此就不做举例了。
多分类交叉熵
多分类交叉熵就是对二分类交叉熵的扩展,在计算公式中和二分类稍微有些许区别,但是还是比较容易理解,具体公式如下所示:
L
=
1
N
∑
i
L
i
=
−
1
N
∑
i
∑
c
=
1
M
y
i
c
log
(
p
i
c
)
L=\frac{1}{N} \sum_{i} L_i=-\frac{1}{N} \sum_{i} \sum_{c=1}^M y_{ic} \log(p_{ic})
L=N1i∑Li=−N1i∑c=1∑Myiclog(pic)
其中:
- M M M——类别的数量
- y i c y_{ic} yic——符号函数(0或1 ),如果样本 i i i的真实类别等于 c c c取 1,否则取 0
- p i c p_{ic} pic——观测样本 i i i属于类别 c c c的预测概率
举例说明
预测(已经经过softmax归一化) | 真实 |
---|---|
0.1 0.2 0.7 | 0 0 1 |
0.3 0.4 0.3 | 0 1 0 |
0.1 0.2 0.7 | 1 0 0 |
现在我们利用这个表达式计算上面例子中的损失函数值:
sample 1 loss
=
−
(
0
×
log
0.1
+
0
×
log
0.2
+
1
×
log
0.7
)
=
0.35
,
sample 2 loss
=
−
(
0
×
log
0.1
+
1
×
log
0.7
+
0
×
log
0.2
)
=
0.35
,
sample 3 loss
=
−
(
1
×
log
0.3
+
0
×
log
0.4
+
0
×
log
0.4
)
=
1.20
,
L
=
0.35
+
0.35
+
1.2
3
=
0.63
\text{sample 1 loss}=-(0 \times \log 0.1+0 \times \log 0.2 + 1 \times \log 0.7)=0.35 ,\\ \text{sample 2 loss}=-(0 \times \log 0.1+1 \times \log 0.7 + 0 \times \log 0.2)=0.35 ,\\ \text{sample 3 loss}=-(1 \times \log 0.3+0 \times \log 0.4 + 0 \times \log 0.4)=1.20,\\ L=\frac{0.35+0.35+1.2}{3}=0.63
sample 1 loss=−(0×log0.1+0×log0.2+1×log0.7)=0.35,sample 2 loss=−(0×log0.1+1×log0.7+0×log0.2)=0.35,sample 3 loss=−(1×log0.3+0×log0.4+0×log0.4)=1.20,L=30.35+0.35+1.2=0.63
其实可以看到,多分类交叉熵只计算正确标签对应概率的损失值,相对错误标签其
y
i
c
=
0
y_{ic}=0
yic=0,所以导致错误标签对应的损失值为0。
Pytorch的CrossEntropyLoss分析
参数设定
CrossEntropyLoss在Pytorch官网中,我们可以看到整个文档已经对该函数CrossEntropyLoss进行了较充分的解释。所以我们简要介绍其参数和传入的值的格式,特别是针对多分类的情况。
常见的传入参数如下所示:
-
weight
:传入的是一个list或者tensor,其检索对应位置的值为该类的权重。注意,如果是GPU的环境下,则传入的值必须是tensor,并且其应该在GPU中。 -
reduction
:传入的是一个字符串,有三种形式可以选择,分别是mean
/sum
/none
,默认是mean
。mean
和sum
如字面意思所示,代表损失值取平均,损失值求和的形式。none
是计算每个位置对应的损失值,返回和label对应的形状。
更多参数解释如下图所示:
使用方法
CrossEntropyLoss传入的值为两个,分别是input
和target
。输出只有一个Output
。
-
input
的形状为 ( N , C ) / ( N , C , d 1 , d 1 , … ) (N,C)/(N,C,d_1,d_1,\ldots) (N,C)/(N,C,d1,d1,…),前者对应二维情况,后者对应高维情况,值得注意的是 C C C是在dim=1
的位置上,可能在高维的情况下很多人都以为默认应该是最后一个维度dim=-1
。 -
target
的形状为 ( N ) / ( N , d 1 , d 1 , … ) (N)/(N,d_1,d_1,\ldots) (N)/(N,d1,d1,…),前者对应二维情况,后者对应高维情况。注意的是target
的值对应的是类别对应的索引,不是one-hot的形式。 -
Output
的形状和target
的形状一致。
更多参数解释如下图所示:
二维情况下对应的5分类交叉熵损失计算(官网示例):
>>> # Example of target with class indices
>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.empty(3, dtype=torch.long).random_(5)
>>> output = loss(input, target)
>>> output.backward()
>>>
>>> # Example of target with class probabilities
>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randn(3, 5).softmax(dim=1)
>>> output = loss(input, target)
>>> output.backward()
高维情况下对应的交叉熵计算:
input = torch.randn(2,3,5,5,4)#最后一个维度对应的是类别
target = torch.empty(2,3,5,5, dtype=torch.long).random_(4) #四分类
loss_fn=CrossEntropyLoss(reduction='sum')
_input=torch.permute(input,dims=(0,-1,1,2,3))
loss=loss_fn(_input,target)#输入的类别一定是在dim=1的位置上
print(loss)
# 当然也可以将输入先转为2维的形式在计算,结果是一样的
_input=input.view(-1,4)
_target=target.view(-1)
loss=loss_fn(_input,_target)
print(loss)
内在原理
Pytorch中的CrossEntropyLoss()
是将logSoftmax()
和NLLLoss()
函数进行合并的,也就是说其内在实现就是基于logSoftmax()
和NLLLoss()
这两个函数。
input=torch.rand(3,5)
target=torch.empty(3,dtype=torch.long).random_(5)
loss_fn=CrossEntropyLoss(reduction='sum')
loss=loss_fn(input,target)
print(loss)
_input=torch.nn.LogSoftmax(dim=1)(input)
loss=torch.nn.NLLLoss(reduction='sum')(_input,target)
print(loss)
其实也就是和官网上所说的一样,CrossEntropyLoss()
是对输出计算softmax()
,在对结果取log()
对数,最后使用NLLLoss()
得到对应位置的索引值。
Focal Loss原理和实现
Focal Loss来自于论文Focal Loss for Dense Object Detection,用于解决类别样本不平衡以及困难样本挖掘的问题,其公式非常简洁:
F
L
(
p
t
)
=
−
α
t
(
1
−
p
t
)
γ
log
(
p
t
)
FL(p_t)=- \alpha_t (1-p_t) ^{\gamma} \log (p_t)
FL(pt)=−αt(1−pt)γlog(pt)
p
t
p_t
pt是模型预测的结果的类别概率值。
−
log
(
p
t
)
- \log (p_t)
−log(pt)和交叉熵损失函数一致,因此当前样本类别对应的那个
p
t
p_t
pt如果越小,说明预测越不准确, 那么
(
1
−
p
t
)
γ
(1-p_t)^{\gamma}
(1−pt)γ 这一项就会增大,这一项也作为困难样本的系数,预测越不准,Focal Loss越倾向于把这个样本当作困难样本,这个系数也就越大,目的是让困难样本对损失和梯度的贡献更大。
前面的 α t \alpha_t αt是类别权重系数。如果你有一个类别不平衡的数据集,那么你肯定想对数量少的那一类在loss贡献上赋予一个高权重,这个 α t \alpha_t αt就起到这样的作用。因此, α t \alpha_t αt应该是一个向量,向量的长度等于类别的个数,用于存放各个类别的权重。一般来说 α t \alpha_t αt中的值为每一个类别样本数量的倒数,相当于平衡样本的数量差距。
这里提供一个二维/高维的Focal Loss的实现:
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=torch.tensor([0.2, 0.3, 0.5,1])):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, input, target):
logpt = nn.functional.log_softmax(input, dim=1) #计算softmax后在计算log
pt = torch.exp(logpt) #对log_softmax去exp,把log取消就是概率
alpha=self.alpha[target].unsqueeze(dim=1) # 去取真实索引类别对应的alpha
logpt = alpha*(1 - pt) ** self.gamma * logpt #focal loss计算公式
loss = nn.functional.nll_loss(logpt, target,reduction='sum') # 最后选择对应位置的元素
return loss
参考资料
CrossEntropy官网详细说明。
Pytorch中的CrossEntropyLoss()函数案例解读和结合one-hot编码计算Loss
详解PyTorch实现多分类Focal Loss——带有alpha简洁实现
最近工作