分布式执行引擎ray入门--(3)Ray Train

news2025/1/16 2:38:27

Ray Train中包含4个部分

  1. Training function: 包含训练模型逻辑的函数

  2. Worker: 用来跑训练的

  3. Scaling configuration: 配置

  4. Trainer: 协调以上三个部分

Ray Train+PyTorch

这一块比较建议直接去官网看diff,官网色块标注的比较清晰,非常直观。

import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray.train.torch

def train_func(config):
    # Model, Loss, Optimizer
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    # model.to("cuda")  # This is done by `prepare_model`
    # [1] Prepare model.
    model = ray.train.torch.prepare_model(model)
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    # Data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
    # [2] Prepare dataloader.
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    # Training
    for epoch in range(10):
        for images, labels in train_loader:
            # This is done by `prepare_data_loader`!
            # images, labels = images.to("cuda"), labels.to("cuda")
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # [3] Report metrics and checkpoint.
        metrics = {"loss": loss.item(), "epoch": epoch}
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                model.module.state_dict(),
                os.path.join(temp_checkpoint_dir, "model.pt")
            )
            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)

# [4] Configure scaling and resource requirements.
scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True)

# [5] Launch distributed training job.
trainer = ray.train.torch.TorchTrainer(
    train_func,
    scaling_config=scaling_config,
    # [5a] If running in a multi-node cluster, this is where you
    # should configure the run's persistent storage that is accessible
    # across all worker nodes.
    # run_config=ray.train.RunConfig(storage_path="s3://..."),
)
result = trainer.fit()

# [6] Load the trained model.
with result.checkpoint.as_directory() as checkpoint_dir:
    model_state_dict = torch.load(os.path.join(checkpoint_dir, "model.pt"))
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    model.load_state_dict(model_state_dict)

模型 

  ray.train.torch.prepare_model() 

model = ray.train.torch.prepare_model(model)
相当于model.to(device_id or "cpu") +  DistributedDataParallel(model, device_ids=[device_id])

将model移动到合适的device上,同时实现分布式

数据

ray.train.torch.prepare_data_loader() 

报告 checkpoints 和 metrics

+import ray.train
+from ray.train import Checkpoint

 def train_func(config):

     ...
     torch.save(model.state_dict(), f"{checkpoint_dir}/model.pth"))
+    metrics = {"loss": loss.item()} # Training/validation metrics.
+    checkpoint = Checkpoint.from_directory(checkpoint_dir) # Build a Ray Train checkpoint from a directory
+    ray.train.report(metrics=metrics, checkpoint=checkpoint)

     ...
data_loader = ray.train.torch.prepare_data_loader(data_loader)

将batches移动到合适的device上,同时实现分布式sampler

配置 scale 和 GPUs

from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=2, use_gpu=True)

配置持久化存储

多节点分布式训练时必须指定,本地路径会有问题。

from ray.train import RunConfig

# Local path (/some/local/path/unique_run_name)
run_config = RunConfig(storage_path="/some/local/path", name="unique_run_name")

# Shared cloud storage URI (s3://bucket/unique_run_name)
run_config = RunConfig(storage_path="s3://bucket", name="unique_run_name")

# Shared NFS path (/mnt/nfs/unique_run_name)
run_config = RunConfig(storage_path="/mnt/nfs", name="unique_run_name")

启动训练任务

from ray.train.torch import TorchTrainer

trainer = TorchTrainer(
    train_func, scaling_config=scaling_config, run_config=run_config
)
result = trainer.fit()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1506959.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

APP2:android studio如何使用lombok

一、前言 不知道从哪个版本开始,android studio便无法在plugins中下载lombok了,有人说是内置了,好像有这么回事儿。我主要面临如下两个问题: 使用内置lombok,可以自动生成setter、setter、toString等。但是&#xff0…

政安晨:【深度学习处理实践】(五)—— 初识RNN-循环神经网络

RNN(循环神经网络)是一种在深度学习中常用的神经网络结构,用于处理序列数据。与传统的前馈神经网络不同,RNN通过引入循环连接在网络中保留了历史信息。 RNN中的每个神经元都有一个隐藏状态,它会根据当前输入和前一个时…

【QT+QGIS跨平台编译】之七十:【QGIS_Analysis跨平台编译】—【qgsrastercalcparser.cpp生成】

文章目录 一、Bison二、生成来源三、构建过程一、Bison GNU Bison 是一个通用的解析器生成器,它可以将注释的无上下文语法转换为使用 LALR (1) 解析表的确定性 LR 或广义 LR (GLR) 解析器。Bison 还可以生成 IELR (1) 或规范 LR (1) 解析表。一旦您熟练使用 Bison,您可以使用…

free pascal 调用 C#程序读 Freeplane.mm文件,生成测试用例.csv文件

C# 请参阅:C# 用 System.Xml 读 Freeplane.mm文件,生成测试用例.csv文件 Freeplane 是一款基于 Java 的开源软件,继承 Freemind 的思维导图工具软件,它扩展了知识管理功能,在 Freemind 上增加了一些额外的功能&#x…

构建LVS集群

