YOLOv11改进 | 注意力篇 | YOLOv11引入Polarized Self-Attention注意力机制

news2024/12/23 22:57:34

1. Polarized Self-Attention介绍

1.1  摘要:像素级回归可能是细粒度计算机视觉任务中最常见的问题,例如估计关键点热图和分割掩模。 这些回归问题非常具有挑战性,特别是因为它们需要在低计算开销的情况下对高分辨率输入/输出的长期依赖性进行建模,以估计高度非线性的像素语义。 虽然深度卷积神经网络(DCNN)中的注意力机制在增强远程依赖性方面已变得流行,但特定于元素的注意力(例如非局部块)的学习非常复杂且对噪声敏感,并且大多数简化的注意力混合体试图达到 多种类型任务之间的最佳折衷方案。 在本文中,我们提出了偏振自注意力(PSA)模块,它结合了高质量像素级回归的两个关键设计:(1)偏振过滤:在通道和空间注意力计算中保持高内部分辨率,同时完全折叠输入张量 它们的对应尺寸。 (2) 增强:构建直接拟合典型细粒度回归输出分布的非线性,例如二维高斯分布(关键点热图)或二维二项分布(二元分割掩模)。 PSA 似乎已经耗尽了其仅通道和仅空间分支内的表示能力,因此其顺序布局和并行布局之间仅存在边际度量差异。 实验结果表明,PSA 将标准基线提高了 2−4 个点,并将 2D 姿态估计和语义分割基准的最先进技术提高了 1−2 个点。

官方论文地址:https://export.arxiv.org/pdf/2107.00782.pdf

官方代码地址:https://github.com/DeLightCMU/PSA

1.2  简单介绍:  

          Polarized Self-Attention (PSA) 模块是一种用于高质量像素级回归的自注意力机制。它主要通过两个关键设计来提高模型性能:极化滤波和增强处理。

           极化滤波:PSA在通道和空间注意力计算中保持高内部分辨率,同时沿着其对应的维度完全折叠输入张量。这有助于在不增加太多计算负担的情况下,保留重要的空间细节和通道信息。

           增强处理:PSA采用非线性组合直接拟合典型细粒度回归的输出分布,如二维高斯分布(关键点热图)或二维二项分布(二值分割掩膜)。这种设计使得PSA能够更好地适应不同任务的特定需求,并提高整体表现。

           此外,PSA模块在实验中显示出显著的性能提升,相比标准基线模型,它在2D姿态估计和语义分割基准测试中分别提高了2到4个百分点和1到2个百分点的表现。这表明PSA不仅能有效处理高分辨率信息,还能通过其独特的非线性组合进一步提高模型的预测精度。

1.3  Polarized Self-Attention模块结构图

2. 核心代码

