YOLOv11改进策略【注意力机制篇】| 添加SE、CBAM、ECA、CA、Swin Transformer等注意力和多头注意力机制

news2024/9/30 15:11:29

前言

这篇文章带来一个经典注意力模块的汇总,虽然有些模块已经发布很久了,但后续的注意力模块也都是在此基础之上进行改进的,对于初学者来说还是有必要去学习了解一下,以加深对模块,模型的理解。


文章目录

  • 前言
  • 一、为什么要引入注意力机制?
  • 二、SE
    • 2.1 SE的原理
    • 2.2 SE的实现代码
  • 三、CBAM
    • 3.1 CBAM的原理
    • 3.2 CBAM的实现代码
  • 四、ECA
    • 4.1 ECA的原理
    • 4.2 ECA的实现代码
  • 五、CA
    • 5.1 CA的原理
    • 5.2 CA的实现代码
  • 六、Swin Transformer
    • 6.1 Swin Transformer的原理
    • 6.2 Swin Transformer的实现代码
  • 七、添加步骤
    • 1. 修改ultralytics/nn/modules/block.py
    • 2. 修改ultralytics/nn/modules/__init__.py
    • 3. 修改ultralytics/nn/modules/tasks.py
  • 八、yaml模型文件
  • 九、成功运行结果


一、为什么要引入注意力机制?

来源:注意力机制的设计灵感来源于人类视觉系统。当我们在观察外界事物时,会自动将注意力集中在重要或感兴趣的区域,而忽略无关信息。计算机视觉中的注意力机制就是在试图模拟这一过程,以提高模型的感知和理解能力。

问题:随着图像数据量的增加,模型需要处理的信息量也随之增大。传统的卷积神经网络在处理大量数据时可能会遇到信息过载的问题,导致性能下降。注意力机制通过有选择地关注重要信息,帮助模型在海量数据中筛选出关键内容,从而提高检测精度。

好处

  • 注意力机制能够赋予输入数据的不同部分以不同的权重,使模型更加关注重要的特征信息。
  • 通过生成热力图,显示模型在做出决策时关注的具体区域,有助于更好地理解模型的决策过程,增强模型的可解释性。
  • 注意力机制使模型在处理不同数据集和任务时能够更灵活地调整其关注点。有助于提升模型的泛化能力,使其在面对新数据集或新任务时仍能保持较高的性能水平。

除了能够提升性能外,其最主要的还是其即插即用的特性,无论模块放在什么地方,都可以运行查看训练效果,更方便炼丹成功~

二、SE

2.1 SE的原理

通道注意力模块关注于网络中每个通道的重要性,通过为每个通道分配不同的权重,使得网络能够更加关注那些对任务更为关键的通道特征,从而提高模型的性能。其中主要涉及SqueezeExcitation两个操作。

  • Squeeze操作:通过全局平均池化将每个通道的特征图压缩为一个实数。
  • Excitation操作:利用两个全连接层(先降维后升维)和一个ReLU激活函数来学习通道间的依赖关系,并通过sigmoid函数生成权重向量。
  • Scale操作:将学习到的通道权重与原始特征图进行逐通道相乘,实现特征的重标定。
    在这里插入图片描述

论文:https://arxiv.org/abs/1709.01507
源码:https://github.com/hujie-frank/SENet

2.2 SE的实现代码

import torch.nn as nn


