YOLO11改进|注意力机制篇|引入三重注意力机制Triplet Attention

news2024/11/26 13:58:01

在这里插入图片描述

目录

    • 一、【Triplet Attention】注意力机制
      • 1.1【Triplet Attention】注意力介绍
      • 1.2【Triplet Attention】核心代码
    • 二、添加【Triplet Attention】注意力机制
      • 2.1STEP1
      • 2.2STEP2
      • 2.3STEP3
      • 2.4STEP4
    • 三、yaml文件与运行
      • 3.1yaml文件
      • 3.2运行成功截图

一、【Triplet Attention】注意力机制

1.1【Triplet Attention】注意力介绍

在这里插入图片描述

下图是【Triplet Attention】的结构图,让我们简单分析一下运行过程和优势

处理过程

  • 通道池化与卷积:
  • 左侧的分支首先对输入特征进行 Channel Pooling,通过通道维度的池化操作,压缩特征图的空间维度,随后应用一个 7×7的卷积层提取多尺度特征。
  • 卷积后的特征图经过 批归一化(Batch Normalization) 和 Sigmoid 激活 生成通道加权系数。
  • 空间维度变换与卷积:
  • 中间和右侧的分支分别进行 Permute 操作,对输入特征的维度进行转换(类似转置操作),接着应用 Z-Pooling,进一步压缩空间维度。
  • 经过 7×7卷积和批归一化处理后,这些分支也生成了加权特征,并通过 Sigmoid 激活,生成空间加权系数。
  • 特征融合:
  • 三个并行路径生成的加权特征通过 加法操作(+) 和 通道平均池化(Avg) 进行融合,得到最终的权重。
  • 最终的输出特征是经过加权操作的特征图,其中每个特征通道和空间位置都被自适应加权。
    优势
  • 多尺度特征提取:
  • 通过不同路径的池化和卷积操作,模块能够从多个尺度上提取特征。这种多尺度处理使得模型在不同的空间分辨率下能够学习到更丰富的上下文信息,提高了对不同目标的敏感度。
  • 自适应加权:
  • 模块通过通道和空间上的加权机制,使得每个通道和空间位置都根据特征的重要性被自适应调整。这使得网络能够更加关注有用的特征,抑制冗余或不相关的信息,从而提升了模型的表达能力和泛化性。
  • 特征融合:
  • 通过不同路径的特征融合(加法操作和平均池化),该模块能够同时整合多维度的信息,提高了特征的多样性,增强了模型处理复杂场景的能力。
  • 有效的通道和空间交互:
  • 通过引入通道池化和维度转置,模型实现了通道和空间维度之间的有效交互,使得网络能够同时捕捉到全局和局部的信息。

在这里插入图片描述

1.2【Triplet Attention】核心代码

import torch
import torch.nn as nn


class BasicConv(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
                 bn=True, bias=False):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None
        self.relu = nn.ReLU() if relu else None

    def forward(self, x):
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x


class ZPool(nn.Module):
    def forward(self, x):
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)


class AttentionGate(nn.Module):
    def __init__(self):
        super(AttentionGate, self).__init__()
        kernel_size = 7
        self.compress = ZPool()
        self.conv = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False)

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.conv(x_compress)
        scale = torch.sigmoid_(x_out)
        return x * scale


class TripletAttention(nn.Module):
    def __init__(self, no_spatial=False):
        super(TripletAttention, self).__init__()
        self.cw = AttentionGate()
        self.hc = AttentionGate()
        self.no_spatial = no_spatial
        if not no_spatial:
            self.hw = AttentionGate()

    def forward(self, x):
        x_perm1 = x.permute(0, 2, 1, 3).contiguous()
        x_out1 = self.cw(x_perm1)
        x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
        x_perm2 = x.permute(0, 3, 2, 1).contiguous()
        x_out2 = self.hc(x_perm2)
        x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
        if not self.no_spatial:
            x_out = self.hw(x)
            x_out = 1 / 3 * (x_out + x_out11 + x_out21)
        else:
            x_out = 1 / 2 * (x_out11 + x_out21)
        return x_out

