神经网络学习小记录73——Pytorch CA(Coordinate attention)注意力机制的解析与代码详解

news2025/1/26 15:36:41

神经网络学习小记录73——Pytorch CA(Coordinate attention)注意力机制的解析与代码详解

  • 学习前言
  • 代码下载
  • CA注意力机制的概念与实现
  • 注意力机制的应用

学习前言

CA注意力机制是最近提出的一种注意力机制,全面关注特征层的空间信息和通道信息。
在这里插入图片描述

代码下载

Github源码下载地址为:
https://github.com/bubbliiiing/yolov4-tiny-pytorch

复制该路径到地址栏跳转。

CA注意力机制的概念与实现

在这里插入图片描述
该文章的作者认为现有的注意力机制(如CBAM、SE)在求取通道注意力的时候,通道的处理一般是采用全局最大池化/平均池化,这样会损失掉物体的空间信息。作者期望在引入通道注意力机制的同时,引入空间注意力机制,作者提出的注意力机制将位置信息嵌入到了通道注意力中。

CA注意力的实现如图所示,可以认为分为两个并行阶段:

将输入特征图分别在为宽度和高度两个方向分别进行全局平均池化,分别获得在宽度和高度两个方向的特征图。假设输入进来的特征层的形状为[C, H, W],在经过宽方向的平均池化后,获得的特征层shape为[C, H, 1],此时我们将特征映射到了高维度上;在经过高方向的平均池化后,获得的特征层shape为[C, 1, W],此时我们将特征映射到了宽维度上。

然后将两个并行阶段合并,将宽和高转置到同一个维度,然后进行堆叠,将宽高特征合并在一起,此时我们获得的特征层为:[C, 1, H+W],利用卷积+标准化+激活函数获得特征。

之后再次分开为两个并行阶段,再将宽高分开成为:[C, 1, H]和[C, 1, W],之后进行转置。获得两个特征层[C, H, 1]和[C, 1, W]。

然后利用1x1卷积调整通道数后取sigmoid获得宽高维度上的注意力情况。乘上原有的特征就是CA注意力机制。

实现的python代码为:

class CA_Block(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CA_Block, self).__init__()
        
        self.conv_1x1 = nn.Conv2d(in_channels=channel, out_channels=channel//reduction, kernel_size=1, stride=1, bias=False)
 
        self.relu   = nn.ReLU()
        self.bn     = nn.BatchNorm2d(channel//reduction)
 
        self.F_h = nn.Conv2d(in_channels=channel//reduction, out_channels=channel, kernel_size=1, stride=1, bias=False)
        self.F_w = nn.Conv2d(in_channels=channel//reduction, out_channels=channel, kernel_size=1, stride=1, bias=False)
 
        self.sigmoid_h = nn.Sigmoid()
        self.sigmoid_w = nn.Sigmoid()
 
    def forward(self, x):
        _, _, h, w = x.size()
        
        x_h = torch.mean(x, dim = 3, keepdim = True).permute(0, 1, 3, 2)
        x_w = torch.mean(x, dim = 2, keepdim = True)
 
        x_cat_conv_relu = self.relu(self.bn(self.conv_1x1(torch.cat((x_h, x_w), 3))))
 
        x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([h, w], 3)
 
        s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2)))
        s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w))
 
        out = x * s_h.expand_as(x) * s_w.expand_as(x)
        return out

注意力机制的应用

注意力机制是一个即插即用的模块,理论上可以放在任何一个特征层后面,可以放在主干网络,也可以放在加强特征提取网络。

由于放置在主干会导致网络的预训练权重无法使用,本文以YoloV4-tiny为例,将注意力机制应用加强特征提取网络上。

如下图所示,我们在主干网络提取出来的两个有效特征层上增加了注意力机制,同时对上采样后的结果增加了注意力机制
在这里插入图片描述
实现代码如下:

attention_block = [se_block, cbam_block, eca_block, CA_Block]

