R2O语义分割: Refine and Represent: Region-to-Object Representation Learning

news2024/12/23 11:00:48

paper: arxiv.org/pdf/2208.11821v2.pdf

repo link: KKallidromitis/r2o: PyTorch implementation of Refine and Represent: Region-to-Object Representation Learning. (github.com)

摘要:

在本文中提出了区域到对象表示学习(Region-to-Object Representation Learning,R2O),它在预测分割掩码和使用这些掩码预训练编码网络之间振荡。R2O通过对编码特征进行聚类来确定分割掩码。R2O然后通过执行区域到区域的相似性学习来预训练编码网络,其中编码网络获取图像的不同视图,并将分割的区域映射到相似的编码特征。

引言:

R2O是一个框架,通过在整个预训练过程中进化预测分割掩码来学习基于区域和以对象为中心的特征,并鼓励与预测掩码内容相对应的特征的表示相似性。R2O使用掩模预测模块来使用所学习的特征将多个小规模图像区域转换为较大的以对象为中心的区域。在预训练过程中,预测掩码被用于学习表示,这反过来又会在未来的时代中导致更准确的、以对象为中心的分割(见图1)。这使得R2O能够在不依赖于分割启发式的情况下训练以对象为中心的特征。

图1. R2O通过联合学习发现和表示区域和对象,将基于区域和以对象为中心的自监督学习统一起来。以前的以对象为中心的预训练方法使用现成的分割先验来学习以对象为核心的表示。相反,R2O通过在预测分割掩模(这取决于表示)和训练表示(使用分割掩模)之间进行振荡来学习以对象为中心的表示。具体而言,R2O汇集与分割区域相对应的特征(区域相似性损失),并在视图之间强制这些特征的表示相似性。然后,习得的表示会导致改进的、越来越以对象为中心的掩膜,进而导致改进的表示。

论文贡献总结如下:

我们提出了R2O:一种自监督预训练方法,通过预测分割掩模、使用学习的图像特征和学习掩模内内容的表示来学习基于区域和以对象为中心的表示。

•我们为R2O引入了一个区域到对象的预训练课程,该课程从训练与简单图像区域相对应的局部特征开始,例如共享相似颜色值的相邻像素,然后逐渐学习以对象为中心的特征。

•与以前的方法相比,R2O预训练提高了ImageNet在MS COCO 1×(+0:2 APbb,+0:3 APmk)和2×(+0:9 APbb,+0.0:4 APmk)。在对COCO进行预训练后,R2O的PASCAL VOC和Cityscapes分别比早期方法高出+1:9mIOU和+2:4mIOU。此外,R2O在加州理工大学UCSD Birds 200-2011(CUB200-2011)上进一步提高了最先进的无监督分割性能+3:3mIOU,尽管没有对CUB200-202011进行微调。

相关工作:

基于区域和以对象为中心的自监督预训练

我们的工作扩展了基于区域的预训练,首先训练区域级特征,然后逐渐发现以对象为中心的分割,最终训练以对象为核心的特征。Odin[31]和SlotCon[62]等并行工作已经提出了对象级预训练方法,该方法不依赖于分割启发式,而是使用学习的特征来分割输入图像。与Odin类似,R2O对图像特征使用Kmeans聚类来帮助发现类对象区域。然而,与只关注训练对象级特征的Odin或SlotCon不同,R2O预训练涉及训练局部特征和对象级特征。

区域聚类

在这项工作中,我们使用聚类将小图像区域细化为以对象为中心的掩码。聚类在无监督语义分割中有着悠久的历史[1,13,14,17,21,33]。这一趋势在深度学习时代仍在继续。当前的方法利用聚类分配来学习空间和语义一致的嵌入[12,34,36,37]。例如,一种流行的方法[34, 37]是对像素嵌入进行聚类训练,并使用聚类分配作为伪标签。其他方法 [12, 36] 则使用相似性损失和特定任务数据增强来训练嵌入。我们的工作并不侧重于无监督语义分割任务,而是侧重于使用聚类来定位对象,以便学习可迁移的表征。我们的方法使用 K-means 聚类[39],因为它简单有效。

方法:

我们提出了R2O,这是一种自监督的预训练方法,可以实现区域级和以对象为中心的表示学习。在高水平上,R2O通过掩模预测模块将区域级图像先验转换为以对象为中心的掩模,并鼓励与所发现的掩模的内容相对应的特征的代表不变性(见图2)。

