VM-UNet:视觉Mamba UNet用来医学图像分割 论文及代码解读

news2024/12/28 12:11:14

VM-UNet

  • 期刊分析
    • 摘要
    • 贡献
    • 方法
      • 整体框架
      • 1. Vision Mamba UNet (VM-UNet)
      • 2. VSS block
      • 3. Loss function
    • 实验
      • 1.对比实验
      • 2.消融实验
  • 可借鉴参考
  • 代码使用

期刊分析

期刊名: Arxiv
论文主页: VM-UNet
代码: Code

摘要

在医学图像分割领域,基于 CNN 和基于 Transformer 的模型都得到了广泛的探索。然而,CNN 在远程建模能力方面表现出局限性,而 Transformer 则受到其二次计算复杂性的限制。 最近,以 Mamba 为代表的状态空间模型 (SSM) 已成为一种有前途的方法。它们不仅擅长对远程交互进行建模,而且还保持线性计算复杂性。在本文中,利用状态空间模型,我们提出了一种用于医学图像分割的 Ushape 架构模型,名为 Vision Mamba UNet (VM-UNet)。具体来说,引入视觉状态空间(VSS)块作为基础块来捕获广泛的上下文信息,并构建不对称的编码器-解码器结构。我们在 ISIC17、ISIC18 和 Synapse 数据集上进行了全面的实验,结果表明 VM-UNet 在医学图像分割任务中表现出竞争力。据我们所知,这是第一个基于纯SSM模型构建的医学图像分割模型。我们的目标是建立一个基线,并为未来开发更高效、更有效的基于 SSM 的分割系统提供有价值的见解。我们的代码可在 https://github.com/JCruan519/VM-UNet 上获取。


贡献

  1. 提出了 VM-UNet,标志着首次探索纯粹基于 SSM 的模型在医学图像分割中的潜在应用。
  2. 在三个数据集上进行综合实验,结果表明VM-UNet表现出相当的竞争力。
  3. 我们为医学图像分割任务中基于纯SSM的模型建立了基线,提供了宝贵的见解,为开发更高效、更有效的基于SSM的分割方法铺平了道路。

方法

整体框架

整体依旧是传统UNet的结构,将原先的卷积模块换成了视觉状态空间模块(VSS block),将原先的上下采样换成了Patch Merging和Patch Expanding模块,因此跟着代码一个一个模块攻克即可!
在这里插入图片描述

1. Vision Mamba UNet (VM-UNet)

1. Patch Embedding层首先将图像划分为4×4的非重叠patch,然后将图像通道数转化为C
2. 编码器由四个阶段组成,在前三个阶段的末尾应用补丁合并操作,以减少输入特征的高度和宽度,同时增加通道数。
3. 类似地,解码器分为四个阶段。在最后三个阶段的开始,利用补丁扩展操作来减少特征通道的数量并增加高度和宽度。
4. 在解码器之后,采用最终投影层来恢复特征的大小以匹配分割目标。
5. 对于跳跃连接,采用简单的加法操作,没有花哨的东西,因此没有引入任何额外的参数。
疑问: 为什么在选择使用对称性的结构时,在最后一阶段解码器中只使用了1个VSS block呢?

2. VSS block

1. 源自 VMamaba 的 VSS 块是 VM-UNet 的核心模块,如图 1 (b) 所示。经过层归一化后,输入被分为两个分支。在第一个分支中,输入通过线​​性层,然后通过激活函数。在第二个分支中,输入通过线​​性层、深度可分离卷积和激活函数进行处理,然后输入 2D 选择性扫描 (SS2D) 模块以进行进一步的特征提取。随后,使用层归一化对特征进行归一化,然后使用第一个分支的输出执行按元素生成以合并两个路径。最后,使用线性层混合特征,并将此结果与残差连接相结合以形成 VSS 块的输出。本文默认采用SiLU[14]作为激活函数。


在这里插入图片描述
2. SS2D 由三个部分组成:扫描扩展操作(Scan Expanding)、S6 块和扫描合并操作(Scan Merging)这里不要和和Patch merging 和 Patch Expanding 弄混。如图2(a)所示,扫描扩展操作沿四个不同方向(左上到右下、右下到左上、右上到左下、左下到右上)展开输入图像)进入序列。然后,这些序列由 S6 块进行处理以进行特征提取,确保来自各个方向的信息得到彻底扫描,从而捕获不同的特征。随后,如图2(b)所示,扫描合并操作对来自四个方向的序列进行求和并合并,将输出图像恢复到与输入相同的大小。 S6 模块源自 Mamba [16],通过根据输入调整 SSM 的参数,在 S4 [17] 之上引入了一种选择机制。这使得模型能够区分并保留相关信息,同时过滤掉不相关的信息。 S6 块的伪代码如算法 1 所示。
3. SS2D中的核心S6
在这里插入图片描述
首先使用nn.linear()和x生成 Δ \Delta Δ,B,C;然后借助 Δ \Delta Δ,A,B生成 A ˉ \bar{A} Aˉ B ˉ \bar{B} Bˉ,最后借助公式计算出输出值y,具体解析见公式注释。

