人工智能图像分割之Mask2former源码解读

news2025/3/18 2:43:35

环境搭建:

(1)首先本代码是下载的mmdetection-2022.9的,所以它的版本要配置好,本源码配置例如mmcv1.7,python3.7,pytorch1.13,cuda11.7。pytorch与python,cuda版本匹配可参考:https://www.jb51.net/python/3308342lx.htm。

(2)还有一个是先要安装一个vs2022版本或vs2019,其中确保工作负载下"使用C++的桌面开发"的内容基本安装上

(3)数据集就用coco集,在参数配置中指定目录,例如../configs/mask2former目录下的.py文件

一.Backbone获取多层级特征

(1)进入train.py读取各个配置参数,构建模型调用Mask2Former类,如下图:

有了model(Mask2Former),datasets(COCO数据集)和cfg后就传入下面的方法:

(2)进入Mask2Former模块,然后调用MaskFormer模块,如下图:

(3)进入MaskFormer类中,并进入def forward_train这个方法中,如下图:

然后

进入父类BaseDetector的forward_train方法:

回到maskformer.py中,如下图:

(4)进入head层,因为mask2former_head的父类是maskformer_head,所以会先调父类中的forward_train方法后再进入到了mask2former_head中的forward方法中,mask2former_head.py这个类很重要,如下图: 

二.多层级采样点初始化构建

对输入的每一张特征图加上位置编码,这个FOR中的特征图是256

上图中得到的level_embed就是按每个层级取出的值,它是一个256维的向量,下面用view转成四维向量后与位置编码(pos_embed)做加法。

然后调用point_generator中的single_level_grid_priors方法,如下图:

shift_x与shift_y都是32的一维矩阵,

做成网络后就是32*32=1024的一维矩阵,然后做合并得到棋盘上的位置

返回到msdeformattn_pixel_decoder.py中,如下图:

它是维度转变后1024在最前面的

三.多层级输入特征序列创建方法

上图中的所说的是进入transformer.py中的DetrTransformerEncoder类下的forward方法。

四.偏移量与权重计算并转换

还是在transformer.py中,对每一层(例如selfattention,bn,norm,ffn,全连接)进行for,如果是self_attn

执行selfattention层时,进入到了multi_scale_deform_attn.py中,

进入到multi_scale_deform_attn.py中的forward方法:

下面做了多头注意力机制,变成8头:

重点来了,对query做sampling_offsets与attention_weights方法后并调用view变形得到相应的每个点的偏移量与权重(softmax),偏移量到时候做采样时用到,最后还要对这二个值做乘法,如下图:

其中sampling_offsets这个是全连接层,如下图:

由这4个采样点找它们的偏移量,每个偏移量都有x,y二个值组成。

为什么这里权重值是96呢?因为每个偏移量是一个点,它只是这个点有x,y组成,但权重是指这个点的权重,所以就是上面偏移量输出的192/2=96了(也等以8*(3*4)=8*12)。

上图中的levels是指层级数,points是指采样点数。

五.Encoder特征构建方法实例

偏移量有了,现在我们执行特征计算操作了,而原来特征不准,现在要把新的偏移完的准确的特征拿到手,所以这里也要做特征的偏移。还是在multi_scale_deform_attn.py中,

上图中对齐特征是要重新进行采样的,这里的特征偏移范围是[-1,1]。

做完这一步后返回到transformer.py这里,发现就是做了一个self_attn层的操作,得到query值,如下图所示:

做完ffn全连接后再做norm,就算做完一个层级采样了,其它层级采样也是一样按这样流程的,最后把每一层级执行完后就得到特征值,返回到msdeformattn_pixel_decoder.py中,存到了“(3)多层级输入特征序列创建方法”所说的memory变量(它存的是编码完后特征)中,如下图:

总结一下这里encoder是做了什么:其实就是和可变形detr是一样的,就是对展开的序列提特征,我们是希望它是多层级,多头注意力机制和加上可变形的位置偏移,这样可以得到我们序列更好的特征。

