从代码直观理解Self-Attention和Cross-Attention的本质区别

news2024/11/15 0:14:41

Transformer的模型架构实际上非常简单,Self-Attention 和 Cross-Attention 仅仅是在 k, v上有所不同(这里不讨论 mask)。

论文原文:Attention Is All You Need

我们可以使用同一个 Attention 类来实现 Self-Attention 和 Cross-Attention。实际上,在 Transformer 的源代码中就是如此。抛开花哨的可视化所赋予的意义,下面是一个 Attention 的实现:

import torch
import torch.nn as nn
import math

# 单头,无 mask 的 Attention 实现(如果你不知道这里说的是什么,就不用在意)
class Attention(nn.Module):
    def __init__(self, d_model):
        super(Attention, self).__init__()
        # 定义查询、键和值的线性变换
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v):
        # 计算查询、键和值的投影
        q = self.w_q(q)
        k = self.w_k(k)
        v = self.w_v(v)

        # 计算注意力得分
        attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
        attention_weights = self.softmax(attention_scores)
        
        # 加权求和得到输出
        attention_output = torch.matmul(attention_weights, v)
        
        return attention_output

这里 d_model 指的是输入的维度,可以看到 W q W_q Wq, W k W_k Wk, W v W_v Wv 实际上就是一个简单的线性层nn.Linear(d_model, d_model),没有任何多余的操作。而这个 Attention() 所接受的输入q, k, v,就是我们将讨论的 self-Attention 和 cross-attention 的主要区别所在。

什么是Self-Attention?

Self-attention,也叫intra-attention,是指序列中的每个元素通过与该序列的其他元素进行关联,来捕捉上下文信息。其关键在于,查询(query)、键(key)和值(value)都来自同一序列(看图可以发现,线条一分为三),通过这种机制,模型能够在不同位置上计算特征的相关性,获取全局信息。

image-20240912192501911

Self-Attention的计算过程:
  1. 对输入进行线性变换生成q, k, v(即查询、键和值)。
  2. 计算查询和键之间的相似性(注意力得分)。
  3. 用softmax归一化注意力得分,生成注意力权重。
  4. 使用注意力权重对值v进行加权求和,生成最终的输出。
# 示例输入数据:q, k, v 来自相同序列
q = k = v = torch.randn(2, 10, 64)  # batch_size=2, 序列长度=10, d_model=64

# 初始化Attention层
attention = Attention(d_model=64)

# 执行前向传播
output = attention(q, k, v)

print(output.shape)  # 输出形状为 (2, 10, 64)

代码说明:

  • Self-Attention 中q, k, v均来自同一个输入序列,所以在这里直接将它们设置为相同的张量。在其他仓库的实现中,你可能会看到 attention(x, x, x)

什么是Cross-Attention?

Cross-attention的查询(query)来自一个序列,而键(key)和值(value)来自另一个序列。

image-20240912192745241

它的目的是让模型能够结合来自两个不同输入的信息,在跨模态任务或翻译任务中,cross-attention非常有用,例如在解码阶段将目标语言与源语言关联。

Cross-Attention的计算过程:

与self-attention一致,主要区别在于 qk, v 来自不同的输入序列。

# 示例输入数据:q与k, v来自不同序列
q = torch.randn(2, 10, 64)  # batch_size=2, 序列长度=10, d_model=64 (Query序列)
k = v = torch.randn(2, 15, 64)  # batch_size=2, 序列长度=15, d_model=64 (Key/Value序列)

# 执行前向传播
output = attention(q, k, v)

print(output.shape)  # 输出形状为 (2, 10, 64)

代码说明:

  • 通常 q 是解码器(decoder)的输入,kv 来自编码器(encoder)。

总结

  • 输入来源:Self-Attention中,q, k, v 都来自同一序列;Cross-Attention中,q 来自一个序列,kv 来自另一个序列。
  • 应用场景:Self-Attention 通常用于理解同一序列中的上下文关系,如文本分析、机器翻译的编码阶段;Cross-Attention 用于两个不同序列间的关联,如机器翻译的解码阶段。

