(即插即用模块-Attention部分) 二十、(2021) GAA 门控轴向注意力

news2024/11/28 17:49:46

在这里插入图片描述

文章目录

  • 1、Gated Axial-Attention
  • 2、代码实现

paper:Medical Transformer: Gated Axial-Attention for Medical Image Segmentation

Code:https://github.com/jeya-maria-jose/Medical-Transformer


1、Gated Axial-Attention

论文首先分析了 ViTs 在训练小规模数据集时的弊端以及指出了 ViTs 的计算复杂度偏高。为此,论文提出了一种门控轴向注意力(Gated Axial-Attention),其通过在自注意力模块中引入额外的门控机制来扩展现有的体系结构。在分析了位置偏差难以学习、相对位置编码不够准确等问题后,通过将可控制的影响位置偏差施加在编码的非本地上下文来实现改进。Gated Axial-Attention的 核心思想是Gate门控机制,通过引入 Gate 控制机制来控制位置编码对 Self-Attention 的影响程度。

对于一个输入特征 X,Gated Axial-Attention的实现过程:

  1. 输入特征图: 将输入图像提取特征图,并进行通道维度上的线性变换,得到 Query、Key 和 Value 向量。

  2. Axial-Attention

    在高度方向上进行 1D Self-Attention,计算像素之间的依赖关系。

    在宽度方向上进行 1D Self-Attention,计算像素之间的依赖关系。

  3. Positional Encoding:计算相对位置编码,将像素位置信息融入到 Query、Key 和 Value 向量中。

  4. Gate 控制机制:通过可学习的 Gate 参数,控制相对位置编码对 Self-Attention 的影响程度。

  5. 输出特征图: 将经过 Self-Attention 和 Gate 控制的特征图进行线性变换,得到最终输出特征图。


Gated Axial-Attention 结构图:
在这里插入图片描述


2、代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 卷积"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class qkv_transform(nn.Conv1d):
    """Conv1d for qkv_transform"""


