YOLOv11改进 | 注意力篇 | YOLOv11引入GAM注意力机制

news2024/10/1 15:18:03

1.GAM介绍

摘要:为了提高各种计算机视觉任务的性能,人们研究了各种注意机制。然而,现有的方法忽略了保留通道和空间信息以增强跨维交互的重要性。因此,我们提出了一种通过减少信息减少和放大全球交互表示来提高深度神经网络性能的全球驻留机制。我们引入了具有多层单个Ceptron的3D置换用于信道注意,同时还引入了卷积空间注意子模块。对 CIFAR-100和lmageNet-1K上图像分类任务的拟议机制的评估表明我们的方法稳定地优于ResNet和轻量级的 MobileNet的几个最近的注意机制。

官方论文地址:https://ar5iv.labs.arxiv.org/html/2112.05561

官方代码地址:https://github.com/dengbuqi/GAM_Pytorch/blob/main/CAM.py

简单介绍: GAM旨在通过设计一种机制,减少信息损失并放大全局维度互动特征,从而解决传统注意力机制在通道和空间两个维度上保留信息不足的问题。GAM采用了顺序的通道-空间注意力制,并对子模块进行了重新设计。具体来说,通道注意力子模块使用3D排列来跨三个维度保留信息,并通过一个两层的MLP增强跨维度的通道-空间依赖性。在空间注意力子模块中,为了更好地关注空间信息,采用了两个卷积层进行空间信息融合,同时去除了可能导致信息减少的最大池化操作。

GAM模块结构图如下:

2.核心代码

import torch
import torch.nn as nn
 
 
class GAM(nn.Module):
    def __init__(self, in_channels, rate=4):
        super().__init__()
        out_channels = in_channels
        in_channels = int(in_channels)
        out_channels = int(out_channels)
        inchannel_rate = int(in_channels/rate)
 
 
        self.linear1 = nn.Linear(in_channels, inchannel_rate)
        self.relu = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(inchannel_rate, in_channels)
        
 
        self.conv1=nn.Conv2d(in_channels, inchannel_rate,kernel_size=7,padding=3,padding_mode='replicate')
 
        self.conv2=nn.Conv2d(inchannel_rate, out_channels,kernel_size=7,padding=3,padding_mode='replicate')
 
        self.norm1 = nn.BatchNorm2d(inchannel_rate)
        self.norm2 = nn.BatchNorm2d(out_channels)
        self.sigmoid = nn.Sigmoid()
 
    def forward(self,x):
        b, c, h, w = x.shape
        # B,C,H,W ==> B,H*W,C
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        
        # B,H*W,C ==> B,H,W,C
        x_att_permute = self.linear2(self.relu(self.linear1(x_permute))).view(b, h, w, c)
 
        # B,H,W,C ==> B,C,H,W
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)
 
        x = x * x_channel_att
 
        x_spatial_att = self.relu(self.norm1(self.conv1(x)))
        x_spatial_att = self.sigmoid(self.norm2(self.conv2(x_spatial_att)))
        
        out = x * x_spatial_att
 
        return out
 
if __name__ == '__main__':
    img = torch.rand(1,64,32,48)
    b, c, h, w = img.shape
    net = GAM(in_channels=c, out_channels=c)
    output = net(img)
    print(output.shape)

3.YOLOv11中添加GAM方式 

3.1 在ultralytics/nn下新建Extramodule

 

 3.2 在Extramodule里创建GAM

在GAM.py文件里添加给出的GAM代码

添加完GAM代码后,在ultralytics/nn/Extramodule/__init__.py文件中引用

3.3 在task.py里引用

在ultralytics/nn/tasks.py文件里引用Extramodule

在tasks.py找到parse_model(ctrl+f可以直接搜索parse_model位置

添加如下代码:

        elif m in {GAM}:
            c2 = ch[f]
            args = [c2, *args]

4.新建一个yolo11GAM.yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)
  - [-1, 1, GAM, []]

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)
  - [-1, 1, GAM, []]

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)
  - [-1, 1, GAM, []]

  - [[16, 20, 25], 1, Detect, [nc]] # Detect(P3, P4, P5)

