Python 分子图分类,GNN Model for HIV Molecules Classification,HIV 分子图分类模型;整图分类问题,代码实战

news2025/1/11 14:46:13

一、分子图

分子图(molecular graph)是一种用来表示分子结构的图形方式,其中原子被表示为节点(vertices),化学键被表示为边(edges)。对于HIV(人类免疫缺陷病毒),分子图可以用来详细描述其复杂的化学结构和相互作用,这对于理解HIV的生物学特性和开发治疗药物至关重要。

 二、分子图分类

直接撸代码,以HIV为例

HIV分子图一共有8个问题,现在取第一个问题:

"The human immunodeficiency viruses (HIV) are a type of retrovirus, which induce acquired immune deficiency syndrome (AIDs).  Now there are six main classes of antiretroviral drugs to treating AIDs patients approved by FDA, which are the nucleoside reverse transcriptase inhibitors (NRTIs), the non-nucleoside reverse transcriptase inhibitors (NNRTIs), the protease inhibitors, the integrase inhibitor, the fusion inhibitor, and the chemokine receptor CCR5 antagonist. Due to the missing 3\u2019hydroxyl group, NRTIs prevent the formation of a 3\u2019-5\u2019-phosphodiester bond in growing DNA chains. The hydroxyl group of the inhibitor interacts with the carboxyl group of the protease active site residues, Asp 25 and Asp 25\u2032, by hydrogen bonds. The inhibitor-contacting residues of HIV protease are relatively conserved, including Gly 27, Asp 29, Asp 30, and Gly 48. Is this molecule effective to this assay?",
Answer is yes or no.

针对上面问题,进行二分类

代码环节

oversample_data.py 数据正负样本均衡处理,进行上采样 

在机器学习和统计学中,上采样(Upsampling)和下采样(Downsampling)是两种常用的数据预处理技术,它们用于处理数据集中的类别不平衡问题。类别不平衡指的是数据集中某些类别的样本数量远多于其他类别,这可能会导致模型偏向于多数类别,从而影响模型的泛化能力和公平性。以下是上采样和下采样的基本概念:

上采样(Upsampling)

上采样是指增加少数类别(underrepresented classes)的样本数量,使其与多数类别(overrepresented classes)的样本数量相匹配或接近。这样做的目的是减少模型对多数类别的偏好,提高对少数类别的识别能力。上采样可以通过以下几种方式实现:

  1. 简单复制:直接复制少数类别的样本,直到其数量与多数类别相等。
  2. 随机采样:从少数类别中随机抽取样本,直到达到所需的数量。
  3. SMOTE(Synthetic Minority Over-sampling Technique):这是一种更复杂的技术,它通过在少数类别的样本之间插入新的、合成的样本来增加样本数量。这些合成样本是通过在少数类别样本的k-最近邻之间进行插值来生成的。

下采样(Downsampling)

下采样是指减少多数类别的样本数量,使其与少数类别的样本数量相匹配或接近。这种方法的目的是减少训练集中的样本数量,以避免模型过度拟合到多数类别。下采样可以通过以下几种方式实现:

  1. 简单随机采样:从多数类别中随机选择一定数量的样本,丢弃其余的样本。
  2. 聚类中心采样:使用聚类算法找到多数类别的中心点,然后只保留这些中心点附近的样本。
  3. ** Tomek Links**:这是一种特殊的下采样技术,它识别并删除那些与少数类别样本非常接近的多数类别样本,因为这些样本可能会导致模型学习到错误的边界。

应用场景

  • 上采样:当数据集中存在严重的类别不平衡,且我们希望模型能够对少数类别有较好的识别能力时,上采样是一个合适的选择。
  • 下采样:当数据集中的类别不平衡不是特别严重,或者我们更关注模型的整体性能而不是对少数类别的识别能力时,下采样可能是一个更好的选择。

注意事项

  • 上采样和下采样都有其局限性。上采样可能会导致过拟合,因为它增加了数据集中的样本数量,尤其是合成样本。下采样可能会导致信息丢失,因为它丢弃了一部分数据。
  • 在实际应用中,可能需要结合上采样和下采样,或者使用其他技术(如调整类别权重、集成学习等)来处理类别不平衡问题。
import pandas as pd

data = pd.read_csv("data/raw/HIV_train.csv")
data.index = data["index"]
data["HIV_active"].value_counts()
start_index = data.iloc[0]["index"]

# %% Apply oversampling

# Check how many additional samples we need
neg_class = data["HIV_active"].value_counts()[0]
pos_class = data["HIV_active"].value_counts()[1]
multiplier = int(neg_class/pos_class) - 1

# Replicate the dataset for the positive class
replicated_pos = [data[data["HIV_active"] == 1]]*multiplier

# Append replicated data
data = data.append(replicated_pos,
                    ignore_index=True)
print(data.shape)

# Shuffle dataset
data = data.sample(frac=1).reset_index(drop=True)

# Re-assign index (This is our ID later)
index = range(start_index, start_index + data.shape[0])
data.index = index
data["index"] = data.index
data.head()

#  Save
data.to_csv("data/raw/HIV_train_oversampled.csv")

特征转化 dataset_featurizer.py ,生成pyg格式的Data对象

构图主要包含:

node_feature,从9个维度描述分子结构的信息

edge_feature,描述边的信息

edge_index

label

import pandas as pd
import torch
import torch_geometric
from torch_geometric.data import Dataset, Data
import os
from tqdm import tqdm
import deepchem as dc
from rdkit import Chem

print(f"Torch version: {torch.__version__}")
print(f"Cuda available: {torch.cuda.is_available()}")
print(f"Torch geometric version: {torch_geometric.__version__}")

