MedMamba代码解释及用于糖尿病视网膜病变分类

news2024/11/25 1:12:11

MedMamba原理和用于糖尿病视网膜病变检测尝试

1.MedMamba原理

image-20241010110028101

MedMamba发表于2024.9.28,是构建在Vision Mamba基础之上,融合了卷积神经网的架构,结构如下图:

image-20241010110201286

原理简述就是图片输入后按通道输入后切分为两部分,一部分走二维分组卷积提取局部特征,一部分利用Vision Mamba中的SS2D模块提取所谓的全局特征,两个分支的输出通过通道维度的拼接后,经过channel shuffle增加信息融合。

2.代码解释

模型代码就在源码的MedMamba.py文件下,对涉及到的代码我进行了详细注释:

  • mamba部分

    基本上是使用Vision Mamaba的SS2D:

class SS2D(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        # d_state="auto", # 20240109
        d_conv=3,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        dropout=0.,
        conv_bias=True,
        bias=False,
        device=None,
        dtype=None,
        **kwargs,
    ):
        # 设置设备和数据类型的关键参数
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model # 模型维度
        self.d_state = d_state # 状态维度
        # self.d_state = math.ceil(self.d_model / 6) if d_state == "auto" else d_model # 20240109
        self.d_conv = d_conv # 卷积核的大小
        self.expand = expand  # 扩展因子
        self.d_inner = int(self.expand * self.d_model)  # 内部维度,等于模型维度乘以扩展因子
        # 时间步长的秩,默认为模型维度除以16
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        # 输入投影层,将模型维度投影到内部维度的两倍,用于后续操作
        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
        # 深度卷积层,输入和输出通道数相同,组数等于内部维度,用于空间特征提取
        self.conv2d = nn.Conv2d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            groups=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            padding=(d_conv - 1) // 2, # 保证输出的空间维度与输入相同
            **factory_kwargs,
        )
        self.act = nn.SiLU() # 激活函数使用 SiLU
        # 定义多个线性投影层,将内部维度投影到不同大小的向量,用于时间步长和状态
        self.x_proj = (
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
            nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs), 
        )
        # 将四个线性投影层的权重合并为一个参数,方便计算
        self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K=4, N, inner)
        # 删除单独的投影层以节省内存
        del self.x_proj
        # 初始化时间步长的线性投影,定义四组时间步长投影参数
        self.dt_projs = (
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
            self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs),
        )
        # 将时间步长的权重和偏置参数合并为可训练参数
        self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K=4, inner, rank)
        self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K=4, inner)
        del self.dt_projs
        # 初始化 S4D 的 A 参数,用于状态更新计算
        self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True) # (K=4, D, N)
        # 初始化 D 参数,用于跳跃连接的计算
        self.Ds = self.D_init(self.d_inner, copies=4, merge=True) # (K=4, D, N)
        # 选择核心的前向计算函数版本,默认为 forward_corev0
        # self.selective_scan = selective_scan_fn
        self.forward_core = self.forward_corev0
        # 输出层的层归一化,归一化到内部维度
        self.out_norm = nn.LayerNorm(self.d_inner)
        # 输出投影层,将内部维度投影回原始模型维度
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
        # 设置 dropout 层,如果 dropout 参数大于 0,则应用随机失活以防止过拟合
        self.dropout = nn.Dropout(dropout) if dropout > 0. else None

    @staticmethod
    def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs):
        dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)
        # 初始化用于时间步长计算的线性投影层
        # Initialize special dt projection to preserve variance at initialization
        # 特殊初始化方法,用于保持初始化时的方差不变
        dt_init_std = dt_rank**-0.5 * dt_scale
        if dt_init == "constant": # 初始化为常数
            nn.init.constant_(dt_proj.weight, dt_init_std)
        elif dt_init == "random": # 初始化为均匀随机数
            nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        # 初始化偏置,以便在使用 F.softplus 时,结果处于 dt_min 和 dt_max 之间
        dt = torch.exp(
            torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        # softplus 的逆操作,确保偏置初始化在合适范围内
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            dt_proj.bias.copy_(inv_dt)  # 设置偏置参数
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        dt_proj.bias._no_reinit = True # 将该偏置标记为不重新初始化
        
        return dt_proj
  • SS_Conv_SSM

    这部分就是论文提出的创新点,图片中的结构

    class SS_Conv_SSM(nn.Module):
        def __init__(
            self,
            hidden_dim: int = 0,
            drop_path: float = 0,
            norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
            attn_drop_rate: float = 0,
            d_state: int = 16,
            **kwargs,
        ):
            super().__init__()
            # 初始化第一个归一化层,归一化的维度是隐藏维度的一半
            self.ln_1 = norm_layer(hidden_dim//2)
            # 初始化自注意力模块 SS2D,输入维度为隐藏维度的一半
            self.self_attention = SS2D(d_model=hidden_dim//2,
                                       dropout=attn_drop_rate,
                                       d_state=d_state,
                                       **kwargs)
            # DropPath 层,用于随机丢弃路径,提高模型的泛化能力
            self.drop_path = DropPath(drop_path)
            # 定义卷积模块,由多个卷积层和批量归一化层组成,用于特征提取
            self.conv33conv33conv11 = nn.Sequential(
                nn.BatchNorm2d(hidden_dim // 2),
                nn.Conv2d(in_channels=hidden_dim//2,out_channels=hidden_dim//2,kernel_size=3,stride=1,padding=1),
                nn.BatchNorm2d(hidden_dim//2),
                nn.ReLU(),
                nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(hidden_dim // 2),
                nn.ReLU(),
                nn.Conv2d(in_channels=hidden_dim // 2, out_channels=hidden_dim // 2, kernel_size=1, stride=1),
                nn.ReLU()
            )
            # 注释掉的最终卷积层,可能用于进一步调整输出维度
            # self.finalconv11 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1, stride=1)
        def forward(self, input: torch.Tensor):
            # 将输入张量沿最后一个维度分割为左右两部分
            input_left, input_right = input.chunk(2,dim=-1)
            # 对右侧输入进行归一化和自注意力操作,之后应用 DropPath 随机丢弃
            x = self.drop_path(self.self_attention(self.ln_1(input_right)))
            # 将左侧输入从 (batch_size, height, width, channels)
            # 转换为 (batch_size, channels, height, width) 以适应卷积操作
            input_left = input_left.permute(0,3,1,2).contiguous()
            input_left = self.conv33conv33conv11(input_left)
            # 将卷积后的左侧输入转换回原来的形状 (batch_size, height, width, channels)
            input_left = input_left.permute(0,2,3,1).contiguous()
            # 将左侧和右侧的输出在最后一个维度上拼接起来
            output = torch.cat((input_left,x),dim=-1)
            # 对拼接后的输出进行通道混洗,增加特征的融合
            output = channel_shuffle(output,groups=2)
            # 返回最终的输出,增加残差连接,将输入与输出相加
            return output+input
    
  • VSSLayer

    有以上结构堆叠构成网络结构

    class VSSLayer(nn.Module):
        """ A basic Swin Transformer layer for one stage.
        Args:
            dim (int): Number of input channels.
            depth (int): Number of blocks.
            drop (float, optional): Dropout rate. Default: 0.0
            attn_drop (float, optional): Attention dropout rate. Default: 0.0
            drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
            norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
            downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
            use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
        """
    
        def __init__(
            self, 
            dim, 
            depth, 
            attn_drop=0.,
            drop_path=0., 
            norm_layer=nn.LayerNorm, 
            downsample=None, 
            use_checkpoint=False, 
            d_state=16,
            **kwargs,
        ):
            super().__init__()
            # 设置输入通道数
            self.dim = dim
            # 是否使用检查点
            self.use_checkpoint = use_checkpoint
            # 创建 SS_Conv_SSM 块列表,数量为 depth
            self.blocks = nn.ModuleList([
                SS_Conv_SSM(
                    hidden_dim=dim, # 隐藏层维度等于输入维度
                    # 处理随机深度的丢弃率
                    drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                    norm_layer=norm_layer, # 使用的归一化层
                    attn_drop_rate=attn_drop, # 注意力丢弃率
                    d_state=d_state, # 状态维度
                )
                for i in range(depth)]) # 重复 depth 次构建块
            # 初始化权重 (暂时没有真正初始化,可能在后续被重写)
            # 确保这一初始化应用于模型 (在 VSSM 中被覆盖)
            if True: # is this really applied? Yes, but been overriden later in VSSM!
                # 对每个模块的参数进行初始化
                def _init_weights(module: nn.Module):
                    for name, p in module.named_parameters():
                        if name in ["out_proj.weight"]:
                            # 克隆并分离参数 p,用于保持随机数种子一致
                            p = p.clone().detach_() # fake init, just to keep the seed ....
                            # 使用 Kaiming 均匀初始化方法
                            nn.init.kaiming_uniform_(p, a=math.sqrt(5))
                # 应用初始化函数到整个模型
                self.apply(_init_weights)
            # 如果提供了下采样层,则使用该层,否则设置为 None
            if downsample is not None:
                self.downsample = downsample(dim=dim, norm_layer=norm_layer)
            else:
                self.downsample = None
    
    
        def forward(self, x):
            # 逐块应用 SS_Conv_SSM 模块
            for blk in self.blocks:
                # 如果使用检查点,则通过检查点执行前向传播,节省内存
                if self.use_checkpoint:
                    x = checkpoint.checkpoint(blk, x)
                else:
                    # 否则直接进行前向传播
                    x = blk(x)
            # 如果存在下采样层,则应用下采样层
            if self.downsample is not None:
                x = self.downsample(x)
            # 返回最终的输出张量
            return x
    
  • 最终的网络模型类

    class VSSM(nn.Module):
        def __init__(self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 4, 2], depths_decoder=[2, 9, 2, 2],
                     dims=[96,192,384,768], dims_decoder=[768, 384, 192, 96], d_state=16, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
                     norm_layer=nn.LayerNorm, patch_norm=True,
                     use_checkpoint=False, **kwargs):
            super().__init__()
            self.num_classes = num_classes # 设置分类的类别数目
            self.num_layers = len(depths)  # 设置层的数量,即编码器层的数量
            # 如果 dims 是一个整数,则自动扩展为一个包含每一层维度的列表
            if isinstance(dims, int):
                dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)]
            self.embed_dim = dims[0]  # 嵌入维度等于第一层的维度
            self.num_features = dims[-1] # 特征维度等于最后一层的维度
            self.dims = dims # 记录每一层的维度
            # 初始化补丁嵌入模块,将输入图像分割成补丁并进行线性投影
            self.patch_embed = PatchEmbed2D(patch_size=patch_size, in_chans=in_chans, embed_dim=self.embed_dim,
                norm_layer=norm_layer if patch_norm else None)
    
            # WASTED absolute position embedding ======================
            # 是否使用绝对位置编码,默认情况下不使用
            self.ape = False
            # self.ape = False
            # drop_rate = 0.0
            # 如果使用绝对位置编码,则初始化位置编码参数
            if self.ape:
                self.patches_resolution = self.patch_embed.patches_resolution
                # 创建位置编码的可训练参数,并进行截断正态分布初始化
                self.absolute_pos_embed = nn.Parameter(torch.zeros(1, *self.patches_resolution, self.embed_dim))
                trunc_normal_(self.absolute_pos_embed, std=.02)
            # 位置编码的 Dropout 层
            self.pos_drop = nn.Dropout(p=drop_rate)
            # 使用线性函数生成每层的随机深度丢弃率
            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # 随机深度衰减规则
            # 解码器部分的随机深度衰减
            dpr_decoder = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths_decoder))][::-1]
            # 初始化编码器的层列表
            self.layers = nn.ModuleList()
            for i_layer in range(self.num_layers):  # 创建每一层的 VSSLayer
                layer = VSSLayer(
                    dim=dims[i_layer], # 输入维度
                    depth=depths[i_layer], # 当前层包含的块数量
                    d_state=math.ceil(dims[0] / 6) if d_state is None else d_state, # 状态维度
                    drop=drop_rate,  # Dropout率
                    attn_drop=attn_drop_rate, # 注意力 Dropout率
                    # 当前层的随机深度丢弃率
                    drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
                    # 归一化层类型
                    norm_layer=norm_layer,
                    # 下采样层,最后一层不进行下采样
                    downsample=PatchMerging2D if (i_layer < self.num_layers - 1) else None,
                    # 是否使用检查点技术节省内存
                    use_checkpoint=use_checkpoint,
                )
                # 将层添加到层列表中
                self.layers.append(layer)
    
    
            # self.norm = norm_layer(self.num_features)
            # 平均池化层,用于将特征池化为单个值
            self.avgpool = nn.AdaptiveAvgPool2d(1)
            # 分类头部,使用线性层将特征映射到类别数目
            self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
            # 初始化模型权重
            self.apply(self._init_weights)
            # 对模型中的卷积层进行 Kaiming 正态分布初始化
            for m in self.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        def _init_weights(self, m: nn.Module):
            """
            out_proj.weight which is previously initilized in SS_Conv_SSM, would be cleared in nn.Linear
            no fc.weight found in the any of the model parameters
            no nn.Embedding found in the any of the model parameters
            so the thing is, SS_Conv_SSM initialization is useless
            
            Conv2D is not intialized !!!
            """
            # 对线性层和归一化层进行权重初始化
            if isinstance(m, nn.Linear):
                # 对线性层的权重使用截断正态分布初始化
                trunc_normal_(m.weight, std=.02)
                # 如果存在偏置,则将其初始化为 0
                if isinstance(m, nn.Linear) and m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                # 对 LayerNorm 层的偏置和权重初始化
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
    
        @torch.jit.ignore
        def no_weight_decay(self):
            # 返回不需要权重衰减的参数名
            return {'absolute_pos_embed'}
    
        @torch.jit.ignore
        def no_weight_decay_keywords(self):
            # 返回不需要权重衰减的关键字
            return {'relative_position_bias_table'}
    
        def forward_backbone(self, x):
            # 使用补丁嵌入模块处理输入张量
            x = self.patch_embed(x)
            if self.ape:
                # 如果使用绝对位置编码,则将位置编码加到输入特征上
                x = x + self.absolute_pos_embed
            # 位置编码之后应用 Dropout
            x = self.pos_drop(x)
            # 逐层通过编码器层
            for layer in self.layers:
                x = layer(x)
            return x
    
        def forward(self, x):
            # 通过骨干网络提取特征
            x = self.forward_backbone(x)
            # 变换维度以适应池化操
            x = x.permute(0,3,1,2)
            # 使用自适应平均池化将特征降维
            x = self.avgpool(x)
            # 展平成一个向量
            x = torch.flatten(x,start_dim=1)
            # 通过分类头进行最终的类别预测
            x = self.head(x)
            return x
    

    作者在原文中尝试了大中小三个不同的参数版本

    medmamba_t = VSSM(depths=[2, 2, 4, 2],dims=[96,192,384,768],num_classes=6).to("cuda")
    medmamba_s = VSSM(depths=[2, 2, 8, 2],dims=[96,192,384,768],num_classes=6).to("cuda")
    medmamba_b = VSSM(depths=[2, 2, 12, 2],dims=[128,256,512,1024],num_classes=6).to("cuda")
    

    总体论文原理比较简单,但是论文实验做得很扎实,感兴趣查看原文。

3.在糖尿病视网膜数据上实验一下效果

数据集情况

采用开源的retino_data糖尿病视网膜病变数据集:

image-20241010113951487

环境安装

这部分主要是vision mamba的环境安装不要出错,参考官方Github会有问题:

  • Python 3.10.13

    • conda create -n vim python=3.10.13
  • torch 2.1.1 + cu118

    • pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118
  • Requirements: vim_requirements.txt

    • pip install -r vim/vim_requirements.txt

wget https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.1.3.post1/causal_conv1d-1.1.3.post1+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
wget https://github.com/state-spaces/mamba/releases/download/v1.1.1/mamba_ssm-1.1.1+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

  • pip install causal_conv1d-1.1.3.post1+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

  • pip install mamba_ssm-1.1.1+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

  • 然后用官方项目里的mamba_ssm替换安装在conda环境里的mamba_ssm

    • 用conda env list 查看刚才安装的mamba环境的路径,我的mamba环境在/home/aic/anaconda3/envs/vim

    • 用官方项目里的mamba_ssm替换安装在conda环境里的mamba_ssm
      cp -rf mamba-1p1p1/mamba_ssm /home/aic/anaconda3/envs/vim/lib/python3.10/site-packages

代码编写

编写一个检查数据集均值和方差的代码,不用Imagenet的:

# -*- coding: utf-8 -*-
# 作者: cskywit
# 文件名: mean_std.py
# 创建时间: 2024-10-07
# 文件描述:计算数据集的均值和方差


# 导入必要的库
from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms

# 定义函数get_mean_and_std,用于计算训练数据集的均值和标准差
def get_mean_and_std(train_data):
  # 创建DataLoader,用于批量加载数据
  train_loader = torch.utils.data.DataLoader(
      train_data, batch_size=1, shuffle=False, num_workers=0,
      pin_memory=True)
  # 初始化均值和标准差
  mean = torch.zeros(3)
  std = torch.zeros(3)
  # 遍历数据集中的每个批次
  for X, _ in train_loader:
      # 遍历RGB三个通道
      for d in range(3):
          # 计算每个通道的均值和标准差
          mean[d] += X[:, d, :, :].mean()
          std[d] += X[:, d, :, :].std()
  # 计算最终的均值和标准差
  mean.div_(len(train_data))
  std.div_(len(train_data))
  # 返回均值和标准差列表
  return list(mean.numpy()), list(std.numpy())

# 判断是否为主程序
if __name__ == '__main__':
  root_path = '/home/aic/deep_learning_data/retino_data/train'
  # 使用ImageFolder加载训练数据集
  train_dataset = ImageFolder(root=root_path, transform=transforms.ToTensor())
  # 打印训练数据集的均值和标准差
  print(get_mean_and_std(train_dataset))
  # ([0.41586006, 0.22244255, 0.07565845],
  # [0.23795983, 0.13206834, 0.05284985])

然后编写train

# -*- coding: utf-8 -*-
# 作者: cskywit
# 文件名: train_DR.py
# 创建时间: 2024-10-10
# 文件描述:
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdm

from MedMamba import VSSM as medmamba # import model
import warnings
import os,sys



warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES']="0"

# 设置随机因子
def seed_everything(seed=42):
  os.environ['PYHTONHASHSEED'] = str(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.backends.cudnn.deterministic = True

def main():
  # 设置随机因子
  seed_everything()
  # 一些超参数设定
  num_classes = 2
  BATCH_SIZE = 64
  num_of_workers = min([os.cpu_count(), BATCH_SIZE if BATCH_SIZE > 1 else 0, 8])  # number of workers
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  epochs = 300
  best_acc = 0.0
  save_path = './{}.pth'.format('bestmodel')
  # 数据预处理
  transform = transforms.Compose([
      transforms.RandomRotation(10),
      transforms.GaussianBlur(kernel_size=(5, 5), sigma=(0.1, 3.0)),
      transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
      transforms.Resize((224, 224)),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.41593555, 0.22245076, 0.075719066],
                           std=[0.23819199, 0.13202211, 0.05282707])

  ])
  transform_test = transforms.Compose([
      transforms.Resize((224, 224)),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.41593555, 0.22245076, 0.075719066],
                           std=[0.23819199, 0.13202211, 0.05282707])
  ])
  # 加载数据集
  root_path = '/home/aic/deep_learning_data/retino_data'
  train_path = os.path.join(root_path, 'train')
  valid_path = os.path.join(root_path, 'valid')
  test_path = os.path.join(root_path, 'test')
  dataset_train = datasets.ImageFolder(train_path, transform=transform)
  dataset_valid = datasets.ImageFolder(valid_path, transform=transform_test)
  dataset_test = datasets.ImageFolder(test_path, transform=transform_test)
  class_labels = {0: 'Diabetic Retinopathy', 1: 'No Diabetic Retinopathy'}
  val_num = len(dataset_valid)

  train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE,
                                             num_workers=num_of_workers,
                                             shuffle=True,
                                             drop_last=True)
  valid_loader = torch.utils.data.DataLoader(dataset_valid,
                                             batch_size=BATCH_SIZE,
                                             num_workers=num_of_workers,
                                             shuffle=False,
                                             drop_last=True)
  test_loader = torch.utils.data.DataLoader(dataset_test,
                                            batch_size=BATCH_SIZE,
                                            shuffle=False)
  print('Using {} dataloader workers every process'.format(num_of_workers))

  # 模型定义
  net = medmamba(num_classes=num_classes).to(device)
  loss_function = nn.CrossEntropyLoss()
  optimizer = optim.Adam(net.parameters(), lr=0.0001)
  train_steps = len(train_loader)

  for epoch in range(epochs):
      # train
      net.train()
      running_loss = 0.0
      train_bar = tqdm(train_loader, file=sys.stdout)
      for step, data in enumerate(train_bar):
          images, labels = data
          optimizer.zero_grad()
          outputs = net(images.to(device))
          loss = loss_function(outputs, labels.to(device))
          loss.backward()
          optimizer.step()

          # print statistics
          running_loss += loss.item()

          train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                   epochs,
                                                                   loss)

      # validate
      net.eval()
      acc = 0.0  # accumulate accurate number / epoch
      with torch.no_grad():
          val_bar = tqdm(valid_loader, file=sys.stdout)
          for val_data in val_bar:
              val_images, val_labels = val_data
              outputs = net(val_images.to(device))
              predict_y = torch.max(outputs, dim=1)[1]
              acc += torch.eq(predict_y, val_labels.to(device)).sum().item()

      val_accurate = acc / val_num
      print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
            (epoch + 1, running_loss / train_steps, val_accurate))

      if val_accurate > best_acc:
          best_acc = val_accurate
          torch.save(net.state_dict(), save_path)

  print('Finished Training')

