深度学习 GNN图神经网络(四)线性回归之ESOL数据集水溶性预测

news2025/1/16 1:59:20

线性回归之ESOL数据集水溶性预测

  • 一、前言
  • 二、ESOL数据集
  • 三、加载数据集
  • 四、数据拆分
  • 五、构造模型
  • 六、训练模型
  • 七、测试结果
  • 八、分类问题
  • 参考文献

一、前言

本文旨在使用化合物分子的SMILES字符串进行数据模型训练,对其水溶性的值进行预测。

之前的文章《深度学习 GNN图神经网络(三)模型思想及文献分类案例实战》引用的Cora数据集只有一张图,属于图神经网络的节点分类问题。本文介绍的是多图批量训练的线性回归问题,在文章最后也讨论了图分类问题。

二、ESOL数据集

本文使用的是ESOL数据集,在文章《如何将化学分子SMILES字符串转化为Pytorch图数据结构——ESOL分子水溶性数据集解析》中有详细介绍,在此不作详述。

三、加载数据集

from torch_geometric.datasets import MoleculeNet

dataset = MoleculeNet(root="data", name="ESOL")

print('num_features:',dataset.num_features)
print('num_classes:',dataset.num_classes)
print('num_node_features',dataset.num_node_features)
print("size:", len(dataset))

d=dataset[10]
print("Sample:", d)
print("Sample y:", d.y)
print("Sample num_nodes:",d.num_nodes)
print("Sample num_edges:",d.num_edges)

这里可以得到数据集的一些基本信息:

num_features: 9
num_classes: 734
num_node_features 9
size: 1128
Sample: Data(x=[6, 9], edge_index=[2, 12], edge_attr=[12, 3], smiles='O=C1CCCN1', y=[1, 1])
Sample y: tensor([[1.0700]])
Sample num_nodes: 6
Sample num_edges: 12

四、数据拆分

将数据集拆分为训练数据和测试数据:

from torch_geometric.loader import DataLoader
data_size = len(dataset)
batch_size = 128
train_data=dataset[:int(data_size*0.8)]
test_data=dataset[int(data_size*0.8):]

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=len(test_data))

五、构造模型

import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
import matplotlib.pyplot as plt
from torch_geometric.nn import global_mean_pool

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

hidden_channels = 64

class GNN(nn.Module):
    
    def __init__(self):
        # 初始化Pytorch父类
        super().__init__()
        
        self.conv1=GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2=GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.conv4 = GCNConv(hidden_channels, hidden_channels)
        self.out = nn.Linear(hidden_channels, 1)
        
        # 创建损失函数,使用均方误差
        self.loss_function = nn.MSELoss()

        # 创建优化器,使用Adam梯度下降
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.005,weight_decay=5e-4)

        # 训练次数计数器
        self.counter = 0
        # 训练过程中损失值记录
        self.progress = []
    
    # 前向传播函数
    def forward(self, x, edge_index,batch):
        
        x=x.to(device)
        edge_index=edge_index.to(device)
        batch=batch.to(device)

        x=self.conv1(x, edge_index)
        x=x.relu()
        x=self.conv2(x, edge_index)
        x=x.relu()
        x=self.conv3(x, edge_index)
        x=x.relu()
        x=self.conv4(x, edge_index)
        x=x.relu()

        # 全局池化
        x = global_mean_pool(x, batch)  # [x, batch]

        out=self.out(x)
        return out
    
    # 训练函数
    def train(self, data):

        # 前向传播计算,获得网络输出
        outputs = self.forward(data.x.float(),data.edge_index,data.batch)
        
        # 计算损失值
        y=data.y.to(device)
        loss = self.loss_function(outputs, y)

        # 累加训练次数
        self.counter += 1

        # 每10次训练记录损失值
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())

        # 每1000次输出训练次数   
        if (self.counter % 1000 == 0):
            print(f"counter={self.counter}, loss={loss.item()}")
            
        # 梯度清零, 反向传播, 更新权重
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
    
    # 测试函数
    def test(self, data):
        # 前向传播计算,获得网络输出
        outputs = self.forward(data.x.float(),data.edge_index,data.batch)

        # 把绝对值误差小于1的视为正确,计算准确度
        y=data.y.to(device)
        acc=sum(torch.abs(y-outputs)<1)/len(data.y)
        return acc

    # 绘制损失变化图
    def plot_progress(self):
        plt.plot(range(len(self.progress)),self.progress)
      

六、训练模型

model = GNN()
model.to(device)

for i in range(1001):
    for data in train_loader:
        # print(data,'num_graphs:',data.num_graphs)
        model.train(data)
