爆改YOLOv8 | 利用MB-TaylorFormer提高YOLOv8图像去雾检测

news2025/1/18 10:51:29

1,本文介绍

MB-TaylorFormer是一种新型多支路线性Transformer网络,用于图像去雾任务。它通过泰勒展开改进了softmax-attention,使用多支路和多尺度结构以获取多层次和多尺度的信息,且比传统方法在性能、计算量和网络重量上更优。

关于MB-TaylorFormer的详细介绍可以看论文:https://arxiv.org/pdf/2308.14036.pdf

本文将讲解如何将MB-TaylorFormer融合进yolov8

话不多说,上代码!

2,将MB-TaylorFormer融合进YOLOv8

2.1 步骤一

首先找到如下的目录'ultralytics/nn',然后在这个目录下创建一个'Addmodules'文件夹,然后在这个目录下创建一个TaylorFormer.py文件,文件名字可以根据你自己的习惯起,然后将MB-TaylorFormer的核心代码复制进去。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.ops.deform_conv import DeformConv2d
import numbers
import math
from einops import rearrange
import numpy as np
 
 
__all__ = ['MB_TaylorFormer']
 
freqs_dict = dict()
 
 
##########################################################################
 
def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')
 
 
def to_4d(x, h, w):
    return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
 
 
class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)
 
        assert len(normalized_shape) == 1
 
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.normalized_shape = normalized_shape
 
    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma + 1e-5) * self.weight
 
 
class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)
 
        assert len(normalized_shape) == 1
 
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.normalized_shape = normalized_shape
 
    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias
 
 
class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type):
        super(LayerNorm, self).__init__()
        if LayerNorm_type == 'BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)
 
    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)
 
 
##########################################################################
## Gated-Dconv Feed-Forward Network (GDFN)
class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()
 
        hidden_features = int(dim * ffn_expansion_factor)
 
        self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
 
        self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1,
                                groups=hidden_features * 2, bias=bias)
 
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
 
    def forward(self, x):
        x = self.project_in(x)
        x1, x2 = self.dwconv(x).chunk(2, dim=1)
        x = F.gelu(x1) * x2
        x = self.project_out(x)
        return x
 
 
class refine_att(nn.Module):
    """Convolutional relative position encoding."""
 
    def __init__(self, Ch, h, window):
 
        super().__init__()
 
        if isinstance(window, int):
            # Set the same window size for all attention heads.
            window = {window: h}
            self.window = window
        elif isinstance(window, dict):
            self.window = window
        else:
 
            raise ValueError()
 
        self.conv_list = nn.ModuleList()
        self.head_splits = []
        for cur_window, cur_head_split in window.items():
            dilation = 1  # Use dilation=1 at default.
            padding_size = (cur_window + (cur_window - 1) *
                            (dilation - 1)) // 2
            cur_conv = nn.Conv2d(
                cur_head_split * Ch * 2,
                cur_head_split,
                kernel_size=(cur_window, cur_window),
                padding=(padding_size, padding_size),
                dilation=(dilation, dilation),
                groups=cur_head_split,
            )
 
            self.conv_list.append(cur_conv)
            self.head_splits.append(cur_head_split)
        self.channel_splits = [x * Ch * 2 for x in self.head_splits]
 
    def forward(self, q, k, v, size):
        """foward function"""
        B, h, N, Ch = q.shape
        H, W = size
 
        # We don't use CLS_TOKEN
        q_img = q
        k_img = k
        v_img = v
        # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W].
        q_img = rearrange(q_img, "B h (H W) Ch -> B h Ch H W", H=H, W=W)
        k_img = rearrange(k_img, "B h Ch (H W) -> B h Ch H W", H=H, W=W)
        qk_concat = torch.cat((q_img, k_img), 2)
        qk_concat = rearrange(qk_concat, "B h Ch H W -> B (h Ch) H W", H=H, W=W)
        # Split according to channels.
        qk_concat_list = torch.split(qk_concat, self.channel_splits, dim=1)
        qk_att_list = [
            conv(x) for conv, x in zip(self.conv_list, qk_concat_list)
        ]
        qk_att = torch.cat(qk_att_list, dim=1)
        # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch].
        qk_att = rearrange(qk_att, "B (h Ch) H W -> B h (H W) Ch", h=h)
        return qk_att
