yolov5、YOLOv7、YOLOv8改进:注意力机制CA

news2024/10/5 23:30:04

论文题目:《Coordinate Attention for Efficient Mobile NetWork Design》
论文地址:  https://arxiv.org/pdf/2103.02907.pdf

请添加图片描述

 

本文中,作者通过将位置信息嵌入到通道注意力中提出了一种新颖的移动网络注意力机制,将其称为“Coordinate Attention”。与通过2维全局池化将特征张量转换为单个特征向量的通道注意力不同,Coordinate注意力将通道注意力分解为两个1维特征编码过程,分别沿2个空间方向聚合特征。这样,可以沿一个空间方向捕获远程依赖关系,同时可以沿另一空间方向保留精确的位置信息。然后将生成的特征图分别编码为一对方向感知和位置敏感的attention map,可以将其互补地应用于输入特征图,以增强关注对象的表示。


不同于通道注意力将输入通过2D全局池化转化为单个特征向量,CoordAttention将通道注意力分解为两个沿着不同方向聚合特征的1D特征编码过程。这样的好处是可以沿着一个空间方向捕获长程依赖,沿着另一个空间方向保留精确的位置信息。然后,将生成的特征图分别编码,形成一对方向感知和位置敏感的特征图,它们可以互补地应用到输入特征图来增强感兴趣的目标的表示
 

下面附上改进代码

YOLOv5改进:

common中加入

 

class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)
class CA(nn.Module):
    # Coordinate Attention for Efficient Mobile Network Design
    '''
        Recent studies on mobile network design have demonstrated the remarkable effectiveness of channel attention (e.g., the Squeeze-and-Excitation attention) for lifting
    model performance, but they generally neglect the positional information, which is important for generating spatially selective attention maps. In this paper, we propose a
    novel attention mechanism for mobile iscyy networks by embedding positional information into channel attention, which
    we call “coordinate attention”. Unlike channel attention
    that transforms a feature tensor to a single feature vector iscyy via 2D global pooling, the coordinate attention factorizes channel attention into two 1D feature encoding 
    processes that aggregate features along the two spatial directions, respectively
    '''
    def __init__(self, inp, oup, reduction=32):
        super(CA, self).__init__()

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        

    def forward(self, x):
        identity = x
        
        n,c,h,w = x.size()
        pool_h = nn.AdaptiveAvgPool2d((h, 1))
        pool_w = nn.AdaptiveAvgPool2d((1, w))
        x_h = pool_h(x)
        x_w = pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y) 
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out   

 

在yolo.py中注册

找到parse.model模块   加入下列代码

 

        elif m in [CA]:
            c1, c2 = ch[f], args[0]
            if c2 != no:  # if not outputss
                c2 = make_divisible(c2 * gw, 8)
            args = [c1, c2, *args[1:]]

添加配置文件

# YOLOv5 🚀, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth iscyy multiple
width_multiple: 0.50  # layer channel iscyy multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
   [-1, 1, CA, [1024]],

   [[17, 20, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

 到此YOLOv5就改好了

下面介绍yolov7

同样在common中加入以下代码

 

class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)
class CA(nn.Module):
    # Coordinate Attention for Efficient Mobile Network Design
    '''
        Recent studies on mobile network design have demonstrated the remarkable effectiveness of channel attention (e.g., the Squeeze-and-Excitation attention) for lifting
    model performance, but they generally neglect the positional information, which is important for generating spatially selective attention maps. In this paper, we propose a
    novel attention mechanism for mobile iscyy networks by embedding positional information into channel attention, which
    we call “coordinate attention”. Unlike channel attention
    that transforms a feature tensor to a single feature vector iscyy via 2D global pooling, the coordinate attention factorizes channel attention into two 1D feature encoding 
    processes that aggregate features along the two spatial directions, respectively
    '''
    def __init__(self, inp, oup, reduction=32):
        super(CA, self).__init__()

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        

    def forward(self, x):
        identity = x
        
        n,c,h,w = x.size()
        pool_h = nn.AdaptiveAvgPool2d((h, 1))
        pool_w = nn.AdaptiveAvgPool2d((1, w))
        x_h = pool_h(x)
        x_w = pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y) 
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out   

 