还是在msdeformattn_pixel_decoder.py中,总共3个层级,每个层级的大小(图像上的点的个数)分别是1024(它是由32*32得到),4096(它是由64*64得到)与16384(它是由128*128得到)如下图:

等下就用这个y用来预测一下是每个点是前景还是背景?

六.query要预测的任务解读

返回到mask2former_head.py类中,至此self.pixel_decoder方法调用结束(主要是transformer编码这一块),准备decoder解码了。

上图的query_feat中初始化的100是指decoder中会找100样东西,例下图:

下面开始调用forward_head方法:

这个会找到这100个的前景与背景分别是那些?

七.Decoder中的AttentionMask方法

上图中调用sigmoid()后的值是在0至1之间范围之间。

返回上面3张图的预测值

现在又跳回到transformer.py的baseTransformer当中,

八.损失模块输入参数分析

返回到mask2former_head.py中,如下图:

有了cls_pred_list,mask_pred_list这二个结果(10层)后,接下来就去计算每一层的损失函数结果了。返回到maskformer_head.py文件中,如下图:

上图中的81=80+1,其中80是实际类别,1表示背景。

gt_masks是指标注的信息。

九.标签分配策略解读

进入到mask2former_head.py中的loss_single方法中,它是对每一层进行处理(共10层),

get_targets方法目的就是找正负样本,这个方法是在maskformer_head.py中,对于每一张图像它的正负样本是不一样的。

调用上图的multi_apply会进入到mask_hungarian_assigner.py中的assign方法,这里主要进行标签分配,如下图:

上图100个-1的值当中,如果与标签匹配上就修改里面-1值。

上图中标签分配考虑的三方面其实就对应分类,mask,iou这三方面的损失。

十.正样本筛选损失计算

还是在mask_hungarian_assigner.py中,

上图中的gt_labels是10个标签。调用cls_cost方法后会进入到match_cost.py中,如下图:

上图中调用了softmax方法后即把cls_pred变成概率值了(0------1之间的值),

同时cls_cost的第2个维度是10了,即返回值变成了(100,10),cls_cost的值是负数来的,因为前面加了负号,如下图:

mask_hungarian_assigner.py中计算完类别损失后,现在计算mask损失,如下图:

上图中的12544就当作服从256*256正态分布的随机采样吧,调用上图的mask_cost方法后进入下图的类中:

上图中求pos与neg损失时,都调用了binary_cross_entropy_with_logits,首先我们知道二元交叉熵(Binary cross entropy)是二分类中常用的损失函数,它可以衡量两个概率分布的距离,二元交叉熵越小,分布越相似。相比F.binary_cross_entropy函数,F.binary_cross_entropy_with_logits函数在内部使用了sigmoid函数,也就是F.binary_cross_entropy_with_logits = sigmoid + F.binary_cross_entropy。

上图中因为它是有12544个采样点,它是累加后求平均,所以要除以12544,最后得到一个(100,10)的矩阵返回回去。为什么它是返回(100,10)?是因为它是100个query都分别与10个类别去计算得到的。

十一.标签分类匹配结果分析

这个mask算出来的cost值是正数来的。下面开始计算dice_cost(类似iou重合比例计算)

跳转到match_cost.py中,先拉长,再按dice系数公式得到被除数

Dice 系数可以计算两个字符串的相似度:Dice(s1,s2) = 2*comm(s1,s2)/(leng(s1)+leng(s2))。其中,comm(s1,s2) 是s1和s2中相同字符的个数; leng(s1)、leng(s2)是字符串s1、s2 的长度。

返回后得到如下图:

计算出三个损失值后就进行累加,如下图:

而上图中matched_col_inds的10个索引值是对应gt标签中的索引位置,matched_row_inds是100个中匹配到损失最小的对应的10个索引。这时我们就认为通过与标签匹配完成了,得到正样本了。

返回后回到了mask2former_head.py中,调用assign方法进行标签匹配就结束了,如下图:

然后调用sample主要是用来取正,负样本的索引值,如下图:

