CV常用注意力机制总结

news2024/11/25 2:39:42

本文总结了近几年CV领域常用的注意力机制,包括:SE(Squeeze and Excitation)、ECA(Efficient Channel Attention)、CBAM(Convolutional Block Attention Module)、CA(Coordinate attention for efficient mobile network design)

一、SE

目的是给特征图中不同的通道赋予不同的权重,步骤如下:

  1. 对特征图进行Squeeze,该步骤是通过全局平均池化把特征图从大小为(N,C,H,W)转换为(N,C,1,1),这样就达到了全局上下文信息的融合
  2. Excitation操作,该步骤使用两个全连接层,其中第一个全连接层使用ReLU激活函数,第二个全连接层采用Sigmoid激活函数,目的是将权重中映射到(0,1)之间。值得注意的是,为了减少计算量进行降维处理,将第一个全连接的输出采用输入的1/4或者1/16
  3. 通过广播机制将权重与输入特征图相乘,得到不同权重下的特征图

代码实现如下, 

import torch
import torch.nn as nn


class Se(nn.Module):
    def __init__(self, in_channel, reduction=16):
        super(Se, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.fc = nn.Sequential(
            nn.Linear(in_features=in_channel, out_features=in_channel//reduction, bias=False),
            nn.ReLU(),
            nn.Linear(in_features=in_channel//reduction, out_features=in_channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self,x):
        out = self.pool(x)
        out = self.fc(out.view(out.size(0),-1))
        out = out.view(x.size(0),x.size(1),1,1)
        return out*x

二、ECA

可参考《注意力机制 ECA-Net 学习记录_eca注意力机制_chen_zn95的博客-CSDN博客》

ECA也是一个通道注意力机制,该算法是在SE的基础上做出了一定的改进,首先ECA作者认为SE中的全连接降维可以降低模型的复杂度,但这也破坏了通道与其权重之间的直接对应关系,先降维后升维,这样权重和通道的对应关系是间接的。为了解决以上问题,作者提出一维卷积的方法,避免了降维对数据的影响,步骤如下:

  1. 对特征图进行Squeeze,该步骤是通过全局平均池化把特征图从大小为(N,C,H,W)转换为(N,C,1,1),这样就达到了全局上下文信息的融合(与SE的步骤1相同)
  2. 计算自适应卷积核的大小,k=\left | \frac{log_{2}^{C}}{\gamma }+\frac{b}{\gamma } \right |,其中,C:输入通道数,b=1,γ=2;对经过步骤一处理的特征进行一维卷积操作(获得局部跨通道信息),再采用Sigmoid激活函数将权重映射在0到1之间
  3. 通过广播机制将权重与输入特征图相乘,得到不同权重下的特征图

代码实现如下,

import torch
import torch.nn as nn
import math


class ECA(nn.Module):
    def __init__(self, in_channel, gamma=2, b=1):
        super(ECA, self).__init__()
        k = int(abs((math.log(in_channel,2)+b)/gamma))
        kernel_size = k if k % 2 else k+1
        padding = kernel_size//2
        self.pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size, padding=padding, bias=False),
            nn.Sigmoid()
        )

    def forward(self,x):
        out=self.pool(x)
        out=out.view(x.size(0), 1, x.size(1))
        out=self.conv(out)
        out=out.view(x.size(0), x.size(1), 1, 1)
        return out*x

三、CBAM

CBAM是一种将通道与空间注意力机制相结合的算法,输入特征图先进行通道注意力机制再进行空间注意力机制操作,这样从通道和空间两个方面达到了强化感兴趣区域的目的

通道注意力机制的实现步骤如下:

  1. 对特征图进行Squeeze,该步骤分别采用全局平均池化和全局最大池化把特征图从大小为(N,C,H,W)转换为(N,C,1,1),这样就达到了全局上下文信息的融合
  2. 分别将全局最大池化和全局平均池化结果进行MLP操作,MLP在这里定义与SE的全连接层操作一样,为两层全连接层,中间采用ReLU激活,最后将两者相加后利用Sigmoid函数激活
  3. 通过广播机制将权重与输入特征图相乘,得到不同权重下的特征图

空间注意力机制的实现步骤如下:

  1. 将上述通道注意力操作的结果,分别在通道维度上进行最大池化和平均池化,即将经过通道注意力机制的特征图从(N,C,H,W)转换为(N,1,H,W),融合不同通道的信息,然后在通道维度上将最大池化与平均池化的结果concat起来
  2. 将叠加后2个通道的结果做卷积运算,输出通道为1,卷积核大小为7,最后将输出结果进行Sigmoid处理
  3. 通过广播机制将权重与输入特征图相乘,得到不同权重下的特征图

代码实现如下,

import torch
import torch.nn as nn
import math


class CBAM(nn.Module):
    def __init__(self, in_channel, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        # 通道注意力机制
        self.max_pool = nn.AdaptiveMaxPool2d(output_size=1)
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)
        self.mlp = nn.Sequential(
            nn.Linear(in_features=in_channel, out_features=in_channel//reduction, bias=False),
            nn.ReLU(),
            nn.Linear(in_features=in_channel//reduction, out_features=in_channel,bias=False)
        )
        self.sigmoid = nn.Sigmoid()
        # 空间注意力机制
        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size , stride=1, padding=kernel_size//2, bias=False)

    def forward(self,x):
        # 通道注意力机制
        max_out = self.max_pool(x)
        max_out = self.mlp(max_out.view(maxout.size(0), -1))
        avg_out = self.avg_pool(x)
        avg_out = self.mlp(avg_out.view(avgout.size(0), -1))
        channel_out = self.sigmoid(max_out+avg_out)
        channel_out = channel_out.view(x.size(0), x.size(1),1,1)
        channel_out = channel_out*x
        # 空间注意力机制
        max_out, _ = torch.max(channel_out, dim=1, keepdim=True)
        mean_out = torch.mean(channel_out, dim=1, keepdim=True)
        out = torch.cat((max_out, mean_out), dim=1)
        out = self.sigmoid(self.conv(out))
        out = out*channel_out
        return out

四、CA

SE对提升模型性能具有显著的有效性,但它们通常忽略了位置信息。CA利用x、y两个方向的全局池化,分别将垂直和水平方向上的输入特征聚合为两个独立的方向感知特征映射,将输入Feature Map的位置信息嵌入到通道注意力的聚合特征向量。这两个嵌入方向特定信息的特征图被分别编码到两个注意图中,然后通过乘法将这两种注意图应用于输入特征图,加强感兴趣区域的表示。实现步骤如下:

  1. 沿着x、y方向对输入特征图分别进行自适应池化操作,将特征图大小从(N,C,H,W)变为(N,C,H,1)、(N,C,1,W),对大小为(N,C,1,W)的特征进行permute操作,使得该特征大小变为(N,C,W,1),再沿着(dim=2)对这两个特征进行concat,得到大小为(N,C,H+W,1)的特征图
  2. 对步骤一处理后的特征进行1*1卷积操作(目的是降维),随后再经过BN和h_swish激活函数
  3. 沿着dim=2对步骤二处理后的特征进行分割操作(torch.split),将特征分为(N,C/r,H,1)、(N,C/r,W,1),其中r为通道降维因子。随后对大小为(N,C/r,W,1)的特征进行permute操作,变为(N,C/r,1,W)。最后对以上两个特征进行1*1卷积操作(目的是改变特征图的通道数)并经Sigmoid处理
  4. 通过广播机制将步骤三的两个特征先后与输入特征图相乘(out = identity * a_w * a_h),得到不同权重下的特征图

【下图有点小问题,应该是对1D Global Avg Pool(W)特征进行permute操作】

代码实现如下,

import torch
import torch.nn as nn
import math
import torch.nn.functional as F


class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)

    def forward(self, x):
        return self.relu(x + 3) / 6

class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)

    def forward(self, x):
        return x * self.sigmoid(x)

class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))

        mip = max(8, inp // reduction)

        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        

    def forward(self, x):
        identity = x
        
        n,c,h,w = x.size()
        x_h = self.pool_h(x)
        x_w = self.pool_w(x).permute(0, 1, 3, 2)

        y = torch.cat([x_h, x_w], dim=2)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y) 
        
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)

        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()

        out = identity * a_w * a_h

        return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/746970.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DevOps基础服务2——Jenkins