3. Loss function

在这里插入图片描述

1. VM-UNet的引入旨在验证纯基于SSM的模型在医学图像分割任务中的应用潜力。因此,我们专门利用最基本的二元交叉熵和骰子损失(BceDice 损失)以及交叉熵和骰子损失(CeDice 损失)分别作为二元和多类分割任务的损失函数,如方程 5 所示和 6.
2.
3.

实验

在本节中,我们在 VM-UNet 上针对皮肤病变和器官分割任务进行全面的实验。具体来说,我们评估了 VM-UNet 在 ISIC17、ISIC18 和 Synapse 数据集上的医学图像分割任务上的性能。 对于这两个数据集,我们提供了对多个指标的详细评估,包括并集平均交集 (mIoU)、Dice 相似系数 (DSC)、准确性 (Acc)、灵敏度 (Sen) 和特异性 (Spe)。

继之前的工作[28,9]之后,我们将 ISIC17 和 ISIC18 数据集中的图像大小调整为 256×256,将 Synapse 数据集中的图像大小调整为 224×224。为了防止过度拟合,采用了数据增强技术,包括随机翻转和随机旋转。 ISIC17和ISIC18数据集采用BceDice损失函数,Synapse数据集采用CeDice损失函数。

我们将批量大小设置为 32,并使用 AdamW [23] 优化器,初始学习率为 1e-3。 CosineAnnealingLR [22] 用作调度器,最多 50 次迭代,最小学习率为 1e-5。训练周期设置为 300。对于 VM-UNet,我们使用 VMamba-S [20] 的权重初始化编码器和解码器的权重,VMamba-S [20] 在 ImageNet-1k 上预训练。所有实验均在单个 NVIDIA RTX A6000 GPU 上进行。

1.对比实验

在这里插入图片描述
在这里插入图片描述
将 VM-UNet 与一些最先进的模型进行比较,实验结果如表 1 和表 2 所示。对于 ISIC17 和 ISIC18 数据集,我们的 VM-UNet 在 mIoU、DSC 和 Acc 方面优于其他模型指标。对于Synapse数据集,VM-UNet也取得了有竞争力的性能。例如,我们的模型在 DSC 和 HD95 指标上超过了 Swin-UNet(第一个纯基于 Transformer 的模型)1.95% 和 2.34mm。结果证明了基于SSM的模型在医学图像分割任务中的优越性。


找来了ISIC2018数据集上结果榜单如下:
在这里插入图片描述

VM-UNet 距离SOTA还有一点距离!!

2.消融实验

在这里插入图片描述

在本节中,我们使用 ISIC17 和 ISIC18 数据集对 VMUNet 的初始化进行消融实验。我们分别使用 VMamba-T 和 VMamba-S 的预训练权重初始化 VM-UNet1。实验结果如表 3 所示,表明更有效的预训练权重显着增强了 VM-UNet 的下游性能,表明 VM-UNet很大程度上受预训练权重的影响。

可以很明显看出,大一点的模型预训练权重就是厉害!

总体看下来:

  1. 作为mamba模型最重要的突破点,线性计算复杂度没有做实验证明,真的可惜!
  2. 作者选择在对称结构上使用不对称的VSS block数量,也是经过实验得到的(对称好肯定就用对称的了)。因此感觉将mamba模型放到分割任务中,只超过一些经典模型,对于SOTA模型还有一些距离,因此需要做好攻克的准备。
  3. 作者公开了代码,而且代码结构简单好用,太棒了!!

可借鉴参考

  1. 提供了一个baseline

  2. 提供了mamba模型的介绍流程

  3. 阅读U-Mamba 2024
    Ma, J., Li, F., Wang, B.: U-mamba: Enhancing long-range dependency for biomedical image segmentation. arXiv preprint arXiv:2401.04722 (2024)

  4. 阅读SegMamba 2024
    Xing, Z., Ye, T., Yang, Y., Liu, G., Zhu, L.: Segmamba: Long-range sequential modeling mamba for 3d medical image segmentation. arXiv preprint arXiv:2401.13560 (2024)

