导读
Dice Loss是由 Dice 系数而得名的,Dice系数是一种用于评估两个样本相似性的度量函数,其值越大意味着这两个样本越相似,Dice系数的数学表达式如下:
Dice
=
2
∣
X
∩
Y
∣
∣
X
∣
+
∣
Y
∣
\text { Dice }=\frac{2|X \cap Y|}{|X|+|Y|}
Dice =∣X∣+∣Y∣2∣X∩Y∣
其中, ∣ X ∩ Y ∣ |X \cap Y| ∣X∩Y∣ 表示 X 和 Y 之间交集元素的个数, ∣ X ∣ |X| ∣X∣ 和 ∣ Y ∣ |Y| ∣Y∣ 分别表示 X,Y 中元素的个数。Dice Loss 表达式如下:
DiceLoss = 1 − Dice = 1 − 2 ∣ X ∩ Y ∣ ∣ X ∣ + ∣ Y ∣ \text { DiceLoss }=1-\text { Dice }=1-\frac{2|X \cap Y|}{|X|+|Y|} DiceLoss =1− Dice =1−∣X∣+∣Y∣2∣X∩Y∣
语义分割下的dice loss
Dice Loss常用于语义分割问题中,计算公式不变,但是变量含义有所改变。
X
\mathrm{X}
X 表示真实分割图像的像素标签,
Y
\mathrm{Y}
Y 表示模型预测分割图像的像素类别。
Y
\mathrm{Y}
Y 只有 0,1 两个值,0表示像素点不是目标类,1表示像素点是目标类。
∣
X
∩
Y
∣
|X \cap Y|
∣X∩Y∣ 近似为预测图像的像素与真实标签图像的像素之间的点乘,并将点乘结果相加,
∣
X
∣
|X|
∣X∣ 和
∣
Y
∣
|Y|
∣Y∣ 分别近似为它们各自对应图像中的像素相加。故有公式:
DiceLoss
=
1
−
2
∑
i
=
1
N
y
i
y
^
i
∑
i
=
1
N
y
i
+
∑
i
=
1
N
y
i
^
\text { DiceLoss }=1-\frac{2 \sum_{i=1}^N y_i \hat{y}_i}{\sum_{i=1}^N y_i+\sum_{i=1}^N \hat{y_i}}
DiceLoss =1−∑i=1Nyi+∑i=1Nyi^2∑i=1Nyiy^i
注意,dice loss通常是不计算背景类的。
对于多分类问题,对 label 进行 one hot 编码,生成多个 label 图,每个类别对应一个二分类label图。通过计算每个类别的 Dice Loss 损失,最后再求均值即得到多分类的 Dice Loss 损失。
等价F1-score
假设有两个集合
A
A
A 和
B
B
B , Dice系数定义为:
Dice
(
A
,
B
)
=
2
∣
A
∩
B
∣
∣
A
∣
+
∣
B
∣
\operatorname{Dice}(A, B) =\frac{2|A \cap B|}{|A|+|B|}
Dice(A,B)=∣A∣+∣B∣2∣A∩B∣
A
∩
B
A \cap B
A∩B 表示预测结果与真实标签的交集,在二分类问题中等于正确预测为正类的数量 TP 。而
F
P
FP
FP 表示预测为正类但实际上是负类的数量 (属于A,但不属于B) ,
F
N
FN
FN 表示预测为负类但实际上是正类的数量 (属于B,但不属于A) ,故又有
2
∗
T
P
+
F
P
+
F
N
=
∣
A
∣
+
∣
B
∣
2* \mathrm{TP}+\mathrm{FP}+\mathrm{FN} = |A|+|B|
2∗TP+FP+FN=∣A∣+∣B∣
故有,
D
i
c
e
=
2
∗
T
P
2
∗
T
P
+
F
P
+
F
N
Dice = \frac{2 * TP}{2 * TP + FP + FN}
Dice=2∗TP+FP+FN2∗TP
,该公式正好等于 F1-score。
二分类例子
假设模型输出的预测值如下
标签 label 如下(0 即对应背景,表示不属于某一类,1 表示属于某一类):
计算类别1的dice:
∣
X
∩
Y
∣
=
[
0.5322
0.4932
0.1764
0.3107
0.5297
0.1604
0.3841
0.3537
0.3574
0.3323
0.8301
0.6436
]
⋆
[
0
0
0
0
0
0
1
1
1
1
1
1
]
=
[
0.0000
0.0000
0.0000
0.0000
0.0000
0.0000
0.3841
0.3537
0.3574
0.3323
0.8301
0.6436
]
→
2.9012
(求和)
\begin{aligned} |\mathrm{X} \cap \mathrm{Y}| & =\left[\begin{array}{lll} 0.5322 & 0.4932 & 0.1764 \\ 0.3107 & 0.5297 & 0.1604 \\ 0.3841 & 0.3537 & 0.3574 \\ 0.3323 & 0.8301 & 0.6436 \end{array}\right] \star\left[\begin{array}{lll} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{array}\right] \\ \\ & =\left[\begin{array}{lll} 0.0000 & 0.0000 & 0.0000 \\ 0.0000 & 0.0000 & 0.0000 \\ 0.3841 & 0.3537 & 0.3574 \\ 0.3323 & 0.8301 & 0.6436 \end{array}\right] \rightarrow 2.9012 \text { (求和) } \end{aligned}
∣X∩Y∣=
0.53220.31070.38410.33230.49320.52970.35370.83010.17640.16040.35740.6436
⋆
001100110011
=
0.00000.00000.38410.33230.00000.00000.35370.83010.00000.00000.35740.6436
→2.9012 (求和)
∣ X ∣ = [ 0.5322 0.4932 0.1764 0.3107 0.5297 0.1604 0.3841 0.3537 0.3574 0.3323 0.8301 0.6436 ] → 5.1038 |\mathrm{X}|=\left[\begin{array}{lll} 0.5322 & 0.4932 & 0.1764 \\ 0.3107 & 0.5297 & 0.1604 \\ 0.3841 & 0.3537 & 0.3574 \\ 0.3323 & 0.8301 & 0.6436 \end{array}\right] \rightarrow 5.1038 ∣X∣= 0.53220.31070.38410.33230.49320.52970.35370.83010.17640.16040.35740.6436 →5.1038
∣ Y ∣ = [ 0 0 0 0 0 0 1 1 1 1 1 1 ] → 6 |Y|=\left[\begin{array}{lll} 0 & 0 & 0 \\ 0 & 0 & 0 \\ 1 & 1 & 1 \\ 1 & 1 & 1 \end{array}\right] \rightarrow 6 ∣Y∣= 001100110011 →6
所以 Dice 系数为
D
=
2
∗
∣
X
∩
Y
∣
+
1
∣
X
∣
+
∣
Y
∣
+
1
=
2
∗
2.9012
+
1
5.1038
+
6
+
1
=
0.5620
\mathrm{D}=\frac{2 *|\mathrm{X} \cap \mathrm{Y}|+1}{|\mathrm{X}|+|\mathrm{Y}|+1}=\frac{2 * 2.9012+1}{5.1038+6+1}=0.5620
D=∣X∣+∣Y∣+12∗∣X∩Y∣+1=5.1038+6+12∗2.9012+1=0.5620
所以 Dice 损失
L
=
1
−
D
=
0.4380
\mathrm{L}=1-\mathrm{D}=0.4380
L=1−D=0.4380 。
优点
Dice Loss 可以缓解样本中前景背景(面积)不平衡带来的消极影响,前景背景不平衡也就是说图像中大部分区域是不包含目标的,只有一小部分区域包含目标。Dice Loss训练更关注对前景区域的挖掘,即保证有较低的FN,但会存在损失饱和问题,而CE Loss是平等地计算每个像素点的损失,当前点的损失只和当前预测值与真实标签值的距离有关,这会导致一些问题(见Focal Loss)。因此单独使用Dice Loss往往并不能取得较好的结果,需要进行组合使用,比如Dice Loss+CE Loss或者Dice Loss+Focal Loss等。
代码
def Dice_loss(inputs, target, beta=1, smooth = 1e-5):
n, c, h, w = inputs.size()
nt, ht, wt, ct = target.size()
if h != ht and w != wt:
inputs = F.interpolate(inputs, size=(ht, wt), mode="bilinear", align_corners=True)
temp_inputs = torch.softmax(inputs.transpose(1, 2).transpose(2, 3).contiguous().view(n, -1, c),-1)
temp_target = target.view(n, -1, ct)
#--------------------------------------------#
# 计算dice loss
#--------------------------------------------#
tp = torch.sum(temp_target[...,:-1] * temp_inputs, axis=[0,1])
fp = torch.sum(temp_inputs , axis=[0,1]) - tp
fn = torch.sum(temp_target[...,:-1] , axis=[0,1]) - tp
score = ((1 + beta ** 2) * tp + smooth) / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
dice_loss = 1 - torch.mean(score)
return dice_loss