图2:R2O架构。R2O由两个相互依存的步骤组成:(1)掩模预测,其中使用学习的特征将小图像区域(由区域级先验给定)转换为以对象为中心的掩模;(2)表示学习,其中学习对象级表示。

步骤1通过计算由区域级先验给出的每个区域的特征并对这些区域级嵌入执行K-means聚类来产生以对象为中心的分割。

步骤2鼓励与每个视图中的掩码的内容相对应的特征的表示相似性。通过这个过程,我们发现通过第二步学习到的特征会导致越来越多的以对象为中心的分割,从而训练以对象为核心的表示。考虑到区域级特征在我们的掩模预测过程中的重要性,R2O采用了一种区域到对象的课程(备注:课程学习curriculum learning),该课程在掩模预测期间将聚类(K)的数量从非常高的值(K=128)逐渐减少到较低的值(K=4),正如我们所示,这使得能够在训练期间早期进行基于区域的预训练,并慢慢发展为以对象为中心的关联。我们将在线网络和目标网络分别表示为ON,TN。

图3. COCO预训练过程中掩码预测的可视化。在后期的预训练中,R2O预测的掩码能够发现以对象为中心的区域,即使是在多对象、以场景为中心的数据集中,如COCO。

结果

表1. 在ImageNet预训练后的PASCAL VOC和Cityscapes的COCO对象检测和实例分割以及语义分割方面的性能。所有方法都预训练了ResNet-50主干,并微调了用于COCO的Mask RCNN(R50-FPN)和用于PASCAL VOC和Cityscapes的FCN。*:表示并行工作。†:重新实施的结果。

表2. COCO预训练后PASCAL VOC和Cityscapes语义分割(mIOU)的性能。在对以场景为中心的数据进行预训练后,R2O在PASCAL VOC和Cityscapes上展示了最先进的语义分割传输性能。*:表示并行工作。†:重新实施的结果。

表3. CUB200-2011细分市场表现。结果报告了Caltech UCSD Birds 200-2011(CUB200-2011)测试集中图像前景背景分割的平均联合交集(mIOU)。有趣的是,在不微调ImageNet预训练编码器的情况下,我们能够优于现有方法[4,5,7,58]。†:重新实施的结果。

消融实验

表4. R2O组件的重要性。我们烧蚀了R2O的关键组件,即:掩模预测模块、先验(SLIC)的使用以及区域到对象调度。我们在ImageNet-100上评估了PASCAL VOC语义分割(mIOU)和k-NN分类(%)的性能。如图所示,使用Mask Prediction、SLIC Prior和Region to Object Schedule的组合对于获得最佳性能结果(62.3 mIOU)是必要的——删除任何组件都会影响性能至少−2:6 mIOU和−5:1%。

Code_R2O