##########################################################################
## Multi-DConv Head Transposed Self-Attention (MDTA)
class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias, shared_refine_att=None, qk_norm=1):
        super(Attention, self).__init__()
        self.norm = qk_norm
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        # self.Leakyrelu=nn.LeakyReLU(negative_slope=0.01,inplace=True)
        self.sigmoid = nn.Sigmoid()
        self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        if num_heads == 8:
            crpe_window = {
                3: 2,
                5: 3,
                7: 3
            }
        elif num_heads == 1:
            crpe_window = {
                3: 1,
            }
        elif num_heads == 2:
            crpe_window = {
                3: 2,
            }
        elif num_heads == 4:
            crpe_window = {
                3: 2,
                5: 2,
            }
        self.refine_att = refine_att(Ch=dim // num_heads,
                                     h=num_heads,
                                     window=crpe_window)
    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.qkv_dwconv(self.qkv(x))
        q, k, v = qkv.chunk(3, dim=1)
        q = rearrange(q, 'b (head c) h w -> b head (h w) c', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head (h w) c', head=self.num_heads)
        # q = torch.nn.functional.normalize(q, dim=-1)
        q_norm = torch.norm(q, p=2, dim=-1, keepdim=True) / self.norm + 1e-6
        q = torch.div(q, q_norm)
        k_norm = torch.norm(k, p=2, dim=-2, keepdim=True) / self.norm + 1e-6
        k = torch.div(k, k_norm)
        # k = torch.nn.functional.normalize(k, dim=-2)
        refine_weight = self.refine_att(q, k, v, size=(h, w))
        # refine_weight=self.Leakyrelu(refine_weight)
        refine_weight = self.sigmoid(refine_weight)
        attn = k @ v
        # attn = attn.softmax(dim=-1)
        # print(torch.sum(k, dim=-1).unsqueeze(3).shape)
        out_numerator = torch.sum(v, dim=-2).unsqueeze(2) + (q @ attn)
        out_denominator = torch.full((h * w, c // self.num_heads), h * w).to(q.device) \
                          + q @ torch.sum(k, dim=-1).unsqueeze(3).repeat(1, 1, 1, c // self.num_heads) + 1e-6
        # out=torch.div(out_numerator,out_denominator)*self.temperature*refine_weight
        out = torch.div(out_numerator, out_denominator) * self.temperature
        out = out * refine_weight
        out = rearrange(out, 'b head (h w) c-> b (head c) h w', head=self.num_heads, h=h, w=w)
        out = self.project_out(out)
        return out
##########################################################################
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type, shared_refine_att=None, qk_norm=1):
        super(TransformerBlock, self).__init__()
        self.norm1 = LayerNorm(dim, LayerNorm_type)
        self.attn = Attention(dim, num_heads, bias, shared_refine_att=shared_refine_att, qk_norm=qk_norm)
        self.norm2 = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.ffn(self.norm2(x))
        return x
class MHCAEncoder(nn.Module):
    """Multi-Head Convolutional self-Attention Encoder comprised of `MHCA`
    blocks."""
    def __init__(
            self,
            dim,
            num_layers=1,
            num_heads=8,
            ffn_expansion_factor=2.66,
            bias=False,
            LayerNorm_type='BiasFree',
            qk_norm=1
    ):
        super().__init__()
        self.num_layers = num_layers
        self.MHCA_layers = nn.ModuleList([
            TransformerBlock(
                dim,
                num_heads=num_heads,
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
                qk_norm=qk_norm
            ) for idx in range(self.num_layers)
        ])
    def forward(self, x, size):
        """foward function"""
        H, W = size
        B = x.shape[0]
        # return x's shape : [B, N, C] -> [B, C, H, W]
        x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
 
        for layer in self.MHCA_layers:
            x = layer(x)
 
        return x
 
 
class ResBlock(nn.Module):
    """Residual block for convolutional local feature."""
 
    def __init__(
            self,
            in_features,
            hidden_features=None,
            out_features=None,
            act_layer=nn.Hardswish,
            norm_layer=nn.BatchNorm2d,
    ):
        super().__init__()
 
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        # self.act0 = act_layer()
        self.conv1 = Conv2d_BN(in_features,
                               hidden_features,
                               act_layer=act_layer)
        self.dwconv = nn.Conv2d(
            hidden_features,
            hidden_features,
            3,
            1,
            1,
            bias=False,
            groups=hidden_features,
        )
        # self.norm = norm_layer(hidden_features)
        self.act = act_layer()
        self.conv2 = Conv2d_BN(hidden_features, out_features)
        self.apply(self._init_weights)
 
    def _init_weights(self, m):
        """
        initialization
        """
        if isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
 
    def forward(self, x):
        """foward function"""
        identity = x
        # x=self.act0(x)
        feat = self.conv1(x)
        feat = self.dwconv(feat)
        # feat = self.norm(feat)
        feat = self.act(feat)
        feat = self.conv2(feat)
 
        return identity + feat
 
 
class MHCA_stage(nn.Module):
    """Multi-Head Convolutional self-Attention stage comprised of `MHCAEncoder`
    layers."""
 
    def __init__(
            self,
            embed_dim,
            out_embed_dim,
            num_layers=1,
            num_heads=8,
            ffn_expansion_factor=2.66,
            num_path=4,
            bias=False,
            LayerNorm_type='BiasFree',
            qk_norm=1
 
    ):
        super().__init__()
 
        self.mhca_blks = nn.ModuleList([
            MHCAEncoder(
                embed_dim,
                num_layers,
                num_heads,
                ffn_expansion_factor=ffn_expansion_factor,
                bias=bias,
                LayerNorm_type=LayerNorm_type,
                qk_norm=qk_norm
 
            ) for _ in range(num_path)
        ])
 
        self.aggregate = SKFF(embed_dim, height=num_path)
 
        # self.InvRes = ResBlock(in_features=embed_dim, out_features=embed_dim)
 
    # self.aggregate = Conv2d_aggregate(embed_dim * (num_path + 1),
    #                           out_embed_dim,
    #                           act_layer=nn.Hardswish)
 
    def forward(self, inputs):
        """foward function"""
        # att_outputs = [self.InvRes(inputs[0])]
        att_outputs = []
 
        for x, encoder in zip(inputs, self.mhca_blks):
            # [B, C, H, W] -> [B, N, C]
            _, _, H, W = x.shape
            x = x.flatten(2).transpose(1, 2).contiguous()
            att_outputs.append(encoder(x, size=(H, W)))
 
        # out_concat = torch.cat(att_outputs, dim=1)
        out = self.aggregate(att_outputs)
 
        return out
 
 
##########################################################################
## Overlapped image patch embedding with 3x3 Conv
class Conv2d_BN(nn.Module):
 
    def __init__(
            self,
            in_ch,
            out_ch,
            kernel_size=1,
            stride=1,
            pad=0,
            dilation=1,
            groups=1,
            bn_weight_init=1,
            norm_layer=nn.BatchNorm2d,
            act_layer=None,
    ):
        super().__init__()
 
        self.conv = torch.nn.Conv2d(in_ch,
                                    out_ch,
                                    kernel_size,
                                    stride,
                                    pad,
                                    dilation,
                                    groups,
                                    bias=False)
        # self.bn = norm_layer(out_ch)
 
        # torch.nn.init.constant_(self.bn.weight, bn_weight_init)
        # torch.nn.init.constant_(self.bn.bias, 0)
 
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # Note that there is no bias due to BN
                fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out))
 
        self.act_layer = act_layer() if act_layer is not None else nn.Identity()
 
    def forward(self, x):
 
        x = self.conv(x)
        # x = self.bn(x)
        x = self.act_layer(x)
 
        return x
 
 
class SKFF(nn.Module):
    def __init__(self, in_channels, height=2, reduction=8, bias=False):
        super(SKFF, self).__init__()
 
        self.height = height
        d = max(int(in_channels / reduction), 4)
 
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv_du = nn.Sequential(nn.Conv2d(in_channels, d, 1, padding=0, bias=bias), nn.PReLU())
 
        self.fcs = nn.ModuleList([])
        for i in range(self.height):
            self.fcs.append(nn.Conv2d(d, in_channels, kernel_size=1, stride=1, bias=bias))
 
        self.softmax = nn.Softmax(dim=1)
 
    def forward(self, inp_feats):
        batch_size = inp_feats[0].shape[0]
        n_feats = inp_feats[0].shape[1]
 
        inp_feats = torch.cat(inp_feats, dim=1)
        inp_feats = inp_feats.view(batch_size, self.height, n_feats, inp_feats.shape[2], inp_feats.shape[3])
 
        feats_U = torch.sum(inp_feats, dim=1)
        feats_S = self.avg_pool(feats_U)
        feats_Z = self.conv_du(feats_S)
 
        attention_vectors = [fc(feats_Z) for fc in self.fcs]
        attention_vectors = torch.cat(attention_vectors, dim=1)
        attention_vectors = attention_vectors.view(batch_size, self.height, n_feats, 1, 1)
        # stx()
        attention_vectors = self.softmax(attention_vectors)
 
        feats_V = torch.sum(inp_feats * attention_vectors, dim=1)
 
        return feats_V
 
 
class DWConv2d_BN(nn.Module):
 
    def __init__(
            self,
            in_ch,
            out_ch,
            kernel_size=1,
            stride=1,
            norm_layer=nn.BatchNorm2d,
            act_layer=nn.Hardswish,
            bn_weight_init=1,
            offset_clamp=(-1, 1)
    ):
        super().__init__()
        # dw
        # self.conv=torch.nn.Conv2d(in_ch,out_ch,kernel_size,stride,(kernel_size - 1) // 2,bias=False,)
        # self.mask_generator = nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=3,
        #                                                 stride=1, padding=1, bias=False, groups=in_ch),
        #                                       nn.Conv2d(in_channels=in_ch, out_channels=9,
        #                                                 kernel_size=1,
        #                                                 stride=1, padding=0, bias=False)
        #                                      )
        self.offset_clamp = offset_clamp
        self.offset_generator = nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=in_ch, kernel_size=3,
                                                        stride=1, padding=1, bias=False, groups=in_ch),
                                              nn.Conv2d(in_channels=in_ch, out_channels=18,
                                                        kernel_size=1,
                                                        stride=1, padding=0, bias=False)
 
                                              )
        self.dcn = DeformConv2d(
            in_channels=in_ch,
            out_channels=in_ch,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False,
            groups=in_ch
        )  # .cuda(7)
        self.pwconv = nn.Conv2d(in_ch, out_ch, 1, 1, 0, bias=False)
 
        # self.bn = norm_layer(out_ch)
        self.act = act_layer() if act_layer is not None else nn.Identity()
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2.0 / n))
                if m.bias is not None:
                    m.bias.data.zero_()
                # print(m)
 
        #   elif isinstance(m, nn.BatchNorm2d):
        #     m.weight.data.fill_(bn_weight_init)
        #      m.bias.data.zero_()
 
    def forward(self, x):
 
        # x=self.conv(x)
        # x = self.bn(x)
        # x = self.act(x)
        # mask= torch.sigmoid(self.mask_generator(x))
        # print('1')
        offset = self.offset_generator(x)
        # print('2')
        if self.offset_clamp:
            offset = torch.clamp(offset, min=self.offset_clamp[0], max=self.offset_clamp[1])  # .cuda(7)1
        # print(offset)
        # print('3')
        # x=x.cuda(7)
        x = self.dcn(x, offset)
        # x=x.cpu()
        # print('4')
        x = self.pwconv(x)
        # print('5')
        # x = self.bn(x)
        x = self.act(x)
        return x
 
 
