YOLOv11改进 | 注意力篇 | YOLOv11引入ACmix注意力机制

news2024/12/24 8:12:43

1. ACmix介绍

1.1  摘要:卷积和自注意力是表示学习的两种强大技术,它们通常被认为是两种彼此不同的同行方法。 在本文中,我们表明它们之间存在很强的潜在关系,从某种意义上说,这两种范式的大量计算实际上是通过相同的操作完成的。 具体来说,我们首先证明内核大小为 k×k 的传统卷积可以分解为 k2 个单独的 1×1 卷积,然后进行移位和求和操作。 然后,我们将自注意力模块中的查询、键和值的投影解释为多个 1×1 卷积,然后计算注意力权重和值的聚合。 因此,两个模块的第一阶段都包括类似的操作。 更重要的是,与第二级相比,第一级贡献了主要的计算复杂性(通道大小的平方)。 这种观察自然地导致了这两个看似不同的范式的优雅整合,即一个混合模型,它既享受自注意力和卷积(ACmix)的好处,同时与纯卷积或自注意力对应物相比具有最小的计算开销。 大量实验表明,我们的模型在图像识别和下游任务方面比竞争基线取得了持续改进的结果。

官方论文地址:https://arxiv.org/pdf/2111.14556v1.pdf

官方代码地址:https://github.com/Panxuran/ACmix

1.2  简单介绍:  

          ACmix模块是一种结合了自注意力和卷积技术的混合模型,旨在通过最小化计算开销来整合这两种看似不同的范式。该模块的核心在于揭示了自注意力和卷积之间存在的强大内在联系,这种联系主要体现在两者在第一阶段的计算复杂性上。具体而言,ACmix首先使用1×1卷积将输入特征图投影,以获得一组丰富的中间特征。然后,这些中间特征被重用,并分别遵循自注意力和卷积的方式聚合。

          在ACmix中,两个主要阶段(I和II)共同作用:在第一阶段,输入特征通过三个1×1卷积进行投影,生成包含3×N特征图的丰富中间特征集;第二阶段则根据不同范式(即自注意力和卷积方式)使用这些中间特征。此外,为了提高模型的灵活性和效率,ACmix还引入了几个改进措施,包括使用多个组卷积分解复杂的位移操作以及采用可学习的卷积分类器初始化固定核。

1.3  ACmix模块结构图

2. 核心代码

import torch
import torch.nn as nn


def position(H, W, type, is_cuda=True):
    if is_cuda:
        loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1).to(type)
        loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W).to(type)
    else:
        loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)
        loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)
    loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)
    return loc


def stride(x, stride):
    b, c, h, w = x.shape
    return x[:, :, ::stride, ::stride]


def init_rate_half(tensor):
    if tensor is not None:
        tensor.data.fill_(0.5)


def init_rate_0(tensor):
    if tensor is not None:
        tensor.data.fill_(0.)


class ACmix(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):
        super(ACmix, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.head = head
        self.kernel_att = kernel_att
        self.kernel_conv = kernel_conv
        self.stride = stride
        self.dilation = dilation
        self.rate1 = torch.nn.Parameter(torch.Tensor(1))
        self.rate2 = torch.nn.Parameter(torch.Tensor(1))
        self.head_dim = self.out_planes // self.head

        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)

        self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2
        self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)
        self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)
        self.softmax = torch.nn.Softmax(dim=1)

        self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)
        self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes,
                                  kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1,
                                  stride=stride)

        self.reset_parameters()

    def reset_parameters(self):
        init_rate_half(self.rate1)
        init_rate_half(self.rate2)
        kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)
        for i in range(self.kernel_conv * self.kernel_conv):
            kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.
        kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)
        self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)
        self.dep_conv.bias = init_rate_0(self.dep_conv.bias)

    def forward(self, x):
        q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)
        scaling = float(self.head_dim) ** -0.5
        b, c, h, w = q.shape
        h_out, w_out = h // self.stride, w // self.stride

        pe = self.conv_p(position(h, w, x.dtype, x.is_cuda))

        q_att = q.view(b * self.head, self.head_dim, h, w) * scaling
        k_att = k.view(b * self.head, self.head_dim, h, w)
        v_att = v.view(b * self.head, self.head_dim, h, w)

        if self.stride > 1:
            q_att = stride(q_att, self.stride)
            q_pe = stride(pe, self.stride)
        else:
            q_pe = pe

        unfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim,
                                                         self.kernel_att * self.kernel_att, h_out,
                                                         w_out)  # b*head, head_dim, k_att^2, h_out, w_out
        unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out,
                                                        w_out)  # 1, head_dim, k_att^2, h_out, w_out

        att = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(
            1)
        att = self.softmax(att)

        out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att,
                                                        h_out, w_out)
        out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)

        f_all = self.fc(torch.cat(
            [q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w),
             v.view(b, self.head, self.head_dim, h * w)], 1))
        f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])

        out_conv = self.dep_conv(f_conv)

        return self.rate1 * out_att + self.rate2 * out_conv


