YOLOv7改进:GAMAttention注意力机制

news2024/11/25 11:28:17

1.背景介绍
为了提高各种计算机视觉任务的性能,人们研究了各种注意机制。然而,以往的方法忽略了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,我们提出了一种全局调度机制,通过减少信息缩减和放大全局交互表示来提高深度神经网络的性能。我们沿着卷积空间注意子模块引入了用于通道注意的多层感知器3D置换。

论文题目:Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions
论文地址:https://paperswithcode.com/paper/global-attention-mechanism-retain-information

GAMAttention注意力机制原理图

对于ImageNet-1K,我们将图像预处理为224×224(He et al.[2016])。我们包括ResNet18和ResNet50(He et al.[2016]),以验证不同网络深度的方法推广。对于ResNet50,我们将其与群卷积进行了比较,以防止参数显著增加。我们将起始学习率设置为0.1,并每隔30个阶段降低一次。我们总共使用90个训练时段。在空间注意子模块中,我们将第一个块的第一步从1切换到2,以匹配特征的大小。为了进行公平比较,CBAM保留了其他设置,包括在空间注意子模块中使用最大池。3 MobileNet V2是用于图像分类的最高效的轻量级模型之一。我们对MobileNet V2使用相同的ResNet设置,只是使用了0.045的初始学习率和4×10的权重衰减−5.对ImageNet-1K的评估如表所示。它表明GAM可以稳定地提高不同神经架构的性能。尤其是对于ResNet18,GAM以更少的参数和更好的效率优于ABN。

相关实验结果

对ImageNet-1K的评估如表2所示,它表明GAM可以稳定地提高不同神经体系结构的性能。特别是,对于ResNet18,GAM的性能优于ABN,参数更少,效率更高。

 为了更好地理解空间注意和通道注意分别对消融的贡献,我们通过开启和关闭一种方式进行了消融研究。例如,ch表示空间注意力被关闭,而频道注意力被打开。SP表示通道关注已关闭,空间关注已打开。结果如表3所示。我们可以在两个开关实验中观察到性能的提高。结果表明,空间关注度和通道关注度对性能增益均有贡献。请注意,它们的组合进一步提高了性能。

 将GAM与CBAM在使用和不使用ResNet18最大池化的情况下进行比较。表4显示了结果。可以观察到,在这两种情况下,我们的方法都优于CBAM。

2.YOLOv7改进方法

2.1增加以下GAMAttention.yaml文件

# YOLOv7 🚀, GPL-3.0 license
# parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 1.0  # layer channel multiple

# anchors
anchors:
  - [12,16, 19,36, 40,28]  # P3/8
  - [36,75, 76,55, 72,146]  # P4/16
  - [142,110, 192,243, 459,401]  # P5/32

# yolov7 backbone by yoloair
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [32, 3, 1]],  # 0
   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [128, 3, 2]],  # 3-P2/4 
   [-1, 1, CNeB, [128]], 
   [-1, 1, Conv, [256, 3, 2]], 
   [-1, 1, MP, []],
   [-1, 1, Conv, [128, 1, 1]],
   [-3, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 2]],
   [[-1, -3], 1, Concat, [1]],  # 16-P3/8
   [-1, 1, Conv, [128, 1, 1]],
   [-2, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [[-1, -3, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [512, 1, 1]],
   [-1, 1, MP, []],
   [-1, 1, Conv, [256, 1, 1]],
   [-3, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 2]],
   [[-1, -3], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1]],
   [-2, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [[-1, -3, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [1024, 1, 1]],          
   [-1, 1, MP, []],
   [-1, 1, Conv, [512, 1, 1]],
   [-3, 1, Conv, [512, 1, 1]],
   [-1, 1, Conv, [512, 3, 2]],
   [[-1, -3], 1, Concat, [1]],
   [-1, 1, CNeB, [1024]],
   [-1, 1, Conv, [256, 3, 1]],
  ]

# yolov7 head by yoloair
head:
  [[-1, 1, SPPCSPC, [512]],
   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [31, 1, Conv, [256, 1, 1]],
   [[-1, -2], 1, Concat, [1]],
   [-1, 1, C3C2, [128]],
   [-1, 1, Conv, [128, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [18, 1, Conv, [128, 1, 1]],
   [[-1, -2], 1, Concat, [1]],
   [-1, 1, C3C2, [128]],
   [-1, 1, MP, []],
   [-1, 1, Conv, [128, 1, 1]],
   [-3, 1, GAMAttention, [128]],
   [-1, 1, Conv, [128, 3, 2]],
   [[-1, -3, 44], 1, Concat, [1]],
   [-1, 1, C3C2, [256]], 
   [-1, 1, MP, []],
   [-1, 1, Conv, [256, 1, 1]],
   [-3, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 2]], 
   [[-1, -3, 39], 1, Concat, [1]],
   [-1, 3, C3C2, [512]],

# 检测头 -----------------------------
   [49, 1, RepConv, [256, 3, 1]],
   [55, 1, RepConv, [512, 3, 1]],
   [61, 1, RepConv, [1024, 3, 1]],

   [[62,63,64], 1, IDetect, [nc, anchors]],   # Detect(P3, P4, P5)
  ]

2.2common.py配置

./models/common.py文件增加以下模块


import numpy as np
import torch
from torch import nn
from torch.nn import init

class GAMAttention(nn.Module):
       #https://paperswithcode.com/paper/global-attention-mechanism-retain-information
    def __init__(self, c1, c2, group=True,rate=4):
        super(GAMAttention, self).__init__()
        
        self.channel_attention = nn.Sequential(
            nn.Linear(c1, int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(c1 / rate), c1)
        )
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3), 
            nn.BatchNorm2d(int(c1 /rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3), 
            nn.BatchNorm2d(c2)
        )

    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)
        x = x * x_channel_att
 
        x_spatial_att = self.spatial_attention(x).sigmoid()
        x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle 
        out = x * x_spatial_att
        return out  