class DWCPatchEmbed(nn.Module):
    """Depthwise Convolutional Patch Embedding layer Image to Patch
    Embedding."""
 
    def __init__(self,
                 in_chans=3,
                 embed_dim=768,
                 patch_size=16,
                 stride=1,
                 idx=0,
                 act_layer=nn.Hardswish,
                 offset_clamp=(-1, 1)):
        super().__init__()
 
        self.patch_conv = DWConv2d_BN(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=stride,
            act_layer=act_layer,
            offset_clamp=offset_clamp
        )
        """
        self.patch_conv = DWConv2d_BN(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=stride,
            act_layer=act_layer,
        )
        """
 
    def forward(self, x):
        """foward function"""
        x = self.patch_conv(x)
 
        return x
 
 
class Patch_Embed_stage(nn.Module):
    """Depthwise Convolutional Patch Embedding stage comprised of
    `DWCPatchEmbed` layers."""
 
    def __init__(self, in_chans, embed_dim, num_path=4, isPool=False, offset_clamp=(-1, 1)):
        super(Patch_Embed_stage, self).__init__()
 
        self.patch_embeds = nn.ModuleList([
            DWCPatchEmbed(
                in_chans=in_chans if idx == 0 else embed_dim,
                embed_dim=embed_dim,
                patch_size=3,
                stride=1,
                idx=idx,
                offset_clamp=offset_clamp
            ) for idx in range(num_path)
        ])
 
    def forward(self, x):
        """foward function"""
        att_inputs = []
        for pe in self.patch_embeds:
            x = pe(x)
            att_inputs.append(x)
 
        return att_inputs
 
 
