【Faster R-CNN】之 RPN Head 代码精读
- 1、前言
- 2、RPN Head 网络结构
- 3、代码
- 4、相关问题
1、前言
在上一篇文章 【Faster R-CNN】之 backbone 代码精读 中,我们学习了创建 backbone,并将 batch 中的图片通过backbone 获得了 feature maps。 batch 的 feature map size 类似为 [batch_size, out_channel, ]
其中:
- batch_size, 是根据 gpu性能,我们自己指定的
- out_channel : 是由 backbone 决定的,比如我们采用的是 resnet18 前部分到 layer4 的结构作为 backbone, 输出的 channel 就是 512
- height, width : 是由 batch中的图片尺寸 和 backbone 共同决定的。 值得注意的是,同一个 batch 中的图像尺寸是相同的,不同 batch 间的图像尺寸是不一样的。所以,不同 batch 获得的 feature map 的 高和宽 是不一样的。
2、RPN Head 网络结构
在获得 feature map 之后,我们就要考虑 怎么预测 候选框(proposals) 的坐标,以及候选框的置信度了。 使用网络结构 如下:
这里网络相对简单,就不详细介绍了
3、代码
class RPNHead(nn.Module):
def __init__(self, in_channels, num_anchors):
super(RPNHead, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
# 预测 bounding box 中有 object 的概率 (为前景的概率)
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
# 预测 bounding box regression 参数
self.bbox_pred = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=1, stride=1)
# 初始化模型参数
for layer in self.children():
if isinstance(layer, nn.Conv2d):
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.constant_(layer.bias, 0)
def forward(self, x):
t = F.relu(self.conv(x))
return self.cls_logits(t), self.bbox_pred(t)
假设 feature map 的 shape =(8, 512, 38, 40),经过 conv 3x3 得到的尺寸同样是 (8, 512, 38, 40),最后通过 conv 1x1 得到的 object 概率的 shap 为 (8, 15, 38, 40), 经过 conv 1x1 得到的 bounding box regression 的 shap 为 (8, 60, 38, 40)。
4、相关问题
问:为什么 RPN Head 的两个输出的 shape分别是 (8, 15, 38, 40) 和 (8, 60, 38, 40)呢?
强烈建议 结合下一篇文章 【Faster R-CNN】之 AnchorGenerator 代码精读 一起看,下篇文章 介绍了 anchor 与 feature map 、原图、bounding box 之间的关系。
-
object 概率的 shap 为 (8, 15, 38, 40),其中:8是batch size,表示batch 中有8张图像;(38,40)对应着 feature map 的每一个像素,15 对应着每个anchor; 矩阵中的每个值就 表示每张 feature map 中的每个像素 的 每个anchor 为 前景的概率。
-
bounding box resresison 的 shape 为(8, 60, 38, 40),其中:8是batch size,表示batch 中有8张图像;(38,40)对应着 feature map 的每一个像素,60 对应着 15 个 bounding box 的 4 个坐标位置(xmin, ymin, xmax, ymax); 矩阵中的每个值就 表示每张 feature map 中的每个像素 对应在原图中的 bounding box 的 4个坐标之一。