【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)

news2024/11/25 6:27:24

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)

【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)


文章目录

  • 【深度学习中的注意力机制10】11种主流注意力机制112个创新研究paper+代码——交叉注意力(Cross-Attention)
  • 1. 交叉注意力的起源与提出
  • 2. 交叉注意力的原理
  • 3. 交叉注意力的数学表示
  • 4. 交叉注意力的应用场景与发展
  • 5. 代码实现
  • 6. 代码解释
  • 7. 总结


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

1. 交叉注意力的起源与提出

交叉注意力(Cross-Attention)是在深度学习中提出的一种重要注意力机制,用于在多个输入之间建立关联,主要用于多模态任务中(如图像和文本、视频和音频的联合处理)。

与常规的自注意力机制不同,交叉注意力专注于从两个不同的输入特征空间中提取和结合关键信息。这种机制最初在自然语言处理和计算机视觉的融合任务中得到应用,例如在多模态Transformer、机器翻译和图像-文本任务(如CLIP、DALL·E、VQA等)中。

  • 提出背景:交叉注意力通常用于处理两种不同类型的数据,通过这种机制,一个输入可以对另一个输入进行查询,捕捉和增强跨模态之间的关联。相比自注意力(仅在同一个输入中找到相关性),交叉注意力能够有效地捕捉多模态数据的交互信息。

2. 交叉注意力的原理

交叉注意力的核心思想是将一个输入(例如图像)作为查询(Query),另一个输入(例如文本)作为键(Key)和值(Value),通过注意力机制让查询能够从键和值中选择和关注相关信息。

交叉注意力的步骤:

  • 查询、键、值的生成: 假设有两个不同的输入数据 X1 和 X2,分别生成对应的 Query、Key 和 Value 矩阵。对于 X1,我们可以生成 Query 矩阵,而对于 X2,则可以生成 Key 和 Value 矩阵。
  • 注意力计算: 与自注意力类似,交叉注意力通过计算 Query 和 Key 的相似性来获得注意力权重:
    在这里插入图片描述
    其中 Q 来自 X1,而 K 和 V 来自 X2 。通过这种计算,Query 可以从X2 中提取与其最相关的信息,这种机制实现了两个输入数据之间的特征融合和信息传递。
  • 权重与输出: 计算出的注意力权重应用到 X2的 Value 矩阵上,得到 X1在
    X2上的相关信息。这种机制实现了两个输入数据之间的特征融合和信息传递。

3. 交叉注意力的数学表示

假设有两个输入特征 X 1 ∈ R T 1 × d X_1∈R^{T_1×d} X1RT1×d X 2 ∈ R T 2 × d X_2∈R^{T_2×d} X2RT2×d,其中 T 1 T_1 T1 T 2 T_2 T2分别表示两个输入的长度(如序列长度或特征维度), d d d 表示特征维度。

Query、Key 和 Value 的生成:

  • 对于 X 1 X_1 X1:生成查询矩阵 Q = W q X 1 Q=W_qX_1 Q=WqX1
  • 对于 X 2 X_2 X2:生成键矩阵 K = W k X 2 K=W_kX_2 K=WkX2和值矩阵 V = W v X 2 V=W_vX_2 V=WvX2

注意力计算:
在这里插入图片描述
其中, W q W_q Wq W k W_k Wk W v W_v Wv ∈ R d × d ∈R^{d×d} Rd×d是线性变换矩阵, d d d 是键的维度。

结果输出: 注意力权重应用于 V V V 后的结果,即:
在这里插入图片描述

4. 交叉注意力的应用场景与发展

交叉注意力在以下场景中得到广泛应用:

  • 多模态学习:交叉注意力在视觉和语言任务中的多模态联合建模中尤为常见,如图像与文本的对齐(CLIP)、视觉问答(VQA)和跨模态生成任务(如DALL·E)。
  • 机器翻译:交叉注意力在Transformer中的"解码器"部分用于让生成的序列(目标语言)参考源语言的表示,这大大提高了翻译质量。
  • Transformer架构的扩展:在诸如BERT、GPT等基于Transformer的模型中,交叉注意力也被用于各种任务,例如文本生成、序列到序列任务等。

发展过程中,交叉注意力机制已经被改进和扩展。例如,层次化交叉注意力(Hierarchical Cross-Attention)通过在不同层次上融合多模态信息,进一步提升了模型在多模态任务上的性能。

5. 代码实现

下面是一个基于PyTorch的交叉注意力机制的简单实现,用于展示如何在两个不同的输入(例如图像和文本)之间计算交叉注意力。

import torch
import torch.nn as nn

