【深度学习】注意力机制(一)

news2025/1/16 1:05:44

本文介绍一些注意力机制的实现,包括SE/ECA/GE/A2-Net/GC/CBAM。

目录

一、SE(Squeeze-and-Excitation)

二、ECA(Efficient Channel Attention)

三、GE(Gather-Excite)

四、A2-Net(Double Attention Networks)

五、GCNet(Global Context)

六、CBAM(Convolutional Block Attention Module)


一、SE(Squeeze-and-Excitation)

SE是通道注意力机制,论文地址:论文地址

SE模块流程:

1、输入特征图经过自适应池化变为NC11的特征图,特征图resize为NC;

2、经过全连接层和Relu、sigmoid生成权重;

3、将权重和输入特征图相乘。

如下所示:

torch代码实现:

import numpy as np
import torch
from torch import nn
from torch.nn import init


class SEAttention(nn.Module):

    def __init__(self, channel=512,reduction=16):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

二、ECA(Efficient Channel Attention)

ECA是通道注意力机制,论文:论文地址

ECA模块过程:

1、使用自适应池化将NCHW的特征图变为N1C的特征图(自适应池化、squeeze、transpose);

2、使用1D卷积生成N1C的特征图(在C通道做卷积),将经过1D卷积的特征图变为NC11(transpose、unsqueeze);

3、特征图通过sigmoid,生成NC11的权重,将权重与原特征图相乘;

如下图:

torch代码:

import torch
from torch import nn
from torch.nn.parameter import Parameter

class ECALayer(nn.Module):
    """Constructs a ECA module.

    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """
    def __init__(self, channel, k_size=3):
        super(eca_layer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) 
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # feature descriptor on the global spatial information
        y = self.avg_pool(x)

        # Two different branches of ECA module
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)

        # Multi-scale information fusion
        y = self.sigmoid(y)

        return x * y.expand_as(x)
        

三、GE(Gather-Excite)

GE是空间注意力机制,论文:论文地址

该机制较为简单,有四种方式,总体流程如下(看图理解比较好,不多说了):

可以通过timm轻松调用该模块,timm实现的源码:

import math

from torch import nn as nn
import torch.nn.functional as F

from .create_act import create_act_layer, get_act_layer
from .create_conv2d import create_conv2d
from .helpers import make_divisible
from .mlp import ConvMlp


