【MMDetection】MMDetection中AnchorGenerator学习笔记

news2024/11/13 21:20:41

文章目录

      • 初始化-AnchorGenerator()
      • Anchor平移-grid_priors
      • 计算有效anchor-valid_flags
      • 参考文献

初始化-AnchorGenerator()

@TASK_UTILS.register_module()
class AnchorGenerator:

    def __init__(self, strides, ratios, scales=None, base_sizes=None, scale_major=True, octave_base_scale=None, scales_per_octave=None, centers=None, center_offset=0., use_box_type=False):
    	
        # check center and center_offset
        if center_offset != 0:
            assert centers is None, 'center cannot be set when center_offset' \
                                    f'!=0, {centers} is given.'
        if not (0 <= center_offset <= 1):
            raise ValueError('center_offset should be in range [0, 1], '
                             f'{center_offset} is given.')
        if centers is not None:
            assert len(centers) == len(strides), \
                'The number of strides should be the same as centers, got ' \
                f'{strides} and {centers}'

        # calculate base sizes of anchors
        self.strides = [_pair(stride) for stride in strides]
        self.base_sizes = [min(stride) for stride in self.strides
                           ] if base_sizes is None else base_sizes
        assert len(self.base_sizes) == len(self.strides), \
            'The number of strides should be the same as base sizes, got ' \
            f'{self.strides} and {self.base_sizes}'

        # calculate scales of anchors
        assert ((octave_base_scale is not None
                 and scales_per_octave is not None) ^ (scales is not None)), \
            'scales and octave_base_scale with scales_per_octave cannot' \
            ' be set at the same time'
        if scales is not None:
            self.scales = torch.Tensor(scales)
        elif octave_base_scale is not None and scales_per_octave is not None:
            octave_scales = np.array(
                [2**(i / scales_per_octave) for i in range(scales_per_octave)])
            scales = octave_scales * octave_base_scale
            self.scales = torch.Tensor(scales)
        else:
            raise ValueError('Either scales or octave_base_scale with '
                             'scales_per_octave should be set')

        self.octave_base_scale = octave_base_scale
        self.scales_per_octave = scales_per_octave
        self.ratios = torch.Tensor(ratios)
        self.scale_major = scale_major
        self.centers = centers
        self.center_offset = center_offset
        self.base_anchors = self.gen_base_anchors()
        self.use_box_type = use_box_type

构造函数参数讲解
注意:这三个参数scale_major,center_offset,use_box_type我不是很清晰,如果你们看到了有懂的,可以评论告诉我一下,谢谢啦。

strides:           (list[int] | list[tuple[int, int]])  输入的各个特征图的stride步长,若为list[int],则经过_pair(stride)变为list[tuple[int, int]];若为list[tuple[int, int]],表示(w_stride,h_stride)。
ratios:            (list[float]) 每个grid上生成多个anchor的ratio,ratio=height/width,基于base_size变化。
scales:            (list[int] | None) 每个grid上生成多个anchor的scale,表示缩放比例,基于base_size变化,注意不可以与octave_base_scale、scales_per_octave同时指定。在RetinaNet模型中,指定了octave_base_scaleh和scales_per_octave,因此scales默认为None.
base_sizes:        (list[int] | None) 每一特征层的anchor的基本大小。若为None,则默认等于stride(若stride的长宽不一致,则选择短边) 。
scale_major:       (bool) 首先每个grid上会生成len(scales)*len(ratios)个base anchor。scale_major将确定base anchor的排列顺序!若为true,表示scale优先,即base anchors的每一行的scale相同;若为false,表示ratios优先,base anchors的每一行的ratio相同。在MMDetection2.0中,默认为True.
octave_base_scale: (int) The base scale of octave。
scales_per_octave: (int) Number of scales for each octave。octave_base_scale and scales_per_octave 用在retinanet中,注意不可以与scales同时指定,scale和octave_base_scale and scales_per_octave的转换公式为:scales = [2**(i / scales_per_octave) for i in range(scales_per_octave)]) * octave_base_scale。
centers:           (list[tuple[float, float]] | None) AnchorGenerator类中默认为None,若为None,则每个anchor中心与网格的左上角对齐!yolohead会设计center,使得anchor中心与网格中心对齐。
center_offset:     (float) The offset of center in proportion to anchors' width and height。
use_box_type:      (bool) Whether to warp anchors with the box type data structure. Defaults to False.
# 计算base_size
self.strides = [_pair(stride) for stride in strides] # 
self.base_sizes = [min(stride) for stride in self.strides] if base_sizes is None else base_sizes