class MoleculeDataset(Dataset):
    def __init__(self, root, filename, test=False, transform=None, pre_transform=None):
        """
        root = Where the dataset should be stored. This folder is split
        into raw_dir (downloaded dataset) and processed_dir (processed data). 
        """
        self.test = test
        self.filename = filename
        super(MoleculeDataset, self).__init__(root, transform, pre_transform)
        
    @property
    def raw_file_names(self):
        """ If this file exists in raw_dir, the download is not triggered.
            (The download func. is not implemented here)  
        """
        return self.filename

    @property
    def processed_file_names(self):
        """ If these files are found in raw_dir, processing is skipped"""
        self.data = pd.read_csv(os.path.join(self.root,self.filename)).reset_index()

        if self.test:
            return [f'data_test_{i}.pt' for i in list(self.data.index)]
        else:
            return [f'data_{i}.pt' for i in list(self.data.index)]

    def download(self):
        pass

    def process(self):
        self.data = pd.read_csv(os.path.join(self.root,self.filename))
        for index, mol in tqdm(self.data.iterrows(), total=self.data.shape[0]):
            mol_obj = Chem.MolFromSmiles(mol["smiles"])
            # Get node features
            node_feats = self._get_node_features(mol_obj)
            # Get edge features
            edge_feats = self._get_edge_features(mol_obj)
            # Get adjacency info
            edge_index = self._get_adjacency_info(mol_obj)
            # Get labels info
            label = self._get_labels(mol["HIV_active"])

            # Create data object
            data = Data(x=node_feats, 
                        edge_index=edge_index,
                        edge_attr=edge_feats,
                        y=label,
                        smiles=mol["smiles"]
                        ) 
            if self.test:
                torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_test_{index}.pt'))
            else:
                torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_{index}.pt'))

    def _get_node_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of Nodes, Node Feature size]
        """
        all_node_feats = []

        for atom in mol.GetAtoms():
            node_feats = []
            # Feature 1: Atomic number        
            node_feats.append(atom.GetAtomicNum())
            # Feature 2: Atom degree
            node_feats.append(atom.GetDegree())
            # Feature 3: Formal charge
            node_feats.append(atom.GetFormalCharge())
            # Feature 4: Hybridization
            node_feats.append(atom.GetHybridization())
            # Feature 5: Aromaticity
            node_feats.append(atom.GetIsAromatic())
            # Feature 6: Total Num Hs
            node_feats.append(atom.GetTotalNumHs())
            # Feature 7: Radical Electrons
            node_feats.append(atom.GetNumRadicalElectrons())
            # Feature 8: In Ring
            node_feats.append(atom.IsInRing())
            # Feature 9: Chirality
            node_feats.append(atom.GetChiralTag())

            # Append node features to matrix
            all_node_feats.append(node_feats)

        all_node_feats = np.asarray(all_node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)

    def _get_edge_features(self, mol):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        all_edge_feats = []

        for bond in mol.GetBonds():
            edge_feats = []
            # Feature 1: Bond type (as double)
            edge_feats.append(bond.GetBondTypeAsDouble())
            # Feature 2: Rings
            edge_feats.append(bond.IsInRing())
            # Append node features to matrix (twice, per direction)
            all_edge_feats += [edge_feats, edge_feats]

        all_edge_feats = np.asarray(all_edge_feats)
        return torch.tensor(all_edge_feats, dtype=torch.float)

    def _get_adjacency_info(self, mol):
        """
        We could also use rdmolops.GetAdjacencyMatrix(mol)
        but we want to be sure that the order of the indices
        matches the order of the edge features
        """
        edge_indices = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_indices += [[i, j], [j, i]]

        edge_indices = torch.tensor(edge_indices)
        edge_indices = edge_indices.t().to(torch.long).view(2, -1)
        return edge_indices

    def _get_labels(self, label):
        label = np.asarray([label])
        return torch.tensor(label, dtype=torch.int64)

    def len(self):
        return self.data.shape[0]

    def get(self, idx):
        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """
        if self.test:
            data = torch.load(os.path.join('../',self.processed_dir, 
                                 f'data_test_{idx}.pt'))
        else:
            data = torch.load(os.path.join('../', self.processed_dir, 
                                 f'data_{idx}.pt'))   
        return data

模型结构 :GNN1,常见的GAT结构

import os
import rdkit
import torch
import cairosvg
import numpy as np
import torch_geometric
import pandas as pd
from tqdm import tqdm
import deepchem as dc
from PIL import Image
from rdkit import Chem
from rdkit import Chem
from rdkit import RDConfig
from rdkit.Chem import Draw
from rdkit.Chem import Draw
from rdkit.Chem import rdBase
import matplotlib.pyplot as plt
import IPython.display as display
from rdkit.Chem.Draw import IPythonConsole
from sklearn.model_selection import train_test_split
from torch_geometric.data import Dataset, Data
from torch_geometric.data import DataLoader
from dataset_featurizer import MoleculeDataset
import torch.nn.functional as F 
import seaborn as sns
from torch.nn import Sequential, Linear, BatchNorm1d, ModuleList, ReLU
from torch_geometric.nn import TransformerConv, TopKPooling, GATConv, BatchNorm
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from sklearn.metrics import confusion_matrix, f1_score, \
    accuracy_score, precision_score, recall_score, roc_auc_score
from torch_geometric.nn.conv.x_conv import XConv
torch.manual_seed(42)