class GatherExcite(nn.Module):
    """ Gather-Excite Attention Module
    """
    def __init__(
            self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
            rd_ratio=1./16, rd_channels=None,  rd_divisor=1, add_maxpool=False,
            act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
        super(GatherExcite, self).__init__()
        self.add_maxpool = add_maxpool
        act_layer = get_act_layer(act_layer)
        self.extent = extent
        if extra_params:
            self.gather = nn.Sequential()
            if extent == 0:
                assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
                self.gather.add_module(
                    'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))
                if norm_layer:
                    self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))
            else:
                assert extent % 2 == 0
                num_conv = int(math.log2(extent))
                for i in range(num_conv):
                    self.gather.add_module(
                        f'conv{i + 1}',
                        create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))
                    if norm_layer:
                        self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))
                    if i != num_conv - 1:
                        self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
        else:
            self.gather = None
            if self.extent == 0:
                self.gk = 0
                self.gs = 0
            else:
                assert extent % 2 == 0
                self.gk = self.extent * 2 - 1
                self.gs = self.extent

        if not rd_channels:
            rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
        self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()
        self.gate = create_act_layer(gate_layer)

    def forward(self, x):
        size = x.shape[-2:]
        if self.gather is not None:
            x_ge = self.gather(x)
        else:
            if self.extent == 0:
                # global extent
                x_ge = x.mean(dim=(2, 3), keepdims=True)
                if self.add_maxpool:
                    # experimental codepath, may remove or change
                    x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
            else:
                x_ge = F.avg_pool2d(
                    x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
                if self.add_maxpool:
                    # experimental codepath, may remove or change
                    x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
        x_ge = self.mlp(x_ge)
        if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
            x_ge = F.interpolate(x_ge, size=size)
        return x * self.gate(x_ge)

四、A2-Net(Double Attention Networks)

双重注意力网络(A2-Nets)方法引入了新的关系函数用于非局部(NL)块,依次使用两个连续的注意力块。论文地址:论文地址

其计算过程类似于SelfAttention模块,可以看diamagnetic对照理解。

如下图:

代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

class DoubleAtten(nn.Module):
    """
    A2-Nets: Double Attention Networks. NIPS 2018
    """
    def __init__(self,in_c):
        """
        :param
        in_c: 进行注意力refine的特征图的通道数目;
        原文中的降维和升维没有使用
        """
        super(DoubleAtten,self).__init__()
        self.in_c = in_c
        """
        以下对同一输入特征图进行卷积,产生三个尺度相同的特征图,即为文中提到A, B, V
        """
        self.convA = nn.Conv2d(in_c,in_c,kernel_size=1)
        self.convB = nn.Conv2d(in_c,in_c,kernel_size=1)
        self.convV = nn.Conv2d(in_c,in_c,kernel_size=1)
    def forward(self,input):

        feature_maps = self.convA(input)
        atten_map = self.convB(input)
        b, _, h, w = feature_maps.shape

        feature_maps = feature_maps.view(b, 1, self.in_c, h*w) # 对 A 进行reshape
        atten_map = atten_map.view(b, self.in_c, 1, h*w)       # 对 B 进行reshape 生成 attention_aps
        global_descriptors = torch.mean((feature_maps * F.softmax(atten_map, dim=-1)),dim=-1) # 特征图与attention_maps 相乘生成全局特征描述子

        v = self.convV(input)
        atten_vectors = F.softmax(v.view(b, self.in_c, h*w), dim=-1) # 生成 attention_vectors
        out = torch.bmm(atten_vectors.permute(0,2,1), global_descriptors).permute(0,2,1) # 注意力向量左乘全局特征描述子

        return out.view(b, _, h, w)

五、GCNet(Global Context)

全局上下文网络(GC-Net)方法使用复杂的基于置换的操作将NL-块和SE块集成,以捕捉长期依赖关系。论文:论文地址

可以看出GC模块是对SE的改进,如下图:

该实现的初始化依赖于mmcv,代码如下:

import torch
from mmcv.cnn import constant_init, kaiming_init
from torch import nn


def last_zero_init(m):
    if isinstance(m, nn.Sequential):
        constant_init(m[-1], val=0)
    else:
        constant_init(m, val=0)


class ContextBlock(nn.Module):

    def __init__(self,
                 inplanes,
                 ratio,
                 pooling_type='att',
                 fusion_types=('channel_add', )):
        super(ContextBlock, self).__init__()
        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        valid_fusion_types = ['channel_add', 'channel_mul']
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'
        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None
        self.reset_parameters()

    def reset_parameters(self):
        if self.pooling_type == 'att':
            kaiming_init(self.conv_mask, mode='fan_in')
            self.conv_mask.inited = True

        if self.channel_add_conv is not None:
            last_zero_init(self.channel_add_conv)
        if self.channel_mul_conv is not None:
            last_zero_init(self.channel_mul_conv)

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)

        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term

        return out

六、CBAM(Convolutional Block Attention Module)

CBAM是通道-空间注意力机制,论文:论文地址

很简单的通道注意力和空间注意力融合。

如下图:

代码如下:

import numpy as np
import torch
from torch import nn
from torch.nn import init



