AlphaFold3 RelativePositionEncoding
类用于计算 AlphaFold3 中的相对位置编码。这种编码结合了残基索引、token 索引以及链 ID 等特征,通过线性投影生成一个用于建模相对位置信息的嵌入表示。
源代码:
class RelativePositionEncoding(nn.Module):
"""Relative position encoding."""
def __init__(
self,
c_pair: int,
r_max: int = 32,
s_max: int = 2
):
"""Initializes the relative position encoding.
Args:
c_pair: Dimensions of the pair representation.
r_max: Maximum residue distance, plus or minus r_max.
s_max: Maximum asym_id distance, plus or minus s_max.
"""
super(RelativePositionEncoding, self).__init__()
self.c_pair = c_pair
self.r_max = r_max
self.s_max = s_max
# Compute total input dimensions for the linear projection
input_dim = 2 * r_max + 2 + 2 * r_max + 2 + 2 * s_max + 2 + 1 # (relpos, rel_token, rel_chain, same_entity)
self.linear_proj = Linear(input_dim, c_pair, bias=False)
def forward(self, features: Dict[str, torch.Tensor], mask: torch.Tensor = None) -> torch.Tensor:
"""Computes relative position encoding. AlphaFold3 Supplement Algorithm 3.
Args:
features:
Input feature dictionary containing:
"residue_index":
[*, n_tokens] Residue number in the token's original chain.
"token_index":
[*, n_tokens] Token number. Increases monotonically; does not restart at 1 for new chains.
"asym_id":
[*, n_tokens] Unique integer for each distinct chain.
"entity_id":
[*, n_tokens] Unique integer for each distinct sequence.
"sym_id":
[*, n_tokens] Unique integer within chains of this sequence.
mask:
[*, n_tokens] mask tensor (optional)
Returns:
[*, n_tokens, n_tokens, c_pair] r