【目标检测】LLA: Loss-aware label assignment for dense pedestrian detection【标签分配】

news2025/1/30 16:33:06

总结

本文提出了一种用于行人目标检测的标签分配策略,具体来说,主要有以下几步流程。

  1. 构建代价矩阵。通过网络的前向传播得到网络的输出, C c l s C^{cls} Ccls, C r e g C^{reg} Creg,构建代价矩阵 C = C c l s + λ ∗ C r e g C=C^{cls}+\lambda*C^{reg} C=Ccls+λCreg
  2. 选取代价矩阵中的前TOP K个候选框(即 loss比较小的),作为正样本,其他的为负样本。
  3. 为了加速收敛,强制正样本候选区域在gt框内。

本文的作者和YOLOX是同一个作者,YOLOX的标签分配策略,可以看做在本文上面进行了稍微的更改。

更多的细节

  1. TOP K,超参的敏感性
    作者通过做实验发现,TOP K在一定范围内是不敏感的在这里插入图片描述
  2. 代价矩阵中各部分消融实验研究
    在这里插入图片描述
  3. 可视化结果
    在这里插入图片描述

代码

参考连接

def get_lla_assignments_and_losses(self, shifts, targets, box_cls, box_delta, box_iou):

	gt_classes = []

	box_cls = [permute_to_N_HWA_K(x, self.num_classes) for x in box_cls]
	box_delta = [permute_to_N_HWA_K(x, 4) for x in box_delta]
	box_iou = [permute_to_N_HWA_K(x, 1) for x in box_iou]

	box_cls = torch.cat(box_cls, dim=1)
	box_delta = torch.cat(box_delta, dim=1)
	box_iou = torch.cat(box_iou, dim=1)

	losses_cls = []
	losses_box_reg = []
	losses_iou = []

	num_fg = 0

	for shifts_per_image, targets_per_image, box_cls_per_image, \
			box_delta_per_image, box_iou_per_image in zip(
			shifts, targets, box_cls, box_delta, box_iou):

		shifts_over_all = torch.cat(shifts_per_image, dim=0)

		gt_boxes = targets_per_image.gt_boxes
		gt_classes = targets_per_image.gt_classes

		deltas = self.shift2box_transform.get_deltas(
			shifts_over_all, gt_boxes.tensor.unsqueeze(1))
		is_in_boxes = deltas.min(dim=-1).values > 0.01

		shape = (len(targets_per_image), len(shifts_over_all), -1)
		box_cls_per_image_unexpanded = box_cls_per_image
		box_delta_per_image_unexpanded = box_delta_per_image

		box_cls_per_image = box_cls_per_image.unsqueeze(0).expand(shape)
		gt_cls_per_image = F.one_hot(
			torch.max(gt_classes, torch.zeros_like(gt_classes)), self.num_classes
		).float().unsqueeze(1).expand(shape)

		with torch.no_grad():
			loss_cls = sigmoid_focal_loss_jit(
				box_cls_per_image,
				gt_cls_per_image,
				alpha=self.focal_loss_alpha,
				gamma=self.focal_loss_gamma).sum(dim=-1)
			loss_cls_bg = sigmoid_focal_loss_jit(
				box_cls_per_image_unexpanded,
				torch.zeros_like(box_cls_per_image_unexpanded),
				alpha=self.focal_loss_alpha,
				gamma=self.focal_loss_gamma).sum(dim=-1)
			box_delta_per_image = box_delta_per_image.unsqueeze(0).expand(shape)
			gt_delta_per_image = self.shift2box_transform.get_deltas(
				shifts_over_all, gt_boxes.tensor.unsqueeze(1))
			loss_delta = iou_loss(
				box_delta_per_image,
				gt_delta_per_image,
				box_mode="ltrb",
				loss_type='iou')

			ious = get_ious(
				box_delta_per_image,
				gt_delta_per_image,
				box_mode="ltrb",
				loss_type='iou')

			loss = loss_cls + self.reg_cost * loss_delta + 1e3 * (1 - is_in_boxes.float())
			loss = torch.cat([loss, loss_cls_bg.unsqueeze(0)], dim=0)

			num_gt = loss.shape[0] - 1
			num_anchor = loss.shape[1]

			# Topk
			matching_matrix = torch.zeros_like(loss)
			_, topk_idx = torch.topk(loss[:-1], k=self.topk, dim=1, largest=False)
			matching_matrix[torch.arange(num_gt).unsqueeze(1).repeat(1,
			   self.topk).view(-1), topk_idx.view(-1)] = 1.

			# make sure one anchor with one gt
			anchor_matched_gt = matching_matrix.sum(0)
			if (anchor_matched_gt > 1).sum() > 0:
				loss_min, loss_argmin = torch.min(loss[:-1, anchor_matched_gt > 1], dim=0)
				matching_matrix[:, anchor_matched_gt > 1] *= 0.
				matching_matrix[loss_argmin, anchor_matched_gt > 1] = 1.
				anchor_matched_gt = matching_matrix.sum(0)
			num_fg += matching_matrix.sum()
			matching_matrix[-1] = 1. - anchor_matched_gt  # assignment for Background
			assigned_gt_inds = torch.argmax(matching_matrix, dim=0)

			gt_cls_per_image_bg = gt_cls_per_image.new_zeros(
				(gt_cls_per_image.size(1), gt_cls_per_image.size(2))).unsqueeze(0)
			gt_cls_per_image_with_bg = torch.cat(
				[gt_cls_per_image, gt_cls_per_image_bg], dim=0)
			cls_target_per_image = gt_cls_per_image_with_bg[
				assigned_gt_inds, torch.arange(num_anchor)]

			# Dealing with Crowdhuman ignore label
			gt_classes_ = torch.cat([gt_classes, gt_classes.new_zeros(1)])
			anchor_cls_labels = gt_classes_[assigned_gt_inds]
			valid_flag = anchor_cls_labels >= 0

			pos_mask = assigned_gt_inds != len(targets_per_image)  # get foreground mask
			valid_fg = pos_mask & valid_flag
			assigned_fg_inds = assigned_gt_inds[valid_fg]
			range_fg = torch.arange(num_anchor)[valid_fg]
			ious_fg = ious[assigned_fg_inds, range_fg]

		anchor_loss_cls = sigmoid_focal_loss_jit(
			box_cls_per_image_unexpanded[valid_flag],
			cls_target_per_image[valid_flag],
			alpha=self.focal_loss_alpha,
			gamma=self.focal_loss_gamma).sum(dim=-1)

		delta_target = gt_delta_per_image[assigned_fg_inds, range_fg]
		anchor_loss_delta = 2. * iou_loss(
			box_delta_per_image_unexpanded[valid_fg],
			delta_target,
			box_mode="ltrb",
			loss_type=self.iou_loss_type)

		anchor_loss_iou = 0.5 * F.binary_cross_entropy_with_logits(
			box_iou_per_image.squeeze(1)[valid_fg],
			ious_fg,
			reduction='none')

		losses_cls.append(anchor_loss_cls.sum())
		losses_box_reg.append(anchor_loss_delta.sum())
		losses_iou.append(anchor_loss_iou.sum())

	if self.norm_sync:
		dist.all_reduce(num_fg)
		num_fg = num_fg.float() / dist.get_world_size()

	return {
		'loss_cls': torch.stack(losses_cls).sum() / num_fg,
		'loss_box_reg': torch.stack(losses_box_reg).sum() / num_fg,
		'loss_iou': torch.stack(losses_iou).sum() / num_fg
	}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/39860.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小程序环境切换自定义组件