class CrossAttention(nn.Module):
    def __init__(self, dim, num_heads=8, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.head_dim = dim // num_heads
        
        assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"

        # 线性变换,用于生成 Q, K, V 矩阵
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)

        # 输出的线性变换
        self.out_proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
        
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x1, x2):
        # x1 是 Query,x2 是 Key 和 Value
        B, T1, C = x1.shape  # x1 的形状: [batch_size, seq_len1, dim]
        _, T2, _ = x2.shape  # x2 的形状: [batch_size, seq_len2, dim]

        # 生成 Q, K, V 矩阵
        Q = self.q_proj(x1).view(B, T1, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.k_proj(x2).view(B, T2, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.v_proj(x2).view(B, T2, self.num_heads, self.head_dim).transpose(1, 2)

        # 计算注意力得分
        attn_scores = (Q @ K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_weights = self.softmax(attn_scores)  # 注意力权重
        attn_weights = self.dropout(attn_weights)  # dropout 防止过拟合

        # 使用注意力权重加权值矩阵
        attn_output = attn_weights @ V
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T1, C)

        # 输出线性变换
        output = self.out_proj(attn_output)
        return output

# 测试交叉注意力机制
if __name__ == "__main__":
    B, T1, T2, C = 2, 10, 20, 64  # batch_size, seq_len1, seq_len2, channels
    x1 = torch.randn(B, T1, C)  # Query 输入
    x2 = torch.randn(B, T2, C)  # Key 和 Value 输入

    cross_attn = CrossAttention(dim=C, num_heads=4)
    output = cross_attn(x1, x2)
    
    print("输出形状:", output.shape)  # 输出应该为 [batch_size, seq_len1, channels]

6. 代码解释

CrossAttention 类:该类实现了交叉注意力机制,允许将两个不同的输入(x1x2)进行交叉信息融合。

  • q_proj, k_proj, v_proj:三个线性层,用于将输入分别映射到 Query、Key 和 Value 空间。
  • num_headshead_dim:定义了多头注意力机制的头数和每个头的维度。

forward 函数:实现前向传播过程。

  • Q, K, V:分别从 x1x2 中生成 Query、Key 和 Value 矩阵,形状为 [batch_size, num_heads, seq_len, head_dim]
  • attn_scores:计算 Query 和 Key 的点积,得到注意力得分。
  • attn_weights:通过 softmax 对得分进行归一化,得到注意力权重。
  • attn_output:利用注意力权重对 Value 矩阵进行加权求和,得到最终的注意力输出。

测试部分:随机生成两个输入张量 x1x2,并测试交叉注意力的输出形状,确保与预期一致。

7. 总结

交叉注意力在多模态学习中起到了至关重要的作用,能够有效融合不同类型的数据,使得模型可以同时处理图像、文本等多种信息。通过捕捉模态之间的相关性,交叉注意力为多模态任务中的特征融合提供了强大的工具。

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2224578.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

‌Spring MVC的主要组件有哪些?

前言 SpringMVC的核心组件包括DispatcherServlet、Controller、HandlerMapping、HandlerAdapter、ViewResolver、ModelAndView等,它们协同工作以支持基于MVC架构的Web应用程序开发。这些组件使得开发人员能够以一种声明式和模块化的方式构建Web应用程序&#xff0c…

小程序开发实战:PDF转换为图片工具开发

目录 一、开发思路 1.1 申请微信小程序 1.2 编写后端接口 1.3 后端接口部署 1.4 微信小程序前端页面开发 1.5 运行效果 1.6 小程序部署上线 今天给大家分享小程序开发系列,PDF转换为图片工具的开发实战,感兴趣的朋友可以一起来学习一下&#xff01…

ECharts饼图-基础南丁格尔玫瑰图,附视频讲解与代码下载

引言: 在数据可视化的世界里,ECharts凭借其丰富的图表类型和强大的配置能力,成为了众多开发者的首选。今天,我将带大家一起实现一个饼图图表,通过该图表我们可以直观地展示和分析数据。此外,我还将提供详…

一、在cubemx下RTC配置调试实例测试

一、rtc的时钟有lse提供。 二、选择rtc唤醒与闹钟功能 内部参数介绍 闹钟配置 在配置时间时,注意将时间信息存储起来,防止复位后时间重新配置。 if(HAL_RTCEx_BKUPRead(&hrtc, RTC_BKP_DR0)! 0x55AA)//判断标志位是否配置过,没有则进…

qt EventFilter用途详解

一、概述 EventFilter是QObject类的一个事件过滤器,当使用installEventFilter方法为某个对象安装事件过滤器时,该对象的eventFilter函数就会被调用。通过重写eventFilter方法,开发者可以在事件处理过程中进行拦截和处理,实现对事…

WSL2 Ubuntu22.04编译安装LLVM

前提 这两天因为工作需要,要编译一个Debug版本的llvm。这里对编译安装过程进行一个简单的记录,同时也记录下这个过程中遇到的几个问题。 下载源码并编译 有关llvm编译安装的官方文档在这里。 从git仓库clone llvm的源码。 git clone https://github.c…

FPGA搭建PCIE3.0通信架构简单读写测试,基于XDMA中断模式,提供3套工程源码和技术支持

目录 1、前言工程概述免责声明 2、相关方案推荐我已有的PCIE方案本博客方案的PCIE2.0版本 3、PCIE基础知识4、工程详细设计方案工程设计原理框图XDMA配置及使用XDMA中断模块数据缓存架构用户逻辑Windows版本XDMA驱动安装Linux版本XDMA驱动安装测试应用程序工程源码架构PCIE上板…

电磁场-Laplace算子与冲激函数的关系

csdn重新打一遍公式太麻烦了。欢迎转到我的知乎账号上查阅原版文章,也可后台私信我发送原版PDF或者markdown。 电磁场-Laplace算子与冲激函数的关系 - 知乎 下面的文章是一张超大的图片。

论1+2+3+4+... = -1/12 的不同算法

我们熟知自然数全加和, 推导过程如下, 这个解法并不难,非常容易看懂,但是并不容易真正理解。正负交错和无穷项计算,只需要保持方程的形态,就可以“预知”结果。但是这到底说的是什么意思?比如和…

C++扑克牌(poker)2024年CSP-J认证第二轮第一题 CCF信息学奥赛C++ 中小学初级组 第二轮真题解析

目录 C扑克牌(poker) 一、题目要求 1、编程实现 2、输入输出 二、算法分析 三、程序编写 四、运行结果 五、考点分析 六、推荐资料 C扑克牌(poker) 2024年CSP-J认证第二轮第一题 一、题目要求 1、编程实现 小 P 从同学…

HarmonyOS 组件样式@Style 、 @Extend、自定义扩展(AttributeModifier、AttributeUpdater)

1. HarmonyOS Style 、 Extend、自定义扩展(AttributeModifier、AttributeUpdater) Styles装饰器:定义组件重用样式   ;Extend装饰器:定义扩展组件样式   自定义扩展:AttributeModifier、AttributeUpdater 1.1. 区…

HarmonyOS 5.0应用开发——应用打包HAP、HAR、HSP

【高心星出品】 目录 应用打包HAP、HAR、HSPModule类型HAPHAR创建HAR建立依赖HAR共享内容 HSP创建HSP建立依赖同上HSP共享内容同上 HAR VS HSP 应用打包HAP、HAR、HSP 一个应用通常会包含多种功能,将不同的功能特性按模块来划分和管理是一种良好的设计方式。在开发…

【哈工大_操作系统实验】Lab9 proc文件系统的实现

本节将更新哈工大《操作系统》课程第九个 Lab 实验 proc文件系统的实现。按照实验书要求,介绍了非常详细的实验操作流程,并提供了超级无敌详细的代码注释。 实验目的: 掌握虚拟文件系统的实现原理;实践文件、目录、文件系统等概念…

【C++开篇】

首先初阶的数据结构相信大家已经学习的差不多了,关于初阶数据结构排序的相关内容的总结随后我也会给大家分享出来。C语言和C有许多相同的地方,但也有许多不相同的地方。接下来的C部分,我们主要是针对C与C语言不同的地方来与大家进行分享。其中…

量子变分算法 (python qiskit)

背景 变分量子算法是用于观察嘈杂的近期设备上的量子计算效用的有前途的候选混合算法。变分算法的特点是使用经典优化算法迭代更新参数化试验解决方案或“拟设”。这些方法中最重要的是变分量子特征求解器 (VQE),它旨在求解给定汉密尔顿量的基态,该汉密尔…

这是一篇vue3 的详细教程

Vue 3 详细教程 一、Vue 3 简介 Vue.js 是一款流行的 JavaScript 前端框架,用于构建用户界面。Vue 3 是其最新版本,带来了许多新特性和性能优化,使开发更加高效和灵活。 二、环境搭建 安装 Node.js 前往Node.js 官方网站下载并安装适合你…

WPF+MVVM案例实战(六)- 自定义分页控件实现

文章目录 1、项目准备2、功能实现1、分页控件 DataPager 实现2、分页控件数据模型与查询行为3、数据界面实现 3、运行效果4、源代码获取 1、项目准备 打开项目 Wpf_Examples,新建 PageBarWindow.xaml 界面、PageBarViewModel.cs ,在用户控件库 UserControlLib中创建…

WASM 使用说明23事(RUST实现)

文章目录 1. wasm是什么1.1 chatgpt定义如下:1.2 wasm关键特性: 2. wasm demo2.1 cargo 创建项目2.2 编写code2.3 安装wasm-pack2.4 编译 3.1 html页面引用wasm代码(js引用)3.2 访问页面4 导入js function4.1 编写lib.rs文件,内容…

UML 总结(基于《标准建模语言UML教程》)

定义 UML 又称为统一建模语言或标准建模语言,是一种标准的图形化建模语言,它是面向对象分析与设计的一种标准表示。尽管UML 本身没有对过程有任何定义,但UML 对任何使用它的方法(或过程)提出的要求是:支持用…

【含开题报告+文档+PPT+源码】基于vue框架的东升餐饮点餐管理平台的设计与实现

开题报告 在当前信息化社会背景下,餐饮行业正经历着由传统线下服务模式向线上线下深度融合的转变。随着移动互联网技术及大数据应用的飞速发展,用户对于餐饮服务平台的需求也日益多元化和个性化。他们期望能在一个集便捷、高效、个性化于一体的平台上完…