PointRend 原理与代码解析

news2024/11/30 2:25:57

paper:PointRend: Image Segmentation as Rendering

code1:https://github.com/facebookresearch/detectron2/tree/main/projects/PointRend 

code2:https://github.com/open-mmlab/mmsegmentation/tree/master/configs/point_rend 

创新点

本文的中心思想是将图像分割视为一个渲染问题,具体做法是使用subvision策略来自适应地选择一个非均匀的点集来计算标签。说人话就是针对图像分割中边缘分割不准的情况,提出了一种新的优化方法,具体就是选取网络输出特征图上少数难分像素点,这些难分点大概率分布在物体边缘附近,然后加了一个小的子网络去学习这些难分点的特征,最终提升模型在物体轮廓边缘处的分割精度。

方法介绍

PointRend主要包含三个部分

  1. a point selection strategy
  2. a point-wise feature representation
  3. a point head

Point Selection

推理阶段 对于网络的输出特征图,挑选 \(N\) 个最难分最不确定的点(比如对于只有一类前景的分割问题,概率越接近0.5说明越难分),计算这些难分点的point-wise feature representation,然后根据这些特征去预测这些点的类别。这 \(N\) 个点之外的其它易分点就直接以coarset level即模型输出的最小feature map的预测结果为准。然后上采样,重复该步骤。

训练阶段 推理阶段采用的逐步上采样然后每次都选择 \(N\) 个最难分点的方式不适合训练阶段,因此训练阶段采用non-iterative的方法,首先基于均匀分布随机过采样 \(kN(k>1)\) 个点,然后从中选出 \(\beta N(\beta \in[0,1])\) 个最不确定的点,然后再从剩下的点中基于均匀分布挑出 \((1-\beta)N\) 个点,然后接一个point head subnetwork去学习这 \(N\) 个点的特征,point head的预测和损失计算也只针对这 \(N\) 个点。

Point-wise Representation

Fine-grained features. 为了让PointRend学习精细的分割细节,要从CNN的特征图中提取每个采样点的特征向量,并且要采用浅层的分辨率大的包含丰富细节特征的特征图。下面的实现中采用的是neck输出中分辨率最大的特征图。

Coarse prediction features. 细粒度特征包含了丰富的细节特征,但只有细粒度特征还不够,一是因为当一个点被两个物体的bounding box同时覆盖时,这两个物体在这一点有相同的细粒度特征,但这个点只能被预测为其中一个物体,因为对于实例分割,还需要额外的region-specific特征。二是因为细粒度特征只包含了低维信息,更多的上下文和语义特征可能会有帮助,这对实例分割和语义分割都有帮助。下面的实现中采用的是fpn head的最终预测输出。

将fine-grained特征和coarse prediction特征拼接到一起,就得到了这些采样点的最终特征表示。

Point head

在得到了采样点的特征表示后,PointRend采用了一个多层感知器(MLP)来进行点分割预测,预测每个点的分割类别后,根据对应的标签计算损失进行训练。

代码解析

接下来以mmsegmentation中的PointRend实现为例,讲解一下具体实现。

只有一类前景。假设batch_size=4,input_shape=(4, 3, 480, 480)。backbone=ResNetV1c,backbone的输出为[(4, 256, 120, 120), (4, 512, 60, 60), (4, 1024, 30, 30), (4, 2048, 15, 15)]。neck=FPN,neck后的输出为[(4, 256, 120, 120), (4, 512, 60, 60), (4, 1024, 30, 30), (4, 2048, 15, 15)]。pointrend中有两个head,因此用cascade_encoder_decoder将两个head串联起来。第一个head是FPN head,借鉴了Panoptic Feature Pyramid Networks中的Semantic FPN,这里就不具体介绍了,输出为(4, 2, 120, 120),然后计算这个head的损失,loss采用的交叉熵损失。

第二个head是point_head,point_head的输入包括neck的最大分辨率输出(4, 256, 120, 120),以及FPN head的输出(4, 2, 120, 120)。

选择难分点,这里prev_output是fpn head的输出

with torch.no_grad():
    points = self.get_points_train(
        prev_output, calculate_uncertainty, cfg=train_cfg)  # (4,2,120,120) -> (4,2048,2)

评价难分程度的函数如下,具体就是计算每个点top1得分和top2得分的差,差越小说明越难分。注意这里计算的是top2-top1,值为负,所以值越大说明越难分。