class GNN1(torch.nn.Module):
    def __init__(self, feature_size):
        super(GNN1, self).__init__()
        num_classes =2
        embedding_size = 1024
        
        # GNN1 layers
        self.conv1 = GATConv(feature_size, embedding_size, heads = 3, dropout = 0.3)
        self.head_transform1 = Linear(embedding_size*3, embedding_size)
        self.pool1 = TopKPooling(embedding_size, ratio=0.8)
        self.conv2 = GATConv(feature_size, embedding_size, heads = 3, dropout = 0.3)
        self.head_transform2 = Linear(embedding_size*3, embedding_size)
        self.pool2 = TopKPooling(embedding_size, ratio=0.5)
        self.conv3 = GATConv(feature_size, embedding_size, heads = 3, dropout = 0.3)
        self.head_transform3 = Linear(embedding_size*3, embedding_size)
        self.pool3 = TopKPooling(embedding_size, ratio=0.3)
        
        # Linear layers
        self.linear1 = Linear(embedding_size*2, 1024)
        self.linear2 = Linear(1024, num_classes)
        
    def forward(self, x, edge_attr, edge_index, batch_index):
        # First block
        x = self.conv1(x,edge_index)
        x = self.head_transform1(x)
        
        x, edge_index, edge_attr, batch_index, _,_ = self.pool1(x,
                                                                edge_index,
                                                                None,
                                                                batch_index)
        x1 = torch.cat([gap(x,batch_index), gap(x,batch_index)], dim=1)
        
        # Second block
        x = self.conv2(x,edge_index)
        x = self.head_transform2(x)
        
        x, edge_index, edge_attr, batch_index, _,_ = self.pool2(x,
                                                                edge_index,
                                                                None,
                                                                batch_index)
        
        x2 = torch.cat([gap(x,batch_index), gap(x,batch_index)], dim=1)

        # Third block
        x = self.conv3(x,edge_index)
        x = self.head_transform3(x)
        
        x, edge_index, edge_attr, batch_index, _,_ = self.pool3(x,
                                                                edge_index,
                                                                None,
                                                                batch_index)
        
        x3 = torch.cat([gap(x,batch_index), gap(x,batch_index)], dim=1)
        
        # Concat pooled vector
        x = x1+x2+x3
        
        # Output block
        x = self.linear1(x).relu()
        x = F.dropout(x,p=0.5,training=self.training)
        x = self.linear2(x)
        
        return x

GNN2

import os
import rdkit
import torch
import cairosvg
import numpy as np
import torch_geometric
import pandas as pd
from tqdm import tqdm
import deepchem as dc
from PIL import Image
from rdkit import Chem
from rdkit import Chem
from rdkit import RDConfig
from rdkit.Chem import Draw
from rdkit.Chem import Draw
from rdkit.Chem import rdBase
import matplotlib.pyplot as plt
import IPython.display as display
from rdkit.Chem.Draw import IPythonConsole
from sklearn.model_selection import train_test_split
from torch_geometric.data import Dataset, Data
from torch_geometric.data import DataLoader
from dataset_featurizer import MoleculeDataset
import torch.nn.functional as F 
import seaborn as sns
from torch.nn import Sequential, Linear, BatchNorm1d, ModuleList, ReLU
from torch_geometric.nn import TransformerConv, TopKPooling, GATConv, BatchNorm
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from sklearn.metrics import confusion_matrix, f1_score, \
    accuracy_score, precision_score, recall_score, roc_auc_score
from torch_geometric.nn.conv.x_conv import XConv
torch.manual_seed(42)


class GNN2(torch.nn.Module):
    def __init__(self, feature_size):
        super(GNN2, self).__init__()
        num_classes = 2
        embedding_size = 1024
        
        # GNN2 layers
        self.conv1 = GATConv(feature_size, embedding_size, heads=3, dropout=0.3)
        self.head_transform1 = Linear(embedding_size*3, embedding_size)
        self.pool1 = TopKPooling(embedding_size, ratio=0.8)
        self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
        self.head_transform2 = Linear(embedding_size*3, embedding_size)
        self.pool2 = TopKPooling(embedding_size, ratio=0.5)
        self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
        self.head_transform3 = Linear(embedding_size*3, embedding_size)
        self.pool3 = TopKPooling(embedding_size, ratio=0.3)
        
        # Transformer layer
        self.transformer = TransformerConv(embedding_size, heads=4, num_layers=3)
        
        # Isomorphism layer
        self.isomorphism = XConv(Linear(feature_size, embedding_size), K=3)
        
        # Linear layers
        self.linear1 = Linear(embedding_size*2, 1024)
        self.linear2 = Linear(1024, num_classes)
        
    def forward(self, x, edge_attr, edge_index, batch_index):
        # Isomorphism layer
        x = self.isomorphism(x, edge_index, edge_attr)
        
        # Transformer layer
        x = self.transformer(x, edge_index)
        
        # First block
        x = self.conv1(x, edge_index)
        x = self.head_transform1(x)
        
        x, edge_index, edge_attr, batch_index, _, _ = self.pool1(x,
                                                                 edge_index,
                                                                 None,
                                                                 batch_index)
        x1 = torch.cat([gap(x, batch_index), gap(x, batch_index)], dim=1)
        
        # Second block
        x = self.conv2(x, edge_index)
        x = self.head_transform2(x)
        
        x, edge_index, edge_attr, batch_index, _, _ = self.pool2(x,
                                                                 edge_index,
                                                                 None,
                                                                 batch_index)
        
        x2 = torch.cat([gap(x, batch_index), gap(x, batch_index)], dim=1)

        # Third block
        x = self.conv3(x, edge_index)
        x = self.head_transform3(x)
        
        x, edge_index, edge_attr, batch_index, _, _ = self.pool3(x,
                                                                 edge_index,
                                                                 None,
                                                                 batch_index)
        
        x3 = torch.cat([gap(x, batch_index), gap(x, batch_index)], dim=1)
        
        # Concat pooled vector
        x = x1 + x2 + x3
        
        # Output block
        x = self.linear1(x).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.linear2(x)
        
        return x

GNN3

import os
import rdkit
import torch
import cairosvg
import numpy as np
import torch_geometric
import pandas as pd
from tqdm import tqdm
import deepchem as dc
from PIL import Image
from rdkit import Chem
from rdkit import Chem
from rdkit import RDConfig
from rdkit.Chem import Draw
from rdkit.Chem import Draw
from rdkit.Chem import rdBase
import matplotlib.pyplot as plt
import IPython.display as display
from rdkit.Chem.Draw import IPythonConsole
from sklearn.model_selection import train_test_split
from torch_geometric.data import Dataset, Data
from torch_geometric.data import DataLoader
from dataset_featurizer import MoleculeDataset
import torch.nn.functional as F 
import seaborn as sns
from torch.nn import Sequential, Linear, BatchNorm1d, ModuleList, ReLU
from torch_geometric.nn import TransformerConv, TopKPooling, GATConv, BatchNorm
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from sklearn.metrics import confusion_matrix, f1_score, \
    accuracy_score, precision_score, recall_score, roc_auc_score
