paper:Focal and Global Knowledge Distillation for Detectors
official implementation:https://github.com/yzd-v/FGD
存在的问题
如图1所示,前景区域教师和学生注意力之间的差异非常大,背景区域则相对较小。此外通道注意力的差异也非常明显。
作者还设计了实验解耦了蒸馏过程中的前景和背景,结果如表1所示,令人惊讶的是,前景背景一起进行蒸馏的效果是最差的,比单独蒸馏前景或背景还差。
上述结果表明,特征图中的不均匀差异会对蒸馏产生负面效果。这种不均匀差异不仅存在于前背景之间,也存在于不同像素位置和通道之间。
本文的创新点
针对前背景、空间位置、通道之间的差异,本文提出了focal distillation,在分离前背景的同时,还计算了教师特征不同空间位置和通道的注意力,使得学生专注于学习教师的关键像素和通道。
但是只关注关键信息还不够,在检测任务中全局语义信息也很重要。为了弥补focal蒸馏中缺失的全局信息,作者还提出了global distillation,其中利用GcBlock来提取不同像素之间的关系,然后传递给学生。
方法介绍
Focal Distillation
首先用一个binary mask \(M\) 来分离前背景
其中 \(r\) 是ground truth box,\(i,j\) 表示像素位置的坐标。
为了消除不同大小的gt box的尺度的影响和不同图片中前背景比例的差异,作者又设置了一个scale mask \(S\)
其中 \(H_{r},W_{r}\) 表示gt box \(r\) 的高和宽,如果一个像素属于不同的target,选择最小的box来计算 \(S\)。
接着作者借鉴SENet和CBAM的方法提取通道注意力和空间注意力
\(G^{S},G^{C}\) 分别表示空间和通道attention map,然后attention mask按下式计算
其中 \(T\) 是温度系数。
利用binary mask \(M\)、scale mask \(S\)、attention mask \(A^{S},A^{C}\),特征损失 \(L_{fea}\) 如下
其中 \(A^{S},A^{C}\) 表示教师的空间和通道attention mask,\(F^{T},F^{S}\) 分别表示教师和学生的feature map,\(\alpha, \beta\) 是balance超参。
此外作者还提出了注意力损失 \(L_{at}\) 让学生模仿教师的attention mask
\(l\) 表示L1损失。
完整的focal损失就是特征损失和注意力损失的和
Global Distillation
如图4所示,作者用GcBlock来提取全局关系信息,关于GcBlock的详细介绍可以参考GCNet: Global Context Network(ICCV 2019)原理与代码解析
全局损失 \(L_{global}\) 如下
\(W_{k},W_{v1},W_{v2}\) 是卷积层,\(LN\) 表示layer normalization,\(N_{p}\) 是特征中所有像素个数,\(\lambda\) 是balance超参。
Overall loss
完整的损失函数如下,包括原本的训练损失和蒸馏损失,蒸馏损失又包括focal损失和global损失
实验结果
其中inheriting strategry是《Instance-conditional knowledge distillation for object detection》这篇文章中提出的用教师的neck和head参数初始化学生网络,可以得到更好的效果。
代码解析
主要实现在mmdet/distillation/losses/fgd.py中,函数forward中,首先教师和学生的attention mask,即文中的式(5)~(8)
S_attention_t, C_attention_t = self.get_attention(preds_T, self.temp) # (N,H,W),(N,C)
S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)
def get_attention(self, preds, temp):
""" preds: Bs*C*W*H """
N, C, H, W = preds.shape
value = torch.abs(preds)
# Bs*W*H
fea_map = value.mean(axis=1, keepdim=True)
S_attention = (H * W * F.softmax((fea_map / temp).view(N, -1), dim=1)).view(N, H, W)
# Bs*C
channel_map = value.mean(axis=2, keepdim=False).mean(axis=2, keepdim=False)
C_attention = C * F.softmax(channel_map / temp, dim=1)
return S_attention, C_attention
接下来为了减小不同target尺度和前背景比例的影响,计算scale mask,即文中的式(2)~式(4)。其中内层的for循环是当一个像素属于不同的target时,选择最小的box来计算。
Mask_fg = torch.zeros_like(S_attention_t)
Mask_bg = torch.ones_like(S_attention_t)
wmin, wmax, hmin, hmax = [], [], [], []
for i in range(N):
new_boxxes = torch.ones_like(gt_bboxes[i])
new_boxxes[:, 0] = gt_bboxes[i][:, 0] / img_metas[i]['img_shape'][1] * W
new_boxxes[:, 2] = gt_bboxes[i][:, 2] / img_metas[i]['img_shape'][1] * W
new_boxxes[:, 1] = gt_bboxes[i][:, 1] / img_metas[i]['img_shape'][0] * H
new_boxxes[:, 3] = gt_bboxes[i][:, 3] / img_metas[i]['img_shape'][0] * H
wmin.append(torch.floor(new_boxxes[:, 0]).int())
wmax.append(torch.ceil(new_boxxes[:, 2]).int())
hmin.append(torch.floor(new_boxxes[:, 1]).int())
hmax.append(torch.ceil(new_boxxes[:, 3]).int())
area = 1.0 / (hmax[i].view(1, -1) + 1 - hmin[i].view(1, -1)) / (wmax[i].view(1, -1) + 1 - wmin[i].view(1, -1))
for j in range(len(gt_bboxes[i])):
Mask_fg[i][hmin[i][j]:hmax[i][j] + 1, wmin[i][j]:wmax[i][j] + 1] = \
torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j] + 1, wmin[i][j]:wmax[i][j] + 1], area[0][j])
Mask_bg[i] = torch.where(Mask_fg[i] > 0, 0, 1)
if torch.sum(Mask_bg[i]):
Mask_bg[i] /= torch.sum(Mask_bg[i])
接着就是完整的feature损失,即文中的式(9)
fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg,
C_attention_s, C_attention_t, S_attention_s, S_attention_t)
def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):
loss_mse = nn.MSELoss(reduction='sum')
Mask_fg = Mask_fg.unsqueeze(dim=1)
Mask_bg = Mask_bg.unsqueeze(dim=1)
C_t = C_t.unsqueeze(dim=-1)
C_t = C_t.unsqueeze(dim=-1)
S_t = S_t.unsqueeze(dim=1)
fea_t = torch.mul(preds_T, torch.sqrt(S_t))
fea_t = torch.mul(fea_t, torch.sqrt(C_t))
fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))
bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))
fea_s = torch.mul(preds_S, torch.sqrt(S_t))
fea_s = torch.mul(fea_s, torch.sqrt(C_t))
fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))
bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))
fg_loss = loss_mse(fg_fea_s, fg_fea_t) / len(Mask_fg)
bg_loss = loss_mse(bg_fea_s, bg_fea_t) / len(Mask_bg)
return fg_loss, bg_loss
文中作者还提出了用L1 loss的attention损失,即式(10)
mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)
def get_mask_loss(self, C_s, C_t, S_s, S_t):
mask_loss = torch.sum(torch.abs((C_s - C_t))) / len(C_s) + torch.sum(torch.abs((S_s - S_t))) / len(S_s)
return mask_loss
feature loss和attention loss一起组成的focal loss,为了弥补全局语义信息的缺失,作者又引入了全局蒸馏损失,其中用到了GcBlock,即式(12)
rela_loss = self.get_rela_loss(preds_S, preds_T)
def get_rela_loss(self, preds_S, preds_T):
loss_mse = nn.MSELoss(reduction='sum')
context_s = self.spatial_pool(preds_S, 0)
context_t = self.spatial_pool(preds_T, 1)
out_s = preds_S
out_t = preds_T
channel_add_s = self.channel_add_conv_s(context_s)
out_s = out_s + channel_add_s
channel_add_t = self.channel_add_conv_t(context_t)
out_t = out_t + channel_add_t
rela_loss = loss_mse(out_s, out_t) / len(out_s)
return rela_loss
def spatial_pool(self, x, in_type):
batch, channel, width, height = x.size()
input_x = x
# [N, C, H * W]
input_x = input_x.view(batch, channel, height * width)
# [N, 1, C, H * W]
input_x = input_x.unsqueeze(1)
# [N, 1, H, W]
if in_type == 0:
context_mask = self.conv_mask_s(x)
else:
context_mask = self.conv_mask_t(x)
# [N, 1, H * W]
context_mask = context_mask.view(batch, 1, height * width)
# [N, 1, H * W]
context_mask = F.softmax(context_mask, dim=2)
# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N, 1, C, 1]
context = torch.matmul(input_x, context_mask)
# [N, C, 1, 1]
context = context.view(batch, channel, 1, 1)
return context
self.channel_add_conv_s = nn.Sequential(
nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
nn.LayerNorm([teacher_channels//2, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
self.channel_add_conv_t = nn.Sequential(
nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
nn.LayerNorm([teacher_channels//2, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))