paper: arxiv.org/pdf/2208.11821v2.pdf
repo link: KKallidromitis/r2o: PyTorch implementation of Refine and Represent: Region-to-Object Representation Learning. (github.com)
摘要:
在本文中提出了区域到对象表示学习(Region-to-Object Representation Learning,R2O),它在预测分割掩码和使用这些掩码预训练编码网络之间振荡。R2O通过对编码特征进行聚类来确定分割掩码。R2O然后通过执行区域到区域的相似性学习来预训练编码网络,其中编码网络获取图像的不同视图,并将分割的区域映射到相似的编码特征。
引言:
R2O是一个框架,通过在整个预训练过程中进化预测分割掩码来学习基于区域和以对象为中心的特征,并鼓励与预测掩码内容相对应的特征的表示相似性。R2O使用掩模预测模块来使用所学习的特征将多个小规模图像区域转换为较大的以对象为中心的区域。在预训练过程中,预测掩码被用于学习表示,这反过来又会在未来的时代中导致更准确的、以对象为中心的分割(见图1)。这使得R2O能够在不依赖于分割启发式的情况下训练以对象为中心的特征。
图1. R2O通过联合学习发现和表示区域和对象,将基于区域和以对象为中心的自监督学习统一起来。以前的以对象为中心的预训练方法使用现成的分割先验来学习以对象为核心的表示。相反,R2O通过在预测分割掩模(这取决于表示)和训练表示(使用分割掩模)之间进行振荡来学习以对象为中心的表示。具体而言,R2O汇集与分割区域相对应的特征(区域相似性损失),并在视图之间强制这些特征的表示相似性。然后,习得的表示会导致改进的、越来越以对象为中心的掩膜,进而导致改进的表示。
论文贡献总结如下:
•我们提出了R2O:一种自监督预训练方法,通过预测分割掩模、使用学习的图像特征和学习掩模内内容的表示来学习基于区域和以对象为中心的表示。
•我们为R2O引入了一个区域到对象的预训练课程,该课程从训练与简单图像区域相对应的局部特征开始,例如共享相似颜色值的相邻像素,然后逐渐学习以对象为中心的特征。
•与以前的方法相比,R2O预训练提高了ImageNet在MS COCO 1×(+0:2 APbb,+0:3 APmk)和2×(+0:9 APbb,+0.0:4 APmk)。在对COCO进行预训练后,R2O的PASCAL VOC和Cityscapes分别比早期方法高出+1:9mIOU和+2:4mIOU。此外,R2O在加州理工大学UCSD Birds 200-2011(CUB200-2011)上进一步提高了最先进的无监督分割性能+3:3mIOU,尽管没有对CUB200-202011进行微调。
相关工作:
基于区域和以对象为中心的自监督预训练
我们的工作扩展了基于区域的预训练,首先训练区域级特征,然后逐渐发现以对象为中心的分割,最终训练以对象为核心的特征。Odin[31]和SlotCon[62]等并行工作已经提出了对象级预训练方法,该方法不依赖于分割启发式,而是使用学习的特征来分割输入图像。与Odin类似,R2O对图像特征使用Kmeans聚类来帮助发现类对象区域。然而,与只关注训练对象级特征的Odin或SlotCon不同,R2O预训练涉及训练局部特征和对象级特征。
区域聚类
在这项工作中,我们使用聚类将小图像区域细化为以对象为中心的掩码。聚类在无监督语义分割中有着悠久的历史[1,13,14,17,21,33]。这一趋势在深度学习时代仍在继续。当前的方法利用聚类分配来学习空间和语义一致的嵌入[12,34,36,37]。例如,一种流行的方法[34, 37]是对像素嵌入进行聚类训练,并使用聚类分配作为伪标签。其他方法 [12, 36] 则使用相似性损失和特定任务数据增强来训练嵌入。我们的工作并不侧重于无监督语义分割任务,而是侧重于使用聚类来定位对象,以便学习可迁移的表征。我们的方法使用 K-means 聚类[39],因为它简单有效。
方法:
我们提出了R2O,这是一种自监督的预训练方法,可以实现区域级和以对象为中心的表示学习。在高水平上,R2O通过掩模预测模块将区域级图像先验转换为以对象为中心的掩模,并鼓励与所发现的掩模的内容相对应的特征的代表不变性(见图2)。
图2:R2O架构。R2O由两个相互依存的步骤组成:(1)掩模预测,其中使用学习的特征将小图像区域(由区域级先验给定)转换为以对象为中心的掩模;(2)表示学习,其中学习对象级表示。
步骤1通过计算由区域级先验给出的每个区域的特征并对这些区域级嵌入执行K-means聚类来产生以对象为中心的分割。
步骤2鼓励与每个视图中的掩码的内容相对应的特征的表示相似性。通过这个过程,我们发现通过第二步学习到的特征会导致越来越多的以对象为中心的分割,从而训练以对象为核心的表示。考虑到区域级特征在我们的掩模预测过程中的重要性,R2O采用了一种区域到对象的课程(备注:课程学习curriculum learning),该课程在掩模预测期间将聚类(K)的数量从非常高的值(K=128)逐渐减少到较低的值(K=4),正如我们所示,这使得能够在训练期间早期进行基于区域的预训练,并慢慢发展为以对象为中心的关联。我们将在线网络和目标网络分别表示为ON,TN。
图3. COCO预训练过程中掩码预测的可视化。在后期的预训练中,R2O预测的掩码能够发现以对象为中心的区域,即使是在多对象、以场景为中心的数据集中,如COCO。
结果
表1. 在ImageNet预训练后的PASCAL VOC和Cityscapes的COCO对象检测和实例分割以及语义分割方面的性能。所有方法都预训练了ResNet-50主干,并微调了用于COCO的Mask RCNN(R50-FPN)和用于PASCAL VOC和Cityscapes的FCN。*:表示并行工作。†:重新实施的结果。
表2. COCO预训练后PASCAL VOC和Cityscapes语义分割(mIOU)的性能。在对以场景为中心的数据进行预训练后,R2O在PASCAL VOC和Cityscapes上展示了最先进的语义分割传输性能。*:表示并行工作。†:重新实施的结果。
表3. CUB200-2011细分市场表现。结果报告了Caltech UCSD Birds 200-2011(CUB200-2011)测试集中图像前景背景分割的平均联合交集(mIOU)。有趣的是,在不微调ImageNet预训练编码器的情况下,我们能够优于现有方法[4,5,7,58]。†:重新实施的结果。
消融实验
表4. R2O组件的重要性。我们烧蚀了R2O的关键组件,即:掩模预测模块、先验(SLIC)的使用以及区域到对象调度。我们在ImageNet-100上评估了PASCAL VOC语义分割(mIOU)和k-NN分类(%)的性能。如图所示,使用Mask Prediction、SLIC Prior和Region to Object Schedule的组合对于获得最佳性能结果(62.3 mIOU)是必要的——删除任何组件都会影响性能至少−2:6 mIOU和−5:1%。
Code_R2O
class R2OModel(torch.nn.Module):
def __init__(self, config):
super().__init__()
# online network
self.online_network = EncoderwithProjection(config)
# target network
self.target_network = EncoderwithProjection(config)
# predictor
self.predictor = Predictor(config)
self.over_lap_mask = config['data'].get('over_lap_mask',True)
self._initializes_target_network()
self._fpn = None # not actual FPN, but pesudoname to get c4, TODO: Change the confusing name
self.slic_only = True
self.slic_segments = config['data']['slic_segments']
self.n_kmeans = config['data']['n_kmeans']
if self.n_kmeans < 9999:
self.kmeans = KMeans(self.n_kmeans,)
else:
self.kmeans = None
self.rank = config['rank']
self.agg = AgglomerativeClustering(affinity='cosine',linkage='average',distance_threshold=0.2,n_clusters=None)
self.agg_backup = AgglomerativeClustering(affinity='cosine',linkage='average',n_clusters=16)
@torch.no_grad()
def _initializes_target_network(self):
for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
@torch.no_grad()
def _update_target_network(self, mm):
"""Momentum update of target network"""
for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
param_k.data.mul_(mm).add_(1. - mm, param_q.data)
@torch.no_grad()
def _update_mask_network(self, mm):
"""Momentum update of maks network"""
for param_q, param_k in zip(self.online_network.encoder.parameters(), self.masknet.encoder.parameters()):
param_k.data.mul_(mm).add_(1. - mm, param_q.data)
@property
def fpn(self):
if self._fpn:
return self._fpn
else:
self._fpn = IntermediateLayerGetter(self.target_network.encoder, return_layers={'7':'out','6':'c4'})
return self._fpn
def handle_flip(self,aligned_mask,flip):
'''
aligned_mask: B X C X 7 X 7
flip: B
'''
_,c,h,w = aligned_mask.shape
b = len(flip)
flip = flip.repeat(c*h*w).reshape(c,h,w,b) # C X H X W X B
flip = flip.permute(3,0,1,2)
flipped = aligned_mask.flip(-1)
out = torch.where(flip==1,flipped,aligned_mask)
return out
def get_label_map(self,masks):
#
b,c,h,w = masks.shape
batch_data = masks.permute(0,2,3,1).reshape(b,h*w,32).detach().cpu()
labels = []
for data in batch_data:
agg = self.agg.fit(data)
if np.max(agg.labels_)>15:
agg = self.agg_backup.fit(data)
label = agg.labels_.reshape(h,w)
labels.append(label)
labels = np.stack(labels)
labels = torch.LongTensor(labels).cuda()
return labels
def do_kmeans(self,raw_image,slic_mask):
b = raw_image.shape[0]
feats = self.fpn(raw_image)['c4']
super_pixel = to_binary_mask(slic_mask,-1,resize_to=(14,14))
pooled, _ = maskpool(super_pixel,feats) #pooled B X 100 X d_emb
super_pixel_pooled = pooled.view(-1,1024).detach()
super_pixel_pooled_large = super_pixel_pooled
labels = self.kmeans.fit_transform(F.normalize(super_pixel_pooled_large,dim=-1)) # B X 100
labels = labels.view(b,-1)
raw_mask_target = torch.einsum('bchw,bc->bchw',to_binary_mask(slic_mask,-1,(56,56)) ,labels).sum(1).long().detach()
raw_masks = torch.ones(b,1,0,0).cuda() # to make logging happy
converted_idx = raw_mask_target
converted_idx_b = to_binary_mask(converted_idx,self.n_kmeans)
return converted_idx_b,converted_idx
def forward(self, view1, view2, mm,raw_image,roi_t,slic_mask,clustering_k=64):
im_size = view1.shape[-1]
b = view1.shape[0] # batch size
assert im_size == 224
idx = torch.LongTensor([1,0,3,2]).cuda()
# reset k means if necessary
if self.n_kmeans != clustering_k:
self.n_kmeans = clustering_k
if self.n_kmeans < 9999:
self.kmeans = KMeans(self.n_kmeans,)
else:
self.kmeans = None
# Get spanning view embeddings
with torch.no_grad():
if self.n_kmeans < 9999:
converted_idx_b,converted_idx = self.do_kmeans(raw_image,slic_mask) # B X C X 56 X 56, B X 56 X 56
else:
converted_idx_b = to_binary_mask(slic_mask,-1,(56,56))
converted_idx = torch.argmax(converted_idx_b,1)
raw_masks = torch.ones(b,1,0,0).cuda()
raw_mask_target = converted_idx
mask_dim = 56
rois_1 = [roi_t[j,:1,:4].index_select(-1, idx)*mask_dim for j in range(roi_t.shape[0])]
rois_2 = [roi_t[j,1:2,:4].index_select(-1, idx)*mask_dim for j in range(roi_t.shape[0])]
flip_1 = roi_t[:,0,4]
flip_2 = roi_t[:,1,4]
aligned_1 = self.handle_flip(ops.roi_align(converted_idx_b,rois_1,7),flip_1) # mask output is B X 16 X 7 X 7
aligned_2 = self.handle_flip(ops.roi_align(converted_idx_b,rois_2,7),flip_2) # mask output is B X 16 X 7 X 7
mask_b,mask_c,h,w =aligned_1.shape
aligned_1 = aligned_1.reshape(mask_b,mask_c,h*w).detach()
aligned_2 = aligned_2.reshape(mask_b,mask_c,h*w).detach()
mask_ids = None
masks = torch.cat([aligned_1, aligned_2])
masks_inv = torch.cat([aligned_2, aligned_1])
num_segs = torch.FloatTensor([x.unique().shape[0] for x in converted_idx]).mean()
q,pinds = self.predictor(*self.online_network(torch.cat([view1, view2], dim=0),masks.to('cuda'),mask_ids,mask_ids))
# target network forward
with torch.no_grad():
self._update_target_network(mm)
target_z, tinds = self.target_network(torch.cat([view2, view1], dim=0),masks_inv.to('cuda'),mask_ids,mask_ids)
target_z = target_z.detach().clone()
return q, target_z, pinds, tinds,masks,raw_masks,raw_mask_target,num_segs,converted_idx