from torch_geometric.nn.conv.x_conv import XConv
torch.manual_seed(42)


class GNN3(torch.nn.Module):
    def __init__(self, feature_size, edge_feature_size):
        super(GNN3, self).__init__()
        num_classes = 2
        embedding_size = 1024
        
        # GNN3 layers
        self.conv1 = GATConv(feature_size, embedding_size, heads=3, dropout=0.3)
        self.head_transform1 = Linear(embedding_size*3 + edge_feature_size, embedding_size)
        self.pool1 = TopKPooling(embedding_size, ratio=0.8)
        self.conv2 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
        self.head_transform2 = Linear(embedding_size*3 + edge_feature_size, embedding_size)
        self.pool2 = TopKPooling(embedding_size, ratio=0.5)
        self.conv3 = GATConv(embedding_size, embedding_size, heads=3, dropout=0.3)
        self.head_transform3 = Linear(embedding_size*3 + edge_feature_size, embedding_size)
        self.pool3 = TopKPooling(embedding_size, ratio=0.3)
        
        # Transformer layer
        self.transformer = TransformerConv(embedding_size, heads=4, num_layers=3)
        
        # Isomorphism layer
        self.isomorphism = XConv(Linear(feature_size + edge_feature_size, embedding_size), K=3)
        
        # Linear layers
        self.linear1 = Linear(embedding_size*2, 1024)
        self.linear2 = Linear(1024, num_classes)
        
    def forward(self, x, edge_attr, edge_index, batch_index):
        # Isomorphism layer
        x = self.isomorphism(x, edge_index, edge_attr)
        
        # Transformer layer
        x = self.transformer(x, edge_index)
        
        # First block
        x_with_edge = torch.cat([x, edge_attr], dim=1)
        x = self.conv1(x_with_edge, edge_index)
        x = self.head_transform1(x)
        
        x, edge_index, edge_attr, batch_index, _, _ = self.pool1(x,
                                                                 edge_index,
                                                                 edge_attr,
                                                                 batch_index)
        x1 = torch.cat([gap(x, batch_index), gap(x, batch_index)], dim=1)
        
        # Second block
        x_with_edge = torch.cat([x, edge_attr], dim=1)
        x = self.conv2(x_with_edge, edge_index)
        x = self.head_transform2(x)
        
        x, edge_index, edge_attr, batch_index, _, _ = self.pool2(x,
                                                                 edge_index,
                                                                 edge_attr,
                                                                 batch_index)
        
        x2 = torch.cat([gap(x, batch_index), gap(x, batch_index)], dim=1)

        # Third block
        x_with_edge = torch.cat([x, edge_attr], dim=1)
        x = self.conv3(x_with_edge, edge_index)
        x = self.head_transform3(x)
        
        x, edge_index, edge_attr, batch_index, _, _ = self.pool3(x,
                                                                 edge_index,
                                                                 edge_attr,
                                                                 batch_index)
        
        x3 = torch.cat([gap(x, batch_index), gap(x, batch_index)], dim=1)
        
        # Concat pooled vector
        x = x1 + x2 + x3
        
        # Output block
        x = self.linear1(x).relu()
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.linear2(x)
        
        return x

train.py 训练流程代码

#%%
import os
import rdkit
import torch
import cairosvg
import numpy as np
import torch_geometric
import pandas as pd
from tqdm import tqdm
import deepchem as dc
from PIL import Image
from rdkit import Chem
from rdkit import Chem
from rdkit import RDConfig
from rdkit.Chem import Draw
from rdkit.Chem import Draw
from rdkit.Chem import rdBase
import matplotlib.pyplot as plt
import IPython.display as display
from rdkit.Chem.Draw import IPythonConsole
from sklearn.model_selection import train_test_split
from torch_geometric.data import Dataset, Data
from torch_geometric.data import DataLoader
from dataset_featurizer import MoleculeDataset
from gnn_project.model.GNN1 import GNN1
from gnn_project.model.GNN2 import GNN2
from gnn_project.model.GNN3 import GNN3
import torch.nn.functional as F 
import seaborn as sns
from torch.nn import Sequential, Linear, BatchNorm1d, ModuleList, ReLU
from torch_geometric.nn import TransformerConv, TopKPooling, GATConv, BatchNorm
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from sklearn.metrics import confusion_matrix, f1_score, \
    accuracy_score, precision_score, recall_score, roc_auc_score
from torch_geometric.nn.conv.x_conv import XConv
torch.manual_seed(42)

data_path = 'data/raw/HIV_data.csv'
data = pd.read_csv(data_path,index_col=[0])
print("###### Raw Data Shape - ")
print(data.shape)
print("0 : ",data["HIV_active"].value_counts()[0])
print("1 : ",data["HIV_active"].value_counts()[1])

#%%
# Split the data, Due to imbalance datasets we need to set train ration high
print("###### After Data split Shape - ")
train_data = pd.read_csv('data/split_data/HIV_train.csv')
test_data = pd.read_csv('data/split_data/HIV_test.csv')
train_data_oversampled = pd.read_csv('data/split_data/HIV_train_oversampled.csv')
print("Train Data ", train_data.shape)
print("0 : ",train_data["HIV_active"].value_counts()[0])
print("1 : ",train_data["HIV_active"].value_counts()[1])

print("Test Data ", test_data.shape)
print("0 : ",test_data["HIV_active"].value_counts()[0])
print("1 : ",test_data["HIV_active"].value_counts()[1])
 