代码使用

当前mamba_ssm包只发布了Linux版本,因此在Win系统下无法运行

  1. Linux下安装mamba_ssm

pip install mamba_ssm

  1. 复制代码使用
import time
import math
from functools import partial
from typing import Optional, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
try:
    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
except:
    pass

# an alternative for mamba_ssm (in which causal_conv1d is needed)
# try:
#     from selective_scan import selective_scan_fn as selective_scan_fn_v1
#     from selective_scan import selective_scan_ref as selective_scan_ref_v1
# except:
#     pass

DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})"


def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False):
    """
    u: r(B D L)
    delta: r(B D L)
    A: r(D N)
    B: r(B N L)
    C: r(B N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32
    
    ignores:
        [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 
    """
    import numpy as np
    
    # fvcore.nn.jit_handles
    def get_flops_einsum(input_shapes, equation):
        np_arrs = [np.zeros(s) for s in input_shapes]
        optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1]
        for line in optim.split("\n"):
            if "optimized flop" in line.lower():
                # divided by 2 because we count MAC (multiply-add counted as one flop)
                flop = float(np.floor(float(line.split(":")[-1]) / 2))
                return flop
    

    assert not with_complex

    flops = 0 # below code flops = 0
    if False:
        ...
        """
        dtype_in = u.dtype
        u = u.float()
        delta = delta.float()
        if delta_bias is not None:
            delta = delta + delta_bias[..., None].float()
        if delta_softplus:
            delta = F.softplus(delta)
        batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
        is_variable_B = B.dim() >= 3
        is_variable_C = C.dim() >= 3
        if A.is_complex():
            if is_variable_B:
                B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
            if is_variable_C:
                C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
        else:
            B = B.float()
            C = C.float()
        x = A.new_zeros((batch, dim, dstate))
        ys = []
        """

    flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln")
    if with_Group:
        flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln")
    else:
        flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln")
    if False:
        ...
        """
        deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
        if not is_variable_B:
            deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
        else:
            if B.dim() == 3:
                deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
            else:
                B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
                deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
        if is_variable_C and C.dim() == 4:
            C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
        last_state = None
        """
    
    in_for_flops = B * D * N   
    if with_Group:
        in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd")
    else:
        in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd")
    flops += L * in_for_flops 
    if False:
        ...
        """
        for i in range(u.shape[2]):
            x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
            if not is_variable_C:
                y = torch.einsum('bdn,dn->bd', x, C)
            else:
                if C.dim() == 3:
                    y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
                else:
                    y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
            if i == u.shape[2] - 1:
                last_state = x
            if y.is_complex():
                y = y.real * 2
            ys.append(y)
        y = torch.stack(ys, dim=2) # (batch dim L)
        """

    if with_D:
        flops += B * D * L
    if with_Z:
        flops += B * D * L
    if False:
        ...
        """
        out = y if D is None else y + u * rearrange(D, "d -> d 1")
        if z is not None:
            out = out * F.silu(z)
        out = out.to(dtype=dtype_in)
        """
    
    return flops


class PatchEmbed2D(nn.Module):
    r""" Image to Patch Embedding
    Args:
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """
    def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None, **kwargs):
        super().__init__()
        if isinstance(patch_size, int):
            patch_size = (patch_size, patch_size)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        x = self.proj(x).permute(0, 2, 3, 1)
        if self.norm is not None:
            x = self.norm(x)
        return x