下图是RetinaNet网络中的base_size和stride.
在这里插入图片描述

# 得到scales. 注意RetinaNet网络中scales为None
if scales is not None:       
    self.scales = torch.Tensor(scales)
elif octave_base_scale is not None and scales_per_octave is not None:
    octave_scales = np.array([2**(i / scales_per_octave) for i in range(scales_per_octave)])
    scales = octave_scales * octave_base_scale
    self.scales = torch.Tensor(scales)

下图是RetinaNet网络中的octave_base_scale和octave_scales .
在这里插入图片描述
在这里插入图片描述

 self.octave_base_scale = octave_base_scale # RetinaNet中为4
 self.scales_per_octave = scales_per_octave # RetinaNet中为3
 self.ratios = torch.Tensor(ratios) # RetinaNet中为[0.5, 1.0, 2.0]
 self.scale_major = scale_major # RetinaNet中为True
 self.centers = centers # RetinaNet中为None
 self.center_offset = center_offset # RetinaNet中为0 
 self.base_anchors = self.gen_base_anchors() # 在下面会重点讲
 self.use_box_type = use_box_type # # RetinaNet中为False
# gen_base_anchors 调用了 gen_single_level_base_anchors,得到多尺度的anchor. gen_single_level_base_anchors 会在下面详细讲。
def gen_base_anchors(self):
    multi_level_base_anchors = []
    for i, base_size in enumerate(self.base_sizes):
        center = None
        if self.centers is not None:
            center = self.centers[i]
        multi_level_base_anchors.append(self.gen_single_level_base_anchors(base_size,vscales=self.scales, ratios=self.ratios, center=center))
    return multi_level_base_anchors

下面以RetinaNet为例,讲解一下

def gen_single_level_base_anchors(self, base_size, scales, ratios, center=None):

    w = base_size 
    h = base_size
    
    if center is None:
        x_center = self.center_offset * w
        y_center = self.center_offset * h
    else:
        x_center, y_center = center
        
	# h/w = ratios
    h_ratios = torch.sqrt(ratios)
    w_ratios = 1 / h_ratios
    
    if self.scale_major:
    
        ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
        hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
    else:
        ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
        hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
        
    base_anchors = [ x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws, y_center + 0.5 * hs]
    
    base_anchors = torch.stack(base_anchors, dim=-1)

    return base_anchors
# 当前特征图的w和h
 w = base_size 
 h = base_size
# 计算anchor中心点位置,默认为(0,0)
 if center is None:
     x_center = self.center_offset * w
     y_center = self.center_offset * h
 else:
     x_center, y_center = center

在这里插入图片描述

# 保证高宽比为ratios,注意下述操作是对tensor的操作
 h_ratios = torch.sqrt(ratios)
 w_ratios = 1 / h_ratios

在这里插入图片描述

ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)

注意:这里通过引入None,扩充维度。
在这里插入图片描述
在这里插入图片描述

# 生成当前层的base_anchor
base_anchors = [ x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws, y_center + 0.5 * hs]  
base_anchors = torch.stack(base_anchors, dim=-1)

需要注意scale_major变量的作用,用于确定base anchor的排列顺序。若为true,那先乘以ratios,再乘以scales。举个例子,scales=[1,2],ratios=[0.5,1],base size为(32,32)。那么

当scale_major为true, 则返回[ [( 32 2 , 32 2 32\sqrt2,\frac{32}{\sqrt2} 322 ,2 32), (32,32)] , [( 64 2 , 64 2 64\sqrt2,\frac{64}{\sqrt2} 642 ,2 64),(64,64)] ]
当scale_major为false,则返回[ [(32,32),(64,64)] , [( 32 2 , 32 2 32\sqrt2,\frac{32}{\sqrt2} 322 ,2 32),( 64 2 , 64 2 64\sqrt2,\frac{64}{\sqrt2} 642 ,2 64)] ]

Anchor平移-grid_priors

与anchor初始化一样,平移anchor的操作主要在single_level_grid_priors函数中,下面重点讲解这个函数。


