SMA2:代码实现详解——Image Encoder篇(Hiera章)

news2024/12/23 12:34:35

SMA2:代码实现详解——Image Encoder篇(Hiera)

写在前面

大家在SMA2:代码实现详解——Image Encoder篇(FpnNeck)下的留言我已收到,感谢大家的支持,后面如果遇到比较难以讲清的部分可能会使用视频的形式。博主最近要准备秋招,更新可能会慢许多,希望大家能谅解。

言归正传,在SMA2:代码实现详解——Image Encoder篇(FpnNeck)中,我们已经知道了SMA2的整体架构,并且介绍了Image Encoder组件中的FpnNeck。这一篇博客我们就来详细介绍Image Encoder的基本骨架backbone——Hiera

Hiera介绍

Hiera是文章Hiera: A Hierarchical Vision Transformer without the Bells-and-Whistles中提出的一种分层视觉Transformer架构。它不仅可以处理图像,而且这个架构可以应用于视频。Hiera是一个纯粹的简单分层ViT模型,不存在任何卷积、移位或者十字窗口操作,仅有Transformer结构组件。它比之前跨多个模型大小、领域和任务的工作更快、更准确。
在这里插入图片描述

Hiera与MAE(Masked AutoEncoder)

MAE(Masked AutoEncoder, 掩码自编码器)

图像MAE由论文Masked Autoencoders Are Scalable Vision Learners提出,它表明,MAE是计算机视觉的可扩展自监督学习器。方法非常简单:屏蔽输入图像的随机Patch并重建丢失的像素。它基于两个核心设计。首先,作者开发了一种非对称编码器-解码器架构,其中的编码器仅对Patch的可见子集(没有掩码标记)进行操作,而轻量级解码器可根据潜在表示和掩码标记重建原始图像。作者发现屏蔽高比例的输入图像(例如 75%)会产生一项不简单且有意义的自我监督任务。将这两种设计结合起来能够高效且有效地训练大型模型:加速训练(3 倍或更多)并提高准确性。可扩展方法允许学习泛化良好的高容量模型:例如,在仅使用 ImageNet-1K 数据的方法中,普通 ViT-Huge 模型实现了最佳准确率 (87.8%)。下游任务中的传输性能优于监督预训练,并显示出有希望的扩展行为。

Hiera便使用了MAE的方式进行训练。

Hiera架构

在这里插入图片描述

选择使用像MAE(如图所示)这样的强代理任务(pretext task)来教导模型。 Hiera完全由标准ViT块组成。为了提高效率,在前两个阶段使用“掩模单元”内的局部注意力,其余阶段使用全局注意力(Global Attention)。在每个阶段转换中,Q和跳跃连接的特征通过线性层加倍,空间维度通过2×2最大池池化。

SMA2中Hiera(HieraDet)的实现