背景: 最近一直有参与小程序的项目,发现切换环境时经常要上传然后再设置为体验版,比较麻烦,所以尝试做了个切换环境的组件,分享给大家,希望大家能用得上,提点建议 组件长这个样子 展开后 功能&a…

JVM的垃圾回收机制(GC)

系列文章目录 JVM的内存区域划分_crazy_xieyi的博客-CSDN博客 JVM类加载(类加载过程、双亲委派模型)_crazy_xieyi的博客-CSDN博客 文章目录 一、什么是垃圾回收?二、java的垃圾回收,要回收的内存是哪些?三、回收堆上…

FPGA Base Xilinx跨时钟域宏XPM_CDC

FPGA Base Xilinx跨时钟域宏XPM_CDC最近看手底下的小伙子们写代码,对于跨时钟域的处理极度的不规范,还是放下这句话基础不牢,地动山摇 其实Xilinx公司已经为用户提供了宏定义,实现跨时钟域处理,见截图 XPM_CDC在命名上…

关于旅游景点主题的HTML网页设计——青岛民俗 7页 带登录注册

⛵ 源码获取 文末联系 ✈ Web前端开发技术 描述 网页设计题材,DIVCSS 布局制作,HTMLCSS网页设计期末课程大作业 | 游景点介绍 | 旅游风景区 | 家乡介绍 | 等网站的设计与制作| HTML期末大学生网页设计作业 HTML:结构 CSS:样式 在操作方面上运…

【C++】简化源码——vector的模拟实现

文章目录一、前言二、无参构造&析构三、基础接口1.empty和clear2.size和capacity3.[]和iterator四、resize和reserve五、尾插尾删六、其他构造七、迭代器失效1.insert2.erase八、memcpy问题九、vector.h一、前言 本篇的目的很简单,只有一个:模拟实现…

C语言刷题(一)

🐒博客名:平凡的小苏 📚学习格言:别人可以拷贝我的模式,但不能拷贝我不断往前的激情 目录 用递归法求一个整数一维数组a的最大元素 猴子吃桃问题 奇偶数换位问题 水仙花数(0-100000) 换啤酒…

web前端电影项目作业源码 大学生影视主题网页制作电影网页设计模板 学生静态网页作业成品 dreamweaver电影HTML网站制作