# SE
class SE(nn.Module):
    def __init__(self, c1, ratio=16):
        super(SE, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.l1 = nn.Linear(c1, c1 // ratio, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.l2 = nn.Linear(c1 // ratio, c1, bias=False)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.l1(y)
        y = self.relu(y)
        y = self.l2(y)
        y = self.sig(y)
        y = y.view(b, c, 1, 1)
        return x * y.expand_as(x)

注意❗:在7.2、7.3小节中的__init__.pytasks.py文件中需要声明的模块名称为:SE


三、CBAM

3.1 CBAM的原理

CBAM注意力模块通道注意力模块空间注意力模块两部分组成。它通过顺序地应用通道注意力和空间注意力,使得网络能够自适应地关注到输入特征图中最重要的通道和空间位置,从而提高模型的表征能力

在这里插入图片描述

  • 通道注意力模块(CAM)
    此部分的操作步骤与SE通道注意力模块的步骤一致。

  • 空间注意力模块(SAM)

    • 特征提取:在通道注意力模块处理后的特征图上,分别进行基于通道维度的最大池化平均池化操作,以生成两个新的特征图。
    • 特征融合:将两个池化后的特征图在通道维度上进行拼接(concatenate),然后通过一个卷积层进行特征融合,生成空间注意力权重图。
    • 特征增强:将空间注意力权重图与原始特征图进行逐元素相乘,实现特征的增强,使得模型能够关注到更重要的空间位置信息。

在这里插入图片描述

论文:https://arxiv.org/abs/1807.06521
源码:https://github.com/Jongchan/attention-module

3.2 CBAM的实现代码

import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), "kernel size must be 3 or 7"
        padding = 3 if kernel_size == 7 else 1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        # 1*h*w
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        # 2*h*w
        x = self.conv(x)
        # 1*h*w
        return self.sigmoid(x)


class CBAM(nn.Module):
    def __init__(self, c1, c2, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(c1, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        out = self.channel_attention(x) * x
        # c*h*w
        # c*h*w * 1*h*w
        out = self.spatial_attention(out) * out
        return out

注意❗:在7.2、7.3小节中的__init__.pytasks.py文件中需要声明的模块名称为:CBAM


四、ECA

4.1 ECA的原理

ECA注意力模块的核心思想是在不增加过多计算成本和参数的情况下,通过引入一种有效的通道注意力机制,来增强网络对关键特征的关注能力。它避免了通道注意力机制中可能存在的降维操作带来的性能损失,通过一种自适应的跨通道交互策略来实现通道权重的生成。步骤如下:

  • 特征压缩:这里和SE注意力中的Squeeze操作一致,省略啦。
  • 特征学习ECA使用一维卷积来代替SE注意力中的全连接层,来学习通道间的依赖关系。这里的一维卷积核大小是自适应的,与通道维度成正比,以确保不同通道数的特征图都能有效地进行跨通道交互。通过一维卷积,ECA能够直接捕获局部跨通道交互信息,而无需进行复杂的降维和升维操作。
  • 特征重标定:这里也和SE注意力中的Scale操作一致。

在这里插入图片描述

论文:https://arxiv.org/abs/1910.03151
源码:https://github.com/BangguWu/ECANet

4.2 ECA的实现代码

import torch
import torch.nn as nn


class ECA(nn.Module):
    def __init__(self, c1, c2, k_size=3):
        super(ECA, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(
            1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        y = self.sigmoid(y)
        return x * y.expand_as(x)

注意❗:在7.2、7.3小节中的__init__.pytasks.py文件中需要声明的模块名称为:ECA


五、CA

5.1 CA的原理

CA注意力模块的核心思想是将位置信息嵌入到通道注意力中,以更精确地捕捉到图像中的空间分布特征。与普通的通道注意力机制不同的是,CA不仅关注通道间的依赖关系,还通过引入坐标信息来增强模型对空间细节的敏感度。步骤如下:

  • 特征分解:输入特征图通常具有C(通道数)、H(高度)、W(宽度)三个维度。CA首先通过两个并行的全局平均池化操作,分别沿垂直(高度)和水平(宽度)方向聚合输入特征,生成两个包含方向特定信息的特征图。这两个特征图分别捕捉了高度和宽度方向上的空间信息。
  • 特征编码:将两个池化后的特征图在通道维度上拼接,并通过一个1x1的二维卷积层来融合和转换特征。然后,对卷积后的特征图进行批量归一化非线性激活
  • 注意力图生成:将批量归一化和激活后的特征图分裂为两个特征图,分别对应于高度和宽度方向。接着,通过另外两个1x1的二维卷积层分别处理这两个特征图,并应用Sigmoid激活函数,生成两个注意力图。这两个注意力图分别沿宽度和高度方向对输入特征图进行重标定。
  • 特征重加权:将通过Sigmoid激活的注意力图与原始的输入特征图相乘,以重新加权原始特征。这样,重要的特征会被放大,而不重要的特征则会减弱,从而增强模型对关键信息的关注能力。

论文:https://arxiv.org/abs/2103.02907v1
源码:https://github.com/houqb/CoordAttention

5.2 CA的实现代码

import torch
import torch.nn as nn


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        mip = max(8, inp // reduction)
        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x
        n, c, h, w = x.size()
        # c*1*W
        x_h = self.pool_h(x)
        # c*H*1
        # C*1*h
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        y = torch.cat([x_h, x_w], dim=2)
        # C*1*(h+w)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        out = identity * a_w * a_h
        return out
        

注意❗:在7.2、7.3小节中的__init__.pytasks.py文件中需要声明的模块名称为:CoordAtt


六、Swin Transformer

6.1 Swin Transformer的原理

Swin Transformer通过分层设计结合多个等级的窗口划分来降低计算复杂度,并提出位移窗口使相邻的窗口之间进行交互,从而达到全局建模的能力。在Swin Transformer模型中最重要的是模块是窗口多头自注意力(W-MSA)移动窗口多头自注意力(SW-MSA),用于自注意力的计算。

  • 窗口多头自注意力(W-MSA)
    • 划分窗口:将特征图划分为多个固定大小的窗口。
    • 自注意力计算:在每个窗口内独立计算多头自注意力,此时计算复杂度与窗口内的小块数量成线性关系,从而降低了整体计算复杂度。
    • 输出:得到每个窗口内的自注意力特征图。
  • 移动窗口多头自注意力(SW-MSA)
    • 窗口移动:在W-MSA之后,通过移动窗口的方式改变窗口的划分,使得相邻的窗口之间能够产生交互。
    • 自注意力计算:在新的窗口划分下再次计算多头自注意力。
    • 输出:得到移动窗口后的自注意力特征图。

在这里插入图片描述

论文:https://arxiv.org/abs/2103.14030
源码:https://github.com/microsoft/Swin-Transformer

6.2 Swin Transformer的实现代码

class WindowAttention(nn.Module):

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask=None):

        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # print(attn.dtype, v.dtype)
        try:
            x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        except:
            #print(attn.dtype, v.dtype)
            x = (attn.half() @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class Mlp(nn.Module):

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

def window_partition(x, window_size):

    B, H, W, C = x.shape
    assert H % window_size == 0, 'feature map h and w can not divide by window size'
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class SwinTransformerLayer(nn.Module):

    def __init__(self, dim, num_heads, window_size=8, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.SiLU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        # if min(self.input_resolution) <= self.window_size:
        #     # if window size is larger than input resolution, we don't partition windows
        #     self.shift_size = 0
        #     self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def create_mask(self, H, W):
        # calculate attention mask for SW-MSA
        img_mask = torch.zeros((1, H, W, 1))  # 1 H W 1
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
        mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

        return attn_mask

    def forward(self, x):
        # reshape x[b c h w] to x[b l c]
        _, _, H_, W_ = x.shape

        Padding = False
        if min(H_, W_) < self.window_size or H_ % self.window_size!=0 or W_ % self.window_size!=0:
            Padding = True
            # print(f'img_size {min(H_, W_)} is less than (or not divided by) window_size {self.window_size}, Padding.')
            pad_r = (self.window_size - W_ % self.window_size) % self.window_size
            pad_b = (self.window_size - H_ % self.window_size) % self.window_size
            x = F.pad(x, (0, pad_r, 0, pad_b))

        # print('2', x.shape)
        B, C, H, W = x.shape
        L = H * W
        x = x.permute(0, 2, 3, 1).contiguous().view(B, L, C)  # b, L, c

        # create mask from init to forward
        if self.shift_size > 0:
            attn_mask = self.create_mask(H, W).to(x.device)
        else:
            attn_mask = None

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        x = x.permute(0, 2, 1).contiguous().view(-1, C, H, W)  # b c h w

        if Padding:
            x = x[:, :, :H_, :W_]  # reverse padding

        return x


class SwinTransformerBlock(nn.Module):
    def __init__(self, c1, c2, num_heads, num_layers, window_size=8):
        super().__init__()
        self.conv = None
        if c1 != c2:
            self.conv = Conv(c1, c2)

        # remove input_resolution
        self.blocks = nn.Sequential(*[SwinTransformerLayer(dim=c2, num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2) for i in range(num_layers)])

    def forward(self, x):
        if self.conv is not None:
            x = self.conv(x)
        x = self.blocks(x)
        return x


class STCSPA(nn.Module):
    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(STCSPA, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1, 1)
        num_heads = c_ // 32
        self.m = SwinTransformerBlock(c_, c_, num_heads, n)
        #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

    def forward(self, x):
        y1 = self.m(self.cv1(x))
        y2 = self.cv2(x)
        return self.cv3(torch.cat((y1, y2), dim=1))


class STCSPB(nn.Module):
    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(STCSPB, self).__init__()
        c_ = int(c2)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1, 1)
        num_heads = c_ // 32
        self.m = SwinTransformerBlock(c_, c_, num_heads, n)
        #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

    def forward(self, x):
        x1 = self.cv1(x)
        y1 = self.m(x1)
        y2 = self.cv2(x1)
        return self.cv3(torch.cat((y1, y2), dim=1))


class STCSPC(nn.Module):
    # CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(STCSPC, self).__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(c_, c_, 1, 1)
        self.cv4 = Conv(2 * c_, c2, 1, 1)
        num_heads = c_ // 32
        self.m = SwinTransformerBlock(c_, c_, num_heads, n)
        #self.m = nn.Sequential(*[Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)])

    def forward(self, x):
        y1 = self.cv3(self.m(self.cv1(x)))
        y2 = self.cv2(x)
        return self.cv4(torch.cat((y1, y2), dim=1))

注意❗:在7.2、7.3小节中的__init__.pytasks.py文件中需要声明的模块名称为:STCSPA, STCSPB, STCSPC,在模型中使用哪个选哪个就行。


七、添加步骤

此处在模型配置中以SE通道注意力为例,列举的其他注意力模块添加步骤与此完全一致

1. 修改ultralytics/nn/modules/block.py

此处需要修改的文件是ultralytics/nn/modules/block.py

common.py中定义了网络结构的通用模块,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。

SE添加后如下:

在这里插入图片描述

2. 修改ultralytics/nn/modules/init.py

此处需要修改的文件是ultralytics/nn/modules/__init__.py

__init__.py文件中定义了所有模块的初始化,我们只需要将block.py中的新的模块命添加到对应的函数即可。

SEblock.py中实现,所有要添加在:

from .block import (
    C1,
    C2,
    ...
    SE
)

在这里插入图片描述

3. 修改ultralytics/nn/modules/tasks.py

tasks.py文件中,需要在两处位置添加各模块类名称。

首先:在函数声明中引入SE

在这里插入图片描述

在这里插入图片描述

其次:在parse_model函数中注册SE模块

在这里插入图片描述

在这里插入图片描述


八、yaml模型文件

在代码配置完成后,配置模型的YAML文件。

此处以models/detect/yolov10m.yaml为例,在同目录下创建一个用于自己数据集训练的模型文件yolov10m-SE.yaml

yolov10m.yaml中的内容复制到yolov10m-SE.yaml文件下,修改nc数量等于自己数据中目标的数量。
在骨干网络的最后一层添加SE模块只需要填入一个参数,通道数,和前一层通道数一致还需要注意的是,由于PAN+FPN的颈部模型结构存在,层之间的匹配也要记得修改,维度要匹配上

📌 放在此处的目的是让网络能够学习到更深层的语义信息,因为此时特征图尺寸小,包含全局信息。若是希望网络能够更加关注局部信息,可尝试将注意力模块添加到网络的浅层。

📌 当然由于其即插即用的特性,加在哪里都是可以的,但是想要真的有效,还需要根据模型结构,数据集特性等多方面因素,多做实验进行验证。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# yolo task=detect mode=train model=yolov11m.yaml data=data.yaml device=0 epochs=300 batch=16 imgsz=640 workers=10

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SE, [1024]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 14], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 11], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)

  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)


九、成功运行结果

打印网络模型可以看到SE模块已经加入到模型中,并可以进行训练了。

其他模块如:CBAM、ECA、CA、Swin Transformer这些模块和SE的添加步骤完全一致。

并且对于这里未提到的注意力模块的添加步骤也是一样的,只要加入模块代码,并将其添加到模型中即可。

                   from  n    params  module                                       arguments                     
  0                  -1  1      1856  ultralytics.nn.modules.conv.Conv             [3, 64, 3, 2]                 
  1                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]               
  2                  -1  1    111872  ultralytics.nn.modules.block.C3k2            [128, 256, 1, True, 0.25]     
  3                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
  4                  -1  1    444928  ultralytics.nn.modules.block.C3k2            [256, 512, 1, True, 0.25]     
  5                  -1  1   2360320  ultralytics.nn.modules.conv.Conv             [512, 512, 3, 2]              
  6                  -1  1   1380352  ultralytics.nn.modules.block.C3k2            [512, 512, 1, True]           
  7                  -1  1   2360320  ultralytics.nn.modules.conv.Conv             [512, 512, 3, 2]              
  8                  -1  1   1380352  ultralytics.nn.modules.block.C3k2            [512, 512, 1, True]           
  9                  -1  1      1024  ultralytics.nn.modules.block.SE              [512, 512]                    
 10                  -1  1    656896  ultralytics.nn.modules.block.SPPF            [512, 512, 5]                 
 11                  -1  1    990976  ultralytics.nn.modules.block.C2PSA           [512, 512, 1]                 
 12                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 13             [-1, 6]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 14                  -1  1   1642496  ultralytics.nn.modules.block.C3k2            [1024, 512, 1, True]          
 15                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 16             [-1, 4]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 17                  -1  1    542720  ultralytics.nn.modules.block.C3k2            [1024, 256, 1, True]          
 18                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 19            [-1, 14]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 20                  -1  1   1511424  ultralytics.nn.modules.block.C3k2            [768, 512, 1, True]           
 21                  -1  1   2360320  ultralytics.nn.modules.conv.Conv             [512, 512, 3, 2]              
 22            [-1, 11]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 23                  -1  1   1642496  ultralytics.nn.modules.block.C3k2            [1024, 512, 1, True]          
 24        [17, 20, 23]  1   1411795  ultralytics.nn.modules.head.Detect           [1, [256, 512, 512]]          
YOLOv11m-SE summary: 415 layers, 20,054,803 parameters, 20,054,787 gradients, 68.2 GFLOPs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2179931.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

uniapp生物识别示例(人脸识别、指纹识别)

准备工作&#xff1a; mainfest.json设置勾选&#xff1a; 勾选完成后打 App自定义调试基座测试包 示例代码&#xff1a; <template><view class"content"><button v-if"supportSoterAuthenticationArray.includes(facial)" click"…

QT使用qss控制样式实现动态换肤

文章目录 设计QSS样式表动态加载QSS文件主函数调用QT提供了一种非常灵活的方式来使用QSS(Qt Style Sheet,类似于 CSS 的样式表),实现界面的动态换肤功能。QSS可以改变Qt应用程序中几乎所有可视组件的外观,包括颜色、字体、边框等。下面介绍一下如何通过QSS实现动态换肤。 设…

大模型时代的企业AI发展趋势浅析

在当前技术飞速进步的时代背景下&#xff0c;生成式人工智能与大型模型正逐渐成为推动产业变革的关键力量。随着人工智能技术的持续成熟与普及&#xff0c;其应用范围已从个人领域拓展至企业层面&#xff0c;广泛渗透至各个行业。那么&#xff0c;这些新兴技术究竟将为产业界带…

手把手教你使用YOLOv11训练自己数据集(含环境搭建 、数据集查找、模型训练)

一、前言 本文内含YOLOv11网络结构图 训练教程 推理教程 数据集获取等有关YOLOv11的内容&#xff01; 官方代码地址&#xff1a;https://github.com/ultralytics/ultralytics/tree/main/ultralytics/cfg/models/11 二、整体网络结构图 三、环境搭建 项目环境如下&#xf…

天融信运维安全审计系统 synRequest 远程命令执行漏洞复现

0x01 产品描述&#xff1a; 天融信运维安全审计系统TopSAG是基于自主知识产权NGTOS安全操作系统平台和多年网络安全防护经验积累研发而成&#xff0c;系统以4A管理理念为基础、安全代理为核心&#xff0c;在运维管理领域持续创新&#xff0c;为客户提供事前预防、事中监控、事后…

一文了解构建工具——Maven与Gradle的区别

目录 一、Maven和Gradle是什么&#xff1f; 构建工具介绍 Maven介绍 Gradle介绍 二、使用时的区别&#xff1a; 1、新建项目 Maven&#xff1a; Gradle&#xff1a; 2、配置项目 Maven&#xff1a; Gradle&#xff1a; 3、构建项目——生成项目的jar包 Gradle&…

Linux之实战命令20:split应用实例(五十四)

简介&#xff1a; CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布&#xff1a;《Android系统多媒体进阶实战》&#x1f680; 优质专栏&#xff1a; Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a; 多媒体系统工程师系列【…

【C++】692.前K个高频单词

692. 前K个高频单词 - 力扣&#xff08;LeetCode&#xff09; 思路分析&#xff1a; 使用map统计单词的次数。map是按单词从小到大排序的。对单词再按照次数从大到小排序。有两种方法&#xff1a; 将pair<string&#xff0c;int>键值对放到vector中&#xff0c;用sort排序…

【Linux系统编程】第二十五弹---Shell编程入门:打造一个简易版Shell

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】 目录 1、简易的shell 1.1、输出一个命令行 1.2、获取用户命令字符串 1.3、命令行字符串分割 1.4、检查命令是否是内建命令 1.5、…

LeetCode24. 两两交换链表中的节点(2024秋季每日一题 32)

给你一个链表&#xff0c;两两交换其中相邻的节点&#xff0c;并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题&#xff08;即&#xff0c;只能进行节点交换&#xff09;。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4] 输出&#xff1a;[2,1,…

Llama微调以及Ollama部署

1 Llama微调 在基础模型的基础上&#xff0c;通过一些特定的数据集&#xff0c;将具有特定功能加在原有的模型上。 1.1 效果对比 特定数据集 未使用微调的基础模型的回答 使用微调后的回答 1.2 基础模型 基础大模型我选择Mistral-7B-v0.3-Chinese-Chat-uncensored&#x…

Label-Studio ML利用yolov8模型实现自动标注

引言 Label Studio ML 后端是一个 SDK&#xff0c;用于包装您的机器学习代码并将其转换为 Web 服务器。Web 服务器可以连接到正在运行的 Label Studio 实例&#xff0c;以自动执行标记任务。我们提供了一个示例模型库&#xff0c;您可以在自己的工作流程中使用这些模型&#x…

[Cocoa]_[初级]_[绘制文本如何设置断行方式]

场景 在开发Cocoa程序时&#xff0c;表格NSTableView是经常使用的控件。其基于View Base的视图单元格模式就是使用NSCell或其子类来控制每个单元格的呈现。当一个单元格里的文字过多时&#xff0c;需要截断超出宽度的文字&#xff0c;怎么实现&#xff1f; 说明 Cocoa下的文本…

演讲干货整理:泛能网能碳产业智能平台基于 TDengine 的升级之路

在 7 月 26 日的 TDengine 用户大会上&#xff0c;新奥数能 / 物联和数据技术召集人袁文科进行了题为《基于新一代时序数据库 TDengine 助力泛能网能碳产业智能平台底座升级》的主题演讲。他从泛能网能碳产业智能平台的业务及架构痛点出发&#xff0c;详细分享了在数据库选型、…

【多线程奇妙屋】能把进程和线程讲的这么透彻的,没有20年功夫还真不行【0基础也能看懂】

本篇会加入个人的所谓鱼式疯言 ❤️❤️❤️鱼式疯言:❤️❤️❤️此疯言非彼疯言 而是理解过并总结出来通俗易懂的大白话, 小编会尽可能的在每个概念后插入鱼式疯言,帮助大家理解的. &#x1f92d;&#x1f92d;&#x1f92d;可能说的不是那么严谨.但小编初心是能让更多人…

OpenGL ES 顶点缓冲区和布局(3)

OpenGL ES 顶点缓冲区和布局(3) 简述 顶点缓冲区的本质就是一段GPU上的显存&#xff0c;我们通过绑定顶点缓冲区的方式来将数据从CPU传到GPU。 我们之前在绘制三角形的例子中&#xff0c;我们往顶点缓冲区只传入了坐标&#xff0c;但是其实顶点是可以包含很多数据的&#xff…

指定PDF或图片多个识别区域,识别区域文字,并导出到Excel文件中

常见场景 用户有大量图片/PDF文件&#xff0c;期望能将图片/PDF中的多个区域中的文字批量识别出来&#xff0c;并导入到Excel文件中。期望工具可以批量处理、离线识别&#xff08;保证数据安全性&#xff09;。手工操作麻烦。具体场景&#xff1a;用户有工程现场照片&#xff…

xgboost cross validation

在R中使用xgboost 假设X为训练数据&#xff0c;y为label&#xff0c;为0或者1.用xgboost建立分类模型代码如下 调用caret包中的createFolds方法&#xff0c;进行10倍交叉验证 最后画出AUC曲线 library(xgboost) library(caret) library(caTools) library(pROC)set.seed(123) …

【北京迅为】《STM32MP157开发板嵌入式开发指南》- 第十一章 Linux 帮助手册讲解

iTOP-STM32MP157开发板采用ST推出的双核cortex-A7单核cortex-M4异构处理器&#xff0c;既可用Linux、又可以用于STM32单片机开发。开发板采用核心板底板结构&#xff0c;主频650M、1G内存、8G存储&#xff0c;核心板采用工业级板对板连接器&#xff0c;高可靠&#xff0c;牢固耐…

3DGS中Densification梯度累计策略的改进——绝对梯度策略(Gaussian Opacity Fields)

在学习 StreetGS 代码中发现了其中的 Densification 策略与原 3DGS 不太一样&#xff0c;其是使用的 Gaussian Opacity Fields 中的一个的策略 我们先来回忆一下 3DGS 中一个比较重要 contribution&#xff1a;自适应密度控制 1 自适应密度控制 其具体步骤如下&#xff1a; …