counter=1000, loss=1.4304862022399902
counter=2000, loss=0.9842458963394165
counter=3000, loss=0.27240827679634094
counter=4000, loss=0.23295772075653076
counter=5000, loss=0.38499030470848083
counter=6000, loss=1.470423698425293
counter=7000, loss=0.845589816570282
counter=8000, loss=0.15707021951675415

绘制损失值变化图::

model.plot_progress()

在这里插入图片描述

七、测试结果

#torch.set_printoptions(precision=4,sci_mode=False) #pytorch不使用科学计数法显示

for data in test_loader:
    acc=model.test(data)
    print(acc)
tensor([0.8186], device='cuda:0')

可以看到,预测值误差小于1的占了81.86%,效果还行。

八、分类问题

对于图分类问题,其实也差不多。只需要修改下Linear网络层:

self.out = Linear(hidden_channels, dataset.num_classes)

这样预测结果就会有num_classes个,取最大值的下标索引即可。
伪代码为:

pred=outputs.argmax(dim=1)
correct += int((pred == data.y).sum())

参考文献

[1] https://pytorch-geometric.readthedocs.io/en/latest/get_started/colabs.html
[2] https://zhuanlan.zhihu.com/p/504978470

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/470987.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端开发中获取各种高度宽度

一、前言 前端开发中经常需要获取页面还有屏幕的高度和宽度进行计算,此文即介绍如何用 JavaScript 获取这些尺寸 二、屏幕尺寸 screen.height&#xff1a;屏幕高度screen.width&#xff1a;屏幕宽度screen.availHeight&#xff1a;屏幕可用高度。即屏幕高度减去上下任务栏后的…

2.2 定点加法 减法运算

学习前的建议 以下是一些学习定点加法和减法运算的建议&#xff1a; 掌握定点数的表示方法&#xff1a;在进行定点加法和减法运算之前&#xff0c;需要先了解定点数的表示方法&#xff0c;包括定点数的位数、小数点位置以及符号位等信息。 理解定点加法和减法的原理&#xf…

nginx配置sh脚本远程执行一键安装

背景 本地多机重复操作某些shell指令&#xff0c;分步执行&#xff0c;很耗费时间&#xff0c; 需要远程一键部署&#xff0c;傻瓜化运维&#xff0c;更为通用安装。 即参考docker通用安装 sudo curl https://get.docker.com | sh - # sudo python3 -m pip install docker-co…

SignalR实现简单的Web端实时通讯,跳过WebSocket验证,Swagger加锁后不能访问接口,Script setup不支持动态绑定

版本.Net6Vue3Element-Plus 问题 Swagger加锁后不能访问接口 &#xff08;看第三步&#xff09;跳过WebSocket验证 &#xff08;看第四步里面&#xff09;添加自定义接受方法 &#xff08;看第四步&#xff09;不能使用 第一步、下载包 后端&#xff1a; 前端&#xff1a;…

Android内存优化场景

1、集合类 内存泄露原因 集合类 添加元素后&#xff0c;仍引用着 集合元素对象&#xff0c;导致该集合元素对象不可被回收&#xff0c;从而 导致内存泄漏实例演示 // 通过循环申请Object 对象 & 将申请的对象逐个放入到集合List List<Object> objectList new Arra…

VBA-自定义面板,使用SQL查询Excel数据

需求 定制插件&#xff0c;实现用户打开任意一个工作簿&#xff0c;写sql对Excel中的数据进行查询 案例sql需求场景&#xff1a; 需求 筛选日期小于’2023-4-24’&#xff0c;按group分区&#xff0c;求和各分组下的销售额&#xff0c;返回结果集新建工作表写入 数据源 现…

Docker-compose 启动 lnmp 开发环境

GitHub传送阵 docker-lnmp 项目帮助开发者快速构建本地开发环境&#xff0c;包括Nginx、PHP、MySQL、Redis 服务镜像&#xff0c;支持配置文件和日志文件映射&#xff0c;不限操作系统&#xff1b;此项目适合个人开发者本机部署&#xff0c;可以快速切换服务版本满足学习服务新…

国产开源项目管理软件ZenTao

本文应网友 ukiyoec 要求而写&#xff1b; 什么是禅道 &#xff1f; 禅道 (ZenTao)是国产开源项目管理软件。它集产品管理、项目管理、质量管理、文档管理、组织管理和事务管理于一体&#xff0c;是一款专业的研发项目管理软件&#xff0c;完整覆盖了研发项目管理的核心流程。禅…

2023-Hive性能企业级调优

