【教程】DGL单机多卡分布式GCN训练

news2024/11/29 10:48:13

转载请注明出处:小锋学长生活大爆炸[xfxuezhagn.cn]

如果本文帮助到了你,欢迎[点赞、收藏、关注]哦~

        PyTorch中的DDP会将模型复制到每个GPU中。

        梯度同步默认使用Ring-AllReduce进行,重叠了通信和计算。

        示例代码:

视频:https://youtu.be/Cvdhwx-OBBo

代码:multigpu.py

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import os

import dgl
from dgl.data import RedditDataset
from dgl.nn.pytorch import GraphConv


def ddp_setup(rank, world_size):
    """
    DDP初始化设置。
    
    参数:
        rank (int): 当前进程的唯一标识符。
        world_size (int): 总进程数。
    """
    os.environ["MASTER_ADDR"] = "localhost"  # 设置主节点地址
    os.environ["MASTER_PORT"] = "12355"      # 设置主节点端口
    init_process_group(backend="nccl", rank=rank, world_size=world_size)  # 初始化进程组
    torch.cuda.set_device(rank)  # 设置当前进程使用的GPU设备


class GCN(torch.nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        """
        初始化图卷积网络(GCN)。
        
        参数:
            in_feats (int): 输入特征的维度。
            h_feats (int): 隐藏层特征的维度。
            num_classes (int): 输出类别的数量。
        """
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)  # 第一层图卷积
        self.conv2 = GraphConv(h_feats, num_classes)  # 第二层图卷积

    def forward(self, g, in_feat):
        """
        前向传播。
        
        参数:
            g (DGLGraph): 输入的图。
            in_feat (Tensor): 输入特征。
        
        返回:
            Tensor: 输出的logits。
        """
        h = self.conv1(g, in_feat)  # 进行第一层图卷积
        h = F.relu(h)  # ReLU激活
        h = self.conv2(g, h)  # 进行第二层图卷积
        return h


class Trainer:
    def __init__(
        self,
        model: torch.nn.Module,
        train_data: DataLoader,
        optimizer: torch.optim.Optimizer,
        gpu_id: int,
        save_every: int,
    ) -> None:
        """
        初始化训练器。
        
        参数:
            model (torch.nn.Module): 要训练的模型。
            train_data (DataLoader): 训练数据的DataLoader。
            optimizer (torch.optim.Optimizer): 优化器。
            gpu_id (int): GPU ID。
            save_every (int): 每隔多少个epoch保存一次检查点。
        """
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)  # 将模型移动到指定GPU
        self.train_data = train_data
        self.optimizer = optimizer
        self.save_every = save_every
        self.model = DDP(model, device_ids=[gpu_id])  # 使用DDP包装模型

    def _run_batch(self, batch):
        """
        运行单个批次。
        
        参数:
            batch: 单个批次的数据。
        """
        self.optimizer.zero_grad()  # 梯度清零
        graph, features, labels = batch
        graph = graph.to(self.gpu_id)  # 将图移动到GPU
        features = features.to(self.gpu_id)  # 将特征移动到GPU
        labels = labels.to(self.gpu_id)  # 将标签移动到GPU
        output = self.model(graph, features)  # 前向传播
        loss = F.cross_entropy(output, labels)  # 计算交叉熵损失
        loss.backward()  # 反向传播
        self.optimizer.step()  # 更新模型参数

    def _run_epoch(self, epoch):
        """
        运行单个epoch。
        
        参数:
            epoch (int): 当前epoch号。
        """
        print(f"[GPU{self.gpu_id}] Epoch {epoch} | Steps: {len(self.train_data)}")
        for batch in self.train_data:
            self._run_batch(batch)  # 运行每个批次

    def _save_checkpoint(self, epoch):
        """
        保存训练检查点。
        
        参数:
            epoch (int): 当前epoch号。
        """
        ckp = self.model.module.state_dict()  # 获取模型的状态字典
        PATH = "checkpoint.pt"  # 定义检查点路径
        torch.save(ckp, PATH)  # 保存检查点
        print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")

    def train(self, max_epochs: int):
        """
        训练模型。
        
        参数:
            max_epochs (int): 总训练epoch数。
        """
        for epoch in range(max_epochs):
            self._run_epoch(epoch)  # 运行当前epoch
            if self.gpu_id == 0 and epoch % self.save_every == 0:
                self._save_checkpoint(epoch)  # 保存检查点


