1. MLAttention介绍
(1). 多尺度卷积操作:MLAttention通过多尺度卷积操作来增强不同尺度的特征表达能力。采用了多种卷积核尺寸(例如5x5、1x7、7x1、1x11、11x1、1x21、21x1)的深度可分离卷积来捕捉不同感受野的特征。较小的卷积核擅长捕捉细节信息,而较大的卷积核则能够涵盖更大的上下文信息。这种多尺度的处理方式,确保了网络能够同时对细节和整体信息进行有效的建模。
(2). 多层次的特征融合:MLAttention模块中的多组卷积层会提取不同层次的特征,并通过逐步累加的方式将这些特征进行融合。不同层次的卷积特征通过叠加方式,不仅提高了对复杂特征的捕捉能力,还有效增强了对不同尺度目标的感知能力。这对于复杂场景下的图像特征提取,尤其是包含多尺度目标的场景,有着显著的优势。
(3). 线性注意力机制的引入:MLAttention结合了线性注意力机制,通过生成查询(Q)、键(K)、值(V)三组特征来进行图像局部和全局信息的交互。注意力机制可以通过自适应地学习特征之间的相关性,有效地突出关键区域的特征,同时抑制冗余或不重要的信息。在具体实现中,线性注意力通过Softmax计算注意力权重,然后通过加权求和的方式将重要特征进行增强,从而进一步提高了图像特征提取的准确性和鲁棒性。
(4). 有效的特征增强:MLAttention通过将卷积特征与注意力特征进行融合,在输出时将这些增强后的特征与输入特征进行相加操作。这种设计不仅保留了原始特征的基本信息,还通过注意力机制对特征进行了加强,能够在保留基础信息的同时突出关键特征,进一步提升模型的特征表达能力。
2. 核心代码
import torch
import torch.nn as nn
from torch.nn import functional as F
class MLAttention(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
self.conv3 = nn.Conv2d(dim, dim, 1)
# Linear Attention
self.phi_q = nn.Linear(dim, dim)
self.phi_k = nn.Linear(dim, dim)
self.phi_v = nn.Linear(dim, dim)
def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn_0 = self.conv0_1(attn)
attn_0 = self.conv0_2(attn_0)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn = attn + attn_0 + attn_1 + attn_2
attn = self.conv3(attn)
# Linear Attention
B, C, H, W = x.shape
x_flat = x.view(B, C, -1).permute(0, 2, 1) # (B, N, C)
Q = self.phi_q(x_flat) # (B, N, C)
K = self.phi_k(x_flat) # (B, N, C)
V = self.phi_v(x_flat) # (B, N, C)
# 线性注意力:通过Softmax计算权重
K_T = K.permute(0, 2, 1) # (B, C, N)
attn_weights = torch.matmul(Q, K_T) # (B, N, N)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = torch.matmul(attn_weights, V) # (B, N, C)
attn_output = attn_output.permute(0, 2, 1).view(B, C, H, W) # reshape回原来形状
return attn_output * attn + u
3. YOLOv11中添加MLAttention
3.1 在ultralytics/nn下新建Extramodule
3.2 在Extramodule里创建MLAttention
在MLAttention.py文件里添加给出的MLAttention代码
添加完MLAttention代码后,在ultralytics/nn/Extramodule/__init__.py文件中引用
3.3 在tasks.py里引用
在ultralytics/nn/tasks.py文件里引用Extramodule
在tasks.py找到parse_model(ctrl+f 可以直接搜索parse_model位置)
添加如下代码:
elif m in {MLAttention}:
c2 = ch[f]
args = [c2, *args]
4. 新建一个yolo11MLAttention.yaml文件
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs
# YOLO11n backbone
backbone:
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 2, C3k2, [256, False, 0.25]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 2, C3k2, [512, False, 0.25]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 2, C3k2, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 2, C3k2, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 2, C2PSA, [1024]] # 10
# YOLO11n head
head:
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 2, C3k2, [512, False]] # 13
- [-1, 1, MLAttention, []]
- [-1, 1, nn.Upsample, [None, 2, "nearest"]]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
- [-1, 1, MLAttention, []]
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
- [-1, 1, MLAttention, []]
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
- [-1, 1, MLAttention, []]
- [[17, 21, 26], 1, Detect, [nc]] # Detect(P3, P4, P5)
大家根据自己的数据集实际情况,修改nc大小。
5.模型训练
import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO
if __name__ == '__main__':
model = YOLO(r'D:\yolo\yolov11\ultralytics-main\datasets\yolo11MLAttention.yaml')
model.train(data=r'D:\yolo\yolov11\ultralytics-main\datasets\data.yaml',
cache=False,
imgsz=640,
epochs=100,
single_cls=False, # 是否是单类别检测
batch=4,
close_mosaic=10,
workers=0,
device='0',
optimizer='SGD',
amp=True,
project='runs/train',
name='exp',
)
模型结构打印,成功运行 :
6.本文总结
到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的,后期我会根据各种前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~
YOLOv11有效涨点专栏