if __name__ == '__main__':
  main()

感觉Mamaba系列的通病了吧,显存占用不算高,GPU利用率超高:

image-20241010112042331

可能是没有用任何的训练调参技巧,经过几个epoch后,验证集准确率很快提升到了92.3%,然后就没有继续上升了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2204161.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

每日论文18-24ISCAS采用磁调谐变压器低温CMOS28GHzVCO

《28 GHz VCO Using Magnetically Tuning Trifilar Transformer in Cryogenic CMOS Application 》24ISCAS 瞟到了这篇文章&#xff0c;开关真的是可以加在任何地方哈哈哈&#xff0c;还挺特别 通过改变电感偏压来改变Var的偏压&#xff0c;来拓宽带宽&#xff0c;其实是个挺简…

processing像素画教程

前提&#xff1a;各位已经安装了processing 第一步&#xff1a;创建一个简单的网格 我们首先创建一个网格来定义我们作品的像素画布。网格将帮助您在适当的位置绘制每个像素。 int gridSize 20; // 每个像素的大小 int cols, rows; void setup() {size(400, 400); // 设置画…

k8s 的网络通信

目录 1 k8s通信整体架构 2 flannel 网络插件 2.1 flannel 插件组成 2.2 flannel 插件的通信过程 2.3 flannel 支持的后端模式 3 calico 网络插件 3.1 calico 简介 3.2 calico 网络架构 3.3 部署 calico 1 k8s通信整体架构 k8s通过CNI接口接入其他插件来实现网络通讯。目前比较…

