基于U-Mamba使用nnUNetv2处理BraTS挑战赛数据
- 【深度学习总结】基于U-Mamba使用nnUNetv2处理BraTS挑战赛数据
- U-Mamba介绍
- 数据集下载
- 环境准备
- 数据集准备
- 运行
- 其他
- 2D网络结构
- UMambaBot的模型结构
- UMambaEnc的模型结构
【深度学习总结】基于U-Mamba使用nnUNetv2处理BraTS挑战赛数据
代码地址:U-Mamba
U-Mamba介绍
Mamba出世之后,因其高效率的长程建模和特征选择性扫描能力,很多研究者将其应用于医疗图像分割领域,U-Mamba就是其中之一,它的结构如下:
可以看到,它在卷积后面加了一个SSM块,进行特征的选择。
数据集下载
这里我们使用的是BraTs 2019数据集,这是一个脑肿瘤的分割挑战赛,已经举办很多届了,这里不过多赘述。
在百度飞浆社区下载BraTs 2019的训练集,地址为:https://aistudio.baidu.com/aistudio/datasetdetail/67772
环境准备
下载好U-Mamba的仓库后,进行环境的配置,具体如下:
-
首先安装好causal-conv1d和mamba-ssm,有两种办法:
-
使用pip安装
pip install causal-conv1d>=1.2.0 pip install mamba-ssm --no-cache-dir
-
使用whl安装(如果第一种报错的话)
首先在causal-conv1d的官方仓库下载对应的whl文件,地址:releases,注意pytorch版本、cuda版本以及python要对应。然后执行:
pip install 你的whl文件
然后在releases下载mamba-ssm的whl文件,然后同样执行上面的pip命令。
最后进入U-Mamba仓库的umamba文件夹,执行如下命令:
cd umamba pip install -e .
-
数据集准备
要将数据集提前准备好nnUNet可以处理的形式,它的一些路径在U-Mamba/umamba/nnunetv2/paths.py
文件设置,有三种文件夹:
- nnUNet_raw:符合一定格式的输入的数据集
- nnUNet_preprocessed:预处理后的数据集的输出地址
- nnUNet_results:生成的模型结果
原始数据集要放在raw文件夹中,形式为:
然后每个数据集下有:
因为我使用的是BraTS2019的数据集,不是这种形式,因此要进行数据的转换。
先在nnUNet的官网下载处理代码,地址为:Dataset043_BraTS19.py,然后放在本地的dataset_conversion
路径下。
从代码中可以看出,转换后的文件是被保存到nnUNet_raw中:
out_base = join(nnUNet_raw, foldername)
imagestr = join(out_base, "imagesTr")
labelstr = join(out_base, "labelsTr")
代码需要改一部分,因为数据集中文件的后缀是.gz,因此要将nii改成nii.gz。
print("copying hggs")
for c in tqdm(case_ids_hgg):
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1.nii.gz"), join(imagestr, c + '_0000.nii.gz'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t1ce.nii.gz"), join(imagestr, c + '_0001.nii.gz'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_t2.nii.gz"), join(imagestr, c + '_0002.nii.gz'))
shutil.copy(join(brats_data_dir, "HGG", c, c + "_flair.nii.gz"), join(imagestr, c + '_0003.nii.gz'))
copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "HGG", c, c + "_seg.nii.gz"),
join(labelstr, c + '.nii.gz'))
print("copying lggs")
for c in tqdm(case_ids_lgg):
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1.nii.gz"), join(imagestr, c + '_0000.nii.gz'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t1ce.nii.gz"), join(imagestr, c + '_0001.nii.gz'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_t2.nii.gz"), join(imagestr, c + '_0002.nii.gz'))
shutil.copy(join(brats_data_dir, "LGG", c, c + "_flair.nii.gz"), join(imagestr, c + '_0003.nii.gz'))
copy_BraTS_segmentation_and_convert_labels_to_nnUNet(join(brats_data_dir, "LGG", c, c + "_seg.nii.gz"),
join(labelstr, c + '.nii.gz'))
然后运行nnUNetv2_plan_and_preprocess -d 43 --verify_dataset_integrity
,如果很长时间是这个情况:
修改UMamba/umamba/nnunetv2/experiment_planning/plan_and_preprocess_entrypoints.py
中plan_and_preprocess_entry函数的npfp,之前是8,现在改成2,然后就可以正常输出了。
parser.add_argument('-npfp', type=int, default=2, required=False,
help='[OPTIONAL] Number of processes used for fingerprint extraction. Default: 8')
最后输出的文件内容为:
其中nnUNetPlans.json
是配置文件。
运行
使用如下命令运行:
nnUNetv2_train 43 2d all -tr nnUNetTrainerUMambaEnc
运行成功如下:
- 训练的主代码在
U-Mamba/umamba/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
中的run_training
函数中,梯度下降在train_step函数中。
其他
数据集加载在:nnunetv2/training/dataloading/nnunet_dataset.py
中,其中properties_file
应该是含有分割的区域信息,后缀为pkl;seg.npy
是原始的分割文件。
将区域标签进行转换的代码,它最终是以区域标签进行训练的
在nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
中,如下:
tr_transforms.append(ConvertSegmentationToRegionsTransform(list(regions) + [ignore_label]
if ignore_label is not None else regions,
'target', 'target'))
同时,在此之前,还将seg换成了target键,如下:
tr_transforms.append(RenameTransform('seg', 'target', True))
2D网络结构
如果你不想使用nnUNetv2来进行训练,下面的网络结构或许对你有帮助:
UMambaBot的模型结构
UMambaBot: UMambaBot(
(encoder): UNetResEncoder(
(stem): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(stages): Sequential(
(0): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(1): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(2): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(3): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
)
)
(4): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
)
)
(5): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(512, 512, kernel_size=(1, 1), stride=(2, 2))
)
)
)
)
(mamba_layer): MambaLayer(
(norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
(mamba): Mamba(
(in_proj): Linear(in_features=512, out_features=2048, bias=False)
(conv1d): Conv1d(1024, 1024, kernel_size=(4,), stride=(1,), padding=(3,), groups=1024)
(act): SiLU()
(x_proj): Linear(in_features=1024, out_features=64, bias=False)
(dt_proj): Linear(in_features=32, out_features=1024, bias=True)
(out_proj): Linear(in_features=1024, out_features=512, bias=False)
)
)
(decoder): UNetResDecoder(
(encoder): UNetResEncoder(
(stem): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(stages): Sequential(
(0): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(1): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(2): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(3): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
)
)
(4): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
)
)
(5): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(512, 512, kernel_size=(1, 1), stride=(2, 2))
)
)
)
)
(stages): ModuleList(
(0): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(1): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(2): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(3): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(4): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
)
)
)
(upsample_layers): ModuleList(
(0): UpsampleLayer(
(conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
(1): UpsampleLayer(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
)
(2): UpsampleLayer(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
)
(3): UpsampleLayer(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(4): UpsampleLayer(
(conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
)
)
(seg_layers): ModuleList(
(0): Conv2d(512, 3, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
(2): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))
(3): Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))
(4): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1))
)
)
)
UMambaEnc的模型结构
feature_map_sizes: [[192, 160], [96, 80], [48, 40], [24, 20], [12, 10], [6, 5]]
do_channel_token: [False, False, False, False, True, True]
MambaLayer: dim: 64
MambaLayer: dim: 256
MambaLayer: dim: 30
UMambaEnc: UMambaEnc(
(encoder): ResidualMambaEncoder(
(stem): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(mamba_layers): ModuleList(
(0): Identity()
(1): MambaLayer(
(norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(mamba): Mamba(
(in_proj): Linear(in_features=64, out_features=256, bias=False)
(conv1d): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)
(act): SiLU()
(x_proj): Linear(in_features=128, out_features=36, bias=False)
(dt_proj): Linear(in_features=4, out_features=128, bias=True)
(out_proj): Linear(in_features=128, out_features=64, bias=False)
)
)
(2): Identity()
(3): MambaLayer(
(norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(mamba): Mamba(
(in_proj): Linear(in_features=256, out_features=1024, bias=False)
(conv1d): Conv1d(512, 512, kernel_size=(4,), stride=(1,), padding=(3,), groups=512)
(act): SiLU()
(x_proj): Linear(in_features=512, out_features=48, bias=False)
(dt_proj): Linear(in_features=16, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=256, bias=False)
)
)
(4): Identity()
(5): MambaLayer(
(norm): LayerNorm((30,), eps=1e-05, elementwise_affine=True)
(mamba): Mamba(
(in_proj): Linear(in_features=30, out_features=120, bias=False)
(conv1d): Conv1d(60, 60, kernel_size=(4,), stride=(1,), padding=(3,), groups=60)
(act): SiLU()
(x_proj): Linear(in_features=60, out_features=34, bias=False)
(dt_proj): Linear(in_features=2, out_features=60, bias=True)
(out_proj): Linear(in_features=60, out_features=30, bias=False)
)
)
)
(stages): ModuleList(
(0): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(1): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(2): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(3): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
)
)
(4): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
)
)
(5): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(512, 512, kernel_size=(1, 1), stride=(2, 2))
)
)
)
)
(decoder): UNetResDecoder(
(encoder): ResidualMambaEncoder(
(stem): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(4, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(4, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(mamba_layers): ModuleList(
(0): Identity()
(1): MambaLayer(
(norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(mamba): Mamba(
(in_proj): Linear(in_features=64, out_features=256, bias=False)
(conv1d): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)
(act): SiLU()
(x_proj): Linear(in_features=128, out_features=36, bias=False)
(dt_proj): Linear(in_features=4, out_features=128, bias=True)
(out_proj): Linear(in_features=128, out_features=64, bias=False)
)
)
(2): Identity()
(3): MambaLayer(
(norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(mamba): Mamba(
(in_proj): Linear(in_features=256, out_features=1024, bias=False)
(conv1d): Conv1d(512, 512, kernel_size=(4,), stride=(1,), padding=(3,), groups=512)
(act): SiLU()
(x_proj): Linear(in_features=512, out_features=48, bias=False)
(dt_proj): Linear(in_features=16, out_features=512, bias=True)
(out_proj): Linear(in_features=512, out_features=256, bias=False)
)
)
(4): Identity()
(5): MambaLayer(
(norm): LayerNorm((30,), eps=1e-05, elementwise_affine=True)
(mamba): Mamba(
(in_proj): Linear(in_features=30, out_features=120, bias=False)
(conv1d): Conv1d(60, 60, kernel_size=(4,), stride=(1,), padding=(3,), groups=60)
(act): SiLU()
(x_proj): Linear(in_features=60, out_features=34, bias=False)
(dt_proj): Linear(in_features=2, out_features=60, bias=True)
(out_proj): Linear(in_features=60, out_features=30, bias=False)
)
)
)
(stages): ModuleList(
(0): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(1): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(2): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(3): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2))
)
)
(4): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
)
)
(5): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
(norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(512, 512, kernel_size=(1, 1), stride=(2, 2))
)
)
)
)
(stages): ModuleList(
(0): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(1): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(2): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
)
(1): BasicBlockD(
(conv1): ConvDropoutNormReLU(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(nonlin): LeakyReLU(negative_slope=0.01, inplace=True)
(all_modules): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(conv2): ConvDropoutNormReLU(
(conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(all_modules): Sequential(
(0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
)
)
(nonlin2): LeakyReLU(negative_slope=0.01, inplace=True)
)
)
(3): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
)
(4): Sequential(
(0): BasicResBlock(
(conv1): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act1): LeakyReLU(negative_slope=0.01, inplace=True)
(conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(norm2): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
(act2): LeakyReLU(negative_slope=0.01, inplace=True)
(conv3): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
)
)
)
(upsample_layers): ModuleList(
(0): UpsampleLayer(
(conv): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
(1): UpsampleLayer(
(conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
)
(2): UpsampleLayer(
(conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
)
(3): UpsampleLayer(
(conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
)
(4): UpsampleLayer(
(conv): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
)
)
(seg_layers): ModuleList(
(0): Conv2d(512, 3, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1))
(2): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))
(3): Conv2d(64, 3, kernel_size=(1, 1), stride=(1, 1))
(4): Conv2d(32, 3, kernel_size=(1, 1), stride=(1, 1))
)
)
)