def grid_priors(self, featmap_sizes, device='cuda'):
    assert self.num_levels == len(featmap_sizes)
    multi_level_anchors = []
    for i in range(self.num_levels):
        anchors = self.single_level_grid_priors(
            self.base_anchors[i].to(device),
            featmap_sizes[i],
            self.strides[i],
            device=device)
        multi_level_anchors.append(anchors)
    return multi_level_anchors # 返回list[num_levels * tensor(H*W*num_anchors,4)]

def single_level_grid_priors(self, base_anchors, featmap_size, stride=(16, 16), device='cuda'):

	base_anchors = self.base_anchors[level_idx].to(device).to(dtype)
	feat_h, feat_w = featmap_size
	stride_w, stride_h = self.strides[level_idx]

	shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
	shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h
	
	shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
	shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
		
	all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
	all_anchors = all_anchors.view(-1, 4)

	if self.use_box_type:
	    all_anchors = HorizontalBoxes(all_anchors)
	return all_anchors

def _meshgrid(self, x, y, row_major=True):
       # 获得网格点
       xx = x.repeat(len(y))
       yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
       if row_major:
           return xx, yy # xx和yy的shape为(rows*cols,)
       else:
           return yy, xx
# 获取当前层的base_anchors
base_anchors = self.base_anchors[level_idx].to(device).to(dtype)

在这里插入图片描述

# 当前层的特征图大小和步长
feat_h, feat_w = featmap_size
stride_w, stride_h = self.strides[level_idx]

在这里插入图片描述

# 乘以stride,映射回原图
shift_x = torch.arange(0, feat_w, device=device).to(dtype) * stride_w
shift_y = torch.arange(0, feat_h, device=device).to(dtype) * stride_h

在这里插入图片描述
在这里插入图片描述

# 获取anchor的中心点
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)

shift_xx如下,每隔feat_w重复
在这里插入图片描述
shift_yy如下,每隔feat_hstride
在这里插入图片描述

shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)

shifts如下,(0,1)和(2,3)一致,是因为左上角和右上角坐标移动的时候要同时移动。
在这里插入图片描述
非常简洁的代码实现

# 得到特征图上所有的anchors
all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
base_anchors[None, :, :] 扩充维度为(1,9,4)
shifts[:, None, :] 扩成维度为(15200,1,4)
相加的时候,base_anchors(1,9,4)会广播为(15200,9,4),即将(9,4)赋值为15200份。shifts(15200,1,4)会广播为(15200,9,4),即将(1,4)复制为9份。

base_anchors如下
在这里插入图片描述
shifts[:, None, :]如下,相加的时候会将一行复制为9行。
在这里插入图片描述

# (15200,9,4)变为(136800,4)
all_anchors = all_anchors.view(-1, 4)

计算有效anchor-valid_flags

由于在数据预处理时,填充了大量黑边,所以在黑边上的anchor不用计算loss,可以忽略,节省算力。因此valid_flags返回有效的anchor索引。

def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):

        # pad_shape是有效的特征图大小,是指Pad后的size,collate之前
        assert self.num_levels == len(featmap_sizes)
        multi_level_flags = []
        for i in range(self.num_levels):
            anchor_stride = self.strides[i]
            feat_h, feat_w = featmap_sizes[i]
            h, w = pad_shape[:2]
            valid_feat_h = min(int(np.ceil(h / anchor_stride[1])), feat_h) # 获得有效的特征图
            valid_feat_w = min(int(np.ceil(w / anchor_stride[0])), feat_w) # 获得有效的特征图
            flags = self.single_level_valid_flags((feat_h, feat_w),
                                                  (valid_feat_h, valid_feat_w),
                                                  self.num_base_anchors[i],
                                                  device=device)
            multi_level_flags.append(flags)  # 有效位置设置为1,否则为0
        return multi_level_flags

    def single_level_valid_flags(self,
                                 featmap_size,
                                 valid_size,
                                 num_base_anchors,
                                 device='cuda'):
        feat_h, feat_w = featmap_size
        valid_h, valid_w = valid_size
        assert valid_h <= feat_h and valid_w <= feat_w
        # 使用填桶法生成有效的位置
        valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
        valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
        # 有效的位置填1
        valid_x[:valid_w] = 1
        valid_y[:valid_h] = 1
        
        valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
        
        valid = valid_xx & valid_yy # tensor(H*W,) bool
        
        valid = valid[:, None].expand(valid.size(0),num_base_anchors).contiguous().view(-1)
        # tensor(H*W*num_base_anchors,) bool
        return valid