文章目录 一、基本了解1.1 CI/CD介绍1.2 基于Docker的CI/CD 二、安装jenkins三、页面管理3.1 当前系统用户配置3.2 系统配置3.3 全局工具配置3.4 插件管理3.4.1 安装插件3.4.2 上传插件 3.5 用户设置3.6 查看日志3.7 汉化设置 一、基本了解 DEVOPS概念: DevOps是一种…

经典指针与数组笔试题——C语言

学习这片文章中的知识点,可以加深大家对指针应用的理解,让大家更能轻松知道指针在各种情况下指向那个内存地址。    文章开始之前 ,我们先来介绍一下一些必要的知识点 📢 : 以下代码都是在64位编译器下测试的 经典…

orcle报错:无监听程序,解决方法

orcle报错:无监听程序,解决方法 报错页面: 打开桌面侧边安装orcle的列表,找到Net Configuration Assistant,双击(这个可以重新配置监听) ![]](https://img-blog.csdnimg.cn/3ba6bd6bd0af413ca5…

nginx 开机自启

0x00 前言 简单的记录下 0x01 正文 cd /lib/systemd/system/ vim nginx.service [Unit] Descriptionnginx service Afternetwork.target [Service] Typeforking ExecStart/usr/local/nginx/sbin/nginx ExecReload/usr/local/nginx/sbin/nginx -s reload ExecStop/usr/lo…

pt18CSS

CSS 基础使用 CSS全称为: Cascading Style Sheets ,意为层叠样式表 ,与HTML相辅相成,实现网页的排版布局与样式美化 CSS使用方式 行内样式/内联样式 使用简单,但不推荐,大面积使用,很累 借…

ESP32(MicroPython) 网页显示温湿度+RGB点阵控制

本程序整合了RGB点阵可交互超声波云台网页显示温湿度程序和网页控制WS2812程序的功能,对一些参数进行了调整。去掉了图标以加快加载速度,去掉了超声波云台和按键控制以简化接线并实现RGB点阵更新周期可调,由于RGB点阵更新周期相对较长&#x…

vue3前端模拟https安全策略同局域网内测试方法-local-ssl-proxy

文章目录 前言建议全局安装运行安全策略模拟运行效果如果其他客户端不能访问 直接在cmd跑即可,不过我们应该先运行项目 前言 为什么要用https安全策略呢,因为http浏览器策略访问权限有限,不能使用navigator的激活“用户音频或视频”的方法&a…

windows上的mysql服务突然消失: 10061 Unkonwn error

问题描述 windows10 系统,今天早晨系统自己更新了下,也没啥问题,突然发现电脑上安装的mysql 服务没了… 原因分析: 我是安装的解压版的mysql 虽然服务没了,但是文件夹,包括数据啥的都在最重要的就是数据啦,还好都在 解决方案: 打开mysql 的bin所在目录…

Git撤销已合并提交的多种姿势

#Git撤销已合并提交的多种姿势 在Git中,合并分支是一个常见的操作,但有时候可能会意外地将错误的提交合并到了主分支。这时候需要撤销已合并的提交并恢复到正确的状态。本文将介绍的是如何在Git中撤销已合并的提交,无论这个提交记录是最新的还…

包管理工具:npm

安装Node的过程会自动安装npm工具 比如npm install dayjs后 const dayjsrequire("dayjs")console.log(dayjs()) 直接运行 生成package.json文件  方式一:手动从零创建项目,npm init –y  方式二:通过脚手架创建项目&#xf…

Sentinel服务器容错简介

spring gateway 详解 服务容错高并发带来的问题服务雪崩效应常见容错方案常见的容错思路1、隔离2、超时3、限流4、熔断5、降级 常见的容错组件 SentinelSentinel 具有以下特征:Sentinel概念和功能相关概念1、资源2、规则 重要功能 服务容错 高并发带来的问题 在微服务架构中&…

Vivado使用误区与进阶系列(七)用Tcl定制Vivado设计实现流程

01 基本的FPGA设计实现流程 FPGA 的设计流程简单来讲,就是从源代码到比特流文件的实现过程。大体上跟 IC 设计流程类似,可以分为前端设计和后端设计。其中前端设计是把源代码综合为对应的门级网表的过程,而后端设计则是把门级网表布局布线到…

orcle报错:TNS 监听程序无法为请求的服务器类型找到可用的处理程序

orcle报错:TNS 监听程序无法为请求的服务器类型找到可用的处理程序 方法一:配置文件修改 服务端的数据库是专用服务器,但是在客户端的tnsname.ora里配置中设置了连接方式为shared,这种情况下打开tnsnames.ora, 找到安装orcle的安装目录,点…

MSP430F5529,超声波,距离检测报警,倒车雷达,SR-04模块

文章目录 硬件连接功能实物效果代码 硬件连接 /* OLED----MSP430VCC-----3.3VGND-----GNDSCL------P3.1SDA------P3.0 */ /* 蜂鸣器----MSP430VCC-----3.3VGND-----GNDDAT------P2.4 */ /* 超声波----MSP430VCC-----3.3VGND-----GNDTRIG------P1.3ECHO------P1.2 */ /* …

模板学堂|数据关系和AntV、ECharts图表解析

DataEase开源数据可视化分析平台于2022年6月正式发布模板市场(https://dataease.io/templates/)。模板市场旨在为DataEase用户提供专业、美观、拿来即用的仪表板模板,方便用户根据自身的业务需求和使用场景选择对应的仪表板模板,并…

PHP 音乐欣赏网站mysql数据库web结构apache计算机软件工程网页wamp

一、源码特点 PHP音乐欣赏网站 是一套完善的web设计系统,对理解php编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。 代码下载 https://download.csdn.net/download/qq_41221322/88041034https://download.…

Sentry 监控 Docker 方式部署

一、简介 根据主篇 Sentry 监控部署与使用 流程,使用 Docker 方式 方式进行部署。 docker 方式 部署操作比较简单,也是 Sentry 官方 比较推崇的方式,直接按 Sentry On-Premise 提供的方式按部就班部署就好了。或者可直接参考 Docker 部署 Se…

关于根据文件名以及内容查找文件存放路径

1 根据文件名字查找文件存放路径 1.1 命令如下(先切换到存放该文件的顶级父目录下): find /path/to/search -name "filename"​​ 1.2 案例如下 2 根据内容查找包含该内容的文件存放路径 2.1 命令如下(先切换到存放该文…

《大大简化每次运行bochs的命令行》ubuntu里安装vscode + makefile文件基本编写 + shell命令

📍安装vscode 启动vscode 如图打开商店,在搜索栏里输入visual studio code,安装即可 在随便一个命令行里输入code即可打开vscode 📍makefile文件基本编写 在实验项目文件夹里创建makefile文件(vscode直接能快捷创…

禁止22H2版windows10出现windows11的跨版本升级提示

近期微软为了强推windows11,笔者所用的笔记本又出现了升级windows11的提示,烦人不说,还担心一不小心点错了,系统就给升了,赶紧禁止了跨版本升级,相关设置记录如下: 一、问题情况 系统出现了升…