1 介绍
- 创建一个衡量三元组损失的标准,给定输入张量 x1、x2 和 x3 以及一个大于0的间距值。
- 这用于测量样本之间的相对相似性。一个三元组由a、p和n组成(锚点、正例和负例)。所
- 有输入张量的形状都应为 (N,D)
2 基本使用方法
torch.nn.TripletMarginLoss(
margin=1.0,
p=2.0,
eps=1e-06,
swap=False,
size_average=None,
reduce=None,
reduction='mean')
3 参数
margin(float,可选) | 默认为1 |
p(int,可选) | 用于成对距离的范数度。默认为2 |
reduction(str,可选) | none/mean/sum,默认是mean |
4 举例
import torch
import torch.nn as nn
anchor = torch.tensor([0.5, -0.5, 0.1], requires_grad=True)
pos = torch.tensor([0.7, 0.2, 0.1])
neg= torch.tensor([0.8, 0.9, 0.2])
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=1)
triplet_loss(anchor,pos,neg)
#tensor(0.1000, grad_fn=<MeanBackward0>)
'''
(0.2-0.3)+(0.7-1.4)+(0-0.1)+1=0.1000
''
import torch
import torch.nn as nn
anchor = torch.tensor([0.5, -0.5, 0.1], requires_grad=True)
pos = torch.tensor([0.7, 0.2, 0.1])
neg= torch.tensor([0.8, 0.9, 0.2])
triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
triplet_loss(anchor,pos,neg)
#tensor(0.2927, grad_fn=<MeanBackward0>)
'''
np.sqrt((0.2-0.3)**2+(0.7-1.4)**2+(0-0.1)**2)
'''