class OverlapPatchEmbed(nn.Module):
    def __init__(self, in_c=3, embed_dim=48, bias=False):
        super(OverlapPatchEmbed, self).__init__()
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias)
        # self.proj_dw = nn.Conv2d(in_c, in_c, kernel_size=3, stride=1, padding=1,groups=in_c, bias=bias)
        # self.proj_pw = nn.Conv2d(in_c, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias)
 
        # self.bn=nn.BatchNorm2d(embed_dim)
        # self.act=nn.Hardswish()
 
    def forward(self, x):
        x = self.proj(x)
        # x = self.proj_dw(x)
        # x= self.proj_pw(x)
        # x=self.bn(x)
        # x=self.act(x)
        return x
 
 
##########################################################################
## Resizing modules
class Downsample(nn.Module):
    def __init__(self, input_feat, out_feat):
        super(Downsample, self).__init__()
 
        self.body = nn.Sequential(  # nn.Conv2d(n_feat, n_feat // 2, kernel_size=3, stride=1, padding=1, bias=False),
            # dw
            nn.Conv2d(input_feat, input_feat, kernel_size=3, stride=1, padding=1, groups=input_feat, bias=False, ),
            # pw-linear
            nn.Conv2d(input_feat, out_feat // 4, 1, 1, 0, bias=False),
            # nn.BatchNorm2d(n_feat // 2),
            # nn.Hardswish(),
            nn.PixelUnshuffle(2))
 
    def forward(self, x):
        return self.body(x)
 
 
class Upsample(nn.Module):
    def __init__(self, input_feat, out_feat):
        super(Upsample, self).__init__()
 
        self.body = nn.Sequential(  # nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
            # dw
            nn.Conv2d(input_feat, input_feat, kernel_size=3, stride=1, padding=1, groups=input_feat, bias=False, ),
            # pw-linear
            nn.Conv2d(input_feat, out_feat * 4, 1, 1, 0, bias=False),
            # nn.BatchNorm2d(n_feat*2),
            # nn.Hardswish(),
            nn.PixelShuffle(2))
 
    def forward(self, x):
        return self.body(x)
 
 
##########################################################################
##---------- Restormer -----------------------
class MB_TaylorFormer(nn.Module):
    def __init__(self,
                 inp_channels=3,
                 dim=[6, 12, 24, 36],
                 num_blocks=[1, 1, 1, 1],
                 heads=[1, 1, 1, 1],
                 bias=False,
                 dual_pixel_task=True,
                 num_path=[1, 1, 1, 1],  ## True for dual-pixel defocus deblurring only. Also set inp_channels=6
                 qk_norm=1,
                 offset_clamp=(-1, 1)
                 ):
 
        super(MB_TaylorFormer, self).__init__()
 
        self.patch_embed = OverlapPatchEmbed(inp_channels, dim[0])
        self.patch_embed_encoder_level1 = Patch_Embed_stage(dim[0], dim[0], num_path=num_path[0], isPool=False,
                                                            offset_clamp=offset_clamp)
        self.encoder_level1 = MHCA_stage(dim[0], dim[0], num_layers=num_blocks[0], num_heads=heads[0],
                                         ffn_expansion_factor=2.66, num_path=num_path[0],
                                         bias=False, LayerNorm_type='BiasFree', qk_norm=qk_norm)
 
        self.down1_2 = Downsample(dim[0], dim[1])  ## From Level 1 to Level 2
 
        self.patch_embed_encoder_level2 = Patch_Embed_stage(dim[1], dim[1], num_path=num_path[1], isPool=False,
                                                            offset_clamp=offset_clamp)
        self.encoder_level2 = MHCA_stage(dim[1], dim[1], num_layers=num_blocks[1], num_heads=heads[1],
                                         ffn_expansion_factor=2.66,
                                         num_path=num_path[1], bias=False, LayerNorm_type='BiasFree', qk_norm=qk_norm)
 
        self.down2_3 = Downsample(dim[1], dim[2])  ## From Level 2 to Level 3
 
        self.patch_embed_encoder_level3 = Patch_Embed_stage(dim[2], dim[2], num_path=num_path[2],
                                                            isPool=False, offset_clamp=offset_clamp)
        self.encoder_level3 = MHCA_stage(dim[2], dim[2], num_layers=num_blocks[2], num_heads=heads[2],
                                         ffn_expansion_factor=2.66,
                                         num_path=num_path[2], bias=False, LayerNorm_type='BiasFree', qk_norm=qk_norm)
 
        self.down3_4 = Downsample(dim[2], dim[3])  ## From Level 3 to Level 4
 
        self.patch_embed_latent = Patch_Embed_stage(dim[3], dim[3], num_path=num_path[3],
                                                    isPool=False, offset_clamp=offset_clamp)
        self.latent = MHCA_stage(dim[3], dim[3], num_layers=num_blocks[3], num_heads=heads[3],
                                 ffn_expansion_factor=2.66, num_path=num_path[3], bias=False,
                                 LayerNorm_type='BiasFree', qk_norm=qk_norm)
 
        self.up4_3 = Upsample(int(dim[3]), dim[2])  ## From Level 4 to Level 3
        self.reduce_chan_level3 = nn.Sequential(
            nn.Conv2d(dim[2] * 2, dim[2], 1, 1, 0, bias=bias),
            # nn.BatchNorm2d(dim * 2**2),
            # nn.Hardswish(),
        )
 
        self.patch_embed_decoder_level3 = Patch_Embed_stage(dim[2], dim[2], num_path=num_path[2],
                                                            isPool=False, offset_clamp=offset_clamp)
        self.decoder_level3 = MHCA_stage(dim[2], dim[2], num_layers=num_blocks[2], num_heads=heads[2],
                                         ffn_expansion_factor=2.66, num_path=num_path[2], bias=False,
                                         LayerNorm_type='BiasFree', qk_norm=qk_norm)
 
        self.up3_2 = Upsample(int(dim[2]), dim[1])  ## From Level 3 to Level 2
        self.reduce_chan_level2 = nn.Sequential(
            nn.Conv2d(dim[1] * 2, dim[1], 1, 1, 0, bias=bias),
            # nn.BatchNorm2d( dim * 2),
            # nn.Hardswish(),
        )
 
        self.patch_embed_decoder_level2 = Patch_Embed_stage(dim[1], dim[1], num_path=num_path[1],
                                                            isPool=False, offset_clamp=offset_clamp)
        self.decoder_level2 = MHCA_stage(dim[1], dim[1], num_layers=num_blocks[1], num_heads=heads[1],
                                         ffn_expansion_factor=2.66, num_path=num_path[1], bias=False,
                                         LayerNorm_type='BiasFree', qk_norm=qk_norm)
 
        self.up2_1 = Upsample(int(dim[1]), dim[0])  ## From Level 2 to Level 1  (NO 1x1 conv to reduce channels)
 
        self.patch_embed_decoder_level1 = Patch_Embed_stage(dim[1], dim[1], num_path=num_path[0],
                                                            isPool=False, offset_clamp=offset_clamp)
        self.decoder_level1 = MHCA_stage(dim[1], dim[1], num_layers=num_blocks[0], num_heads=heads[0],
                                         ffn_expansion_factor=2.66, num_path=num_path[0], bias=False,
                                         LayerNorm_type='BiasFree', qk_norm=qk_norm)
 
        self.patch_embed_refinement = Patch_Embed_stage(dim[1], dim[1], num_path=num_path[0],
                                                        isPool=False, offset_clamp=offset_clamp)
        self.refinement = MHCA_stage(dim[1], dim[1], num_layers=num_blocks[0], num_heads=heads[0],
                                     ffn_expansion_factor=2.66, num_path=num_path[0], bias=False,
                                     LayerNorm_type='BiasFree', qk_norm=qk_norm)
 
        #### For Dual-Pixel Defocus Deblurring Task ####
        self.dual_pixel_task = dual_pixel_task
        if self.dual_pixel_task:
            self.skip_conv = nn.Conv2d(dim[0], dim[1], kernel_size=1, bias=bias)
        ###########################
 
        # self.output = nn.Conv2d(dim*2**1, 3, kernel_size=3, stride=1, padding=1, bias=False)
 
        self.output = nn.Sequential(  # nn.Conv2d(n_feat, n_feat*2, kernel_size=3, stride=1, padding=1, bias=False),
            # nn.BatchNorm2d(dim*2),
            # nn.Hardswish(),
            nn.Conv2d(dim[1], 3, kernel_size=3, stride=1, padding=1, bias=False, ),
 
        )
 
    def forward(self, inp_img):
        inp_enc_level1 = self.patch_embed(inp_img)
 
        inp_enc_level1_list = self.patch_embed_encoder_level1(inp_enc_level1)
 
        out_enc_level1 = self.encoder_level1(inp_enc_level1_list) + inp_enc_level1
 
        # out_enc_level1 = self.encoder_level1(inp_enc_level1_list)
        inp_enc_level2 = self.down1_2(out_enc_level1)
 
        inp_enc_level2_list = self.patch_embed_encoder_level2(inp_enc_level2)
        out_enc_level2 = self.encoder_level2(inp_enc_level2_list) + inp_enc_level2
        inp_enc_level3 = self.down2_3(out_enc_level2)
 
        inp_enc_level3_list = self.patch_embed_encoder_level3(inp_enc_level3)
        out_enc_level3 = self.encoder_level3(inp_enc_level3_list) + inp_enc_level3
        inp_enc_level4 = self.down3_4(out_enc_level3)
 
        inp_latent = self.patch_embed_latent(inp_enc_level4)
        latent = self.latent(inp_latent) + inp_enc_level4
 
        inp_dec_level3 = self.up4_3(latent)
        inp_dec_level3 = torch.cat([inp_dec_level3, out_enc_level3], 1)
        inp_dec_level3 = self.reduce_chan_level3(inp_dec_level3)
        inp_dec_level3_list = self.patch_embed_decoder_level3(inp_dec_level3)
        out_dec_level3 = self.decoder_level3(inp_dec_level3_list) + inp_dec_level3
 
        inp_dec_level2 = self.up3_2(out_dec_level3)
 
        inp_dec_level2 = torch.cat([inp_dec_level2, out_enc_level2], 1)
 
        inp_dec_level2 = self.reduce_chan_level2(inp_dec_level2)
 
        inp_dec_level2_list = self.patch_embed_decoder_level2(inp_dec_level2)
 
        out_dec_level2 = self.decoder_level2(inp_dec_level2_list) + inp_dec_level2
 
        inp_dec_level1 = self.up2_1(out_dec_level2)
 
        inp_dec_level1 = torch.cat([inp_dec_level1, out_enc_level1], 1)
 
        inp_dec_level1_list = self.patch_embed_decoder_level1(inp_dec_level1)
 
        out_dec_level1 = self.decoder_level1(inp_dec_level1_list) + inp_dec_level1
 
        inp_latent_list = self.patch_embed_refinement(out_dec_level1)
 
        out_dec_level1 = self.refinement(inp_latent_list) + out_dec_level1
 
        # nn.Hardswish()
 
        #### For Dual-Pixel Defocus Deblurring Task ####
        if self.dual_pixel_task:
            out_dec_level1 = out_dec_level1 + self.skip_conv(inp_enc_level1)
            out_dec_level1 = self.output(out_dec_level1)
        ###########################
        else:
            out_dec_level1 = self.output(out_dec_level1) + inp_img
 
        return out_dec_level1
 
 
def count_param(model):
    param_count = 0
    for param in model.parameters():
        param_count += param.view(-1).size()[0]
    return param_count
 
 
if __name__ == "__main__":
    from thop import profile
 
    model = MB_TaylorFormer()
    model.eval()
    print("params", count_param(model))
    inputs = torch.randn(1, 3, 640, 640)
    output = model(inputs)
    print(output.size())

2.2 步骤二

在Addmodules下创建一个新的py文件名字为'__init__.py',然后在其内部添加如下代码

2.3 步骤三

在task.py进行导入

到此注册成功,复制后面的yaml文件直接运行即可

yaml文件


# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
 
# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
 
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, MB_TaylorFormer, []]  # 0-P1/2
  - [-1, 1, Conv, [64, 3, 2]]  # 1-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 2-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 4-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 6-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 8-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 10
 
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 7], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13
 
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 5], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)
 
  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)
 
  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)
 
  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