def calculate_uncertainty(seg_logits):
    top2_scores = torch.topk(seg_logits, k=2, dim=1)[0]  # (4,2,6144) -> (4,2,6144)
    return (top2_scores[:, 1] - top2_scores[:, 0]).unsqueeze(1)  # (4,6144) -> (4,1,6144)

具体选择用于训练的难分点实现如下,其中随机采样是通过mmcv中的函数point_sample实现的,而point_sample中首先将随机采样的坐标point_coords由[0, 1]转化到[-1, 1]区间,然后通过F.grid_sample根据归一化的坐标位置进行插值采样,F.grid_sample的用法见F.grid_sample 用法解读_00000cj的博客-CSDN博客。在训练过程中,如上文所述,point selection阶段 \(N=2048,k=3,\beta =0.75\)。实际挑出训练的点有2048个,首先随机采样3x2048=6144个点,然后挑选出最难分的0.75x2048=1536个点,剩下的2048-1536=512个数随机挑选。

def get_points_train(self, seg_logits, uncertainty_func, cfg):
    """Sample points for training.

    Sample points in [0, 1] x [0, 1] coordinate space based on their
    uncertainty. The uncertainties are calculated for each point using
    'uncertainty_func' function that takes point's logit prediction as
    input.

    Args:
        seg_logits (Tensor): Semantic segmentation logits, shape (
            batch_size, num_classes, height, width).
        uncertainty_func (func): uncertainty calculation function.
        cfg (dict): Training config of point head.

    Returns:
        point_coords (Tensor): A tensor of shape (batch_size, num_points,
            2) that contains the coordinates of ``num_points`` sampled
            points.
    """
    num_points = cfg.num_points  # 2048
    oversample_ratio = cfg.oversample_ratio  # 3
    importance_sample_ratio = cfg.importance_sample_ratio  # 0.75
    assert oversample_ratio >= 1
    assert 0 <= importance_sample_ratio <= 1
    batch_size = seg_logits.shape[0]  # (4,2,120,120)
    num_sampled = int(num_points * oversample_ratio)  # 2048x3=6144
    point_coords = torch.rand(
        batch_size, num_sampled, 2, device=seg_logits.device)  # (4,6144,2)
    point_logits = point_sample(seg_logits, point_coords)  # (4,2,6144)

    # It is crucial to calculate uncertainty based on the sampled
    # prediction value for the points. Calculating uncertainties of the
    # coarse predictions first and sampling them for points leads to
    # incorrect results.  To illustrate this: assume uncertainty func(
    # logits)=-abs(logits), a sampled point between two coarse
    # predictions with -1 and 1 logits has 0 logits, and therefore 0
    # uncertainty value. However, if we calculate uncertainties for the
    # coarse predictions first, both will have -1 uncertainty,
    # and sampled point will get -1 uncertainty.
    point_uncertainties = uncertainty_func(point_logits)  # (4,1,6144)
    num_uncertain_points = int(importance_sample_ratio * num_points)  # 0.75x2048=1536
    num_random_points = num_points - num_uncertain_points  # 512
    idx = torch.topk(
        point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]  # (4,1536)
    shift = num_sampled * torch.arange(
        batch_size, dtype=torch.long, device=seg_logits.device)  # (4,), (0,6144,12288,18432)
    idx += shift[:, None]  # (4,1536) += (4,1) -> (4,1536)
    # (4,6144,2)->(24576,2)[(4,1536)->(6144), :] -> (6144,2) -> (4,1536,2)
    point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
        batch_size, num_uncertain_points, 2)
    if num_random_points > 0:
        rand_point_coords = torch.rand(
            batch_size, num_random_points, 2, device=seg_logits.device)
        point_coords = torch.cat((point_coords, rand_point_coords), dim=1)  # (4,2048,2)
    return point_coords

在得到待训练点的坐标后,分别从neck最大分辨率输出(4, 256, 120, 120)和FPN head的预测结果(4, 2, 120, 120)上插值得到对应的fine feature和coarse feature。其中内部实现还是通过point_sample。

fine_grained_point_feats = self._get_fine_grained_point_feats(x, points)  # (4,256,2048)
coarse_point_feats = self._get_coarse_point_feats(prev_output, points)  # (4,2,2048)

然后将fine feature和coarse feature拼接起来,最终的point head是一个MLP,层数为3,最终再经过一个卷积层得到这2048个点的分类结果。

point_logits = self.forward(fine_grained_point_feats, coarse_point_feats)  # (4,2,2048)