上图中mask_targets是只考虑匹配上的10个。到此时,第一张图片就处理完了,然后按上面的逻辑做第二张图片的处理了。

十二.最终损失计算流程

这时(9)中所说的调用get_targets方法就结束了。前面算的损失是关于标签分配的。而下面将算实际的损失了。

上图中labels这个标签是更新了正样本索引后的标签值,80是默认的初始化的负样本索引。上图中这三个都是200,那就可以做交叉熵损失了,调下图的方法self.loss_cls方法。

上图中二个极端值是指概率值(上图画的y轴值)要么靠近0,要么靠近1,这样就表示确定性越大(高),而不确定性越强的是概率值靠近0.5那种。

十三.汇总所有损失完成迭代

get_uncertainty方法返回就得到不确定性的点索引,如下图

在分割中也许是边界点比较模拟二可(即可能是背景也可能是前景),它的不确定性就比较高。

这里三个loss_cls,loss_dice,loss_mask做完返回值,重复迭代10次求这三个值。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2295355.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

uniapp 编译生成鸿蒙正式app步骤

1,在最新版本DevEco-Studio工具新建一个空项目并生成p12和csr文件(构建-生成私钥和证书请求文件) 2,华为开发者平台 根据上面生成的csr文件新增cer和p7b文件,分发布和测试 3,在最新版本DevEco-Studio工具 文…

2024最新版Java面试题及答案,【来自于各大厂】

发现网上很多Java面试题都没有答案,所以花了很长时间搜集整理出来了这套Java面试题大全~ 篇幅限制就只能给大家展示小册部分内容了,需要完整版的及Java面试宝典小伙伴点赞转发,关注我后在【翻到最下方,文尾点击名片】即可免费获取…

Excel 融合 deepseek

效果展示 代码实现 Function QhBaiDuYunAIReq(question, _Optional Authorization "Bearer ", _Optional Qhurl "https://qianfan.baidubce.com/v2/chat/completions")Dim XMLHTTP As ObjectDim url As Stringurl Qhurl 这里替换为你实际的URLDim postD…

21.2.6 字体和边框

版权声明:本文为博主原创文章,转载请在显著位置标明本文出处以及作者网名,未经作者允许不得用于商业目的。 通过设置Rang.Font对象的几个成员就可以修改字体,设置Range.Borders就可以修改边框样式。 【例 21.6】【项目&#xff…

OpenFeign远程调用返回的是List<T>类型的数据

在使用 OpenFeign 进行远程调用时,如果接口返回的是 List 类型的数据,可以通过以下方式处理: 直接定义返回类型为List Feign 默认支持 JSON 序列化/反序列化,如果服务端返回的是 List的JSON格式数据,可以直接在 Feig…

三维模拟-机械臂自翻车

机械仿真 前言效果图后续 前言 最近在研究Unity机械仿真,用Unity实现其运动学仿真展示的功能,发现一个好用的插件“MGS-Machinery-master”,完美的解决了Unity关节定义缺少液压缸伸缩关节功能,内置了多个场景,讲真的&…

网络安全治理架构图 网络安全管理架构

网站安全攻防战 XSS攻击 防御手段: - 消毒。 因为恶意脚本中有一些特殊字符,可以通过转义的方式来进行防范 - HttpOnly 对cookie添加httpOnly属性则脚本不能修改cookie。就能防止恶意脚本篡改cookie 注入攻击 SQL注入攻击需要攻击者对数据库结构有所…

调用deepseek的API接口使用,对话,json化,产品化

背景 最近没咋用chatgpt了,deepseek-r1推理模型写代码质量是很高。deepseek其输出内容的质量和效果在国产的模型里面来说确实算是最强的,并且成本低,它的API接口生态也做的非常好,和OpenAI完美兼容。所以我们这一期来学一下怎么调…

DeepSeek大模型本地部署实战

1. 下载并安装Ollama 打开浏览器:使用你常用的浏览器(如Chrome、Firefox等)访问Ollama的官方网站。无需特殊网络环境,直接搜索“Ollama”即可找到。 登录与下载:进入Ollama官网后,点击右上角的“Download…

