目录
1. 交叉熵损失 CrossEntropyLoss
2. ignore_index 参数
3. weight 参数
4. 例子
1. 交叉熵损失 CrossEntropyLoss
CrossEntropyLoss 交叉熵损失可函数以用于分类或者分割任务中,这里主要介绍分割任务
建立如下的数据,pred是预测样本,label是真实标签
分割中,使用交叉熵损失的话,需要保证label的维度比pred维度少1,也就是没有channel维度。并且,label的类型是int
正常计算损失结果为:
手动计算一下,pred的softmax为
所以,loss = -(ln0.69+ln0.3543+ln0.5987)/3 = -(ln0.1464) / 3 = 0.6406
后面的是计算产生的误差,这里用数学方法简化计算了
one-hot 编码,只计算label的 ln 预测值
2. ignore_index 参数
在分割任务中,经常有像素点是认为不感兴趣的,所以这里ignore_index可以将那些不感兴趣的像素点排除
import torch
import torch.nn as nn
import torch.nn.functional as F
pred = torch.Tensor([[0.9, 0.1],[0.8, 0.2],[0.7, 0.3]]) # 预测值 size = 3*2, dtype = torch.float32
label = torch.LongTensor([0, 1, 0]) # 真实值 size = 3 , dtype = torch.int64
loss = nn.CrossEntropyLoss(ignore_index=1)
out = loss(pred,label)
print(out) # tensor(0.4421)
这里将label = 1的像素点排除,手动计算一下
loss = (-ln0.69-ln0.5987) / 2 = 0.4421
这里将label = 1的忽略了,下面是pred的softmax值
3. weight 参数
当涉及到样本的个数不平衡的时候,可以将样本少的label,w加大点
import torch
import torch.nn as nn
import torch.nn.functional as F
pred = torch.Tensor([[0.9, 0.1],[0.8, 0.2],[0.7, 0.3]]) # 预测值 size = 3*2, dtype = torch.float32
label = torch.LongTensor([0, 1, 0]) # 真实值 size = 3 , dtype = torch.int64
w = torch.FloatTensor([1,2])
loss = nn.CrossEntropyLoss(weight=w)
out = loss(pred,label)
print(out) # tensor(0.7398)
计算方法是:
loss =- ( 1*ln0.69 + 2*ln0.3543+1*ln0.5987) / 4 = (0.3711 + 2.0741+ 0.5130) / 4= 0.7396
可以发现答案是类似的,这里保留了四位小数进行计算,所以有误差
因为,label = 1有一个,label = 0 有两个,所以1的样本较少,这里就对label = 1设置权重大点。可以发现,计算出来的loss确实比不加loss的大,下图为不加w的
如果将w改成[2,1]的话,loss会更低,不利于loss的下降
所以,在样本不均衡的情况下,加label少的样本,w加大,可以将loss变大,从而梯度下降的时候可以更好的弥补样本不平衡的问题
注意:w的类型是float
4. 例子
测试代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
pred = torch.Tensor([[0.9, 0.1,0.2],[0.8, 0.2,0.1],[0.7, 0.3,0.5],[0.1,0.5,0.6]])
label = torch.LongTensor([2, 1, 0,1])
s = F.softmax(pred,dim=1)
print(s)
w = torch.FloatTensor([2,1,2])
loss = nn.CrossEntropyLoss(weight=w,ignore_index=2)
out = loss(pred,label)
print(out) # tensor(1.0401)
其中,pred的softmax如下:
label 为:2 1 0 1
可以发现,label 是 0 1 2 三类,这里将label = 2的忽略,并且对0 1 2施加的权重为 2 1 2
所以手动计算的公式为,这里精确到六位小数
label = 0 的损失 = - ln0.4018 = 0.911801
label = 1 的损失 = (- ln0.2683 - ln0.3603 ) / 2 = (1.315650 + 1.020818)/2 = 1.168234
label = 2 的损失 = - ln0.2552 = 1.365708
这里忽略了label = 2,所以还剩:
label = 0 的损失 = - ln0.4018 = 0.911801
label = 1 的损失 = (- ln0.2683 - ln0.3603 ) / 2 = (1.315650 + 1.020818)/2 = 1.168234
并且对0 1 进行加权2 1
所以总的loss = (0.911801 *2 + 1.315650*1+1.020818*1) /(2+1+1) = 4.16007/4=1.0400175
可以发现结果是一样的,这里最后是精度问题