cross attention交叉熵注意力机制

news2024/11/13 12:21:17

        交叉注意力(Cross-Attention)则是在两个不同序列上计算注意力,用于处理两个序列之间的语义关系。在两个不同的输入序列之间计算关联度和加权求和的机制。具体来说,给定两个输入序列,cross attention机制将一个序列中的每个元素与另一个序列中的所有元素计算关联度,并根据关联度对两个序列中的每个元素进行加权求和。这样的机制使模型能够建立不同序列之间的关联关系,并将两个序列的信息融合起来。例如,在翻译任务中,需要将源语言句子和目标语言句子进行对齐,就需要使用交叉注意力来计算两个句子之间的注意力权重。

        交叉注意力机制是一种特殊形式的多头注意力,它将输入张量拆分成两个部分 X1\epsilon R^{n*d1}  和 X2\epsilon R^{n*d2},然后将其中一个部分作为查询集合,另一个部分作为键值集合。它的输出是一个大小为n*d2 的张量,对于每个行向量,都给出了它对于所有行向量的注意力权重。

Q=X_{1} W^{Q} 和 K=V=X_{2} W^{K},则交叉注意力的计算如下:

\operatorname{CrossAttention}\left(X_{1}, X_{2}\right)=\operatorname{Softmax}\left(\frac{Q K^{T}}{\sqrt{d_{2}}}\right) V

import torch
import torch.nn as nn
import torch.nn.functional as F

class CrossAttention(nn.Module):
    def __init__(self, embed_dim, hidden_dim, num_heads):
        super(CrossAttention, self).__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads

        self.query_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
        self.key_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
        self.value_proj = nn.Linear(embed_dim, hidden_dim * num_heads)

        self.out_proj = nn.Linear(hidden_dim * num_heads, embed_dim)

    def forward(self, query, context):
        """
        query: (batch_size, query_len, embed_dim)
        context: (batch_size, context_len, embed_dim)
        """
        batch_size, query_len, _ = query.size()
        context_len = context.size(1)

        # Project input embeddings
        query_proj = self.query_proj(query).view(batch_size, query_len, self.num_heads, self.hidden_dim)
        key_proj = self.key_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim)
        value_proj = self.value_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim)

        # Transpose to get dimensions (batch_size, num_heads, len, hidden_dim)
        query_proj = query_proj.permute(0, 2, 1, 3)
        key_proj = key_proj.permute(0, 2, 1, 3)
        value_proj = value_proj.permute(0, 2, 1, 3)

        # Compute attention scores
        scores = torch.matmul(query_proj, key_proj.transpose(-2, -1)) / (self.hidden_dim ** 0.5)
        attn_weights = F.softmax(scores, dim=-1)

        # Compute weighted context
        context = torch.matmul(attn_weights, value_proj)

        # Concatenate heads and project output
        context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, query_len, -1)
        output = self.out_proj(context)

        return output, attn_weights

# Example usage:
embed_dim = 512
hidden_dim = 64
num_heads = 8

cross_attention = CrossAttention(embed_dim, hidden_dim, num_heads)

# Dummy data
batch_size = 2
query_len = 10
context_len = 20

query = torch.randn(batch_size, query_len, embed_dim)
context = torch.randn(batch_size, context_len, embed_dim)

output, attn_weights = cross_attention(query, context)
print(output.size())  # Should be (batch_size, query_len, embed_dim)
print(attn_weights.size())  # Should be (batch_size, num_heads, query_len, context_len)
  1. 类定义CrossAttention 类继承自 nn.Module,包含初始化函数 __init__ 和前向传播函数 forward
  2. 初始化
    • 定义了一些线性变换层:query_proj, key_proj, 和 value_proj,这些层将嵌入向量转换为多头注意力机制所需的维度。
    • 最终的输出通过 out_proj 再投影回原始的嵌入维度。
  3. 前向传播
    • 输入的 querycontext 分别通过线性变换层,并重新整形以适应多头注意力机制。
    • 计算注意力分数,并通过 softmax 得到注意力权重。
    • 利用注意力权重加权上下文向量,得到新的上下文表示。
    • 最后将多头的结果合并,并通过输出投影层得到最终的输出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1904491.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习与现代医疗设备的结合:革新医疗健康的未来

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 引言 随着技术的不断进步,机器学习(Machine Learning, ML)在现代医疗设备中的应用正在改变着…

基于B/S模式和Java技术的生鲜交易系统

你好呀,我是计算机学姐码农小野!如果有相关需求,可以私信联系我。 开发语言:Java 数据库:MySQL 技术:B/S模式、Java技术 工具:Visual Studio、MySQL数据库开发工具 系统展示 首页 用户注册…

如何在应用运行时定期监控内存使用情况

如何在应用运行时定期监控内存使用情况 在 iOS 应用开发中,实时监控内存使用情况对于优化性能和排查内存泄漏等问题非常重要。本文将介绍如何在应用运行时定期监控内存使用情况,使用 Swift 编写代码并结合必要的工具和库。 1. 创建桥接头文件 首先&…

线程安全的原因及解决方法

什么是线程安全问题 线程安全问题指的是在多线程编程环境中,由于多个线程共享数据或资源,并且这些线程对共享数据或资源的访问和操作没有正确地同步,导致数据的不一致、脏读、不可重复读、幻读等问题。线程安全问题的出现,通常是…

