AlphaFold3的Attention类
是一个多头注意力类,用于执行标准的多头注意力计算,同时支持 AlphaFold3 特有的初始化方法和其他高级特性。
源代码:
class Attention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
residual: bool = True,
proj_q_w_bias: bool = False,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
residual:
If the output is residual, then the final linear layer is initialized to
zeros so that the residual layer acts as the identity at initialization.
proj_q_w_bias:
Whether to project the Q vectors with a Linear layer that uses a bias
"""
super(Attention, self).__init__()
self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
split_heads = nn.Unflatten(dim=-1, unflattened_size=(self.no_heads, self.c_hidden))
# The qkv linear layers project no_heads * c_hidden and then split the dimensions
linear_q_class = Linear if proj_q_w_bias else LinearNoBias
self.linear_q = nn.Sequential(
linear_q_class(self.c_q, self.c_hidden * self.no_heads, init="glorot"),
split_heads
)
self.linear_k = nn.Sequential(
LinearNoBias(self.c_k, self.c_hidden * self.no_heads, init="glorot"),
split_heads
)
self.linear_v = nn.Sequential(
LinearNoBias(self.c_v, self.c_hidden * self.no_heads, init="glorot"),
split_heads
)
self.linear_o = LinearNoBias(
self.c_hidden * self.no_heads, self.c_q, init="final" if residual else "default"
)
self.to_gamma = None
if self.gating:
self.to_gamma = nn.Sequential(
LinearNoBias(self.c_q, self.c_hidden * self.no_heads, init="gating"),