#---------------------------------------------------#
#   特征层->最后的输出
#---------------------------------------------------#
class YoloBody(nn.Module):
    def __init__(self, anchors_mask, num_classes, phi=0):
        super(YoloBody, self).__init__()
        self.phi            = phi
        self.backbone       = darknet53_tiny(None)

        self.conv_for_P5    = BasicConv(512,256,1)
        self.yolo_headP5    = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)

        self.upsample       = Upsample(256,128)
        self.yolo_headP4    = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],384)

        if 1 <= self.phi and self.phi <= 3:
            self.feat1_att      = attention_block[self.phi - 1](256)
            self.feat2_att      = attention_block[self.phi - 1](512)
            self.upsample_att   = attention_block[self.phi - 1](128)

    def forward(self, x):
        #---------------------------------------------------#
        #   生成CSPdarknet53_tiny的主干模型
        #   feat1的shape为26,26,256
        #   feat2的shape为13,13,512
        #---------------------------------------------------#
        feat1, feat2 = self.backbone(x)
        if 1 <= self.phi and self.phi <= 3:
            feat1 = self.feat1_att(feat1)
            feat2 = self.feat2_att(feat2)

        # 13,13,512 -> 13,13,256
        P5 = self.conv_for_P5(feat2)
        # 13,13,256 -> 13,13,512 -> 13,13,255
        out0 = self.yolo_headP5(P5) 

        # 13,13,256 -> 13,13,128 -> 26,26,128
        P5_Upsample = self.upsample(P5)
        # 26,26,256 + 26,26,128 -> 26,26,384
        if 1 <= self.phi and self.phi <= 3:
            P5_Upsample = self.upsample_att(P5_Upsample)
        P4 = torch.cat([P5_Upsample,feat1],axis=1)

        # 26,26,384 -> 26,26,256 -> 26,26,255
        out1 = self.yolo_headP4(P4)
        
        return out0, out1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/587934.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Unity随手问题记录(持续更新~~)

目录 1.将摄像机定位到模型实际中心点前边&#xff08;防止有些模型中心点和实际模型中心位置偏移很大的情况&#xff09; 2.获取当鼠标在RawImage上时&#xff0c;鼠标位置对应的图像坐标(简单粗暴方式) 3.设置脚本运行顺序 4.当plugins底下出现dll文件识别不到的情况&#xf…

LeetCode 1110. Delete Nodes And Return Forest【二叉树,DFS,哈希表】中等

本文属于「征服LeetCode」系列文章之一&#xff0c;这一系列正式开始于2021/08/12。由于LeetCode上部分题目有锁&#xff0c;本系列将至少持续到刷完所有无锁题之日为止&#xff1b;由于LeetCode还在不断地创建新题&#xff0c;本系列的终止日期可能是永远。在这一系列刷题文章…

SKD240

SKD240 系列智能电力仪表 SKD240 系列智能电力仪表是陕西斯科德智能科技有限公司自主研发、生产的。 产品概述 - 点击了解详情 SKD240采用先进的微处理器和数字信号处理技术&#xff08;内置主芯片采用32位单片机, 采用32位浮点型真有效值处理数据&#xff09;&#xff0c;测量…

图像采集卡的基本原理、应用领域和发展趋势

图像采集卡是一种硬件设备&#xff0c;用于将模拟视频信号转换为数字信号&#xff0c;并将其传输到计算机中进行处理和存储。它通常用于监控、视频会议、医学图像等领域。本文将介绍图像采集卡的基本原理、应用领域和发展趋势。 一、图像采集卡的基本原理 图像采集卡的基本原…

视频里的声音怎么转换成音频?

视频里的声音怎么转换成音频&#xff1f;这样我们就能把视频里的想要的声音在其他音频平台播放或是用于其他视频。其实视频提取音频是一种将视频文件中的音频数据分离出来的技术。该技术可以将视频中的音频转换为不同的格式&#xff0c;让我们可以在无需视频的情况下使用音频文…

Web 应用程序防火墙 (WAF) 相关知识介绍

Web应用程序防火墙 (WAF) 如何工作&#xff1f; Web应用防护系统&#xff08;也称为&#xff1a;网站应用级入侵防御系统。英文&#xff1a;Web Application Firewall&#xff0c;简称&#xff1a;WAF&#xff09;。利用国际上公认的一种说法&#xff1a;Web应用防火墙是通过执…

【Zero to One系列】在WSL linux系统上,使用docker运行Mysql与Nacos,以及如何启动与停止WSL

前期回顾&#xff1a; 【Zero to One系列】window系统安装Linux、docker 1、下载docker-compose 1.下载&#xff1a; curl -SL https://github.com/docker/compose/releases/download/v2.17.2/docker-compose-linux-x86_64 -o /usr/local/bin/docker-compose 2.授予权限&a…

大数据治理入门系列:元数据管理

在介绍数据治理一文中&#xff0c;我们曾用在图书馆找书的例子解释为什么需要进行数据治理。数据治理在某种程度上类似于图书管理。元数据管理作为数据治理的重要一环&#xff0c;也可以进行这种类比。 在图书管理过程中&#xff0c;需要根据相应的制度购买、记录、存放、借还…

【LeetCode: 486. 预测赢家 | 暴力递归=>记忆化搜索=>动态规划 】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

