【图像分割】【深度学习】SAM官方Pytorch代码-Prompt encoder模块ProEnco网络解析

news2025/1/8 19:00:30

【图像分割】【深度学习】SAM官方Pytorch代码-Prompt encoder模块PromptEncoder网络解析

Segment Anything:建立了迄今为止最大的分割数据集,在1100万张图像上有超过1亿个掩码,模型的设计和训练是灵活的,其重要的特点是Zero-shot(零样本迁移性)转移到新的图像分布和任务,一个图像分割新的任务、模型和数据集。SAM由三个部分组成:一个强大的图像编码器(Image encoder)计算图像嵌入,一个提示编码器(Prompt encoder)嵌入提示,然后将两个信息源组合在一个轻量级掩码解码器(Mask decoder)中来预测分割掩码。本博客将讲解Prompt encoder模块的深度学习网络代码。

文章目录

  • 【图像分割】【深度学习】SAM官方Pytorch代码-Prompt encoder模块PromptEncoder网络解析
  • 前言
  • PromptEncoder网络简述
    • SAM模型关于ProEnco网络的配置
    • ProEnco网络结构与执行流程
  • ProEnco网络基本步骤代码详解
    • Embed_Points
    • Embed_Boxes
    • Embed_Masks
    • PositionEmbeddingRandom
  • 总结


前言

在详细解析SAM代码之前,首要任务是成功运行SAM代码【win10下参考教程】,后续学习才有意义。本博客讲解Prompt encoder模块的深度网络代码,不涉及其他功能模块代码。


PromptEncoder网络简述

SAM模型关于ProEnco网络的配置

博主以sam_vit_b为例,详细讲解ViT网络的结构。
代码位置:segment_anything/build_sam.py

def build_sam_vit_b(checkpoint=None):
    return _build_sam(
        # 图像编码channel
        encoder_embed_dim=768,
        # 主体编码器的个数
        encoder_depth=12,
        # attention中head的个数
        encoder_num_heads=12,
        # 需要将相对位置嵌入添加到注意力图的编码器( Encoder Block)
        encoder_global_attn_indexes=[2, 5, 8, 11],
        # 权重
        checkpoint=checkpoint,
    )

sam模型中prompt_encoder模块初始化

prompt_encoder=PromptEncoder(
    # 提示编码channel(和image_encoder输出channel一致,后续会融合)
    embed_dim=prompt_embed_dim,
    # mask的编码尺寸(和image_encoder输出尺寸一致)
    image_embedding_size=(image_embedding_size, image_embedding_size),
    # 输入图像的标准尺寸
    input_image_size=(image_size, image_size),
    # 对输入掩码编码的通道数
    mask_in_chans=16,
),

ProEnco网络结构与执行流程

Prompt encoder源码位置:segment_anything/modeling/prompt_encoder.py
ProEnco网络(PromptEncoder类)结构参数配置。

