TansUNet代码理解

news2025/1/13 7:34:49

首先通过论文中所给的图片了解网络的整体架构:
在这里插入图片描述

vit_seg_modeling部分

模块引入和定义相关量:

# coding=utf-8
# __future__ 在老版本的Python代码中兼顾新特性的一种方法
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import logging
import math

from os.path import join as pjoin

import torch
import torch.nn as nn
import numpy as np

from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
from torch.nn.modules.utils import _pair
from scipy import ndimage
from . import vit_seg_configs as configs
from .vit_seg_modeling_resnet_skip import ResNetV2

logger = logging.getLogger(__name__)

ATTENTION_Q = "MultiHeadDotProductAttention_1/query"
ATTENTION_K = "MultiHeadDotProductAttention_1/key"
ATTENTION_V = "MultiHeadDotProductAttention_1/value"
ATTENTION_OUT = "MultiHeadDotProductAttention_1/out"
FC_0 = "MlpBlock_3/Dense_0"
FC_1 = "MlpBlock_3/Dense_1"
ATTENTION_NORM = "LayerNorm_0"
MLP_NORM = "LayerNorm_2"

# 获取超参
CONFIGS = {
    'ViT-B_16': configs.get_b16_config(),
    'ViT-B_32': configs.get_b32_config(),
    'ViT-L_16': configs.get_l16_config(),
    'ViT-L_32': configs.get_l32_config(),
    'ViT-H_14': configs.get_h14_config(),
    'R50-ViT-B_16': configs.get_r50_b16_config(),
    'R50-ViT-L_16': configs.get_r50_l16_config(),
    'testing': configs.get_testing(),
}

工具函数的定义:
np2th用于将numpy格式的数据改为tensor。