def load_train_objs():
    """
    加载训练所需的对象:数据集、模型和优化器。
    
    返回:
        tuple: 数据集、模型和优化器。
    """
    data = RedditDataset(self_loop=True)  # 加载Reddit数据集,并添加自环
    graph = data[0]  # 获取图
    train_mask = graph.ndata['train_mask']  # 获取训练掩码
    features = graph.ndata['feat']  # 获取特征
    labels = graph.ndata['label']  # 获取标签

    model = GCN(features.shape[1], 128, data.num_classes)  # 初始化GCN模型
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)  # 初始化优化器
    train_data = [(graph, features, labels)]  # 准备训练数据
    
    return train_data, model, optimizer


def prepare_dataloader(dataset, batch_size: int):
    """
    准备DataLoader。
    
    参数:
        dataset: 数据集。
        batch_size (int): 批次大小。
    
    返回:
        DataLoader: DataLoader对象。
    """
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=True,
        collate_fn=lambda x: x[0]  # 自定义collate函数,解包数据集中的单个元素
    )


def main(rank: int, world_size: int, save_every: int, total_epochs: int, batch_size: int):
    """
    主训练函数。
    
    参数:
        rank (int): 当前进程的唯一标识符。
        world_size (int): 总进程数。
        save_every (int): 每隔多少个epoch保存一次检查点。
        total_epochs (int): 总训练epoch数。
        batch_size (int): 批次大小。
    """
    ddp_setup(rank, world_size)  # DDP初始化设置
    dataset, model, optimizer = load_train_objs()  # 加载训练对象
    train_data = prepare_dataloader(dataset, batch_size)  # 准备DataLoader
    trainer = Trainer(model, train_data, optimizer, rank, save_every)  # 初始化训练器
    trainer.train(total_epochs)  # 开始训练
    destroy_process_group()  # 销毁进程组


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Simple distributed training job')
    parser.add_argument('--total_epochs', default=50, type=int, help='Total epochs to train the model')
    parser.add_argument('--save_every', default=10, type=int, help='How often to save a snapshot')
    parser.add_argument('--batch_size', default=8, type=int, help='Input batch size on each device (default: 32)')
    args = parser.parse_args()
    
    world_size = torch.cuda.device_count()  # 获取可用GPU的数量
    mp.spawn(main, args=(world_size, args.save_every, args.total_epochs, args.batch_size), nprocs=world_size)  # 启动多个进程进行分布式训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1815631.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++怎么根据变量名称返回变量的值?

在开始前刚好我有一些资料,是我根据网友给的问题精心整理了一份「C的资料从专业入门到高级教程」, 点个关注在评论区回复“888”之后私信回复“888”,全部无偿共享给大家!!! 有点好奇你这么做是为了什么。…

[Redis] Redis Desktop Manager 安装包和连接和创建流程

1. 安装流程就是next,就可以。 2. 分别填写好: Name(自定义,redis这个库展示的名字), Host, Port, Auth(Redis 的连接password) 3. 要勾选上Use SSL Protocol 选项, 4. 连接到redis上,展示不同的database,…

cad标注尺寸很简单,这三个方法很好掌握!

在CAD(计算机辅助设计)的广阔领域中,标注尺寸是至关重要的一环。无论是初入CAD领域的新手,还是经验丰富的设计师,掌握标注尺寸的技巧都是提升工作效率和准确性的关键。今天,我们就来分享三个简单而实用的方…

AIGC数字人视频生成解决方案,赋能广电电视内容生产

AI数字人可以有效加大人工智能在内容生产的应用,推动广电电视节目创意生产,提高生产效率的同时,还能提升节目质量,增强互动呈现,为观众提供更加精彩的视听产品。 广州虚拟动力作为3D、AI数字人技术服务商及方案提供商…

5款非常好用的小众软件,你值得拥有

​ 今天为大家推荐五款不常见但好用的win10软件,它们都有着各自的特色和优势,相信你会喜欢的。 1. 文件夹查看——Folder Size View ​ Folder Size View是一款高效的文件夹大小查看工具,它能够快速扫描并展示文件夹及其子文件夹的占用空间…

np.array()按权重求平均值详解

代码如下: a np.array([[1, 4, 2, 6],[10, 41, 7, 3],[9, 1, 6, 2]]) v1 np.average(a, weights[3, 3, 4], axis0) print(v1) 运行结果 当执行这段代码时,np.average(a, weights[3, 3, 4], axis0)会根据指定的权重在列方向上计算加权平均值。 具体计…

告别枯燥:Python数据处理也可以很有趣

想要成为数据处理的超级英雄吗?阿佑将带你一探究竟!我们将深入数据村,学习如何使用Python的超能力处理各种复杂的数据格式。从解码错误和字符集问题的解决,到大数据量的性能优化,再到数据验证与清洗,每一个…

git clone 项目报“鉴权失败”的解决办法

#问题展示# git clone https://gitee.com/soaringsoft/.....git 正克隆到...... Username for https://gitee.com:...... Password for https://.....gitee.com:...... remote: [session-1440f183] Unauthorized fatal: git clone https://gitee.com/soaringsoft/.....gi…