def __init__(
    self,
    embed_dim: int,                         # 提示编码channel
    image_embedding_size: Tuple[int, int],  # # mask的编码尺寸
    input_image_size: Tuple[int, int],      # 输入图像的标准尺寸
    mask_in_chans: int,                     # 输入掩码编码的通道数
    activation: Type[nn.Module] = nn.GELU,  # 激活层
) -> None:
    super().__init__()
    self.embed_dim = embed_dim              # 提示编码channel
    self.input_image_size = input_image_size                # 输入图像的标准尺寸
    self.image_embedding_size = image_embedding_size        # mask的编码尺寸
    self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
    self.num_point_embeddings: int = 4                      # 4个点:正负点,框的俩个点
    point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]   # 4个点的嵌入向量
    # nn.ModuleList它是一个存储不同module,并自动将每个module的parameters添加到网络之中的容器
    self.point_embeddings = nn.ModuleList(point_embeddings)                     # 4个点的嵌入向量添加到网络
    self.not_a_point_embed = nn.Embedding(1, embed_dim)                         # 不是点的嵌入向量
    self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])           # mask的输入尺寸
    self.mask_downscaling = nn.Sequential(                                                      # 输入mask时 4倍下采样
        nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
        LayerNorm2d(mask_in_chans // 4),
        activation(),
        nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
        LayerNorm2d(mask_in_chans),
        activation(),
        nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
    )
    self.no_mask_embed = nn.Embedding(1, embed_dim)                         # 没有mask输入时 嵌入向量

SAM模型中ProEnco网络结构如下图所示:

ProEnco网络(PromptEncoder类)在特征提取中的几个基本步骤:

  1. Embed_Points:标记点编码(标记点由点转变为向量)
  2. Embed_Boxes:标记框编码(标记框由点转变为向量)
  3. Embed_Masks:mask编码(mask下采样保证与Image encoder输出一致)
def forward(
    self,
    points: Optional[Tuple[torch.Tensor, torch.Tensor]],
    boxes: Optional[torch.Tensor],
    masks: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
    # 获得 batchsize  当前predict为1
    bs = self._get_batch_size(points, boxes, masks)
    
    # -----sparse_embeddings----
    sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
    if points is not None:
        coords, labels = points
        point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
        sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
    if boxes is not None:
        box_embeddings = self._embed_boxes(boxes)
        sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
    # -----sparse_embeddings----
    
    # -----dense_embeddings----
    if masks is not None:
        dense_embeddings = self._embed_masks(masks)
    else:
        dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
            bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
        )
    # -----dense_embeddings----
    
    return sparse_embeddings, dense_embeddings

获取batchsize

def _get_batch_size(
    self,
    points: Optional[Tuple[torch.Tensor, torch.Tensor]],
    boxes: Optional[torch.Tensor],
    masks: Optional[torch.Tensor],
) -> int:
    if points is not None:
        return points[0].shape[0]
    elif boxes is not None:
        return boxes.shape[0]
    elif masks is not None:
        return masks.shape[0]
    else:
        return 1

获取设备型号

    def _get_device(self) -> torch.device:
        return self.point_embeddings[0].weight.device

ProEnco网络基本步骤代码详解

Embed_Points


标记点预处理,将channel由2变成embed_dim(PositionEmbeddingRandom),然后再加上位置编码权重。

2:坐标(h,w)
embed_dim:提示编码的channel

Embed_Points结构如下图所示:

def _embed_points(
    self,
    points: torch.Tensor,
    labels: torch.Tensor,
    pad: bool,
) -> torch.Tensor:
    # 移到像素中心
    points = points + 0.5
    # points和boxes联合则不需要pad
    if pad:
        padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)  # B,1,2
        padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)     # B,1
        points = torch.cat([points, padding_point], dim=1)                          # B,N+1,2
        labels = torch.cat([labels, padding_label], dim=1)                          # B,N+1
    point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)  # B,N+1,2f
    # labels为-1是非标记点,设为非标记点权重
    point_embedding[labels == -1] = 0.0
    point_embedding[labels == -1] += self.not_a_point_embed.weight
    # labels为0是背景点,加上背景点权重
    point_embedding[labels == 0] += self.point_embeddings[0].weight
    # labels为1的目标点,加上目标点权重
    point_embedding[labels == 1] += self.point_embeddings[1].weight
    return point_embedding

个人理解:pad的作用相当于box占位符号,box和points可以联合标定完成图像分割的,但是此时的box只能有一个,不能有多个。

Embed_Boxes


标记框预处理,将channel由4到2再变成embed_dim(PositionEmbeddingRandom),然后再加上位置编码权重。

4:坐标(h1,w1,h2,w2) -->起始点与末位点
2:坐标(h,w)–>4 reshape 成 2×2
embed_dim:提示编码的channel

Embed_Boxes结构如下图所示:

def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
    # 移到像素中心
    boxes = boxes + 0.5
    coords = boxes.reshape(-1, 2, 2)
    corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)    #
    # 目标框起始点的和末位点分别加上权重
    corner_embedding[:, 0, :] += self.point_embeddings[2].weight
    corner_embedding[:, 1, :] += self.point_embeddings[3].weight
    return corner_embedding

个人理解:boxes reshape 后 batchsize是会增加的,B,N,4–>BN,2,2
因此这里可以得出box和points联合标定时,box为什么只能是一个,而不能是多个。

Embed_Masks


mask的输出尺寸是Image encoder模块输出的图像编码尺寸的4倍,因此为了保持一致,需要4倍下采样。
Embed_Masks结构如下图所示:

def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
    # mask下采样4倍
    mask_embedding = self.mask_downscaling(masks)
    return mask_embedding
# 在PromptEncoder的__init__定义
self.mask_downscaling = nn.Sequential(                                                      # 输入mask时 4倍下采样
    nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
    LayerNorm2d(mask_in_chans // 4),
    activation(),
    nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
    LayerNorm2d(mask_in_chans),
    activation(),
    nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
        )

假设没有mask输入,则将no_mask_embed编码扩展到与图像编码一致的尺寸代替mask。

# 在PromptEncoder的forward定义
dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
    bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
)

PositionEmbeddingRandom


用于将标记点和标记框的坐标进行提示编码预处理。

def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
    super().__init__()
    if scale is None or scale <= 0.0:
        scale = 1.0
    # 理解为模型的常数 [2,f]
    self.register_buffer(
        "positional_encoding_gaussian_matrix",
        scale * torch.randn((2, num_pos_feats)),
    )