print("Train Data Oversampled ", train_data_oversampled.shape)
print("0 : ",train_data_oversampled["HIV_active"].value_counts()[0])
print("1 : ",train_data_oversampled["HIV_active"].value_counts()[1])


# %%
# Define the folder path to save the images
output_folder = "visualization"
os.makedirs(output_folder, exist_ok=True)

sample_smiles = train_data["smiles"][4:30].values
sdf = Chem.SDMolSupplier(output_folder+'/cdk2.sdf')
mols = [m for m in sdf]

for i, smiles in enumerate(sample_smiles):
    core = Chem.MolFromSmiles(smiles)
    img = Draw.MolsToGridImage(mols, molsPerRow=3, highlightAtomLists=[mol.GetSubstructMatch(core) for mol in mols], useSVG=True)

    ## Save the image in the output folder
    image_path = os.path.join(output_folder, f"image_{i}.png")
    cairosvg.svg2png(bytestring=img.data, write_to=image_path)
    break

print(f"Image saved: {image_path}")


#%%
print("######## Loading dataset...")
train_dataset = MoleculeDataset(root="data/split_data", filename="HIV_train_oversampled.csv")
test_dataset = MoleculeDataset(root="data/split_data", filename="HIV_test.csv", test=True)       

# %%
print("######## Loading GNN1 Model...")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

model = GNN1(feature_size=train_dataset[0].x.shape[1])
model = model.to(device)
print(f"######### Number of parameters: {count_parameters(model)}")
print(model)
# %%
# Loss and Optimizer
# Due to imbalance postive and negative label so apply weight in the +ve side < 1 increases precision, > 1 recall
weight = torch.tensor([1,10], dtype=torch.float32).to(device)
loss_fn = torch.nn.CrossEntropyLoss(weight==weight)
optimizer = torch.optim.SGD(model.parameters(), 
                            lr=0.1,
                            momentum=0.9)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)

# %%
NUM_GRAPHS_PER_BATCH = 256
train_loader = DataLoader(train_dataset,
                          batch_size = NUM_GRAPHS_PER_BATCH, shuffle=True)
test_loader = DataLoader(test_dataset,
                          batch_size = NUM_GRAPHS_PER_BATCH, shuffle=True)

#%%


#%%

def train(epoch, model, train_loader, optimizer, loss_fn):
    #Enumerate over the data
    all_preds = []
    all_labels = []
    running_loss = 0.0
    step = 0
    for _,batch in enumerate(tqdm(train_loader)):
        #Using GPU
        batch.to(device)
        #Reset gradient
        optimizer.zero_grad()
        #passing the node feature and concat the info
        pred = model(batch.x.float(),
                     batch.edge_attr.float,
                     batch.edge_index,
                     batch.batch)
        
        # Calculating the loss and gradients
        loss = loss_fn(torch.squeeze(pred), batch.y.float())
        loss.backward()  
        optimizer.step()  
        # Update tracking
        running_loss += loss.item()
        step += 1
        all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
        all_labels.append(batch.y.cpu().detach().numpy())
    
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    calculate_metrics(all_preds, all_labels, epoch, "train")
    return running_loss/step
#%%   
def test(epoch, model, test_loader, loss_fn):
    #Enumerate over the data
    all_preds = []
    all_preds_raw = []
    all_labels = []
    running_loss = 0.0
    step = 0
    for batch in test_loader:
        #Using GPU
        batch.to(device)
        pred = model(batch.x.float(),
                    batch.edge_attr.float,
                    batch.edge_index,
                    batch.batch)
        loss = loss_fn(torch.squeeze(pred), batch.y.float())            
        # Update tracking
        running_loss += loss.item()
        step += 1
        all_preds.append(np.rint(torch.sigmoid(pred).cpu().detach().numpy()))
        all_preds_raw.append(torch.sigmoid(pred).cpu().detach().numpy())
        all_labels.append(batch.y.cpu().detach().numpy())
    
    all_preds = np.concatenate(all_preds).ravel()
    all_labels = np.concatenate(all_labels).ravel()
    print(all_preds_raw[0][:10])
    print(all_preds[:10])
    print(all_labels[:10])
    calculate_metrics(all_preds, all_labels, epoch, "test")
    log_conf_matrix(all_preds, all_labels, epoch)
    return running_loss/step

#%%

def log_conf_matrix(y_pred, y_true, epoch):
    # Log confusion matrix as image
    cm = confusion_matrix(y_pred, y_true)
    classes = ["0", "1"]
    df_cfm = pd.DataFrame(cm, index = classes, columns = classes)
    plt.figure(figsize = (10,7))
    cfm_plot = sns.heatmap(df_cfm, annot=True, cmap='Blues', fmt='g')
    cfm_plot.figure.savefig(f'data/images/cm_{epoch}.png')
   
    
def calculate_metrics(y_pred, y_true, epoch, type):
    print(f"\n Confusion matrix: \n {confusion_matrix(y_pred, y_true)}")
    print(f"F1 Score: {f1_score(y_true, y_pred)}")
    print(f"Accuracy: {accuracy_score(y_true, y_pred)}")
    prec = precision_score(y_true, y_pred)
    rec = recall_score(y_true, y_pred)
    print(f"Precision: {prec}")
    print(f"Recall: {rec}")
    
    try:
        roc = roc_auc_score(y_pred,y_true)
        print(f"ROC AUC: {roc}")
    except:
        print(f"ROC AUC: not definded")
