PSP - 蛋白质结构预测 OpenFold Multimer 重构训练模型的数据加载

news2024/10/7 6:51:07

欢迎关注我的CSDN:https://spike.blog.csdn.net/
本文地址:https://spike.blog.csdn.net/article/details/132602155

PDB

OpenFold Multimer 在训练过程的数据加载时,需要将 MSA 与 Template 信息转换成 Feature,再进行训练,这样速度较慢。通过修改数据集类 OpenFoldSingleMultimerDataset__getitem__ 方法,可以加速训练过程。


1. 准备训练数据

在训练过程中,需要读取 mmcif_cache.json 文件,数据结构如下:

{
    "4ewn": {
        "release_date": "2012-12-05",
        "chain_ids": [
            "D"
        ],
        "seqs": [
            "MLAKRI..."
        ],
        "no_chains": 1,
        "resolution": 1.9
    },
    "5m9r": {
        "release_date": "2017-02-22",
        "chain_ids": [
            "A",
            "B"
        ],
        "seqs": [
            "MQDNS...",
            "MQDNS..."
        ],
        "no_chains": 2,
        "resolution": 1.44
    },
#...
}  

当前的训练数据格式,例如 train_200_mini.csv,如下:

pdb_id,chain_id,resolution,release_date,seq,len,chain_type,filepath
7m5z,"A,B",3.06,2021-10-06,"LEDVV...,QNKLE...","263,264","protein,protein",[pdb_path]/structures/m5/pdb7m5z.ent.gz
7k05,"A,B",1.85,2021-10-06,"MSFPP...,MSFPP...","200,200","protein,protein",[pdb_path]/structures/k0/pdb7k05.ent.gz
# ...

同时需要将 feature 的路径,也加入到训练文件 mmcif_cache.json 中,进而,通过预读文件,进行特征抽取,即:

[your folder]/multimer_train/features

使用特征文件夹中,已经预处理之后的特征 features.pkl,进行训练即可:

# 单个文件夹内容
chain_id_map.json
features.pkl
sequences.fasta

训练文件的转换命令,如下:

python openfold_scripts/main_mmcif_cache_transfer.py -i data/train_200_mini.csv -f [your folder]/multimer_train/features -o mydata/openfold/mmcif_cache_mini.json

源码如下:

#!/usr/bin/env python
# -- coding: utf-8 --
"""
Copyright (c) 2022. All rights reserved.
Created by C. L. Wang on 2023/8/31
"""
import argparse
import json
import os
import sys
from pathlib import Path

import pandas as pd
from tqdm import tqdm

p = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if p not in sys.path:
    sys.path.append(p)


class MmcifCacheTransfer(object):
    """
    训练 CSV 转换成 OpenFold 的 mmcif_cache.json 格式
    """
    def __init__(self):
        pass

    @staticmethod
    def process(input_path, feature_dir, output_path):
        print(f"[Info] 输入文件: {input_path}")
        print(f"[Info] 特征文件夹: {feature_dir}")
        print(f"[Info] 输出文件: {output_path}")
        assert os.path.isfile(input_path)
        df = pd.read_csv(input_path)
        print(f"[Info] 输入样本: {len(df)}")
        mmcif_cache_dict = dict()
        # pdb_id,chain_id,resolution,release_date,seq,len,chain_type,filepath
        for _, row in tqdm(df.iterrows(), "[Info] pdb"):
            pdb_id = row["pdb_id"]
            release_date = row["release_date"]
            chain_ids = row["chain_id"].split(",")
            seqs = row["seq"].split(",")
            no_chains = len(chain_ids)
            resolution = float(row["resolution"])
            feature_folder = os.path.join(feature_dir, pdb_id[1:3], f"pdb{pdb_id}_{''.join(chain_ids)}")
            pdb_dict = {
                "release_date": str(release_date),
                "chain_ids": chain_ids,
                "seqs": seqs,
                "no_chains": no_chains,
                "resolution": resolution,
                "feature_folder": feature_folder
            }
            mmcif_cache_dict[pdb_id] = pdb_dict
        with open(output_path, "w") as fp:
            fp.write(json.dumps(mmcif_cache_dict, indent=4))
        print(f"[Info] 全部处理完成: {output_path}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-i",
        "--input-path",
        help="the input file path.",
        type=Path,
        required=True,
    )
    parser.add_argument(
        "-f",
        "--feature-dir",
        help="the preprocess feature dir.",
        type=Path,
        required=True
    )
    parser.add_argument(
        "-o",
        "--output-path",
        help="the output file path.",
        type=Path,
        required=True
    )

    args = parser.parse_args()

    input_path = str(args.input_path)
    feature_dir = str(args.feature_dir)
    output_path = str(args.output_path)
    assert os.path.isfile(input_path)

    # from root_dir import ROOT_DIR, DATA_DIR
    # input_path = os.path.join(ROOT_DIR, "data", "train_200_mini.csv")
    # output_path = os.path.join(DATA_DIR, "openfold", "mmcif_cache_mini.json")
    mct = MmcifCacheTransfer()
    mct.process(input_path, feature_dir, output_path)