SpringMVC框架学习笔记(七):处理 json 和 HttpMessageConverter 以及文件的下载和上传

1 处理 JSON-ResponseBody 说明: 项目开发中,我们往往需要服务器返回的数据格式是按照 json 来返回的 下面通过一个案例来演示SpringMVC 是如何处理的 (1) 在web/WEB-INF/lib 目录下引入处理 json 需要的 jar 包,注意 spring5.x…

推荐网站(22)GeoSpy,根据图片显示地理位置

今天推荐一款名为GeoSpy的AI工具。它利用人工智能技术,通过分析照片中的光线、植被、建筑风格等细节线索,实现对拍摄地点的精确定位。令人难以置信的是,它对位置的定位准确度非常高。 GeoSpy之所以智能如此,是因为它将输入的照片与大量的街景和地理图像…

夹层辊能否解决智能测径仪量程不足的问题?

关键字:智能测径仪,测径仪夹层辊,测径仪量程,夹层辊作用,测径仪量程不足, 智能测径仪是一种高精度的测量设备,主要用于检测线材、管材等圆柱形物体的直径尺寸。在测径仪中,夹层辊是测径仪的关键部件之一,它负责引导和支撑被测物体&#xff0c…

Astar路径规划算法复现-python实现

# -*- coding: utf-8 -*- """ Created on Fri May 24 09:04:23 2024"""import os import sys import math import heapq import matplotlib.pyplot as plt import time 传统A*算法 class Astar:AStar set the cost heuristics as the priorityA…

企业里面最常用的6大管理系统!附6个模板下载!

企业管理系统旨在帮助企业优化工作流程,提高工作效率的信息化系统。它通过对一些流程的规范,可以极大地减少企业存在的一些流程重复造成的浪费,并通过规范每个员工的动作来提高效率。企业在选择管理系统时,注重功能的全面性、流程…

CentOS7下快速升级至OpenSSH9.7p2安全版本

一、CentOS7服务器上编译生成OpenSSH9.3p2的RPM包 1、编译打包的shell脚本来源于该项目 https://github.com/boypt/openssh-rpms解压zip项目包 unzip openssh-rpms-main.zip -d /opt cd /opt/openssh-rpms-main/ vim pullsrc.sh 修改第23行为source ./version.env 2、sh pull…

山东大学软件学院项目实训-创新实训-基于大模型的旅游平台(三十一)- 微服务(11)

12.7 DSL查询语法 查询的基本语法 GET /indexName/_search{"query": {"查询类型": {"查询条件": "条件值"}}} 查询所有 GET /hotel/_search{"query": {"match_all": {}}} 12.7.1 全文检索查询 全文检索查询,会…

OZON云仓靠谱吗,OZON云仓垫资提货模式

在电商飞速发展的今天,物流仓储成为了支撑整个电商生态的重要基石。OZON云仓作为市场上新兴的仓储物流服务提供商,凭借其先进的技术和灵活的服务模式,受到了不少电商卖家和消费者的关注。但随之而来的是一系列疑问:OZON云仓靠谱吗…

【八股系列】react里组件通信有几种方式,分别怎样进行通信?

文章目录 1. props传递(父向子通信):2. 回调函数作为props(子向父通信):3. Context API:4. Redux或MobX等状态管理库:4.1 Redux使用示例 5. refs: 1. props传递(父向子通信&#xff…

社区新标准发布!龙蜥社区标准化 SIG MeetUp 圆满结束

5 月 31 日,「龙蜥社区“走进系列”」第 9 期之走进阿里云于北京圆满结束。来自阿里云、浪潮信息、红旗软件、中兴通讯|中兴新支点、中科曙光、中科方德、统信软件、麒麟软件、万里红、普华基础软件、飞腾信息、凝思、申威、新华三等公司的 30 余位专家出席会议。会…

C#开源软件:OneNote组件oneMore轻松打造自己的公众号编辑器

OneMore是一款为Microsoft OneNote设计的插件,它提供了许多扩展功能来增强OneNote的使用体验。 插件功能概述: OneMore插件拥有多达一百多个扩展功能,这些功能覆盖了笔记编辑、搜索、导出等多个方面,为用户提供了更加便捷和高效的…

人工智能超万卡集群的设计架构解读

超万卡集群的核心设计原则和总体架构 超万卡集群建设正起步,现主要依赖英伟达GPU及其配套设备。英伟达GPU在大模型训练中优势显著。国产AI芯片虽在政策与应用驱动下取得进步,但整体性能与生态建设仍有不足。构建一个基于国产生态、技术领先的超万卡集群&…