def channel_shuffle(x, groups=2):
        B, C, H, W = x.size()
        out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
        out=out.view(B, C, H, W) 
        return out

2.3yolo.py配置

在 models/yolo.py文件夹下

  • 定位到parse_model函数中
  • for i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):内部
  • 对应位置 下方只需要新增以下代码
elif m is GAMAttention:
    c1, c2 = ch[f], args[0]
    if c2 != no:
        c2 = make_divisible(c2 * gw, 8)
    args = [c1, c2, *args[1:]]

修改完成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1049658.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

风光储一体化能源中心 | 数字孪生智慧能源

自“双碳”目标提出以来,我国能源产业不断朝着清洁低碳化、绿色化的方向发展。其中,风能、太阳能等可再生能源在促进全球能源可持续发展、共建清洁美丽世界中被寄予厚望。风能、太阳能具有波动性、间歇性、随机性等特点,主要通过转化为电能再…

This dependency was not found: vxe-table/lib/vxe-table in ./src/main.js

描述 使用时 安装 npm install xe-utils vxe-table 引入 import Vue from vue import xe-utils import VXETable from vxe-table import vxe-table/lib/style.css vxe-table是一个基于 vue 的 PC 端表格组件, 支持增删改查、虚拟滚动、懒加载、快捷菜单、数据校验…

微信公众平台怎么添加秒杀活动

微信公众平台是一个非常有用的工具,它可以帮助企业或个人建立自己的品牌形象,增加用户粘性,提高销售业绩等等。在微信公众平台上添加秒杀活动为主题可以吸引更多的用户关注,促进销售,提高品牌知名度等。下面我们将介绍…

uni-app 实现凸起的 tabbar 底部导航栏

效果图 在 pages.json 中设置隐藏自带的 tabbar 导航栏 "custom": true, // 开启自定义tabBar(不填每次原来的tabbar在重新加载时都回闪现) 新建一个 custom-tabbar.vue 自定义组件页面 custom-tabbar.vue <!-- 自定义底部导航栏 --> <template><v…

图像直方图的基础知识

直方图的概念 图像直方图反映了图像中的灰度分布规律。它描述每个灰度级具有的像元个数&#xff0c;但不包含这些像元在图像中的位置信息。任何一幅特定的图像都有唯一的直方图与之对应&#xff0c;但不同的图像可以有相同的直方图。如果一幅图像有两个不相连的区域组成&#…

ARM Linux DIY(十四)摄像头捕获画面显示到屏幕

前言 前期已经调试好了摄像头和屏幕&#xff0c;今天我们将摄像头捕获的画面显示到屏幕上。 原理 摄像头对应 /dev/video0&#xff0c;屏幕对应 /dev/fb0&#xff0c;所以我们只要写一个应用程序&#xff0c;读取 video0 写入到 fb0 就可以了。 应用程序代码实例 camera_d…

[PyTorch][chapter 55][WGAN]

前言&#xff1a; 前面讲到GAN 在训练生成器的时候&#xff0c;如果当前的Pr 和 Pg 的分布不重叠场景下&#xff1a; JS散度为一个固定值&#xff0c;梯度为0&#xff0c;导致无法更新生成器G WGAN的全称是WassersteinGAN&#xff0c;它提出了用Wasserstein距离&#xff08;也…

第2章 算法

