Knowledge-based-BERT(二)

news2024/11/27 0:23:21

多种预训练任务解决NLP处理SMILES的多种弊端,代码:Knowledge-based-BERT,原文:Knowledge-based BERT: a method to extract molecular features like computational chemists,代码解析继续K_BERT_WCL_pretrain。模型框架如下:
在这里插入图片描述

文章目录

  • 1.load_data_for_pretrain
    • 1.1.build_pretrain_selected_tasks
    • 1.2.build_maccs_pretrain_data_and_save
  • 2.loss
  • 3.K_BERT_WCL

args['pretrain_data_path'] = '../data/pretrain_data/CHEMBL_maccs'
pretrain_set = build_data.load_data_for_pretrain(
    pretrain_data_path=args['pretrain_data_path'])
print("Pretrain data generation is complete !")

pretrain_loader = DataLoader(dataset=pretrain_set,
                             batch_size=args['batch_size'],
                             shuffle=True,
                             collate_fn=collate_pretrain_data)

1.load_data_for_pretrain

def load_data_for_pretrain(pretrain_data_path='./data/CHEMBL_wash_500_pretrain'):
    tokens_idx_list = []
    global_labels_list = []
    atom_labels_list = []
    atom_mask_list = []
    for i in range(80):
        pretrain_data = np.load(pretrain_data_path+'_{}.npy'.format(i+1), allow_pickle=True)
        tokens_idx_list = tokens_idx_list + [x for x in pretrain_data[0]]
        global_labels_list = global_labels_list + [x for x in pretrain_data[1]]
        atom_labels_list = atom_labels_list + [x for x in pretrain_data[2]]
        atom_mask_list = atom_mask_list + [x for x in pretrain_data[3]]
        print(pretrain_data_path+'_{}.npy'.format(i+1) + ' is loaded')
    pretrain_data_final = []
    for i in range(len(tokens_idx_list)):
        a_pretrain_data = [tokens_idx_list[i], global_labels_list[i], atom_labels_list[i], atom_mask_list[i]]
        pretrain_data_final.append(a_pretrain_data)
    return pretrain_data_final
  • 和之前的load_data_for_contrastive_aug_pretrain一模一样,只是最后文件载入的不一样,没有_contrastive_,文件是在build_pretrain_selected_tasks中构造的

1.1.build_pretrain_selected_tasks

from experiment.build_data import build_maccs_pretrain_data_and_save
import multiprocessing
import pandas as pd

