《Mask2Former》算法详解

news2025/2/24 10:18:17

文章地址:《Masked-attention Mask Transformer for Universal Image Segmentation》
代码地址:https://github.com/facebookresearch/Mask2Former

文章为发表在CVPR2022的一篇文章。从名字可以看出文章像提出一个可以统一处理各种分割任务(全景分割、语义分割、实例分割)的网络。

这里稍微通俗的解释一下上述的几个分割任务:
全景分割:分割的结果有背景概念(天空、大海),有实例概念(person1、person2、person2)。
语义分割:只有类别概念,比如上述的person1、person2、person3都属于people这一类,不区分每个实例。且包含背景类别的识别。
实例分割:只有前景类别的概念,例如只有人、猫、狗等类别,没有天空大海这一类背景类别。且前景类别是有实例概念的。

更详细一点的说,在coco数据集里面定义,背景类称为stuff类别,这一类类别是没有边界的概念,例如一张图只有一片天空。前景类别称为things类别。

本文提出的网络就是可以一次性处理上述几个分割任务,而不用向之前的网络,一个任务去处理特定的一种任务。如下图所示

上图不仅可以看出不同任务的示意,还可以看出文章的网络在各个任务上表现都是SOTA的。

一、网络结构

文章采用的网络架构与MaskFormer 一致的。该类架构由三部分组成,一个backbone用于提取图片的特征,一个pixel decoder用于将主干网络提前的特征进行上采样生成高分辨率的图像特征,一个transformer decoder用于根据图像特征来处理object queries。最终网络根据pixel decoder输出的高分辨率的图像特征和transformer decoder输出的object queries生成最终的预测mask。
该结构能够很好的处理各种分割任务,原因就是输出对每个mask预测一个类别,这样不同的任务只是定义的不同类别而已。

具体的Mask2Former的示意图如下图所示,左边为整体的框架,右边为Transformer decoder with masked attention结构:

1.1 Transformer decoder with masked attention

有文章研究全局的特征信息对图像分割任务是非常重要的,但是也有文章证明对于transformer-based的结构来说,全局的特征信息会导致cross-attention收敛变慢,因为cross-attention需要很多轮的训练才能关注到需要关注的对应的物体区域上。

文章假设局部特征已经可以很好的去更新query feature了,而全局特征可以通过self-attention结构来学习。基于这假设,文章提出了masked attetion结构。

标准的cross-attetion结构用公式表示如下所示:
X l = s o f t m a x ( Q l K l T ) V l + X l − 1 X_l = softmax(Q_lK^T_l)V_l + X_{l-1} Xl=softmax(QlKlT)Vl+Xl1
其中,l表示当前层的索引, X l ∈ R N × C X_l\in R^{N\times C} XlRN×C表示l层的N个C维的query features,而 Q l = f Q ( X l − 1 ) ∈ R N × C Q_l=f_{Q}(X_{l-1})\in R^{N\times C} Ql=fQ(Xl1)RN×C. X 0 X_0 X0表示Transformer decoder的输入。 K l , V l ∈ R H l W l × C K_l,V_l\in R^{H_l W_l \times C} Kl,VlRHlWl×C为图像特征经过 f K ( ⋅ ) f_K({\cdot}) fK() f V ( ⋅ ) f_V({\cdot}) fV()变化后的结果,其中 H l H_l Hl W l W_l Wl是图像特征的分辨率。上述的 f Q f_Q fQ f K f_K fK f V f_V fV都是线性变换层。