class PatchMerging2D(nn.Module):
    r""" Patch Merging Layer.
    Args:
        input_resolution (tuple[int]): Resolution of input feature.
        dim (int): Number of input channels.
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        B, H, W, C = x.shape

        SHAPE_FIX = [-1, -1]
        if (W % 2 != 0) or (H % 2 != 0):
            print(f"Warning, x.shape {x.shape} is not match even ===========", flush=True)
            SHAPE_FIX[0] = H // 2
            SHAPE_FIX[1] = W // 2

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C

        if SHAPE_FIX[0] > 0:
            x0 = x0[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x1 = x1[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x2 = x2[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
            x3 = x3[:, :SHAPE_FIX[0], :SHAPE_FIX[1], :]
        
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, H//2, W//2, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x
    

class PatchExpand2D(nn.Module):
    def __init__(self, dim, dim_scale=2, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim*2
        self.dim_scale = dim_scale
        self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False)
        self.norm = norm_layer(self.dim // dim_scale)

    def forward(self, x):
        B, H, W, C = x.shape
        x = self.expand(x)

        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale)
        x= self.norm(x)

        return x
    

class Final_PatchExpand2D(nn.Module):
    def __init__(self, dim, dim_scale=4, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.dim_scale = dim_scale
        self.expand = nn.Linear(self.dim, dim_scale*self.dim, bias=False)
        self.norm = norm_layer(self.dim // dim_scale)

    def forward(self, x):
        B, H, W, C = x.shape
        x = self.expand(x)

        x = rearrange(x, 'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=self.dim_scale, p2=self.dim_scale, c=C//self.dim_scale)
        x= self.norm(x)

        return x


class SS2D(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        # d_state="auto", # 20240109
        d_conv=3,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        dropout=0.,
        conv_bias=True,
        bias=False,
        device=None,
        dtype=None,
        **kwargs,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank

        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
        self.conv2d = nn.Conv2d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            groups=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            padding=(d_conv - 1) // 2,
            **factory_kwargs,
        )
        self.act = nn.SiLU()

        self.x_proj = (
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
        )
        self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
        del self.x_proj

        self.dt_projs = (
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
        )
        self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
        self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
        del self.dt_projs
        
        self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
        self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)

        # self.selective_scan = selective_scan_fn
        self.forward_core = self.forward_corev0

        self.out_norm = nn.LayerNorm(self.d_inner)
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        self.dropout = nn.Dropout(dropout) if dropout > 0. else None

    @staticmethod
    def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
        dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        dt_proj.bias._no_reinit = True
        
        return dt_proj

    @staticmethod
    def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
        # S4D real initialization
        A = repeat(
            torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        if copies > 1:
            A_log = repeat(A_log, "d n -> r d n", r=copies)
            if merge:
                A_log = A_log.flatten(0, 1)
        A_log = nn.Parameter(A_log)
        A_log._no_weight_decay = True
        return A_log

    @staticmethod
    def D_init(d_inner, copies=1, device=None, merge=True):
        # D "skip" parameter
        D = torch.ones(d_inner, device=device)
        if copies > 1:
            D = repeat(D, "n1 -> r n1", r=copies)
            if merge:
                D = D.flatten(0, 1)
        D = nn.Parameter(D)  # Keep in fp32
        D._no_weight_decay = True
        return D

    def forward_corev0(self, x: torch.Tensor):
        self.selective_scan = selective_scan_fn
        
        B, C, H, W = x.shape
        L = H * W
        K = 4

        x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
        xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)

        x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
        # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
        dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
        dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
        # dts = dts + self.dt_projs_bias.view(1, K, -1, 1)

        xs = xs.float().view(B, -1, L) # (b, k * d, l)
        dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
        Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Ds = self.Ds.float().view(-1) # (k * d)
        As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)  # (k * d, d_state)
        dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)

        out_y = self.selective_scan(
            xs, dts, 
            As, Bs, Cs, Ds, z=None,
            delta_bias=dt_projs_bias,
            delta_softplus=True,
            return_last_state=False,
        ).view(B, K, -1, L)
        assert out_y.dtype == torch.float

        inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
        wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
        invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)

        return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y

    # an alternative to forward_corev1
    def forward_corev1(self, x: torch.Tensor):
        self.selective_scan = selective_scan_fn

        B, C, H, W = x.shape
        L = H * W
        K = 4

        x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
        xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l)

        x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
        # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1)
        dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
        dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
        # dts = dts + self.dt_projs_bias.view(1, K, -1, 1)

        xs = xs.float().view(B, -1, L) # (b, k * d, l)
        dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l)
        Bs = Bs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Cs = Cs.float().view(B, K, -1, L) # (b, k, d_state, l)
        Ds = self.Ds.float().view(-1) # (k * d)
        As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)  # (k * d, d_state)
        dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d)

        out_y = self.selective_scan(
            xs, dts, 
            As, Bs, Cs, Ds,
            delta_bias=dt_projs_bias,
            delta_softplus=True,
        ).view(B, K, -1, L)
        assert out_y.dtype == torch.float

        inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
        wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
        invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)

        return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y

    def forward(self, x: torch.Tensor, **kwargs):
        B, H, W, C = x.shape

        xz = self.in_proj(x)
        x, z = xz.chunk(2, dim=-1) # (b, h, w, d)

        x = x.permute(0, 3, 1, 2).contiguous()
        x = self.act(self.conv2d(x)) # (b, d, h, w)
        y1, y2, y3, y4 = self.forward_core(x)
        assert y1.dtype == torch.float32
        y = y1 + y2 + y3 + y4
        y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
        y = self.out_norm(y)
        y = y * F.silu(z)
        out = self.out_proj(y)
        if self.dropout is not None:
            out = self.dropout(out)
        return out


class VSSBlock(nn.Module):
    def __init__(
        self,
        hidden_dim: int = 0,
        drop_path: float = 0,
        norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
        attn_drop_rate: float = 0,
        d_state: int = 16,
        **kwargs,
    ):
        super().__init__()
        self.ln_1 = norm_layer(hidden_dim)
        self.self_attention = SS2D(d_model=hidden_dim, dropout=attn_drop_rate, d_state=d_state, **kwargs)
        self.drop_path = DropPath(drop_path)

    def forward(self, input: torch.Tensor):
        x = input + self.drop_path(self.self_attention(self.ln_1(input)))
        return x


class VSSLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(
        self, 
        dim, 
        depth, 
        attn_drop=0.,
        drop_path=0., 
        norm_layer=nn.LayerNorm, 
        downsample=None, 
        use_checkpoint=False, 
        d_state=16,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.use_checkpoint = use_checkpoint

        self.blocks = nn.ModuleList([
            VSSBlock(
                hidden_dim=dim,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer,
                attn_drop_rate=attn_drop,
                d_state=d_state,
            )
            for i in range(depth)])
        
        if True: # is this really applied? Yes, but been overriden later in VSSM!
            def _init_weights(module: nn.Module):
                for name, p in module.named_parameters():
                    if name in ["out_proj.weight"]:
                        p = p.clone().detach_() # fake init, just to keep the seed ....
                        nn.init.kaiming_uniform_(p, a=math.sqrt(5))
            self.apply(_init_weights)

        if downsample is not None:
            self.downsample = downsample(dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None


    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        
        if self.downsample is not None:
            x = self.downsample(x)

        return x
    


class VSSLayer_up(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        depth (int): Number of blocks.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(
        self, 
        dim, 
        depth, 
        attn_drop=0.,
        drop_path=0., 
        norm_layer=nn.LayerNorm, 
        upsample=None, 
        use_checkpoint=False, 
        d_state=16,
        **kwargs,
    ):
        super().__init__()
        self.dim = dim
        self.use_checkpoint = use_checkpoint

        self.blocks = nn.ModuleList([
            VSSBlock(
                hidden_dim=dim,
                drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                norm_layer=norm_layer,
                attn_drop_rate=attn_drop,
                d_state=d_state,
            )
            for i in range(depth)])
        
        if True: # is this really applied? Yes, but been overriden later in VSSM!
            def _init_weights(module: nn.Module):
                for name, p in module.named_parameters():
                    if name in ["out_proj.weight"]:
                        p = p.clone().detach_() # fake init, just to keep the seed ....
                        nn.init.kaiming_uniform_(p, a=math.sqrt(5))
            self.apply(_init_weights)

        if upsample is not None:
            self.upsample = upsample(dim=dim, norm_layer=norm_layer)
        else:
            self.upsample = None


    def forward(self, x):
        if self.upsample is not None:
            x = self.upsample(x)
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        return x
    


class VSSM(nn.Module):
    def __init__(self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], depths_decoder=[2, 9, 2, 2],
                 dims=[96, 192, 384, 768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                 norm_layer=nn.LayerNorm, patch_norm=True,
                 use_checkpoint=False, **kwargs):
        super().__init__()
        self.num_classes = num_classes
        self.num_layers = len(depths)
        if isinstance(dims, int):
            dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
        self.embed_dim = dims[0]
        self.num_features = dims[-1]
        self.dims = dims

        self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim,
            norm_layer=norm_layer if patch_norm else None)

        # WASTED absolute position embedding ======================
        self.ape = False
        # self.ape = False
        # drop_rate = 0.0
        if self.ape:
            self.patches_resolution = self.patch_embed.patches_resolution
            self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim))
            trunc_normal_(self.absolute_pos_embed, std=.02)
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1]

        self.layers = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = VSSLayer(
                dim=dims[i_layer],
                depth=depths[i_layer],
                d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
                drop=drop_rate, 
                attn_drop=attn_drop_rate,
                drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                norm_layer=norm_layer,
                downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None,
                use_checkpoint=use_checkpoint,
            )
            self.layers.append(layer)

        self.layers_up = nn.ModuleList()
        for i_layer in range(self.num_layers):
            layer = VSSLayer_up(
                dim=dims_decoder[i_layer],
                depth=depths_decoder[i_layer],
                d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 20240109
                drop=drop_rate, 
                attn_drop=attn_drop_rate,
                drop_path=dpr_decoder[sum(depths_decoder[:i_layer]):sum(depths_decoder[:i_layer + 1])],
                norm_layer=norm_layer,
                upsample=PatchExpand2D if (i_layer != 0) else None,
                use_checkpoint=use_checkpoint,
            )
            self.layers_up.append(layer)

        self.final_up = Final_PatchExpand2D(dim=dims_decoder[-1], dim_scale=4, norm_layer=norm_layer)
        self.final_conv = nn.Conv2d(dims_decoder[-1]//4, num_classes, 1)

        # self.norm = norm_layer(self.num_features)
        # self.avgpool = nn.AdaptiveAvgPool1d(1)
        # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.apply(self._init_weights)

    def _init_weights(self, m: nn.Module):
        """
        out_proj.weight which is previously initilized in VSSBlock, would be cleared in nn.Linear
        no fc.weight found in the any of the model parameters
        no nn.Embedding found in the any of the model parameters
        so the thing is, VSSBlock initialization is useless
        
        Conv2D is not intialized !!!
        """
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'absolute_pos_embed'}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward_features(self, x):
        skip_list = []
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            skip_list.append(x)
            x = layer(x)
        return x, skip_list
    
    def forward_features_up(self, x, skip_list):
        for inx, layer_up in enumerate(self.layers_up):
            if inx == 0:
                x = layer_up(x)
            else:
                x = layer_up(x+skip_list[-inx])

        return x
    
    def forward_final(self, x):
        x = self.final_up(x)
        x = x.permute(0,3,1,2)
        x = self.final_conv(x)
        return x

    def forward_backbone(self, x):
        x = self.patch_embed(x)
        if self.ape:
            x = x + self.absolute_pos_embed
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)
        return x

    def forward(self, x):
        x, skip_list = self.forward_features(x)
        x = self.forward_features_up(x, skip_list)
        x = self.forward_final(x)
        
        return x


import torch
from torch import nn


class VMUNet(nn.Module):
    def __init__(self, 
                 input_channels=3, 
                 num_classes=1,
                 depths=[2, 2, 9, 2], 
                 depths_decoder=[2, 9, 2, 2],
                 drop_path_rate=0.2,
                 load_ckpt_path=None,
                ):
        super().__init__()

        self.load_ckpt_path = load_ckpt_path
        self.num_classes = num_classes

        self.vmunet = VSSM(in_chans=input_channels,
                           num_classes=num_classes,
                           depths=depths,
                           depths_decoder=depths_decoder,
                           drop_path_rate=drop_path_rate,
                        )
    
    def forward(self, x):
        if x.size()[1] == 1:
            x = x.repeat(1,3,1,1)
        logits = self.vmunet(x)
        if self.num_classes == 1: return torch.sigmoid(logits)
        else: return logits
    
    def load_from(self):
        if self.load_ckpt_path is not None:
            model_dict = self.vmunet.state_dict()
            modelCheckpoint = torch.load(self.load_ckpt_path)
            pretrained_dict = modelCheckpoint['model']
            # 过滤操作
            new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
            model_dict.update(new_dict)
            # 打印出来,更新了多少的参数
            print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict)))
            self.vmunet.load_state_dict(model_dict)

            not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()]
            print('Not loaded keys:', not_loaded_keys)
            print("encoder loaded finished!")

            model_dict = self.vmunet.state_dict()
            modelCheckpoint = torch.load(self.load_ckpt_path)
            pretrained_odict = modelCheckpoint['model']
            pretrained_dict = {}
            for k, v in pretrained_odict.items():
                if 'layers.0' in k: 
                    new_k = k.replace('layers.0', 'layers_up.3')
                    pretrained_dict[new_k] = v
                elif 'layers.1' in k: 
                    new_k = k.replace('layers.1', 'layers_up.2')
                    pretrained_dict[new_k] = v
                elif 'layers.2' in k: 
                    new_k = k.replace('layers.2', 'layers_up.1')
                    pretrained_dict[new_k] = v
                elif 'layers.3' in k: 
                    new_k = k.replace('layers.3', 'layers_up.0')
                    pretrained_dict[new_k] = v
            # 过滤操作
            new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
            model_dict.update(new_dict)
            # 打印出来,更新了多少的参数
            print('Total model_dict: {}, Total pretrained_dict: {}, update: {}'.format(len(model_dict), len(pretrained_dict), len(new_dict)))
            self.vmunet.load_state_dict(model_dict)
            
            # 找到没有加载的键(keys)
            not_loaded_keys = [k for k in pretrained_dict.keys() if k not in new_dict.keys()]
            print('Not loaded keys:', not_loaded_keys)
            print("decoder loaded finished!")

x = torch.randn(3, 3, 256, 256).to("cuda:0")
net = VMUNet().to("cuda:0")
print(net(x).shape)

直接复制使用,相较于原始代码,因为没安装成功selective_scan包,扫描代码发现其中只用到了一处,将其改成了mamba_scan中的函数了,实测效果不错,在自己最近做的任务中,分数可以排进前五!!
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1472106.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JANGOW: 1.0.1

kali:192.168.223.128 主机发现 nmap -sP 192.168.223.0/24 端口扫描 nmap -p- 192.168.223.154 开启了21 80端口 web看一下&#xff0c;有个busque.php参数是buscar,但是不知道输入什么&#xff0c;尝试文件包含失败 扫描目录 dirsearch -u http://192.168.223.154 dirse…

mysql 分表实战

本文主要介绍基于range分区的相关 1、业务需求&#xff0c;每日160w数据&#xff0c;每月2000w;解决大表数据读写性能问题。 2、数据库mysql 8.0.34&#xff0c;默认innerDB;mysql自带的逻辑分表 3、分表的目的:解决大表性能差&#xff0c;小表缩小查询单位的特点(其实优化的精…

人事|人事管理系统|基于Springboot的人事管理系统设计与实现(源码+数据库+文档)

人事管理系统目录 目录 基于Springboot的人事管理系统设计与实现 一、前言 二、系统功能设计 三、系统实现 1、管理员登录 2、员工管理 3、公告信息管理 4、公告类型管理 5、培训管理 6、培训类型管理 四、数据库设计 1、实体ER图 五、核心代码 六、论文参考 七、…

2024年环境安全科学、材料工程与制造国际学术会议(ESSMEM2024)

【EI检索】2024年环境安全科学、材料工程与制造国际学术会议&#xff08;ESSMEM2024) 会议简介 我们很高兴邀请您参加将在三亚举行的2024年环境安全科学、材料工程和制造国际学术会议&#xff08;ESSMEM 2024&#xff09;。 ESSMEM2024将汇集世界各国和地区的研究人员&…

CrackRTF1

由此可知为六位密码且<100000自动退出 发现是哈希加密&#xff0c;经过网上查找后了解到 0x8004u 为 SHA1 加密&#xff0c;故从网上寻找脚本并进行爆破 密码为123321 找到一个取巧的网站 MD5免费在线解密破解_MD5在线加密-SOMD5 Flag{N0_M0re_Free_Bugs}

Mysql索引讲解及创建

索引介绍 1.什么是索引&#xff1f; 索引是存储引擎中一种数据结构&#xff0c;或者说数据的组织方式&#xff0c;又称之为键key&#xff0c;是存储引擎用于快速找到记录的 一种数据结构。 为数据建立索引就好比是为书建目录&#xff0c;或者说是为字典创建音序表&#xff0c;…

视频号视频下载教程:如何把微信视频号的视频下载下来

视频号下载相信不少人都多少有一些了解&#xff0c;但今天我们就来细说一下关于视频号视频下载的相关疑问&#xff0c;以及大家经常会问到底如何把微信视频号的视频下载下来&#xff1f; 视频号视频下载教程 视频号链接提取器详细使用指南&#xff0c;教你轻松下载号视频&…

LabVIEW燃料电池船舶电力推进监控系统

LabVIEW燃料电池船舶电力推进监控系统 随着全球经济一体化的推进&#xff0c;航运业的发展显得尤为重要&#xff0c;大约80%的世界贸易依靠海上运输实现。传统的船舶推进系统主要依赖于柴油机&#xff0c;这不仅耗能高&#xff0c;而且排放严重&#xff0c;对资源和环境的影响…

uniapp播放mp4省流方案

背景&#xff1a; 因为项目要播放一个宣传和讲解视频&#xff0c;视频文件过大&#xff0c;同时还为了节省存储流量&#xff0c;想到了一个方案&#xff0c;用m3u8切片替代mp4。 m3u8&#xff1a;切片播放&#xff0c;可以理解为一个1G的视频文件&#xff0c;自行设置文…

Antd 嵌套子表格 defaultExpandAllRows 默认展开不生效

问题描述 在使用 antd 嵌套子表格时&#xff0c;想要默认展开所有子列表&#xff0c;设置属性:defaultExpandAllRows“true”&#xff0c;但是子列表没有展开 原因分析 defaultExpandAllRows 属性官网定义&#xff1a; 从官网定义可知&#xff0c;defaultExpandAllRows 属性…

P2680 [NOIP2015 提高组] 运输计划 第一个测试点信息 || 被卡常,链式前向星应该解决的是vector的push_back频繁扩容的耗时

[NOIP2015 提高组] 运输计划 - 洛谷 目录 测试点信息 P2680_1.in P2680_1.out 图&#xff1a; 50分参考代码&#xff08;开了n^2的数组&#xff0c;MLE了&#xff09;&#xff1a; 测试点信息 Subtask #0 #1 P2680_1.in 100 1 7 97 4 89 2 0 40 91 1 70 84 1 36 92 3 …

ASLR 和 PIE

前言 ASLR&#xff08;Address Space Layout Randomization&#xff0c;地址空间随机化&#xff09;是一种内存攻击缓解技术&#xff0c;是一种操作系统用来抵御缓冲区溢出攻击的内存保护机制。这种技术使得系统上运行的进程的内存地址无法被预测&#xff0c;使得与这些进程有…

Spring Cloud学习

1、什么是SpringCloud Spring cloud 流应用程序启动器是基于 Spring Boot 的 Spring 集成应用程序&#xff0c;提供与外部系统的集成。Spring cloud Task&#xff0c;一个生命周期短暂的微服务框架&#xff0c;用于快速构建执行有限数据处理的应用程序。Spring cloud 流应用程…

【Java设计模式】二、单例模式

文章目录 1、懒汉式2、双重检查3、静态内部类4、饿汉式5、枚举6、单例模式的破坏&#xff1a;序列化和反序列化7、单例模式的破坏&#xff1a;反射 单例模式即在程序中想要保持一个实例对象&#xff0c;让某个类只有一个实例单例类必须自己创建自己唯一的实例&#xff0c;并对外…

mysql的增删改查(常用)

增(insert) 语法&#xff1a; insert into 表名&#xff08;字段&#xff09; values( 字段对应的值) 案例&#xff1a; 创建一个学生表 结构如下&#xff1a; create table student(id int ,name varchar(20),age int); 向表中插入2条数据 create table student(id int ,n…

802.11局域网的 MAC 层协议、CSMA/CA

目录 802.11 局域网的 MAC 层协议 1 CSMA/CA 协议 无线局域网不能使用 CSMA/CD 无线局域网可以使用 CSMA 协议 802.11 的 MAC 层 分布协调功能 DCF 点协调功能 PCF CSMA/CA 协议的要点 2 时间间隔 DIFS 的重要性 SIFS DIFS 3 争用信道的过程 时隙长度的确定 退避…

Java+SpringBoot+Vue+MySQL构建银行客户管理新平台

✍✍计算机毕业编程指导师 ⭐⭐个人介绍&#xff1a;自己非常喜欢研究技术问题&#xff01;专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目&#xff1a;有源码或者技术上的问题欢迎在评论区一起讨论交流&#xff01; ⚡⚡ Java、…

HUAWEI Programming Contest 2024(AtCoder Beginner Contest 342)

D - Square Pair 题目大意 给一长为的数组&#xff0c;问有多少对&#xff0c;两者相乘为非负整数完全平方数 解题思路 一个数除以其能整除的最大的完全平方数&#xff0c;看前面有多少个与其余数相同的数&#xff0c;两者乘积满足条件&#xff08;已经是完全平方数的部分无…

2-22 方法、面向对象、类、JVM内存、构造方法

文章目录 方法的重载面向对象类、属性和方法成员变量默认值属性JVM简单内存分析栈空间堆空间 构造方法执行过程构造器注意点 方法的重载 一个类中名称相同&#xff0c;但是参数列表不同的方法 参数列表不同是指&#xff1a; 形参类型形参个数形参顺序 面向对象 field —— …

JAVA工程师面试专题-《JVM篇》

目录 一、运行时数据区 1、说一下JVM的主要组成部分及其作用&#xff1f; 2、说一下 JVM 运行时数据区 &#xff1f; 3、说一下堆栈的区别 4、成员变量、局部变量、类变量分别存储在什么地方&#xff1f; 5、类常量池、运行时常量池、字符串常量池有什么区别&#xff1f;…