不知不觉已经看完了哦,动动小手留个点赞吧--_--

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2074674.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

拉取Github仓库错误

说明:记录一次拉取github仓库时,报下面的错误 $ git clone https://github.com/LamSpace/data-plus.git Cloning into data-plus... fatal: unable to access https://github.com/LamSpace/data-plus.git/: SSL certificate problem: unable to get loc…

windows本地搭建zookeeper和kafka环境

zookeeper 1.1 下载zookeeper 下载地址 随便进一个站点,默认是新版本,旧版本点击archives进入,选择合适的版本下载,本文使用的是3.7.2 下载时候选择apache-zookeeper-3.7.2-bin.tar.gz 格式的,编译后的,解…

Overleaf参考文献由 BibTex 转 \bibitem 格式

目录 Overleaf参考文献由 BibTex 转 \bibitem 格式一、获取引用论文的BibTex二、编写引用论文对应的bib文件三、编写生成bibitem的tex文件四、转化bibitem格式 参考资料 Overleaf参考文献由 BibTex 转 \bibitem 格式 一、获取引用论文的BibTex 搜索论文引用点击BibTex 跳转出…

Ps:首选项

Photoshop 的“首选项” Preferences是一个集中的设置面板,允许用户根据自己的工作流程和个人喜好来定制软件的行为和界面。 Windows: Ps菜单:编辑/首选项 Edit/Preferences 快捷键:Ctrl K macOS: Ps菜单:…

