AlignPS中的TOIM损失

news2024/11/27 13:45:55

本文介绍了CVPR2021行人重识别领域中一篇名为AlignPS论文中的TOIM损失函数

论文链接:https://arxiv.org/abs/2109.00211

代码链接:GitHub - daodaofr/AlignPS: Code for CVPR 2021 paper: Anchor-Free Person Search

TOIM

TOIM Loss = OIM Loss + Triplet Loss

OIM Loss

步骤一、初始化两个查找表(Looking-Up Tabel,LUT),第一个用于存放有标注的行人特征,第二个用于存放无标注的行人特征,

self.labeled_matching_layer = LabeledMatchingLayerQueue(num_persons=num_person, feat_len=self.in_channels)
self.unlabeled_matching_layer = UnlabeledMatchingLayer(queue_size=queue_size, feat_len=self.in_channels)


# 用于存放有label匹配的embeddings
class LabeledMatchingLayerQueue(nn.Module):
    """
    Labeled matching of OIM loss function.
    """

    def __init__(self, num_persons=5532, feat_len=256):
        """
        Args:
            num_persons (int): Number of labeled persons.
            feat_len (int): Length of the feature extracted by the network.
        """
        super(LabeledMatchingLayerQueue, self).__init__()
        self.register_buffer("lookup_table", torch.zeros(num_persons, feat_len))

    def forward(self, features, pid_labels):
        """
        Args:
            features (Tensor[N, feat_len]): Features of the proposals.
            pid_labels (Tensor[N]): Ground-truth person IDs of the proposals.

        Returns:
            scores (Tensor[N, num_persons]): Labeled matching scores, namely the similarities
                                             between proposals and labeled persons.
        """
        scores, pos_feats, pos_pids = LabeledMatching.apply(features, pid_labels, self.lookup_table)
        return scores, pos_feats, pos_pids


# 用于存放无label匹配的embeddings
class UnlabeledMatchingLayer(nn.Module):
    """
    Unlabeled matching of OIM loss function.
    """

    def __init__(self, queue_size=5000, feat_len=256):
        """
        Args:
            queue_size (int): Size of the queue saving the features of unlabeled persons.
            feat_len (int): Length of the feature extracted by the network.
        """
        super(UnlabeledMatchingLayer, self).__init__()
        self.register_buffer("queue", torch.zeros(queue_size, feat_len))
        self.register_buffer("tail", torch.tensor(0))

    def forward(self, features, pid_labels):
        """
        Args:
            features (Tensor[N, feat_len]): Features of the proposals.
            pid_labels (Tensor[N]): Ground-truth person IDs of the proposals.

        Returns:
            scores (Tensor[N, queue_size]): Unlabeled matching scores, namely the similarities
                                            between proposals and unlabeled persons.
        """
        scores = UnlabeledMatching.apply(features, pid_labels, self.queue, self.tail)
        return scores

步骤二、将embeddings分别与两个LUT的转置进行矩阵乘法操作,得到(labeled_matching_scores, labeled_matching_reid, labeled_matching_ids)以及(unlabeled_matching_scores)

labeled_matching_scores, labeled_matching_reid, labeled_matching_ids = self.labeled_matching_layer(pos_reid, pos_reid_ids)


class LabeledMatching(Function):
    @staticmethod
    def forward(ctx, features, pid_labels, lookup_table, momentum=0.5):
        ctx.save_for_backward(features, pid_labels)
        ctx.lookup_table = lookup_table
        ctx.momentum = momentum

        scores = features.mm(lookup_table.t())
        pos_feats = lookup_table.clone().detach()
        pos_idx = pid_labels > 0
        pos_pids = pid_labels[pos_idx]
        pos_feats = pos_feats[pos_pids]
        
        return scores, pos_feats, pos_pids

    @staticmethod
    def backward(ctx, grad_output, grad_feat, grad_pids):
        features, pid_labels = ctx.saved_tensors
        lookup_table = ctx.lookup_table
        momentum = ctx.momentum

        grad_feats = None
        if ctx.needs_input_grad[0]:
            grad_feats = grad_output.mm(lookup_table)

        # Update lookup table, but not by standard backpropagation with gradients
        for indx, label in enumerate(pid_labels):
            if label >= 0:
                lookup_table[label] = (
                    momentum * lookup_table[label] + (1 - momentum) * features[indx]
                )

        return grad_feats, None, None, None