import torch
import torch.nn as nn
 
 
class PolarizedSelfAttention(nn.Module):
    def __init__(self, channel=512):
        super().__init__()
        self.ch_wv = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1))
        self.ch_wq = nn.Conv2d(channel, 1, kernel_size=(1, 1))
        self.softmax_channel = nn.Softmax(1)
        self.softmax_spatial = nn.Softmax(-1)
        self.ch_wz = nn.Conv2d(channel // 2, channel, kernel_size=(1, 1))
        self.ln = nn.LayerNorm(channel)
        self.sigmoid = nn.Sigmoid()
        self.sp_wv = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1))
        self.sp_wq = nn.Conv2d(channel, channel // 2, kernel_size=(1, 1))
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
 
    def forward(self, x):
        b, c, h, w = x.size()
 
        # Channel-only Self-Attention
        channel_wv = self.ch_wv(x)  # bs,c//2,h,w
        channel_wq = self.ch_wq(x)  # bs,1,h,w
        channel_wv = channel_wv.reshape(b, c // 2, -1)  # bs,c//2,h*w
        channel_wq = channel_wq.reshape(b, -1, 1)  # bs,h*w,1
        channel_wq = self.softmax_channel(channel_wq)
        channel_wz = torch.matmul(channel_wv, channel_wq).unsqueeze(-1)  # bs,c//2,1,1
        channel_weight = self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b, c, 1).permute(0, 2, 1))).permute(0, 2, 1).reshape(b, c, 1, 1)  # bs,c,1,1
        channel_out = channel_weight * x
 
        # Spatial-only Self-Attention
        spatial_wv = self.sp_wv(x)  # bs,c//2,h,w
        spatial_wq = self.sp_wq(x)  # bs,c//2,h,w
        spatial_wq = self.agp(spatial_wq)  # bs,c//2,1,1
        spatial_wv = spatial_wv.reshape(b, c // 2, -1)  # bs,c//2,h*w
        spatial_wq = spatial_wq.permute(0, 2, 3, 1).reshape(b, 1, c // 2)  # bs,1,c//2
        spatial_wq = self.softmax_spatial(spatial_wq)
        spatial_wz = torch.matmul(spatial_wq, spatial_wv)  # bs,1,h*w
        spatial_weight = self.sigmoid(spatial_wz.reshape(b, 1, h, w))  # bs,1,h,w
        spatial_out = spatial_weight * x
        out = spatial_out + channel_out
        return out

3. YOLOv11中添加Polarized Self-Attention

3.1 在ultralytics/nn下新建Extramodule

 3.2 在Extramodule里创建PSA

在PSA.py文件里添加给出的PSA代码

添加完PSA代码后,在ultralytics/nn/Extramodule/__init__.py文件中引用

3.3 在tasks.py里引用

在ultralytics/nn/tasks.py文件里引用Extramodule

在tasks.py找到parse_model(ctrl+f可以直接搜索parse_model位置

添加如下代码:

        elif m in {PolarizedSelfAttention}:
            c2 = ch[f]
            args = [c2, *args]

4. 新建一个yolo11PSA.yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13
  - [-1, 1, PolarizedSelfAttention, []]

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  - [-1, 1, PolarizedSelfAttention, []]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  - [-1, 1, PolarizedSelfAttention, []]

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  - [-1, 1, PolarizedSelfAttention, []]

  - [[17, 21, 26], 1, Detect, [nc]] # Detect(P3, P4, P5)

大家根据自己的数据集实际情况,修改nc大小。

5. 模型训练

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO(r'D:\yolo\yolov11\ultralytics-main\datasets\yolo11PSA.yaml')
    model.train(data=r'D:\yolo\yolov11\ultralytics-main\datasets\data.yaml',
                cache=False,
                imgsz=640,
                epochs=100,
                single_cls=False,  # 是否是单类别检测
                batch=4,
                close_mosaic=10,
                workers=0,
                device='0',
                optimizer='SGD',
                amp=True,
                project='runs/train',
                name='exp',
                )

模型结构打印,成功运行:

6.本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的,后期我会根据各种前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

YOLOv11有效涨点专栏

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2186209.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

最新BurpSuite2024.9专业中英文开箱即用版下载

1、工具介绍 本版本更新介绍 此版本对 Burp Intruder 进行了重大改进,包括自定义 Bambda HTTP 匹配和替换规则以及对扫描 SOAP 端点的支持。我们还进行了其他改进和错误修复。 Burp Intruder 的精简布局我们对 Burp Intruder 进行了重大升级。现在,您可…

0基础学习CSS(十四)填充

CSS padding(填充) CSS padding(填充)是一个简写属性,定义元素边框与元素内容之间的空间,即上下左右的内边距。 padding(填充) 当元素的 padding(填充)内边距…

深入理解 Solidity 中的支付与转账:安全高效的资金管理攻略

在 Solidity 中,支付和转账是非常常见的操作,尤其是在涉及资金的合约中,比如拍卖、众筹、托管等。Solidity 提供了几种不同的方式来处理 Ether 转账,包括 transfer、send 和 call,每种方式的安全性、灵活性和复杂度各有…

【通配符】粗浅学习

1 背景说明 首先要注意,通配符中的符号和正则表达式中的特殊符号具备不同的匹配意义,例如:*在正则表达式中表示里面是指匹配前面的子表达式0次或者多次,而在通配符领域则是表示代表0个到无穷个任意字符。 此外,要注意…

大学城就餐推荐系统小程序的设计

管理员账户功能包括:系统首页,个人中心,用户管理,餐厅信息管理,美食类型管理,餐厅美食管理,评价信息管理,系统管理 微信端账号功能包括:系统首页,餐厅信息&a…

JavaScript for循环语句

for循环 循环语句用于重复执行某个操作,for语句就是循环命令,可以指定循环的起点、终点和终止条件。它的格式如下 for(初始化表达式;条件;迭代因子){语句} for语句后面的括号里面,有三个表达式 初始化表达式(initialize):确定循环变量的初始…

OpenAI开发者大会派礼包:大幅降低模型成本 AI语音加持App

美东时间10月1日周二,OpenAI举行了年度开发者大会DevDay,今年的大会并没有任何重大的产品发布,相比去年大会显得更低调,但OpenAI也为开发者派发了几个大“礼包”,对现有的人工智能(AI)工具和API…

Spring(学习笔记)

<context:annotation-config/>是 Spring 配置文件中的一个标签&#xff0c;用于开启注解配置功能。这个标签可以让 Spring 容器识别并处理使用注解定义的 bean。例如&#xff0c;可以使用 Autowired 注解自动装配 bean&#xff0c;或者使用 Component 注解将类标记为 bea…

四.网络层(上)

目录 4.1网络层功能概述 4.2 SDN基本概念 4.3 路由算法与路由协议 4.3.1什么是路由协议&#xff1f; 4.3.2什么是路由算法&#xff1f; 4.3.3路由算法分类 (1)静态路由算法 (2)动态路由算法 ①全局性 OSPF协议与链路状态算法 ②分散性 RIP协议与距离向量算法 4.3.…

netty之Netty使用Protobuf传输数据

前言 在netty数据传输过程中可以有很多选择&#xff0c;比如&#xff1b;字符串、json、xml、java对象&#xff0c;但为了保证传输的数据具备&#xff1b;良好的通用性、方便的操作性和传输的高性能&#xff0c;我们可以选择protobuf作为我们的数据传输格式。目前protobuf可以支…

(作业)第三期书生·浦语大模型实战营(十一卷王场)–书生基础岛第1关---书生大模型全链路开源体系

观看本关卡视频和官网https://internlm.intern-ai.org.cn/后&#xff0c;写一篇关于书生大模型全链路开源开放体系的笔记发布到知乎、CSDN等任一社交媒体&#xff0c;将作业链接提交到以下问卷&#xff0c;助教老师批改后将获得 100 算力点奖励&#xff01;&#xff01;&#x…

V3D——从单一图像生成 3D 物体

导言 论文地址&#xff1a;https://arxiv.org/abs/2403.06738 源码地址&#xff1a;https://github.com/heheyas/V3D.git 人工智能的最新进展使得自动生成 3D 内容的技术成为可能。虽然这一领域取得了重大进展&#xff0c;但目前的方法仍面临一些挑战。有些方法速度较慢&…

深刻理解Redis集群(中):Redis主从数据同步模式

背景 目前实现Redis高可用的模式主要有三种&#xff1a;主从模式、哨兵模式、集群模式。今天我们先来聊一下主从模式。 Redis 提供的主从模式&#xff0c;是通过复制的方式&#xff0c;将主服务器上的Redis的数据同步复制一份到从 Redis 服务器&#xff0c;这种做法很常见&…

函数式接口在Java中的应用与实践

1. 引言 函数式接口是Java 8引入的一个概念&#xff0c;它是指只有一个抽象方法的接口。函数式接口可以被用作lambda表达式的目标类型。在函数式接口中&#xff0c;除了抽象方法外&#xff0c;还可以有默认方法和静态方法。 函数式接口的引入是为了支持函数式编程&#xff0c…

SpringBoot 源码解读与自动装配原理结合Actuator讲解

Spring Boot 作为简化 Spring 应用开发的重要框架&#xff0c;能够通过“约定大于配置”的方式&#xff0c;使开发者无需大量的 XML 或配置类即可完成复杂的配置过程。这背后的核心机制之一就是 自动装配 (Auto-Configuration)&#xff0c;其依赖 Spring 的 依赖注入 (DI) 和 注…

AI通用大模型编程需要的能力

这几天研究通过通义千问AI大模型编程&#xff0c;有三点感受&#xff0c;分享给大家。如果将来有新的感受&#xff0c;会继续分享。 1、清晰的提示词指令&#xff0c;让输出的成功率更高 2、了解点代码知识&#xff0c;虽不会写&#xff0c;但能看的懂 3、定位代码问题的能力…

数据库软题5-SQL语言

一、DDL数据定义语言 题 1-创建视图 建立视图属于DDL的知识 建立视图要用到CREATE AS CREATE View Computer-BOOK ASSELECT 图书编号、图书名称、作者、出版社、出版日期FROM 图书WHERE 图书类型计算机 WITH CHEEK OPTION&#xff1b;二、DQL数据查询语言 题1-交 查询平均…

SAP 和 Carahsoft 的调查范围扩大到与近 100 家机构

美国司法部正在扩大对德国软件公司SAP和经销商Carahsoft的价格操纵调查&#xff0c;涉及近100个政府机构。这项调查最初集中在两家公司是否在2014年以来向美国国防部和其他政府部门收取过高费用&#xff0c;涉及金额超过20亿美元。最新的法院文件显示&#xff0c;调查范围已扩展…

HTTPS协议详解:从原理到流程,全面解析安全传输的奥秘

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐&#xff1a;「storm…

精准农业中遥感技术应用(六)- 作物长势分析和展示

橙蜂智能公司致力于提供先进的人工智能和物联网解决方案&#xff0c;帮助企业优化运营并实现技术潜能。公司主要服务包括AI数字人、AI翻译、领域知识库、大模型服务等。其核心价值观为创新、客户至上、质量、合作和可持续发展。 橙蜂智农的智慧农业产品涵盖了多方面的功能&…