【T+】畅捷通T+设置收入成本配比结转

【问题需求】 收入成本配比原则是指&#xff1a; 取得的销售收入应与为取得该收入所发生的成本相匹配&#xff0c; 即先出库后销货时需要等收到销售发票才能确认成本&#xff0c; 先销货后出库时要先确认虚拟成本。 【解决方案】 重点&#xff1a;业务流程选择【单据立账】的情…

ArcGis系列-java发布空间表为地图服务(map)

1,实现思路 使用java调用cmd命令执行python脚本python环境使用arcgis pro安装目录下的 \ArcGIS\Pro\bin\Python\envs\arcgispro-py3作为地图服务应该可以支持添加样式文件发布表需要用到sde文件,使用java创建sde的代码可以看这里发布表时,先在本地的空项目模板中添加数据库表作…

vue模拟el-table演示插槽用法

vue模拟el-table演示插槽用法 转载自&#xff1a;www.javaman.cn 很多人知道插槽分为三种&#xff0c;但是实际到elementui当中为什么这么用&#xff0c;就一脸懵逼&#xff0c;接下来就跟大家聊一聊插槽在elementui中的应用&#xff0c;并且自己写一个类似el-table的组件 vue…

回归问题里的数学

假设一个简单的案例 投入的广告费越多&#xff0c;广告的点击量就越高&#xff0c;进而带来访问数的增加&#xff0c;不过点击量经常变化&#xff0c;投入同样的广告费未必能带来同样的点击量。根据广告费和实际点击量的对应关系数据&#xff0c;可以将两个变量用下面的图展示…

CASA模型:生态系统NPP及碳源、碳汇模拟、土地利用变化、未来气候变化、空间动态模拟

查看原文>>>生态系统NPP及碳源、碳汇模拟、土地利用变化、未来气候变化、空间动态模拟实践技术应用 目录 第一章 CASA模型介绍&#xff08;讲解案例实践&#xff09; 第二章 CASA初步操作 第三章 CASA数据制备&#xff08;一&#xff09; 第四章 CASA数据制备&am…

4_回归算法(算法原理推导+实践)

文章目录 1 线性回归1.1 定义1.2 题目分析1.3 误差项分析1.4 目标函数推导1.5 线性回归求解1.6 最小二乘法的参数最优解 2 目标函数&#xff08;loss/cost function&#xff09;3 模型效果判断4 机器学习调参5 梯度下降算法5.1 梯度方向5.2 批量梯度下降算法&#xff08;BGD&am…

Spring IOC容器及DI相关概念

文章目录 一、组件、框架、容器的相关概念1.组件2.框架3.容器4.总结 二、IOC与DI简介1.IOC入门案例2.DI入门案例 一、组件、框架、容器的相关概念 1.组件 组件是为了代码的重用而对代码进行隔离封装&#xff0c;组件的呈现方式是单个或多个.class文件&#xff0c;或者打包的.…

Flutter的手势识别功能实现GestureDetector

GestureDetector简介 GestureDetector 是 Flutter 中一个非常常用的小部件&#xff0c;它提供了许多手势识别的功能&#xff0c;包括点击、双击、长按、拖动、缩放等等。 使用方法 GestureDetector 可以包裹其他部件&#xff0c;当用户在这些部件上进行手势操作时&#xff0…

基于SSM的网辩平台的设计与实现

摘 要 线上作为当前信息的重要传播形式之一&#xff0c;线上辩论系统具有显著的方便性&#xff0c;是人类快捷了解辩论信息、资讯等相关途径。但在新时期特殊背景下&#xff0c;随着网辩的进一步优化&#xff0c;辩论赛结合网络平台融合创新强度也随之增强。本文就网辩平台进…

尧泰汉海五城联动,“益”起圆梦!用爱守护成长,助力502名孩子实现心愿

公益的力量让孩子们的梦想被看见。 文具套装、书包、篮球 、益智积木、生日蛋糕......一个个看似小小的心愿&#xff0c;对于城市里的孩子来说是平常不过的礼物&#xff0c;但却成了许多正处于困境孩子的期待。 本次活动由重庆市慈善总会指导&#xff0c;Home尧泰汉海慈善专项…

【项目】ROS下使用乐视深度相机LeTMC-520

本文主要记录如何在ros下使用乐视深度相机。乐视三合一体感摄像头LeTMC-520其实就是奥比中光摄像头&#xff08;Orbbec Astra Pro&#xff09; 系统&#xff1a;Ubuntu20.04 这款相机使用uvc输入彩色信息&#xff0c;需要使用libuvc、libuvc_ros才能在ROS上正常使用彩色功能。…