本文提出的masked attetion模块,用公式表示如下:
X l = s o f t m a x ( M l − 1 + Q l K l T ) V l + X l − 1 X_l = softmax(M_{l-1}+Q_lK^T_l)V_l + X_{l-1} Xl=softmax(Ml1+QlKlT)Vl+Xl1
其中attetion mask M_{l-1}中位置(x,y)的值用如下公式计算得到:
M l − 1 ( x , y ) = { 0 i f M l − 1 ( x , y ) = 1 − ∞ o t h e r w i s e M_{l-1}(x, y)=\left\{ \begin{aligned} 0 \quad if M_{l-1}(x,y) = 1\\ -\infty \quad otherwise \end{aligned} \right. Ml1(x,y)={0ifMl1(x,y)=1otherwise
这里 M l − 1 ∈ 0 , 1 N × H l W l M_{l-1}\in {0, 1}^{N\times H_l W_l} Ml10,1N×HlWl是根据阈值为0.5对Transformer decoder l-1层的输出进行resize后的二值化的结果。 resize后的分辨率大小和 K l K_l Kl一样。 M 0 M_0 M0是通过 X 0 X_0 X0二值化得到的。

1.2 High-resolution features

高分辨率的特征能够改善模型的效果,但是也每次都采用高分辨率的特征对于计算量要求也非常大。为了提升效率,文章输入给Transformer decoder层的特征采用不同分辨率的图片特征。
更详细说明,pixel decoder输出的图像特征大小分别为原图的1/32, 1/16, 1/8。对于每个分辨率的图片,在给到Transformer decoder之前,会加入sinusoidal positional embedding e p o s ∈ R H l W l × C e_{pos}\in R^{H_l W_l \times C} eposRHlWl×C和一个可学习的scale-level embedding e l v l ∈ R 1 × C e_{lvl}\in R^{1\times C} elvlR1×C。Transformer decoder对这种三层Transformer decoder结构重复L次。

1.3 Optimization improvements

这里针对普通的Transformer decoder layer进行改进。普通的Transformer decoder layer处理query features的顺序为self-attention module, cross-attention module,feed-forward network。query feature( X 0 X_0 X0)是初始化为0的特征。dropout用在residual connections和attention maps结构中。

文章对上述三点进行改进,文章认为self-attention只有图片特征的输入,没啥信息可以学习,为了提高计算效率,将self-attention、cross-attention调换了顺序。query feature( X 0 X_0 X0)变成可学习的特征。去除dropout。

二、提升训练效率

因为对高分辨率的mask进行预测,对显存的消耗很大,例如上一版的MaskFormer一个图片训练需要32G的显存。
受到PoinRend和Implicit PointRend文章的启发,训练分割任务的网络时,不需要计算整个mask的loss,只需要计算K个随机采样点的loss即可。
在训练时,有matching-loss(Transformer结构预测类别时特有的匹配loss)和final loss(匹配好后,计算预测结果和gt的loss)。
在计算matching-loss时,采用均匀采样采相同的K个点计算loss。
在计算final loss时,采用importance sampling,给每个不同的预测结果采不同的K个点进行计算loss。
这样的loss计算方式可以减少三倍的显存占用量,从而提高网络训练效率。

三、网络具体实现

  1. Pixel decoder. 采用multi-scale deformable attention(MSDeformAttn)做为pixel decoder结构,采用6层MSDeformAttn处理1/8,1/16,1/32大小的图片feature,并用一个上采样生成1/4的图片feature。
  2. Transformer decoder. L=3(共9层),100个queries(N=100), 在Transformer decoder layer的每个中间层度有一个辅助loss(9层的输出都有一个辅助loss来指导学习1.1中的M)
  3. Loss weights. 对于mask loss,文中采用binary cross-entropy loss和 dice loss一起,即 L m a s k = λ c e L c e + λ d i c e L d i c e L_{mask}=\lambda_{ce}L_{ce}+\lambda_{dice}L_{dice} Lmask=λceLce+λdiceLdice,其中 λ c e = 5.0 , λ d i c e = 5.0 \lambda_{ce}=5.0, \lambda_{dice}=5.0 λce=5.0,λdice=5.0. final loss是mask loss和classfication loss一起计算,即 L m a s k + λ c l s L c l s L_{mask}+\lambda_{cls}L_{cls} Lmask+λclsLcls,其中当有匹配的gt时 λ c l s = 2.0 \lambda_{cls}=2.0 λcls=2.0,当匹配的为no object时, λ c l s = 0.1 \lambda_{cls}=0.1 λcls=0.1
  4. post-processing. 对于全景和语义分割来说,后处理方式同MaskFormer,输出对应的mask以及其对应的类别。对于实例分割,为了输出对应实例的分割,采用类别的分数和mask的平均分数相乘得到每个实例的分数。

到这里该算法的基本内容都介绍完了,具体的训练参数还有训练数据以及数据结果可以查看文章找到更详细的信息。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1641372.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++学习第二十二课:STL映射类的深入解析

C学习第二十二课:STL映射类的深入解析 在C标准模板库(STL)中,映射类(std::map和std::multimap)是用来存储关联数据的容器。与集合类不同,映射类中的每个元素都是一个键值对(key-val…

十四、网络编程

目录 一、二、网络通讯要素三、IP和端口号四、网络协议1、网络通信协议2、TCP/IP协议簇1)TCP协议2)UDP 3、Socket 五、TCP网络编程1、基于Socket的TCP编程1)客户端创建socket对象2) 服务器端建立 ServerSocket对象 2、UDP网络通信…

17 内核开发-内核内部内联汇编学习

​ 17 内核开发-内核内部内联汇编学习 课程简介: Linux内核开发入门是一门旨在帮助学习者从最基本的知识开始学习Linux内核开发的入门课程。该课程旨在为对Linux内核开发感兴趣的初学者提供一个扎实的基础,让他们能够理解和参与到Linux内核的开发过程中…

【 书生·浦语大模型实战营】学习笔记(六):Lagent AgentLego 智能体应用搭建

🎉AI学习星球推荐: GoAI的学习社区 知识星球是一个致力于提供《机器学习 | 深度学习 | CV | NLP | 大模型 | 多模态 | AIGC 》各个最新AI方向综述、论文等成体系的学习资料,配有全面而有深度的专栏内容,包括不限于 前沿论文解读、…

MySQL技能树学习——数据库组成

数据库组成: 数据库是一个组织和存储数据的系统,它由多个组件组成,这些组件共同工作以确保数据的安全、可靠和高效的存储和访问。数据库的主要组成部分包括: 数据库管理系统(DBMS): 数据库管理系…

node.js中path模块-路径处理,语法讲解