在yolo.py中注册

找到parse.model模块   加入下列代码

 

        elif m in [CA]:
            c1, c2 = ch[f], args[0]
            if c2 != no:  # if not outputss
                c2 = make_divisible(c2 * gw, 8)
            args = [c1, c2, *args[1:]]

添加配置文件

# YOLOv7 🚀, GPL-3.0 license
# parameters
nc: 80  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel iscyy multiple

# anchors
anchors:
  - [12,16, 19,36, 40,28]  # P3/8
  - [36,75, 76,55, 72,146]  # P4/16
  - [142,110, 192,243, 459,401]  # P5/32

# yolov7 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [32, 3, 1]],  # 0
   [-1, 1, Conv, [64, 3, 2]],  # 1-P1/2
   [-1, 1, Conv, [64, 3, 1]],
   [-1, 1, Conv, [128, 3, 2]],  # 3-P2/4 
   [-1, 1, C3, [128]], 
   [-1, 1, Conv, [256, 3, 2]], 
   [-1, 1, MP, []],
   [-1, 1, Conv, [128, 1, 1]],
   [-3, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 2]],
   [[-1, -3], 1, Concat, [1]],  # 16-P3/8
   [-1, 1, Conv, [128, 1, 1]],
   [-2, 1, Conv, [128, 1, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [-1, 1, Conv, [128, 3, 1]],
   [[-1, -3, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [512, 1, 1]],
   [-1, 1, MP, []],
   [-1, 1, Conv, [256, 1, 1]],
   [-3, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 2]],
   [[-1, -3], 1, Concat, [1]],
   [-1, 1, Conv, [256, 1, 1]],
   [-2, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [-1, 1, Conv, [256, 3, 1]],
   [[-1, -3, -5, -6], 1, Concat, [1]],
   [-1, 1, Conv, [1024, 1, 1]],          
   [-1, 1, MP, []],
   [-1, 1, Conv, [512, 1, 1]],
   [-3, 1, Conv, [512, 1, 1]],
   [-1, 1, Conv, [512, 3, 2]],
   [[-1, -3], 1, Concat, [1]],
   [-1, 1, C3, [1024]],
   [-1, 1, Conv, [256, 3, 1]],
  ]

# yolov7 head by iscyy
head:
  [[-1, 1, SPPCSPC, [512]],
   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [31, 1, Conv, [256, 1, 1]],
   [[-1, -2], 1, Concat, [1]],
   [-1, 1, C3, [128]],
   [-1, 1, Conv, [128, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [18, 1, Conv, [128, 1, 1]],
   [[-1, -2], 1, Concat, [1]],
   [-1, 1, C3, [128]],
   [-1, 1, MP, []],
   [-1, 1, Conv, [128, 1, 1]],
   [-3, 1, CA, [128]],
   [-1, 1, Conv, [128, 3, 2]],
   [[-1, -3, 44], 1, Concat, [1]],
   [-1, 1, C3, [256]], 
   [-1, 1, MP, []],
   [-1, 1, Conv, [256, 1, 1]],
   [-3, 1, Conv, [256, 1, 1]],
   [-1, 1, Conv, [256, 3, 2]], 
   [[-1, -3, 39], 1, Concat, [1]],
   [-1, 3, C3, [512]],

# 检测头 -----------------------------
   [49, 1, RepConv, [256, 3, 1]],
   [55, 1, RepConv, [512, 3, 1]],
   [61, 1, RepConv, [1024, 3, 1]],

   [[62,63,64], 1, IDetect, [nc, anchors]],   # Detect(P3, P4, P5)
  ]

至此v7就配置完成了

v8的配置同v5是一样的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/871567.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

msvcp120.dll丢失的解决方法?分享三种常见解决方法

msvcp120.dll是一个动态链接库文件,它是Microsoft Visual C Redistributable包中的一个组成部分。它是用于支持C编程语言的运行时库文件之一。它包含了许多标准库函数、容器类、算法和其他与C语言相关的功能。这些功能包括内存管理、字符串处理、数学计算、文件操作…

每日一题 206反转链表

题目 给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 示例 1: 输入:head [1,2,3,4,5] 输出:[5,4,3,2,1]示例 2: 输入:head [1,2] 输出:[2,1]示例 3: …

浅谈XML配置实现逻辑

XML简介 什么是XML? xml是可扩展的标记语言 XML的作用 主要作用: 1.用来保存数据,而且这些数据具有自我描述性 2.他可以作为项目或者模块的配置文件 3.还可以作为网络传输数据的格式(现在JSON为主) 第一个实例 命…

计算机组成原理-笔记-汇总

📚 前言 本人在备考408,王道讲得的确不错,本人之前也看过哈工大【刘宏伟老师】的课,两者对比下来。 王道——更加基础,对小白更加友好哈工大——偏实践偏硬件(会将更多的代码硬件设计) PS&#…

SpringBoot2-Tomcat部署

1.排除内置 Tomcat 在pom.xml文件中的下添加以下代码&#xff0c;用于排除SpringBoot内置Tomcat <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId><exclusions><exclusion&…

什么是CSS中的渐变(gradient)?如何使用CSS创建线性渐变和径向渐变?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 渐变&#xff08;Gradient&#xff09;在CSS中的应用⭐ 线性渐变&#xff08;Linear Gradient&#xff09;语法&#xff1a;示例&#xff1a; ⭐ 径向渐变&#xff08;Radial Gradient&#xff09;语法&#xff1a;示例&#xff1a; ⭐ 写…

Kotlin和Java互操作时的可空性

注&#xff1a;文中demo的kt版本是1.7.10 一、kotlin语言中的可空性设计 在Java语言中的NPE&#xff08;NullPointerException&#xff09;可以说非常常见&#xff0c;而且诟病已久。 kotlin做为后起之秀&#xff0c;在空指针的问题上进行了升级&#xff0c;即&#xff1…

day9 10-牛客67道剑指offer-JZ66、19、20、75、23、76、8、28、77、78

文章目录 1. JZ66 构建乘积数组暴力解法双向遍历 2. JZ19 正则表达式匹配3. JZ20 表示数值的字符串有限状态机遍历 4. JZ75 字符流中第一个不重复的字符5. JZ23 链表中环的入口结点快慢指针哈希表 6. JZ76 删除链表中重复的结点快慢指针三指针如果只保留一个重复结点 7. JZ8 二…

HttpRunner自动化工具之设置代理和请求证书验证

httprunner设置代理&#xff1a; httprunner 库本身没有提供设置代理的接口&#xff0c;但是底层使用了urllib.requests 等库&#xff0c;可以设置HTTP_PROXY 和HTTPS_PROXY 环境变量&#xff0c;常用的网络库会自动识别这些环境变量。 日常调试使用代理&#xff08;如charles…

SQL-每日一题【】

题目 Employees 表&#xff1a; EmployeeUNI 表&#xff1a; 展示每位用户的 唯一标识码&#xff08;unique ID &#xff09;&#xff1b;如果某位员工没有唯一标识码&#xff0c;使用 null 填充即可。 你可以以 任意 顺序返回结果表。 返回结果的格式如下例所示。 示例 1&a…

使用ip2region获取客户端地区

目录 从gitee拉取ip2region.xdb资源文件 写测试类 注意要写对资源路径 本地测试结果 ​编辑 远端测试结果 从gitee拉取ip2region.xdb资源文件 git clone https://gitee.com/lionsoul/ip2region.git 将xdb放入resources资源文件夹 写测试类 private Searcher searcher;GetMap…

funbox3靶场渗透笔记

funbox3靶场渗透笔记 靶机地址 https://download.vulnhub.com/funbox/Funbox3.ova 信息收集 fscan找主机ip192.168.177.199 .\fscan64.exe -h 192.168.177.0/24___ _/ _ \ ___ ___ _ __ __ _ ___| | __/ /_\/____/ __|/ __| __/ _ |/ …

数据结构-栈的实现(C语言版)

前言 栈是一种特殊的线性表&#xff0c;只允许在固定的一端进行插入和删除的操作&#xff0c;进行数据插入和删除的一端叫做栈顶&#xff0c;另一端叫做栈底。 栈中的数据元素遵循后进先出的的原则。 目录 1.压栈和出栈 2. 栈的实现 3.测试代码 1.压栈和出栈 压栈&#xff…

最小生成树 — Prim算法

同Kruskal算法一样&#xff0c;Prim算法也是最小生成树的算法&#xff0c;但与Kruskal算法有较大的差别。 Prim算法整体是通过“解锁” “选中”的方式&#xff0c;点 -> 边 -> 点 -> 边。 因为是最小生成树&#xff0c;所以针对的也是无向图&#xff0c;所以可以随意…

带你了解SpringBoot支持的复杂参数--自定义对象参数-自动封装

&#x1f600;前言 本篇博文是关于SpringBoot 在响应客户端请求时支持的复杂参数和自定义对象参数&#xff0c;希望您能够喜欢&#x1f60a; &#x1f3e0;个人主页&#xff1a;晨犀主页 &#x1f9d1;个人简介&#xff1a;大家好&#xff0c;我是晨犀&#xff0c;希望我的文章…

Docker知识(详细笔记)

概览图 文章目录 概览图docker 知识速查1. 初识 Docker1.1 概念1.2 特点1.3 架构1.4 应用场景1.5 安装 Docker1.6 配置 Docker 镜像 2. Docker 命令2.1 Docker 进程相关命令2.2 Docker 镜像相关命令2.3 Docker 容器相关命令 3. Docker 容器的数据卷3.1 数据卷概念及作用3.1.1 概…

小米平板6Max14即将发布:自研G1 电池管理芯片,支持33W反向快充

明天晚上7点&#xff08;8 月 14 日&#xff09;&#xff0c;雷军将进行年度演讲&#xff0c;重点探讨“成长”主题。与此同时&#xff0c;小米将推出一系列全新产品&#xff0c;其中包括备受瞩目的小米MIX Fold 3折叠屏手机和小米平板6 Max 14。近期&#xff0c;小米官方一直在…

java 9新特效解读(4)

目录 InputStream 加强 增强的 Stream API takeWhile()的使用 dropWhile()的使用 ofNullable()的使用 iterate()重载的使用 Optional获取Stream的方法 Optional类中stream()的使用 javascript引擎升级&#xff1a;Nashorn InputStream 加强 InputStream 终于有了…

supervisor因为依赖安装失败的解决方法

安装FEATA时报错情况 下列软件包有未满足的依赖关系&#xff1a;supervisor : 依赖: python-pkg-resources 但是它将不会被安装依赖: python-meld3 但是它将不会被安装依赖: python:any (< 2.8)依赖: python:any (> 2.7.5-5~) E: 无法修正错误&#xff0c;因为您要求某些…

图书管理系统-Java

目录 一、图书管理系统样式 二、图书管理系统具体实现 2.1 book包 2.2 user包 2.3 Main类 2.4 operation包 一、图书管理系统样式 管理系统效果图&#xff1a; 通过效果图我们可以看到&#xff0c;首先要实现登录&#xff0c;得到用户姓名和身份&#xff0c;依据用户的身份不…