QTableView加入小灯泡

通过重载QAbstractTableModel中的data方法 QVariant CTblModel::data(const QModelIndex &index, int role) const { if (!index.isValid()) return QVariant(); int col index.column(); if (col ledColIndex && role Qt::DecorationRole) { return Q…

股指期货的杠杆是怎么体现和使用的?

股指期货的杠杆效应是通过保证金交易实现的。投资者只需支付合约价值的一小部分作为保证金&#xff0c;即可控制整个合约的价值。例如&#xff0c;如果一个股指期货合约的价值为100,000元&#xff0c;而保证金比例为10%&#xff0c;那么投资者只需支付10,000元即可控制这个合约…

PPT分享:埃森哲-业务流程BPM能力框架体系

PPT下载链接见文末~ 业务流程管理&#xff08;BPM, Business Process Management&#xff09;的能力框架体系是一个全面、系统的流程管理方法论和工具集&#xff0c;旨在帮助企业优化和持续改进其业务流程&#xff0c;从而提升运营效率和市场竞争力。 一、BPM能力框架体系概述…

云计算的江湖,风云再起

大数据产业创新服务媒体 ——聚焦数据 改变商业 还记得当年英特尔的广告语吗&#xff1f;“Intel Inside”&#xff0c;这个标志性的标签几乎成了计算设备的象征。然而&#xff0c;随着AI大模型的迅速崛起&#xff0c;计算的核心从CPU悄然转向了GPU。一场前所未有的技术革命正…

【学术会议征稿】第四届公共管理与大数据分析国际学术会议(PMBDA 2024)

第四届公共管理与大数据分析国际学术会议(PMBDA 2024) 2024 4th International Conference on Public Management and Big Data Analysis 第四届公共管理与大数据分析国际学术会议 &#xff08;PMBDA 2024&#xff09;将于2024年12月20-22日在中国青岛召开。会议主题主要围绕…

MySQL-表相关(DDL DML)

文章目录 表的基本操作表的创建表的删除 MySQL中的数据类型整数类型浮点数类型定点数类型日期和时间类型字符串类型charvarchartext 二进制类型 DDL语句查看建表语句修改表名新增字段修改字段(名类型)修改字段(仅类型)删除字段 表的基本操作 在介绍DDL和DQL的操作语句之前, 我…

HCIP-HarmonyOS Application Developer 习题(六)

&#xff08;多选&#xff09;1、Harmonyos多窗口交互能力提供了以下哪几种交互方式? A. 平行视界 B.全局消息通知 C.分屏 D.悬浮窗 答案&#xff1a;ACD 分析&#xff1a;系统提供了悬浮窗、分屏、平行视界三种多窗口交互&#xff0c;为用户在大屏幕设备上的多任务并行、便捷…

V2M2引擎传奇全套源码2024BLUE最新版 可自定义UI

特点优势是最新XE10.4或者XE12编辑器&#xff0c;微端&#xff0c;各种自定义UI 无限仿GOM引擎功能下载地址:BlueCodePXL_415.rar官方版下载丨最新版下载丨绿色版下载丨APP下载-123云盘 提取码: AuX7BlueCodePXL_415.rar官方版下载丨最新版下载丨绿色版下载丨APP下载-123云盘…

无需复杂计算!如何用“加法”打造高效而低功耗的语言模型

当我们聊到人工智能特别是语言模型时,大家脑海中可能浮现的都是庞大的计算能力、高能耗的服务器群。然而,最近有一篇有趣的论文《Addition Is All You Need for Energy-Efficient Language Models》(加法才是低能耗语言模型的关键)却颠覆了我们对语言模型的传统认知。那么,…

Redis高级篇 —— 分布式缓存

Redis高级篇 —— 分布式缓存 文章目录 Redis高级篇 —— 分布式缓存1 Redis持久化1.1 RDB1.2 RDB的fork原理1.3 RDB总结1.4 AOF持久化1.5 RDB和AOF的对比 2 Redis主从2.1 搭建主从架构2.2 数据同步原理2.2.1 全量同步2.2.2 增量同步 3 Redis哨兵3.1 哨兵的作用和原理3.1.1 哨兵…

基于IOU匹配的DeepSort目标跟踪与匈牙利算法解析

在多目标跟踪任务中&#xff0c;如何将检测框与已有轨迹进行关联&#xff0c;进而维持目标的连续跟踪&#xff0c;是一个关键问题。DeepSort&#xff08;Deep Simple Online and Realtime Tracking&#xff09;是一种常用的多目标跟踪算法&#xff0c;它结合了IOU&#xff08;交…

Linux搭建Hadoop集群(详细步骤)

前言 Hadoop是一个由Apache基金会所开发的分布式系统基础架构。用户可以在不了解分布式底层细节的情况下&#xff0c;开发分布式程序。充分利用集群的威力进行高速运算和存储。 说白了就是实现一个任务可以在多个电脑上计算的过程。 一&#xff1a;准备工具 1.1 VMware 1.2L…

利用内部知识库优化SOP与HR培训效果评估

在当今快速变化的商业环境中&#xff0c;企业运营的高效性和员工的综合能力成为决定竞争力的关键因素。SOP作为确保业务一致性和质量的基础&#xff0c;其有效执行至关重要。同时&#xff0c;HR培训作为提升员工技能和知识的重要手段&#xff0c;其效果直接影响到企业的整体绩效…

【顶刊核心变量】中国地级市绿色金融试点改革试验区名单数据(2010-2023年)

一、测算方式&#xff1a; 参考《中国工业经济》崔惠玉&#xff08;2023&#xff09;老师的研究&#xff0c;2017 年&#xff0c;国务院决定将浙江、广东、江西、贵州和新疆的部分地区作为绿色金融改革创新试验 区的首批试点地区。试点地区在顶层设计、组织体系、产品创新、配…

Docker容器简介及部署方法

1.1 Docker简介 Docker之父Solomon Hykes&#xff1a;Docker就好比传统的货运集装箱 2008 年LXC(LinuX Contiainer)发布&#xff0c;但是没有行业标准&#xff0c;兼容性非常差 docker2013年首次发布&#xff0c;由Docker, Inc开发 1.1.1什么是Docker Docker是管理容器的引…

java脚手架系列4--测试用例、拦截器

异常处理、拦截器、数据库连接 1 测试用例 单元测试是一个老生常谈的问题&#xff0c;无论是后端对自己的代码质量把的第一道关也好&#xff0c;也是对测试减缓压力。这里就不过多讲述测试用例的重要性&#xff0c;但是有2个框架我们必须了解一下。 1.1 JUnit和mockito 我们…

【gRPC】4—gRPC与Netty

gRPC与Netty ⭐⭐⭐⭐⭐⭐ Github主页&#x1f449;https://github.com/A-BigTree 笔记链接&#x1f449;https://github.com/A-BigTree/Code_Learning ⭐⭐⭐⭐⭐⭐ 如果可以&#xff0c;麻烦各位看官顺手点个star~&#x1f60a; &#x1f4d6;RPC专栏&#xff1a;https://…