# %%
# Start training
print("###### Start GNN Model training")
best_loss = 1000
early_stopping_counter = 0
for epoch in range(300): 
    if early_stopping_counter <= 10: # = x * 5 
        # Training
        model.train()
        loss = train(epoch, model, train_loader, optimizer, loss_fn)
        print(f"Epoch {epoch} | Train Loss {loss}")
        #mlflow.log_metric(key="Train loss", value=float(loss), step=epoch)

        # Testing
        model.eval()
        if epoch % 5 == 0:
            loss = test(epoch, model, test_loader, loss_fn)
            print(f"Epoch {epoch} | Test Loss {loss}")
           # mlflow.log_metric(key="Test loss", value=float(loss), step=epoch)
            
            # Update best loss
            if float(loss) < best_loss:
                best_loss = loss
                # Save the currently best model 
                #mlflow.pytorch.log_model(model, "model", signature=SIGNATURE)
                early_stopping_counter = 0
            else:
                early_stopping_counter += 1

        scheduler.step()
    else:
        print("Early stopping due to no improvement.")
        print([best_loss])
print(f"Finishing training with best test loss: {best_loss}")
print([best_loss])

output_folder = "model_weight"
os.makedirs(output_folder, exist_ok=True)
model_path = os.join.path(output_folder,"model.pth") # Replace with the desired path to save the model
torch.save(model, model_path)

# %%

train_optimization.py

import argparse
import os
import pandas as pd
import torch
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score, precision_score, recall_score, roc_auc_score
from torch_geometric.data import DataLoader
from torch.nn import BCELoss
from torch.optim import Adam
from dataset_featurizer import MoleculeDataset
from model.GNN1 import GNN1
from model.GNN2 import GNN2
from model.GNN3 import GNN3
import optuna

torch.manual_seed(42)

# Create a parser object
parser = argparse.ArgumentParser(description='GNN Model Training')

# Add arguments for data paths
parser.add_argument('--test_data_path', type=str, required=True, help='Path to the test data file')
parser.add_argument('--train_oversampled', type=str, required=True, help='Path to the train oversampled data file')

# Add an argument for the GNN model selection
parser.add_argument('--model', type=str, choices=['GNN1', 'GNN2', 'GNN3'], default='GNN1', help='Choose the GNN model (GNN1, GNN2, GNN3)')

# Add an argument for the number of epochs
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')

# Parse the command-line arguments
args = parser.parse_args()

# Get the data paths from the command-line arguments
test_data_path = args.test_data_path
train_oversampled_ = args.train_oversampled

# Get the selected model from the command-line arguments
selected_model = args.model

# Get the number of epochs from the command-line arguments
num_epochs = args.epochs

# Load the data
test_data = pd.read_csv(test_data_path)
train_data = pd.read_csv(train_oversampled)

model_folder = "model_weights"
os.makedirs(model_folder, exist_ok=True)

# Define the GNN model based on the selected model
if selected_model == 'GNN1':
    model = GNN1(feature_size=train_data[0].x.shape[1])
  
elif selected_model == 'GNN2':
    model = GNN2(feature_size=train_data[0].x.shape[1])
    
elif selected_model == 'GNN3':
    model = GNN3(feature_size=train_data[0].x.shape[1])
    
else:
    raise ValueError('Invalid model selected')

# Define the loss function
loss_fn = BCELoss()

# Define the optimizer
optimizer = Adam(model.parameters(), lr=0.001)

# Define the objective function for Optuna optimization
def objective(trial):
    # Sample the hyperparameters to be tuned
    hyperparameters = {
        "batch_size": trial.suggest_categorical("batch_size", [32, 128, 64]),
        "learning_rate": trial.suggest_loguniform("learning_rate", 1e-4, 1e-1),
        "weight_decay": trial.suggest_loguniform("weight_decay", 1e-5, 1e-3),
        "sgd_momentum": trial.suggest_uniform("sgd_momentum", 0.5, 0.9),
        "scheduler_gamma": trial.suggest_categorical("scheduler_gamma", [0.995, 0.9, 0.8, 0.5, 1]),
        "pos_weight": trial.suggest_categorical("pos_weight", [1.0]),
        "model_embedding_size": trial.suggest_categorical("model_embedding_size", [8, 16, 32, 64, 128]),
        "model_attention_heads": trial.suggest_int("model_attention_heads", 1, 4),
        "model_layers": trial.suggest_categorical("model_layers", [3]),
        "model_dropout_rate": trial.suggest_uniform("model_dropout_rate", 0.2, 0.9),
        "model_top_k_ratio": trial.suggest_categorical("model_top_k_ratio", [0.2, 0.5, 0.8, 0.9]),
        "model_top_k_every_n": trial.suggest_categorical("model_top_k_every_n", [0]),
        "model_dense_neurons": trial.suggest_categorical("model_dense_neurons", [16, 128, 64, 256, 32]),
    }

    # Set the hyperparameters in the model
    model.set_hyperparameters(**hyperparameters)

    # Create the data loaders
    train_loader = DataLoader(train_dataset, batch_size=hyperparameters["batch_size"], shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=hyperparameters["batch_size"], shuffle=False)

    # Train the model
    for epoch in range(num_epochs):
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            out = model(batch)
            loss = loss_fn(out, batch.y)
            loss.backward()
            optimizer.step()

    # Evaluate the model
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in test_loader:
            out = model(batch)
            pred = (out >= 0.5).float()
            y_pred.extend(pred.tolist())
            y_true.extend(batch.y.tolist())

    # Compute evaluation metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    auc_roc = roc_auc_score(y_true, y_pred)

    return f1

# Create the dataset
test_dataset = MoleculeDataset(test_data)
train_dataset = MoleculeDataset(train_data)

# Create the Optuna study and run the optimization
study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=100)

# Get the best hyperparameters and metric
best_hyperparameters = study.best_params
best_metric = study.best_value

# Set the best hyperparameters in the model
model.set_hyperparameters(**best_hyperparameters)

# Create the best data loaders
best_train_loader = DataLoader(train_dataset, batch_size=best_hyperparameters["batch_size"], shuffle=True)
best_test_loader = DataLoader(test_dataset, batch_size=best_hyperparameters["batch_size"], shuffle=False)

