全新的聚焦线性注意力模块(Focused Linear Attention)是一种旨在提高计算效率和准确性的注意力机制。传统的自注意力机制在处理长序列数据时通常计算复杂度较高,限制了其在大规模数据上的应用。聚焦线性注意力模块则通过优化注意力计算的方式,显著降低了计算复杂度。
- 自然语言处理:在长文本或大规模语料库的处理上,聚焦线性注意力模块能够提供更高的效率和更低的延迟。
- 计算机视觉:在处理高分辨率图像或视频数据时,能够加速计算过程,提升模型的实时性。
关于Focused Linear Attention的详细介绍可以看论文:https://arxiv.org/pdf/2308.00442
本文将讲解如何将Focused Linear Attention融合进yolov8
2, 将Focused Linear Attention融合进yolov8
2.1 步骤一
找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个FLA.py文件,文件名字可以根据你自己的习惯起,然后将Focused Linear Attention的核心代码复制进去
import torch
import torch.nn as nn
from einops import rearrange
class FocusedLinearAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
def __init__(self, dim, window_size=[20, 20], num_heads=8, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
focusing_factor=3, kernel_size=5):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.focusing_factor = focusing_factor
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.window_size = window_size
self.positional_encoding = nn.Parameter(torch.zeros(size=(1, window_size[0] * window_size[1], dim)))
self.softmax = nn.Softmax(dim=-1)
self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,
groups=head_dim, padding=kernel_size // 2)
self.scale = nn.Parameter(torch.zeros(size=(1, 1, dim)))
def forward(self, x, mask=None):
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
# flatten: [B, C, H, W] -> [B, C, HW]
# transpose: [B, C, HW] -> [B, HW, C]
x = x.flatten(2).transpose(1, 2)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, C).permute(2, 0, 1, 3)
q, k, v = qkv.unbind(0)
k = k + self.positional_encoding[:, :k.shape[1], :]
focusing_factor = self.focusing_factor
kernel_function = nn.ReLU()
q = kernel_function(q) + 1e-6
k = kernel_function(k) + 1e-6
scale = nn.Softplus()(self.scale)
q = q / scale
k = k / scale
q_norm = q.norm(dim=-1, keepdim=True)
k_norm = k.norm(dim=-1, keepdim=True)
if float(focusing_factor) <= 6:
q = q ** focusing_factor
k = k ** focusing_factor
q = (q / q.max(dim=-1, keepdim=True)[0]) ** focusing_factor
k = (k / k.max(dim=-1, keepdim=True)[0]) ** focusing_factor
q = (q / q.norm(dim=-1, keepdim=True)) * q_norm
k = (k / k.norm(dim=-1, keepdim=True)) * k_norm
q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]
z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)
if i * j * (c + d) > c * d * (i + j):
kv = torch.einsum("b j c, b j d -> b c d", k, v)
x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)
qk = torch.einsum("b i c, b j c -> b i j", q, k)
x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)
num = int(v.shape[1] ** 0.5)
feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num)
feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c")
x = x + feature_map
x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)
x = self.proj(x)
x = self.proj_drop(x)
x = rearrange(x, "b (w h) c -> b c w h", b=B, c=self.dim, w=num, h=num)
return x
2.2 步骤二
2.3 步骤三
# Ultralytics YOLO 🚀, GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers, 3157200 parameters, 3157184 gradients, 8.9 GFLOPs
s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients, 28.8 GFLOPs
m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients, 79.3 GFLOPs
l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
# YOLOv8.0n backbone
# [from, repeats, module, args]
- [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
- [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
- [-1, 3, C2f, [128, True]]
- [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
- [-1, 6, C2f, [256, True]]
- [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
- [-1, 6, C2f, [512, True]]
- [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
- [-1, 3, C2f, [1024, True]]
- [-1, 1, SPPF, [1024, 5]] # 9
- [-1, 1, FocusedLinearAttention, [256]] # 10
# YOLOv8.0n head
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 6], 1, Concat, [1]] # cat backbone P4
- [-1, 3, C2f, [512]] # 13
- [-1, 1, nn.Upsample, [None, 2, 'nearest']]
- [[-1, 4], 1, Concat, [1]] # cat backbone P3
- [-1, 3, C2f, [256]] # 16 (P3/8-small)
- [-1, 1, Conv, [256, 3, 2]]
- [[-1, 13], 1, Concat, [1]] # cat head P4
- [-1, 3, C2f, [512]] # 19 (P4/16-medium)
- [-1, 1, Conv, [512, 3, 2]]
- [[-1, 10], 1, Concat, [1]] # cat head P5
- [-1, 3, C2f, [1024]] # 22 (P5/32-large)
- [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)
raise EinopsError(message + "\n {}".format(e))
einops.EinopsError: Error while processing rearrange-reduction pattern "b (w h) c -> b c w h".
Input tensor shape: torch.Size([128, 294, 32]). Additional info: {'w': 17, 'h': 17}.
Shape mismatch, 294 != 289
def build_dataset(self, img_path, mode='train', batch=None):
Build YOLO Dataset.
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=mode == 'val', stride=gs)
return build_yolo_dataset(self.args, img_path, batch, self.data, mode=mode, rect=(False if mode == 'val' else False), stride=gs)