class Hiera(nn.Module):
    """
    Reference: https://arxiv.org/abs/2306.00989
    """

    def __init__(self, ...):
        ...
        self.blocks = nn.ModuleList()

        for i in range(depth):
            dim_out = embed_dim
            ...
            block = MultiScaleBlock(
                dim=embed_dim,
                dim_out=dim_out,
                num_heads=num_heads,
                drop_path=dpr[i],
                q_stride=self.q_stride if i in self.q_pool_blocks else None,
                window_size=window_size,
            )
            embed_dim = dim_out
            self.blocks.append(block)

    def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
        h, w = hw
        window_embed = self.pos_embed_window
        pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
        pos_embed = pos_embed + window_embed.tile(
            [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1)
        return pos_embed

    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        x = self.patch_embed(x)
        # x: (B, H, W, C)

        # Add pos embed
        x = x + self._get_pos_embed(x.shape[1:3])

        outputs = []
        for i, blk in enumerate(self.blocks):
            x = blk(x)
            if (i == self.stage_ends[-1]) or (
                i in self.stage_ends and self.return_interm_layers
            ):
                feats = x.permute(0, 3, 1, 2)
                outputs.append(feats)

        return outputs

首先,Hiera先将图片划分并映射为patch嵌入向量(上述代码62行),然后计算位置信息并相加(代码第66行)。值得注意的是,SMA2在实现Hiera中位置嵌入时,参照了Window Attention is Bugged: How not to Interpolate Position Embeddings一文,他们发现在使用窗口注意力的同时插值位置嵌入是错误的。Hiera和ViTDet两者确实都存在此错误。于是作者提出了一种简单的绝对窗口位置嵌入策略,它彻底解决了Hiera中的错误,并提高了ViTDet中模型的速度和性能。

代码的68-75行实际上就是Hiera主体ViT块的处理,值得关注的只有带有Q pooling的ViT块,这是在MultiScaleBlock中实现的。

class PatchEmbed(nn.Module):
    """
    Image to Patch Embedding.
    """

    def __init__(
        self,
        kernel_size: Tuple[int, ...] = (7, 7),
        stride: Tuple[int, ...] = (4, 4),
        padding: Tuple[int, ...] = (3, 3),
        in_chans: int = 3,
        embed_dim: int = 768,
    ):
        """
        Args:
            kernel_size (Tuple): kernel size of the projection layer.
            stride (Tuple): stride of the projection layer.
            padding (Tuple): padding size of the projection layer.
            in_chans (int): Number of input image channels.
            embed_dim (int):  embed_dim (int): Patch embedding dimension.
        """
        super().__init__()
        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        # B C H W -> B H W C
        x = x.permute(0, 2, 3, 1)
        return x

PatchEmbed模块将图片的形状(B,C,H,W)转化为更常见的适用于Transformer处理的形状(B, H, W, C),因为后面经过VIT块时会要求(B,L,C)的形式。实际上,这个模块的卷积映射继承了ViT的做法,直接利用了卷积的特性,通过指定Kernel_size与strides隐式划分了窗口,并且完成了线性变换得到patch enmbedding

值得注意的是位置嵌入的计算:

class Hiera(nn.Module):
 
    def __init__(...)
        super().__init__()
        ...
        self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
        self.pos_embed = nn.Parameter(
            torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
        )
        self.pos_embed_window = nn.Parameter(
            torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
        )
        ...

    def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
        h, w = hw
        window_embed = self.pos_embed_window
        pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
        pos_embed = pos_embed + window_embed.tile(
            [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
        )
        pos_embed = pos_embed.permute(0, 2, 3, 1)
        return pos_embed

代码第18行是计算全局的可学习位置嵌入。第19行加号的右边window_embed.tile(...)是计算每个window内的局部位置编码,每个window的位置编码都是相同的。我们可以使用matplotlib做一个可视化的样例,可能更容易理解。示例如下(由于代码中是零初始化,不太好展示,这里我选择随机初始化来展示):

在这里插入图片描述

从左到右依次为全局编码、局部编码和最终位置编码。

接下来我们来看MultiScaleBlock的实现:

class MultiScaleBlock(nn.Module):
    def __init__(
        self,
        dim: int,
        dim_out: int,
        num_heads: int,
        mlp_ratio: float = 4.0,
        drop_path: float = 0.0,
        norm_layer: Union[nn.Module, str] = "LayerNorm",
        q_stride: Tuple[int, int] = None,
        act_layer: nn.Module = nn.GELU,
        window_size: int = 0,
    ):
        super().__init__()

        if isinstance(norm_layer, str):
            norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)

        self.dim = dim
        self.dim_out = dim_out
        self.norm1 = norm_layer(dim)

        self.window_size = window_size

        self.pool, self.q_stride = None, q_stride
        if self.q_stride:
            self.pool = nn.MaxPool2d(
                kernel_size=q_stride, stride=q_stride, ceil_mode=False
            )

        self.attn = MultiScaleAttention(
            dim,
            dim_out,
            num_heads=num_heads,
            q_pool=self.pool,
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.norm2 = norm_layer(dim_out)
        self.mlp = MLP(
            dim_out,
            int(dim_out * mlp_ratio),
            dim_out,
            num_layers=2,
            activation=act_layer,
        )

        if dim != dim_out:
            self.proj = nn.Linear(dim, dim_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x  # B, H, W, C
        x = self.norm1(x)

        # Skip connection
        if self.dim != self.dim_out:
            shortcut = do_pool(self.proj(x), self.pool)

        # Window partition
        window_size = self.window_size
        if window_size > 0:
            H, W = x.shape[1], x.shape[2]
            x, pad_hw = window_partition(x, window_size)

        # Window Attention + Q Pooling (if stage change)
        x = self.attn(x)
        if self.q_stride:
            # Shapes have changed due to Q pooling
            window_size = self.window_size // self.q_stride[0]
            H, W = shortcut.shape[1:3]

            pad_h = (window_size - H % window_size) % window_size
            pad_w = (window_size - W % window_size) % window_size
            pad_hw = (H + pad_h, W + pad_w)

        # Reverse window partition
        if self.window_size > 0:
            x = window_unpartition(x, window_size, pad_hw, (H, W))

        x = shortcut + self.drop_path(x)
        # MLP
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x


def window_partition(x, window_size):
    """
    Partition into non-overlapping windows with padding if needed.
    Args:
        x (tensor): input tokens with [B, H, W, C].
        window_size (int): window size.
    Returns:
        windows: windows after partition with [B * num_windows, window_size, window_size, C].
        (Hp, Wp): padded height and width before partition
    """
    B, H, W, C = x.shape

    pad_h = (window_size - H % window_size) % window_size
    pad_w = (window_size - W % window_size) % window_size
    if pad_h > 0 or pad_w > 0:
        x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
    Hp, Wp = H + pad_h, W + pad_w

    x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
    windows = (
        x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    )
    return windows, (Hp, Wp)

def do_pool():... #(B, H, W, C) -> (B, H', W' C)

MultiScaleBlockMultiScaleAttentionMLP构成,有经验的小伙伴看到注意力机制和MLP,显然得出它是一个Transformer。第60-63行代码就是根据每个stage给定的window size划分patch。
在这里插入图片描述

而且针对于每个Stage的交界,都使用Q pooling,这在MultiScaleAttention中实现。

class MultiScaleAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        dim_out: int,
        num_heads: int,
        q_pool: nn.Module = None,
    ):
        super().__init__()

        self.dim = dim
        self.dim_out = dim_out
        self.num_heads = num_heads
        self.q_pool = q_pool
        self.qkv = nn.Linear(dim, dim_out * 3)
        self.proj = nn.Linear(dim_out, dim_out)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, _ = x.shape
        # qkv with shape (B, H * W, 3, nHead, C)
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
        # q, k, v with shape (B, H * W, nheads, C)
        q, k, v = torch.unbind(qkv, 2)

        # Q pooling (for downsample at stage changes)
        if self.q_pool:
            q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
            H, W = q.shape[1:3]  # downsampled shape
            q = q.reshape(B, H * W, self.num_heads, -1)

        # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
        x = F.scaled_dot_product_attention(
            q.transpose(1, 2),
            k.transpose(1, 2),
            v.transpose(1, 2),
        )
        # Transpose back
        x = x.transpose(1, 2)
        x = x.reshape(B, H, W, -1)

        x = self.proj(x)

        return x

代码19-23以及31-41都是比较传统的自注意力机制的计算了。
而所谓的Q pooling在26-30行,只是对Q向量转换为宽高的形状(B, H*W,)->(B, H, W, …),然后进行池化。其实对于H和W,它们应该是我们之前指定的window size。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2137342.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

jsp+sevlet+mysql图书管理系统

jspsevletmysql图书管理系统 一、系统介绍二、功能展示1.图书查询(学生)2.借阅信息(学生)3.借阅历史(学生)4.借阅历史(管理员)5.读者管理(管理员)6.图书分类(管理员)7.图书借阅信息(管理员)8.图书归还信息(管理员) 四、其它1.其他系统实现 一、系统介绍 系统主要功能&#xff…

《Linux运维总结:基于ARM64+X86_64架构CPU使用docker-compose一键离线部署mongodb 7.0.14容器版副本集群》

总结:整理不易,如果对你有帮助,可否点赞关注一下? 更多详细内容请参考:《Linux运维篇:Linux系统运维指南》 一、部署背景 由于业务系统的特殊性,我们需要面向不通的客户安装我们的业务系统&…

机器学习和深度学习的常见概念总结(多原创图)

目录 使用说明一、未分类损失函数(Loss Function)1. **损失函数的作用**2. **常见的损失函数**2.1. **均方误差(MSE, Mean Squared Error)**2.2. **均方根误差(RMSE, Root Mean Squared Error)**2.3. **平均…

【云原生安全篇】一文掌握Harbor集成Trivy应用实践

【云原生安全篇】一文掌握Harbor集成Trivy应用实践 目录 1 概念 1.1 什么是 Harbor 和 Trivy? 1.1.1 Harbor 1.1.2 Trivy 1.2 Harbor 与 Trivy 的关系 Trivy 在 Harbor 中的作用: 1.3 镜像扫描工作流程 2 实战案例:在Harbor 配置 Trivy …

SafaRi:弱监督引用表达式分割的自适应序列转换器

引用表达式分割(reference Expression Segmentation, RES)旨在提供文本所引用的图像(即引用表达式)中目标对象的分割掩码。 目前存在的挑战 1)现有的方法需要大规模的掩码注释。 2)此外,这种方法不能很好地推广到未见/零射击场景 改进 1)提出了一个弱…

