ProteinMPNN的神经网络模型,主要用于处理蛋白质相关的数据。模型包括特征提取部分(ProteinFeatures)、编码器层(EncLayer)和译码器层(DecLayer)
ProteinMPNN forward函数的部分代码:
# Concatenate sequence embeddings for autoregressive decoder
h_S = self.W_s(S)
h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
# Build encoder embeddings
h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
# ...
h_EXV_encoder_fw = mask_fw * h_EXV_encoder
for layer in self.decoder_layers:
h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
h_V = torch.utils.checkpoint.checkpoint(layer, h_V, h_ESV, mask)
-
代码解读:
-
h_EXV_encoder
没有序列信息:- 没错,
h_EXV_encoder
是从编码器得到的图结构的上下文信息,
- 没错,