unlabeled_matching_scores = self.unlabeled_matching_layer(pos_reid, pos_reid_ids)


class UnlabeledMatching(Function):
    @staticmethod
    def forward(ctx, features, pid_labels, queue, tail):
        ctx.save_for_backward(features, pid_labels)
        ctx.queue = queue
        ctx.tail = tail

        scores = features.mm(queue.t())
        return scores

    @staticmethod
    def backward(ctx, grad_output):
        features, pid_labels = ctx.saved_tensors
        queue = ctx.queue
        tail = ctx.tail

        grad_feats = None
        if ctx.needs_input_grad[0]:
            grad_feats = grad_output.mm(queue.data)

        """
        只将无label行人的前64维特征进行存储, 如果存储的无label行人数量大于queue_size 
        则对queue进行类似push和pop操作, 使queue的大小维持在queue_size
        """
        for indx, label in enumerate(pid_labels):
            if label == -1:
                queue[tail, :64] = features[indx, :64]
                tail += 1
                if tail >= queue.size(0):
                    tail -= queue.size(0)

        return grad_feats, None, None, None

步骤三、将步骤二得到的labeled_matching_scores和unlabeled_matching_scores分别乘以10后,沿着dim=1进行concat,得到matching_scores。对matching_scores进行softmax处理,得到p_i,对应论文中的公式如下,

labeled_matching_scores *= 10
unlabeled_matching_scores *= 10
matching_scores = torch.cat((labeled_matching_scores, unlabeled_matching_scores), dim=1)
p_i = F.softmax(matching_scores, dim=1)

 根据p_i的大小,对p_i进行加权处理(类似focal loss),把较大的权重因子给到较小的p_i,得到focal_p_i,

focal_p_i = (1 - p_i)**2 * p_i.log()

步骤四、对focal_p_i以及对应的label求负对数似然,便可得到OIM Loss

loss_oim = F.nll_loss(focal_p_i, pid_labels, reduction='none', ignore_index=-1)

步骤五、反向传播时,会对存放有label行人特征的LUT进行更新,更新的方式如下,

lookup_table[label] = (momentum * lookup_table[label] + (1 - momentum) * features[indx])

Triplet Loss

步骤一、将求OIM Loss过程中得到的labeled_matching_reid和labeled_matching_ids分别与pos_reid和pid_labels进行concat(相当于扩大了batch size,让triplet loss在更大的样本空间中寻找困难样本对),

pos_reid = torch.cat((pos_reid, labeled_matching_reid), dim=0)
pid_labels = torch.cat((pid_labels, labeled_matching_ids), dim=0)

步骤二、根据pos_reid和pid_labels求得Triplet Loss,

 

loss_tri = self.loss_tri(pos_reid, pid_labels)


class TripletLossFilter(nn.Module):
    """Triplet loss with hard positive/negative mining.

    Reference:
    Hermans et al. In Defense of the Triplet Loss for Person Re-Identification. arXiv:1703.07737.

    Code imported from https://github.com/Cysu/open-reid/blob/master/reid/loss/triplet.py.

    Args:
        margin (float): margin for triplet.
    """
    def __init__(self, margin=0.3):
        super(TripletLossFilter, self).__init__()
        self.margin = margin
        self.ranking_loss = nn.MarginRankingLoss(margin=margin)

    def forward(self, inputs, targets):
        """
        Does not calculate noise inputs with label -1
        Args:
            inputs: feature matrix with shape (batch_size, feat_dim)
            targets: ground truth labels with shape (num_classes)
        """
        inputs_new = []
        targets_new = []
        targets_value = []
        for i in range(len(targets)):
            if targets[i] == -1:
                continue
            else:
                inputs_new.append(inputs[i])
                targets_new.append(targets[i])
                targets_value.append(targets[i].cpu().numpy().item())
        if len(set(targets_value)) < 2:
            tmp_loss = torch.zeros(1)
            tmp_loss = tmp_loss[0]
            tmp_loss = tmp_loss.to(targets.device)
            return tmp_loss
        
        inputs_new = torch.stack(inputs_new)
        targets_new = torch.stack(targets_new)
        n = inputs_new.size(0)

        # Compute pairwise distance, replace by the official when merged
        dist = torch.pow(inputs_new, 2).sum(dim=1, keepdim=True).expand(n, n)
        dist = dist + dist.t()
        dist.addmm_(1, -2, inputs_new, inputs_new.t())
        dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability

        # For each anchor, find the hardest positive and negative
        mask = targets_new.expand(n, n).eq(targets_new.expand(n, n).t())
        dist_ap, dist_an = [], []
        for i in range(n):
            dist_ap.append(dist[i][mask[i]].max())
            dist_an.append(dist[i][mask[i] == 0].min())

        dist_ap = torch.stack(dist_ap)
        dist_an = torch.stack(dist_an)
        # Compute ranking hinge loss
        y = torch.ones_like(dist_an)
        loss = self.ranking_loss(dist_an, dist_ap, y)
        return loss

