原理讲解
【Transformer系列(2)】注意力机制、自注意力机制、多头注意力机制、通道注意力机制、空间注意力机制超详细讲解
自注意力机制
import torch
import torch.nn as nn
# 自注意力机制
class SelfAttention(nn.Module):
def __init__(self, input_dim):
super(SelfAttention, self).__init__()
self.query = nn.Linear(input_dim, input_dim)
self.key = nn.Linear(input_dim, input_dim)
self.value = nn.Linear(input_dim, input_dim)
def forward(self, x, mask=None):
batch_size, seq_len, input_dim = x.shape
q = self.query(x)
k = self.key(x)
v = self.value(x)
atten_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(input_dim, dtype=torch.float))
if mask is not None:
mask = mask.unsqueeze(1)
attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
atten_scores = torch.softmax(atten_weights, dim=-1)
attented_values = torch.matmul(atten_scores, v)
return attented_values
# 自动填充函数
def pad_sequences(sequences, max_len=None):
batch_size = len(sequences)
input_dim = sequences[0].shape[-1]
lengths = torch.tensor([seq.shape[0] for seq in sequences])
max_len = max_len or lengths.max().item()
padded = torch.zeros(batch_size, max_len, input_dim)
for i, seq in enumerate(sequences):
seq_len = seq.shape[0]
padded[i, :seq_len, :] = seq
mask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)
return padded, mask.long()
if __name__ == '__main__':
batch_size = 2
seq_len = 3
input_dim = 128
seq_len_1 = 3
seq_len_2 = 5
x1 = torch.randn(seq_len_1, input_dim)
x2 = torch.randn(seq_len_2, input_dim)
target_seq_len = 10
padded_x, mask = pad_sequences([x1, x2], target_seq_len)
selfattention = SelfAttention(input_dim)
attention = selfattention(padded_x)
print(attention)
多头自注意力机制
import torch
import torch.nn as nn
# 定义多头自注意力模块
class MultiHeadSelfAttention(nn.Module):
def __init__(self, input_dim, num_heads):
super(MultiHeadSelfAttention, self).__init__()
self.num_heads = num_heads
self.head_dim = input_dim // num_heads
self.query = nn.Linear(input_dim, input_dim)
self.key = nn.Linear(input_dim, input_dim)
self.value = nn.Linear(input_dim, input_dim)
def forward(self, x, mask=None):
batch_size, seq_len, input_dim = x.shape
# 将输入向量拆分为多个头
## transpose(1,2)后变成 (batch_size, self.num_heads, seq_len, self.head_dim)形式
q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# 计算注意力权重
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
# 应用 padding mask
if mask is not None:
# mask: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) 用于广播
mask = mask.unsqueeze(1).unsqueeze(2) # 扩展维度以便于广播
attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))
attn_scores = torch.softmax(attn_weights, dim=-1)
# 注意力加权求和
attended_values = torch.matmul(attn_scores, v).transpose(1, 2).contiguous().view(batch_size, seq_len, input_dim)
return attended_values
# 自动填充函数
def pad_sequences(sequences, max_len=None):
batch_size = len(sequences)
input_dim = sequences[0].shape[-1]
lengths = torch.tensor([seq.shape[0] for seq in sequences])
max_len = max_len or lengths.max().item()
padded = torch.zeros(batch_size, max_len, input_dim)
for i, seq in enumerate(sequences):
seq_len = seq.shape[0]
padded[i, :seq_len, :] = seq
mask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)
return padded, mask.long()
if __name__ == '__main__':
heads = 2
batch_size = 2
seq_len_1 = 3
seq_len_2 = 5
input_dim = 128
x1 = torch.randn(seq_len_1, input_dim)
x2 = torch.randn(seq_len_2, input_dim)
target_seq_len = 10
padded_x, mask = pad_sequences([x1, x2], target_seq_len)
multiheadattention = MultiHeadSelfAttention(input_dim, heads)
attention = multiheadattention(padded_x, mask)
print(attention)