大家根据自己的数据集实际情况,修改nc大小。

5.模型训练

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO(r'D:\yolo\yolov11\ultralytics-main\datasets\yolo11GAM.yaml')
    model.train(data=r'D:\yolo\yolov11\ultralytics-main\datasets\data.yaml',
                cache=False,
                imgsz=640,
                epochs=100,
                single_cls=False,  # 是否是单类别检测
                batch=8,
                close_mosaic=10,
                workers=0,
                device='0',
                optimizer='SGD',
                amp=True,
                project='runs/train',
                name='exp',
                )

模型结构打印,成功运行 :

6.本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的,后期我会根据各种前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

YOLOv11有效涨点专栏

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2182597.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue3 实现拖拽排序效果 sortablejs

效果图 依赖安装 npm i sortablejs -S <template><div class"warp"><div class"parent-box" v-for"pItem in sortData" :key"pItem.name"><h2 class"parent-name">{{ pItem.name }}</h2>&l…

程序计数器(学习笔记)

程序计数器是一块较小的内存空间&#xff0c;它的作用可以看做是当前线程所执行的字节码的信号指示器&#xff08;偏移地址&#xff09;&#xff0c;Java编译过程中产生的字节码有点类似编译原理的指令&#xff0c;程序计数器的内存空间存储的是当前执行的字节码的偏移地址 因为…

唱响红色志愿,赞歌献给祖国——杭州建德市庆祝中华人民共和国成立75周年联欢盛宴纪实

作者&#xff1a;华夏之音/李望 通讯员&#xff1a;王江平 9月30日上午&#xff0c;金桂的香气与红旗的鲜艳交相辉映&#xff0c;杭州建德市党群服务中心、建德市新时代文明实践中心内洋溢着一股浓厚的节日氛围。在这里&#xff0c;一场名为“唱响红色志愿、赞歌献给祖国”的联…

企业架构系列(15)ArchiMate第13节:战略视角

战略视角提供了对企业高层战略方向和构成的不同视角建模&#xff0c;使建模者能够专注于某些特定方面。 一、战略视角概览 战略视角主要包括&#xff1a; 战略视角&#xff1a;提供企业战略、能力、价值流和资源以及预期成果的高层概述。能力地图视角&#xff1a;提供企业能力…

MySQL基础篇 part1

为什么使用数据库和数据库基本概念 想在vscode用markdown了&#xff0c;为什么不直接拿pdf版本呢&#xff1f; DB:数据库(Database) 即存储数据的“仓库”&#xff0c;其本质是一个文件系统。它保存了一系列有组织的数据。 DBMS:数据库管理系统(Database Management System)…

Oracle控制文件全部丢失如何使用RMAN智能恢复?

1.手动删除所有控制文件模拟故障产生 2.此时启动数据库发现控制文件丢失 3.登录rman 4.列出故障 list failure; 5.让RMAN列举恢复建议 advise failure; 6.使用RMAN智能修复 repair failure;

当AI遇上金融科技,创新业务场景和案例涌现

大家好&#xff0c;我是Shelly&#xff0c;一个专注于输出AI工具和科技前沿内容的AI应用教练&#xff0c;体验过300款以上的AI应用工具。关注科技及大模型领域对社会的影响10年。关注我一起驾驭AI工具&#xff0c;拥抱AI时代的到来。 在这个信息爆炸的时代&#xff0c;我们每天…

【路径规划】使用 RRT、RRT* 和 BIT* 进行网格图的路径规划

摘要 本文比较了三种路径规划算法&#xff1a;快速随机树&#xff08;RRT&#xff09;、快速随机树星&#xff08;RRT* &#xff09;和批量信息树&#xff08;BIT*&#xff09;&#xff0c;在网格图环境中进行路径规划的效果。通过仿真分析这些算法在路径质量、计算效率和收敛…

程序员哪里累了?

程序员是最不累的&#xff0c;最不辛苦的职业&#xff0c;非要说有什么门槛&#xff0c;那只需要你有点智力而已。 在这么多的职业中&#xff0c;比程序员轻松的职业可不多&#xff0c;跟程序员的比起来&#xff0c;大部分的职业更苦、更累。 这些问题经常在网上谈论来谈论去&…