node中的path 模块是node.js的基础语法,实际开发中,我们通过使用 path 模块来得到绝对路径,避免因为相对路径带来的找不到资源的问题。 具体来说:Node.js 执行 JS 代码时,代码中的路径都是以终端所在文件夹出发查找相…

服务器被攻击,为什么后台任务管理器无法打开?

在服务器遭受DDoS攻击后,当后台任务管理器由于系统资源耗尽无法打开时,管理员需要依赖间接手段来进行攻击类型的判断和解决措施的实施。由于涉及真实代码可能涉及到敏感操作,这里将以概念性伪代码和示例指令的方式来说明。 判断攻击类型 步…

DHCPv4_CLIENT_ALLOCATING_04: 发送DHCPREQUEST - 头部值‘secs‘字段

测试目的: 验证客户端发送的DHCPREQUEST消息是否使用了与原始DHCPDISCOVER消息相同的’secs’字段值。 描述: 本测试用例旨在确保DHCP客户端在发送DHCPREQUEST消息时,使用了与它之前发送的DHCPDISCOVER消息相同的’secs’字段值。这是DHCP…

国产数据库的发展势不可挡

前言 新的一天又开始了,光头强强总不紧不慢地来到办公室,准备为今天一天的工作,做一个初上安排。突然,熊二直接进入办公室,说:“强总老大,昨天有一个数据库群炸了锅了,有一位姓虎的…

【LLM 论文】UPRISE:使用 prompt retriever 检索 prompt 来让 LLM 实现 zero-shot 解决 task

论文:UPRISE: Universal Prompt Retrieval for Improving Zero-Shot Evaluation ⭐⭐⭐⭐ EMNLP 2023, Microsoft Code:https://github.com/microsoft/LMOps 一、论文速读 这篇论文提出了 UPRISE,其思路是:训练一个 prompt retri…

Git可视化工具tortoisegit 的下载与使用

一、tortoisegit 介绍 TortoiseGit 是一个非常实用的版本控制工具,主要用于与 Git 版本控制系统配合使用。 它的主要特点包括: 图形化界面:提供了直观、方便的操作界面,让用户更易于理解和管理版本控制。与 Windows 资源管理器…

Flutter笔记:Widgets Easier组件库(9)使用弹窗

Flutter笔记 Widgets Easier组件库(9):使用弹窗 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite:http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress o…

美国站群服务器的定义、功能以及在网站运营中的应用

美国站群服务器的定义、功能以及在网站运营中的应用 在当今互联网的蓬勃发展中,站群服务器已成为网站运营和SEO优化中不可或缺的重要工具之一。尤其是美国站群服务器,在全球范围内备受关注。本文将深入探讨美国站群服务器的定义、功能以及在网站运营中的…

Go实战训练之Web Server 与路由树

Server & 路由树 Server Web 核心 对于一个 Web 框架,至少要提供三个抽象: Server:代表服务器的抽象Context:表示上下文的抽象路由树 Server 从特性上来说,至少要提供三部分功能: 生命周期控制&…

基于SSM的宠物领养平台(有报告)。Javaee项目。ssm项目。

演示视频: 基于SSM的宠物领养平台(有报告)。Javaee项目。ssm项目。 项目介绍: 采用M(model)V(view)C(controller)三层体系结构,通过Spring Spri…

《自动机理论、语言和计算导论》阅读笔记:p215-p351

《自动机理论、语言和计算导论》学习第 11 天,p215-p351总结,总计 37 页。 一、技术总结 1.constrained problem 2.Fermat’s lats theorem Fermat’s Last Theorem states that no three positive integers a, b and c satisfy the equation a^n b…

【数据结构(邓俊辉)学习笔记】列表01——从向量到列表

文章目录 0.概述1. 从向量到列表1.1 从静态到动态1.2 从向量到列表1.3 从秩到位置1.4 列表 2. 接口2.1 列表节点2.1.1 ADT接口2.1.2 ListNode模板类 2.2 列表2.2.1 ADT接口2.2.2 List模板类 0.概述 学习了向量,再介绍下列表。先介绍下列表里的概念和语义&#xff0…

C++ | Leetcode C++题解之第66题加一

题目&#xff1a; 题解&#xff1a; class Solution { public:vector<int> plusOne(vector<int>& digits) {int n digits.size();for (int i n - 1; i > 0; --i) {if (digits[i] ! 9) {digits[i];for (int j i 1; j < n; j) {digits[j] 0;}return …

平平科技工作室-Python-超级玛丽

一.准备图片 放在文件夹取名为images 二.准备一些音频和文字格式 放在文件夹media 三.编写代码 import sys, os sys.path.append(os.getcwd()) # coding:UTF-8 import pygame,sys import os from pygame.locals import* import time pygame.init() # 设置一个长为1250,宽为…

Python | Leetcode Python题解之第65题有效数字

题目&#xff1a; 题解&#xff1a; from enum import Enumclass Solution:def isNumber(self, s: str) -> bool:State Enum("State", ["STATE_INITIAL","STATE_INT_SIGN","STATE_INTEGER","STATE_POINT","STATE_…