下面是 Transformer 完整的模型架构图:

image-20240912194002666

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2129553.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

day11-多线程

一、线程安全问题 线程安全问题出现的原因?存在多个线程在同时执行同时访问一个共享资源存在修改该共享资源 线程安全:多个线程同时修改同一个资源 取钱案例 小明和小红是一对夫妻,他们有一个共同的账户,余额是10万元 如果小明和小红同时来取…

速看!6款可以写论文的ai写作网站,这才是真正的论文神器!(含教程)

在当今信息爆炸的时代,AI写作工具的出现极大地提高了写作效率和质量。特别是对于需要撰写论文的学生和研究人员来说,这些工具提供了极大的便利。本文将重点介绍一款备受推荐的AI写作平台——千笔-AIPassPaper,并结合相关教程帮助用户更好地使…

【北京迅为】《STM32MP157开发板使用手册》- 第二十四章 STM32CubeIDE的初步使用

iTOP-STM32MP157开发板采用ST推出的双核cortex-A7单核cortex-M4异构处理器,既可用Linux、又可以用于STM32单片机开发。开发板采用核心板底板结构,主频650M、1G内存、8G存储,核心板采用工业级板对板连接器,高可靠,牢固耐…

校园水电费管理微信小程序的设计与实现+ssm(lw+演示+源码+运行)

校园水电费管理小程序 摘 要 随着社会的发展,社会的方方面面都在利用信息化时代的优势。互联网的优势和普及使得各种系统的开发成为必需。 本文以实际运用为开发背景,运用软件工程原理和开发方法,它主要是采用java语言技术和mysql数据库来…

基于SSM的学生信息管理系统(选课管理系统)的设计与实现 (含源码+sql+视频导入教程)

👉文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1 、功能描述 基于SSM的学生信息管理系统(选课管理系统)13拥有三种角色 管理员:学生管理、教师管理、专业管理、课程管理、审批管理、课程表管理、开课管理、教室管…

高德地图JS API加载行政区边界AMap.Polygon

🤖 作者简介:水煮白菜王 ,一位资深前端劝退师 👻 👀 文章专栏: 高德AMap专栏 ,记录一下平时在博客写作中,总结出的一些开发技巧✍。 感谢支持💕💕&#x1f49…

大模型LLM之SpringAI:Web+AI(二)

2.2.2、ChatModel API(聊天模型API) 聊天模型太多了,这里只写OpenAI和Ollama ChatModel和ChatClient区别:ChatClient针对的是所有模型,共用一个客户端。而ChatModel是针对各个模型实现的。 (1)OpenAI 自动配置 <dependency><groupId>org.springframework…

vue3 内置组件 <Suspense>

官方文档&#xff1a; <Suspense> 指南-Suspense 官方提示&#xff1a; <Suspense> 是一项实验性功能。它不一定会最终成为稳定功能&#xff0c;并且在稳定之前相关 API 也可能会发生变化。 <Suspense>是一个内置组件&#xff0c;用来在组件树中协调对异步依…

git删除本地分支报错:error: the branch ‘xxx‘ is not fully merged

git删除本地分支报错&#xff1a;error: the branch xxx is not fully merged error: the branch xxx is not fully merged 直接&#xff1a; git branch -D xxx 就可以。 如果删除远程分支&#xff1a; git push origin --delete origin/xxx git强制删除本地分支 git branc…

如何将Git本地代码推送到Gitee云端仓库

如何将Git本地代码推送到Gitee云端仓库 在使用Git进行版本控制时&#xff0c;将本地代码推送到远程仓库是一个基本且重要的操作。本文将详细介绍如何将你的Git本地代码推送到Gitee&#xff08;码云&#xff09;云端仓库。Gitee是一个国内非常流行的代码托管平台&#xff0c;类…

NX—UI界面生成的文件在VS上的设置

UI界面保存生成的三个文件 打开VS创建项目&#xff0c;删除自动生成的cpp文件&#xff0c;将生成的hpp和cpp文件拷贝到项目的目录下&#xff0c;并且在VS项目中添加现有项目。 修改VS的输出路径&#xff0c;项目右键选择属性&#xff0c;链接器中的常规&#xff0c;文件路径D:…