if __name__ == '__main__':
    main()

2. 加载训练数据

OpenFold Multimer 的特征读取逻辑,在 openfold/data/data_modules.py#OpenFoldSingleMultimerDataset() 中,即:

if self.mode == 'train' or self.mode == 'eval':
    path = os.path.join(self.data_dir, f"{mmcif_id}")
    ext = None
    for e in self.supported_exts:
        if os.path.exists(path + e):
            ext = e
            break
    if ext is None:
        raise ValueError("Invalid file type")
    # TODO: Add pdb and core exts to data_pipeline for multimer
    path += ext
    if ext == ".cif":
        data = self._parse_mmcif(
            path, mmcif_id, self.alignment_dir, alignment_index)
    else:
        raise ValueError("Extension branch missing")
else:
    path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
    data = self.data_pipeline.process_fasta(
        fasta_path=path,
        alignment_dir=self.alignment_dir)

修改成直接加载 Feature 的形式,即:

if self.mode == 'train' or self.mode == 'eval':
    # 训练或评估时,使用预处理的特征
    feat_folder = self.mmcif_data_cache[mmcif_id]['feature_folder']
    feat_path = os.path.join(feat_folder, "features.pkl")
    # logger.info(f"[Info] feat_path: {feat_path}")
    data = {}
    with open(feat_path, "rb") as f:
        feat_dict = pickle.load(f)
    data.update(feat_dict)
    # logger.info(f"[Info] data: {data.keys()}")
else:
    path = os.path.join(self.data_dir, f"{mmcif_id}.fasta")
    data = self.data_pipeline.process_fasta(
        fasta_path=path,
        alignment_dir=self.alignment_dir)

同时,还需要修改训练数据总数:

def __len__(self):
    # 数据部分都由 mmcif_data_cache 提供
    # return len(self._chain_ids)
    return len(self.mmcif_data_cache.keys)

3. 配置模型训练

模型训练的参数,如下:

python3 train_openfold.py \
    --train_data_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \
    --train_alignment_dir mydata/alignment_dir/ \
    --train_mmcif_data_cache_path [your folder]/multimer_train/openfold_cache/mmcif_cache_mini.json \
    --template_mmcif_dir [your folder]/af2-data-v230/pdb_mmcif/mmcif_files/ \
	--output_dir mydata/output_dir/ \
    --max_template_date "2021-10-10" \
    --config_preset "model_1_multimer_v3" \
    --template_release_dates_cache_path mmcif_cache.json \
    --precision bf16 \
    --gpus 1 \
    --replace_sampler_ddp=True \
    --seed 42 \
    --deepspeed_config_path deepspeed_config.json \
    --checkpoint_every_epoch \
    --obsolete_pdbs_file_path [your folder]/af2-data-v230/pdb_mmcif/obsolete.dat

模型训练占用显存较多,V100 目前无法支持,调低 crop_size 与 num_workers,降低资源占用,配置位于 openfold/config.py 中,即:

# crop_size
elif "multimer" in name:
    c.update(multimer_config_update.copy_and_resolve_references())
    c.data.train.crop_size = 64  # TODO: 用于测试

# num_workers
"data_module": {
    "use_small_bfd": False,
    "data_loaders": {
        "batch_size": 1,
        # "num_workers": 16,
        "num_workers": 2,  # TODO: 用于测试
        "pin_memory": True,
    },
},

其中,crop_size = 64 占用显存约是 5141MiB

训练日志,如下:

Epoch 0:   0%|                                 | 0/199 [00:00<?, ?it/s]INFO:openfold/data/data_modules.py:mmcif_id is: 7poc, idx: 148 and has 4 chains
INFO:openfold/data/data_modules.py:mmcif_id is: 7u49, idx: 97 and has 3 chains
INFO:openfold/data/data_modules.py:mmcif_id is: 7z7h, idx: 114 and has 6 chains
INFO:openfold/data/data_modules.py:mmcif_id is: 7nup, idx: 111 and has 4 chains
cum_loss: tensor([84.1698], device='cuda:0', dtype=torch.float64, grad_fn=<MulBackward0>) losses: {'distogram': tensor(4.1562, device='cuda:0', dtype=torch.float64), 'experimentally_resolved': tensor(0.6914, device='cuda:0'), 'fape': tensor(1.6598, device='cuda:0', dtype=torch.float64), 'plddt_loss': tensor(3.9062, device='cuda:0', dtype=torch.float64), 'masked_msa': tensor(3.0938, device='cuda:0'), 'supervised_chi': tensor(0.7941, device='cuda:0', dtype=torch.float64), 'violation': tensor(3.6495, device='cuda:0'), 'tm': tensor(4.1562, device='cuda:0', dtype=torch.float64), 'chain_center_of_mass': tensor([1.3754], device='cuda:0', dtype=torch.float64), 'unscaled_loss': tensor([10.5212], device='cuda:0', dtype=torch.float64), 'loss': tensor([84.1698], device='cuda:0', dtype=torch.float64)}
Epoch 0:   1%|| 1/199 [02:55<9:38:06, 175.18s/it, loss=84.2, v_num=]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/956903.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【源码】智能导诊系统:医疗行业的变革者

智能导诊系统源码&#xff0c;3D人体导诊系统源码 随着科技的迅速发展&#xff0c;人工智能已经逐渐渗透到我们生活的各个领域。在医疗行业中&#xff0c;智能导诊系统成为了一个备受关注的应用。本文将详细介绍智能导诊系统的概念、技术原理以及在医疗领域中的应用&#xff0c…

计算机毕业设计之基于Python+MySQL的健身房管理系统(文档+源码+部署教程)

系统主要采用python技术和MySQL数据库技术以及Django框架进行开发。系统主要包括个人中心、用户管理、教练管理、健身课程管理、健身器材管理、健身记录管理、身体数据管理、在线留言、系统管理、订单管理等功能&#xff0c;从而实现智能化的健身房管理方式&#xff0c;提高健身…

DatenLord前沿技术分享No.34

达坦科技专注于打造新一代开源跨云存储平台DatenLord&#xff0c;通过软硬件深度融合的方式打通云云壁垒&#xff0c;致力于解决多云架构、多数据中心场景下异构存储、数据统一管理需求等问题&#xff0c;以满足不同行业客户对海量数据跨云、跨数据中心高性能访问的需求。在本周…

Jenkins测试报告样式优化

方式一&#xff1a;修改Content Security Policy&#xff08;临时解决&#xff0c;Jenkins重启后失效) 1、jenkins首页—>ManageJenkins—>Tools and Actions标题下—>Script Console 2、粘贴脚本输入框中&#xff1a;System.setProperty("hudson.model.Directo…

Java中转换流(InputStreamReader,OutputStreamWriter),打印流(PrintStream,PrintWriter)

转换流 InputStreamReader 和 OutputStreamWriter 是 Java 中用于字符流和字节流之间进行转换的转换流类。它们主要用于解决字符编码的问题&#xff0c;在字节流和字符流之间提供了桥梁&#xff0c;可以将字节流转换为字符流或将字符流转换为字节流。 InputStreamReader&#…

提高Python并发性能 - asyncio/aiohttp介绍

在进行大规模数据采集时&#xff0c;如何提高Python爬虫的并发性能是一个关键问题。本文将向您介绍使用asyncio和aiohttp库实现异步网络请求的方法&#xff0c;并通过具体结果和结论展示它们对于优化爬虫效率所带来的效果。 1. 什么是异步编程&#xff1f; 异步编程是一种非阻…

实战教学:农产品小程序商城的搭建与运营

随着移动设备的普及和互联网技术的发展&#xff0c;小程序商城已经成为农产品销售的一种新兴渠道。本文将以乔拓云网为平台&#xff0c;详细介绍如何搭建和运营农产品小程序商城。 步骤一&#xff1a;登录乔拓云网后台 首先&#xff0c;进入乔拓云网站后台&#xff0c;找到并点…

Centos7 使用docker安装oracle数据库(超详细)

在linux中采用解压安装包的方式安装oracle非常麻烦&#xff0c;并且稍微不注意就会出现问题&#xff0c;因此采用docker来安装&#xff0c;下面为详细的步骤&#xff1a; 若不知道是否安装docker可查看这篇文章&#xff1a;docker安装 1、拉取oracle镜像 docker pull registr…

计算机网络 | TCP 三次握手四次挥手 |半关闭连接

本来是不愿意写的&#xff0c;可是在实际场景&#xff0c;对具体的描述标志还是模糊不清&#xff0c;基础不扎实&#xff0c;就得承认&#xff01;&#xff01;&#xff01; TCP 连接建立需要解决三大问题&#xff1a; 知道双方存在约定一些参数&#xff0c;如最大滑动窗口值、…