class R2OModel(torch.nn.Module):
    def __init__(self, config):
        super().__init__()

        # online network
        self.online_network = EncoderwithProjection(config)

        # target network
        self.target_network = EncoderwithProjection(config)
        
        # predictor
        self.predictor = Predictor(config)
        self.over_lap_mask = config['data'].get('over_lap_mask',True)
        self._initializes_target_network()
        self._fpn = None # not actual FPN, but pesudoname to get c4, TODO: Change the confusing name
        self.slic_only = True
        self.slic_segments = config['data']['slic_segments']
        self.n_kmeans = config['data']['n_kmeans']
        if self.n_kmeans < 9999:
            self.kmeans = KMeans(self.n_kmeans,)
        else:
            self.kmeans = None
        self.rank = config['rank']
        self.agg = AgglomerativeClustering(affinity='cosine',linkage='average',distance_threshold=0.2,n_clusters=None)
        self.agg_backup = AgglomerativeClustering(affinity='cosine',linkage='average',n_clusters=16)

    @torch.no_grad()
    def _initializes_target_network(self):
        for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False     # not update by gradient

    @torch.no_grad()
    def _update_target_network(self, mm):
        """Momentum update of target network"""
        for param_q, param_k in zip(self.online_network.parameters(), self.target_network.parameters()):
            param_k.data.mul_(mm).add_(1. - mm, param_q.data)

    @torch.no_grad()
    def _update_mask_network(self, mm):
        """Momentum update of maks network"""
        for param_q, param_k in zip(self.online_network.encoder.parameters(), self.masknet.encoder.parameters()):
            param_k.data.mul_(mm).add_(1. - mm, param_q.data)

    @property
    def fpn(self):
        if self._fpn:
            return self._fpn
        else:
            self._fpn = IntermediateLayerGetter(self.target_network.encoder, return_layers={'7':'out','6':'c4'})
            return self._fpn
            
    def handle_flip(self,aligned_mask,flip):
        '''
        aligned_mask: B X C X 7 X 7
        flip: B 
        '''
        _,c,h,w = aligned_mask.shape
        b = len(flip)
        flip = flip.repeat(c*h*w).reshape(c,h,w,b) # C X H X W X B
        flip = flip.permute(3,0,1,2)
        flipped = aligned_mask.flip(-1)
        out = torch.where(flip==1,flipped,aligned_mask)
        return out

    def get_label_map(self,masks):
        #
        b,c,h,w = masks.shape
        batch_data = masks.permute(0,2,3,1).reshape(b,h*w,32).detach().cpu()
        labels = []
        for data in batch_data:
            agg = self.agg.fit(data)
            if np.max(agg.labels_)>15:
                agg = self.agg_backup.fit(data)
            label = agg.labels_.reshape(h,w)
            labels.append(label)
        labels = np.stack(labels)
        labels = torch.LongTensor(labels).cuda()
        return labels

    def do_kmeans(self,raw_image,slic_mask):
        b = raw_image.shape[0]
        feats = self.fpn(raw_image)['c4']
        super_pixel = to_binary_mask(slic_mask,-1,resize_to=(14,14))
        pooled, _ = maskpool(super_pixel,feats) #pooled B X 100 X d_emb
        super_pixel_pooled = pooled.view(-1,1024).detach()
        super_pixel_pooled_large = super_pixel_pooled
        labels = self.kmeans.fit_transform(F.normalize(super_pixel_pooled_large,dim=-1)) # B X 100
        labels = labels.view(b,-1)
        raw_mask_target =  torch.einsum('bchw,bc->bchw',to_binary_mask(slic_mask,-1,(56,56)) ,labels).sum(1).long().detach()

        raw_masks = torch.ones(b,1,0,0).cuda() # to make logging happy
        converted_idx = raw_mask_target
        converted_idx_b = to_binary_mask(converted_idx,self.n_kmeans)
        return converted_idx_b,converted_idx
    
    def forward(self, view1, view2, mm,raw_image,roi_t,slic_mask,clustering_k=64):
        im_size = view1.shape[-1]
        b = view1.shape[0] # batch size
        assert im_size == 224
        idx = torch.LongTensor([1,0,3,2]).cuda()
        # reset k means if necessary
        if self.n_kmeans != clustering_k:
            self.n_kmeans = clustering_k
            if self.n_kmeans < 9999:
                self.kmeans = KMeans(self.n_kmeans,)
            else:
                self.kmeans = None
        # Get spanning view embeddings
        with torch.no_grad():
            if self.n_kmeans < 9999:
                converted_idx_b,converted_idx = self.do_kmeans(raw_image,slic_mask) # B X C X 56 X 56, B X 56 X 56
            else:
                converted_idx_b = to_binary_mask(slic_mask,-1,(56,56))
                converted_idx = torch.argmax(converted_idx_b,1)
            raw_masks = torch.ones(b,1,0,0).cuda()
            raw_mask_target = converted_idx
            mask_dim = 56

            rois_1 = [roi_t[j,:1,:4].index_select(-1, idx)*mask_dim for j in range(roi_t.shape[0])]
            rois_2 = [roi_t[j,1:2,:4].index_select(-1, idx)*mask_dim for j in range(roi_t.shape[0])]
            flip_1 = roi_t[:,0,4]
            flip_2 = roi_t[:,1,4]
            aligned_1 = self.handle_flip(ops.roi_align(converted_idx_b,rois_1,7),flip_1) # mask output is B X 16 X 7 X 7
            aligned_2 = self.handle_flip(ops.roi_align(converted_idx_b,rois_2,7),flip_2) # mask output is B X 16 X 7 X 7
            mask_b,mask_c,h,w =aligned_1.shape
            aligned_1 = aligned_1.reshape(mask_b,mask_c,h*w).detach()
            aligned_2 = aligned_2.reshape(mask_b,mask_c,h*w).detach()
        
        mask_ids = None
        masks = torch.cat([aligned_1, aligned_2])
        masks_inv = torch.cat([aligned_2, aligned_1])
        num_segs = torch.FloatTensor([x.unique().shape[0] for x in converted_idx]).mean()
        q,pinds = self.predictor(*self.online_network(torch.cat([view1, view2], dim=0),masks.to('cuda'),mask_ids,mask_ids))
        # target network forward
        with torch.no_grad():
            self._update_target_network(mm)
            target_z, tinds = self.target_network(torch.cat([view2, view1], dim=0),masks_inv.to('cuda'),mask_ids,mask_ids)
            target_z = target_z.detach().clone()
        

        return q, target_z, pinds, tinds,masks,raw_masks,raw_mask_target,num_segs,converted_idx

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1308019.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows10下jdk安装