将标记点的坐标具体的位置转变为[0~1]之间的比例位置

def forward_with_coords(
    self, coords_input: torch.Tensor, image_size: Tuple[int, int]
) -> torch.Tensor:
    coords = coords_input.clone()
    # 将坐标位置缩放到[0~1]之间
    coords[:, :, 0] = coords[:, :, 0] / image_size[1]
    coords[:, :, 1] = coords[:, :, 1] / image_size[0]
    # B,N+1,2-->B,N+1,2f
    return self._pe_encoding(coords.to(torch.float))

标记点位置编码

因为sin和cos,编码的值归一化至 [-1,1],源码注释是[0,1],博主经过实验发现注释不对

def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
    # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
    coords = 2 * coords - 1
    # B,N+1,2 × 2,f --> B,N+1,f
    coords = coords @ self.positional_encoding_gaussian_matrix
    coords = 2 * np.pi * coords
    # outputs d_1 x ... x d_n x C shape
    # B,N+1,2f
    return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)

总结

尽可能简单、详细的介绍SAM中Prompt encoder模块的ProEnco网络的代码。后续会讲解SAM的其他模块的代码。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/467600.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

北邮22信通:二叉树层序遍历的非递归算法:A Story Between Two Templates

北邮22信通一枚~ 跟随课程进度每周更新数据结构与算法的代码和文章 持续关注作者 解锁更多邮苑信通专属代码~ 获取更多文章 请访问专栏~ 北邮22信通_青山如墨雨如画的博客-CSDN博客 目录 一.总纲 二.用队列存储 2.1用模板类实现队列 2.1.1核心思路&#xff1a; …

丁鹿学堂:使用vite手动构建vue项目的注意事项和步骤总结

使用yarn 默认安装了nodeJS环境&#xff0c;使用yarn&#xff0c;比npm更好用。 npm install --global yarn使用yarn按钻过vite yarn add -D vite使用yarn初始化项目 yarn init -y安装vite yarn add vite -D安装vue yarn add vue项目目录&#xff1a; 创建index.html sr…

分享一个有意思的文字飞入动画(模仿水滴融合)

先上效果图&#xff1a; 代码如下&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title><style>* {margin: 0;padding: 0;box-sizing: border-box;}:root {--text-…

前后端分离demo 旅馆管理系统(Angular+Springboot)

模型设计 旅馆管理系统&#xff0c;主要涉及到登记入住&#xff0c;退房以及客房和客人信息管理&#xff1b;经过分析抽像出涉及到的实体以及各实体之间的关系&#xff1a;   可以看出整个业务以客房为中心&#xff0c;入住&#xff0c;退房&#xff0c;定价&#xff0c;收费…

Build an SAP Fiori App(一)后面更新中

1.登录 SAP BTP Trial 地址&#xff1a; https://account.hanatrial.ondemand.com 流程可以参考 点击 serviced marketplace 搜索studio 点击创建 点击创建&#xff0c;点击view subscription 点击go to application 创建完成后 添加新链接 Field Value Name ES5 - if you’…

Shell基础入门实战

写在前面 好久没在项目内做自动化了&#xff0c;主要是现阶段在项目内做自动化收益不大&#xff0c;最近开发做batch run的正好缺人&#xff0c;我看了一下代码&#xff0c;就是通过代码读取jar包和远程服务器连接&#xff0c;然后通过shell脚本&#xff0c;向数据库插入数据&a…

如何成为一名优秀的接口自动化测试工程师?了解这些技能是关键

摘要&#xff1a; 随着互联网行业的不断发展&#xff0c;越来越多的应用程序通过API接口提供服务。因此&#xff0c;接口自动化测试成为了保障软件质量的重要环节。本文将介绍接口自动化测试所需掌握的技能&#xff0c;以及相关的历史进程。 B站首推&#xff01;2023最详细自…

什么是 Java 垃圾回收器~

什么是 Java 垃圾回收器 Java 垃圾回收器是 Java 虚拟机 (JVM) 的三个重要模块 (另外两个是解释器和多线程机制) 之一&#xff0c;为应用程序提供内存的自动分配 (Memory Allocation)、自动回收 (Garbage Collect) 功能&#xff0c;这两个操作都发生在 Java 堆上 (一段内存快)…

sqoop安装

文章目录 1. 上传安装包至虚拟机2. 解压安装包到指定路径3. 修改目录名4. 配置环境变量5. 修改配置文件6. 拷贝mysql驱动包7. 验证安装是否成功8. 测试sqoop连接mysql 注&#xff1a;sqoop安装的前提条件是环境已安装java和hadoop 1. 上传安装包至虚拟机 上传安装包sqoop-1.4.…

