axial attention 轴向注意力

news2025/1/16 10:58:05

Medical Transformer: Gated Axial-Attention for Medical Image Segmentation
论文解读:
https://zhuanlan.zhihu.com/p/408662947

在这里插入图片描述

实验结果:
在这里插入图片描述

0 前言

0.1 原始的注意力机制

在这里插入图片描述

0.2 轴向注意力机制+ 相对位置编码

在这里插入图片描述

0.3 在轴向注意力机制基础上 +gated 门控单元

在这里插入图片描述
门控轴向注意机制,引入 四个门共同构成了门控机制,来控制相对位置编码向key、query和value提供的信息量。控制了相对位置编码对非局部上下文编码的影响。
根据相对位置编码获得的信息是否有用,栅极参数要么收敛于0,要么收敛于某个更高的值。如果一个相对的位置编码被准确地学习,与那些不被准确学习的编码相比,门控机制会赋予它较高的权重。

1. axialAttentionUNet

1.1 原始的 axialAttentionUNet

model = ResAxialAttentionUNet(AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs)

  1. 原始的轴注意力 + 残差网络构成的unet
ResAxialAttentionUNet(
  (conv1): Conv2d(3, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (conv2): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv3): Conv2d(128, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): AxialBlock(
      (conv_down): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (layer2): Sequential(
    (0): AxialBlock(
      (conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): AxialBlock(
      (conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer3): Sequential(
    (0): AxialBlock(
      (conv_down): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): AxialBlock(
      (conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): AxialBlock(
      (conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): AxialBlock(
      (conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer4): Sequential(
    (0): AxialBlock(
      (conv_down): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (decoder1): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (decoder2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder3): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (adjust): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
  (soft): Softmax(dim=1)
)

1.2 添加了门控单元的轴注意力网络

model = ResAxialAttentionUNet(AxialBlock_dynamic, [1, 2, 4, 1], s= 0.125, **kwargs)

在门控轴注意力网络中, 
1. gated axial attention network 将axial attention layers 轴注意力层 全部换成门控轴注意力层。

ResAxialAttentionUNet(
  (conv1): Conv2d(3, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (conv2): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv3): Conv2d(128, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): AxialBlock_dynamic(
      (conv_down): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (layer2): Sequential(
    (0): AxialBlock_dynamic(
      (conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): AxialBlock_dynamic(
      (conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer3): Sequential(
    (0): AxialBlock_dynamic(
      (conv_down): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): AxialBlock_dynamic(
      (conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): AxialBlock_dynamic(
      (conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): AxialBlock_dynamic(
      (conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer4): Sequential(
    (0): AxialBlock_dynamic(
      (conv_down): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (decoder1): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (decoder2): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder3): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (adjust): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
  (soft): Softmax(dim=1)
)

2. Medical Transformer

训练过程中,需要注意 前10个 epoch 并没有激活gated 门控单元,在10个epoch 之后才会开启。

2.1 local _ global

model = medt_net(AxialBlock,AxialBlock, [1, 2, 4, 1], s= 0.125, **kwargs)

LoGo network:
在局部 + 全局的网络中:

使用的是方式是:

  1. 使用原始轴注意力构成的unet , 没有使用本文提出的门控轴注意力单元.
  2. 使用了 local+ global training 的训练策略.
medt_net(
  (conv1): Conv2d(3, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (conv2): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv3): Conv2d(128, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): AxialBlock(
      (conv_down): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (layer2): Sequential(
    (0): AxialBlock(
      (conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): AxialBlock(
      (conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (adjust): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
  (soft): Softmax(dim=1)
  (conv1_p): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (conv2_p): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv3_p): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1_p): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2_p): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3_p): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu_p): ReLU(inplace=True)
  (layer1_p): Sequential(
    (0): AxialBlock(
      (conv_down): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (layer2_p): Sequential(
    (0): AxialBlock(
      (conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): AxialBlock(
      (conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer3_p): Sequential(
    (0): AxialBlock(
      (conv_down): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): AxialBlock(
      (conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): AxialBlock(
      (conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): AxialBlock(
      (conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer4_p): Sequential(
    (0): AxialBlock(
      (conv_down): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention(
        (qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention(
        (qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (decoder1_p): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (decoder2_p): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder3_p): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder4_p): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder5_p): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoderf): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (adjust_p): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
  (soft_p): Softmax(dim=1)
)

2.2 Med transformer

model = medt_net(AxialBlock_dynamic,AxialBlock_wopos, [1, 2, 4, 1], s= 0.125, **kwargs)

使用的是方式是:

  1. 在全局分支中,使用提出的门控轴注意力单元。  而在局部分支中,使用的是原始轴注意力,并且没有位置编码。

  2. 使用了 local+ global training 的训练策略.

medt_net(
  (conv1): Conv2d(3, 8, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (conv2): Conv2d(8, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv3): Conv2d(128, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): AxialBlock_dynamic(
      (conv_down): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (layer2): Sequential(
    (0): AxialBlock_dynamic(
      (conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): AxialBlock_dynamic(
      (conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_dynamic(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (decoder4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder5): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (adjust): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
  (soft): Softmax(dim=1)
  (conv1_p): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (conv2_p): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv3_p): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1_p): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn2_p): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (bn3_p): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu_p): ReLU(inplace=True)
  (layer1_p): Sequential(
    (0): AxialBlock_wopos(
      (conv_down): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (conv1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1))
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(16, 32, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (layer2_p): Sequential(
    (0): AxialBlock_wopos(
      (conv_down): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (conv1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): AxialBlock_wopos(
      (conv_down): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (conv1): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1))
      (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(32, 64, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer3_p): Sequential(
    (0): AxialBlock_wopos(
      (conv_down): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): AxialBlock_wopos(
      (conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (2): AxialBlock_wopos(
      (conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (3): AxialBlock_wopos(
      (conv_down): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(64, 128, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (conv_up): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
  )
  (layer4_p): Sequential(
    (0): AxialBlock_wopos(
      (conv_down): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (hight_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (width_block): AxialAttention_wopos(
        (qkv_transform): qkv_transform(128, 256, kernel_size=(1,), stride=(1,), bias=False)
        (bn_qkv): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_similarity): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (bn_output): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pooling): AvgPool2d(kernel_size=2, stride=2, padding=0)
      )
      (conv_up): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
  )
  (decoder1_p): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (decoder2_p): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder3_p): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder4_p): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoder5_p): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (decoderf): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (adjust_p): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))
  (soft_p): Softmax(dim=1)
)

3. reference:

3.1 十字交叉 注意力

https://github.com/yearing1017/CCNet_PyTorch/tree/master/CCNet

https://github.com/speedinghzl/CCNet

3.2 轴注意力机制

https://github.com/lucidrains/axial-attention

Axial Attention in Multidimensional Transformers

3.3  轴注意力机制的应用

MetNet: A Neural Weather Model for Precipitation Forecasting

Medical Transformer:

MeD T 文章解读

轴注意力网络:
https://blog.csdn.net/hxxjxw/article/details/121445561;

https://blog.csdn.net/weixin_43718675/article/details/106760382

https://zhuanlan.zhihu.com/p/408662947;

推荐阅读
https://blog.csdn.net/weixin_43718675/article/details/106760382#t4

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/428488.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pytest测试报告Allure - 动态生成标题生成功能、添加用例失败截图

一、动态生成标题 默认 allure 报告上的测试用例标题不设置就是用例名称,其可读性不高;当结合 pytest.mark.parametrize 参数化完成数据驱动时,如标题写死,其可读性也不高。 那如果希望标题可以动态的生成,采取的方案…

(附3D大屏模板)详解FineVis如何打造智慧医院BIM方案!

近日,又一所三甲医院搭建起了智慧医院,它是深圳大鹏新区人民医院,采用IBM技术,是一家集医疗、科研、预防保健和康复疗养功能的综合体。 这栋建筑包含床位数2000个,总建筑面积417444平方米,建筑高度79.75米…

第四章 法的效力

目录 第一节 法的效力概述 一、法的效力的意义二、法的效力的概念三、法的效力范围 第二节法的时间效力 一、法的生效时间二、法的失效时间三、法律溯及力 第三节法的空间效力 一、法的域内效力二、法的域外效力 第四节 法的对人效力 一、对人效力的原则二、我国法律的对人效力…

epoll 反应堆模型(Libevent库核心思想)

epoll 反应堆模型总述 epoll 反应堆模型是从 libevent 库里面抽取的核心代码。 epoll ET模式 非阻塞、轮询 void *ptr 反应堆的理解:参考理解 加入IO转接之后,有了事件,server才去处理,这里反应堆也是这样,由于网络…

ssm框架之SpringMVC:浅聊获得参数以及获得请求头参数

前面聊过了SpringMVC,以及通过实例演示了SpringMVC如何搭建,如果对环境搭建不太了解的话,可以看一下前面的文章(下面演示的例子,环境都是通过上面的例子进行演示的):传送阵 在使用javaweb项目原…

30天学会《Streamlit》(2)

30学会《Streamlit》是一项编码挑战,旨在帮助您开始构建Streamlit应用程序。特别是,您将能够: 为构建Streamlit应用程序设置编码环境 构建您的第一个Streamlit应用程序 了解用于Streamlit应用程序的所有很棒的输入/输出小部件 第2天 - 构建…

Kubernetes部署Nacos集群

一、k8s架构 master: 11.0.1.3 node: 11.0.1.4,11.0.1.5(nfs) nfs: 11.0.1.5 二、安装nfs 安装nfs-utils和rpcbind nfs客户端和服务端都安装nfs-utils包 yum install nfs-utils rpcbind -y创建共享目录 mkdir -p /nfsdata chmod 777 /nfsdata编辑/etc/exports文件添加如下…

【部署项目】记录一些踩到的坑

这里写自定义目录标题记录一些踩到的坑设置二级域名注意事项文件访问权限记录一些踩到的坑 这个帖子是用来记录自己在windows和linux下进行部署项目时遇到的坑,以及对应的解决办法 设置二级域名 当你只有一个域名又不想买新的域名的时候(域名其实很便宜,贵在租服务器上),二级…

剖析多利熊业务如何基于分布式架构实践稳定性建设

作者 | 百度小程序团队 导读 多利熊稳定性建设,是指为了确保系统或服务,在生产环境中的稳定性而采取的一系列措施和优化。这包括但不限于监控、预警、容错、自动化、规范、质量等方面的优化。通过稳定性建设,可以提高系统的可靠性和可用性&am…

sysbench压测MySQL8问题记录

数据库版本:MySQL8.0.26 sysbench版本:sysbench 1.0.17 CentOS版本:CentOS Linux release 7.9.2009 (Core) 问题一:FATAL: error 2059: Authentication plugin ‘caching_sha2_password’ cannot be loaded 执行 sysbench /usr/…

【论文阅读】3D-LaneNet

【论文阅读】3D-LaneNet 主要要做的事情就是 lane detection。这里提一下 BEV(Bird‘s Eye View) 感知算法,为了将 2D 图像映射到 3D 空间中,能够更准确的检测物体位置,需要 BEV 感知的结果。后续还会继续了解这方面内…

论文阅读【17】Dynamic ensemble learning for multi-label classification

论文十问十答: Q1论文试图解决什么问题? Q2这是否是一个新的问题? Q3这篇文章要验证一个什么科学假设? Q4有哪些相关研究?如何归类?谁是这一课题在领域内值得关注的研究员? Q5论文中提到的解决方…

研究LLMs之前,不如先读读这五篇论文!

目标:了解 LMM 背后的主要思想 ▪️ Neural Machine Translation by Jointly Learning to Align and Translate ▪️ Attention Is All You Need ▪️ BERT ▪️ Improving Language Understanding by Generative Pre-Training ▪️ BART Neural Machine Translati…

引导程序、BIOS中断、检测内存容量、实模式切换到保护模式

初始化引导程序 基本概念 BIOS会将磁盘的第0个扇区(大小为512字节),加载到0x7c00处。 引导程序负责操作系统的加载,主要用于为操作系统运行提供初始化环境,并运行加载操作系统。 BIOS只加载磁盘的第0个扇区(512字节)到内存中,次程…

【论文阅读_序列推荐】Intent Contrastive Learning for Sequential Recommendation

【论文阅读_序列推荐】Intent Contrastive Learning for Sequential Recommendation 文章目录【论文阅读_序列推荐】Intent Contrastive Learning for Sequential Recommendation1. 来源2. 介绍3. 准备工作3.1 问题定义3.2 用于下一个项目预测的深度 SR 模型3.3 SR中的对比SSL …

基于springboot和ajax的简单项目 06 日志界面的delete功能(根据选择的checkbox)

01.这次后台开始&#xff1b; 顺序依次是dao->xml->service->serviceimpl->controller->html 02.dao接口 public int doDeleteObjects(Param("ids") Integer... ids);03.xml文件 <update id"doDeleteObjects" >delete from sys_lo…

七项新发布,亚马逊云科技Amazon S3持续进化

17年前的3月14日&#xff0c;亚马逊云科技推出了一项“非常简单的”对象存储服务&#xff08;Amazon Simple Storage Service&#xff09;。该服务允许开发人员创建、列出和删除私有存储空间&#xff08;称为存储桶&#xff09;、上传和下载文件以及管理其访问权限。当时&#…

C++刷题--选择题1

文章目录选择题选择题 1&#xff0c; 以下for循环的执行次数是&#xff08;&#xff09; for(int x 0, y 0; (y 123) && (x < 4); x);A 是无限循环 B 循环次数不定 C 4次 D 3次 解析 &#xff1a; C&#xff0c;for循环y 123 是赋值语句&#xff0c; 也就是一…

PSO算法

&#x1f34e;道阻且长&#xff0c;行则将至。&#x1f353; 目录1.PSO算法主要步骤&#x1f331;2.PSO更新方法&#x1f33e;3.PSO求解TSP问题&#x1f334;粒子群算法&#xff08;Particle Swarm Optimization&#xff0c;简称PSO&#xff09;是一种优化算法&#xff0c;模拟…

美国全力打击币圈 “一套花式组合拳”,打得从业者透不过气

银行危机“平息”过后&#xff0c;美国监管机构对币圈接连出手&#xff0c;一套花式组合拳打得从业者透不过气&#xff0c;也使得加密行业在政府的拳头之下风声鹤唳。 首先&#xff0c;切断加密货币与传统金融机构的联系。美国金融体系陷入混乱之际&#xff0c;一系列历史性的银…