本文介绍了CVPR2021行人重识别领域中一篇名为AlignPS论文中的TOIM损失函数
论文链接:https://arxiv.org/abs/2109.00211
代码链接:GitHub - daodaofr/AlignPS: Code for CVPR 2021 paper: Anchor-Free Person Search
TOIM
TOIM Loss = OIM Loss + Triplet Loss
OIM Loss
步骤一、初始化两个查找表(Looking-Up Tabel,LUT),第一个用于存放有标注的行人特征,第二个用于存放无标注的行人特征,
self.labeled_matching_layer = LabeledMatchingLayerQueue(num_persons=num_person, feat_len=self.in_channels)
self.unlabeled_matching_layer = UnlabeledMatchingLayer(queue_size=queue_size, feat_len=self.in_channels)
# 用于存放有label匹配的embeddings
class LabeledMatchingLayerQueue(nn.Module):
"""
Labeled matching of OIM loss function.
"""
def __init__(self, num_persons=5532, feat_len=256):
"""
Args:
num_persons (int): Number of labeled persons.
feat_len (int): Length of the feature extracted by the network.
"""
super(LabeledMatchingLayerQueue, self).__init__()
self.register_buffer("lookup_table", torch.zeros(num_persons, feat_len))
def forward(self, features, pid_labels):
"""
Args:
features (Tensor[N, feat_len]): Features of the proposals.
pid_labels (Tensor[N]): Ground-truth person IDs of the proposals.
Returns:
scores (Tensor[N, num_persons]): Labeled matching scores, namely the similarities
between proposals and labeled persons.
"""
scores, pos_feats, pos_pids = LabeledMatching.apply(features, pid_labels, self.lookup_table)
return scores, pos_feats, pos_pids
# 用于存放无label匹配的embeddings
class UnlabeledMatchingLayer(nn.Module):
"""
Unlabeled matching of OIM loss function.
"""
def __init__(self, queue_size=5000, feat_len=256):
"""
Args:
queue_size (int): Size of the queue saving the features of unlabeled persons.
feat_len (int): Length of the feature extracted by the network.
"""
super(UnlabeledMatchingLayer, self).__init__()
self.register_buffer("queue", torch.zeros(queue_size, feat_len))
self.register_buffer("tail", torch.tensor(0))
def forward(self, features, pid_labels):
"""
Args:
features (Tensor[N, feat_len]): Features of the proposals.
pid_labels (Tensor[N]): Ground-truth person IDs of the proposals.
Returns:
scores (Tensor[N, queue_size]): Unlabeled matching scores, namely the similarities
between proposals and unlabeled persons.
"""
scores = UnlabeledMatching.apply(features, pid_labels, self.queue, self.tail)
return scores
步骤二、将embeddings分别与两个LUT的转置进行矩阵乘法操作,得到(labeled_matching_scores, labeled_matching_reid, labeled_matching_ids)以及(unlabeled_matching_scores)
labeled_matching_scores, labeled_matching_reid, labeled_matching_ids = self.labeled_matching_layer(pos_reid, pos_reid_ids)
class LabeledMatching(Function):
@staticmethod
def forward(ctx, features, pid_labels, lookup_table, momentum=0.5):
ctx.save_for_backward(features, pid_labels)
ctx.lookup_table = lookup_table
ctx.momentum = momentum
scores = features.mm(lookup_table.t())
pos_feats = lookup_table.clone().detach()
pos_idx = pid_labels > 0
pos_pids = pid_labels[pos_idx]
pos_feats = pos_feats[pos_pids]
return scores, pos_feats, pos_pids
@staticmethod
def backward(ctx, grad_output, grad_feat, grad_pids):
features, pid_labels = ctx.saved_tensors
lookup_table = ctx.lookup_table
momentum = ctx.momentum
grad_feats = None
if ctx.needs_input_grad[0]:
grad_feats = grad_output.mm(lookup_table)
# Update lookup table, but not by standard backpropagation with gradients
for indx, label in enumerate(pid_labels):
if label >= 0:
lookup_table[label] = (
momentum * lookup_table[label] + (1 - momentum) * features[indx]
)
return grad_feats, None, None, None
unlabeled_matching_scores = self.unlabeled_matching_layer(pos_reid, pos_reid_ids)
class UnlabeledMatching(Function):
@staticmethod
def forward(ctx, features, pid_labels, queue, tail):
ctx.save_for_backward(features, pid_labels)
ctx.queue = queue
ctx.tail = tail
scores = features.mm(queue.t())
return scores
@staticmethod
def backward(ctx, grad_output):
features, pid_labels = ctx.saved_tensors
queue = ctx.queue
tail = ctx.tail
grad_feats = None
if ctx.needs_input_grad[0]:
grad_feats = grad_output.mm(queue.data)
"""
只将无label行人的前64维特征进行存储, 如果存储的无label行人数量大于queue_size
则对queue进行类似push和pop操作, 使queue的大小维持在queue_size
"""
for indx, label in enumerate(pid_labels):
if label == -1:
queue[tail, :64] = features[indx, :64]
tail += 1
if tail >= queue.size(0):
tail -= queue.size(0)
return grad_feats, None, None, None
步骤三、将步骤二得到的labeled_matching_scores和unlabeled_matching_scores分别乘以10后,沿着dim=1进行concat,得到matching_scores。对matching_scores进行softmax处理,得到p_i,对应论文中的公式如下,
labeled_matching_scores *= 10
unlabeled_matching_scores *= 10
matching_scores = torch.cat((labeled_matching_scores, unlabeled_matching_scores), dim=1)
p_i = F.softmax(matching_scores, dim=1)
根据p_i的大小,对p_i进行加权处理(类似focal loss),把较大的权重因子给到较小的p_i,得到focal_p_i,
focal_p_i = (1 - p_i)**2 * p_i.log()
步骤四、对focal_p_i以及对应的label求负对数似然,便可得到OIM Loss,
loss_oim = F.nll_loss(focal_p_i, pid_labels, reduction='none', ignore_index=-1)
步骤五、反向传播时,会对存放有label行人特征的LUT进行更新,更新的方式如下,
lookup_table[label] = (momentum * lookup_table[label] + (1 - momentum) * features[indx])
Triplet Loss
步骤一、将求OIM Loss过程中得到的labeled_matching_reid和labeled_matching_ids分别与pos_reid和pid_labels进行concat(相当于扩大了batch size,让triplet loss在更大的样本空间中寻找困难样本对),
pos_reid = torch.cat((pos_reid, labeled_matching_reid), dim=0)
pid_labels = torch.cat((pid_labels, labeled_matching_ids), dim=0)
步骤二、根据pos_reid和pid_labels求得Triplet Loss,
loss_tri = self.loss_tri(pos_reid, pid_labels)
class TripletLossFilter(nn.Module):
"""Triplet loss with hard positive/negative mining.
Reference:
Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.
Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.
Args:
margin (float): margin for triplet.
"""
def __init__(self, margin=0.3):
super(TripletLossFilter, self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin)
def forward(self, inputs, targets):
"""
Does not calculate noise inputs with label -1
Args:
inputs: feature matrix with shape (batch_size, feat_dim)
targets: ground truth labels with shape (num_classes)
"""
inputs_new = []
targets_new = []
targets_value = []
for i in range(len(targets)):
if targets[i] == -1:
continue
else:
inputs_new.append(inputs[i])
targets_new.append(targets[i])
targets_value.append(targets[i].cpu().numpy().item())
if len(set(targets_value)) < 2:
tmp_loss = torch.zeros(1)
tmp_loss = tmp_loss[0]
tmp_loss = tmp_loss.to(targets.device)
return tmp_loss
inputs_new = torch.stack(inputs_new)
targets_new = torch.stack(targets_new)
n = inputs_new.size(0)
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs_new, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t()
dist.addmm_(1, -2, inputs_new, inputs_new.t())
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
# For each anchor, find the hardest positive and negative
mask = targets_new.expand(n, n).eq(targets_new.expand(n, n).t())
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max())
dist_an.append(dist[i][mask[i] == 0].min())
dist_ap = torch.stack(dist_ap)
dist_an = torch.stack(dist_an)
# Compute ranking hinge loss
y = torch.ones_like(dist_an)
loss = self.ranking_loss(dist_an, dist_ap, y)
return loss
补充一下,torch.nn.MarginRankingLoss(margin=margin)的公式如下,
对应到以上代码中,