Hive作为大数据平台举足轻重的框架&#xff0c;以其稳定性和简单易用性也成为当前构建企业级数据仓库时使用最多的框架之一。 但是如果我们只局限于会使用Hive&#xff0c;而不考虑性能问题&#xff0c;就难搭建出一个完美的数仓&#xff0c;所以Hive性能调优是我们大数据从业…

前端周总结

在vue里面引入ts文件报错&#xff1a; An import path cannot end with a .ts extension. Consider importing xx.js instead. 方法一&#xff08;最快&#xff09; 把引入的xx.ts后缀删除 方法二 # 在tsconfig.json中加入以下配置 "baseUrl": ".", &quo…

Oracle LiveLabs实验:DB Security - Data Masking and Subsetting (DMS)

概述 本实验介绍了适用于 Enterprise Manager 的 Oracle 数据屏蔽和子集 (DMS) 包的各种特性和功能。 它使用户有机会学习如何配置这些功能&#xff0c;以便在非生产环境中保护他们的敏感数据。 此实验申请地址在这里&#xff0c;时间为60分钟。 本实验也是DB Security Adva…

String AOP

AOP AOP(Aspect Object programmar) 面向切面编程&#xff0c;它是对某一类问题的统一处理&#xff0c;而StringAOP就是AOP思想的一种具体实现就像Ioc和DI。 AOP组成 切面(Aspect) 切⾯&#xff08;Aspect&#xff09;由切点&#xff08;Pointcut&#xff09;和通知&#x…

论文阅读笔记《Grounded Action Transformation for Robot Learning in Simulation》

Grounded Action Transformation for Robot Learning in Simulation 发表于AAAI 2017 仿真机器人学习中的接地动作变换 Hanna J, Stone P. Grounded action transformation for robot learning in simulation[C]//Proceedings of the AAAI Conference on Artificial Intellig…

Linux中的阻塞机制

我们知道在字符设备驱动中&#xff0c;应用层调用read、write等系统调用终会调到驱动中对应的接口。 可以当应用层调用read要去读硬件的数据时&#xff0c;硬件的数据未准备好&#xff0c;那我们该怎么做&#xff1f; 一种办法是直接返回并报错&#xff0c;但是这样应用层要获得…

linux通配符和正则表达式深层解析...

目录&#xff1a; (一)了解通配符和正则的作用 (二)通配符的使用 (三)正则表达式的使用 (四)扩展正则表达式的使用 (一)了解通配符和正则的作用 (1.1)在我们日常的工作中&#xff0c;我们都会使用到通配符或者正则表达式。通配符是一种特殊语句&#xff0c;主要有星号(*)和问号…

交换机和路由器到底有什么区别???

我&#xff1a;度娘度娘&#xff0c;交换机和路由器的区别是什么呢&#xff1f; 度娘&#xff1a;一个工作在第二层数据链路层&#xff0c;一个工作在第三层网络层。 我&#xff1a;哈&#xff1f;那工作在不同层会有什么区别&#xff1f;为什么要工作在不同层&#xff1f; …

2023五一数学建模A题完整思路

已更新五一数学建模A题思路&#xff0c;文章末尾获取&#xff01; A题完整思路&#xff1a; A题是一个动力学问题&#xff0c;需要我们将物理学概念运用到实际生活中&#xff0c;我们可以先看题目 问题1&#xff1a; 假设无人机以平行于水平面的方式飞行&#xff0c;在空中投…

Windows11安装sqlserver2012失败后解决方案

首先卸载 WinR打开运行输入services.msc查看所有服务/或者我的电脑管理找到服务列表/任务管理器进入服务列表&#xff0c;停止所有与Sql Server有关的服务&#xff0c;如下&#xff1a; 打开控制面板-卸载sqlserver所有相关软件&#xff1b; 删除SQL Server相关注册表&#…

【观察】中国软件行业进入“重构期”,看浪潮海岳如何“开新局”

众所周知&#xff0c;改开四十多来年&#xff0c;中国软件产业在经历了萌芽与低谷、摸索与转型后&#xff0c;逐步进入了快速发展期。特别是过去几年&#xff0c;在新的发展格局&#xff0c;信创替代的进程中&#xff0c;整个中国软件业更是加速进入了全新的“重构期”。 在此过…

Unity API详解——Quaternion类

Quaternion类又称四元数&#xff0c;由x、y、z和w这4个分量组成&#xff0c;属于struct类型。在Unity中&#xff0c;用Quaternion来存储和表示对象的旋转角度。Quaternion的变换比较复杂&#xff0c;对于GameObject一般的旋转及移动&#xff0c;可以用Transform中的相关方法实现…