def forward(self, fine_grained_point_feats, coarse_point_feats):
    x = torch.cat([fine_grained_point_feats, coarse_point_feats], dim=1)  # (4,258,2048)
    for fc in self.fcs:
        x = fc(x)
        if self.coarse_pred_each_layer:  # True
            x = torch.cat((x, coarse_point_feats), dim=1)  # (4,258,2048)
    return self.cls_seg(x)  # (4,2,2048)

在得到难分点的预测结果后,因为采样点的坐标不是整数,特征是从feature map上插值得到的,对应的标签也要插值得到,只不过特征插值时是采用的bilinear,而标签采用的是nearest。point head的loss也是交叉熵损失。 

point_label = point_sample(
    gt_semantic_seg.float(),  # (4,1,480,480)
    points,
    mode='nearest',
    align_corners=self.align_corners)  # (4,1,2048)
point_label = point_label.squeeze(1).long()  # (4,2048)
losses = self.losses(point_logits, point_label)

实验结果 

下图是一些示例,可以看出PointRend对边缘的分割更加精细。

因为只采样少数难分点,而对于大部分易分点比如远离图像边缘的区域,coarse prediction就足够了,因此增加point head后增加的计算量有限,如下 

下面分别是在DeeplabV3和SemanticFPN中加入PointRend,精度都得到了提升 

分割的标注通常不够精确,因此实际效果的提升可能比上表中的更大。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/48977.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[附源码]Python计算机毕业设计Django的花店售卖系统的设计与实现

项目运行 环境配置&#xff1a; Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术&#xff1a; django python Vue 等等组成&#xff0c;B/S模式 pychram管理等等。 环境需要 1.运行环境&#xff1a;最好是python3.7.7&#xff0c;…

[附源码]SSM计算机毕业设计学习资源共享与在线学习系统JAVA

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

全局路由拦截、局部路由拦截

引入&#xff1a; 看下面这个效果&#xff1a; 每次我们在点击一个功能时&#xff0c;它就会跳转到登录页面&#xff0c;意思就是让我们先登录&#xff0c;登录之后再进行功能操作&#xff1b;但是如果我们登录了&#xff0c;它就不会跳转&#xff0c;这是什么原理呢&#x…

vue3Blog首页基础布局样式规划

思考:我们已经安装了一个ant-design-vue的组件库,是否还可以安装其他的UI组件库混合使用? 答案是可以的,比如这个组件库没有要用到的组件,但另外一个组件有,我们完全可以再安装,单独将某一个使用到的组件引入即可,当项目打包的时候也不会说把所有的安装的都打包进去,只…

VS Code + Vue 开发环境搭建

1、下载并安装 Visual Studio Code 2019 2、Visual Studio Code 2019安装成功后&#xff0c;打开VS Code 工具点击左侧【扩展】菜单&#xff0c;在搜索栏中输入 Chinese 查找中文语言汉化包插件下载安装&#xff0c;然后重启VS Code 3、点击左侧【扩展】菜单&#xff0c;在搜…

【JS】数据结构之图结构

文章目录基本概念图的实现基本概念 什么是图&#xff1f; 图是一种数据结构&#xff0c;与树结构有些相似&#xff08;在数学的概念上&#xff0c;树是图的一种&#xff09;树结构是一对多的关系&#xff0c;而图结构是多对多的关系比如导航的最优路径&#xff1a;可以看成多个…

【Pandas数据处理100例】(一百):Pandas中使用filter()过滤器实现数据筛选

前言 大家好,我是阿光。 本专栏整理了《Pandas数据分析处理》,内包含了各种常见的数据处理,以及Pandas内置函数的使用方法,帮助我们快速便捷的处理表格数据。 正在更新中~ ✨ 🚨 我的项目环境: 平台:Windows10语言环境:python3.7编译器:PyCharmPandas版本:1.3.5N…

Oracle的学习心得和知识总结(八)|Oracle数据库PL/SQL语言GOTO语句技术详解

目录结构 注&#xff1a;提前言明 本文借鉴了以下博主、书籍或网站的内容&#xff0c;其列表如下&#xff1a; 1、参考书籍&#xff1a;《Oracle Database SQL Language Reference》 2、参考书籍&#xff1a;《PostgreSQL中文手册》 3、EDB Postgres Advanced Server User Guid…

Java常见漏洞——整数溢出漏洞、硬编码密码漏洞、不安全的随机数生成器