一、集群的基本理论(一)什么是集群 人群或事物聚集:在日常用语中,群集指的是一大群人或事物密集地聚在一起。例如,“人们群集在广场上”,这里的“群集”是指大量人群聚集的现象。 计算机技术中的集群&…

吴恩达机器学习-可选实验室:逻辑回归(Logistic Regression))

在这个不评分的实验中,你会探索sigmoid函数(也称为逻辑函数)探索逻辑回归;哪个用到了sigmoid函数 import numpy as np %matplotlib widget import matplotlib.pyplot as plt from plt_one_addpt_onclick import plt_one_addpt_onclick from lab_utils_common impor…

批量提取PDF指定区域内容到 Excel 以及根据PDF里面第一页的标题来批量重命名-附思路和代码实现

首先说明下,PDF需要是电子版本的,不能是图片或者无法选中的那种。 需求1:假如我有一批数量比较多的同样格式的PDF电子文档,需要把特定多个区域的数字或者文字提取出来 需求2:我有一批PDF文档,但是文件的名…

【CSP试题回顾】202006-1-线性分类器

CSP-202006-1-线性分类器 解题思路 线性分类问题,即根据给定的数据点和分类界限,判断是否存在一条线能够将属于不同类别的点完全分开。具体来说,数据点被分为两类,标记为A和B,我们要找出是否存在一个线性决策边界&…

神经网络实战前言

应用广泛 从人脸识别到网约车,在生活中无处不在 未来可期 无人驾驶技术便利出行医疗健康改善民生 产业革命 第四次工业革命——人工智能 机器学习概念 机器学习不等价与人工智能20世纪50年代,人工智能是说机器模仿人类行为的能力 符号人工智能 …

官方安装配置要求服务器最低2核4G

官方安装配置要求服务器至少2核、4G。 如果服务器低于这个要求,就没有必要安装,因为用户体验超级差。 对于服务器CPU来说,建议2到4核就完全足够了,太多就浪费了,但是内存越大越好,最好是4G以上。 如果服务器…

XSS攻击场景分析

XSS攻击场景分析 在目前这个时间节点还是属于一个排位比较高的漏洞,在OWASP TOP10 2021中隶属于注入型漏洞,高居TOP3的排位,可见这个漏洞的普遍性。跨站脚本攻击的学习中我们主要需要明白的是跨站的含义,以及XSS的核心。XSS主流分…

CentOS 7安装MySQL及常见问题与解决方案(含JDBC示例与错误处理)

引言 MySQL是一个流行的开源关系型数据库管理系统,广泛应用于各种业务场景。在CentOS 7上安装MySQL后,我们通常需要使用JDBC(Java Database Connectivity)连接MySQL进行后端操作。 目录 引言 CentOS 7安装MySQL 使用JDBC连接My…

LLM Drift(漂移), Prompt Drift Cascading(级联)

原文地址:LLM Drift, Prompt Drift & Cascading 提示链接可以手动或自动执行;手动需要通过 GUI 链构建工具手工制作链。自治代理在执行时利用可用的工具动态创建链。这两种方法都容易受到级联、LLM 和即时漂移的影响。 2024 年 2 月 23 日 在讨论大型…

Java对接(BSC)币安链 | BNB与BEP20的开发实践(二)BNB转账、BEP20转账、链上交易监控

上一节我们主要是环境搭建,主要是为了能够快速得去开发,有些地方只是简单的介绍,比如ETH 、web3j等等这些。 这一节我们来用代码来实现BNB转账、BEP20转账、链上交易监控 话不多说,我们直接用代码实现吧 1. BNB转账 /*** BNB转…

Python判断语句+循环语句

一、Python判断语句 1.1 布尔类型和比较运算符 # 定义变量存储布尔类型的数据 bool_1 True bool_2 False print( f"bool_1变量的内容是:{ bool_1 },类型为:{ type( bool_1 ) }" ) print( f"bool_2变量的内容是:{…

打卡--MySQL8.0 一(单机部署)

一路走来,所有遇到的人,帮助过我的、伤害过我的都是朋友,没有一个是敌人。如有侵权,请留言,我及时删除! MySQL 8.0 简介 MySQL 8.0与5.7的区别主要体现在:1、性能提升;2、新的默认…

ELFK 分布式日志收集系统

ELFK的组成: Elasticsearch: 它是一个分布式的搜索和分析引擎,它可以用来存储和索引大量的日志数据,并提供强大的搜索和分析功能。 (java语言开发,)logstash: 是一个用于日志收集,处理和传输的…

04hive数仓内外部表复杂数据类型与分区分桶

hive内部表和外部表 默认为内部表,外部表的关键字 :external内部表:对应的文件夹就在默认路径下 /user/hive/warehouse/库名.db/外部表:数据文件在哪里都行,无须移动数据 # students.txt 1,Lucy,girl,23 2,Tom,boy,2…

2023年终总结——跌跌撞撞不断修正

目录 一、回顾1.一月,鼓足信心的开始2.二月,焦躁不安3.三月,路还是要一步一步的走4.四月,平平淡淡的前行5.五月,轰轰烈烈的前行6.六月,看事情更底层透彻了7.七月,设计模式升华月8.八月&#xff…

加速 Webpack 构建:提升效率的秘诀

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…