YOLOv5改进算法之添加CA注意力机制模块

news2024/9/28 15:19:26

目录

1.CA注意力机制

2.YOLOv5添加注意力机制

送书活动


1.CA注意力机制

CA(Coordinate Attention)注意力机制是一种用于加强深度学习模型对输入数据的空间结构理解的注意力机制。CA 注意力机制的核心思想是引入坐标信息,以便模型可以更好地理解不同位置之间的关系。如下图:

1. 输入特征: CA 注意力机制的输入通常是一个特征图,它通常是卷积神经网络(CNN)中的某一层的输出,具有以下形状:[C, H, W],其中:

  • C 是通道数,表示特征图中的不同特征通道。
  • H 是高度,表示特征图的垂直维度。
  • W 是宽度,表示特征图的水平维度。

2. 全局平均池化: CA 注意力机制首先对输入特征图进行两次全局平均池化,一次在宽度方向上,一次在高度方向上。这两次操作分别得到两个特征映射:

  • 在宽度方向上的平均池化得到特征映射 [C, H, 1]
  • 在高度方向上的平均池化得到特征映射 [C, 1, W]

这两个特征映射分别捕捉了在宽度和高度方向上的全局特征。

3. 合并宽高特征: 将上述两个特征映射合并,通常通过简单的堆叠操作,得到一个新的特征层,形状为 [C, 1, H + W],其中 H + W 表示在宽度和高度两个方向上的维度合并在一起。

4. 卷积+标准化+激活函数: 对合并后的特征层进行卷积操作,通常是 1x1 卷积,以捕捉宽度和高度维度之间的关系。然后,通常会应用标准化(如批量标准化)和激活函数(如ReLU)来进一步处理特征,得到一个更加丰富的表示。

5. 再次分开: 分别从上述特征层中分离出宽度和高度方向的特征:

  • 一个分支得到特征层 [C, 1, H]
  • 另一个分支得到特征层 [C, 1, W]

6. 转置: 对分开的两个特征层进行转置操作,以恢复宽度和高度的维度,得到两个特征层分别为 [C, H, 1][C, 1, W]

7. 通道调整和 Sigmoid: 对两个分开的特征层分别应用 1x1 卷积,以调整通道数,使其适应注意力计算。然后,应用 Sigmoid 激活函数,得到在宽度和高度维度上的注意力分数。这些分数用于指示不同位置的重要性。

8. 应用注意力: 将原始输入特征图与宽度和高度方向上的注意力分数相乘,得到 CA 注意力机制的输出。

2.YOLOv5添加注意力机制

在models/common.py文件中增加以下模块:

import torch
import torch.nn as nn
import torch.nn.functional as F


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6


class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)


class CoordAtt(nn.Module):
    def __init__(self, inp, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()

        self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        identity = x

        n, c, h, w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)

        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out

在models/yolo.py文件下里的parse_model函数将类名加入进去,如下图:

 创建添加CA模块的YOLOv5的yaml配置文件如下:

# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Focus, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, CoordAtt, []],
   [-1, 1, Conv, [512, 3, 2]],  # 6-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, CoordAtt, []],
   [-1, 1, Conv, [1024, 3, 2]],  # 9-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, CoordAtt, []],
   [-1, 1, SPPF, [1024, 5]],  # 12
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 8], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 5], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 17], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 13], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)

   [[20, 23, 26], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

送书活动

用ChatGPT轻松玩转机器学习与深度学习

突破传统学习束缚,借助ChatGPT的神奇力量,解锁AI无限可能!

关键点

(1)利用ChatGPT,轻松理解机器学习和深度学习的概念和技术。

(2)提供实用经验和技巧,更好地掌握机器学习和深度学习的基本原理和方法。

(3)系统全面、易于理解,不需要过多的数学背景,只需掌握基本的编程知识即可上手。

内容简介

随着机器学习和深度学习技术的不断发展和进步,它们的复杂性也在不断增强。对于初学者来说,学习这两个领域可能会遇到许多难题和挑战,如理论知识的缺乏、数据处理的困难、算法选择的不确定性等。此时,ChatGPT可以提供强有力的帮助。利用ChatGPT,读者可以更轻松地理解机器学习和深度学习的概念和技术,并解决学习过程中遇到的各种问题和疑惑。此外,ChatGPT还可以为读者提供更多的实用经验和技巧,帮助他们更好地掌握机器学习和深度学习的基本原理和方法。本书主要内容包括探索性数据分析、有监督学习(线性回归、SVM、决策树等)、无监督学习(降维、聚类等),以及深度学习的基础原理和应用等。

本书旨在为广大读者提供一个系统全面、易于理解的机器学习和深度学习入门教程。不需要过多的数学背景,只需掌握基本的编程知识即可轻松上手。

作者简介

段小手,曾供职于百度、敦煌网、慧聪网、方正集团等知名IT企业。有多年的科技项目管理及开发经验。负责的项目曾获得“国家发改委电子商务示范项目”“中关村现代服务业试点项目”“北京市信息化基础设施提升专项”“北京市外贸公共服务平台”等多项政策支持。著有《深入浅出Python机器学习》《深入浅出Python量化交易实战》等著作,在与云南省公安厅合作期间,使用机器学习算法有效将某类案件发案率大幅降低。

当当网链接:《用ChatGPT轻松玩转机器学习与深度学习 突破传统学习束缚,借助ChatGPT的神奇力量,解锁AI无限可能 段小手》(段小手)【简介_书评_在线阅读】 - 当当图书

京东的链接:京东安全

 关注博主、点赞、收藏、

评论区评论 “ 人生苦短,我爱python”

  即可参与送书活动!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/984728.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

火热的低代码,是时候系统的来学一学了!

一、前言 低代码诞生至今,大家各抒己见,也不乏有针锋相对的意思。古时的治国之术有百家争鸣,如今的低代码也有“诸子论道”,这本质上是一件有助于推动低代码发展的事情。 业内的朋友们一定知道,关于低代码的热点不止发…

数字内容风控行业首本白皮书正式发布,打造长效安全的数字内容生态

数字内容包含文本、图片、视频等多种形式,起源于计算机问世,并随着互联网、智能手机快速发展,如今,数字内容已经成为个人及企业建立形象、传播价值的必要途径。 2022年起,随着ChatGPT的火爆出圈,AI大模型强…

Kotlin+MVVM 构建todo App 应用

作者:易科 项目介绍 使用KotlinMVVM实现的todo app,功能界面参考微软的Todo软件(只实现了核心功能,部分功能未实现)。 功能模块介绍 项目模块:添加/删除项目,项目负责管理todo任务任务模块&a…

执行上下文-通俗易懂版

(1) js引擎执行代码时候/前,在堆内存创建一个全局对象,该对象 所有的作用域(scope)都可以访问,里面会包含Date、Array、String、Number、setTimeout、setInterval等等,其中还有一个window属性指向自己 (2…

C++数组类的自实现,使其可以保存学生成绩,并进行降序排列

类的封装 #ifndef ARRAY_H #define ARRAY_Hclass DoubArray { private:int m_length;double* m_pointer;public:DoubArray(int len);DoubArray(const DoubArray& obj);int length();bool get(int index, double& value);bool set(int index, double value);void sort(…

尚硅谷大数据项目《在线教育之离线数仓》笔记007

视频地址:尚硅谷大数据项目《在线教育之离线数仓》_哔哩哔哩_bilibili 目录 第12章 报表数据导出 P112 01、创建数据表 02、修改datax的jar包 03、ads_traffic_stats_by_source.json文件 P113 P114 P115 P116 P117 P118 P119 P120 P121 P122【122_在…

Hadoop:HDFS--分布式文件存储系统

目录 HDFS的基础架构 VMware虚拟机部署HDFS集群 HDFS集群启停命令 HDFS Shell操作 hadoop 命令体系: 创建文件夹 -mkdir 查看目录内容 -ls 上传文件到hdfs -put 查看HDFS文件内容 -cat 下载HDFS文件 -get 复制HDFS文件 -cp 追加数据到HDFS文件中 -appendTo…

第 3 章 栈和队列(汉诺塔问题递归解法)

1. 背景说明 假设有 3 个分别命名为 X、Y 和 Z 的塔座,在塔座 X 上插有 n 个直径大小各不相同、依小到大编号为 1, 2,…,n 的圆盘。 现要求将 X 轴上的 n 个圆盘移至塔座 Z 上并仍按同样顺序叠排,圆盘移动时必须遵循下列规则&…

伦敦金的走势高低的规律

伦敦金市场是一个流动性很强的市场,其价格走势会在诸多因素的影响下,出现反复的上下波动,如果投资者能够在这些高低走势中找到一定的规律,在相对有利的时机入场和离场,就能够通过不断的交易,累积大量的财富…

浏览器渲染原理及流程

浏览器主要组成与浏览器线程 浏览器组件 浏览器大体上由以下几个组件组成,各个浏览器可能有一点不同。 界面控件 – 包括地址栏,前进后退,书签菜单等窗口上除了网页显示区域以外的部分浏览器引擎 – 查询与操作渲染引擎的接口渲染引擎 – …

记录vite下使用require报错和解决办法

前情提要 我们现在项目用的是vite4react18开发的项目、但是最近公司有个睿智的人让我把webpack中的bpmn组件迁移过来、结果就出现问题啦:因为webpack是commonjs规范、但是vite不是、好像是es吧、可想而知各种报错 废话不多说啦 直接上代码: 注释是之前c…

生成式AI爆发,安全问题如何解决?

在生成式AI浪潮下,如何为行业用户提供符合实际应用场景需求的生成式AI服务,是行业数字化转型的下一个重点。《亚马逊云科技AIGC加速企业创新指南》白皮书指出,AIGC在游戏、零售电商、金融、媒体娱乐、医疗健康等行业都有典型应用场景。作为 A…

冠达管理:股票退市整理期?

近些年来,随着我国股市的发展,股票市场的出资者逐渐增多。但在出资过程中,退市股票的问题也成为了备受重视的论题。那么,股票退市收拾期到底是什么?如何应对退市股票? 首要,什么是股票退市收拾…

设备管理系统有什么功能?它有什么用?

设备管理系统已成为现代化大规模研究所,信息化管理体系建设中最为关键的要素。随着工业设备的机械化、自动化、大型化、高速化以及复杂化等因素不断叠加,设备设施对于工业生产的作用和影响越来越大,其各项制度和流程也涉及面广、内容繁杂。  …

中企绕道突破封锁,防不胜防 | 百能云芯

韩国的财经媒体Business Korea最新报道指出,尽管美方在《通胀削减法案》(IRA)的补贴中排除了中国,但中国企业正通过多种方式积极应对美国在半导体和电动汽车电池领域的封锁,这包括建立合资企业、设立生产基地以及开展技…

STDF-Viewer 解析工具说明

一、简介 1. 概述 STDF(Standard Test Data Format)(标准测试数据格式)是半导体测试行业的最主要的数据格式,包含了summary信息和所有测试项的测试结果;是半导体行业芯片测试数据的存储规范。 在半导体行业…

城市排水监测方案(dtu终端配合工业路由器精准监测)

台风季节,暴雨容易导致城市内涝积水。为有效监测排水状况,预警和防控积水灾害,星创易联推出智慧排水监测解决方案。 解决方案采用星创易联DTU300作为水位数据采集终端,它可挂载在河道及排水井等地点,实时监测水位变化,一旦超过预警阈值,立即通过4G网络传输报警信息,实现对水位…

使用Kmeans进行图像聚类

Kmeans可以用于与发现聚类相关的其他任务 介绍 聚类是一种无监督机器学习技术。这意味着您的数据集没有标签,即与解释变量发现的模式关联的目标变量。 无监督学习是找到看似相似的模式并将它们放入同一个桶中的过程。 最常用的无监督学习算法之一是Kmeans&#xff…

冠达管理:紧盯必要性 追问合理性 再融资问询透露监管新动向

在“活泼资本市场,提振出资者决心”一系列办法落地之后,再融资市场整体已明确收紧,但审阅尺度、相关细则还有待进一步观察。有接受采访的投行人士指出,现在存量项目仍在持续推进,监管审阅要点已在问询环节有较为充沛的…

2.7 PE结构:重定位表详细解析

重定位表(Relocation Table)是Windows PE可执行文件中的一部分,主要记录了与地址相关的信息,它在程序加载和运行时被用来修改程序代码中的地址的值,因为程序在不同的内存地址中加载时,程序中使用到的地址也…