线性代数 第七讲 二次型_标准型_规范型_坐标变换_合同_正定二次型详细讲解_重难点题型总结

文章目录 1.二次型1.1 二次型、标准型、规范型、正负惯性指数、二次型的秩1.2 坐标变换1.3 合同1.4 正交变换化为标准型 2.二次型的主要定理3.正定二次型与正定矩阵4.重难点题型总结4.1 配方法将二次型化为标准型4.2 正交变换法将二次型化为标准型4.3 规范型确定取值范围问题4.…

《中国制药设备行业市场现状分析与发展前景预测研究报告》

报告导读&#xff1a;本报告从国际制药设备发展、国内制药设备政策环境及发展、研发动态、供需情况、重点生产企业、存在的问题及对策等多方面多角度阐述了制药设备市场的发展&#xff0c;并在此基础上对制药设备的发展前景做出了科学的预测&#xff0c;最后对制药设备投资潜力…

​​操作系统 ---- 进程调度的时机、切换与过程

目录 一、进程调度的时机 1.1 什么时候需要进行进程调度与切换&#xff1f; 1.2 什么情况下不能进行进程调度与切换&#xff1f; 二、进程调度的方式 2.1 非抢占方式(Nonpreemptive Mode) 2.2 抢占方式(Preemptive Mode) 三、总结 一、进程调度的时机 进程调度&am…

FreeRTOS内部机制学习04(任务通知和软件定时器)

文章目录 何为任务通知&#xff1f;任务通知使用例子任务通知的优势以及劣势优势劣势 深入源码看看API函数内部干了什么函数的种类函数都做了啥&#xff1f; 软件定时器软件定时器的作用软件定时器内部到底做了什么实现了“闹钟”功能引入守护任务&#xff0c;守护任务做了啥&a…

SprinBoot+Vue网上购物商城的设计与实现

目录 1 项目介绍2 项目截图3 核心代码3.1 Controller3.2 Service3.3 Dao3.4 application.yml3.5 SpringbootApplication3.5 Vue 4 数据库表设计5 文档参考6 计算机毕设选题推荐7 源码获取 1 项目介绍 博主个人介绍&#xff1a;CSDN认证博客专家&#xff0c;CSDN平台Java领域优质…

我们怎么把自动化测试落地到一个项目上呢?

现在的软件测试行业已经不是原先的点点点的功能测试&#xff0c;要想在软件测试这一行中扎根稳住&#xff0c;就需要你会的很多&#xff0c;不局限于功能测试&#xff0c;还要会自动化测试、接口测试、性能测试等。 今天就来说一下自动化测试&#xff0c;首先什么是自动化测试…

简单分享-获取.txt文件内数据 文件内数据逗号分隔 分隔符 C语言

简单分享-获取.txt文件内数据 文件内数据逗号分隔 分隔符 C语言 数据存储到文件中&#xff0c;把文件数据读取到数组&#xff0c;方便数据处理。 # include <stdio.h> # include <stdlib.h> # include <string.h>#define DATANUM 307200 //数组个数 int ma…

Linux之MySQL定时备份

#!/bin/bash #author: zking #MySQL定义备份并发送邮件 #定义变量 DATE$(date %F"_"%H:%M:%S) HOST127.0.0.1 DBdb1 USERNAMEroot PASSWORDun1xR00t MAILdonkeevip.qq.com BACKUP_DIR/data/db_backup SQL_FILE${DB}_sql_$DATE.sql#判断备份目录是否存在 if [ ! -d $B…

Visual Studio提示:无法安装CPpython.Exe.x64

如果你需要在Visual Studio中使用python环境&#xff0c;而且你本身已经有一个python环境&#xff0c;则只需要将你自己的python环境配置到Visual Studio中即可&#xff0c;可以无视如题报错&#xff0c;将不会产生实质性的问题或影响。 解决办法&#xff1a; 工具->获取工…