二、添加【Triplet Attention】注意力机制

2.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个TripletAttention.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

2.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

2.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

2.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

三、yaml文件与运行

3.1yaml文件

以下是添加【Triplet Attention】注意力机制在Backbone中的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128,3,2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256,3,2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512,3,2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024,3,2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1,1,TripletAttention,[]]
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 14], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 11], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)

  - [[17, 20, 23], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

3.2运行成功截图

在这里插入图片描述

OK 以上就是添加【Triplet Attention】注意力机制的全部过程了,后续将持续更新尽情期待

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2210469.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ARM嵌入式学习--第二天

-指令流水线 -基础知识 1.流水线技术通过多个功能部件并行工作来缩短程序执行时间,提高处理器的效率和吞吐率 2.增加流水线级数,可以简化流水线的各级逻辑,进一步提高了处理器的性能 3.以三级流水线分析: pc代表程序计数器&#x…

Graph Contrastive Learning 图对比学习GCL

Preamble GCL主要任务:学习一个编码器,可以编码出结构和结点特征信息,得到一个低维的表达 早期大部分GNN模型都是有监督的训练 自监督学习主要分成两种:生成式(用已有信息去预测自己的其他信息) and 对…

C++学习笔记----9、发现继承的技巧(一)---- 使用继承构建类(1)

在前面的章节中,你学到了继承关系是一种真实世界对象以层次存在的模式。在编程世界中,当需要写一个类基于其构建,或进行细微的修改的另一个类时,那种模式就有了关系。完成这个目标的一个方式是拷贝一个类的代码粘贴到另一个类中。…

一个月学会Java 第14天 内部类

Day14 内部类 类有外边的public class,然后还有一个文件多个的class,但是有没有想过,class可以作为成员也就是类内部的类,甚至作为方法内部的属性也就是类内部的方法的内部出现。除了这两个, 还有直接对着上节课讲的抽…

GeoScene Pro教程(008):GeoScenePro数据查询和检索

文章目录 1、工具分类2、数据常用工具2.1 加载数据2.2 查询需求2.2.1 按照属性查询查询1:人口大于300万的城市有哪些查询2:自治州有哪些查询3:城市名字中带有“荆”的有哪些补充2.2.2 按照位置查询需求2:导出湖北省境内的铁路数据需求3:武汉市共有多少条铁路2.2.3 空间连接…

48 Redis

48 Redis 前言 Redis(Remote Dictionary Server ),即远程字典服务。是一个开源的使用ANSI C语言编写、支持网络、可基于内存亦可持久化的日志型、Key-Value数据库,并提供多种语言的API。 redis会周期性的把更新的数据写入磁盘或者把修改操…

RTSP与ONVIF协议的区别及其在EasyCVR视频汇聚平台中的应用

在视频监控和物联网设备领域,RTSP(Real Time Streaming Protocol)和ONVIF(Open Network Video Interface Forum)是两个重要的协议,它们各自在视频流的传输和控制上发挥着不同的作用,并在实际应用…

网络安全之XXE攻击

0x01 什么是 XXE 个人认为,XXE 可以归结为一句话:构造恶意 DTD 介绍 XXE 之前,我先来说一下普通的 XML 注入,这个的利用面比较狭窄,如果有的话应该也是逻辑漏洞。 既然能插入 XML 代码,那我们肯定不能善罢…

基于Nodemcu的手机控制小车

基于Nodemcu的手机控制小车 一、项目说明二、项目材料三、代码与电路设计四、轮子和车体五、电路连接六、使用方法 一、项目说明 嗨,机器人项目制造者们!在这个项目中,我制作了这辆简单但快速的遥控车,它可以通过智能手机控制&am…

gaussdb 主备版本8 SQL参考 学习

SQL参考 1 数据类型 1.1 货币类型 1.1.1 货币类型存储带有固定小数精度的货币金额。 1.2 布尔类型 1.2.1 true:真 1.2.2 false:假 1.2.3 null:未知(unknown) 1.3 日期/时间类型 1.3.1 DATE 输出格式:仅支…

MySQL-06.DDL-表结构操作-创建

一.DDL(表操作) create database db01;use db01;create table tb_user(id int comment ID,唯一标识,username varchar(20) comment 用户名,name varchar(10) comment 姓名,age int comment 年龄,gender char(1) comment 性别 ) comment 用户表; 此时并没有限制ID为…

圈子系统APP小程序H5该如何设置IM?

搭建圈子系统的常见问题,以及圈子论坛系统的功能特点 社交圈子论坛系统的概念 圈子小程序源码 多客圈子系统 圈子是什么软件 跟进圈一个系统的软件 为圈子系统APP小程序H5设置IM(即时通讯),需要遵循一系列步骤来确保通讯功能的稳定、安全和高…

企业架构之从理论指南到实践指导企业数字化转型

理论与实践结合的数字化转型之道 在当今的全球化经济中,企业面临着前所未有的数字化转型压力。数字化转型不仅是技术的更新换代,更是业务、组织、文化和战略的系统性重塑。对于企业来说,如何将理论转换为有效的实践路径,是推动数…

STM32 通用同步/异步收发器

目录 串行通信基础 串行异步通信数据格式 USART介绍 USART的主要特性 USART的功能 USART的通信时序 USART的中断 串行通信基础 在串行通信中,参与通信的两台或多台设备通常共享一条物理通路。发送者依次逐位发送一串数据信号,按一定的约定规则被接…

乐鑫ESP32-S3无线方案,AI大模型中控屏智能升级,提升智能家居用户体验

在这个由数据驱动的时代,人工智能正以其前所未有的速度和规模改变着我们的世界。随着技术的不断进步,AI已经从科幻小说中的概念,转变为我们日常生活中不可或缺的一部分。 特别是在智能家居领域,AI的应用已成为提升生活质量、增强…

linux下编译鸿蒙版curl、openssl

一.环境准备 1.参考说明 NDK开发介绍:https://docs.openharmony.cn/pages/v5.0/zh-cn/application-dev/napi/ndk-development-overview.md 2.NDK下载 点击介绍页面中的链接可以跳转到相应下载页面: 下载相应版本: 下载完毕后解压到指定目…

Matlab详细学习教程 MATLAB使用教程与知识点总结

Matlab语言教程 章节目录 一、Matlab简介与基础操作 二、变量与数据类型 三、矩阵与数组操作 四、基本数学运算与函数 五、图形绘制与数据可视化 六、控制流与逻辑运算 七、脚本与函数编写 八、数据导入与导出 九、Matlab应用实例分析 一、Matlab简介与基础操作 重点内容知识…

第2章 STM32最小系统介绍

第2章 STM32最小系统介绍 1. STM32最小系统组成 2. STM32启动模式 1.STM32最小系统组成 (1)电源电路 (2)复位电路 (3) 晶振电路 (4)下载电路 可打开开发板原理图查看 2.STM32启动模式 在STM…

Postman最新V11版本关键更新一览

Postman作为接口测试中,被广泛应用的一款主流工具,以其丰富的功能,灵活方便的使用方式,广受欢迎。最新发布的V11版本则在向协作平台转型的过程中一路狂奔,增加大量全新的协作支持。下面我们就一起来看看都有哪些变化吧…

基于Arduino的泡茶机器人

打造你的完美泡茶助手 引言 你是否曾遇到过泡出的茶太淡或太苦?通过这个项目,你可以创建一个设备,为你的茶包提供完美的浸泡时间。只需附上一个茶包并放置你的杯子,设备就会开始工作!它将完美地按照你的喜好浸泡你的…