探索自动化的魔法:Python中的pyautogui库

文章目录 探索自动化的魔法:Python中的 pyautogui 库背景:为什么选择pyautogui?pyautogui是什么?如何安装pyautogui?五个简单的库函数使用方法场景应用常见Bug及解决方案总结 探索自动化的魔法:Python中的 …

VirtualBox桥接网卡消失,安装Docker后导致桥接网卡服务消失问题解决记录

问题记录:VirtualBox虚拟机的桥接网卡消失 记录时间:2024.9.14 系统:win10 问题已解决。 原因: 猜测是由于安装Docker,也会使用我们的网卡进行虚拟化,导致网卡与virtualbox的桥接服务丢失。 解决方案…

基于python+django+vue的鲜花商城系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【2025最新】基于pythondjangovueMySQL的线…

三维点云处理(C++)学习记录——PDAL

一、OSGeo4W简概 OSGeo4W是一个基于Windows系统(版本7-11)的开源地理软件二进制包发布平台。OSGeo4W包括开源GIS桌面应用程序(QGIS、GRASS GIS)、地理空间库(PROJ、GDAL/OGR、GEOS、SpatiaLite、SAGA GIS)、…

org.flowable.bpmn.exceptions.XMLException: 元素类型 必须由匹配的结束标记

flowable在流程部署时经常汇报这个错误: org.flowable.bpmn.exceptions.XMLException: 元素类型... 必须由匹配的结束标记 经检查发现是数据库存的中午乱码导致xml结构异常了 解决办法如下: 在catalina.bat文件中找到如下地方,加入 -Dfile.…