# Train the model using the best hyperparameters
for epoch in range(num_epochs):
    model.train()
    for batch in best_train_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = loss_fn(out, batch.y)
        loss.backward()
        optimizer.step()

    # Evaluate the model
    model.eval()
    y_true = []
    y_pred = []
    with torch.no_grad():
        for batch in best_test_loader:
            out = model(batch)
            pred = (out >= 0.5).float()
            y_pred.extend(pred.tolist())
            y_true.extend(batch.y.tolist())

    # Compute evaluation metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    auc_roc = roc_auc_score(y_true, y_pred)

    print(f'Epoch {epoch + 1}: Accuracy={accuracy:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}, AUC-ROC={auc_roc:.4f}')

print('Best Hyperparameters:', best_hyperparameters)
print('Best Metric:', best_metric)

inference.py  推理代码

import torch
import pandas as pd
from dataset_featurizer import MoleculeDataset
from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score

# Load the test dataset
test_dataset = MoleculeDataset(root="data/split_data", filename="HIV_test.csv", test=True)
test_loader = DataLoader(test_dataset, batch_size=NUM_GRAPHS_PER_BATCH, shuffle=True)

# Load the trained model
model = torch.load(os.join.path(output_folder,"model.pth"))
model.eval()

# Create lists to store the predicted and true labels
all_preds = []
all_labels = []
all_preds_raw = []

# Perform inference on the test dataset
with torch.no_grad():
    for batch in test_loader:
        # Move the batch to the device
        batch = batch.to(device)

        # Perform forward pass
        pred = model(batch.x.float(), batch.edge_attr.float(), batch.edge_index, batch.batch)

        # Convert the predictions to class labels
        preds = torch.argmax(pred, dim=1)

        # Append the predicted and true labels to the lists
        all_preds.extend(preds.cpu().detach().numpy())
        all_labels.extend(batch.y.cpu().detach().numpy())
        all_preds_raw.extend(pred.cpu().detach().numpy())

# Calculate the confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Calculate the accuracy score
accuracy = accuracy_score(all_labels, all_preds)

# Calculate the ROC AUC score
roc_auc = roc_auc_score(all_labels, all_preds_raw)

# Print the confusion matrix, accuracy score, and ROC AUC score
print("Confusion Matrix:")
print(cm)
print("Accuracy Score:", accuracy)
print("ROC AUC Score:", roc_auc)

config.py

import numpy as np



HYPERPARAMETERS = {
    "batch_size": [32, 128, 64],
    "learning_rate": [0.1, 0.05, 0.01, 0.001],
    "weight_decay": [0.0001, 0.00001, 0.001],
    "sgd_momentum": [0.9, 0.8, 0.5],
    "scheduler_gamma": [0.995, 0.9, 0.8, 0.5, 1],
    "pos_weight" : [1.0],  
    "model_embedding_size": [8, 16, 32, 64, 128],
    "model_attention_heads": [1, 2, 3, 4],
    "model_layers": [3],
    "model_dropout_rate": [0.2, 0.5, 0.9],
    "model_top_k_ratio": [0.2, 0.5, 0.8, 0.9],
    "model_top_k_every_n": [0],
    "model_dense_neurons": [16, 128, 64, 256, 32]
}


