文章目录
- Differential Transformer差分注意力,负注意力的引入
- 相关链接
- 介绍
- 初始化函数
- 多头差分注意力
Differential Transformer差分注意力,负注意力的引入
相关链接
ai-algorithms/README.md at main · Jaykef/ai-algorithms (github.com)
unilm/Diff-Transformer at master · microsoft/unilm (github.com)
介绍
注意力是非负的,导致在长序列时,有效信息淹没在无关信息的海洋中,因此引入负注意力,着重关注序列中的有效部分。因此一半的注意力头用作负注意力头,注意力权重由这两部分的注意力权重的加权差决定,加权系数可学习。加权系数的初始化值和层数有关。加权系数是通过四个可学习参数重参数化而来
lambda_q1, lambda_k1, lambda_q2, lambda_k2
参数维度
d i m _ h e a d ∗ n u m _ h e a d ∗ 2 = e m b e d _ d i m dim\_head * num\_head *2 = embed\_dim dim_head∗num_head∗2=embed_dim
名称 | 定义 | 举例 |
---|---|---|
dim_head | embed // num_heads //2 | 32//4//2 |
proj_q | (embed_dim, embed_dim) | (32, 32) |
proj_k | (embed_dim,embed_dim) | (32, 32) |
proj_v | (embed_dim, embed_dim) | (32, 32) |
proj_out | (embed_dim, embed_dim) | (32, 32) |
Q | [N, L, C].view(N, L, 2 *num_heads,dim_head) | (1024, 256, 2 *4 , 4) |
K | [N, L, C].view(N, L, 2 *num_heads, dim_head ) | (1024, 256, 2 *4 , 4) |
V | [N, L, C].view(N, L, num_heads, 2*dim_head ) | (1024, 256, 4 , 2 * 4) |
attn_weights | [N, 2*num_heads, L, L].view(N, 2,num_heads, L, L ) -> [N, num_heads, L, L] | (1024, 2 , 4 , 256 , 256 ) ->(1024, 4 , 256 , 256 ) |
attn | [N, num_heads, L, 2*dim_heads]->[N, L, C] | (1024, 4, 256, 8) -> (1024, 256, 32) |
初始化函数
def lambda_init_fn(depth):
return 0.8 - 0.6 * math.exp(-0.3 * depth)
多头差分注意力
class MultiheadDiffAttn(nn.Module):
def __init__(
self,
embed_dim = 32,
depth = 0,
num_heads = 8,
):
super().__init__()
self.embed_dim = embed_dim
# num_heads set to half of Transformer's #heads
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads // 2
self.scaling = self.head_dim ** -0.5
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.lambda_init = lambda_init_fn(depth)
self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
def forward(
self,
x,
):
bsz, tgt_len, embed_dim = x.size()
src_len = tgt_len
q = self.q_proj(x) #[bsz, tgt_len, embed_dim ]
k = self.k_proj(x) #[bsz, tgt_len, embed_dim]
v = self.v_proj(x) #[bsz, tgt_len, embed_dim]
q = q.view(bsz, tgt_len, 2 * self.num_heads, self.head_dim) #[bsz, tgt_len, 2 * num_heads, head_dim] embed_dim = head_dim * num_heads
k = k.view(bsz, src_len, 2 * self.num_heads, self.head_dim) #[bsz, src_len, 2 * num_heads, head_dim]
v = v.view(bsz, src_len, self.num_heads, 2 * self.head_dim) #[131072, 2, 8, 8] [bsz, tgt_len, num_heads, 2 * head_dim]
q = q.transpose(1, 2) #[bsz, 2 * num_heads, tgt_len, head_dim] [131072, 16, 2, 4]
q *= self.scaling
k = k.transpose(1, 2) #[131072, 16, 2, 4]
v = v.transpose(1, 2) #[131072, 8, 2, 8]
attn_weights = torch.matmul(q, k.transpose(-1, -2)) #[131072, 16, 2, 2] [bsz, 2 * num_heads, tgt_len, src_len]
attn_weights = torch.nan_to_num(attn_weights)
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).type_as(
attn_weights
)
lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
#[bsz, 2 * num_heads, tgt_len, src_len] 每一个注意力还是 [bsz, num_heads, tgt_len, src_len]
attn_weights = attn_weights.view(bsz, self.num_heads, 2, tgt_len, src_len) #[131072, 8, 2, 2, 2] 第一个2是两个差分
attn_weights = attn_weights[:, :, 0] - lambda_full * attn_weights[:, :, 1] # 第一个注意力减去第二个注意力 [131072, 8, 2, 2]
#[bsz, num_heads, tgt_len, src_len]
attn = torch.matmul(attn_weights, v) # [131072, 8, 2, 8]
attn = attn * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, tgt_len, self.num_heads * 2 * self.head_dim) #[131072, 2, 32]
attn = self.out_proj(attn)
return attn