市域社会治理解决方案

1. 建设背景与政策解读 市域社会治理现代化是响应党的十九届四中全会《决定》和中央政法委秘书长陈一新的指示,旨在加强和创新社会治理,构建和谐社会。通过“综治中心网格化信息化”模式,提升基层社会治理能力,确保社会安定有序。…

一道ssrf题目--Web-ssrfme

目录 环境搭建 代码分析 漏洞点寻找 渗透 使用工具构造payload 结果 ​编辑 环境搭建 使用docker拉取上面文件 rootubuntu:~/web-ssrfme/web-ssrfme# docker-compose up -d代码分析 首先进入题目环境,查看docker发现在8091端口下,进入后出现这…

基于ssm+vue的前后端分离鲜花销售系统

系统背景 当前社会各行业领域竞争压力非常大,随着当前时代的信息化,科学化发展,让社会各行业领域都争相使用新的信息技术,对行业内的各种相关数据进行科学化,规范化管理。这样的大环境让那些止步不前,不接受…

Springboot里集成Mybatis-plus、ClickHouse

🌹作者主页:青花锁 🌹简介:Java领域优质创作者🏆、Java微服务架构公号作者😄 🌹简历模板、学习资料、面试题库、技术互助 🌹文末获取联系方式 📝 Springboot里集成Mybati…