3. YOLOv11中添加ACmix

3.1 在ultralytics/nn下新建Extramodule

 3.2 在Extramodule里创建ACmix

在ACmix.py文件里添加给出的ACmix代码

添加完ACmix代码后,在ultralytics/nn/Extramodule/__init__.py文件中引用

3.3 在tasks.py里引用

在ultralytics/nn/tasks.py文件里引用Extramodule

在tasks.py找到parse_model(ctrl+f可以直接搜索parse_model位置

添加如下代码:

        elif m in {ACmix}:
            args = [ch[f],  ch[f]]

4. 新建一个yolo11ACmix.yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13
  - [-1, 1, ACmix, []]

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  - [-1, 1, ACmix, []]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  - [-1, 1, ACmix, []]

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  - [-1, 1, ACmix, []]

  - [[17, 21, 26], 1, Detect, [nc]] # Detect(P3, P4, P5)

大家根据自己的数据集实际情况,修改nc大小。

5.模型训练

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO(r'D:\yolo\yolov11\ultralytics-main\datasets\yolo11ACmix.yaml')
    model.train(data=r'D:\yolo\yolov11\ultralytics-main\datasets\data.yaml',
                cache=False,
                imgsz=640,
                epochs=100,
                single_cls=False,  # 是否是单类别检测
                batch=8,
                close_mosaic=10,
                workers=0,
                device='0',
                optimizer='SGD',
                amp=True,
                project='runs/train',
                name='exp',
                )

模型结构打印,成功运行 :

6.本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的,后期我会根据各种前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

YOLOv11有效涨点专栏

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2185305.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

特殊的加法和除法(考察点为位操作符)

目录 一简介: 二例题讲解: 2.1不用加号的加法: 2.1.1题目: 2.1.2思路汇总: 2.1.3代码解答: 2.2两数相除: 2.2.1题目: 2.2.2思路汇总: 2.2.3代码解答&#xff1a…

第 13 章 常用类

第 13 章 常用类 文章目录 <center>第 13 章 常用类13.1 包装类13.1.1 包装类的分类13.1.2 包装类和基本数据的转换13.1.3 案例演示13.1.4 课堂测试题13.1.5 包装类型和 String 类型的相互转换13.1.6 Integer 类和 Character13.1.7 Integer 类面试题 113.1.8 Intege 类面…

【算法】0/1背包问题

背包中有一些物品&#xff0c;每件物品有它的价值与重量&#xff0c;给定一个重量&#xff0c;在该重量范围内取物品&#xff08;每件物品不可重复取&#xff09;&#xff0c;求最大价值。 将需求转化为表格&#xff0c;每一行中的每个格子代表可选哪些下标的物品在总重量限额内…

【c++】 模板初阶

泛型编程 写一个交换函数&#xff0c;在学习模板之前&#xff0c;为了匹配不同的参数类型&#xff0c;我们可以利用函数重载来实现。 void Swap(int& a, int& b) {int c a;a b;b c; } void Swap(char& a, char& b) {char c a;a b;b c; } void Swap(dou…

Linux开发讲课45--- 链表

Linux内核代码中广泛使用了数据结构和算法,其中最常用的有链表、队列kfifo、红黑树、基数树和位图。 链表 Linux内核代码大量使用了链表这种数据结构。链表是在解决数组不能动态扩展这个缺陷而产生的一种数据结构。 链表所包含的元素可以动态创建并插入和删除。链表的每个元素…

AR 领域的突破——微型化显示屏为主流 AR 眼镜铺平道路

概述 多年来&#xff0c;增强现实 (AR) 技术一直吸引着人们的想象力&#xff0c;有望将数字信息与我们的物理世界无缝融合。通过将计算机生成的图像叠加到现实世界的视图上&#xff0c;AR 有可能彻底改变我们与环境的互动方式。从增强游戏体验到协助手术室的外科医生&#xff…

【Linux】进程间关系与守护进程

超出能力之外的事&#xff0c; 如果永远不去做&#xff0c; 那你就永远无法进步。 --- 乌龟大师 《功夫熊猫》--- 进程间关系与守护进程 1 进程组2 会话3 控制终端4 作业控制5 守护进程 1 进程组 之前我们提到了进程的概念&#xff0c; 其实每一个进程除了有一个进程 ID(P…

2024/10/2 408 20题

c d d b b a b c b b a d c d a c

【C++】C++基础

目录 一. C关键字(C98) 二、C的第一个程序 三、命名空间 3.1.namespace的价值 3.2.namespace的定义 3.2.命名空间使用 总结&#xff1a;在项目当中第一、第二种方法搭配使用&#xff0c;第三种冲突风险非常大&#xff0c;仅适合练习使用。 四、C输入&输出 五、缺省…

【数据库】揭秘Oracle中不朽的scott用户:起源、影响与技术启示

标题&#xff1a;【数据库探秘】揭秘Oracle中不朽的scott用户&#xff1a;起源、影响与技术启示 摘要 本文将带你深入了解Oracle数据库中一个传奇的用户——scott。从scott用户的起源到其在数据库发展中的影响&#xff0c;我们将探索这个经典用户账户背后的故事。此外&#xf…

[动态规划] 二叉树中的最大路径和##树形DP#DFS

标题&#xff1a;[动态规划] 二叉树中的最大路径和##树形DP#DFS 个人主页水墨不写bug &#xff08;图片来源于网络&#xff09; 目录 一 、什么是树形DP 二、题目描述&#xff08;点击题目转跳至题目&#xff09; NC6 二叉树中的最大路径和 算法思路&#xff1a; 讲解与参考代…

SpringCloud-基于Docker和Docker-Compose的项目部署

一、初始化环境 1. 卸载旧版本 首先&#xff0c;卸载可能已存在的旧版本 Docker。如果您不确定是否安装过&#xff0c;可以直接执行以下命令&#xff1a; sudo yum remove docker \docker-client \docker-client-latest \docker-common \docker-latest \docker-latest-logro…

NeRF2: Neural Radio-Frequency Radiance Fields 笔记

任务&#xff1a;用 NeRF 对无线信号的传播进行建模&#xff0c;建模完成后可以用NeRF网络生成新位置下的信号。生成的信号用于指纹定位、信道估计等下游任务。 核心思路 在视觉 NeRF 的基础上&#xff0c;根据无线信号的特点修改了隐式场模型、渲染函数&#xff0c;网络的输…

C++初阶:STL详解(十)——priority_queue的介绍,使用以及模拟实现

✨✨小新课堂开课了&#xff0c;欢迎欢迎~✨✨ &#x1f388;&#x1f388;养成好习惯&#xff0c;先赞后看哦~&#x1f388;&#x1f388; 所属专栏&#xff1a;C&#xff1a;由浅入深篇 小新的主页&#xff1a;编程版小新-CSDN博客 一.priority_queue的介绍 优先级队列被实现…

calibre-web的翻译translations

calibre-web的翻译translations Windows安装calibre-web&#xff0c;Python-CSDN博客文章浏览阅读539次&#xff0c;点赞10次&#xff0c;收藏11次。pip install calibreweb报错&#xff1a;error: Microsoft Visual C 14.0 or greater is required. Get it with "Microso…

Oracle 12c在Windows环境下安装

适合初学者使用的Oracle 12c在Windows环境下安装步骤、参数配置、常见问题及参数调优的详细补充说明。 一、Oracle 12c安装步骤 1. 准备工作 在安装Oracle 12c之前&#xff0c;确保你的系统满足以下要求&#xff1a; 操作系统&#xff1a;Oracle 12c支持的Windows版本包括Wi…

掌控物体运动艺术:图扑 Easing 函数实践应用

现如今&#xff0c;前端开发除了构建功能性的网站和应用程序外&#xff0c;还需要创建具有吸引力且尤为流畅交互的用户界面&#xff0c;其中动画技术在其中发挥着至关重要的作用。在数字孪生领域&#xff0c;动画的应用显得尤为重要。数字孪生技术通过精确模拟现实世界中的对象…

【C++】树形结构的关联式容器:set、map、multiset、multimap的使用

&#x1f33b;个人主页&#xff1a;路飞雪吖~ ✨专栏&#xff1a;C/C 目录 一、set的简单介绍和使用 &#x1f31f;set的介绍 &#x1f525;注意&#xff1a; &#x1f320;小贴士&#xff1a; &#x1f31f;set的使用 ✨set的构造 ✨set的迭代器 ​编辑 ✨set的容量 …

结构光—格雷码构造代码

本篇文章主要给出生成格雷码的代码&#xff0c;鉴于自身水平所限&#xff0c;如有错误&#xff0c;欢迎批评指正。&#xff08;欢迎进Q群交流&#xff1a;874653199&#xff09; #include <iostream> #include <fstream> #include <Windows.h>using…

vue2老项目打包优化:优化脚本生成的代码

前言 上次讲到在一个 vue-cli 的老项目中&#xff0c;修改 vue.config.js 的以下参数&#xff0c;将打包时间从 40min &#xff0c;降到了 12min {parallel: true, // 多核处理&#xff0c;按理说默认应该生效&#xff0c;但我的文件被设置成了falseruntimeCompiler: false, …