【深度学习】注意力机制(二)

news2024/9/25 7:30:58

本文介绍一些注意力机制的实现,包括EA/MHSA/SK/DA/EPSA。

【深度学习】注意力机制(一)

【深度学习】注意力机制(三)

目录

一、EA(External Attention)

二、Multi Head Self Attention

三、SK(Selective Kernel Networks)

四、DA(Dual Attention)

五、EPSA(Efficient Pyramid Squeeze Attention)


一、EA(External Attention)

EA可以关注全局的空间信息,论文:论文地址

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init

class External_attention(nn.Module):
    '''
    Arguments:
        c (int): The input and output channel number.
    '''
    def __init__(self, c):
        super(External_attention, self).__init__()
        
        self.conv1 = nn.Conv2d(c, c, 1)

        self.k = 64
        self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)

        self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)
        self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)        
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(c, c, 1, bias=False),
            norm_layer(c))        
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, nn.Conv1d):
                n = m.kernel_size[0] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, _BatchNorm):
                m.weight.data.fill_(1)
                if m.bias is not None:
                    m.bias.data.zero_()
 

    def forward(self, x):
        idn = x
        x = self.conv1(x)

        b, c, h, w = x.size()
        n = h*w
        x = x.view(b, c, h*w)   # b * c * n 

        attn = self.linear_0(x) # b, k, n
        attn = F.softmax(attn, dim=-1) # b, k, n

        attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) #  # b, k, n
        x = self.linear_1(attn) # b, c, n

        x = x.view(b, c, h, w)
        x = self.conv2(x)
        x = x + idn
        x = F.relu(x)
        return x

二、Multi Head Self Attention

注意力机制的经典,Transformer的基石。论文:论文地址

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init



class ScaledDotProductAttention(nn.Module):
    '''
    Scaled dot-product attention
    '''

    def __init__(self, d_model, d_k, d_v, h,dropout=.1):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''
        super(ScaledDotProductAttention, self).__init__()
        self.fc_q = nn.Linear(d_model, h * d_k)
        self.fc_k = nn.Linear(d_model, h * d_k)
        self.fc_v = nn.Linear(d_model, h * d_v)
        self.fc_o = nn.Linear(h * d_v, d_model)
        self.dropout=nn.Dropout(dropout)

        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.h = h

        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_model)
        :param keys: Keys (b_s, nk, d_model)
        :param values: Values (b_s, nk, d_model)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)
        att = torch.softmax(att, -1)
        att=self.dropout(att)

        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_model)
        return out

三、SK(Selective Kernel Networks)

SK是通道注意力机制。论文地址:论文连接

如下图:

代码如下(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init
from collections import OrderedDict



class SKAttention(nn.Module):

    def __init__(self, channel=512,kernels=[1,3,5,7],reduction=16,group=1,L=32):
        super().__init__()
        self.d=max(L,channel//reduction)
        self.convs=nn.ModuleList([])
        for k in kernels:
            self.convs.append(
                nn.Sequential(OrderedDict([
                    ('conv',nn.Conv2d(channel,channel,kernel_size=k,padding=k//2,groups=group)),
                    ('bn',nn.BatchNorm2d(channel)),
                    ('relu',nn.ReLU())
                ]))
            )
        self.fc=nn.Linear(channel,self.d)
        self.fcs=nn.ModuleList([])
        for i in range(len(kernels)):
            self.fcs.append(nn.Linear(self.d,channel))
        self.softmax=nn.Softmax(dim=0)



    def forward(self, x):
        bs, c, _, _ = x.size()
        conv_outs=[]
        ### split
        for conv in self.convs:
            conv_outs.append(conv(x))
        feats=torch.stack(conv_outs,0)#k,bs,channel,h,w

        ### fuse
        U=sum(conv_outs) #bs,c,h,w

        ### reduction channel
        S=U.mean(-1).mean(-1) #bs,c
        Z=self.fc(S) #bs,d

        ### calculate attention weight
        weights=[]
        for fc in self.fcs:
            weight=fc(Z)
            weights.append(weight.view(bs,c,1,1)) #bs,channel
        attention_weughts=torch.stack(weights,0)#k,bs,channel,1,1
        attention_weughts=self.softmax(attention_weughts)#k,bs,channel,1,1

        ### fuse
        V=(attention_weughts*feats).sum(0)
        return V

四、DA(Dual Attention)

DA融合了通道注意力和空间注意力机制。论文:论文地址

如下图:

代码(代码连接):

import numpy as np
import torch
from torch import nn
from torch.nn import init
from model.attention.SelfAttention import ScaledDotProductAttention
from model.attention.SimplifiedSelfAttention import SimplifiedScaledDotProductAttention

class PositionAttentionModule(nn.Module):

    def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
        super().__init__()
        self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)
        self.pa=ScaledDotProductAttention(d_model,d_k=d_model,d_v=d_model,h=1)
    
    def forward(self,x):
        bs,c,h,w=x.shape
        y=self.cnn(x)
        y=y.view(bs,c,-1).permute(0,2,1) #bs,h*w,c
        y=self.pa(y,y,y) #bs,h*w,c
        return y


class ChannelAttentionModule(nn.Module):
    
    def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
        super().__init__()
        self.cnn=nn.Conv2d(d_model,d_model,kernel_size=kernel_size,padding=(kernel_size-1)//2)
        self.pa=SimplifiedScaledDotProductAttention(H*W,h=1)
    
    def forward(self,x):
        bs,c,h,w=x.shape
        y=self.cnn(x)
        y=y.view(bs,c,-1) #bs,c,h*w
        y=self.pa(y,y,y) #bs,c,h*w
        return y


class DAModule(nn.Module):

    def __init__(self,d_model=512,kernel_size=3,H=7,W=7):
        super().__init__()
        self.position_attention_module=PositionAttentionModule(d_model=512,kernel_size=3,H=7,W=7)
        self.channel_attention_module=ChannelAttentionModule(d_model=512,kernel_size=3,H=7,W=7)
    
    def forward(self,input):
        bs,c,h,w=input.shape
        p_out=self.position_attention_module(input)
        c_out=self.channel_attention_module(input)
        p_out=p_out.permute(0,2,1).view(bs,c,h,w)
        c_out=c_out.view(bs,c,h,w)
        return p_out+c_out

五、EPSA(Efficient Pyramid Squeeze Attention)

论文:论文地址

如下图:

代码如下(代码连接):

import torch.nn as nn

class SEWeightModule(nn.Module):

    def __init__(self, channels, reduction=16):
        super(SEWeightModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, channels//reduction, kernel_size=1, padding=0)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels//reduction, channels, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        out = self.avg_pool(x)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        weight = self.sigmoid(out)

        return weight


def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):
    """standard convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                     padding=padding, dilation=dilation, groups=groups, bias=False)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class PSAModule(nn.Module):

    def __init__(self, inplans, planes, conv_kernels=[3, 5, 7, 9], stride=1, conv_groups=[1, 4, 8, 16]):
        super(PSAModule, self).__init__()
        self.conv_1 = conv(inplans, planes//4, kernel_size=conv_kernels[0], padding=conv_kernels[0]//2,
                            stride=stride, groups=conv_groups[0])
        self.conv_2 = conv(inplans, planes//4, kernel_size=conv_kernels[1], padding=conv_kernels[1]//2,
                            stride=stride, groups=conv_groups[1])
        self.conv_3 = conv(inplans, planes//4, kernel_size=conv_kernels[2], padding=conv_kernels[2]//2,
                            stride=stride, groups=conv_groups[2])
        self.conv_4 = conv(inplans, planes//4, kernel_size=conv_kernels[3], padding=conv_kernels[3]//2,
                            stride=stride, groups=conv_groups[3])
        self.se = SEWeightModule(planes // 4)
        self.split_channel = planes // 4
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        batch_size = x.shape[0]
        x1 = self.conv_1(x)
        x2 = self.conv_2(x)
        x3 = self.conv_3(x)
        x4 = self.conv_4(x)

        feats = torch.cat((x1, x2, x3, x4), dim=1)
        feats = feats.view(batch_size, 4, self.split_channel, feats.shape[2], feats.shape[3])

        x1_se = self.se(x1)
        x2_se = self.se(x2)
        x3_se = self.se(x3)
        x4_se = self.se(x4)

        x_se = torch.cat((x1_se, x2_se, x3_se, x4_se), dim=1)
        attention_vectors = x_se.view(batch_size, 4, self.split_channel, 1, 1)
        attention_vectors = self.softmax(attention_vectors)
        feats_weight = feats * attention_vectors
        for i in range(4):
            x_se_weight_fp = feats_weight[:, i, :, :]
            if i == 0:
                out = x_se_weight_fp
            else:
                out = torch.cat((x_se_weight_fp, out), 1)

        return out

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1314989.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

26 redis 中 replication/cluster 集群中的主从复制

前言 我们这里首先来看 redis 这边实现比较复杂的 replication集群模式 我们这里主要关注的是 redis 这边的主从同步的相关实现 这边相对比较简单, 我们直接基于 cluster集群模式 进行调试 主从命令同步复制 比如这里 master 是 redis_7002, slave 是 redis_7005 然后 这…

打开软木塞,我们来谈谈葡萄酒泡泡吧

香槟是任何庆祝场合的最佳搭配。从婚礼和生日到单身派对和典型的周五晚上,这款气泡饮料是生活中特别聚会的受欢迎伴侣。 来自云仓酒庄品牌雷盛红酒分享你知道吗,你喜欢喝的那瓶香槟酒可能根本不是香槟,而是汽酒?你不是唯一一个认…

6个超好用的小众图片素材网站,高清、免费,值得收藏~

推荐几个超好用的图片素材网站,免费下载,还可以商用,建议收藏哦~ 1、菜鸟图库 https://www.sucai999.com/pic.html?vNTYwNDUx 我推荐过很多次的设计素材网站,除了设计类素材,还有很多自媒体可以用到的高清图片、背景…

最好的猫粮排行榜前十名有哪些品牌?质量好的主食冻干猫粮分享

为什么越来越多人推荐冻干猫粮喂养呢?主食冻干猫粮究竟是最适应猫饮食习惯的喂养方式还是消费陷阱? 作为一个6年的宠物营养师,我以前接触过很多不同品种的猫咪,一只健康又漂亮的猫咪从表面上就能看出来!体型匀称刚好、…

大模型落地,向量数据库到底能做什么?

▼最近直播超级多,预约保你有收获 今晚直播:《AI编程向量数据库架构设计案例实践》 —1— 大模型的“数据局限性” 数据局限对企业做 LLM 大模型带来的影响,可归结为以下三点: 第一点:对数据的管理和运维。随着文本、…

LeetCode(63)旋转链表【链表】【中等】

目录 1.题目2.答案3.提交结果截图 链接: 旋转链表 1.题目 给你一个链表的头节点 head ,旋转链表,将链表每个节点向右移动 k 个位置。 示例 1: 输入:head [1,2,3,4,5], k 2 输出:[4,5,1,2,3]示例 2&…

深入理解LightGBM

1. LightGBM简介 GBDT (Gradient Boosting Decision Tree) 是机器学习中一个长盛不衰的模型,其主要思想是利用弱分类器(决策树)迭代训练以得到最优模型,该模型具有训练效果好、不易过拟合等优点。GBDT不仅在工业界应用广泛&#…

python初试二

连接数据库 Django为多种数据库后台提供了统一的调用API。根据需求不同,Django可以选择不同的数据库后台。MySQL算是最常用的数据库。我们这里将Django和MySQL连接。 在Linux终端下启动mysql: $mysql -u root -p 在MySQL中创立Django项目的数据库: …

【数据结构和算法】判断子序列

其他系列文章导航 Java基础合集数据结构与算法合集 设计模式合集 多线程合集 分布式合集 ES合集 文章目录 其他系列文章导航 文章目录 前言 一、题目描述 二、题解 2.1 方法一:双指针 三、代码 3.1 方法一:双指针 3.1.1 Java易懂版:…

Cosmopolitan Libc:让 C 语言一次构建、随处运行 | 开源日报 No.109

jart/cosmopolitan Stars: 12.9k License: ISC Cosmopolitan Libc 使 C 成为一种构建一次运行在任何地方的语言,就像 Java 一样,但它不需要解释器或虚拟机。相反,它重新配置了标准 GCC 和 Clang 以输出符合 POSIX 标准的多语言格式&#xff…

VS Code连接远程Linux服务器调试C程序

1.在 VS Code 上安装扩展 C/C 2.通过 VS Code 连接远程 Linux 服务器 3.通过 VS Code 在远程 Linux 服务器上安装扩展 C/C 4.打开远程 Linux 服务器上的文件夹 【注】本文以 /root/ 为例。 5.创建项目文件夹,并在项目文件夹下创建C程序 6.按 F5,选…

数据挖掘任务一般流程

数据挖掘是从大量数据中提取有价值信息的过程。它涉及多个步骤,每一步都对整个数据挖掘过程至关重要。以下是数据挖掘任务的一般流程: 业务理解: 确定业务目标。评估当前情况。定义数据挖掘问题。制定一个初步计划来达到这些目标。 数据理…

JVM的类的生命周期

目录 前言 1. 加载(Loading): 2. 验证(Verification): 3. 准备(Preparation): 4. 解析(Resolution): 5. 初始化(Ini…

解决ES伪慢查询

一、问题现象 服务现象 服务接口的TP99性能降低 ES现象 YGC:耗时极其不正常, 峰值200次,耗时7sFULL GC:不正常,次数为1但是频繁,STW 5s慢查询:存在慢查询5 二 解决过程 1、去除干扰因素 从现象上看应用是由于某种…

从零开始:前端架构师的基础建设和架构设计之路

文章目录 一、引言二、前端架构师的职责三、基础建设四、架构设计思想五、总结《前端架构师:基础建设与架构设计思想》编辑推荐内容简介作者简介目录获取方式 一、引言 在现代软件开发中,前端开发已经成为了一个不可或缺的部分。随着互联网的普及和移动…

算法通关第十九关-青铜挑战理解动态规划

大家好我是苏麟 , 今天聊聊动态规划 . 动态规划是最热门、最重要的算法思想之一,在面试中大量出现,而且题目整体都偏难一些对于大部人来说,最大的问题是不知道动态规划到底是怎么回事。很多人看教程等,都被里面的状态子问题、状态…

文章解读与仿真程序复现思路——电力系统自动化EI\CSCD\北大核心《考虑电力-交通交互的配电网故障下电动汽车充电演化特性》

这个标题涉及到电力系统、交通系统和电动汽车充电的复杂主题。让我们逐步解读: 考虑电力-交通交互的配电网故障: 电力-交通交互: 指的是电力系统和交通系统之间相互影响、相互关联的关系。这可能涉及到电力需求对交通流量的影响,反…

windows wsl2 ubuntu上部署 redroid云手机

Redroid WSL2部署文档 下载wsl内核源码 #文档注明 5.15和5.10 版本内核可以部署成功,这里我当前最新的发布版本 #下载wsl 源码 wget --progressbar:force --output-documentlinux-msft-wsl-5.15.133.1.tar.gz https://codeload.github.com/microsoft/WSL2-Linux-Ker…

Nacos配置管理-微服务配置拉取

yaml已配置内容 目录 一、配置获取步骤 二、统一配置管理步骤 三、Nacos管理配置的步骤总结 一、配置获取步骤 二、统一配置管理步骤 1、引入Nacos的配置管理客户端依赖: <!--nacos配置管理依赖--> <dependency> <groupId>com.alibaba.cloud&l…

ABAP与HANA集成:HANA视图转换为ABAP字典视图

使用场景 最近项目在用HANA开发逻辑&#xff0c;形成了很多过程的计算视图&#xff0c;一般我们BW人员可能直接用计算视图出具前端报表&#xff0c;或者链接到cp使用&#xff0c;没有考虑转换成abap字典视图&#xff0c;也就是前台SE11能查到的视图&#xff0c;但是非开发人员…