文章目录 windows10下jdk安装说明what安装包下载执行安装包验证是否安装成功 windows10下jdk安装 说明 操作系统&#xff1a;windows10 版本&#xff1a;1.8 what JDK(Java Development Kit) 是 Java 语言的软件开发工具包 安装包下载 https://www.oracle.com/java/techn…

4.11 构建onnx结构模型-Clip

前言 构建onnx方式通常有两种&#xff1a; 1、通过代码转换成onnx结构&#xff0c;比如pytorch —> onnx 2、通过onnx 自定义结点&#xff0c;图&#xff0c;生成onnx结构 本文主要是简单学习和使用两种不同onnx结构&#xff0c; 下面以 Clip 结点进行分析 方式 方法一…

下一站 Gen AI 城市巡展指南来了!“码”上出发,Let‘s 构!

亚马逊云科技开发者社区为开发者们提供全球的开发技术资源。这里有技术文档、开发案例、技术专栏、培训视频、活动与竞赛等。帮助中国开发者对接世界最前沿技术&#xff0c;观点&#xff0c;和项目&#xff0c;并将中国优秀开发者或技术推荐给全球云社区。如果你还没有关注/收藏…

【Java用法】Hutool树结构工具-TreeUtil快速构建树形结构的两种方式 + 数据排序

Hutool树结构工具-TreeUtil快速构建树形结构的两种方式 数据排序 一、业务场景二、Hutool官网树结构工具2.1 介绍2.2 使用2.2.1 定义结构2.2.2 构建Tree2.2.3 自定义字段名 2.3 说明 三、具体的使用场景3.1 实现的效果3.2 业务代码3.3 实现自定义字段的排序 四、踩过的坑4.1 坑…

Android studio如何安装ai辅助工具

引言 在没有翻墙的情况下&#xff0c;即单纯在公司打工&#xff0c;经测试&#xff0c;大部分ai工具都是使用不了的&#xff08;比如各种gpt,codeium,copilot&#xff09;&#xff0c;根本登录不了账号&#xff0c;但有一个国内的codegeex是可以使用的&#xff0c;在这里不对各…

DS冲刺整理做题定理(二)线性表、栈、队列的套路

继续归纳套路&#xff0c;做题练习非常重要&#xff0c;王道的基本上足够了&#xff0c;学有余力可以做一下数据结构1800~ DS冲刺整理做题定理&#xff08;一&#xff09;二叉树专题https://blog.csdn.net/jsl123x/article/details/134949736?spm1001.2014.3001.5501 目录 一…

Spring Boot--Freemarker渲染技术+实际案例

目录 Freemarker 1.1.什么是Freemarker 1.2.Freemarker模板组成部分 1.3.优点 FreeMarker常见的方法&#xff1a; 2.2.2.数值 2.2.3.布尔值 2.2.4.日期 2.3.常见指令 2.3.1.处理不存在的值 assign 2.3.4.list 2.3.5.include SpringBoot整合Freemarker Freemarker…

回归预测 | MATLAB实现CHOA-BiLSTM黑猩猩优化算法优化双向长短期记忆网络回归预测 (多指标,多图)

回归预测 | MATLAB实现CHOA-BiLSTM黑猩猩优化算法优化双向长短期记忆网络回归预测 &#xff08;多指标&#xff0c;多图&#xff09; 目录 回归预测 | MATLAB实现CHOA-BiLSTM黑猩猩优化算法优化双向长短期记忆网络回归预测 &#xff08;多指标&#xff0c;多图&#xff09;效果…

数字孪生 5G时代的重要应用场景 - 读书笔记

作者&#xff1a;陈根 第1章&#xff1a;数字孪生概述 数字孪生&#xff1a;对物理世界&#xff0c;构建数字化实体&#xff0c;实现了解、分析和优化集成技术&#xff1a;AI、机器学习、大数据分析构成&#xff1a;传感器、数据、集成、分析、促动器&#xff08;可以人工干预…