class ChannelAttention(nn.Module):
    def __init__(self,channel,reduction=16):
        super().__init__()
        self.maxpool=nn.AdaptiveMaxPool2d(1)
        self.avgpool=nn.AdaptiveAvgPool2d(1)
        self.se=nn.Sequential(
            nn.Conv2d(channel,channel//reduction,1,bias=False),
            nn.ReLU(),
            nn.Conv2d(channel//reduction,channel,1,bias=False)
        )
        self.sigmoid=nn.Sigmoid()
    
    def forward(self, x) :
        max_result=self.maxpool(x)
        avg_result=self.avgpool(x)
        max_out=self.se(max_result)
        avg_out=self.se(avg_result)
        output=self.sigmoid(max_out+avg_out)
        return output

class SpatialAttention(nn.Module):
    def __init__(self,kernel_size=7):
        super().__init__()
        self.conv=nn.Conv2d(2,1,kernel_size=kernel_size,padding=kernel_size//2)
        self.sigmoid=nn.Sigmoid()
    
    def forward(self, x) :
        max_result,_=torch.max(x,dim=1,keepdim=True)
        avg_result=torch.mean(x,dim=1,keepdim=True)
        result=torch.cat([max_result,avg_result],1)
        output=self.conv(result)
        output=self.sigmoid(output)
        return output



class CBAMBlock(nn.Module):

    def __init__(self, channel=512,reduction=16,kernel_size=49):
        super().__init__()
        self.ca=ChannelAttention(channel=channel,reduction=reduction)
        self.sa=SpatialAttention(kernel_size=kernel_size)


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, _, _ = x.size()
        residual=x
        out=x*self.ca(x)
        out=out*self.sa(out)
        return out+residual

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1299505.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Stable Diffusion AI绘画系列【23】:赛博朋克-机甲美女系列

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…

任意文件上传漏洞实战和防范

文件上传漏洞广泛存在于Web1.0时代,恶意攻击者的主要攻击手法是将可执行脚本(WebShell)上传至目标服务器,以达到控制目标服务器的目的。 此漏洞成立的前提条件至少有下面两个: 1.可以上传对应的脚本文件,…

Java 8 新特性深度解析:探索 Lambda 表达式、Stream API 和函数式编程的革新之路

Java8 新特性 Java 8 的革新之路 自 1995 年首次发布以来,Java 已经成为世界上最广泛使用的编程语言之一。随着时间的推移,Java 经历了多次版本更新,其中最具里程碑意义的便是 Java 8 的发布。这个版本引入了许多重大变革,包括 …

vivado时序方法检查11

TIMING-47 &#xff1a; 同步时钟之间的伪路径、异步时钟组或仅最 大延迟数据路径约束 在 <clock_group> 与 <clock_group> 这两个时钟之间设置了 <message_string> 时序约束 &#xff08; 请参阅 VivadoIDE 的“ Timing Constraint ”窗口中的约束位…

Oracle(2-13) RMAN Complete Recove

文章目录 一、基础知识1、Restoration Using RMAN利用RMAN进行恢复2、Relocate a Tablespace 重新定位表空间 二、基础操作1、恢复前的准备2、恢复数据库3、恢复单个数据文件4、在数据库打开的情况下恢复 RMAN Complete Recove RMAN完全恢复 目标&#xff1a; 了解RMAN用于恢复…

HarmonyOS应用开发者基础认证考试(稳过)

判断题 ​​​​​​​ 1. Web组件对于所有的网页都可以使用zoom(factor: number)方法进行缩放。错误(False) 2. 每一个自定义组件都有自己的生命周期正确(True) 3. 每调用一次router.pushUrl()方法&#xff0c;默认情况下&#xff0c;页面栈数量会加1&#xff0c;页面栈支持的…

基于ssm校园教务系统论文

摘 要 互联网发展至今&#xff0c;无论是其理论还是技术都已经成熟&#xff0c;而且它广泛参与在社会中的方方面面。它让信息都可以通过网络传播&#xff0c;搭配信息管理工具可以很好地为人们提供服务。针对校园教务信息管理混乱&#xff0c;出错率高&#xff0c;信息安全性差…

Fall in love with English

Fall in love with English 爱上英语 Hiding behind the loose dusty curtain, a teenager packed up his overcoat into the suitcase. 躲藏在布满尘土的松软的窗帘后边&#xff0c;一个年轻人打包他的外套到行李箱中。 He planned to leave home at dusk though there was th…

Java-JDBC操作MySQL

Java-JDBC操作MySQL 文章目录 Java-JDBC操作MySQL一、Java-JDBC-MySQL的关系二、创建连接三、登录MySQL四、操作数据库1、返回型操作2、无返回型操作 练习题目及完整代码 一、Java-JDBC-MySQL的关系 #mermaid-svg-B7qjXrosQaCOwRos {font-family:"trebuchet ms",verd…

Spring Cloud Alibaba及其Nacos初学习

目录 1、Spring Cloud Alibaba是什么 2、Spring Cloud 微服务体系​ 3、Nacos 3.1 Nacos概述 3.2 下载nacos 3.3 解压nacos 3.4 配置文件properties和sql 3.5 访问web管理页面 4、Spring Cloud整合Nacos 4.1 nacos创建命名空间namespace 4.2 配置管理 4.3 Spring C…

记一次堆内外内存问题的排查和优化

为优化淘宝带宽成本&#xff0c;我们在网关 SDK&#xff08;Java&#xff09;统一使用 ZSTD 替代 GZIP 压缩以获取更高的压缩比&#xff0c;从而得到更小的响应包。具体实现采用官方推荐的 zstd-jni 库。zstd-jni 会调用 zstd 的 c 库。 背景 在性能压测和优化过程中&#xff0…

错题总结(四)

1.【一维数组】输入10个整数&#xff0c;求平均值 编写一个程序&#xff0c;从用户输入中读取10个整数并存储在一个数组中。然后&#xff0c;计算并输出这些整数的平均值。 int main() {int arr[10];int sum 0;for (int n 0; n < 10; n){scanf("%d", &arr…

数据库连接池Druid

在 Spring Boot 项目中&#xff0c;数据库连接池已经成为标配&#xff0c;然而&#xff0c;我曾经遇到过不少连接池异常导致业务错误的事故。很多经验丰富的工程师也可能不小心在这方面出现问题。 在这篇文章中&#xff0c;我们将探讨数据库连接池&#xff0c;深入解析其实现机…

c语言:理解和避免野指针

野指针的定义&#xff1a; 野指针是指一个指针变量存储了一个无效的地址&#xff0c;通常是一个未初始化的指针或者指向已经被释放的内存地址。当程序尝试使用野指针时&#xff0c;可能会导致程序崩溃、内存泄漏或者其他不可预测的行为。因此&#xff0c;在编程中需要特别注意…

MEMS制造的基本工艺介绍——晶圆键合

晶圆键合是一种晶圆级封装技术&#xff0c;用于制造微机电系统 (MEMS)、纳米机电系统 (NEMS)、微电子学和光电子学&#xff0c;确保机械稳定和气密密封。用于 MEMS/NEMS 的晶圆直径范围为 100 毫米至 200 毫米&#xff08;4 英寸至 8 英寸&#xff09;&#xff0c;用于生产微电…

实用篇 | 3D建模中Blender软件的下载及使用[图文详情]

本文基于数字人系列的3D建模工具Blender软件的安装及使用&#xff0c;还介绍了图片生成3D模型的AI工具~ 目录 1.Blender的下载 2.Blender的使用 3.安装插件(通过压缩包安装) 4.实例 4.1.Blender使用MB-Lab插件快速人体模型建构 4.1.1.点击官网&#xff0c;进行下载 4.1.…

Mybatis、Mybatis整合Spring的流程图

Mybatis 注意MapperProxy里面有invoke方法&#xff0c;当进到invoker方法会拿到 二、mybatis整合Spring 1、当我们的拿到的【Dao】其实就是【MapperProxy】&#xff0c;执行Dao的方法时&#xff0c;会被MapperProxy的【Invoke方法拦截】 2、图上已经标注了MapperProxy包含哪些…

深入理解 Go Channel:解密并发编程中的通信机制

一、Channel管道 1、Channel说明 共享内存交互数据弊端 单纯地将函数并发执行是没有意义的。函数与函数间需要交互数据才能体现编发执行函数的意义虽然可以使用共享内存进行数据交换&#xff0c;但是共享内存在不同的goroutine中容易发送静态问题为了保证数据交换的正确性&am…

基于ssm应急资源管理系统论文

摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本应急资源管理系统就是在这样的大环境下诞生&#xff0c;其可以帮助管理者在短时间内处理完毕庞大的数据信息…

入门Redis学习总结

记录之前刚学习Redis 的笔记&#xff0c; 主要包括Redis的基本数据结构、Redis 发布订阅机制、Redis 事务、Redis 服务器相关及采用Spring Boot 集成Redis 实现增删改查基本功能 一&#xff1a;常用命令及数据结构 1.Redis 键(key) # 设置key和value 127.0.0.1:6379> set …