class AxialAttention(nn.Module):
    def __init__(self, in_planes, out_planes, groups=8, kernel_size=56,
                 stride=1, bias=False, width=False):
        assert (in_planes % groups == 0) and (out_planes % groups == 0)
        super(AxialAttention, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.groups = groups
        self.group_planes = out_planes // groups
        self.kernel_size = kernel_size
        self.stride = stride
        self.bias = bias
        self.width = width

        # Multi-head self attention
        self.qkv_transform = qkv_transform(in_planes, out_planes * 2, kernel_size=1, stride=1,
                                           padding=0, bias=False)
        self.bn_qkv = nn.BatchNorm1d(out_planes * 2)
        self.bn_similarity = nn.BatchNorm2d(groups * 3)

        self.bn_output = nn.BatchNorm1d(out_planes * 2)

        # Position embedding
        self.relative = nn.Parameter(torch.randn(self.group_planes * 2, kernel_size * 2 - 1), requires_grad=True)
        query_index = torch.arange(kernel_size).unsqueeze(0)
        key_index = torch.arange(kernel_size).unsqueeze(1)
        relative_index = key_index - query_index + kernel_size - 1
        self.register_buffer('flatten_index', relative_index.view(-1))
        if stride > 1:
            self.pooling = nn.AvgPool2d(stride, stride=stride)

        self.reset_parameters()

    def forward(self, x):
        # pdb.set_trace()
        if self.width:
            x = x.permute(0, 2, 1, 3)
        else:
            x = x.permute(0, 3, 1, 2)  # N, W, C, H
        N, W, C, H = x.shape
        x = x.contiguous().view(N * W, C, H)

        # Transformations
        qkv = self.bn_qkv(self.qkv_transform(x))
        q, k, v = torch.split(qkv.reshape(N * W, self.groups, self.group_planes * 2, H),
                              [self.group_planes // 2, self.group_planes // 2, self.group_planes], dim=2)

        # Calculate position embedding
        all_embeddings = torch.index_select(self.relative, 1, self.flatten_index).view(self.group_planes * 2,
                                                                                       self.kernel_size,
                                                                                       self.kernel_size)
        q_embedding, k_embedding, v_embedding = torch.split(all_embeddings,
                                                            [self.group_planes // 2, self.group_planes // 2,
                                                             self.group_planes], dim=0)

        qr = torch.einsum('bgci,cij->bgij', q, q_embedding)
        kr = torch.einsum('bgci,cij->bgij', k, k_embedding).transpose(2, 3)

        qk = torch.einsum('bgci, bgcj->bgij', q, k)

        stacked_similarity = torch.cat([qk, qr, kr], dim=1)
        stacked_similarity = self.bn_similarity(stacked_similarity).view(N * W, 3, self.groups, H, H).sum(dim=1)
        # stacked_similarity = self.bn_qr(qr) + self.bn_kr(kr) + self.bn_qk(qk)
        # (N, groups, H, H, W)
        similarity = F.softmax(stacked_similarity, dim=3)
        sv = torch.einsum('bgij,bgcj->bgci', similarity, v)
        sve = torch.einsum('bgij,cij->bgci', similarity, v_embedding)
        stacked_output = torch.cat([sv, sve], dim=-1).view(N * W, self.out_planes * 2, H)
        output = self.bn_output(stacked_output).view(N, W, self.out_planes, 2, H).sum(dim=-2)

        if self.width:
            output = output.permute(0, 2, 1, 3)
        else:
            output = output.permute(0, 2, 3, 1)

        if self.stride > 1:
            output = self.pooling(output)

        return output

    def reset_parameters(self):
        self.qkv_transform.weight.data.normal_(0, math.sqrt(1. / self.in_planes))
        # nn.init.uniform_(self.relative, -0.1, 0.1)
        nn.init.normal_(self.relative, 0., math.sqrt(1. / self.group_planes))


if __name__ == '__main__':
    x = torch.randn(4, 512, 7, 7).cuda()
    # kernel_size 要跟 h,w 相同
    model = AxialAttention(512, 512, kernel_size=7).cuda()
    out = model(x)
    print(out.shape)

本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2249222.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Git 进程占用报错-解决方案

背景 大仓库,由于开发者分支较多,我们在使用 git pull 或 git push 等命令时(与远端仓库交互的命令),不知之前配置了什么,我的电脑会必现以下报错(有非常长一大串报错-不同分支的git进程占用报…

【FPGA-MicroBlaze】串口收发以及相关函数讲解

前言 工具:Vivado2018.3及其所对应的SDK版本 目前网上有许多MicroBlaze 的入门教程,比如下面的这个参考文章,用串口打印一个hello world。 【FPGA】Xilinx MicroBlaze软核使用第一节:Hello World!_fpga软核microblaze-CSDN博客 个…

代码美学2:MATLAB制作渐变色

效果: %代码美学:MATLAB制作渐变色 % 创建一个10x10的矩阵来表示热力图的数据 data reshape(1:100, [10, 10]);% 创建热力图 figure; imagesc(data);% 设置颜色映射为“cool” colormap(cool);% 在热力图上添加边框 axis on; grid on;% 设置热力图的颜色…

Android下载出现open failed: EPERM (Operation not permitted)

今天帮忙给同事调一下apk,发现android 自动更新apk,下载apk的时候总是失败,总是卡在 输出流这一步了 于是第一步分析,立马想到权限 但是下载之前的读写内存的权限也都有了 什么android 10高版本的不开启分区存储也用了 android…

使用爬虫时,如何确保数据的准确性?

在数字化时代,数据的准确性对于决策和分析至关重要。本文将探讨如何在使用Python爬虫时确保数据的准确性,并提供代码示例。 1. 数据清洗 数据清洗是确保数据准确性的首要步骤。在爬取数据后,需要对数据进行清洗,去除重复、无效和…

uniapp中使用Mescroll实现下拉刷新与上拉加载项目实战

如何在UniApp中使用Mescroll实现下拉刷新与上拉加载 前言 下拉刷新和上拉加载更多成为了提升用户体验不可或缺的功能。UniApp作为一个跨平台的应用开发框架,支持使用Vue.js语法编写多端(iOS、Android、H5等)应用。Mescroll作为一款专为Vue设…

【接口自动化测试】一文从0到1详解接口测试协议!

接口自动化测试是软件开发过程中重要的环节之一。通过对接口进行测试,可以验证接口的功能和性能,确保系统正常运行。本文将从零开始详细介绍接口测试的协议和规范。 定义接口测试协议 接口测试协议是指用于描述接口测试的规范和约定。它包含了接口的请求…

RAG数据拆分之PDF

引言RAG数据简介PDF解析方法及工具代码实现总结 二、正文内容 引言 本文将介绍如何将RAG数据拆分至PDF格式,并探讨PDF解析的方法和工具,最后提供代码示例。 RAG数据简介 RAG(关系型属性图)是一种用于表示实体及其关系的图数据…

【开源项目】2024最新PHP在线客服系统源码/带预知消息/带搭建教程

简介 随着人工智能技术的飞速发展,AI驱动的在线客服系统已经成为企业提升客户服务质量和效率的重要工具。本文将探讨AI在线客服系统的理论基础,并展示如何使用PHP语言实现一个简单的AI客服系统。源码仓库地址:ym.fzapp.top 在线客服系统的…

WEB攻防-通用漏洞XSS跨站MXSSUXSSFlashXSSPDFXSS

演示案例: UXSS-Edge&CVE-2021-34506 FlashXSS-PHPWind&SWF反编译 PDFXSS-PDF动作添加&文件上传 使用jpexs反编译swf文件 上传后,发给别人带漏洞的分享链接

QSqlTableModel的使用

实例功能 这边使用一个实例显示数据库 demodb 中 employee 数据表的内容,实现编辑、插入、删除的操作,实现数据的排序和记录过滤,还实现 BLOB 类型字段 Photo 中存储照片的显示、导入等操作,运行界面如下图: 在上图中…

适用于学校、医院等低压用电场所的智能安全配电装置

引言 电力,作为一种清洁且高效的能源,极大地促进了现代生活的便捷与舒适。然而,与此同时,因使用不当或维护缺失等问题,漏电、触电事件以及电气火灾频发,对人们的生命安全和财产安全构成了严重威胁&#xf…

LabVIEW实现UDP通信

目录 1、UDP通信原理 2、硬件环境部署 3、云端环境部署 4、UDP通信函数 5、程序架构 6、前面板设计 7、程序框图设计 8、测试验证 本专栏以LabVIEW为开发平台,讲解物联网通信组网原理与开发方法,覆盖RS232、TCP、MQTT、蓝牙、Wi-Fi、NB-IoT等协议。 结合…

java——Spring MVC的工作流程

Spring MVC的工作流程是基于模型-视图-控制器(MVC)设计模式的一个典型实现,以下是其主要工作流程步骤: 客户端请求提交: 用户通过浏览器向服务器发送请求,该请求首先到达Spring MVC的前端控制器DispatcherS…

带有悬浮窗功能的Android应用

android api29 gradle 8.9 要求 布局文件 (floating_window_layout.xml): 增加、删除、关闭按钮默认隐藏。使用“开始”按钮来控制这些按钮的显示和隐藏。 服务类 (FloatingWindowService.kt): 实现“开始”按钮的功能,点击时切换增加、删除、关闭按钮的可见性。处…

【编译原理】词法、语法、语义实验流程内容梳理

编译原理实验有点难,但是加上ai的辅助就会很简单,下面梳理一下代码流程。 全代码在github仓库中,链接:NeiFeiTiii/CompilerOriginTest at Version2.0,感谢star一下 一、项目结构 关键内容就是里面的那几个.c和.h文件。…

Linux介绍与安装指南:从入门到精通

1. Linux简介 1.1 什么是Linux? Linux是一种基于Unix的操作系统,由Linus Torvalds于1991年首次发布。Linux的核心(Kernel)是开源的,允许任何人自由使用、修改和分发。Linux操作系统通常包括Linux内核、GNU工具集、图…

MFC图形函数学习12——位图操作函数

位图即后缀为bmp的图形文件,MFC中有专门的函数处理这种格式的图形文件。这些函数只能处理作为MFC资源的bmp图,没有操作文件的功能,受限较多,一般常作为程序窗口界面图片、显示背景图片等用途。有关位图操作的步骤、相关函数等介绍…

12.Three.js纹理动画与动效墙案例

12.Three.js纹理动画与动效墙案例 在Three.js的数字孪生场景应用中,我们通常会使用到一些动画渲染效果,如动效墙,飞线、雷达等等,今天主要了解一下其中一种动画渲染效果:纹理动画。下面实现以下动效墙效果&#xff08…

SJYP 24冬季系列 FROZEN CHARISMA发布

近日,女装品牌SJYP 2024年冬季系列——FROZEN CHARISMA已正式发布,展现了更加干练的法式风格。此次新品发布不仅延续了SJYP一贯的强烈设计风格和个性时尚,更融入了法式风情的干练元素,为消费者带来了一场视觉与穿着的双重盛宴。  …