损失函数总结(十):TripletMarginLoss、TripletMarginWithDistanceLoss
- 1 引言
- 2 损失函数
- 2.1 TripletMarginLoss
- 2.2 TripletMarginWithDistanceLoss
- 3 总结
1 引言
在前面的文章中已经介绍了介绍了一系列损失函数 (L1Loss
、MSELoss
、BCELoss
、CrossEntropyLoss
、NLLLoss
、CTCLoss
、PoissonNLLLoss
、GaussianNLLLoss
、KLDivLoss
、BCEWithLogitsLoss
、MarginRankingLoss
、HingeEmbeddingLoss
、MultiMarginLoss
、MultiLabelMarginLoss
、SoftMarginLoss
、MultiLabelSoftMarginLoss
)。在这篇文章中,会接着上文提到的众多损失函数继续进行介绍,给大家带来更多不常见的损失函数的介绍。这里放一张损失函数的机理图:
2 损失函数
2.1 TripletMarginLoss
论文链接:Learning local feature descriptors with triplets and shallow convolutional neural networks
TripletMarginLoss 是一种损失函数,不同于交叉熵损失仅仅考虑
样本与类别标签之间误差,TripletMarginLoss 关注样本与其他样本
之间距离。在输入张量 x1、x2、x3 和边际值大于 0
的情况下,创建一个用于测量三元组损失的标准
。一个三元组由 a、p 和 n(即分别为锚
、正样本
和负样本
)组成。所有输入张量的形状应为 (N,D)
。小批量
中每个样品
的损失函数为::
L
(
a
,
p
,
n
)
=
m
a
x
{
d
(
a
i
,
p
i
)
−
d
(
a
i
,
n
i
)
+
m
a
r
g
i
n
,
0
}
L(a,p,n)=max\{d(a_i,p_i)−d(a_i,n_i)+margin,0\}
L(a,p,n)=max{d(ai,pi)−d(ai,ni)+margin,0}
其中:
d
(
x
i
,
y
i
)
=
∥
x
i
−
y
i
∥
p
d(x_i,y_i)=∥x_i−y_i∥_p
d(xi,yi)=∥xi−yi∥p
- 损失函数目标:最小化损失函数,使得锚点与正例的距离越小,与负例的距离越大。
- m a r g i n margin margin: 人为设置的常数。
代码实现(Pytorch):
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2, eps=1e-7)
anchor = torch.randn(100, 128, requires_grad=True)
positive = torch.randn(100, 128, requires_grad=True)
negative = torch.randn(100, 128, requires_grad=True)
output = triplet_loss(anchor, positive, negative)
output.backward()
在siamese net或者Triplet net任务中被广泛使用。。。。
2.2 TripletMarginWithDistanceLoss
TripletMarginWithDistanceLoss 函数与 TripletMarginLoss功能基本一致
,只不过可以定制化
的传入不同的距离函数。当传入的距离函数是torch.nn.PairwiseDistance
时,两者完全一致。
使用自定义损失函数,代码实现(Pytorch):
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(20)
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
anchor = torch.randn(100, 128, requires_grad=True)
positive = torch.randn(100, 128, requires_grad=True)
negative = torch.randn(100, 128, requires_grad=True)
# Custom Distance Function
def l_infinity(x1, x2):
return torch.max(torch.abs(x1 - x2), dim=1).values
triplet_loss = nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5)
output = triplet_loss(anchor, positive, negative)
print(output.item())
# Custom Distance Function (Lambda)
triplet_loss = nn.TripletMarginWithDistanceLoss(
distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y))
output = triplet_loss(anchor, positive, negative)
print(output.item())
输出结果:
1.529929518699646
1.0007251501083374
在siamese net或者Triplet net任务中被广泛使用。。。。
3 总结
到此,使用 损失函数总结(十) 已经介绍完毕了!!! 如果有什么疑问欢迎在评论区提出,对于共性问题可能会后续添加到文章介绍中。如果存在没有提及的损失函数
也可以在评论区提出,后续会对其进行添加!!!!
如果觉得这篇文章对你有用,记得点赞、收藏并分享给你的小伙伴们哦😄。