2.1 开场白 2.2 数据结构与算法之间的关系 在“数据结构”课程中&#xff0c;就算谈到算法&#xff0c;也是为了帮助理解好数据结构&#xff0c;并不会详细谈及算法的方方面面。 2.3 两种算法的比较 2.4 算法的定义 算法是解决特定问题求解步骤的描述&#xff0c;在计算机…

【AI视野·今日Robot 机器人论文速览 第四十一期】Tue, 26 Sep 2023

AI视野今日CS.Robotics 机器人学论文速览 Tue, 26 Sep 2023 Totally 73 papers &#x1f449;上期速览✈更多精彩请移步主页 Daily Robotics Papers Extreme Parkour with Legged Robots Authors Xuxin Cheng, Kexin Shi, Ananye Agarwal, Deepak Pathak人类可以通过以高度动态…

华为智能企业远程办公安全解决方案(1)

华为智能企业远程办公安全解决方案&#xff08;1&#xff09; 课程地址方案背景需求分析企业远程办公业务概述企业远程办公安全风险分析企业远程办公环境搭建需求分析 方案设计组网架构设备选型方案亮点 课程地址 本方案相关课程资源已在华为O3社区发布&#xff0c;可按照以下…

shell脚本学习笔记

shell脚本重点记录 判断文件或者文件夹是否存在 if [ ! -d "log" ];thenchmod 707 $file1一个文件的权限包括读取、写入、执行&#xff0c;权限范围包含所有者、所属组、其他人&#xff0c;可以通过数字或者字母描述一个文件的权限&#xff1a;读取权限对应r或4&a…

高速,低延,任意频丨庚顿新一代实时数据库鼎力支撑电力装备服务数字化

产品同质化日趋严重以及市场需求不断迭代等内外形势下&#xff0c;电力装备制造业自身赢利需求不断增涨&#xff0c;电力等下游产业数字化发展形成倒逼之态&#xff0c;作为国家未来发展的高端装备创新工程主战场&#xff0c;电力装备智能化以及服务型转型升级已经成为装备制造…

在nodejs中如何防止ssrf攻击

在nodejs中如何防止ssrf攻击 什么是ssrf攻击 ssrf&#xff08;server-side request forgery&#xff09;是服务器端请求伪造&#xff0c;指攻击者能够从易受攻击的Web应用程序发送精心设计的请求的对其他网站进行攻击。(利用一个可发起网络请求的服务当作跳板来攻击其他服务)…

mac docker部署hadoop集群

1. 安装docker 确保电脑已经安装docker docker安装过程可自行查找资料&#xff0c;mac下docker可以使用brew命令安装 安装之后&#xff0c;查看docker版本&#xff0c;确认安装成功 docker -v2. 下载jdk 最好下载jdk-8&#xff0c;jdk的版本过高可能hadoop2.x不支持jdk-8的下…

掌握 JavaScript 数组方法:了解如何操作和优化数组

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

Android Logcat 命令行工具

关于作者&#xff1a;CSDN内容合伙人、技术专家&#xff0c; 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 &#xff0c;擅长java后端、移动开发、商业变现、人工智能等&#xff0c;希望大家多多支持。 目录 一、导读二、概览三、日常用法3.1 面板介绍3.2 日志过滤…

零代码编程:用ChatGPT批量将多个文件夹中的视频转为音频

有多个文件夹中的 视频&#xff0c;都要批量转换成音频格式。 转换完成后要删除视频。虽然现在已经有很多格式转换软件可以实现这个功能&#xff0c;但是需要一个个文件夹的操作&#xff0c;还要手动去删除视频。用ChatGPT来写一个批量自动操作程序吧&#xff1a; 输入提示词如…

获取el-select选中的下标

accountZbList:[ ]:下拉列表已通过接口获取数据 <el-row><el-col :span"12"><el-form-item label"账簙" prop"accountTook" class"itemzb"><el-select v-model"tableForm.accountTook" placeholder&…

软件测试基础学习

注意&#xff1a; 各位同学们&#xff0c;今年本人求职目前遇到的情况大体是这样了&#xff0c;开发太卷&#xff0c;学历高的话优势非常的大&#xff0c;公司会根据实际情况考虑是否值得培养&#xff08;哪怕技术差一点&#xff09;&#xff1b;学历稍微低一些但是技术熟练的…

改进的最大内切圆算法求裂缝轮廓宽度

前段时间我将网上最大内切圆算法进行了代码的整理&#xff0c;原先博主上传的代码稍微有点乱&#xff0c;可能也是它自己使用&#xff0c;大家可以看这篇整理好的&#xff1a;最大内切圆算法计算裂缝宽度。 最大内切圆算法详解 一个圆与给定的多边形或曲线的每一条边或曲线都…