Spring Boot Actuator使用

说明&#xff1a;本文介绍Spring Boot Actuator的使用&#xff0c;关于Spring Boot Actuator介绍&#xff0c;下面这篇博客写得很好&#xff0c;珠玉在前&#xff0c;我就不多介绍了。 Spring Boot Actuator 简单使用 项目里引入下面这个依赖 <!--Spring Boot Actuator依…

[css] 黑白主题切换

link动态引入 类名切换 css滤镜 var 类名切换 v-bind css预处理器mixin类名切换 【前端知识分享】CSS主题切换方案

阿里云专有云网络架构学习

阿里云专有云网络架构 叶脊&#xff08;spine-leaf&#xff09;网络和传统三层网络拓扑对比 阿里云网络架构V3拓扑角色介绍推荐设备设备组网举例带外管理网络带外网和带内网对比设备介绍 安全网络设备介绍 参考 后续更新流量分析叶脊&#xff08;spine-leaf&#xff09;网络和传…

【AIGC】冷启动数据与多阶段训练在 DeepSeek 中的作用

博客主页&#xff1a; [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: AIGC | ChatGPT 文章目录 &#x1f4af;前言&#x1f4af;冷启动数据的作用冷启动数据设计 &#x1f4af;多阶段训练的作用阶段 1&#xff1a;冷启动微调阶段 2&#xff1a;推理导向强化学习&#xff08;RL&#xff0…

GenAI + 电商:从单张图片生成可动态模拟的3D服装

在当今数字化时代,电子商务和虚拟现实技术的结合正在改变人们的购物体验。特别是在服装行业,消费者越来越期待能够通过虚拟试衣来预览衣服的效果,而无需实际穿戴。Dress-1-to-3 技术框架正是为此而生,它利用生成式AI模型(GenAI)和物理模拟技术,将一张普通的穿衣照片转化…

harmonyOS生命周期详述

harmonyOS的生命周期分为app(应用)的生命周期和页面的生命周期函数两部分 应用的生命周期-app应用 在app.js中写逻辑,具体有哪些生命周期函数呢,请看下图: onCreated()、onShow()、onHide()、onDestroy()这五部分 页面及组件生命周期 着重说下onShow和onHide,分别代表是不是…

记一次调整磁盘分区大小的经验

背景 redhat 6 系统 根目录挂载的逻辑卷满了&#xff0c;系统都不能正常运行了 但是/home目录挂载的另外一个逻辑卷却占用只有4% 所以想把/home挂的逻辑卷分一部分给/ 挂的逻辑卷 备份 先把系统整盘备份一下&#xff0c;用clonezilla做一个磁盘镜像&#xff0c;免得失误了搞…

软件测试就业

文章目录 2.6 初识一、软件测试理论二、软件的生产过程三、软件测试概述四、软件测试目的五、软件开发与软件测试的区别&#xff1f;六、学习内容 2.7 理解一、软件测试的定义二、软件测试的生命周期三、软件测试的原则四、软件测试分类五、软件的开发与测试模型1.软件开发模型…

后缀表达式(蓝桥杯19I)

有减于号时 假设有n个大于0从大到小的数&#xff0c;加减符号数为n-1&#xff1a;a,b,c,d,。。。。。&#xff0c;e sum求最大&#xff1a;(max )-(min ) a - (e - ( ) -&#xff08;&#xff09;)( ( )( ) ( ) 。。。。 ) 当序列中有负数时&#xff1a; a -&am…

mac环境下,ollama+deepseek+cherry studio+chatbox本地部署

春节期间&#xff0c;deepseek迅速火爆全网&#xff0c;然后回来上班&#xff0c;我就浅浅的学习一下&#xff0c;然后这里总结一下&#xff0c;我学习中&#xff0c;总结的一些知识点吧&#xff0c;分享给大家。具体的深度安装部署&#xff0c;这里不做赘述&#xff0c;因为网…

TypeScript 中的联合类型:灵活的类型系统

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…