补充一下,torch.nn.MarginRankingLoss(margin=margin)的公式如下,

对应到以上代码中,

Loss(d_{an},d_{ap},y)=max(0,d_{ap}-d_{an}+margin) 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/764240.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【STM32零基础入门教程02】STM32环境获取及搭建

(14条消息) 【STM32零基础入门教程01】STM32入门基础知识_小超电子笔记的博客-CSDN博客 工欲善其事必先利其器&#xff0c;在上一章了解完STM32的一些基础知识之后&#xff0c;我们需要对使用的环境进行获取和安装。 一、MDK&#xff08;KEIL&#xff09;的获取 Keil是一种…

CCF 202209-2 何以包邮? (01背包动态规划练习)

一、先温习一下01背包问题 有N件物品和一个容量为V的背包。第i件物品的体积是c[i]&#xff0c;价值是w[i]。求解将哪些物品装入背包可使价值总和最大。 条件汇总 -------- 背包限制容量&#xff1a;Z 此时背包容量&#xff1a;C 物品&#xff1a;1 , i ... …

WSA - root,frida与ida测试

本文旨在配置windows subsystem for android(win安卓子系统)来作为win在开启了hyper-v的情况下的一种轻量的安卓模拟器方案。使用MagiskOnWsa设置root权限&#xff0c;最终使其正常与开发环境、frida、ida打通。 1. Root的WSA 常用的Wsa版本在目前是没有默认root的。在物理机…

matlab滤波器设计-IIR滤波器的设计与仿真

matlab滤波器设计-IIR滤波器的设计与仿真 1 引言 在现代通信系统中&#xff0c;由于信号中经常混有各种复杂成分&#xff0c;所以很多信号的处理和分析都是基于滤波器而进行的。但是&#xff0c;传统的数字滤波器的设计使用繁琐的公式计算&#xff0c;改变参数后需要重新计…

【Matlab】智能优化算法_猎豹优化算法CO)

【Matlab】智能优化算法_猎豹优化算法CO 1.背景介绍2.数学模型2.1 搜索策略2.2 坐等策略2.3 攻击策略2.4 假设 3.文件结构4.伪代码5.详细代码及注释5.1 CO.m5.2 CO_VectorBased.m5.3 Get_Functions_details.m 6.运行结果7.参考文献 1.背景介绍 猎豹&#xff08;Achinonyx jubat…

Kubernetes部署服务到集群中的指定节点

# kubectl get node NAME STATUS ROLES AGE VERSION k8s-master Ready master 25h v1.17.3 k8s-node2 Ready <none> 25h v1.17.3 集群只有两个节点&#xff0c;这里打算将应用部署在k8s-node2节点上&#xff0c;需要先记下这个节点的…

【3】Vite Vue3 用户、角色、岗位选择组件封装

在当今前端开发的领域里&#xff0c;快速、高效的项目构建工具以及使用最新技术栈是非常关键的。ViteVue3 组合为一体的项目实战示例专栏将带领你深入了解和掌握这一最新的前端开发工具和框架。 作为下一代前端构建工具&#xff0c;Vite 在开发中的启动速度和热重载方面具有突…

攻不下dfs不参加比赛(十三)

标题 为什么练dfs题目为什么练dfs 相信学过数据结构的朋友都知道dfs(深度优先搜索)是里面相当重要的一种搜索算法,可能直接说大家感受不到有条件的大家可以去看看一些算法比赛。这些比赛中每一届或多或少都会牵扯到dfs,可能提到dfs大家都知道但是我们为了避免眼高手低有的东…

24 - 数组和广义表 - 二维数组