第2章-09-浏览器同源策略及跨域

🏆作者简介,黑夜开发者,CSDN领军人物,全栈领域优质创作者✌,CSDN博客专家,阿里云社区专家博主,2023年CSDN全站百大博主。 🏆数年电商行业从业经验,历任核心研发工程师,项目技术负责人。 🏆本文已收录于专栏:Web爬虫入门与实战精讲,后续完整更新内容如下。 文章…

RPA自动化流程机器人助力企业财务数字化转型

在数字经济时代,企业需要快速响应市场变化,而财务数字化转型是企业适应现代商业环境、提升竞争力的必要步骤。财务数字化转型不仅涉及企业财务能力的提升,推动了财务管理与决策模式的转变。RPA自动化流程机器人因其能通过自动化技术帮助企业实…

【Hot100】LeetCode—200. 岛屿数量

目录 1- 思路DFS 深搜 2- 实现⭐200. 岛屿数量——题解思路 3- ACM 思路 题目连接:200. 岛屿数量 1- 思路 DFS 深搜 在遍历中对 res 结果进行 操作 。遇到一个陆地结果为 1 的地方, 就将他们直接填充为 0 思路 ① 先遍历,收集 res② 之后…

Springboot打包、部署

一、导入maven打包插件 <build><plugins><plugin><groupId>org.springframework.boot</groupId><artifactId>spring-boot-maven-plugin</artifactId></plugin></plugins> </build> 二、执行打包操作&#xff08;…