永磁电机与普通电机的比较:结构、原理、性能及应用场景分析

创作不易&#xff0c;您的打赏、关注、点赞、收藏和转发是我坚持下去的动力&#xff01; 永磁电机和普通电机在结构、运行原理、性能以及应用场景上都有较大的不同。为了详细回答这些问题&#xff0c;先分别介绍两种电机的基本特点&#xff0c;再分析其异同点及适用场景。 一…

YOLOv11,地瓜RDK X5开发板,TROS端到端140FPS!

YOLOv11 Detect YOLOv11 Detect YOLO介绍性能数据 (简要) RDK X5 & RDK X5 Module 模型下载地址输入输出数据公版处理流程优化处理流程步骤参考 环境、项目准备导出为onnxPTQ方案量化转化使用hb_perf命令对bin模型进行可视化, hrt_model_exec命令检查bin模型的输入输出情况…

录屏+GIF一键生成,2024年费软件大揭秘

视频和 GIF 动图那可都是咱日常生活和工作里少不了的东西。不管是教学的时候用用、直播打个游戏&#xff0c;还是在社交媒体上分享点啥&#xff0c;高质量的录屏和 GIF 制作工具那可老重要了。今天呢&#xff0c;咱就一起来瞅瞅三款在 2024 年特别受推崇的免费录屏和 GIF 制作软…

安装pymssql

一、pycharm 安装pymssql 要在PyCharm中安装pymssql&#xff0c;你需要打开PyCharm的终端或者是Python解释器的交互模式。以下是安装pymssql的步骤&#xff1a; 打开PyCharm。 确保你正在使用的是正确的Python解释器。你可以在PyCharm的右下角看到当前使用的解释器。 点击顶…

SpringBoot实现社区医院数据集成解决方案

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及&#xff0c;互联网成为人们查找信息的重要场所&#xff0c;二十一世纪是信息的时代&#xff0c;所以信息的管理显得特别重要。因此&#xff0c;使用计算机来管理社区医院信息平台的相关信息成为必然。开发…

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-09-30

计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-09-30 目录 文章目录 计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-09-30目录1. Proof Automation with Large Language Models概览&#xff1a;论文研究背景&#xff1a;技术挑战&#xff1a;如何破局…

在Linux中将设备驱动的地址映射到用户空间

本期主题&#xff1a; MMU的简单介绍&#xff0c;以及如何实现设备地址映射到用户空间 往期链接&#xff1a; Linux内核链表零长度数组的使用inline的作用嵌入式C基础——ARRAY_SIZE使用以及踩坑分析Linux下如何操作寄存器&#xff08;用户空间、内核空间方法讲解&#xff09;…

利用SpringBoot构建高效社区医院平台

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

【初阶数据结构】排序——交换排序

目录 前言冒泡排序快速排序Hoare版前后指针版优化三数取中法取随机数做基准值小区间优化 快排非递归版 前言 对于常见的排序算法有以下几种&#xff1a; 下面这节我们来看交换排序算法。 冒泡排序 基本思想&#xff1a; 在待排序序列中&#xff0c;每一次将相邻的元素进行两…

CSS内边距

内边距&#xff08;padding&#xff09;是指元素内容区与边框之间的区域&#xff0c;与外边距不同&#xff0c;内边距会受到背景属性的影响。您可以通过下面的属性来设置元素内边距的尺寸&#xff1a; padding-top&#xff1a;设置元素内容区上方的内边距&#xff1b;padding-…

2024-09-06 深入JavaScript高级语法十六——JS的内存管理和闭包

目录 1、JS内存管理1.1、认识内存管理1.2、JS的内存管理1.3、JS的垃圾回收1.3.1、常见的 GC 算法 - 引用计数1.3.2、常见的 GC 算法&#xfe63;标记清除 2、JS闭包2.1、JS中函数是一等公民2.2、JS中闭包的定义2.3、闭包的访问过程2.4、闭包的内存泄漏2.5、JS闭包内存泄漏案例2…