'''
BEST_PARAMETERS = {
    "batch_size": [128],
    "learning_rate": [0.01],
    "weight_decay": [0.0001],
    "sgd_momentum": [0.8],
    "scheduler_gamma": [0.8],
    "pos_weight": [1.3],
    "model_embedding_size": [64],
    "model_attention_heads": [3],
    "model_layers": [4],
    "model_dropout_rate": [0.2],
    "model_top_k_ratio": [0.5],
    "model_top_k_every_n": [1],
    "model_dense_neurons": [256]
}

'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2237636.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何调整pdf的页面尺寸

用福昕阅读器打开pdf&#xff0c;进入打印页面&#xff0c;选择“属性”&#xff0c;在弹出的页面选择“高级” 选择你想调成的纸张尺寸&#xff0c;然后打印&#xff0c;打印出来的pdf就是调整尺寸后的pdf

查缺补漏----用户上网过程(HTTP,DNS与ARP)

&#xff08;1&#xff09;HTTP 来自湖科大计算机网络微课堂&#xff1a; ① HTTP/1.0采用非持续连接方式。在该方式下&#xff0c;每次浏览器要请求一个文件都要与服务器建立TCP连接当收到响应后就立即关闭连接。 每请求一个文档就要有两倍的RTT的开销。若一个网页上有很多引…

koa、vue安装与使用

koa官网&#xff1a;https://koajs.com/ 首选创建一个文件夹&#xff1a;mkdir koaDemo (cmd即可) 文件夹初始化&#xff1a;npm init (cmd即可) 初始化完成后就会产生一个package.json的文件。 安装&#xff1a; npm install koa --save (vscode的控制台中安装&a…

Linux:版本控制器git的简单使用+gdb/cgdb调试器的使用

一&#xff0c;版本控制器git 1.1概念 为了能够更方便我们管理不同版本的文件&#xff0c;便有了版本控制器。所谓的版本控制器&#xff0c;就是能让你 了解到⼀个文件的历史&#xff0c;以及它的发展过程的系统。通俗的讲就是⼀个可以记录工程的每⼀次改动和版本迭代的⼀个…

ML 系列:第 21 节 — 离散概率分布(二项分布)

一、说明 二项分布描述了在固定数量的独立伯努利试验中一定数量的成功的概率&#xff0c;其中每个试验只有两种可能的结果&#xff08;通常标记为成功和失败&#xff09;。 二、探讨伯努利模型 例如&#xff0c;假设您正在抛一枚公平的硬币 &#xff08;其中正面成功&#xff…

【优选算法篇】微位至简,数之恢宏——解构 C++ 位运算中的理与美

文章目录 C 位运算详解&#xff1a;基础题解与思维分析前言第一章&#xff1a;位运算基础应用1.1 判断字符是否唯一&#xff08;easy&#xff09;解法&#xff08;位图的思想&#xff09;C 代码实现易错点提示时间复杂度和空间复杂度 1.2 丢失的数字&#xff08;easy&#xff0…

存算分离与计算向数据移动:深度解析与Java实现

背景 随着大数据时代的到来&#xff0c;数据量的激增给传统的数据处理架构带来了巨大的挑战。传统的“存算一体”架构&#xff0c;即计算资源与存储资源紧密耦合&#xff0c;在处理海量数据时逐渐显露出其局限性。为了应对这些挑战&#xff0c;存算分离&#xff08;Disaggrega…

WPS单元格重复值提示设置

选中要检查的所有的单元格 设置提示效果 当出现单元格值重复时&#xff0c;重复的单元格就会自动变化 要修改或删除&#xff0c;点击

Linux笔记之pandoc实现各种文档格式间的相互转换

Linux笔记之pandoc实现各种文档格式间的相互转换 code review! 文章目录 Linux笔记之pandoc实现各种文档格式间的相互转换1.安装 Pandoc2.Word转Markdown3.markdown转html4.Pandoc 支持的一些常见格式4.1.输入格式4.2.输出格式 1.安装 Pandoc sudo apt-get install pandoc # …

MySQL重难点(一)索引

目录 一、引子&#xff1a;MySQL与磁盘间的交互基本单元&#xff1a;Page 1、重要问题&#xff1a;为什么 MySQL 每次与磁盘交互&#xff0c;都要以 16KB 为基本单元&#xff1f;为什么不用多少加载多少&#xff1f; 2、有关MySQL的一些共识 3、如何管理 Page 3.1 单个 P…

solo博客使用非docker方式进行https部署

solo博客使用非docker方式进行https部署 数据库配置启动命令讲解设置自定义访问端口&#xff1a;9168 配置https访问部署效果 服务器上请通过 Docker 部署。但是我服务器资源有限&#xff0c;不想安装docker&#xff0c;直接以编译包的形式运行&#xff0c;节省资源。 如果不会…

【Steam登录】protobuf协议逆向 | 续

登录接口&#xff1a; ‘https://api.steampowered.com/IAuthenticationService/BeginAuthSessionViaCredentials/v1’ 精准定位&#xff0c;打上条件断点 this.CreateWebAPIURL(t) ‘https://api.steampowered.com/IAuthenticationService/BeginAuthSessionViaCredentials/v1…

环形链表问题(图 + 证明 + 题)

文章目录 判断链表是否有环返回链表开始入环的第一个结点 判断链表是否有环 题目链接 思路&#xff1a; 可以明确的是&#xff1a;若一个链表带环&#xff0c;那么用指针一直顺着链表遍历&#xff0c;最终会回到某个地方。 我们可以定义两个指针&#xff08;快慢指针&#xf…

Linux Centos7 如何安装图形化界面

如果系统是以最小安装的话,一般是不带有图形化界面的,如果需要图形话界面,需要单独安装。本篇教程,主要介绍如何在CentOS7中安装图形化界面。 1、更新系统 首先,保证系统依赖版本处于最新。 sudo yum update -y2、安装 GNOME 桌面环境 sudo yum groupinstall "GNOME…

Spark的学习-02

Spark Standalone集群的安装 架构&#xff1a;普通分布式主从架构 主&#xff1a;Master&#xff1a;管理节点&#xff1a;管理从节点、接客、资源管理和任务 调度&#xff0c;等同于YARN中的ResourceManager 从&#xff1a;Worker&#xff1a;计算节点&#xff1a;负责利用自己…

Linux相关概念和易错知识点(20)(dentry、分区、挂载)

目录 1.dentry &#xff08;1&#xff09;路径缓存的原因 &#xff08;2&#xff09;dentry的结构 ①多叉树结构 ②file和dentry之间的联系 ③路径概念存在的意义 2.分区 &#xff08;1&#xff09;为什么要确认分区 &#xff08;2&#xff09;挂载 ①进入分区 ②被挂…

《Linux运维总结:基于银河麒麟V10+ARM64架构CPU部署redis 6.2.14 TLS/SSL哨兵集群》

总结:整理不易,如果对你有帮助,可否点赞关注一下? 更多详细内容请参考:《Linux运维篇:Linux系统运维指南》 一、简介 Redis 哨兵模式是一种高可用性解决方案,它通过监控 Redis 主从架构,自动执行故障转移,从而确保服务的连续性。哨兵模式的核心组件包括哨兵(Sentine…

vue3实现一个无缝衔接、滚动平滑的列表自动滚屏效果,支持鼠标移入停止移出滚动

文章目录 前言一、滚动元素相关属性回顾一、实现分析二、代码实现示例&#xff1a;2、继续添加功能&#xff0c;增加鼠标移入停止滚动、移出继续滚动效果2、继续完善 前言 列表自动滚屏效果常见于大屏开发场景中&#xff0c;本文将讲解用vue3实现一个无缝衔接、滚动平滑的列表自…

腾讯云nginx SSL证书配置

本章教程,记录在使用腾讯云域名nginx证书配置SSL配置过程。 一、nginx配置 域名和证书,替换成自己的即可。证书文件可以自定义路径位置。服务器安全组或者防火墙需要开放80和443端口。 server {#SSL 默认访问端口号为 443listen 443 ssl; #请填写绑定证书的域名server_name c…

RabbitMQ的DLX(Dead-Letter-Exchange 死信交换机,死信交换器,死信邮箱)(重要)

RabbitMQ的DLX 1、RabbitMQ死信队列2、代码示例2.1、队列过期2.1.1、配置类RabbitConfig&#xff08;关键代码&#xff09;2.1.2、业务类MessageService2.1.3、配置文件application.yml2.1.4、启动类2.1.5、配置文件2.1.6、测试 2.2、消息过期2.2.1、配置类RabbitConfig2.2.2、…