def np2th(weights, conv=False):
    """Possibly convert HWIO to OIHW."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)

swish时由谷歌团队提出来的激活函数,他们实验表明,在一些具有挑战性的数据集上,它的效果比relu更好。

def swish(x):
    return x * torch.sigmoid(x)


ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}

采用自顶向下的结构来理解代码
VisionTransformer就是模型的整个结构,其中调用了Transformer,DecoderCup,SegmentationHead,load_from用于加载训练好的参数。

class VisionTransformer(nn.Module):
    def __init__(self, config, img_size=224, num_classes=21843, zero_head=False, vis=False):
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.zero_head = zero_head
        self.classifier = config.classifier
        self.transformer = Transformer(config, img_size, vis)
        self.decoder = DecoderCup(config)
        self.segmentation_head = SegmentationHead(
            in_channels=config['decoder_channels'][-1],
            out_channels=config['n_classes'],
            kernel_size=3,
        )
        self.config = config

    def forward(self, x):
        if x.size()[1] == 1:
            x = x.repeat(1, 3, 1, 1)
        x, attn_weights, features = self.transformer(x)  # (B, n_patch, hidden)
        x = self.decoder(x, features)
        logits = self.segmentation_head(x)
        return logits

    def load_from(self, weights):
        # with torch.no_grad()将所有require_grad临时设置为False,这样可以只更新变量的值
        with torch.no_grad():

            res_weight = weights
            self.transformer.embeddings.patch_embeddings.weight.copy_(np2th(weights["embedding/kernel"], conv=True))
            self.transformer.embeddings.patch_embeddings.bias.copy_(np2th(weights["embedding/bias"]))

            self.transformer.encoder.encoder_norm.weight.copy_(np2th(weights["Transformer/encoder_norm/scale"]))
            self.transformer.encoder.encoder_norm.bias.copy_(np2th(weights["Transformer/encoder_norm/bias"]))

            posemb = np2th(weights["Transformer/posembed_input/pos_embedding"])

            posemb_new = self.transformer.embeddings.position_embeddings
            if posemb.size() == posemb_new.size():
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            elif posemb.size()[1] - 1 == posemb_new.size()[1]:
                posemb = posemb[:, 1:]
                self.transformer.embeddings.position_embeddings.copy_(posemb)
            else:
                logger.info("load_pretrained: resized variant: %s to %s" % (posemb.size(), posemb_new.size()))
                ntok_new = posemb_new.size(1)
                if self.classifier == "seg":
                    _, posemb_grid = posemb[:, :1], posemb[0, 1:]
                gs_old = int(np.sqrt(len(posemb_grid)))
                gs_new = int(np.sqrt(ntok_new))
                print('load_pretrained: grid-size from %s to %s' % (gs_old, gs_new))
                posemb_grid = posemb_grid.reshape(gs_old, gs_old, -1)
                zoom = (gs_new / gs_old, gs_new / gs_old, 1)
                posemb_grid = ndimage.zoom(posemb_grid, zoom, order=1)  # th2np
                posemb_grid = posemb_grid.reshape(1, gs_new * gs_new, -1)
                posemb = posemb_grid
                self.transformer.embeddings.position_embeddings.copy_(np2th(posemb))

            # Encoder whole
            for bname, block in self.transformer.encoder.named_children():
                for uname, unit in block.named_children():
                    unit.load_from(weights, n_block=uname)

            if self.transformer.embeddings.hybrid:
                self.transformer.embeddings.hybrid_model.root.conv.weight.copy_(
                    np2th(res_weight["conv_root/kernel"], conv=True))
                # .view(-1)将tensor展开为一维张量,但不改变该对象本身的形状
                gn_weight = np2th(res_weight["gn_root/scale"]).view(-1)
                gn_bias = np2th(res_weight["gn_root/bias"]).view(-1)
                self.transformer.embeddings.hybrid_model.root.gn.weight.copy_(gn_weight)
                self.transformer.embeddings.hybrid_model.root.gn.bias.copy_(gn_bias)

                for bname, block in self.transformer.embeddings.hybrid_model.body.named_children():
                    for uname, unit in block.named_children():
                        unit.load_from(res_weight, n_block=bname, n_unit=uname)


接下来是Transformer的代码:
Transformer包括了Embeddings和Encoder:

class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids):
        embedding_output, features = self.embeddings(input_ids)
        encoded, attn_weights = self.encoder(embedding_output)  # (B, n_patch, hidden)
        return encoded, attn_weights, features

Embeddings的功能对应于图片中的:
在这里插入图片描述
ResNetV2(这部分的代码放在最后一个部分)对图片通过卷积操作提取特征,然后将提取到的各层特征返回到Embeddings。
拿到ResNetV2返回的特征后,将最后一层的特征分割为多个切片,并将各个切片映射成长度为patch_size*patch_size*channels的向量,并且加上位置序列信息,对应于图片的这个部分:
在这里插入图片描述

class Embeddings(nn.Module):
    """Construct the embeddings from patch, position embeddings.
    """

    def __init__(self, config, img_size, in_channels=3):
        super(Embeddings, self).__init__()
        self.hybrid = None
        self.config = config
        # 应该是把参数中的img_size,转换为元组形式即:img_size = (value,value)这里的value即为参数的img_size。
        img_size = _pair(img_size)

        if config.patches.get("grid") is not None:  # ResNet
            grid_size = config.patches["grid"]  # grid 是一个元组,值为:输入图片大小//切片大小
            patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])
            patch_size_real = (patch_size[0] * 16, patch_size[1] * 16)
            n_patches = (img_size[0] // patch_size_real[0]) * (img_size[1] // patch_size_real[1])
            self.hybrid = True
        else:
            patch_size = _pair(config.patches["size"])
            n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
            self.hybrid = False

        if self.hybrid:
            self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers, width_factor=config.resnet.width_factor)
            in_channels = self.hybrid_model.width * 16
        # patch_embeddings通过卷积操作将输入转变为(B, hidden_size, n_patches^(1/2), n_patches^(1/2))
        # hidden_size是一个token(相当于输入的一个词)的长度
        self.patch_embeddings = Conv2d(in_channels=in_channels,
                                       out_channels=config.hidden_size,
                                       kernel_size=patch_size,
                                       stride=patch_size)
        # 各个向量的位置序列
        self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches, config.hidden_size))

        self.dropout = Dropout(config.transformer["dropout_rate"])

    def forward(self, x):
        if self.hybrid:
            x, features = self.hybrid_model(x)
        else:
            features = None
        x = self.patch_embeddings(x)  # (B, hidden, n_patches^(1/2), n_patches^(1/2))
        x = x.flatten(2)  # 表示从2维开始压缩,得到(B, hidden, n_patches)
        x = x.transpose(-1, -2)  # 对最后两个维度进行转置(B, n_patches, hidden)

        embeddings = x + self.position_embeddings  # 加上位置序列
        embeddings = self.dropout(embeddings)
        return embeddings, features

Encoder是图像的编码部分,根据num_layers生成多个Block模块

class Encoder(nn.Module):
    def __init__(self, config, vis):
        super(Encoder, self).__init__()
        self.vis = vis
        # nn.ModuleList()一个module列表,与普通的list相比,它继承了nn.Module的网络模型class,因此可以识别其中的parameters,
        # 即该列表中记录的module可以被主module识别,但它只是一个list,不会自动实现forward方法。
        self.layer = nn.ModuleList()
        self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)
        for _ in range(config.transformer["num_layers"]):
            layer = Block(config, vis)
            self.layer.append(copy.deepcopy(layer))

    def forward(self, hidden_states):
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights

Block包括了MSA(Multihead Self-Attention)和MSA(Multi-Layer Perceptron)两个结构,对应于图像中的:
在这里插入图片描述

class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config, vis)

    def forward(self, x):
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x)
        x = x + h

        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x, weights

    def load_from(self, weights, n_block):
        ROOT = f"Transformer/encoderblock_{n_block}"
        with torch.no_grad():
            query_weight = np2th(weights[pjoin(ROOT, ATTENTION_Q, "kernel")]).view(self.hidden_size,
                                                                                   self.hidden_size).t()
            key_weight = np2th(weights[pjoin(ROOT, ATTENTION_K, "kernel")]).view(self.hidden_size, self.hidden_size).t()
            value_weight = np2th(weights[pjoin(ROOT, ATTENTION_V, "kernel")]).view(self.hidden_size,
                                                                                   self.hidden_size).t()
            out_weight = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "kernel")]).view(self.hidden_size,
                                                                                   self.hidden_size).t()

            query_bias = np2th(weights[pjoin(ROOT, ATTENTION_Q, "bias")]).view(-1)
            key_bias = np2th(weights[pjoin(ROOT, ATTENTION_K, "bias")]).view(-1)
            value_bias = np2th(weights[pjoin(ROOT, ATTENTION_V, "bias")]).view(-1)
            out_bias = np2th(weights[pjoin(ROOT, ATTENTION_OUT, "bias")]).view(-1)

            self.attn.query.weight.copy_(query_weight)
            self.attn.key.weight.copy_(key_weight)
            self.attn.value.weight.copy_(value_weight)
            self.attn.out.weight.copy_(out_weight)
            self.attn.query.bias.copy_(query_bias)
            self.attn.key.bias.copy_(key_bias)
            self.attn.value.bias.copy_(value_bias)
            self.attn.out.bias.copy_(out_bias)

            mlp_weight_0 = np2th(weights[pjoin(ROOT, FC_0, "kernel")]).t()
            mlp_weight_1 = np2th(weights[pjoin(ROOT, FC_1, "kernel")]).t()
            mlp_bias_0 = np2th(weights[pjoin(ROOT, FC_0, "bias")]).t()
            mlp_bias_1 = np2th(weights[pjoin(ROOT, FC_1, "bias")]).t()

            self.ffn.fc1.weight.copy_(mlp_weight_0)
            self.ffn.fc2.weight.copy_(mlp_weight_1)
            self.ffn.fc1.bias.copy_(mlp_bias_0)
            self.ffn.fc2.bias.copy_(mlp_bias_1)

            self.attention_norm.weight.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "scale")]))
            self.attention_norm.bias.copy_(np2th(weights[pjoin(ROOT, ATTENTION_NORM, "bias")]))
            self.ffn_norm.weight.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "scale")]))
            self.ffn_norm.bias.copy_(np2th(weights[pjoin(ROOT, MLP_NORM, "bias")]))

Attention对应图中的MSA部分,num_heads即为多头注意力机制的数量,attention_head_size为每个注意力机制的输出大小。Multihead self-attention 就是采用多个注意力机制来预测,但实现时并不是采用循环来实现多次,由于每个注意力机制采用相同的策略,他们只存在学习到的参数的差异,所以可以直接学习一个大的参数矩阵,我的理解如下图所示:
在这里插入图片描述

class Attention(nn.Module):
    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)

    def transpose_for_scores(self, x):
        # new_x_shape (B, n_patch, num_attention_heads, attention_head_size)
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        # view()方法主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensor
        x = x.view(*new_x_shape)
        # permute可以对任意高维矩阵进行转置,transpose只能操作2D矩阵的转置
        return x.permute(0, 2, 1, 3)  # return (B, num_attention_heads, n_patch, attention_head_size)

    def forward(self, hidden_states):
        # hidden_states (B, n_patch, hidden)
        # mixed_*  (B, n_patch, all_head_size)
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # torch.matmul矩阵相乘
        # key_layer.transpose(-1, -2): (B, num_attention_heads, attention_head_size, n_patch)
        # attention_scores: (B, num_attention_heads, n_patch, n_patch)
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        # context_layer (B, num_attention_heads, n_patch, attention_head_size)
        context_layer = torch.matmul(attention_probs, value_layer)
        # context_layer (B, n_patch, num_attention_heads, attention_head_size)
        # contiguous一般与transpose,permute,view搭配使用:使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        # new_context_layer_shape (B, n_patch,all_head_size)
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        # attention_output (B, n_patch,hidden_size)
        # 小细节 attention_head_size = int(hidden_size / num_attention_heads),all_head_size = num_attention_heads * attention_head_size
        # 所以应该满足hidden_size能被num_attention_heads整除
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights

Mlp也就是一个前馈神经网络

class Mlp(nn.Module):
    """
    Multi-Layer Perceptron: 多层感知器
    """
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        # nn.init.xavier_uniform_初始化权重,避免深度神经网络训练过程中的梯度消失和梯度爆炸问题
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        # nn.init.normal_是正态初始化函数
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

至此,Transformer所调用的模块结束了。


DecoderCup 对对应图片向上解码的部分:
在这里插入图片描述

在forward函数中的

B, n_patch, hidden = hidden_states.size()  # hidden_states: (B, n_patch, hidden)
h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
x = hidden_states.permute(0, 2, 1)  # x: (B, hidden, n_patch)
x = x.contiguous().view(B, hidden, h, w)  # x: (B, hidden, h, w)
x = self.conv_more(x)  # (B, hidden, h, w) ===> (B, 512, h', w')

将Transformer的输出(B, n_patch, hidden),先转化为(B, hidden, h, w),其中 h , w = n _ p a t c h = H 16 = W 16 h,w = \sqrt{n\_patch} = \frac{H}{16}= \frac{W}{16} h,w=n_patch =16H=16W ,即:
在这里插入图片描述
然后通过卷积操作conv_more得到(512, hidden, h, w):
在这里插入图片描述

class DecoderCup(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        head_channels = 512
        self.conv_more = Conv2dReLU(
            config.hidden_size,
            head_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=True,
        )

        decoder_channels = config.decoder_channels  # decoder_channels (256, 128, 64, 16)
        in_channels = [head_channels] + list(decoder_channels[:-1])  # in_channels = [512, 256, 128, 64]
        out_channels = decoder_channels

        # config.n_skip = 3
        if self.config.n_skip != 0:
            skip_channels = self.config.skip_channels  # config.skip_channels = [512, 256, 64, 16]
            for i in range(4 - self.config.n_skip):  # re-select the skip channels according to n_skip
                skip_channels[3 - i] = 0  # ===》skip_channels = [512, 256, 64, 0]

        else:
            skip_channels = [0, 0, 0, 0]

        # in_channels = [512, 256, 128, 64] out_channels = (256, 128, 64, 16)
        blocks = [
            DecoderBlock(in_ch, out_ch, sk_ch) for in_ch, out_ch, sk_ch in zip(in_channels, out_channels, skip_channels)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, hidden_states, features=None):
        B, n_patch, hidden = hidden_states.size()  # hidden_states: (B, n_patch, hidden)
        h, w = int(np.sqrt(n_patch)), int(np.sqrt(n_patch))
        x = hidden_states.permute(0, 2, 1)  # x: (B, hidden, n_patch)
        x = x.contiguous().view(B, hidden, h, w)  # x: (B, hidden, h, w)
        x = self.conv_more(x)  # (B, hidden, h, w) ===> (B, 512, h, w)
        for i, decoder_block in enumerate(self.blocks):
            if features is not None:
                skip = features[i] if (i < self.config.n_skip) else None
            else:
                skip = None
            x = decoder_block(x, skip=skip)
        return x

DecoderBlock就是逐层向上解码的过程,首先通过插值上采样UpsamplingBilinear2d扩大H和W,随后与对应的feature进行拼接后进行卷积,即:
在这里插入图片描述

class DecoderBlock(nn.Module):
    def __init__(
            self,
            in_channels,
            out_channels,
            skip_channels=0,
            use_batchnorm=True,
    ):
        super().__init__()
        self.conv1 = Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.up = nn.UpsamplingBilinear2d(scale_factor=2)

    def forward(self, x, skip=None):
        x = self.up(x)
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        return x

SegmentationHead对应于图像分割部分:
在这里插入图片描述
nn.Identity()不对输入进行任何操作,常在分类任务中替换最后一层,得到分类前得到的特征,常用于迁移学习,用法举例:

model = models.resnet18()
# replace last linar layer with nn.Identity
model.fc = nn.Identity()

# get features for input
x = torch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape)
> torch.Size([1, 512])

SegmentationHead模块:

class SegmentationHead(nn.Sequential):

    def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1):
        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        super().__init__(conv2d, upsampling)

最后是ResNetV2模块,该模块在vit_seg_modeling_resnet_skip文件中,对应图片中的:
在这里插入图片描述
该模块的相关包及其工具函数:

import math

from os.path import join as pjoin
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F

def np2th(weights, conv=False):
    """Possibly convert HWIO to OIHW."""
    if conv:
        weights = weights.transpose([3, 2, 0, 1])
    return torch.from_numpy(weights)


class StdConv2d(nn.Conv2d):

    def forward(self, x):
        w = self.weight
        v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
        w = (w - m) / torch.sqrt(v + 1e-5)
        return F.conv2d(x, w, self.bias, self.stride, self.padding,
                        self.dilation, self.groups)


def conv3x3(cin, cout, stride=1, groups=1, bias=False):
    return StdConv2d(cin, cout, kernel_size=3, stride=stride,
                     padding=1, bias=bias, groups=groups)


def conv1x1(cin, cout, stride=1, bias=False):
    return StdConv2d(cin, cout, kernel_size=1, stride=stride,
                     padding=0, bias=bias)
class ResNetV2(nn.Module):
    """Implementation of Pre-activation (v2) ResNet mode."""

    def __init__(self, block_units, width_factor):
        super().__init__()
        width = int(64 * width_factor)
        self.width = width

        self.root = nn.Sequential(OrderedDict([
            ('conv', StdConv2d(3, width, kernel_size=7, stride=2, bias=False, padding=3)),
            ('gn', nn.GroupNorm(32, width, eps=1e-6)),
            ('relu', nn.ReLU(inplace=True)),
            # ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0))
        ]))

        self.body = nn.Sequential(OrderedDict([
            ('block1', nn.Sequential(OrderedDict(
                [('unit1', PreActBottleneck(cin=width, cout=width*4, cmid=width))] +
                [(f'unit{i:d}', PreActBottleneck(cin=width*4, cout=width*4, cmid=width)) for i in range(2, block_units[0] + 1)],
                ))),
            ('block2', nn.Sequential(OrderedDict(
                [('unit1', PreActBottleneck(cin=width*4, cout=width*8, cmid=width*2, stride=2))] +
                [(f'unit{i:d}', PreActBottleneck(cin=width*8, cout=width*8, cmid=width*2)) for i in range(2, block_units[1] + 1)],
                ))),
            ('block3', nn.Sequential(OrderedDict(
                [('unit1', PreActBottleneck(cin=width*8, cout=width*16, cmid=width*4, stride=2))] +
                [(f'unit{i:d}', PreActBottleneck(cin=width*16, cout=width*16, cmid=width*4)) for i in range(2, block_units[2] + 1)],
                ))),
        ]))

    def forward(self, x):
        features = []
        b, c, in_size, _ = x.size()
        x = self.root(x)
        features.append(x)
        x = nn.MaxPool2d(kernel_size=3, stride=2, padding=0)(x)
        for i in range(len(self.body)-1):
            x = self.body[i](x)
            right_size = int(in_size / 4 / (i+1))
            if x.size()[2] != right_size:
                pad = right_size - x.size()[2]
                assert pad < 3 and pad > 0, "x {} should {}".format(x.size(), right_size)
                feat = torch.zeros((b, x.size()[1], right_size, right_size), device=x.device)
                feat[:, :, 0:x.size()[2], 0:x.size()[3]] = x[:]
            else:
                feat = x
            features.append(feat)
        x = self.body[-1](x)
        return x, features[::-1]
class PreActBottleneck(nn.Module):
    """Pre-activation (v2) bottleneck block.
    """

    def __init__(self, cin, cout=None, cmid=None, stride=1):
        super().__init__()
        cout = cout or cin
        cmid = cmid or cout//4

        self.gn1 = nn.GroupNorm(32, cmid, eps=1e-6)
        self.conv1 = conv1x1(cin, cmid, bias=False)
        self.gn2 = nn.GroupNorm(32, cmid, eps=1e-6)
        self.conv2 = conv3x3(cmid, cmid, stride, bias=False)  # Original code has it on conv1!!
        self.gn3 = nn.GroupNorm(32, cout, eps=1e-6)
        self.conv3 = conv1x1(cmid, cout, bias=False)
        self.relu = nn.ReLU(inplace=True)

        if (stride != 1 or cin != cout):
            # Projection also with pre-activation according to paper.
            self.downsample = conv1x1(cin, cout, stride, bias=False)
            self.gn_proj = nn.GroupNorm(cout, cout)

    def forward(self, x):

        # Residual branch
        residual = x
        if hasattr(self, 'downsample'):
            residual = self.downsample(x)
            residual = self.gn_proj(residual)

        # Unit's branch
        y = self.relu(self.gn1(self.conv1(x)))
        y = self.relu(self.gn2(self.conv2(y)))
        y = self.gn3(self.conv3(y))

        y = self.relu(residual + y)
        return y

    def load_from(self, weights, n_block, n_unit):
        conv1_weight = np2th(weights[pjoin(n_block, n_unit, "conv1/kernel")], conv=True)
        conv2_weight = np2th(weights[pjoin(n_block, n_unit, "conv2/kernel")], conv=True)
        conv3_weight = np2th(weights[pjoin(n_block, n_unit, "conv3/kernel")], conv=True)

        gn1_weight = np2th(weights[pjoin(n_block, n_unit, "gn1/scale")])
        gn1_bias = np2th(weights[pjoin(n_block, n_unit, "gn1/bias")])

        gn2_weight = np2th(weights[pjoin(n_block, n_unit, "gn2/scale")])
        gn2_bias = np2th(weights[pjoin(n_block, n_unit, "gn2/bias")])

        gn3_weight = np2th(weights[pjoin(n_block, n_unit, "gn3/scale")])
        gn3_bias = np2th(weights[pjoin(n_block, n_unit, "gn3/bias")])

        self.conv1.weight.copy_(conv1_weight)
        self.conv2.weight.copy_(conv2_weight)
        self.conv3.weight.copy_(conv3_weight)

        self.gn1.weight.copy_(gn1_weight.view(-1))
        self.gn1.bias.copy_(gn1_bias.view(-1))

        self.gn2.weight.copy_(gn2_weight.view(-1))
        self.gn2.bias.copy_(gn2_bias.view(-1))

        self.gn3.weight.copy_(gn3_weight.view(-1))
        self.gn3.bias.copy_(gn3_bias.view(-1))

        if hasattr(self, 'downsample'):
            proj_conv_weight = np2th(weights[pjoin(n_block, n_unit, "conv_proj/kernel")], conv=True)
            proj_gn_weight = np2th(weights[pjoin(n_block, n_unit, "gn_proj/scale")])
            proj_gn_bias = np2th(weights[pjoin(n_block, n_unit, "gn_proj/bias")])

            self.downsample.weight.copy_(proj_conv_weight)
            self.gn_proj.weight.copy_(proj_gn_weight.view(-1))
            self.gn_proj.bias.copy_(proj_gn_bias.view(-1))

由于只有在hybrid模式下才用到这部分的代码,所以目前并没有去了解为什么采用StdConv2d和GroupNorm,后面再去ViT里面找答案吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/867186.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

制造业为什么要建设数字化供应链

数字化让越来越多的人走向了线上的世界&#xff0c;让那些拥有线上产品或提供线上服务的企业提供了更多流量。 但与此同时&#xff0c;传统制造业遭受了沉重的打击&#xff0c;考虑到防疫要求&#xff0c;很多工厂长期处于人手不足的状态&#xff0c;生产制造效率大幅降低&…

激活函数总结(六):ReLU系列激活函数补充(RReLU、CELU、ReLU6)

激活函数总结&#xff08;六&#xff09;&#xff1a;ReLU系列激活函数补充 1 引言2 激活函数2.1 RReLU激活函数2.2 CELU激活函数2.3 ReLU6 激活函数 3. 总结 1 引言 在前面的文章中已经介绍了介绍了一系列激活函数 (Sigmoid、Tanh、ReLU、Leaky ReLU、PReLU、Swish、ELU、SEL…

用python写一个简单的贪吃蛇游戏

入门教程、案例源码、学习资料、读者群 请访问&#xff1a; python666.cn 大家好&#xff0c;欢迎来到 Crossin的编程教室 &#xff01; 不知道有多少同学跟我一样&#xff0c;最初接触编程的动机就是为了自己做个游戏玩&#xff1f; Python 虽然并不是一个“为游戏而生”的语言…

给QT添加图片

给QT添加图片 第一步: 添加图片资源文件。

基于深度学习的3D城市模型增强【Mask R-CNN】

在这篇文章中&#xff0c;我们描述了一个为阿姆斯特丹 3D 城市模型自动添加门窗的系统&#xff08;可以在这里访问&#xff09;。 计算机视觉用于从城市全景图像中提取有关门窗位置的信息。 由于这种类型的街道级图像广泛可用&#xff0c;因此该方法可用于较大的地理区域。 推荐…

LinearAlgebraMIT_9_LinearIndependence/SpanningASpace/Basis/Dimension

这节课我们主要学习一下(Linear Independence)线性无关&#xff0c;(spanning a space)生成空间&#xff0c;(basis)基和(dimension)维度。同时我们要注意这四个很重要的基本概念的描述对象&#xff0c;我们会说向量组线性无关&#xff0c;由一个向量组生成的空间&#xff0c;子…

哪些CRM的报价公开且透明?

企业在选型时&#xff0c;会发现很多品牌的CRM系统价格并不透明&#xff0c;往往都是需要跟产品顾问沟通后才能了解。下面推荐一款价格实在的CRM系统&#xff0c;所有报价公开透明&#xff0c;那就是Zoho CRM。 Zoho CRM是什么&#xff1f; Zoho CRM是一款在线CRM软件&#x…

将十进制(整数型)转换为二进制(字符串型)numpy.binary_repr()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 将十进制&#xff08;整数型&#xff09;转换为 二进制&#xff08;字符串型&#xff09; numpy.binary_repr() [太阳]选择题 下列代码最后一次输出的结果是&#xff1f; import numpy as np…

从键盘输入一些字符,并逐个把它们送到磁盘上去,直到用户输入一个“#”为止

题为c程序设计&#xff08;第五版&#xff09;谭浩强 例10.1 目录 文章目录 前言 一、题目复现 二、实现步骤 1.思路分析 2.具体实现 总结 前言 这篇博客&#xff0c;让我们一起学习顺序读写数据文件。 什么是顺序读写呢&#xff1f;顺序读写就是对文件读写数据的顺序和数据在…

云端的“人机之恋”,离我们还有多远?

不同于人与人之间复杂而多变的关系&#xff0c;AI与人的关系往往简单却又微妙。 在往来的语句对话中&#xff0c;AI通常通过文本语言的训练与学习去面对被抛出的问题。延伸至技术领域&#xff0c;主流的Transformer架构也仅仅是通过数据对物理世界的压缩来实现自我智能的涌现&a…

101. 对称二叉树

题目 原题链接 : 101.对称二叉树 题面 : 对于这一题呢&#xff0c;题目要求给出递归和迭代两种方式来解决!!! 注 : 这一题不仅仅是判断左右两个子节点是否对称,而是要遍历两棵树而且要比较内侧和外侧节点 递归 先确认递归三要素 : 确定递归函数的参数和返回值 bool …

gitee(码云)如何生成并添加公钥配置用户信息

一&#xff0c;简介 在使用Gitee的时候&#xff0c;公钥是必须的&#xff0c;无论是克隆还是上传。本文主要介绍如何本地生成和添加公钥到服务器&#xff0c;然后配置自己的用户信息&#xff0c;方便日后拉取与上传代码。 二&#xff0c;步骤介绍 2.1 本地生成公钥 打开git ba…

Linux上安装温度监控软件

文章目录 Linux上安装温度监控软件IDRAC设置 Linux上安装温度监控软件 服务器的温度是影响服务器性能重要条件&#xff0c;怎么监控机器的温度呢&#xff0c;这里知道的有两种方式 通过管理界面&#xff0c;查看机器的温度通过机器上安装监监控软件来监控温度 在物理机上怎么…

Go把Map转成对象

最近使用了Redis的Hash&#xff0c;把一个对象给存储到了hash里面&#xff0c;具体如下&#xff1a; 现在需要从RedisHash缓存里面把结果给取出来&#xff0c;同时赋值到一个对象上面 result, err : global.GVA_REDIS.HGetAll(context.Background(), key).Result() 问题是resul…

接触式静电压测试仪的使用场景和注意事项

接触式静电压测试仪是一种用于测量物体表面静电电势的工具。它使用金属接触针或传感器接触待测试物体表面&#xff0c;通过测量传感器和地面之间的电势差来确定物体表面的静电电势。 接触式静电压测试仪通常用于以下场景&#xff1a; 1. 静电防护&#xff1a;在静电敏感环境中…

腾讯云轻量和CVM有什么区别?不都是服务器吗?

腾讯云轻量服务器和云服务器有什么区别&#xff1f;为什么轻量应用服务器价格便宜&#xff1f;是因为轻量服务器CPU内存性能比云服务器CVM性能差吗&#xff1f;轻量应用服务器适合中小企业或个人开发者搭建企业官网、博客论坛、微信小程序或开发测试环境&#xff0c;云服务器CV…

机器学习深度学习——从编码器-解码器架构到seq2seq(机器翻译)

&#x1f468;‍&#x1f393;作者简介&#xff1a;一位即将上大四&#xff0c;正专攻机器学习的保研er &#x1f30c;上期文章&#xff1a;机器学习&&深度学习——注意力提示、注意力池化&#xff08;核回归&#xff09; &#x1f4da;订阅专栏&#xff1a;机器学习&a…

UDP服务器—实现数据通信

目录 前言 1.接口介绍 2.编写服务器 3.编写客户端 4.测试 总结 前言 在这篇文章中为大家介绍如何通过编码实现数据通信&#xff0c;实现思路是根据前面介绍的网络编程函数编写一个服务端和客户端&#xff0c;实现客户端和服务端双方通信 1.接口介绍 创建套接字 #include…

前端:Vue.js学习

前端:Vue.js学习 1. 第一个Vue程序2. Vue指令2.1 v-if、v-else-if、v-else2.2 v-for2.3 事件绑定 v-on:2.4 v-model 数据双向绑定2.5 v-bind 绑定属性 3. Vue组件4. Vue axios异步通信5. 计算属性6. 插槽 slots7. 自定义事件内容分发 1. 第一个Vue程序 首先把vue.js拷贝到本地…

腾讯云服务器CVM实例族标准型/内存/高IO/大数据/GPU有什么区别?

腾讯云服务器CVM有多种实例族&#xff0c;如标准型S6、标准型S5、SA3实例、高IO型、内存、计算型及GPU型实例等&#xff0c;如何选择云服务器CVM实例规格呢&#xff1f;腾讯云服务器网建议根据实际使用场景选择云服务器CVM规格&#xff0c;例如Web网站应用可以选择标准型S5或S6…