【消息中间件】RabbitMQ

1 基础篇 1.1 为什么需要消息队列 1.2 什么是消息队列 1.3 RabbitMQ简介 1.4 RabbbitMQ安装 一、安装 # 拉取镜像 docker pull rabbitmq:3.13-management# -d 参数&#xff1a;后台运行 Docker 容器 # --name 参数&#xff1a;设置容器名称 # -p 参数&#xff1a;映射端口号…

MFC对话框程序界面UI优化Demo

目录 1. 修改MFC对话框背景色 1.1 在对话框头文件中申明画刷对象 1.2 在对话框源文件中OnInitDialog()函数中创建画刷对象 1.3 重载OnCtlColor&#xff0c;即WM_CTLCOLOR消息 2 修改文本框等控件的背景色、文字颜色 3 修改文本框等控件的字体大小 4 修改按钮控件的背景色…

在DDD中应用模式

深层模型和柔性设计并非唾手可得。要想取得进展&#xff0c;必须学习大量领域知识并进行充分的讨论&#xff0c;还需要经历大量的尝试和失败。但有时我们也能从中获得一些优势。一位经验丰富的开发人员在研究领域问题时&#xff0c;如果发现了他所熟悉的某种职责或某个关系网&a…

Puppeteer Web 抓取:使用 Browserless 的 Docker

Docker 镜像介绍 Docker 镜像是用于在 Docker 容器中执行代码的文件。它类似于构建 Docker 容器的指令集&#xff0c;就像一个模板。换句话说&#xff0c;它们相当于虚拟机环境中的快照。 Docker 镜像包含运行容器所需的所有库、依赖项和文件&#xff0c;使其成为容器的独立可…

虚拟机Ubuntu误操作导致无法自动联网的解决办法

直接上解决办法&#xff1a;安装netplan.io。 sudo apt install netplan.io sudo netplan apply 问题描述&#xff1a; 虚拟机的Ubuntu由于我的乱操作不小心删除了很多东西&#xff0c;导致开机无法自动联网了。而且在设置里边也没了设置网络的地方。 尝试不同办法发现用sudo…

基于javaweb+jsp的鲜花商城系统

基于javawebjsp的鲜花商城系统的设计与实现~ 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SpringSpringMVCMyBatisJSP 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 前台 后台 摘要 当下&#xff0c;正处于信息化的时代&#xff0c;许多行…

超越GPT4V,最强多模态MiniCPM-V2.6模型分享

MiniCPM-V2.6是由OpenBMB开发的一款多模态大型语言模型&#xff08;MLLM&#xff09;&#xff0c;专为视觉-语言理解设计。 MiniCPM-V2.6模型能够处理图像、视频和文本输入&#xff0c;并提供高质量的文本输出。 MiniCPM-V 2.6模型在单图像理解方面超越了广泛使用的专有模型&…

78 Linux libusb库USB HID应用编程笔记

1 前言 这几天搞另外一个项目&#xff0c;基于Ubuntu开发一个小的应用程序&#xff0c;就是通过USB HID与设备通信。因此需要在Linux环境编写对应USB HID通信应用。 目前libusb库已经很好的支持USB相关应用的开发&#xff0c;库中提供了丰富的USB接口&#xff0c;用户可以直接调…