YOLOv5/v7 添加注意力机制,30多种模块分析⑤,SOCA模块 ,SimAM模块

news2024/11/18 14:54:52

目录

    • 一、注意力机制介绍
      • 1、什么是注意力机制?
      • 2、注意力机制的分类
      • 3、注意力机制的核心
    • 二、SOCA模块
      • 1、SOCA模块的原理
      • 2、实验结果
      • 3、应用示例
    • 三、SimAM模块
      • 1、SimAM模块的原理
      • 2、实验结果
      • 3、应用示例

大家好,我是哪吒。

🏆本文收录于,目标检测YOLO改进指南。

本专栏均为全网独家首发,内附代码,可直接使用,改进的方法均是2023年最近的模型、方法和注意力机制。每一篇都做了实验,并附有实验结果分析,模型对比。


在机器学习和自然语言处理领域,随着数据的不断增长和任务的复杂性提高,传统的模型在处理长序列或大型输入时面临一些困难。传统模型无法有效地区分每个输入的重要性,导致模型难以捕捉到与当前任务相关的关键信息。为了解决这个问题,注意力机制(Attention Mechanism)应运而生。

一、注意力机制介绍

1、什么是注意力机制?

注意力机制(Attention Mechanism)是一种在机器学习和自然语言处理领域中广泛应用的重要概念。它的出现解决了模型在处理长序列或大型输入时的困难,使得模型能够更加关注与当前任务相关的信息,从而提高模型的性能和效果。

本文将详细介绍注意力机制的原理、应用示例以及应用示例。

2、注意力机制的分类

类别描述
全局注意力机制(Global Attention)在计算注意力权重时,考虑输入序列中的所有位置或元素,适用于需要全局信息的任务。
局部注意力机制(Local Attention)在计算注意力权重时,只考虑输入序列中的局部区域或邻近元素,适用于需要关注局部信息的任务。
自注意力机制(Self Attention)在计算注意力权重时,根据输入序列内部的关系来决定每个位置的注意力权重,适用于序列中元素之间存在依赖关系的任务。
Bahdanau 注意力机制全局注意力机制的一种变体,通过引入可学习的对齐模型,对输入序列的每个位置计算注意力权重。
Luong 注意力机制全局注意力机制的另一种变体,通过引入不同的计算方式,对输入序列的每个位置计算注意力权重。
Transformer 注意力机制自注意力机制在Transformer模型中的具体实现,用于对输入序列中的元素进行关联建模和特征提取。

3、注意力机制的核心

注意力机制的核心思想是根据输入的上下文信息来动态地计算每个输入的权重。这个过程可以分为三个关键步骤:计算注意力权重、对输入进行加权和输出。首先,计算注意力权重是通过将输入与模型的当前状态进行比较,从而得到每个输入的注意力分数。这些注意力分数反映了每个输入对当前任务的重要性。对输入进行加权是将每个输入乘以其对应的注意力分数,从而根据其重要性对输入进行加权。最后,将加权后的输入进行求和或者拼接,得到最终的输出。注意力机制的关键之处在于它允许模型在不同的时间步或位置上关注不同的输入,从而捕捉到与任务相关的信息。

🏆YOLOv5/v7 添加注意力机制,30多种模块分析①,SE模块,SK模块

🏆YOLOv5/v7 添加注意力机制,30多种模块分析②,BAM模块,CBAM模块

🏆YOLOv5/v7 添加注意力机制,30多种模块分析③,GCN模块,DAN模块

🏆YOLOv5/v7 添加注意力机制,30多种模块分析④,CA模块,ECA模块

二、SOCA模块

1、SOCA模块的原理

SOCA(Second-order Channel Attention,二阶通道注意力)模块是一种用于图像超分辨率的注意力机制。它可以通过对输入特征张量进行协方差计算,并使用计算出的协方差矩阵作为权重,来提高模型对重要通道的关注度,从而提高模型的超分辨率效果。

在这里插入图片描述