论文略读:Can Long-Context Language Models Subsume Retrieval, RAG, SQL, and More?

202406 arxiv 1 intro 传统上,复杂的AI任务需要多个专门系统协作完成。 这类系统通常需要独立的模块来进行信息检索、问答和数据库查询等任务大模型时代,尤其是上下文语言模型(LCLM)时代,上述问题可以“一体化”完成…

MybatisX插件的简单使用教程

搜索mybatis 开始生成 module path:当前项目 base package:生成的包名,建议先独立生成一个,和你原本的项目分开 encoding:编码,建议UTF-8 class name strategy:命名选择 推荐选择camel:驼峰命…

ROS——多个海龟追踪一个海龟实验

目标 通过键盘控制一个海龟(领航龟)的移动,其余生成的海龟通过监听实现追踪定期获取领航龟和其余龟的坐标信息,通过广播告知其余龟,进行相应移动其余龟负责监听 疑惑点(已解决) int main(int…

【网络安全】实验四(网络扫描工具的使用)

一、本次实验的实验目的 (1)掌握使用端口扫描器的技术,了解端口扫描器的原理 (2)会用Wireshark捕获数据包,并对捕获的数据包进行简单的分析 二、搭配环境 打开两台虚拟机,并参照下图&#xff…

k8s+docker集群整合搭建(完整版)

一、Kubernetes系列之介绍篇 1、背景介绍 云计算飞速发展 IaaS PaaS SaaS Docker技术突飞猛进 一次构建,到处运行 容器的快速轻量 完整的生态环境 2、什么是kubernetes 首先,他是一个全新的基于容器技术的分布式架构领先方案。Kubernetes(k8s)是Goog…

磐维2.0数据库日常维护

磐维数据库简介 “中国移动磐维数据库”(ChinaMobileDB),简称“磐维数据库”(PanWeiDB)。是中国移动信息技术中心首个基于中国本土开源数据库打造的面向ICT基础设施的自研数据库产品。 其产品内核能力基于华为 OpenG…

001uboot体验

1.uboot的作用: 上电->uboot启动->关闭看门狗、初始化时钟、sdram、uart等外设->把内核文件从flash读取到SDRAM->引导内核启动->挂载根文件系统->启动根文件系统的应用程序 2.uboot编译 uboot是一个通用的裸机程序,为了适应各种芯片&…

注意力机制 attention Transformer 笔记

动手学深度学习 这里写自定义目录标题 注意力加性注意力缩放点积注意力多头注意力自注意力自注意力缩放点积注意力:案例Transformer 注意力 注意力汇聚的输出为值的加权和 查询的长度为q,键的长度为k,值的长度为v。 q ∈ 1 q , k ∈ 1 k …

现场Live震撼!OmAgent框架强势开源!行业应用已全面开花

第一个提出自动驾驶并进行研发的公司是Google,巧的是,它发布的Transformer模型也为今天的大模型发展奠定了基础。 自动驾驶已经完成从概念到现实的华丽转变,彻底重塑了传统驾车方式,而大模型行业正在经历的,恰如自动驾…

Mac安装AndroidStudio连接手机 客户端测试

参考文档:https://www.cnblogs.com/andy0816/p/17097760.html 环境依赖 需要java 1.8 java安装 略 下载Android Studio 地址 下载 Android Studio 和应用工具 - Android 开发者 | Android Developers 本机对应的包进行下载 安装过程 https://www.cnblogs.c…

STM32实现硬件IIC通信(HAL库)

文章目录 一. 前言二. 关于IIC通信三. IIC通信过程四. STM32实现硬件IIC通信五. 关于硬件IIC的Bug 一. 前言 最近正在DIY一款智能电池,需要使用STM32F030F4P6和TI的电池管理芯片BQ40Z50进行SMBUS通信。SMBUS本质上就是IIC通信,项目用到STM32CubeMXHAL库…

2025中国郑州门窗业博览会暨整屋定制家居展

2025中国郑州门窗业博览会 2025中国郑州整屋定制家居及家具产业博览会 2025中国家居行业开年第1展 邀请函 展览时间:第一期 2025年2月15日-17日 第二期 2025年2月22日-24日 展览地址:郑州国际会展中心 组委会:【I 3 3】【937O】【7897】…

软件工程(上)

目录 软件过程模型(软件开发模型) 瀑布模型 原型模型 V模型 构件组装模型 螺旋模型(原型瀑布) 基于构件的软件工程(CBSE) 快速应用开发模型(RAD) 统一过程(UP&a…

HTTP模块(一)

HTTP服务 本小节主要讲解HTTP服务如何创建服务,查看HTTP请求&响应报文,还有注意事项说明,另外讲解本地环境&Node环境&浏览器之间的链路图示,如何提取HTTP报文字符串,及报错信息查询。 创建HTTP服务端 c…

【TB作品】51单片机 Proteus仿真00016 乒乓球游戏机

课题任务 本课题任务 (联机乒乓球游戏)如下图所示: 同步显示 oo 8个LED ooooo oo ooooo 8个LED 单片机 单片机 按键 主机 从机 按键 设计题目:两机联机乒乓球游戏 图1课题任务示意图 具体说明: 共有两个单片机,每个单片机接8个LED和1 个按键,两个单片机使用串口连接。 (2)单片机…