在这里插入图片描述
valid_x如下
在这里插入图片描述
valid_y如下
在这里插入图片描述

参考文献

  1. mmdetection源码阅读笔记:prior generator

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/67862.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

numpy的部分通用函数浅谈

numpy的部分通用函数 1.数组算术运算符 运算符对应的通用函数描述np.add加法运算&#xff08;即112)-np.substract减法运算&#xff08;即3-21&#xff09;-np.negative负数运算&#xff08;即-2&#xff09;*Nnp.multiply乘法运算&#xff08;即2*36&#xff09;/np.divide除…

Optional用法与争议点

Optional用法与争议点 简介 要说Java中什么异常最容易出现&#xff0c;我想NullPointerException一定当仁不让&#xff0c;为了解决这种null值判断问题&#xff0c;Java8中提供了一个新的工具类Optional&#xff0c;用于提示程序员注意null值&#xff0c;并在特定场景中简化代…

软件测试8年,却被应届生踩在头上,是应届生太牛了,还是我们太弱了?

前几天有个朋友向我哭诉&#xff0c;说他在公司干了8年的软件测试&#xff0c;却被一个实习生代替了&#xff0c;该何去何从? 这是一个值得深思的问题&#xff0c;作为职场人员&#xff0c;我们确实该思考&#xff0c;我们的工作会被实习生代替吗?这是一个很尖锐的问题&…

MFC基于对话框——仿照Windows计算器制作C++简易计算器

目录 一、界面设计 二、设置成员变量 三、初始化成员变量 四、初始化对话框 ​五、添加控件代码 1.各个数字的代码&#xff08;0~9&#xff09; 2.清除功能的代码 3.退格功能的代码 4.加减乘除功能的代码 5.小数点功能的代码 6.正负号功能的代码 7.等于功能的代码…

算法day42|背包问题

目录 01背包问题 二维 01背包问题 一维 416. 分割等和子集 背包问题分为&#xff1a;01背包&#xff0c;完全背包&#xff0c;多种背包01背包指的是有n种物品&#xff0c;每种物品只能取一个完全背包指的是有n种物品,每种物品可以取无限个多种背包指的是有n种物品&#xff0c;每…

公众号网课搜题接口

公众号网课搜题接口 本平台优点&#xff1a; 多题库查题、独立后台、响应速度快、全网平台可查、功能最全&#xff01; 1.想要给自己的公众号获得查题接口&#xff0c;只需要两步&#xff01; 2.题库&#xff1a; 查题校园题库&#xff1a;查题校园题库后台&#xff08;点击…

常用的在线工具网站

1&#xff0c;在线Photoshop软件 https://www.uupoop.com/ PS在线图片编辑器是一个专业精简的在线ps图片照片制作处理软件工具,绿色免安装,免下载,直接在浏览器打开就可用它修正,调整和美化图像。 2&#xff0c;bilibili视频编辑器 https://bilibili.clipchamp.com/ 由哔哩哔哩…

(保姆级)国内1块钱注册火爆全网的OpenAI-ChatGPT机器人

下面有给出完整的注册流程。首先介绍一下它是什么&#xff0c;如果只想看注册往下翻&#xff01; 1块钱注册火爆全网的OpenAI-ChatGPT机器人OpengAI-ChatGPT能做什么如何注册1块钱收取验证码使用注册的账号登录ChatGPTOpengAI-ChatGPT能做什么 我作为一个程序员用了一段时间&a…

金蝶云星空生产管理(冲刺学习)

物料“基本”和“生产”相关属性字段介绍 物料属性&#xff1a;生产中常用的物料属性包括自制、委外、外购、虚拟、配置、特征。 自制&#xff1a;一般是指由企业自己生产的物料&#xff0c;一般会建BOM、生产订单&#xff1b; 委外&#xff1a;是指委托给其他加工单位进行加工…

DevTools 无法加载来源映射:无法加载 chrome-extension: 警告的原因以及如何去除(全网最全 最详细解决方案)

是类似这样的一个警告。每次都有看着还是挺难受的。 这个警告的原因是你的浏览器插件造成的。例如警告已经很明确的告诉你是chrome-extension&#xff0c;也就是谷歌插件的问题。后面的字符串其实就是这个插件的id。 chrome-extension://cfhdojbkjhnklbpkdaibdccddilifddb/br…