SOCA模块的主要思想是将输入特征张量中每个通道之间的关系进行建模,即通过计算协方差矩阵来表示各个通道之间的相关性。具体来说,SOCA模块包括以下几个步骤:

  1. 将输入特征张量进行拉平操作,得到一个二维矩阵。
  2. 对该矩阵进行归一化处理,使得每个通道的均值为0、标准差为1。
  3. 计算该矩阵的协方差矩阵,并对其进行特征分解,得到其特征向量和特征值。
  4. 使用特征向量来构造一个注意力向量,并对输入特征张量进行加权平均。
  5. 将加权平均后的特征张量与原始特征张量相加,得到最终的输出特征张量。

2、实验结果

在这里插入图片描述

5.6×105次迭代中在Set5(4×)上的最佳PSNR(分贝)值。

3、应用示例

下面是使用SOCA模块的应用示例:

import torch.nn as nn
import torch.nn.functional as F

class SOCA(nn.Module):
    def __init__(self, in_channels):
        super(SOCA, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn = nn.BatchNorm2d(in_channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.conv(y.view(b, c, 1, 1)).view(b, c, 1, 1)
        y = self.bn(y)
        y = self.sigmoid(y)
        return x * y

该代码定义了一个SOCA(Second-Order Channel Attention)模块的类,包括以下几个部分:

  1. __init__(self, in_channels):构造函数,接受输入通道数in_channels作为参数。
  2. avg_pool = nn.AdaptiveAvgPool2d(1):创建一个自适应平均池化层,将输入特征张量缩小到 1x1 的大小。
  3. conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0, bias=False):创建一个 1x1 的卷积层,用于降低通道维度。
  4. bn = nn.BatchNorm2d(in_channels):创建一个批归一化层,对输出进行标准化处理。
  5. sigmoid = nn.Sigmoid():创建一个 Sigmoid 激活函数,用于生成注意力权重。
  6. forward(self, x):前向传播函数,接受输入特征张量x作为参数,并返回加权后的特征张量。

在前向传播函数中,首先对输入特征张量进行自适应平均池化操作,然后通过卷积和批归一化层来降低通道维度,并使用 Sigmoid 激活函数生成注意力权重。最后将输入特征张量与注意力权重相乘,得到加权后的特征张量作为输出。

三、SimAM模块

1、SimAM模块的原理

在这里插入图片描述
SimAM模块是一种用于图像分类任务的自适应注意力机制。它的核心思想是利用相似性信息来调整每个通道的注意力权重,从而提升图像分类的性能。

SimAM模块首先通过两个全局池化操作获取特征图中每个通道的全局平均值和标准差。对于每个通道,SimAM将其与其他通道之间的相似性定义为该通道与其他通道的余弦相似度,并将这些相似性作为一个矩阵输入到一个子网络中。该子网络使用多层感知器(MLP)来学习如何将相似性信息转换为注意力权重。SimAM根据学习到的注意力权重对输入的特征进行加权求和,得到了调整后的特征表示。

SimAM模块的优点在于它能够灵活地适应不同的数据分布,从而提高图像分类的泛化性能。此外,由于它只依赖于全局平均值和标准差,SimAM的计算成本比较低,适合大规模图像分类任务。

2、实验结果

在这里插入图片描述

在ImageNet-1K上,ResNet-50有和没有我们的模块的训练曲线比较。左侧和右侧分别显示了两个网络的Top-1(%)和Top-5(%)准确性。可以看出,将SimAM集成到ResNet-50中,在训练和验证中都比基线模型效果更好。

在这里插入图片描述

使用经过训练的带有SimAM的ResNet-50的特征激活可视化。对于每个图像,从左到右的地图是SimAM之前和之后的特征GradCAM,注意力权重和注意力权重的GradCAM。注意力映射是通过沿通道维度平均3-D注意力权重获得的。

3、应用示例

下面是使用SimAM模块在YOLOv5中进行目标检测的应用示例:

(1)在YOLOv5模型中导入SimAM模块:

# 定义SimAM模块
class SimAM(nn.Module):
    def __init__(self, out_channels, groups=8):
        super().__init__()
        self.groups = groups
        self.out_channels = out_channels
        self.conv_theta = nn.Conv2d(out_channels, out_channels // groups, kernel_size=1, bias=False)
        self.conv_phi = nn.Conv2d(out_channels, out_channels // groups, kernel_size=1, bias=False)
        self.conv_g = nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False)
        self.conv_attn = nn.Conv2d(out_channels // groups, out_channels, kernel_size=1, bias=False)

    def forward(self, x):
        # 获取输入张量的高度、宽度和通道数
        b, c, h, w = x.shape

        # 计算query、key、value
        theta = self.conv_theta(x).view(b, self.groups, -1, h * w).permute(0, 1, 3, 2)  # b, g, hw, c//g
        phi = self.conv_phi(x).view(b, self.groups, -1, h * w)  # b, g, c//g, hw
        g = self.conv_g(x).view(b, self.out_channels, -1).permute(0, 2, 1)  # b, hw, c

        # 计算相似度矩阵
        attn = torch.matmul(theta, phi) / math.sqrt(c // self.groups)
        attn = F.softmax(attn, dim=-1)

        # 计算加权后的value
        attn_g = torch.matmul(attn, g).permute(0, 2, 1).reshape(b, self.out_channels // self.groups, h, w)
        attn_g = self.conv_attn(attn_g)

        # 输出结果
        return x + attn_g

以上代码是SimAM(Similarity Attention Module)模块的定义及其前向传播过程。代码说明如下:

  • 初始化函数中,out_channels表示输入张量的通道数,groups表示将输入张量的通道数分成几组进行相似度计算。
  • conv_theta、conv_phi、conv_g和conv_attn分别表示用于计算query、key、value和加权后的value的卷积层。
  • forward函数中,首先获取输入张量的高度、宽度和通道数,然后按照SimAM的原理,计算出query、key、value,并通过相似度矩阵计算出加权后的value。
  • 最后将加权后的value与输入张量相加作为输出结果。

(2)在YOLOv5模型中使用SimAM模块:

# 定义YOLOv5模型
class YOLOv5(nn.Module):
    def __init__(self, num_classes=80):
        super().__init__()
        self.num_classes = num_classes
        
        # 省略其余部分...
        
        # 添加SimAM模块
        self.sa1 = SimAM(256)
        self.sa2 = SimAM(512)
        self.sa3 = SimAM(1024)
        
    def forward(self, x):
        # 省略起始部分...
        
        # 使用SimAM模块提取特征
        x = self.sa1(x)
        x = self.conv3(x)
        x = self.sa2(x)
        x = self.conv4(x)
        x = self.sa3(x)
        x = self.conv5(x)
        
        # 省略其余部分...
        
        # 输出结果
        return x

代码说明如下:

  • 在初始化函数中,num_classes表示模型需要识别的物体类别数。
  • 添加了三个SimAM模块,分别用于提取不同层级的特征,即sa1对应backbone中的C3层,sa2对应C4层,sa3对应C5层。
  • 在forward函数中,首先省略掉起始部分(包括输入张量的大小调整、Backbone和FPN网络),然后使用SimAM模块提取特征,并通过卷积层预测目标框、置信度和类别信息。
  • 最后输出预测结果。

参考论文:

  1. https://openaccess.thecvf.com/content_CVPR_2019/papers/Dai_Second-Order_Attention_Network_for_Single_Image_Super-Resolution_CVPR_2019_paper.pdf
  2. http://proceedings.mlr.press/v139/yang21o/yang21o.pdf

在这里插入图片描述

🏆本文收录于,目标检测YOLO改进指南。

本专栏均为全网独家首发,🚀内附代码,可直接使用,改进的方法均是2023年最近的模型、方法和注意力机制。每一篇都做了实验,并附有实验结果分析,模型对比。

🏆华为OD机试(JAVA)真题(A卷+B卷)

每一题都有详细的答题思路、详细的代码注释、样例测试,订阅后,专栏内的文章都可看,可加入华为OD刷题群(私信即可),发现新题目,随时更新,全天CSDN在线答疑。

🏆哪吒多年工作总结:Java学习路线总结,搬砖工逆袭Java架构师。

🏆往期回顾:

YOLOv5/v7 添加注意力机制,30多种模块分析①,SE模块,SK模块

YOLOv5/v7 添加注意力机制,30多种模块分析②,BAM模块,CBAM模块

YOLOv5/v7 添加注意力机制,30多种模块分析③,GCN模块,DAN模块

YOLOv5/v7 添加注意力机制,30多种模块分析④,CA模块,ECA模块

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/660629.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue 面试题

一、对于MVVM的理解? MVVM 是 Model-View-ViewModel 的缩写。 1、Model 代表数据模型,也可以在Model中定义数据修改和操作的业务逻辑。 2、View 代表UI 组件,它负责将数据模型转化成UI 展现出来。 3、ViewModel 监听模型数据的改变和控制…

【分布式系统】分布式唯一ID生成方案总结

目录 UUID实现 数据库生成IDsegment号段模式美团Leaf-segment号段模式Redis生成ID实现 zookeeper生成ID实现 snowflake雪花算法实现 Leaf-snowflake雪花算法百度UidGenerator滴滴TinyID双号段缓存多DB支持tinyid-client 参考 UUID 基本方法之一。 UUID(Universally Unique Id…

Java阿里巴巴代码规范

目录 1 编程规约1.1 方法参数类型必须一致,不要出现自动装箱拆箱操作1.1.1 反例1.1.2 正例 1.2 SimpleDateFormat是线程不安全的1.2.1 反例1.2.2 正例 1.3 使用equals方法应该注意空指针1.3.1 反例1.3.2 正例 2 异常日志2.1 事务场景中如果异常被被捕获要注意回滚2.…

2023 年最适用于工业物联网领域的三款开源 MQTT Broker

MQTT 最初作为一种轻量级的发布/订阅消息传递协议而设计,如今已经成为工业物联网(IIoT)和工业 4.0 发展的重要基础。它的意义在于实现了各类工业设备与云端的无缝连接,促进了运营技术(OT)和信息技术&#x…

MySQL----事物与存储引擎

文章目录 一、事务介绍1.1 MySQL 事务的概念1.2 事务的ACID特点原子性一致性隔离性持久性 1.3 事务之间的相互影响1.4 设置隔离级别1.5事务控制语句1.6使用 set 设置控制事务 二、存储引擎介绍2.1查看系统支持的存储引擎2.2 修改存储引擎2.3InnoDB行锁与索引的关系 一、事务介绍…

Windows VMware安装RockyLinux9

前言,今天用虚拟机安装rockyLinux时碰到了一些坑,要么时无法联网,要么是无法使用ssh链 接,在这里记录下 准备工作 1. VMware Workstation 17 Pro 2. RockyLinux9.2阿里镜像站,这里无选择了最小版本Rocky-9-latest-x86…

我的内网渗透-Empire模块的使用(宏病毒主要)

目录 stager模块(payload) 宏病毒 理解 在word中的设置 宏病毒代码 运行 保存 监听模块 提权模块 持久化模块 stager模块(payload) 常用的windows类型windows/launcher_bat #生成bat类型,还是可以用的。但…

GPT学习笔记-Enterprise Knowledge Retrieval(企业知识检索)--私有知识库的集成

openai-cookbook/apps/enterprise-knowledge-retrieval at main openai/openai-cookbook GitHub 终于看到对于我解决现有问题的例子代码,对于企业私有知识库的集成。 我对"Retrieval"重新理解了一下,源自动词"retrieve"&#xf…

【Matlab】LM迭代估计法

简介 在最近的传感器校准算法学习中,有一些非线性的代价函数求解使用最小二乘法很难求解,使用LM算法求解会简单许多,因此学习了一下LM算法的基础记录一下。 LM 优化迭代算法时一种非线性优化算法,可以看作是梯度下降与高斯牛顿法…

【linux kernel】linux media子系统分析之media控制器设备

文章目录 一、抽象媒体设备模型二、媒体设备三、Entity四、Interfaces五、Pad六、Link七、Media图遍历八、使用计数和电源处理九、link设置十、Pipeline和Media流十一、链接验证十二、媒体控制器设备的分配器API 本文基于linux内核 4.19.4,抽象媒体设备模型框架的相…

day11_类中成员之方法

成员变量是用来存储对象的数据信息的,那么如何表示对象的行为功能呢?就要通过方法来实现 方法 概念: 方法也叫函数,是一个独立功能的定义,是一个类中最基本的功能单元。把一个功能封装为方法的目的是,可…

信息安全原理与技术期末复习(如学)

文章目录 一、前言(开卷打印)二、选择题三、简答题1、简述端口扫描技术原理(P136)2、分组密码工作方式(P61)3、木马攻击(P176)4、消息认证码(P84)5、非对称密…

华为OD机试真题B卷 JavaScript 实现【删除字符串中出现次数最少的字符】,附详细解题思路

一、题目描述 删除字符串中出现次数最少的字符,如果多个字符出现次数一样则都删除。 二、输入描述 一个字符串。 三、输出描述 删除字符串中出现次数最少的字符,如果多个字符出现次数一样则都删除,如果都被删除 则换为empty。 四、解题…

机器视觉初步5-1:图像平滑专题

在计算机视觉领域,图像平滑处理是一个重要的任务,用于降低噪声,提高图像质量。常见的图像平滑算法有均值滤波、中值滤波、高斯滤波等。本文将介绍这些算法的原理,并分别给出使用Python与Halcon实现的代码。(当前版本&a…

libface 人脸检测

于老师的项目地址GitHub - ShiqiYu/libfacedetection: An open source library for face detection in images. The face detection speed can reach 1000FPS. 关于如何使用,于老师写得很清楚: 测试代码 CMakeList.txt 和 三个face开头的cpp文件都是于老…

有趣的数学 数学建模入门一 从几个简单的示例入手

一、“变量”的概念 一个代数表达式(通常只有一个字母:x,y,z…,如果它取代了一个未知值(物理、经济、时间等),则称为“变量”。 变量的作用是占据一个值所在的位置,如果该…

设计模式之工厂方法模式笔记

设计模式之工厂方法模式笔记 说明Factory Method(工厂方法)目录UML抽象工厂示例类图咖啡抽象类美式咖啡类拿铁咖啡类 咖啡工厂接口美式咖啡工厂类拿铁咖啡工厂类 咖啡店类测试类 说明 记录下学习设计模式-工厂方法模式的写法。 Factory Method(工厂方法) 意图:定义一个用于创…

深度学习图像分类、目标检测、图像分割源码小项目

​demo仓库和视频演示: 银色子弹zg的个人空间-银色子弹zg个人主页-哔哩哔哩视频 卷积网路CNN分类的模型一般使用包括alexnet、DenseNet、DLA、GoogleNet、Mobilenet、ResNet、ResNeXt、ShuffleNet、VGG、EfficientNet和Swin transformer等10多种模型 目标检测包括…

Sourcetree 打开闪退怎么处理

问题描述:Sourcetree打开闪退,已管理员身份运行仍然闪退 解决方法; 在Sourcetree图标上右键,然后打开文件所在位置: 找到目录 xxxx\AppData\Local\Atlassian 删除箭头所指向文件即可。

2023年怎么移除微博粉丝 微博怎么批量移除粉丝方法

2023最新微博批量粉丝移除_手机微博粉丝怎么批量删除 使用微博粉丝移除工具:可以帮助用户快速批量移除粉丝。在微博管理工具中,用户可以根据自己的需要设置移除粉丝的数量,可以一键批量移除多个粉丝。此外,管理工具还提供了粉丝管…