HTML实例网页代码, 本实例适合于初学HTML的同学。该实例里面有设置了css的样式设置,有div的样式格局,这个实例比较全面,有助于同学的学习,本文将介绍如何通过从头开始设计个人网站并将其转换为代码的过程来实践设计。 文章目录一、网页介绍一…

redis命令行操作库、键、和五大数据类型详解

一、数据库操作命令 redis默认有16个数据库,类似数组下标从0开始,初始默认使用0号库。 1.1 测试是否连通 ping测试服务器是否连通 返回pone就是连通了 1.2 切换数据库 select index1.3 数据移动 move key db1.4 显示数据总量 dbsize1.5 数据清除 …

Maven 跳过测试的几种方式

在 Maven 对项目进行编译的时候,我们通常可能会希望跳过复杂的测试。 尤其是在开始项目还不是非常稳定的阶段。 命令行中使用 -Dmaven.test.skiptrue 在命令行,只要简单的给任何目标添加 maven.test.skip 属性就能跳过测试: mvn install …

leetcode:6251. 统计回文子序列数目【dp + 统计xy子序列出现的个数】

目录题目截图题目分析ac code总结题目截图 题目分析 固定了中间的数i后从两边选xy 和 yx对于x y的情况,比较简单预处理每个数字出现的index为ids然后看看两边x各自的个数n1 n2n1和n2必须大于等于2左边可以选n1 * (n1 - 1) // 2右边可以选n2 * (n2 - 1) // 2两边乘…

【C++】通过哈希表实现map和set

前言 在前面,我们通过红黑树这一底层结构实现了map和set。它们是关联式容器。而现在,我们将通过哈希表这一数据结构重新实现map和set,即unordered系列的关联式容器。因为它们的遍历是无序的,和平衡二叉树不同,不能做到…

APOLLO UDACITY自动驾驶课程笔记——规划、控制

1、路径规划使用三个输入,第一个输入为地图,Apollo提供的地图数据包括公路网和实时交通信息。第二个输入为我们当前在地图上的位置。第三个输入为我们的目的地,目的地取决于车辆中的乘客。 2、将地图转为图形 该图形由“节点”(node)和“边缘…

直流潮流计算matlab程序

一、直流潮流计算原理 直流潮流发的特点是用电力系统的交流潮流(有功功率和无功功率)等值的直流电流来代替。甚至只用直流电路的解析法来分析电力系统的有功潮流,而不考虑无功分布对有功的影响。这样一来计算速度加快,但计算的准确…

Rocket MQ : 拒绝神化零拷贝

注: 本文绝非对零拷贝机制的否定笔者能力有限,理解偏差请大家多多指正不可否认零拷贝对于Rocket MQ的高性能表现有着积极正面的作用,但是笔者认为只是锦上添花,并非决定性因素。Rocket MQ性能卓越的原因绝非零拷贝就可以一言以蔽之。 笔者企图…

第146篇 笔记-智能合约介绍

定义:当满足某些预定义条件时,智能合约是一种在区块链网络上运行的防篡改程序。 1.什么是智能合约 智能合约是在区块链网络上托管和执行的计算机程序。每个智能合约都包含指定预定条件的代码,这些条件在满足时会触发并产生结果。通过在去中…

IDEA热部署插件JRebel and XRebel

IDEA热部署插件JRebel and XRebel嘚吧嘚下载安装激活配置使用嘚吧嘚 刚开始用过一段时间的eclipse,其他方面没感觉,但是eclipse的热部署真的是深得我心啊😊。 后来换了IDEA,瞬间就心动了,各个方面真的很好用&#xf…

U3D VideoPlayer播放视频和坑点

最近做的游戏里,需要先播放一段几秒钟的工作室LOGO片头,拿到的视频是AVI格式,以前没在U3D里用到过视频,本以为很简单,没想到都2022年了,U3D播放视频还这么烂。。。 插件最好用的是AVPro,除非你有大量的视频要播放,否则没必要用插件,一个是贵,另一个插件很大。 首先…

Python爬虫从入门到进阶

前言 董伟明,国内某知名Python应用网站高级产品开发工程师,《 Python Web 开发实战》作者,本书目前已经售出 17k 余本,另外也已经在台湾地区上市。在 2012 和 2014 年分别通过 2 个爬虫免试获得 2 个业界知名公司 offer&#xff…

MyBatis缓存机制之一级缓存

MyBatis缓存机制之一级缓存 前言 MyBatis内部封装了JDBC,简化了加载驱动、创建连接、创建statement等繁杂的过程,是我们常见的持久性框架。缓存是在计算机内存中保存的临时数据,读取时无需再从磁盘中读取,从而减少数据库的查询次…

Node.js 入门教程 1 Node.js 简介

Node.js 入门教程 Node.js官方入门教程 Node.js中文网 本文仅用于学习记录,不存在任何商业用途,如侵删 文章目录Node.js 入门教程1 Node.js 简介1.1 大量的库1.2 Node.js 应用程序的示例1.3 Node.js框架和工具1 Node.js 简介 Node.js 是一个开源和跨平台…