目录 前言&#xff1a; &#xff08;一&#xff09;整数溢出漏洞 0x01 整数溢出漏洞介绍 1.1 上界溢出 1.2 下界溢出 0x02 整数溢出漏洞修复 &#xff08;二&#xff09;硬编码密码漏洞 修复案例&#xff1a; &#xff08;三&#xff09;不安全的随机数生成器 前言&a…

1、Shell 概述

文章目录1、Shell 概述1.1 Linux 提供的 Shell 解析器有1.2 bash 和 sh 的关系1.3 Centos 默认的解析器是 bash尚硅谷2022版Linux扩展篇Shell教程-讲师&#xff1a;武晟然 壁立千仞 无欲则刚 1、Shell 概述 硬件–>操作系统核心&#xff08;Linux内核&#xff09;–>解释…

ubuntu2004 有线与另一个Ubuntu系统通信

在Ubuntu2004&#xff08;从机&#xff09;打开一个终端&#xff0c;输入如下配置有线网络ip&#xff0c;其中eth0 为有线网络的名称,up使能有线网络eth0&#xff1a; ifconfig eth0 192.169.10.2 up 并在.bashrc文件中输入 export ROS_MASTER_URIhttp://192.169.10.1:11311 …

Java多线程之线程池

Java高并发应用开发过程中会频繁的创建和销毁线程&#xff0c;为了节约成本和提升性能&#xff0c;往往会使用线程池来统一管理线程&#xff0c;使用线程池主要有以下几点优势 降低资源消耗&#xff1a;重复利用已创建的线程降低线程创建和销毁造成的消耗 提高响应速度&#xf…

ImageNet classification with deep convolutional neural networks

使用深度卷积神经网络进行ImageNet图像分类 目录 1.引言 2.网络结构 2.1 小细节 2.2 代码部分 3. 创新点 3.1 非线性激活函数ReLU&#xff08;提速&#xff09; 3.2 多GPU训练&#xff08;提速&#xff09; 3.3局部响应归一化&#xff08;增强泛化能力&#xff0c;已不…

我国天宫空间站以及各个仓位介绍

一、天宫空间站 天宫空间站&#xff08;China Space Station&#xff09;是中国从2021年开始建设的一个模块化空间站系统&#xff0c;为人类自1986年的和平号空间站及1998年的国际空间站后所建造的第三座大型在轨空间实验平台&#xff0c;基本构型由天和核心舱、问天实验舱和梦…

Head First设计模式(阅读笔记)-07.适配器模式

火鸡冒充鸭子 现在缺少一个绿头鸭对象&#xff0c;需要用野生火鸡对象冒充一下&#xff0c;但是二者的接口都不一样该怎么去冒充呢&#xff1f; // 鸭子接口 public interface Duck{public void quack(); // 呱呱叫public void fly(); // 飞行 } // 火鸡接口 public interfac…

应力奇异,你是一个神奇的应力!

在用ANSYS进行压力容器应力分析计算的时候&#xff0c;总会出现一些应力集中的问题&#xff0c;而且&#xff0c;有些应力集中点竟然没办法采用倒圆角的办法消除&#xff0c;采用网格加密方法时&#xff0c;甚至应力值比之前更大。这个情况&#xff0c;大家通常称为应力奇异。 …

springboot-mybatisplus-redis二级缓存

前言 mybatis可以自己带有二级缓存的实现&#xff0c;这里加上redis是想把东西缓存到redis中&#xff0c;而不是mybaits自带的map中。这也就构成了我们看到的springboot mybatisplus redis实现二级缓存的题目。 具体步骤如下&#xff1a; 首先加入需要的依赖 <dependenc…

《InnoDB引擎六》InnoDB 1.0.x版本之前的Master Thread

Master Thread 工作方式 在后台线程中提到&#xff0c;Master Thread是核心的后台线程。InnoDB存储引擎的主要工作都是在一个单独线程中完成的。 InnoDB 1.0.x版本之前的Master Thread Master Thread具有最高的线程优先级别。内部由多个循环组成&#xff1a;主循环(loop)、后台…

[附源码]SSM计算机毕业设计医院仪器设备管理系统JAVA

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

基于萤火虫算法优化的BP神经网络预测模型(Matlab代码实现)

目录 1 概述 2 萤火虫算法 3 萤火虫算法优化BP神经网络的算法设计 3.1 基本思想 3.2 萤火虫算法优化BP神经网络算法 4 运行结果 5 参考文献 6 Matlab代码及文章 1 概述 现实的世界中混沌现象无处不在,大至宇宙,小到基本粒子,都受到混沌理论支配.如气候变化会出 现混沌…