交叉注意力(Cross-Attention)则是在两个不同序列上计算注意力,用于处理两个序列之间的语义关系。在两个不同的输入序列之间计算关联度和加权求和的机制。具体来说,给定两个输入序列,cross attention机制将一个序列中的每个元素与另一个序列中的所有元素计算关联度,并根据关联度对两个序列中的每个元素进行加权求和。这样的机制使模型能够建立不同序列之间的关联关系,并将两个序列的信息融合起来。例如,在翻译任务中,需要将源语言句子和目标语言句子进行对齐,就需要使用交叉注意力来计算两个句子之间的注意力权重。
交叉注意力机制是一种特殊形式的多头注意力,它将输入张量拆分成两个部分 和 ,然后将其中一个部分作为查询集合,另一个部分作为键值集合。它的输出是一个大小为 的张量,对于每个行向量,都给出了它对于所有行向量的注意力权重。
令 和 ,则交叉注意力的计算如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
class CrossAttention(nn.Module):
def __init__(self, embed_dim, hidden_dim, num_heads):
super(CrossAttention, self).__init__()
self.embed_dim = embed_dim
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.query_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
self.key_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
self.value_proj = nn.Linear(embed_dim, hidden_dim * num_heads)
self.out_proj = nn.Linear(hidden_dim * num_heads, embed_dim)
def forward(self, query, context):
"""
query: (batch_size, query_len, embed_dim)
context: (batch_size, context_len, embed_dim)
"""
batch_size, query_len, _ = query.size()
context_len = context.size(1)
# Project input embeddings
query_proj = self.query_proj(query).view(batch_size, query_len, self.num_heads, self.hidden_dim)
key_proj = self.key_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim)
value_proj = self.value_proj(context).view(batch_size, context_len, self.num_heads, self.hidden_dim)
# Transpose to get dimensions (batch_size, num_heads, len, hidden_dim)
query_proj = query_proj.permute(0, 2, 1, 3)
key_proj = key_proj.permute(0, 2, 1, 3)
value_proj = value_proj.permute(0, 2, 1, 3)
# Compute attention scores
scores = torch.matmul(query_proj, key_proj.transpose(-2, -1)) / (self.hidden_dim ** 0.5)
attn_weights = F.softmax(scores, dim=-1)
# Compute weighted context
context = torch.matmul(attn_weights, value_proj)
# Concatenate heads and project output
context = context.permute(0, 2, 1, 3).contiguous().view(batch_size, query_len, -1)
output = self.out_proj(context)
return output, attn_weights
# Example usage:
embed_dim = 512
hidden_dim = 64
num_heads = 8
cross_attention = CrossAttention(embed_dim, hidden_dim, num_heads)
# Dummy data
batch_size = 2
query_len = 10
context_len = 20
query = torch.randn(batch_size, query_len, embed_dim)
context = torch.randn(batch_size, context_len, embed_dim)
output, attn_weights = cross_attention(query, context)
print(output.size()) # Should be (batch_size, query_len, embed_dim)
print(attn_weights.size()) # Should be (batch_size, num_heads, query_len, context_len)
- 类定义:
CrossAttention
类继承自nn.Module
,包含初始化函数__init__
和前向传播函数forward
。 - 初始化:
- 定义了一些线性变换层:
query_proj
,key_proj
, 和value_proj
,这些层将嵌入向量转换为多头注意力机制所需的维度。 - 最终的输出通过
out_proj
再投影回原始的嵌入维度。
- 定义了一些线性变换层:
- 前向传播:
- 输入的
query
和context
分别通过线性变换层,并重新整形以适应多头注意力机制。 - 计算注意力分数,并通过 softmax 得到注意力权重。
- 利用注意力权重加权上下文向量,得到新的上下文表示。
- 最后将多头的结果合并,并通过输出投影层得到最终的输出。
- 输入的