task_name = 'CHEMBL'
if __name__ == "__main__":
    n_thread = 8
    data = pd.read_csv('../pretrain_data/'+task_name+'.csv')
    smiles_list = data['smiles'].values.tolist()
    # 避免内存不足,将数据集分为10份来计算
    for i in range(10):
        n_split = int(len(smiles_list)/10)
        smiles_split = smiles_list[i*n_split:(i+1)*n_split]

        n_mol = int(len(smiles_split)/8)

        # creating processes
        p1 = multiprocessing.Process(target=build_maccs_pretrain_data_and_save, args=(smiles_split[:n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_'+str(i*8+1)+'.npy'))
        p2 = multiprocessing.Process(target=build_maccs_pretrain_data_and_save, args=(smiles_split[n_mol:2*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_'+str(i*8+2)+'.npy'))
        p3 = multiprocessing.Process(target=build_maccs_pretrain_data_and_save, args=(smiles_split[2*n_mol:3*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_'+str(i*8+3)+'.npy'))
        p4 = multiprocessing.Process(target=build_maccs_pretrain_data_and_save, args=(smiles_split[3*n_mol:4*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_'+str(i*8+4)+'.npy'))
        p5 = multiprocessing.Process(target=build_maccs_pretrain_data_and_save, args=(smiles_split[4*n_mol:5*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_'+str(i*8+5)+'.npy'))
        p6 = multiprocessing.Process(target=build_maccs_pretrain_data_and_save, args=(smiles_split[5*n_mol:6*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_'+str(i*8+6)+'.npy'))
        p7 = multiprocessing.Process(target=build_maccs_pretrain_data_and_save, args=(smiles_split[6*n_mol:7*n_mol],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_'+str(i*8+7)+'.npy'))
        p8 = multiprocessing.Process(target=build_maccs_pretrain_data_and_save, args=(smiles_split[7*n_mol:],
                                                                                '../data/pretrain_data/'+task_name+'_maccs_'+str(i*8+8)+'.npy'))

        # starting my_scaffold_split 1&2
        p1.start()
        p2.start()
        p3.start()
        p4.start()
        p5.start()
        p6.start()
        p7.start()
        p8.start()

        # wait until my_scaffold_split 1&2 is finished
        p1.join()
        p2.join()
        p3.join()
        p4.join()
        p5.join()
        p6.join()
        p7.join()
        p8.join()


        # both processes finished
        print("Done!")

  • 和之前的一样,只是调用的函数不同,这里是build_maccs_pretrain_data_and_save,这里文件没有 _5_contrastive_aug,应该只有一个smiles列

1.2.build_maccs_pretrain_data_and_save

def build_maccs_pretrain_data_and_save(smiles_list, output_smiles_path, global_feature='MACCS'):
    smiles_list = smiles_list
    tokens_idx_list = []
    global_label_list = []
    atom_labels_list = []
    atom_mask_list = []
    for i, smiles in enumerate(smiles_list):
        tokens_idx, global_labels, atom_labels, atom_mask = construct_input_from_smiles(smiles,
                                                                                        global_feature=global_feature)
        if tokens_idx != 0:
            tokens_idx_list.append(tokens_idx)
            global_label_list.append(global_labels)
            atom_labels_list.append(atom_labels)
            atom_mask_list.append(atom_mask)
            print('{}/{} is transformed!'.format(i+1, len(smiles_list)))
        else:
            print('{} is transformed failed!'.format(smiles))
    pretrain_data_list = [tokens_idx_list, global_label_list, atom_labels_list, atom_mask_list]
    pretrain_data_np = np.array(pretrain_data_list)
    np.save(output_smiles_path, pretrain_data_np)
  • 读入的smiles只有一个,与之前调用的函数一样

2.loss

global_pos_weight = torch.tensor([884.17, 70.71, 43.32, 118.73, 428.67, 829.0, 192.84, 67.89, 533.86, 18.46, 707.55, 160.14, 23.19, 26.33, 13.38, 12.45, 44.91, 173.58, 40.14, 67.25, 171.12, 8.84, 8.36, 43.63, 5.87, 10.2, 3.06, 161.72, 101.75, 20.01, 4.35, 12.62, 331.79, 31.17, 23.19, 5.91, 53.58, 15.73, 10.75, 6.84, 3.92, 6.52, 6.33, 6.74, 24.7, 2.67, 6.64, 5.4, 6.71, 6.51, 1.35, 24.07, 5.2, 0.74, 4.78, 6.1, 62.43, 6.1, 12.57, 9.44, 3.33, 5.71, 4.67, 0.98, 8.2, 1.28, 9.13, 1.1, 1.03, 2.46, 2.95, 0.74, 6.24, 0.96, 1.72, 2.25, 2.16, 2.87, 1.8, 1.62, 0.76, 1.78, 1.74, 1.08, 0.65, 0.97, 0.71, 5.08, 0.75, 0.85, 3.3, 4.79, 1.72, 0.78, 1.46, 1.8, 2.97, 2.18, 0.61, 0.61, 1.83, 1.19, 4.68, 3.08, 2.83, 0.51, 0.77, 6.31, 0.47, 0.29, 0.58, 2.76, 1.48, 0.25, 1.33, 0.69, 1.03, 0.97, 3.27, 1.31, 1.22, 0.85, 1.75, 1.02, 1.13, 0.16, 1.02, 2.2, 1.72, 2.9, 0.26, 0.69, 0.6, 0.23, 0.76, 0.73, 0.47, 1.13, 0.48, 0.53, 0.72, 0.38, 0.35, 0.48, 0.12, 0.52, 0.15, 0.28, 0.36, 0.08, 0.06, 0.03, 0.07, 0.01])
atom_pos_weight = torch.tensor([4.81, 1.0, 2.23, 53.49, 211.94, 0.49, 2.1, 1.13, 1.22, 1.93, 5.74, 15.42, 70.09, 61.47, 23.2])
loss_criterion_global = torch.nn.BCEWithLogitsLoss(reduction='none', pos_weight=global_pos_weight.to('cuda'))
loss_criterion_atom = torch.nn.BCEWithLogitsLoss(reduction='none', pos_weight=atom_pos_weight.to('cuda'))
  • loss与之前一致

3.K_BERT_WCL

model = K_BERT_WCL(d_model=args['d_model'], n_layers=args['n_layers'], vocab_size=args['vocab_size'],
                   maxlen=args['maxlen'], d_k=args['d_k'], d_v=args['d_v'], n_heads=args['n_heads'], d_ff=args['d_ff'],
                   global_label_dim=args['global_labels_dim'], atom_label_dim=args['atom_labels_dim'])
class K_BERT_WCL(nn.Module):
    def __init__(self, d_model, n_layers, vocab_size, maxlen, d_k, d_v, n_heads, d_ff, global_label_dim, atom_label_dim,
                 use_atom=False):
        super(K_BERT_WCL, self).__init__()
        self.maxlen = maxlen
        self.d_model = d_model
        self.use_atom = use_atom
        self.embedding = Embedding(vocab_size, self.d_model, maxlen)
        self.layers = nn.ModuleList([EncoderLayer(self.d_model, d_k, d_v, n_heads, d_ff) for _ in range(n_layers)])
        if self.use_atom:
            self.fc = nn.Sequential(
                nn.Dropout(0.),
                nn.Linear(self.d_model + self.d_model, self.d_model),
                nn.ReLU(),
                nn.BatchNorm1d(self.d_model))
            self.fc_weight = nn.Sequential(
                nn.Linear(self.d_model, 1),
                nn.Sigmoid())
        else:
            self.fc = nn.Sequential(
                nn.Dropout(0.),
                nn.Linear(self.d_model, self.d_model),
                nn.ReLU(),
                nn.BatchNorm1d(self.d_model))
        self.classifier_global = nn.Linear(self.d_model, global_label_dim)
        self.classifier_atom = nn.Linear(self.d_model, atom_label_dim)

    def forward(self, input_ids):
        output = self.embedding(input_ids)
        enc_self_attn_mask = get_attn_pad_mask(input_ids)
        for layer in self.layers:
            output = layer(output, enc_self_attn_mask)
        h_global = output[:, 0]
        if self.use_atom:
            h_atom = output[:, 1:]
            h_atom_weight = self.fc_weight(h_atom)
            h_atom_weight_expand = h_atom_weight.expand(h_atom.size())
            h_atom_mean = (h_atom*h_atom_weight_expand).mean(dim=1)
            h_mol = torch.cat([h_global, h_atom_mean], dim=1)
        else:
            h_mol = h_global
        h_embedding = self.fc(h_mol)
        logits_global = self.classifier_global(h_embedding)
        return logits_global
  • 只有分子水平的global一个任务,没有对比学习任务,也没有原子特征水平的global任务

  • 其他部分差不多,这里不在详细分析

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/180825.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java多线程 - 线程安全和线程同步解决线程安全问题

文章目录线程安全问题线程同步方式一: 同步代码块方式二: 同步方法方式三: Lock锁线程安全问题 线程安全问题指的是: 多个线程同时操作同一个共享资源的时候可能会出现业务安全问题,称为线程安全问题。 举例: 取钱模型演示 需求:小明和小红是一对夫妻&am…

建单向链表-C语言实现

任务描述 本关需要你建立一个带头结点的单向链表。 相关知识 什么是链表?链表和二叉树是C语言数据结构的基础和核心。 链表有多种形式,它可以是单链接的或者双链接的,可以是已排序的或未排序的,可以是循环的或非循环的。 本关让我们来学习单链表。 单链表 单向链表(单…

XC-16 SpringSecurity Oauth2 JWT

SpringSecurityOauth2用户认证需求分析用户认证与授权单点登录需求第三方认证需求用户认证技术方案单点登录技术方案Oauth2认证Oauth2认证流程2.2.2Oauth2在本项目中的应用SpringSecurity Oauth2认证解决方案SpringSecurityOauth2研目标搭建认证服务器导入基础工程创建数据库Oa…

一起自学SLAM算法:9.2 LSD-SLAM算法

连载文章,长期更新,欢迎关注: 下面将从原理分析、源码解读和安装与运行这3个方面展开讲解LSD-SLAM算法。 9.2.1 LSD-SLAM原理分析 前面已经说过,LSD-SLAM算法是直接法的典型代表。因此在下面的分析中,首先介绍一下直…

学习笔记:Java 并发编程④

若文章内容或图片失效,请留言反馈。 部分素材来自网络,若不小心影响到您的利益,请联系博主删除。 视频链接:https://www.bilibili.com/video/av81461839配套资料:https://pan.baidu.com/s/1lSDty6-hzCWTXFYuqThRPw&am…

CSS语法格式与三种引入方式

文章目录第一章——CSS简介1.1 CSS语法格式1.2 CSS 位置1.3 CSS引入方式1.3.1.行内样式表(内联样式表)1.3.2 外部样式表1.3.3 内部样式表第一章——CSS简介 1.1 CSS语法格式 CSS 规则由两个主要的部分构成:选择器以及一条或多条声明。 选择…

C语言全局变量和局部变量

局部变量定义在函数内部的变量称为局部变量(Local Variable),它的作用域仅限于函数内部, 离开该函数后就是无效的,再使用就会报错。例如:intf1(int a){ int b,c;//a,b,c仅在函数f1()内有效 return abc; } i…

各种CV领域 Attention (原理+代码大全)

人类在处理信息时,天然会过滤掉不太关注的信息,着重于感兴趣信息,于是将这种处理信息的机制称为注意力机制。 注意力机制分类:软注意力机制(全局注意)、硬注意力机制(局部注意)、和…

打工人必知必会(三)——经济补偿金和赔偿金的那些事

目录 参考 一、经济补偿金&赔偿金-用人单位承担赔偿责任 1、月平均工资是税前还是税后工资? 3、经济补偿金是否要交个人所得税?如何交? 二、劳动者承担赔偿责任 三、劳动者需要特别注意 参考 《HR全程法律顾问:企业人力资…

Day12 XML配置AOP

1 前言前文我们已经介绍了AOP概念Day11 AOP介绍&#xff0c;并将其总结如下&#xff1a;2 AOP 标签和expression表达式学习<?xml version"1.0" encoding"UTF-8"?> <beans xmlns"http://www.springframework.org/schema/beans"xmlns:x…

3.4只读存储器ROM

文章目录一、引子二、介绍1.MROM2.PROM3.EPROM4.Flash Memory5.SSD三、运行过程四、回顾一、引子 这一小节&#xff0c;我们学习只读存储器ROM。 上一小节&#xff0c;学习了两种RAM芯片&#xff0c;分别是SRAM和DRAM。详情请戳&#xff1a;3.3Sram和Dram RAM芯片可以支持随…

Pygame创建界面

今天开始对Python的外置包pygame进行学习&#xff0c;pygame是Python的游戏包&#xff0c;使用该包可以设计一些简单的小游戏。 前言 利用Python外置包创建一个简单界面&#xff0c;首先需要下载Python外置包pygame 使用语句&#xff1a;pip install pygame Display模块 创建…

红黑树知识点回顾

Rudolf Bayer 于1978年发明红黑树&#xff0c;在当时被称为对称二叉 B 树(symmetric binary B-trees)。后来&#xff0c;在1978年被 Leo J. Guibas 和 Robert Sedgewick 修改为如今的红黑树。 红黑树具有良好的效率&#xff0c;它可在近似O(logN) 时间复杂度下完成插入、删除、…

实验五、任意N进制异步计数器设计

实验五 任意N进制异步计数器设计 实验目的 掌握任意N进制异步计数器设计的方法。 实验要求 一人一组&#xff0c;独立上机。在电脑上利用Multisim软件完成实验内容。 实验内容 说明任意N进制异步计数器的构成方法 设计过程 集成计数器一般都设有清零端和置数输入端&#xff…

3.7动态规划--图像压缩

3.6多边形游戏&#xff0c;多边形最优三角剖分类似&#xff0c;仅仅是最优子结构的性质不同&#xff0c;这个多边形游戏更加具有一般性。不想看了&#xff0c;跳过。 写在前面 明确数组含义&#xff1a; l: l[i]存放第i段长度, 表中各项均为8位长&#xff0c;限制了相同位数…

ElasticSearch - RestClient操作ES基本操作

目录 什么是RestClient hotel数据结构分析 初始化RestClient 创建索引库 删除索引库 判断索引库是否存在 小结 新增文档 查询文档 更新文档 删除文档 批量导入文档 小结 什么是RestClient ES官方提供了各种不同语言的客户端&#xff0c;用来操作ES这些客户端的本质…

Java基础语法——方法

目录 方法概述 方法定义及格式 方法重载 •方法重载概述 •方法重载特点 方法中基本数据类型和引用数据类型的传递 方法概述 ——假设有一个游戏程序&#xff0c;程序在运行过程中&#xff0c;要不断地发射炮弹(植物大战僵尸)。发射炮弹的动作需要编写100行的代码&…

五、在测试集上评估图像分类算法精度(Datawhale组队学习)

文章目录配置环境准备图像分类数据集和模型文件测试集图像分类预测结果表格A-测试集图像路径及标注表格B-测试集每张图像的图像分类预测结果&#xff0c;以及各类别置信度可视化测试集中被误判的图像测试集总体准确率评估指标常见评估指标混淆矩阵PR曲线绘制某一类别的PR曲线绘…

密码学的100个基本概念

密码学的100个基本概念一、密码学历史二、密码学基础三、分组密码四、序列密码五、哈希函数六、公钥密码七、数字签名八、密码协议九、密钥管理十、量子密码2022年主要完成了密码学专栏的编写&#xff0c;较为系统的介绍了从传统密码到现代密码&#xff0c;以及量子密码的相关概…

C语言函数声明以及函数原型

C语言代码由上到下依次执行&#xff0c;原则上函数定义要出现在函数调用之前&#xff0c;否则就会报错。但在实际开发中&#xff0c;经常会在函数定义之前使用它们&#xff0c;这个时候就需要提前声明。所谓声明&#xff08;Declaration&#xff09;&#xff0c;就是告诉编译器…