前面我们学习了一维数组、今天来看看二维数组,比一维数组更加复杂! 数组的特点 存储的空间连续 存储类型相同 可以使用地址+偏移快速访问 二维数组定义 二维数组本质上是以数组作为数组元素的数组,即“数组的数组”,语法格式如下: 类型说明符 数组名[常量表达式][常量表达…

977.有序数组的平方

977.有序数组的平方 1.暴力排序 这道题最直观的方法在于&#xff0c;将数组中的每个数平方之后&#xff0c;排个序 public int[] sortedSquares(int[] nums) {int[]ans new int[nums.length];for(int i0;i<nums.length;i){ans[i] nums[i]*nums[i];} Arrays.sort(ans);ret…

人工智能LLM模型:奖励模型的训练、PPO 强化学习的训练、RLHF

人工智能LLM模型&#xff1a;奖励模型的训练、PPO 强化学习的训练 1.奖励模型的训练 1.1大语言模型中奖励模型的概念 在大语言模型完成 SFT 监督微调后&#xff0c;下一阶段是构建一个奖励模型来对问答对作出得分评价。奖励模型源于强化学习中的奖励函数&#xff0c;能对当前…

高通芯片android进入EDL模式 下载 热启动 串口指令

参考&#xff1a;高通方案的Android设备几种开机模式的进入与退出_edl模式怎么退出_Rookie20190715的博客-CSDN博客 切换为EDL模式 向串口发送 4b 65 01 00 54 0f 7e 或者adb reboot edl

Ceph的安装部署

文章目录 一、存储基础1.1 单机存储设备1.2 单机存储的问题1.3分布式存储&#xff08;软件定义的存储 SDS&#xff09; 二、Ceph 简介2.1 Ceph 优势2.2 Ceph 架构2.3 Ceph 核心组件2.4 Pool、PG 和 OSD 的关系&#xff1a;2.5 OSD 存储后端2.6 Ceph 数据的存储过程2.7 Ceph 版本…

PID控制系列--(1、最形象的PID)

目录 1、 比例控制系统的标准结构2、最简单的例子3、第二个例子4、积分控制器6、微分控制7 总结 今天 看到了B站上一个叫洋葱auto的UP主搬来的介绍PID控制的视频&#xff0c;感觉讲得形象易懂&#xff0c;为便于让和我一样看了无数文章还是不能很好理解PID控制本质的人共同分享…

2. DATASETS DATALOADERS

2. DATASETS & DATALOADERS PyTorch提供了两个数据基元&#xff1a;torch.utils.data.DataLoader和torch.uutils.data.data集&#xff0c;允许使用预加载的数据集以及自己的数据。数据集存储样本及其相应的标签&#xff0c;DataLoader在数据集周围包装了一个可迭代项&…

Sentinel整合OpenFegin

之前学习了openFeign的使用&#xff0c;我是超链接 现在学习通过Sentinel来进行整合OpenFegin。 引入OpenFegin 我们需要在当前的8084项目中引入对应的依赖 <dependency><groupId>org.springframework.cloud</groupId><artifactId>spring-cloud-sta…

网络套接字编程(一)(UDP)

gitee仓库&#xff1a;https://gitee.com/WangZihao64/linux/tree/master/chat_udp 预备知识 源IP地址和目的IP地址 它是用来标识网络中不同主机的地址。两台主机进行通信时&#xff0c;发送方需要知道自己往哪一台主机发送&#xff0c;这就需要知道接受方主机的的IP地址&am…

【数学建模】利用C语言来实现 太阳赤纬 太阳高度角 太阳方位角 计算和求解分析 树木树冠阴影面积与种植间距的编程计算分析研究

太阳赤纬的计算 #include <stdio.h> #include <math.h>double calculateDelta(int year, int month, int day, int hour, int minute, int second) {int n, n0;double t, theta, delta;// 计算n和n0n month * 30 day;n0 79.6764 0.2422 * (year - 1985) - ((y…

35+大龄程序员从焦虑到收入飙升:我的搞钱副业分享。

37岁大龄程序员&#xff0c;一度觉得自己的职场生涯到头了。既没有晋升和加薪的机会&#xff0c;外面的公司要么接不住我的薪资&#xff0c;要么就是卷得不行&#xff0c;无法兼顾工作和家庭&#xff0c;感觉陷入了死局…… 好在我又重新振作起来&#xff0c;决定用副业和兼职填…

2.3Listbox列表部件

2.3Listbox列表部件 创建主窗口 window tk.Tk() window.title(my window) window.geometry(200x200)创建一个label用于显示 var1 tk.StringVar() #创建变量 l tk.Label(window,bgyellow,width4,textvariablevar1) l.pack()创建一个方法用于按钮的点击事件 def print_s…