欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/132597659
OpenFold Multimer 是基于深度学习的方法,预测蛋白质的多聚体结构和相互作用。利用大规模的蛋白质序列和结构数据,以及先进的神经网络架构,来学习蛋白质的表示和特征。可以处理不同类型的多聚体,包括同源和异源多聚体,以及复杂的蛋白质-蛋白质相互作用网络。OpenFold Multimer 的目标是为生物学家提供一个快速、准确和易用的工具,来探索蛋白质的多聚体功能和机制。
训练参数:
python3 train_openfold.py \
--train_data_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \
--train_alignment_dir mydata/alignment_dir/ \
--train_mmcif_data_cache_path mmcif_cache.json \
--template_mmcif_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \
--output_dir mydata/output_dir/ \
--max_template_date "2021-10-10" \
--config_preset "model_1_multimer_v3" \
--template_release_dates_cache_path mmcif_cache.json \
--precision bf16 \
--gpus 1 \
--replace_sampler_ddp=True \
--seed 42 \
--deepspeed_config_path deepspeed_config.json \
--checkpoint_every_epoch \
--train_chain_data_cache_path chain_data_cache.json \
--obsolete_pdbs_file_path [your folder]/af2-data-v230/pdb_mmcif/obsolete.dat
1. train_alignment_dir
核心关注 train_alignment_dir
,这部分是缓存的预处理特征,调用路径如下:
train_openfold.py
的args
参数,传入OpenFoldMultimerDataModule
类- 再由
dataset_gen()
方法,也就是OpenFoldSingleMultimerDataset
类,接收 - 参数由
alignment_dir=self.train_alignment_dir
,转换成alignment_dir
- 再由
OpenFoldMultimerDataModule
类,调用OpenFoldSingleMultimerDataset
类
即
# train_openfold.py
# ...
if "multimer" in args.config_preset:
data_module = OpenFoldMultimerDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args))
# ...
# openfold/data/data_modules.py#OpenFoldMultimerDataModule
# ...
if self.training_mode:
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
mmcif_data_cache_path=self.train_mmcif_data_cache_path,
alignment_dir=self.train_alignment_dir,
filter_path=self.train_filter_path,
max_template_hits=self.config.train.max_template_hits,
shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
alignment_index=self.alignment_index,)
# ...
在 OpenFoldSingleMultimerDataset
类中,alignment_dir
用于 _chain_ids
的赋值,即
if alignment_index is not None:
self._chain_ids = list(alignment_index.keys())
else:
self._chain_ids = list(os.listdir(alignment_dir))
alignment_index_path
支持作为参数,传入,默认是空,相关描述如下,核心是先编译成单个文件,再读入,可以提升效率:
In cases where it may be burdensome to create separate files for each chain’s alignments, alignment directories can be consolidated using the scripts in scripts/alignment_db_scripts/. First, run create_alignment_db.py to consolidate an alignment directory into a pair of database and index files. Once all alignment directories (or shards of a single alignment directory) have been compiled, unify the indices with unify_alignment_db_indices.py. The resulting index, super.index, can be passed to the training script flags containing the phrase alignment_index. In this scenario, the alignment_dir flags instead represent the directory containing the compiled alignment databases. Both the training and distillation datasets can be compiled in this way. Anecdotally, this can speed up training in I/O-bottlenecked environments.
其中,self._chain_ids
是全部的训练集:
def __len__(self):
return len(self._chain_ids)
设置 logger 日志:
import logging
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
训练数据的遍历参数:
def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx)
chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
根据输出,组织训练数据:
mmcif_id is: 5ykn, idx: 8580 and has 1 chains
mmcif_id is: 2lna, idx: 3848 and has 1 chains
mmcif_id is: 7rrp, idx: 8447 and has 24 chains
mmcif_id is: 6k8h, idx: 7870 and has 2 chains
...
2. OpenFoldSingleMultimerDataset
具体分析 OpenFoldSingleMultimerDataset 类。在 __getitem__
方法中,遍历训练样本,核心关注:
self.idx_to_mmcif_id()
函数调用self._mmcifs[idx]
- 2个关键变量,
self._mmcifs
和self.mmcif_data_cache
,而且两者的 keys 要保持一致。
即:
def __getitem__(self, idx):
mmcif_id = self.idx_to_mmcif_id(idx)
chains = self.mmcif_data_cache[mmcif_id]['chain_ids']
print(f"mmcif_id is: {mmcif_id}, idx: {idx} and has {len(chains)} chains")
关于 self._mmcifs
数据,调用 mmcif_data_cache_path
-> self.mmcif_data_cache
-> self._mmcifs
mmcif_data_cache_path
来源于预处理的过程
即:
# ...
logger.info(f"[CL] mmcif_data_cache_path: {mmcif_data_cache_path}")
if mmcif_data_cache_path is not None:
with open(mmcif_data_cache_path, "r") as infile:
self.mmcif_data_cache = json.load(infile)
assert isinstance(self.mmcif_data_cache, dict)
# ...
if self.mmcif_data_cache is not None:
self._mmcifs = list(self.mmcif_data_cache.keys())
self._mmcif_id_to_idx_dict = {mmcif: i for i, mmcif in enumerate(self._mmcifs)}
其中 mmcif_cache.json
的文件数据,包括PDB信息,即:
{
"4ewn": {
"release_date": "2012-12-05",
"chain_ids": ["D"],
"seqs": [
"MLAKRI..."
],
"no_chains": 1,
"resolution": 1.9
},
"5m9r": {
"release_date": "2017-02-22",
"chain_ids": ["A", "B"],
"seqs": [
"MQDNS...",
"MQDNS..."
],
"no_chains": 2,
"resolution": 1.44
},
# ...
BugFix: 增加 train_mmcif_data_cache_path
参数
--train_mmcif_data_cache_path mmcif_cache.json