Knowledge-based-BERT(一)

news2025/1/9 1:56:57

多种预训练任务解决NLP处理SMILES的多种弊端,代码:Knowledge-based-BERT,原文:Knowledge-based BERT: a method to extract molecular features like computational chemists,代码解析从K_BERT_pretrain开始。模型框架如下:
在这里插入图片描述

文章目录

  • 1.K_BERT_pretrain
    • 1.1.load_data_for_contrastive_aug_pretrain
      • 1.1.1.build_contrastive_pretrain_selected_tasks
      • 1.1.2.build_maccs_pretrain_contrastive_data_and_save
      • 1.1.3.construct_input_from_smiles

1.K_BERT_pretrain

args['pretrain_data_path'] = '../data/pretrain_data/CHEMBL_maccs'
args['batch_size'] = 32
pretrain_set = build_data.load_data_for_contrastive_aug_pretrain(
                                        pretrain_data_path=args['pretrain_data_path'])
print("Pretrain data generation is complete !")

pretrain_loader = DataLoader(dataset=pretrain_set,
                             batch_size=args['batch_size'],
                             shuffle=True,
                             collate_fn=collate_pretrain_data)

1.1.load_data_for_contrastive_aug_pretrain

def load_data_for_contrastive_aug_pretrain(pretrain_data_path='./data/CHEMBL_wash_500_pretrain'):
    tokens_idx_list = []
    global_labels_list = []
    atom_labels_list = []
    atom_mask_list = []
    for i in range(80):
        pretrain_data = np.load(pretrain_data_path+'_contrastive_{}.npy'.format(i+1), allow_pickle=True)
        tokens_idx_list = tokens_idx_list + [x for x in pretrain_data[0]]
        global_labels_list = global_labels_list + [x for x in pretrain_data[1]]
        atom_labels_list = atom_labels_list + [x for x in pretrain_data[2]]
        atom_mask_list = atom_mask_list + [x for x in pretrain_data[3]]
        print(pretrain_data_path+'_contrastive_{}.npy'.format(i+1) + ' is loaded')
    pretrain_data_final = []
    for i in range(len(tokens_idx_list)):
        a_pretrain_data = [tokens_idx_list[i], global_labels_list[i], atom_labels_list[i], atom_mask_list[i]]
        pretrain_data_final.append(a_pretrain_data)
    return pretrain_data_final
  • CHEMBL_maccs_contrastive_{}.npy 是在 build_contrastive_pretrain_selected_tasks 文件中构造的
  • 通过下面的分析,最终 .npy 存储的内容应该是 tokens_idx_all_list, global_label_list, atom_labels_list, atom_mask_list,其中 tokens_idx_all_list 是某个分子的5个SMILES编码转化为token后的下标列表,shape应该是(n_smiles,5,201),其他几个的shape在下面有示例,应该只是多了 n_smiles 这个维度

1.1.1.build_contrastive_pretrain_selected_tasks

from experiment.build_data import build_maccs_pretrain_contrastive_data_and_save
import multiprocessing
import pandas as pd