QT笔记——QSlider滑动条滚轮事件和点击鼠标位置事件问题

需求&#xff1a;我们需要对一个滑动条 滚轮事件 和 点击到滑动条的位置 实时显示 问题&#xff1a;其中在做的时候遇到了很多的问题&#xff0c;一开始感觉很简单&#xff0c;现在将这些问题记录下来 ui图&#xff1a; 问题1&#xff1a;处理QSlider 滚轮事件的时候 这里有…

AlphaFold2源码解析(8)--模型之三维坐标构建

AlphaFold2源码解析(8)–模型之三维坐标构建 这个模块我们讲解AlphaFold的Structure module模块&#xff0c;该结构模块将蛋白质结构的抽象表示映射为具体的三维原子坐标。 Evoformer的单一表征被用作初始单一表征siinitial{s^{initial}_i }siinitial​&#xff0c;siinitial∈…

同步整流 降压恒流 输入4-40V 功率可达40W 电流3.6A 原理图

◆PCB 布线参考PCB 布局应遵循如下规则以确保芯片的正常工作。1:功率线包括地线&#xff0c;LX线和VIN线应该尽量做到短、 直和宽。2:输入电容应尽可能靠近芯片管脚&#xff08;VIN 和 &#xff09;。输入电源引脚可增加一个 0.1uF 的陶瓷电容以增强芯片的抗高频噪声能力。3:功…

小迪-day13(MySQL注入)

一、information_schema information_schema 数据库跟 performance_schema 一样&#xff0c;都是 MySQL 自带的信息数据库。其中 performance_schema 用于性能分析&#xff0c;而 information_schema 用于存储数据库元数据(关于数据的数据)&#xff0c;例如数据库名、表名、列…

信号和电源隔离的有效设计技术

介绍 如今&#xff0c;电子产品设计师比以往任何时候都更面临着一系列共同的目标&#xff1a;实现更高的吞吐量、更高的分辨率、更高效的系统和缩短上市时间。在工业自动化、医疗电子或电信系统等领域&#xff0c;通常需要电隔离多个信号&#xff0c;以使子系统能够共享数据或…

农民歌唱家大衣哥外出商演,大衣嫂在家晒麦子,真是一对金童玉女

在中国华语乐坛&#xff0c;曾经有很多对模范夫妻&#xff0c;比如说任静和付笛声&#xff0c;他们也是音乐领域的金童玉女。其实大家都忽略了一对夫妻&#xff0c;农民歌唱家大衣哥&#xff0c;和他的结发妻子玉华&#xff0c;同样是中国华语乐坛的骄傲。 只是因为大衣哥过于低…

计算机网络复习(一~三)

第一章 基本概念 1-01.计算机网络可以向用户提供哪些服务&#xff1f; 答&#xff1a;例如音频&#xff0c;视频&#xff0c;游戏等&#xff0c;但本质是提供连通性和共享这两个功能。连通性&#xff1a;计算机网络使上网用户之间可以交换信息&#xff0c;好像这些用户的计算…

RDPCrystal EDI SDK 10.0.4.X Crack

关于 RDPCrystal EDI 库 使用 .NET、NodeJS、JavaScript 或 .NET Core 创建、查看和验证 EDI 数据。 RDPCrystal EDI 库是一套 EDI 组件&#xff08;.NET、NodeJS/JavaScript 和 .NET Core&#xff09;&#xff0c;可以创建和操作任何 X12 标准文件。功能包括解析、连接、拆分、…

【Unity】填坑,Unity接入Epic Online Service上架Epic游戏商城

EOS SDK For Unity地址&#xff1a;https://github.com/PlayEveryWare/eos_plugin_for_unity_upm Epic是虚幻游戏引擎开发商&#xff0c;2018年12月Epic宣布推出Epic游戏商城至今刚好三年&#xff0c;Epic将平台分成定为12%(远低于当时Steam的30%)&#xff0c;并且频繁推出各种…

每天一个面试题:四种引用,弱引用防止内存泄漏

每天一个面试题&#xff1a;四种引用四种引用基本介绍实例Demo- 虚引用弱引用防止内存泄漏弱引用Debug分析源码开始全新的学习&#xff0c;沉淀才会有产出&#xff0c;一步一脚印&#xff01; 面试题系列搞起来&#xff0c;这个专栏并非单纯的八股文&#xff0c;我会在技术底层…