损失函数总结(八):MultiMarginLoss、MultiLabelMarginLoss
- 1 引言
- 2 损失函数
- 2.1 MultiMarginLoss
- 2.2 MultiLabelMarginLoss
- 3 总结
1 引言
在前面的文章中已经介绍了介绍了一系列损失函数 (L1Loss
、MSELoss
、BCELoss
、CrossEntropyLoss
、NLLLoss
、CTCLoss
、PoissonNLLLoss
、GaussianNLLLoss
、KLDivLoss
、BCEWithLogitsLoss
、MarginRankingLoss
、HingeEmbeddingLoss
)。在这篇文章中,会接着上文提到的众多损失函数继续进行介绍,给大家带来更多不常见的损失函数的介绍。这里放一张损失函数的机理图:
2 损失函数
2.1 MultiMarginLoss
MultiMarginLoss 是一种损失函数,通常用于多分类问题
,其中每个样本只能属于一个类别
。这个损失函数的主要目标是鼓励模型将正确类别的得分与错误类别的得分之间的间隔(差距)最小化
。通常,这个损失函数被用于训练神经网络模型
,以确保正确的类别获得高的分数,而错误的类别获得低的分数。MultiMarginLoss 的数学表达式如下:
L
o
s
s
(
x
,
y
)
=
∑
i
w
[
y
]
∗
m
a
x
(
0
,
m
a
r
g
i
n
−
x
[
y
]
+
x
[
i
]
)
p
x
.
s
i
z
e
(
0
)
Loss(x, y)=\frac{\sum_iw[y]*max(0, margin-x[y]+x[i])^p}{x.size(0)}
Loss(x,y)=x.size(0)∑iw[y]∗max(0,margin−x[y]+x[i])p
其中:
- p p p: 默认值为1,仅可选1或者2。
- m a r g i n margin margin: 默认值为1.
-
w
[
y
]
w[y]
w[y]: 为各类别的weight。weight必须是float类型的tensor,其长度要于
类别C
一致,即每一个类别都要设置有weight
。
代码实现(Pytorch):
loss = nn.MultiMarginLoss()
x = torch.tensor([[0.1, 0.2, 0.4, 0.8]])
y = torch.tensor([3])
# 0.25 * ((1-(0.8-0.1)) + (1-(0.8-0.2)) + (1-(0.8-0.4)))
loss(x, y)
在siamese net或者Triplet net任务中被广泛使用。。。。
2.2 MultiLabelMarginLoss
MultiLabelMarginLoss 是一种损失函数,通常用于多标签分类
问题,其中每个样本
可以属于多个类别
。它有助于训练模型以将样本正确分类到其相关类别
,并在训练中惩罚不正确的分类
。MultiLabelMarginLoss 的数学表达式如下:
L
o
s
s
(
x
,
y
)
=
∑
i
j
m
a
x
(
0
,
1
−
(
x
[
y
[
j
]
]
−
x
[
i
]
)
)
x
.
s
i
z
e
(
0
)
Loss(x, y) = \sum_{ij}\frac{max(0,1-(x[y[j]] - x[i]))}{x.size(0)}
Loss(x,y)=ij∑x.size(0)max(0,1−(x[y[j]]−x[i]))
其中:
- x [ y [ j ] ] x[y[j]] x[y[j]]: 表示 样本x所属类的输出值。
- x [ i ] x[i] x[i]: 表示不等于该类的输出值。 并且,对于所有的 i i i 和 j j j, i ≠ y [ j ] i\neq y[j] i=y[j]。
代码实现(Pytorch):
loss = nn.MultiLabelMarginLoss()
x = torch.FloatTensor([[0.1, 0.2, 0.4, 0.8]])
# for target y, only consider labels 3 and 0, not after label -1
y = torch.LongTensor([[3, 0, -1, 1]])
# 0.25 * ((1-(0.1-0.2)) + (1-(0.1-0.4)) + (1-(0.8-0.2)) + (1-(0.8-0.4)))
loss(x, y)
MultiLabelMarginLoss 是 MultiMarginLoss的多标签版本。。。。
3 总结
到此,使用 损失函数总结(八) 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。如果存在没有提及的损失函数
也可以在评论区提出,后续会对其进行添加!!!!
如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。