Python爬取某猫投诉数据(含signature参数分析与算法还原)

文章目录 1. 写在前面2. 接口分析3. 爬虫实现 【🏠作者主页】:吴秋霖 【💼作者介绍】:擅长爬虫与JS加密逆向分析!Python领域优质创作者、CSDN博客专家、阿里云博客专家、华为云享专家。一路走来长期坚守并致力于Python…

电商数据采集分析全流程分享||电商数据API接口

电商数据监测,能为品牌发展提供参考依据,已经成为了业内共识。依托智能系统,将电商数据转换为有价值的营销情报,只需三步: 数据采集 可采集30多个电商平台数据,采集字段高达40多个,包含标题、价…

网络穿透:TCP 打洞、UDP 打洞与 UPnP

在现代网络中,很多设备都处于 NAT(网络地址转换)或防火墙后面,这使得直接访问这些设备变得困难。在这种情况下,网络穿透技术就显得非常重要。本文将介绍三种常用的网络穿透技术:TCP 打洞、UDP 打洞和 UPnP。…

数据库运维实操优质文章文档分享(含Oracle、MySQL等) | 2024年8月刊

本文为大家整理了墨天轮数据社区2024年8月发布的优质技术文章/文档,主题涵盖Oracle、MySQL、PostgreSQL等主流数据库系统以及国产数据库的技术实操,从基础的安装配置到复杂的故障排查,再到性能优化的实用技巧及常用脚本等,分享给大…

【Python电商项目汇报总结】**采集10万+淘宝商品详情数据注意事项总结汇报**

大家好,今天我想和大家聊聊我们在采集10万淘宝商品详情数据时需要注意的一些关键问题。这不仅仅是一个技术活,更是一场细心与合规的较量。下面,我就用咱们都听得懂的话,一一给大家说道说道。 **一、明确目标,有的放矢…

vue前端实现下载导入模板文件

1.需要导出的文件放置public文件夹中 2.在.vue页面中添加下载代码 <a href"./exportTemplate.xlsx" download"导入数据模板.xlsx" target"_blank" style"color: #2967e9;">导入数据模板.xlsx</a><!-- 如使用element框…

linux使用命令行编译qt.cpp

步骤&#xff1a; mkdir qttestcd qttestvim hello.cpp #include <QApplication> #include <QDialog> #include <QLabel> int main(int argc,char* argv[]) {QApplication a(argc,argv);QLabel label("aaa");label.resize(100,100);label.show()…

在conda虚拟环境中安装cv2(试错多次总结)

首先保证你创建好了虚拟环境&#xff0c;并在anaconda命令窗口激活虚拟环境 依次输入下列命令&#xff1a; pip install opencv-python3.4.1.15 pip install opencv-contrib-python3.4.1.15 pip install dlib19.6.1 然后测试cv2是否可以使用&#xff0c;输入python 运行pyth…

二叉搜索树的判断+平衡二叉树的判断

一、认识二叉树 二叉树 二叉树 二叉树 二叉搜索树 满二叉树 平衡二…

SpringBoot万级并发-jemeter-Address already in use: connect

一、场景 用Jmeter压力单测接口的时候&#xff0c;发现报 Response code:Non HTTP response code: java.net.BindException Response message:Non HTTP response message: Address already in use: connect 然后我这边是wondows的电脑操作压测的&#xff0c;操作系统win10&…