【python-wrf】绘制wrf中的土地利用报错内容及其解决方法

从该代码处绘制wrf中的土地利用报错内容及其解决方法 1.报错内容&#xff1a; 微信公众平台 (qq.com)https://mp.weixin.qq.com/s/Cn0vhvfroVADPnT237LXNw --------------------------------------------------------------------------- AttributeError …

支付功能如何测试?

一、测试思维 要分析测试点之前&#xff0c;我们先来梳理一下测试思维。总结来说&#xff0c;任何事物的测试思路都可以总结如下&#xff1a; 第一步&#xff1a;梳理产品的核心业务流程&#xff1a;明白这是个什么项目&#xff0c;实现了什么业务&#xff0c;以及是怎么实现…

青少年CTF-Crypto(Morse code/ASCII和凯撒)

FLAG&#xff1a;你这一生到底想干嘛 专研方向: Web安全 &#xff0c;Md5碰撞 每日emo&#xff1a;不要因为别人都交卷了&#xff0c;就乱选答案 文章目录 1.Morse code2、ASCII和凯撒的约定 1.Morse code 题目提示摩尔斯电码&#xff0c;这个是给的附件 直接用摩尔斯解密&am…

java实现局域网内视频投屏播放(四)投屏实现

代码链接​​​​​​​​​​​​​​​​​​​​​ 设备发现 上一篇文章说过&#xff0c;设备的发现有两种情况&#xff0c;主动和被动&#xff0c;下面我们来用java实现这两种模式 主动发现 构建一个UDP请求发送到239.255.255.250:1900获取设备信息&#xff0c;UDP包的…

记录汇川:套接字TCP通信-梯形图

H5U集成一路以太网接口。使用AutoShop可以通过以太网方便、快捷对H5U进行行监控、下载、上载以及调试等操作。同时也可以通过以太网与网络中的其他设备进行数据交互。H5U集成了Modbus-TCP协议&#xff0c;包括服务器与客户端。可轻松实现与支持Modbus-TCP的设备进行通讯与数据交…

redis-学习笔记(Jedis string 简单命令)

mset & mget 批量设置和获取键值对 可以看出,参数都是可变参数 (就是说, 可以写任意个) 代码演示 getrange & setrange 获取和设置 string 类型中 某一区间的值 代码演示 append 往字符串的末尾拼接字符串 代码演示 incr & decr 如果 string 中为数字的话, 可以进行…

最新版ES8的client API操作 Elasticsearch Java API client 8.0

作者&#xff1a;ChenZhen 本人不常看网站消息&#xff0c;有问题通过下面的方式联系&#xff1a; 邮箱&#xff1a;1583296383qq.comvx: ChenZhen_7 我的个人博客地址&#xff1a;https://www.chenzhen.space/&#x1f310; 版权&#xff1a;本文为博主的原创文章&#xff…

大数据机器学习与深度学习——回归模型评估

大数据机器学习与深度学习——回归模型评估 回归模型的性能的评价指标主要有&#xff1a;MAE(平均绝对误差)、MSE(平均平方误差)、RMSE(平方根误差)、R2_score。但是当量纲不同时&#xff0c;RMSE、MAE、MSE难以衡量模型效果好坏&#xff0c;这就需要用到R2_score。 平均绝对…

怎么去评估数据资产?一个典型的政务数据资产评估案例

据中国资产评估协会《数据资产评估指导意见》&#xff0c;数据资产评估主要是三个方法&#xff1a;市场法、成本法和收益法。之前小亿和大家分享了数据资产评估方法以及价值发挥的路径&#xff0c;今天结合一个案例来具体讲解一下怎么去评估数据资产。 这个案例是一个典型的一个…

【LeetCode刷题】-- 165.比较版本号

165.比较版本号 方法&#xff1a;使用双指针 class Solution {public int compareVersion(String version1, String version2) {//使用双指针int n version1.length(),m version2.length();int i 0,j 0;while(i<n || j <m){int x 0;for(; i < n && vers…

做数据分析为何要学统计学(6)——什么问题适合使用卡方检验?

卡方检验作为一种非常著名的非参数检验方法&#xff08;不受总体分布因素的限制&#xff09;&#xff0c;在工程试验、临床试验、社会调查等领域被广泛应用。但是也正是因为使用的便捷性&#xff0c;造成时常被误用。本文参阅相关的文献&#xff0c;对卡方检验的适用性进行粗浅…