YOLOv5算法改进(10)— 替换主干网络之GhostNet

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。GhostNet是一种针对计算机视觉任务的深度神经网络架构&#xff0c;它于2020年由中国科学院大学的研究人员提出。GhostNet的设计目标是在保持高精度的同时&#xff0c;减少模型的计算和存储成本。GhostNet通过引入Ghost模块…

浅谈多人游戏原理和简单实现。

&#x1f61c;作 者&#xff1a;是江迪呀✒️本文关键词&#xff1a;websocket、网络、原理、多人游戏☀️每日 一言&#xff1a;这世上有两种东西无法直视&#xff0c;一是太阳&#xff0c;二是人心&#xff01; 一、我的游戏史 我最开始接触游戏要从一盘300游戏…

RSA算法与错误敏感攻击

参见《RSA 算法的错误敏感攻击研究与实践》 RSA 算法简介 RSA 算法原理&#xff1a; 1&#xff09; RSA 算法密钥产生过程 &#xff08;1&#xff09;系统随机产生两个大素数 p p p 和 q q q&#xff0c;对这两个数据保密&#xff1b; &#xff08;2&#xff09;计算 n p …

Java类的声明周期、对象的创建过程

一、类的生命周期 使用类时&#xff0c;要先使用类加载器将类的字节码从磁盘加载到内存的方法区中&#xff0c;用Class对象表示加载到内存中的类&#xff0c;Class类是JDK中提供的类创建对象时&#xff0c;是根据内存中的Class对象&#xff0c;在堆中分配内存&#xff0c;完成…

c语言之指针的学习

1.指针是什么 &#xff08;指针是内存中一个最小单元的编号,也就是地址&#xff09; int main() {int a10;//当我们取出地址a的时候,取出的其实是a占4个字节中的第一个字节的地址int *pa&a;//pa是一个指针变量,用于存放地址//pa在口头语上常说为指针//指针本质上就是地址,…

C++学习|CUFFT计算一维傅里叶变换

CUFFT计算一维傅里叶变换 CUFFT库介绍CUFFTW计算一维傅里叶变换CUFFT计算一维傅里叶变换 前言&#xff1a;之前实现了CPU运行一维傅里叶变换&#xff0c;最近要改成GPU加速一维傅里叶变换&#xff0c;于是有了此篇作为记录&#xff0c;方便以后查阅。 CUFFT库介绍 CUFFT&#…

Protein - ECD (ExtraCellular Domain) 膜蛋白胞外区的 UniProt 与 PDB 数据分析

欢迎关注我的CSDN&#xff1a;https://spike.blog.csdn.net/ 本文地址&#xff1a;https://spike.blog.csdn.net/article/details/132597158 ECD 是 Extracellular Domain 的缩写&#xff0c;指的是跨膜蛋白质的细胞外部分 (膜蛋白的胞外区)&#xff0c;通常包含一些功能性的结…

JVM的故事——类文件结构

类文件结构 文章目录 类文件结构一、概述二、无关性基石三、Class类文件的结构 一、概述 计算机是只认由0、1组成的二进制码的&#xff0c;不过随着发展&#xff0c;我们编写的程序可以被编译成与指令集无关、平台中立的一种格式。 二、无关性基石 对于不同平台和不同平台的…

77GHz线性调频连续波雷达

文章目录 前言 一、背景 二、优缺点 三、工作原理 四、电路模块设计 4.1.LFMCW信号源 4.2.发射电路 4.3.接收电路 4.4.信号处理器 五、应用 5.1.汽车测距 5.2.军事方面 5.3.气象方面 总结 前言 这篇文章是博主本科期间整理的关于77GHz线性调频连续波雷达的相关资料&#xff0c;…

【Java】文件操作和IO

文件操作和IO 文件树形结构组织和目录文件路径 Java中操作文件File 文件内容的读写(数据流)Reader和Writer字符输入流 Reader字符输出流 WriterFileReader 和 FileWriterFileReaderFileWriter InputStream和OutputStreamInputStreamFileInputStreamFileOutputStream 小程序扫描…

Vue3实现24小时倒计时

方法一:时间戳(24小时以内,毫秒为单位)转成时间,并且倒计时 效果预览: <script> // 剩余时间的时间戳,24小时的时间戳是86400000 const exTime = ref(86400000) // 支付时间期限 const payTime = ref() const maxtime = ref(0) //倒计时(时间戳,毫秒单位)转换成秒…