信通初试第一:无科研无竞赛一战上岸上海交大819学硕感悟

笔者来自通信考研小马哥23上交819全程班学员 信通初试第一&#xff1a;无科研无竞赛一战上岸上海交大819学硕感悟 原创2023-04-27 11:04通信考研小马哥 笔者来自通信考研小马哥23上交819全程班学员 本人情况&#xff1a; 本人是19届交本&#xff0c;本科成绩很差&#xff0c;…

赎金信(Hash的应用)

给你两个字符串&#xff1a;ransomNote 和 magazine &#xff0c;判断 ransomNote 能不能由 magazine 里面的字符构成。 如果可以&#xff0c;返回 true &#xff1b;否则返回 false 。 magazine 中的每个字符只能在 ransomNote 中使用一次。 来源&#xff1a;力扣&#xff0…

Java 实现 YoloV7 目标检测

1 OpenCV 环境的准备 这个项目中需要用到 opencv 进行图片的读取与处理操作&#xff0c;因此我们需要先配置一下 opencv 在 java 中运行的配置。 首先前往 opencv 官网下载 opencv-4.6 &#xff1a;点此下载&#xff1b;下载好后仅选择路径后即可完成安装。 此时将 opencv\b…

WMS是什么?

WMS&#xff08;Warehouse Management System&#xff09;中文译作仓库管理系统&#xff0c;是一种专用于物流仓储管理的IT系统。它主要应用于企业物流中心、配送中心、供应商物料储备中心、电子商务配送中心等仓库管理过程中。 WMS系统可以帮助企业管理和控制其物流仓储流程。…

线程池的设计

一.什么是线程池? 线程池就是创建若干个可执行的线程放到容器中&#xff0c;有任务处理时&#xff0c;会提交到线程池中的任务队列中&#xff0c;线程处理完不是销毁&#xff0c;而是阻塞等待下一个任务。 二.为何要使用线程池? 降低资源消耗。重复利用创建好的线程减少线…

NLP原理和应用入门:paddle(梯度裁剪、ONNX协议、动态图转静态图、推理部署)

目录 一、梯度裁剪 1.1设定范围值裁剪 1. 全部参数裁剪&#xff08;默认&#xff09; 2. 部分参数裁剪 1.2 通过L2范数裁剪 1.3通过全局L2范数裁剪 二. 模型导出ONNX协议 三、动态图转静态图 3.1两种图定义 3.2 什么场景下需要动态图转静态图 3.3为什么动态图模式越来…

k8s 部署 seata1.6.0 集群 基于 nacos 注册中心 + mysql 数据库

k8s 部署 seata1.6.0 集群 基于 nacos 注册中心 mysql 数据库 大纲 1 镜像制作2 准备configmap3 准备deploy 部署文件4 部署seata到k8s 镜像制作 下载seata 选择1.6.0。下载后得到 seata-server-1.6.0.zip 已经上传到百度云盘 下载地址&#xff1a;http://seata.io/zh-cn…

Maven 依赖下载失败解决方案——配置国内源 + 具体解决办法

目录 前言 一、配置 Maven 国内源 二、重新下载jar包 三、其他问题 前言 最近发现 spring-boot 框架更新到 2.7.11 了&#xff0c;由于以前一直使用的是 2.7.9 &#xff0c;所以一直出现依赖下载失败的问题&#xff0c;实际上这是由于 IDEA 会先加载之前下载好的依赖&#xf…

openharmony内核中不一样的双向链表

不一样的双向链表 链表初识别遍历双向链表参考链接 链表初识别 最近看openharmony的内核源码时看到一个有意思的双向链表&#xff0c;结构如下 typedef struct LOS_DL_LIST{struct LOS_DL_LIST *pstPrev; //前驱节点struct LOS_DL_LIST *pstNext; //后继节点 }LOS_DL_LIST;不…

FPGA入门系列12--RAM的使用

文章简介 本系列文章主要针对FPGA初学者编写&#xff0c;包括FPGA的模块书写、基础语法、状态机、RAM、UART、SPI、VGA、以及功能验证等。将每一个知识点作为一个章节进行讲解&#xff0c;旨在更快速的提升初学者在FPGA开发方面的能力&#xff0c;每一个章节中都有针对性的代码…

Spring IOC DI - 整合MyBatis

Spring IOC目录 主要内容Spring 框架介绍Spring 框架的优势(对比以前项目的缺点)Spring 框架引入历史发展框架学习三要素Spring 模块介绍 Spring IoC/DI - 引入IoC/DI 概念辨析使用IoC/DI的好处IoC/DI具体应用场景 Spring IoC/DI - 代码实现环境准备Spring 框架环境搭建创建Mav…