task_name = 'CHEMBL'
if __name__ == "__main__":
    n_thread = 8
    data = pd.read_csv('../data/pretrain_data/'+task_name+'_5_contrastive_aug.csv')
    smiles_name_list = ['smiles', 'aug_smiles_0', 'aug_smiles_1', 'aug_smiles_2', 'aug_smiles_3']
    smiles_list = data[smiles_name_list].values.tolist()

    # 避免内存不足,将数据集分为10份来计算
    for i in range(10):
        n_split = int(len(smiles_list)/10)
        smiles_split = smiles_list[i*n_split:(i+1)*n_split]

        n_mol = int(len(smiles_split)/8)

        # creating processes
        p1 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[:n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+1)+'.npy'))
        p2 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[n_mol:2*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+2)+'.npy'))
        p3 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[2*n_mol:3*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+3)+'.npy'))
        p4 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[3*n_mol:4*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+4)+'.npy'))
        p5 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[4*n_mol:5*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+5)+'.npy'))
        p6 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[5*n_mol:6*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+6)+'.npy'))
        p7 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[6*n_mol:7*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+7)+'.npy'))
        p8 = multiprocessing.Process(target=build_maccs_pretrain_contrastive_data_and_save, args=(smiles_split[7*n_mol:],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_contrastive_'+str(i*8+8)+'.npy'))

        # starting my_scaffold_split 1&2
        p1.start()
        p2.start()
        p3.start()
        p4.start()
        p5.start()
        p6.start()
        p7.start()
        p8.start()

        # wait until my_scaffold_split 1&2 is finished
        p1.join()
        p2.join()
        p3.join()
        p4.join()
        p5.join()
        p6.join()
        p7.join()
        p8.join()


        # both processes finished
        print("Done!")
  • 在 CHEMBAL 收集分子后,经过数据增强存成SMILES,这里读入生成 .npy
  • 输入 smiles_list 的格式如下,每一行是一个分子的五个SMILES:
import pandas as pd
import numpy as np
smiles_name_list = ['smiles', 'aug_smiles_0', 'aug_smiles_1', 'aug_smiles_2', 'aug_smiles_3']
data=pd.DataFrame(np.arange(15).reshape(3,5),columns=smiles_name_list)
smiles_list = data[smiles_name_list].values.tolist()
smiles_list
#[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14]]

1.1.2.build_maccs_pretrain_contrastive_data_and_save

def build_maccs_pretrain_contrastive_data_and_save(smiles_list, output_smiles_path, global_feature='MACCS'):
    # all smiles list
    smiles_list = smiles_list
    tokens_idx_all_list = []
    global_label_list = []
    atom_labels_list = []
    atom_mask_list = []
    for i, smiles_one_mol in enumerate(smiles_list):
        tokens_idx_list = [construct_input_from_smiles(smiles, global_feature=global_feature)[0] for
                           smiles in smiles_one_mol]
        if 0 not in tokens_idx_list:
            _ , global_labels, atom_labels, atom_mask = construct_input_from_smiles(smiles_one_mol[0],
                                                                global_feature=global_feature)
            tokens_idx_all_list.append(tokens_idx_list)
            global_label_list.append(global_labels)
            atom_labels_list.append(atom_labels)
            atom_mask_list.append(atom_mask)
            print('{}/{} is transformed!'.format(i+1, len(smiles_list)))
        else:
            print('{} is transformed failed!'.format(smiles_one_mol[0]))
    pretrain_data_list = [tokens_idx_all_list, global_label_list, atom_labels_list, atom_mask_list]
    pretrain_data_np = np.array(pretrain_data_list, dtype=object)
    np.save(output_smiles_path, pretrain_data_np)

tokens_idx_list 取 construct_input_from_smiles 返回的第一个元素

1.1.3.construct_input_from_smiles

def construct_input_from_smiles(smiles, max_len=200, global_feature='MACCS'):
    try:
        # built a pretrain data from smiles
        atom_list = []
        atom_token_list = ['c', 'C', 'O', 'N', 'n', '[C@H]', 'F', '[C@@H]', 'S', 'Cl', '[nH]', 's', 'o', '[C@]',
                           '[C@@]', '[O-]', '[N+]', 'Br', 'P', '[n+]', 'I', '[S+]',  '[N-]', '[Si]', 'B', '[Se]', '[other_atom]']
        all_token_list = ['[PAD]', '[GLO]', 'c', 'C', '(', ')', 'O', '1', '2', '=', 'N', '3', 'n', '4', '[C@H]', 'F', '[C@@H]', '-', 'S', '/', 'Cl', '[nH]', 's', 'o', '5', '#', '[C@]', '[C@@]', '\\', '[O-]', '[N+]', 'Br', '6', 'P', '[n+]', '7', 'I', '[S+]', '8', '[N-]', '[Si]', 'B', '9', '[2H]', '[Se]', '[other_atom]', '[other_token]']

        # 构建token转化成idx的字典
        word2idx = {}
        for i, w in enumerate(all_token_list):
            word2idx[w] = i
        # 构建token_list 并加上padding和global
        token_list = smi_tokenizer(smiles)
        padding_list = ['[PAD]' for x in range(max_len-len(token_list))]
        tokens = ['[GLO]'] + token_list + padding_list
        mol = MolFromSmiles(smiles)
        atom_example = mol.GetAtomWithIdx(0)
        atom_labels_example = atom_labels(atom_example)
        atom_mask_labels = [2 for x in range(len(atom_labels_example))]
        atom_labels_list = []
        atom_mask_list = []

        index = 0
        tokens_idx = []
        for i, token in enumerate(tokens):
            if token in atom_token_list:
                atom = mol.GetAtomWithIdx(index)
                an_atom_labels = atom_labels(atom)
                atom_labels_list.append(an_atom_labels)
                atom_mask_list.append(1)
                index = index + 1
                tokens_idx.append(word2idx[token])
            else:
                if token in all_token_list:
                    atom_labels_list.append(atom_mask_labels)
                    tokens_idx.append(word2idx[token])
                    atom_mask_list.append(0)
                elif '[' in list(token):
                    atom = mol.GetAtomWithIdx(index)
                    tokens[i] = '[other_atom]'
                    an_atom_labels = atom_labels(atom)
                    atom_labels_list.append(an_atom_labels)
                    atom_mask_list.append(1)
                    index = index + 1
                    tokens_idx.append(word2idx['[other_atom]'])
                else:
                    tokens[i] = '[other_token]'
                    atom_labels_list.append(atom_mask_labels)
                    tokens_idx.append(word2idx['[other_token]'])
                    atom_mask_list.append(0)
        if global_feature == 'MACCS':
            global_label_list = global_maccs_data(smiles)
        elif global_feature == 'ECFP4':
            global_label_list = global_ecfp4_data(smiles)
        elif global_feature == 'RDKIT_des':
            global_label_list = global_rdkit_des_data(smiles)

        tokens_idx = [word2idx[x] for x in tokens]
        if len(tokens_idx) == max_len + 1:
            return tokens_idx, global_label_list, atom_labels_list, atom_mask_list
        else:
            return 0, 0, 0, 0
    except:
        return 0, 0, 0, 0
def smi_tokenizer(smi):
    """
    Tokenize a SMILES molecule or reaction
    """
    import re
    pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
    regex = re.compile(pattern)
    tokens = [token for token in regex.findall(smi)]
    # assert smi == ''.join(tokens)
    # return ' '.join(tokens)
    return tokens
    """
    smi='C=CCC=CCO'
	smi_tokenizer(smi)
	#['C', '=', 'C', 'C', 'C', '=', 'C', 'C', 'O']
    """
def atom_labels(atom, use_chirality=True):
    results = one_of_k_encoding(atom.GetDegree(),
                                [0, 1, 2, 3, 4, 5, 6]) + \
              one_of_k_encoding_unk(atom.GetHybridization(), [
                  Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                  Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D,
                  Chem.rdchem.HybridizationType.SP3D2, 'other']) + [atom.GetIsAromatic()] \
              + one_of_k_encoding_unk(atom.GetTotalNumHs(),
                                                  [0, 1, 2, 3, 4])
    if use_chirality:
        try:
            results = results + one_of_k_encoding_unk(
                atom.GetProp('_CIPCode'),
                ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
        except:
            results = results + [False, False
                                 ] + [atom.HasProp('_ChiralityPossible')]
    atom_labels_list = np.array(results).tolist()
    atom_selected_index = [1, 2, 3, 4, 7, 8, 9, 13, 14, 15, 16, 17, 19, 20, 21]
    atom_labels_selected = [atom_labels_list[x] for x in atom_selected_index]
    return atom_labels_selected
    """
	from rdkit.Chem import *
	from build_data import atom_labels
	mol = MolFromSmiles(smi)
	atom_example = mol.GetAtomWithIdx(0)
	atom_labels_example = atom_labels(atom_example)
	atom_labels_example
	#[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0]
    """
  • tokens_idx 是 SMILES 转换为 tokens 后对应的下标列表,global_label_list 是根据 SMILES 算出的各种描述符,这里是 global_maccs_data,atom_labels_list 是分子中每个原子编码,如果 token 不是原子就设为全2,atom_mask_list 是 token 是否是原子的标记,构建失败返回全0,正确的话 tokens_idx 是一个列表,构建失败就是数值0
def global_maccs_data(smiles):
    mol = Chem.MolFromSmiles(smiles)
    maccs = MACCSkeys.GenMACCSKeys(mol)
    global_maccs_list = np.array(maccs).tolist()
    # 选择负/正样本比例小于1000且大于0.001的数据
    selected_index = [3, 8, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165]
    selected_global_list = [global_maccs_list[x] for x in selected_index]
    return selected_global_list
  • 使用示例如下,具体实现的if-else细节处理不再深入
from build_data import *
import numpy as np
smi1='C=CCC=CCO'
smi2='OCC=CCC=C'
res=construct_input_from_smiles(smi1)
#res=construct_input_from_smiles(smi2)
len(res),np.array(res[0]).shape,np.array(res[1]).shape,np.array(res[2]).shape,np.array(res[3]).shape
#(4, (201,), (154,), (201, 15), (201,)) smi1
#(4, (201,), (154,), (201, 15), (201,)) smi2
  • 201是pad到200再加glo,154 是 selected_index 的长度,每个 token 编码为长度为15的向量

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/177727.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Tkinter的Listbox控件

Tkinter的Listbox控件是个选项框,主要是用来在给定的选项中选择一个 使用方法 创建选项框Listbox 和其他控件的创建方法一样,直接创建即可,命名为Lb Lbtk.Listbox(root) Lb.pack() 在选项框中加入选项 可以边创建边添加,即利…

【C#】WPF实现经典纸牌游戏,适合新手入门

文章目录1 纸牌类2 布局3 初始化4 事件点击牌堆拖动牌的去留源代码1 纸牌类 之所以产生这个无聊至极的念头,是因为发现Unicode中竟然有这种字符。。。 黑桃🂡 🂢 🂣 🂤 🂥 🂦 🂧 &…

【设计模式】结构型模式·外观模式

学习汇总入口【23种设计模式】学习汇总(数万字讲解体系思维导图) 写作不易,如果您觉得写的不错,欢迎给博主来一波点赞、收藏~让博主更有动力吧!> 学习汇总入口 一.概述 外观(Facade)模式是七大设计原则“迪米特法则…

谷粒商城-高级篇-Day12-性能压测和缓存

文章目录性能优化nginx动静分离优化三级分类的获取(优化业务)分布式缓存整合redis高并发下的缓存失效问题缓存穿透缓存雪崩缓存击穿解决这些问题分布式锁Redisson可重入锁(Reentrant Lock)指定过期时间读写锁闭锁信号量使用Redssi…

Python实现一个简易的CLI翻译程序

Python实现一个简易的CLI翻译程序Python百度翻译API实现一个简易的CLI翻译程序获取百度翻译API编写一个简单的Python程序Python百度翻译API实现一个简易的CLI翻译程序 之前翻译用的linux上的golddict,每次翻译都很慢。。。 所以想写一个简单快速的翻译命令行翻译软件 获取百度…

Allegro如何自动高亮不等长的网络操作指导

Allegro如何自动高亮不等长的网络操作指导 在做PCB设计的时候,时常需要要做等长,Allegro可以自动高亮一组内不等长的网络,可以直观的看到哪些网络长度是不满足的,类似下图 绿色的是通过的,红色是长度不足的,粉色是超长的 具体操作如下 选择Route-Timing Vision出现optio…

Springboot359的医院病历管理系统

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 2 第3章 系统分析 3 3.1 需求分析 3 3.2 系统可行性分析 4 3.2.1技术可行性:技术背景 4 3.2.2经济…

Ubiquiti MAC Address Changer 3.0 Crack

Ubiquiti MAC Address Changer,目前mac address changer的版本有很多,本次发布的是V3版本,这是一款功能非常强大的修改网卡mac地址软件,基本上所有的网卡MAC地址都支持修改,包括虚拟机和TeamViewer软件都是支持的。 Ea…

5、基本数据类型

目录 一、整数类型 二、浮点类型 三、字符类型 四、布尔类型 一、整数类型 整数类型用来存储整数数值,即没有小数部分的数值。可以是正数,也可以是负数。整 型数据在Java程序中有3种表示形式,分别为十进制、八进制和十六进制。 1.十进…

2.4.4 数值类型的转换

文章目录1.运算时的自转2.运算时的强转3.强转时的精度丢失问题1.运算时的自转 不同数字类型之间的大小关系如下:double > float > long > int > char, short,byte 自转:小类型的数据可以直接赋值给大类型的变量; byte short c…

Linux(五)创建一个miniShell

前情提要:掌握进程控制中的进程创建、进程终止、进程等待、进程替换。可以参考下方博文 LInux(四)进程控制(创建、终止、等待、替换) 了解strtok函数的使用 正文: 目录 Shell是什么? 如何…

蓝桥杯之二分与前缀和

蓝桥杯之二分二分板子?第一次和最后一次出现的位置机器人跳跃问题四平方和分巧克力?典型二分找大的(从右往左找)二分upper_bound(a1,an1,x)-a?递增三元组前缀和取余?K倍区间二维前缀和?激光炸弹…

17种编程语言实现排序算法-合并排序

开源地址 https://gitee.com/lblbc/simple-works/tree/master/sort/ 覆盖语言:C、C、C#、Java、Kotlin、Dart、Go、JavaScript(JS)、TypeScript(TS)、ArkTS、swift、PHP。 覆盖平台:安卓(Java、Kotlin)、iOS(SwiftUI)、Flutter(Dart)、Window桌面(C#)、…

分享139个ASP源码,总有一款适合您

ASP源码 分享139个ASP源码,总有一款适合您 下面是文件的名字,我放了一些图片,文章里不是所有的图主要是放不下..., 139个ASP源码下载链接:https://pan.baidu.com/s/1Vk4U4EXVCWZWPMWf9ax2dw?pwdif23 提取码&#x…

【C++】类和对象(上)---什么是类?

目录1.面向过程和面向对象初步认识2.类的引入2.1使用struct定义类3.类的定义3.1类的两种定义方式:3.2成员变量命名规则的建议3.3成员函数与成员变量定义的位置建议4.类的访问限定符及封装4.1访问限定符4.2封装5.类的作用域6.类的实例化7.类对象模型7.1如何计算类对象…

springboot静态资源目录访问,及自定义静态资源路径,index页面的访问

springboot静态资源目录访问,及自定义静态资源路径,index页面的访问静态资源目录的访问位置静态资源访问测试自定义静态资源路径和静态资源请求映射web首页的访问自定义静态资源请求映射影响index.html首页的访问的**解决方案**:1.取消自定义…

【JUC系列】CountDownLatch实现原理

简单示例 public class Main {private static final int NUM 3;public static void main(String[] args) throws InterruptedException {CountDownLatch latch new CountDownLatch(NUM);for (int i 0; i < NUM; i) {new Thread(() -> {try {Thread.sleep(2000);Syste…

梯度之上:Hessian 矩阵

原文链接&#xff1a;原文 文章目录梯度之上&#xff1a;Hessian 矩阵梯度、雅克比矩阵海森矩阵海森矩阵应用梯度之上&#xff1a;Hessian 矩阵 本文讨论研究梯度下降法的一个有力的数学工具&#xff1a;海森矩阵。在讨论海森矩阵之前&#xff0c;需要首先了解梯度和雅克比矩阵…

基础知识一览3

这里写目录标题1.Servlet1.1 入门1.2 什么是Servlet1.3 Servlet的作用1.4 Servlet生命周期1.5 Servler的体系结构1.6 Servler的两种配置方式2.Filter2.1 Filter拦截路径配置2.2 过滤器链2.2 入门2.3 过滤器链2.4 过滤器生命周期3.Listener3.1 监听器分类3.1.1 一类监听器4.Serv…

ESP32设备驱动-GA1A12S202光线传感器驱动

GA1A12S202光线传感器驱动 1、GA1A2S202介绍 GA1A1S202 对数刻度模拟光传感器使用起来非常简单,只需添加电源,然后监控模拟输出。大多数光传感器对光强度具有线性响应,这意味着它们对低光水平非常不敏感,然后在高光水平下达到最大值。另一方面,该传感器具有对数响应,这…