深度学习中的attention机制

news2025/1/4 19:09:43

SE

文章   https://openaccess.thecvf.com/content_cvpr_2018/papers/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper.pdfhttps://openaccess.thecvf.com/content_cvpr_2018/papers/Hu_Squeeze-and-Excitation_Networks_CVPR_2018_paper.pdf

class SELayer(nn.Module):
    def __init__(self, channel, reduction=4):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(channel, channel // reduction),
                nn.ReLU(inplace=True),
                nn.Linear(channel // reduction, channel),        )
 
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        y = torch.clamp(y, 0, 1)
        return x * y
class SEBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,
                 *, reduction=16):
        super(SEBasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, 1)
        self.bn2 = nn.BatchNorm2d(planes)
        self.se = SELayer(planes, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class SEBottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1, norm_layer=None,
                 *, reduction=16):
        super(SEBottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.se = SELayer(planes * 4, reduction)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.se(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

 

代码

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

插入位置:

SE模块的插入位置
通过Resnet的基础模块和bottleneck模块 可以看出SE模块插入到,跳连结构add之前,对前面特征提取之后的特征图给与不同的权重,再与shortcut跳连分支相加。

CBAM

文章

https://arxiv.org/pdf/1807.06521.pdf

github:GitHub - Jongchan/attention-module: Official PyTorch code for "BAM: Bottleneck Attention Module (BMVC2018)" and "CBAM: Convolutional Block Attention Module (ECCV2018)"

第一步:利用SE同时进行全局最大和全局平均池化,生成1x1xC通道注意力图,不同通道加权;

第二步: 通过channel池化,生成两个HxW的特征图,再卷积成1xHxW,对spatial进行逐点加权。

注意机制(CBAM)理解_Tc.小浩的博客-CSDN博客_cbam注意力机制

 分两个注意力,CAM和SAM。

CAM

CAM和SE的不同在于:同时进行了全局平均池化,和全局最大池化。

SAM


# (1)通道注意力机制
class channel_attention(nn.Module):
    # ratio代表第一个全连接的通道下降倍数
    def __init__(self, in_channel, ratio=4):
        super().__init__()

        # 全局最大池化 [b,c,h,w]==>[b,c,1,1]
        self.max_pool = nn.AdaptiveMaxPool2d(output_size=1)
        # 全局平均池化 [b,c,h,w]==>[b,c,1,1]
        self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)

        # 第一个全连接层, 通道数下降4倍(可以换成1x1的卷积,效果相同)
        self.fc1 = nn.Linear(in_features=in_channel, out_features=in_channel // ratio, bias=False)
        # 第二个全连接层, 恢复通道数(可以换成1x1的卷积,效果相同)
        self.fc2 = nn.Linear(in_features=in_channel // ratio, out_features=in_channel, bias=False)

        # relu激活函数
        self.relu = nn.ReLU()

        # sigmoid激活函数
        self.sigmoid = nn.Sigmoid()

    # 前向传播
    def forward(self, inputs):
        b, c, h, w = inputs.shape

        # 输入图像做全局最大池化 [b,c,h,w]==>[b,c,1,1]
        max_pool = self.max_pool(inputs)

        # 输入图像的全局平均池化 [b,c,h,w]==>[b,c,1,1]
        avg_pool = self.avg_pool(inputs)

        # 调整池化结果的维度 [b,c,1,1]==>[b,c]
        max_pool = max_pool.view([b, c])
        avg_pool = avg_pool.view([b, c])

        # 第一个全连接层下降通道数 [b,c]==>[b,c//4]

        x_maxpool = self.fc1(max_pool)
        x_avgpool = self.fc1(avg_pool)

        # 激活函数
        x_maxpool = self.relu(x_maxpool)
        x_avgpool = self.relu(x_avgpool)

        # 第二个全连接层恢复通道数 [b,c//4]==>[b,c]
        # (可以换成1x1的卷积,效果相同)
        x_maxpool = self.fc2(x_maxpool)
        x_avgpool = self.fc2(x_avgpool)

        # 将这两种池化结果相加 [b,c]==>[b,c]
        x = x_maxpool + x_avgpool

        # sigmoid函数权值归一化
        x = self.sigmoid(x)

        # 调整维度 [b,c]==>[b,c,1,1]
        x = x.view([b, c, 1, 1])

        # 输入特征图和通道权重相乘 [b,c,h,w]
        outputs = inputs * x

        return outputs

# (2)空间注意力机制
class spatial_attention(nn.Module):
    # 卷积核大小为7*7
    def __init__(self, kernel_size=7):
        super().__init__()

        # 为了保持卷积前后的特征图shape相同,卷积时需要padding
        padding = kernel_size // 2

        # 7*7卷积融合通道信息 [b,2,h,w]==>[b,1,h,w]
        self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size,
                              padding=padding, bias=False)
        # sigmoid函数
        self.sigmoid = nn.Sigmoid()

    # 前向传播
    def forward(self, inputs):
        # 在通道维度上最大池化 [b,1,h,w]  keepdim保留原有深度
        # 返回值是在某维度的最大值和对应的索引
        x_maxpool, _ = torch.max(inputs, dim=1, keepdim=True)

        # 在通道维度上平均池化 [b,1,h,w]
        x_avgpool = torch.mean(inputs, dim=1, keepdim=True)
        # 池化后的结果在通道维度上堆叠 [b,2,h,w]
        x = torch.cat([x_maxpool, x_avgpool], dim=1)

        # 卷积融合通道信息 [b,2,h,w]==>[b,1,h,w]
        x = self.conv(x)

        # 空间权重归一化
        x = self.sigmoid(x)

        # 输入特征图和空间权重相乘
        outputs = inputs * x

        return outputs

# (3)CBAM注意力机制
class CBAM(nn.Module):
    # 初始化,in_channel和ratio=4代表通道注意力机制的输入通道数和第一个全连接下降的通道数
    # kernel_size代表空间注意力机制的卷积核大小
    def __init__(self, in_channel, ratio=4, kernel_size=7):
        super().__init__()
        # 实例化通道注意力机制
        self.channel_attention = channel_attention(in_channel=in_channel, ratio=ratio)
        # 实例化空间注意力机制
        self.spatial_attention = spatial_attention(kernel_size=kernel_size)

    # 前向传播
    def forward(self, inputs):
        # 先将输入图像经过通道注意力机制
        x = self.channel_attention(inputs)

        # 然后经过空间注意力机制
        x = self.spatial_attention(x)

        return x

CA

文章:https://openaccess.thecvf.com/content/CVPR2021/papers/Hou_Coordinate_Attention_for_Efficient_Mobile_Network_Design_CVPR_2021_paper.pdf

code: GitHub - houqb/CoordAttention: Code for our CVPR2021 paper coordinate attention

参考文档:

 CA(Coordinate attention) 注意力机制 - 知乎

2021CVPR-Coordinate Attention for Efficient Mobile Network Design 坐标注意力机制_小哈蒙德的博客-CSDN博客_坐标注意力机制

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/194403.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java工具包类

java.util包有很多实用的类、接口和异常。 向量类,堆栈类,哈希表,枚举接口,日历类,随机函数类,映射接口和属性类。 Vector类 vector是异构的,可以存储不同的对象,同时可以动态增加…

【工具】国内苹果市场已上架 新一代社交产品 damus

国内苹果市场可下载 2月1日,Twitter 联合创始人 Jack Dorsey 发布推文表示,基于分布式社交媒体协议 Nostr 的社交产品 Damus 和 Amethyst 正式在苹果 App Store 和谷歌 Google Play Store 上线。 目前为止,Damus 在国内苹果应用市场是可以直…

远程超大功率森林防火喊话与应急广播系统方案

北京恒星科通发布于2023-2-2 一、引言 随着消灭宜林荒山和实现全面绿化,造林事业不断发展,林地面积、林业蓄积量逐年增加,如何加强森林防火、保护环境,是全国当前面临的一项重大任务。 森林火灾是一种突发性和破坏性极强的自然…

Spring Security(新版本)实现权限认证与授权

学习新版SpringSecurity详细配置一、Spring Security介绍1、Spring Security简介2、历史3、同款产品对比3.1、Spring Security3.2、 Shiro二、Spring Security实现权限1、SpringSecurity入门1.1 添加依赖1.2、启动项目测试2、用户认证2.1、用户认证核心组件2.2、用户认证2.2.1、…

CrossOver虚拟机软件2023最新版Mac运行切换Windows

CrossOver2023版是专为苹果电脑用户打造的一款实用工具,这款工具主要方便用户在mac上运行windows系列的应用程序,用户不需要安装虚拟机就可以实现各种应用程序的直接应用,并且可以实现无缝集成,实现跨平台的复制粘贴和文件互通等&…

我为什么抢不到票?!全国最难抢线路揭晓

随着疫情防控策略的转变,不少多年未归的朋友选择在今年返乡团聚。那么2023年春运抢票难度是否会因此而飙升?本期文章,我们通过数据分析,观察比较哪条线路的票最难抢,给还没有买到票的朋友提供参考。 根据往年央视报道和…

浅析晶体管放大电路的负载线

晶体管放大电路的负载线包括直流负载线和交流负载线,描述了输出端电压、电流与负载之间的关系。大学期间曾经学习过相关知识,本文将与大家重温所学内容,并介绍直流工作点对功率放大器性能的影响。 直流负载线 以场效应管为例,图…

基于python3实现Azure机器学习最接近人声的文本转语音功能

上期文章,我们介绍了如何使用Azure来创建一个语音服务API,哪里,我们得到了API的key,以及语音服务的基本信息,包含地区等,这些都是本期代码需要的参数 听了那么多AI合成的语音,Azure机器学习的文本转语音最接近人声https://blog.csdn.net/weixin_44782294/article/detai…

如何实现大文件上传:秒传、断点续传、分片上传

前言 文件上传是一个老生常谈的话题了,在文件相对比较小的情况下,可以直接把文件转化为字节流上传到服务器,但在文件比较大的情况下,用普通的方式进行上传,这可不是一个好的办法,毕竟很少有人会忍受&#…

在字节跳动干了5年的软件测试,2月无情被辞,想给划水的兄弟提个醒

前几天,一个认识了好几年在大厂工作做软件测试的朋友,年近30了,却被大厂以“人员优化”的名义无情被辞,据他说,有一个月散伙饭都吃了好几顿…… 在很多企业,都有KPI考核,然后在此基础上还会弄个…

Ubuntu设置静态IP

Ubuntu设置静态IP1.当前环境2. 设置前准备3.前提准备4.修改VMware中的网络配置5.修改Ubuntu配置文件6.查看网关信息7.Xshell远程连接1.当前环境 VMware16、Xshell7 2. 设置前准备 VMware16设置快照,配置出错可以返回到初始状态 3.前提准备 查看Ubuntu是否安装vi…

小程序项目学习--第七章:播放页布局-歌曲进度控制-歌词的展示

第七章:播放页布局-歌曲进度控制-歌词的展示 01_(了解)之前页面的回顾和播放页的介绍 功能介绍 02_(掌握)播放页-点击Item跳转到播放页和传入ID 功能概览 1.创建页面music-player 2.监听item的点击 方式一:直接写在子组件上 绑定监听点击 需要获取…

关于xxl-job中的慢sql引发的磁盘I/O飙升导致拖垮整个数据库服务

背景: 某天突然发现服务探测接口疯狂告警、同时数据库CPU消耗也告警,最后系统都无法访问; 查看服务端日志,发现大量的报错如下: CommunicationsException: Communications link failure :The last packet s…

dapr入门与本地托管模式尝试

1 简介 Dapr是一个可移植的、事件驱动的运行时,它使任何开发人员能够轻松构建出弹性的、无状态和有状态的应用程序,并可运行在云平台或边缘计算中,它同时也支持多种编程语言和开发框架。Dapr支持的语言很多,包括C/Go/Java/JavaSc…

antv/x6 2.x 搭建流程图编辑页面(1)

进来闲来无事,看到x6 2.x版本也更新了有几个月了,便想着熟悉下2.x版本 一、首先搭建项目基础框架。 // yarn 方式 yarn create vitejs/app v3-ts --template vue-ts cd v3-ts yarn yarn dev// npm npm init vitejs/app v3-ts --template vue-ts cd v3…

人工神经网络BP神经网络结构及优化原理单隐层,多隐层及反向传播梯度下降释义

神经网络:人工神经网络(Artificial Neural Networks,简写为ANNs)也简称为神经网络(NNs)或称作连接模型(Connection Model),它是一种模仿动物神经网络行为特征&#xff0c…

ASEMI整流模块MDA110-16参数,MDA110-16规格

编辑-Z ASEMI整流模块MDA110-16参数: 型号:MDA110-16 最大重复峰值反向电压(VRRM):1600V 最大RMS电桥输入电压(VRMS):1700V 最大平均正向整流输出电流(IF&#xff0…

多线程synchronized对象锁与静态锁之间的8锁问题理解

目录8锁问题锁1:多个线程使用同一对象分别调用不同带有带synchronized关键字且非静态的方法锁2:在锁1基础上,增加A线程执行的方法的执行时间,使得B有机会参与执行锁3:多个线程使用同一对象,一个线程执行带有…

三菱Q系列QJ71C24N模块 MODBUS通信(含完整步骤+源代码)

MODBUS通信的其它相关基础知识请参看下面的文章链接: SMART S7-200PLC MODBUS通信_RXXW_Dor的博客-CSDN博客MODBUS 是 OSI 模型第 7 层上的应用层报文传输协议,它在连接至不同类型总线或网络的设备之间提供客户机/服务器通信。自从 1979 年出现工业串行链路的事实标准以来,…

7z压缩包有可能被破解吗,需要多久,暴力穷举和字典破解分别的速度分析

开门见山,我看到网上有很多此类软件,功能就是来破解7zip格式的压缩包,但都没有认真进行测试,这里认真进行判断和测试 首先,目前世界最快的计算机为Frontier,算力1,685.65 PFlop/s。目前最高的算力为全网比…