神经网络学习小记录78——Keras CA(Coordinate attention)注意力机制的解析与代码详解

news2025/1/16 3:41:14

神经网络学习小记录78——Keras CA(Coordinate attention)注意力机制的解析与代码详解

  • 学习前言
  • 代码下载
  • CA注意力机制的概念与实现
  • 注意力机制的应用

学习前言

CA注意力机制是最近提出的一种注意力机制,全面关注特征层的空间信息和通道信息。
在这里插入图片描述

代码下载

Github源码下载地址为:
https://github.com/bubbliiiing/yolov4-tiny-keras

复制该路径到地址栏跳转。

CA注意力机制的概念与实现

请添加图片描述
该文章的作者认为现有的注意力机制(如CBAM、SE)在求取通道注意力的时候,通道的处理一般是采用全局最大池化/平均池化,这样会损失掉物体的空间信息。作者期望在引入通道注意力机制的同时,引入空间注意力机制,作者提出的注意力机制将位置信息嵌入到了通道注意力中。

CA注意力的实现如图所示,可以认为分为两个并行阶段:

将输入特征图分别在为宽度和高度两个方向分别进行全局平均池化,分别获得在宽度和高度两个方向的特征图。假设输入进来的特征层的形状为[H, W, C],在经过宽方向的平均池化后,获得的特征层shape为[H, 1, C],此时我们将特征映射到了高维度上;在经过高方向的平均池化后,获得的特征层shape为[1, W, C],此时我们将特征映射到了宽维度上。

然后将两个并行阶段合并,将宽和高转置到同一个维度,然后进行堆叠,将宽高特征合并在一起,此时我们获得的特征层为:[1, H+W, C],利用卷积+标准化+激活函数获得特征。需要注意的是,这里的卷积通道数一般会小一点,做一个缩放,可以减少参数量。卷积后的特征层的shape为[1, H+W, C/r],其中r为缩放系数。

之后再次分开为两个并行阶段,再将宽高分开成为:[1, H, C/r]和[1, W, C/r],之后进行转置。获得两个特征层[H, 1, C/r]和[1, W, C/r]。

然后利用1x1卷积调整通道数后取sigmoid获得宽高维度上的注意力情况,前者在宽上拓展一下,后者在高上拓展一下,然后一起乘上原有的特征就是CA注意力机制。

实现的python代码为:

def ca_block(input_feature, ratio=16, name=""):
	channel = input_feature._keras_shape[-1]
	h		= input_feature._keras_shape[1]
	w		= input_feature._keras_shape[2]
 
	x_h = Lambda(lambda x: K.mean(x, axis=2, keepdims=True))(input_feature)
	x_h = Lambda(lambda x: K.permute_dimensions(x, [0, 2, 1, 3]))(x_h)
	x_w = Lambda(lambda x: K.max(x, axis=1, keepdims=True))(input_feature)
	
	x_cat_conv_relu = Concatenate(axis=2)([x_w, x_h])
	x_cat_conv_relu = Conv2D(channel // ratio, kernel_size=1, strides=1, use_bias=False, name = "ca_block_conv1_"+str(name))(x_cat_conv_relu)
	x_cat_conv_relu = BatchNormalization(name = "ca_block_bn_"+str(name))(x_cat_conv_relu)
	x_cat_conv_relu = Activation('relu')(x_cat_conv_relu)
 
	x_cat_conv_split_h, x_cat_conv_split_w = Lambda(lambda x: tf.split(x, num_or_size_splits=[h, w], axis=2))(x_cat_conv_relu)
	x_cat_conv_split_h = Lambda(lambda x: K.permute_dimensions(x, [0, 2, 1, 3]))(x_cat_conv_split_h)
	x_cat_conv_split_h = Conv2D(channel, kernel_size=1, strides=1, use_bias=False, name = "ca_block_conv2_"+str(name))(x_cat_conv_split_h)
	x_cat_conv_split_h = Activation('sigmoid')(x_cat_conv_split_h)
 
	x_cat_conv_split_w = Conv2D(channel, kernel_size=1, strides=1, use_bias=False, name = "ca_block_conv3_"+str(name))(x_cat_conv_split_w)
	x_cat_conv_split_w = Activation('sigmoid')(x_cat_conv_split_w)
 
	output = multiply([input_feature, x_cat_conv_split_h])
	output = multiply([output, x_cat_conv_split_w])
	return output

注意力机制的应用

注意力机制是一个即插即用的模块,理论上可以放在任何一个特征层后面,可以放在主干网络,也可以放在加强特征提取网络。

由于放置在主干会导致网络的预训练权重无法使用,本文以YoloV4-tiny为例,将注意力机制应用加强特征提取网络上。

如下图所示,我们在主干网络提取出来的两个有效特征层上增加了注意力机制,同时对上采样后的结果增加了注意力机制
在这里插入图片描述
实现代码如下:

attention = [se_block, cbam_block, eca_block, ca_block]

#---------------------------------------------------#
#   特征层->最后的输出
#---------------------------------------------------#
def yolo_body(input_shape, anchors_mask, num_classes, phi = 0):
    inputs = Input(input_shape)
    #---------------------------------------------------#
    #   生成CSPdarknet53_tiny的主干模型
    #   feat1的shape为26,26,256
    #   feat2的shape为13,13,512
    #---------------------------------------------------#
    feat1, feat2 = darknet_body(inputs)
    if phi >= 1 and phi <= 4:
        feat1 = attention[phi - 1](feat1, name='feat1')
        feat2 = attention[phi - 1](feat2, name='feat2')

    # 13,13,512 -> 13,13,256
    P5 = DarknetConv2D_BN_Leaky(256, (1,1))(feat2)
    # 13,13,256 -> 13,13,512 -> 13,13,255
    P5_output = DarknetConv2D_BN_Leaky(512, (3,3))(P5)
    P5_output = DarknetConv2D(len(anchors_mask[0]) * (num_classes+5), (1,1))(P5_output)
    
    # 13,13,256 -> 13,13,128 -> 26,26,128
    P5_upsample = compose(DarknetConv2D_BN_Leaky(128, (1,1)), UpSampling2D(2))(P5)
    if phi >= 1 and phi <= 4:
        P5_upsample = attention[phi - 1](P5_upsample, name='P5_upsample')

    # 26,26,256 + 26,26,128 -> 26,26,384
    P4 = Concatenate()([P5_upsample, feat1])
    
    # 26,26,384 -> 26,26,256 -> 26,26,255
    P4_output = DarknetConv2D_BN_Leaky(256, (3,3))(P4)
    P4_output = DarknetConv2D(len(anchors_mask[1]) * (num_classes+5), (1,1))(P4_output)
    
    return Model(inputs, [P5_output, P4_output])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1458958.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

花费200元,我用全志H616和雪糕棒手搓了一台可UI交互的视觉循迹小车

常见的视觉循迹小车都具备有路径识别、轨迹跟踪、转向避障、自主决策等基本功能&#xff0c;如果不采用红外避障的方案&#xff0c;那么想要完全满足以上这些功能&#xff0c;摄像头、电机、传感器这类关键部件缺一不可&#xff0c;由此一来小车成本也就难以控制了。 但如果&a…

欢迎来到IT时代----盘点曾经爆火全网的计算机电影

计算机专业必看的几部电影 计算机专业必看的几部电影&#xff0c;就像一场精彩的编程盛宴&#xff01;《黑客帝国》让你穿越虚拟世界&#xff0c;感受高科技的魅力&#xff1b;《社交网络》揭示了互联网巨头的创业之路&#xff0c;《源代码》带你穿越时间解救世界&#xff0c;这…

LabVIEW声速测定实验数据处理

LabVIEW声速测定实验数据处理 介绍了一个基于LabVIEW的声速测定实验数据处理系统的应用。该系统利用LabVIEW的强大数据处理和分析能力&#xff0c;通过设计友好的用户界面和高效的算法&#xff0c;有效提高了声速测定实验的数据处理效率和准确性。通过这个案例&#xff0c;可以…

2024最新uniapp基础腾讯IM

即时通信 IM 快速入门&#xff08;uniapp vue2/vue3&#xff09;-快速入门-文档中心-腾讯云

行人重识别综述

Deep Learning for Person Re-identification: A Survey and Outlook 论文地址https://arxiv.org/pdf/2001.04193 1. 摘要 we categorize it into the closed-world and open-world settings. closed-world&#xff1a;学术环境下 open-world &#xff1a;实际应用场景下 2…

瑞_23种设计模式_适配器模式

文章目录 1 适配器模式&#xff08;Adapter Pattern&#xff09;1.1 介绍1.2 概述1.3 适配器模式的结构 2 类适配器模式2.1 案例2.2 代码实现 3 对象适配器模式&#xff08;推荐&#xff09;★3.1 案例3.2 代码实现 4 拓展——JDK源码解析 &#x1f64a; 前言&#xff1a;本文章…

Java项目,营销抽奖系统设计实现

作者&#xff1a;小傅哥 博客&#xff1a;https://bugstack.cn 项目&#xff1a;https://gaga.plus 沉淀、分享、成长&#xff0c;让自己和他人都能有所收获&#xff01;&#x1f604; 大家好&#xff0c;我是技术UP主&#xff0c;小傅哥。 经过这个假期的嘎嘎卷&#x1f9e8;…

VMware下安装银河麒麟V10操作系统

VMware下安装银河麒麟V10操作系统 文章目录 下载在VMware中应用编辑虚拟机设置 在麒麟系统内安装 下载 官网下载&#xff1a;https://www.kylinos.cn/ 银河麒麟、中标麒麟、开放麒麟、星光麒麟 在VMware中应用 1.新建虚拟机 2.稍后安装操作系统 3.新建虚拟机向导&#xff0…

设计模式三:工厂模式

工厂模式包括简单工厂模式、工厂方法模式和抽象工厂模式&#xff0c;其中后两者属于23中设计模式 各种模式中共同用到的实体对象类&#xff1a; //汽车类&#xff1a;宝马X3/X5/X7&#xff1b;发动机类&#xff1a;B48TU、B48//宝马汽车接口 public interface BMWCar {void s…

CSS-基础-MDN文档学习笔记

CSS构建基础 查看更多学习笔记&#xff1a;GitHub&#xff1a;LoveEmiliaForever MDN中文官网 CSS选择器 选择器是什么 CSS 选择器是 CSS 规则的第一部分&#xff0c;它用来选择HTML元素&#xff0c;选择器所选择的元素&#xff0c;叫做选择器的对象 选择器列表 如果有多…

盘点3款实用的音频文件转文字工具!

在信息爆炸的时代&#xff0c;我们每天都面临着海量的信息输入和输出。其中&#xff0c;音频信息作为一种重要的信息传播方式&#xff0c;如何高效地将其转化为文字&#xff0c;成为许多人和企业迫切的需求。本文将为您盘点几款实用的音频转文字工具&#xff0c;让声音瞬间转化…

通过闭包表解决无限极代理分销

闭包表设计 闭包表是解决分层存储一个简单而又优雅的解决方案&#xff0c;它记录了表中所有的节点关系&#xff0c;并不仅仅是直接的父子关系。   在闭包表的设计中&#xff0c;额外创建了一张节点关系表(空间换取时间)&#xff0c;它包含两列&#xff0c;每一列都是一个指向…

facebook群控如何做?使用静态住宅ip代理有什么好处?

在进行Facebook群控时&#xff0c;ip地址的管理是非常重要的&#xff0c;因为Facebook通常会检测ip地址的使用情况&#xff0c;如果发现有异常的使用行为&#xff0c;比如从同一个ip地址频繁进行登录、发布内容或者在短时间内进行大量的活动等等&#xff0c;就会视为垃圾邮件或…

我的NPI项目之Android USB 系列(一) - 遥望和USB的相识

和USB应该是老朋友了&#xff0c;从2011年接触Android开发开始&#xff0c;就天天和USB打交道了。那时候还有不 对称扁头的usb/方口的usb&#xff0c;直到如今使用广泛的防反插USB3.0 type-C。 但是&#xff0c;一直有一个不是很清楚的问题萦绕在心头&#xff0c;那就是。先有…

Vue3 学习笔记(Day1)

「写在前面」 本文为尚硅谷禹神 Vue3 教程的学习笔记。本着自己学习、分享他人的态度&#xff0c;分享学习笔记&#xff0c;希望能对大家有所帮助。 目录 0 课程介绍 1 Vue3 简介 2 创建 Vue3 工程 2.1 基于 vue-cli 创建 2.2 基于 vite 创建&#xff08;推荐&#xff09; 2.3 …

[word] word正反面打印应该怎么设置呢? #知识分享#学习方法#职场发展

word正反面打印应该怎么设置呢&#xff1f; word文档打印时&#xff0c;如果页数比较多&#xff0c;出于格式要求或为了节省纸张&#xff0c;通常需要正反面打印&#xff0c;那怎么操作正反双面打印呢&#xff1f;通常有两种方法打印。 1、选择“打印”对话框底部的“打印”下…

linux 安装、删除 JTAG驱动

安装 安装驱动需要sudo访问权限&#xff0c;所以得手动安装。 在petalinux安装目录下&#xff1a; 文件的路径。 cd tools/xsct/data/xicom/cable_drivers/lin64/install_script/install_drivers 然后执行文件 install_drivers。 sudo ./install_drivers安装成功。 删除 …

FFmpeg进阶-给视频添加马赛克效果

很多时候为了隐藏视频中的敏感信息如人脸、身份证号、车牌号等,我们会采用马赛克算法对视频帧中的一部分内容进行处理。这里介绍一下如何采用FFmpeg实现马赛克效果。 马赛克效果算法的原理如下: 1.分块处理:首先将图像划分为多个小块或区域 2.像素替换:对于每个小块,算法会将…

sentinel的资源数据指标是如何采集

资源数据采集 之前的NodeSelectorSlot和ClusterBuilderSlot已经完成了对资源调用树的构建, 现在则是要对资源进行收集, 核心点就是这些资源数据是如何统计 LogSlot 作用: 记录异常请求日志, 用于故障排查 public class LogSlot extends AbstractLinkedProcessorSlot<Def…

鸿蒙 状态管理-组件装饰器

前提&#xff1a;基于官网3.1/4.0文档。参考官网文档 基于Android开发体系来进行比较和思考。&#xff08;或有偏颇&#xff0c;自行斟酌&#xff09; 1.概念 Android中使用过Jetpack MVVM框架知道状态管理&#xff0c;包括React前端所使用的状态管理框架&#xff0c;都有所设…