AlphaFold3 protein_datamodule 模块 ProteinDataModule 类继承自
PyTorch Lightning 数据模块(LightningDataModule),
负责 ProteinFlow 数据的准备、加载、拆分、变换等逻辑封装在一起,便于训练过程中的统一管理和复现。
这个类承担了 AlphaFold3 训练和评估过程中的 数据准备、划分、转换、加载 四个核心任务:
任务 | 功能说明 |
---|---|
数据准备 (prepare_data ) |
下载/准备 ProteinFlow 数据集(包含结构和注释) |
数据集构建 (setup ) |
构建训练、验证、测试集,应用转换 |
数据加载器提供 (*_dataloader ) |
返回 PyTorch 的 DataLoader 供模型训练/验证/测试使用 |
数据增强与特征提取 | 应用了 Cropper , Reorder , AF3Featurizer ,为模型生成输入特征 |
源代码:
class ProteinDataModule(LightningDataModule):
"""`LightningDataModule` for the Protein Data Bank.
A `LightningDataModule` implements 7 key methods:
```python
def prepare_data(self):
# Things to do on 1 GPU/TPU (not on every GPU/TPU in DDP).
# Download data, pre-process, split, save to disk, etc...
def setup(self, stage):
# Things to do on every process in DDP.
# Load data, set variables, etc...
def train_dataloader(self):
# return train dataloader
def val_dataloader(self):
# return validation dataloader
def test_dataloader(self):
# return test dataloader
def predict_dataloader(self):
# return predict dataloader
def teardown(self, stage):
# Called on every process in DDP.
# Clean up after fit or test.
```
This allows you to share a full dataset without explaining how to download,
split, transform and process the data.
Read the docs:
https://lightning.ai/docs/pytorch/latest/data/datamodule.html
"""
def __init__(
self,
data_dir: str = "./data/",
resolution_thr: float = 3.5,
min_seq_id: float = 0.3,
crop_size: int = 384,
max_length: int = 10_000,
use_fraction: float = 1.0,
entry_type: str = "chain",
classes_to_exclude: Optional[List[str]] = None,
mask_residues: bool = False,
lower_limit: int = 15,
upper_limit: int = 100,
mask_frac: Optional[float] = None,
mask_sequential: bool = False,
mask_whole_chains: bool = False,
force_binding_sites_frac: float = 0.15,
batch_size: int = 64,
num_workers: int = 0,
pin_memory: bool = False,
debug: bool = False
) -> None:
"""Initialize a `ProteinDataModule`.
:param resolution_thr: Resolution threshold for PDB structures
:param min_seq_id: Minimum sequence identity for MMSeq2 clustering
:param crop_size: The number of residues to crop the proteins to.
:param max_length: Entries with total length of chains larger than max_length will be disregarded.
:param use_fraction: the fraction of the clusters to use (first N in alphabetic order)
:param entry_type: {"biounit", "chain", "pair"} the type of entries to generate ("biounit" for biounit-level
complexes, "chain" for chain-level, "pair" for chain-chain pairs (all pairs that are seen
in the same biounit and have intersecting coordinate clouds))
:param classes_to_exclude: a list of classes to exclude from the dataset (select from "single_chains",
"heteromers", "homomers")
:param mask_residues: if True, the masked residues will be added to the output
:param lower_limit: the lower limit of the number of residues to mask
:param upper_limit: the upper limit of the number of residues to mask