PyTorch Lightning:通过分布式训练扩展深度学习工作流

news2025/1/15 17:33:17

 

一、介绍

        欢迎来到我们关于 PyTorch Lightning 系列的第二篇文章!在上一篇文章中,我们向您介绍了 PyTorch Lightning,并探讨了它在简化深度学习模型开发方面的主要功能和优势。我们了解了 PyTorch Lightning 如何为组织和构建 PyTorch 代码提供高级抽象,使研究人员和从业者能够更多地关注模型设计和实验,而不是样板代码。

        在本文中,我们将深入研究 PyTorch Lightning,并探索它如何通过分布式训练实现深度学习工作流的扩展。分布式训练对于在海量数据集上训练大型模型至关重要,因为它允许我们利用多个 GPU 或机器的强大功能来加速训练过程。然而,分布式训练往往伴随着一系列挑战和复杂性。

二、安装 Pytorch Lightning & Torchvision

pip install torch torchvision pytorch-lightning 

三、实现

        首先,我们需要从 PyTorch 和 PyTorch Lightning 导入必要的模块:

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms

import pytorch_lightning as pl

        接下来,我们使用 PyTorch 的类定义我们的神经网络架构。在这个例子中,我们使用一个简单的卷积神经网络,其中包含两个卷积层和三个全连接层:nn.Module

class Net(pl.LightningModule):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

        然后,我们为 .在该方法中,我们接收一批输入和标签,将它们通过我们的神经网络来获取 logits,计算交叉熵损失,并使用该方法记录训练损失。在该方法中,我们执行与 相同的操作,但不记录损失:LightningModuletraining_stepxyself.logvalidation_steptraining_step

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.log("val_loss", loss)
        return loss

        我们还在方法中定义了优化器和学习率调度器:configure_optimizers

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
        return [optimizer], [scheduler]

        接下来,我们使用 PyTorch 和 定义数据加载和预处理步骤:DataLoadertransforms

    def prepare_data(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        CIFAR10(root='./data', train=True, download=True, transform=transform)
        CIFAR10(root='./data', train=False, download=True, transform=transform)

    def train_dataloader(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        train_dataset = CIFAR10(root='./data', train=True, download=False, transform=transform)
        return DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)

    def val_dataloader(self):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        val_dataset = CIFAR10(root='./data', train=False, download=False, transform=transform)
        return DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=8)
  1. prepare_data(self):此函数负责在训练模型之前准备数据。它首先使用该类定义一系列转换。转换包括将数据转换为张量并对其进行规范化。定义转换后,该函数将下载用于训练和测试拆分的 CIFAR10 数据集。数据集将下载到目录,并将指定的转换应用于数据。transforms.Compose'./data'
  2. train_dataloader(self):此函数为训练数据集创建数据加载器。它首先定义与函数中相同的转换。接下来,它为训练拆分创建 CIFAR10 数据集的实例。从目录中加载数据集,并应用指定的转换。最后,使用训练数据集创建一个对象。数据加载程序配置为 64 的批大小,对数据进行随机排序,并使用 8 个工作线程进行数据加载。它返回数据加载器。prepare_data'./data'DataLoader
  3. val_dataloader(self):此函数为验证数据集创建数据加载器。它遵循与函数类似的结构。它首先使用 定义转换,这些转换与前面的函数相同。然后,为验证拆分创建 CIFAR10 数据集的实例。从目录中加载数据集,并应用指定的转换。最后,使用验证数据集创建一个对象。数据加载器配置为 64 的批大小,无需随机处理数据,并使用 8 个工作线程进行数据加载。它返回数据加载器。train_dataloadertransforms.Compose'./data'DataLoader

        该函数将模型作为输入,并对测试数据集执行评估。它首先对测试数据应用转换,将其转换为张量并规范化。然后,它为测试数据集创建数据加载程序。模型将移动到相应的设备(GPU,如果可用)。评估标准设置为交叉熵损失。evaluate_model

def evaluate_model(model):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=8)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()

    model.eval()
    test_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for data in test_loader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100.0 * correct / total
    average_loss = test_loss / len(test_loader)

    print(f"Test Loss: {average_loss:.4f}")
    print(f"Test Accuracy: {accuracy:.2f}%")

        将模型置于评估模式,并初始化测试损失、正确预测和总数据点的变量。在无梯度上下文中,该函数遍历测试数据加载器,通过模型转发成批的输入,计算损失并累积测试损失。它还计算正确预测的数量和数据点的总数。最后,它计算并打印平均测试损失和测试精度。

        最后,我们实例化我们的模型和来自 PyTorch Lightning,指定用于分布式训练的所需数量的 GPU 或机器:NetTrainer

net = Net()

trainer = pl.Trainer(
    
    num_nodes=1,  # Change to the number of machines in your distributed setup
    accelerator="auto",  # Distributed Data Parallel, Available names are: auto, cpu, cuda, hpu, ipu, mps, tpu.
    max_epochs=5, 
    devices=1 # Change to the desired number of GPUs or use `None` for CPU training
)

trainer.fit(net)

evaluate_model(net)
  • num_nodes:它指定分布式设置中的计算机数量。在这种情况下,它设置为 ,表示单台计算机设置。1
  • accelerator:它确定训练的加速器类型。该值允许 PyTorch Lightning 根据硬件和软件环境自动选择适当的加速器。其他可能的值包括 、 和 ,它们对应于特定的硬件加速器。"auto""cpu""cuda""hpu""ipu""mps""tpu"
  • max_epochs:它设置用于训练模型的最大周期数(通过训练数据集的完整遍历)。在本例中,它设置为 。5
  • devices:它指定用于训练的 GPU 数量。将其设置为 表示使用单个 GPU 进行训练。如果要在 CPU 上进行训练,可以将其设置为 。1None

        这些选项允许您控制训练过程的各个方面,例如分布式训练、加速器选择以及用于训练的周期数和设备数。

        设置好所有内容后,我们只需调用对象的方法,传入我们的模型、训练数据加载器和验证数据加载器。fitTrainerNet

四、输出

 

五、结论

        PyTorch Lightning 通过分布式训练简化了扩展深度学习工作流的过程。通过抽象化分布式训练的复杂性,PyTorch Lightning 使我们能够专注于设计和实现我们的深度学习模型,而不必担心低级细节。在本文中,我们演练了一个使用 PyTorch Lightning 进行分布式训练的示例代码实现。通过利用多个GPU或机器的强大功能,我们可以显著减少大型深度学习模型的训练时间。

六、引用

  • PyTorch Lightning: Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 2.1.0.rc0 documentation
  • PyTorch: PyTorch
  • torchvision.datasets.CIFAR10: Datasets — Torchvision 0.15 documentation
  • torch.utils.data.DataLoader: torch.utils.data — PyTorch 2.0 documentation
  • 火炬亚当:Adam — PyTorch 2.0 documentation
  • torch.optim.lr_scheduler。步长:StepLR — PyTorch 2.0 documentation
  • Torch.nn.CrossEntropyLoss: CrossEntropyLoss — PyTorch 2.0 documentation
  • torch.cuda.is_available:torch.cuda — PyTorch 2.0 documentation

阿奈·东格雷

皮托奇

分布式系统

深度学习
皮托奇闪电
计算机视觉

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/900348.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3种获取OpenStreetMap数据的方法【OSM】

OpenStreetMap 是每个人都可以编辑的世界地图。 这意味着你可以纠正错误、添加新地点,甚至自己为地图做出贡献! 这是一个社区驱动的项目,拥有数百万注册用户。 这是一个社区驱动的项目,旨在在开放许可下向每个人提供所有地理数据。…

基于YOLOv8模型的奶牛目标检测系统(PyTorch+Pyside6+YOLOv8模型)

摘要:基于YOLOv8模型的奶牛目标检测系统可用于日常生活中检测与定位奶牛目标,利用深度学习算法可实现图片、视频、摄像头等方式的目标检测,另外本系统还支持图片、视频等格式的结果可视化与结果导出。本系统采用YOLOv8目标检测算法训练数据集…

【小梦C嘎嘎——启航篇】vector 以及日常使用的接口介绍

【小梦C嘎嘎——启航篇】vector 日常使用的接口介绍😎 前言🙌vector 是什么?vector 比较常使用的接口 总结撒花💞 😎博客昵称:博客小梦 😊最喜欢的座右铭:全神贯注的上吧&#xff01…

Parking Steps

上面是老师傅说的停车步骤,说这样不会伤变速箱。 平时就是,脚踩刹车,直接从D档撸到P档,拉手刹,哈哈。 你的停车步骤是啥。。

redis 存储结构原理 2

咱们接着上一部分来进行分享,我们可以在如下地址下载 redis 的源码: https://redis.io/download 此处我下载的是 redis-6.2.5 版本的,xdm 可以直接下载上图中的 **redis-6.2.6 **版本, redis 中 hash 表的数据结构 redis hash …

php_mb_strlen指定扩展

1 中文在utf-字符集下占3个字节,所以计算出来长度为9。 2 可以引入php多字节字符的扩展,默认是没有的,需要自己配置这个函数 3 找到php.ini文件,去掉;extension mbstring的注释,接着重启apache服务 可以看到准确输出的中文的长度…

javascript期末作业【三维房屋设计】

1、引入three.js库 官网下载three.js 库 放置目录并引用 引入js文件: 设置场景(scene) (1)创建场景对象 (2)设置透明相机 1,透明相机的优点 透明相机机制更符合于人的视角,在场景预览和游戏场景多有使用…

视频怎么转gif高清动图?分享一款视频转gif工具

许多小伙伴都不知道如何将拍摄的短视频转gif图片,本文将分享一款专业的视频转gif工具,打来浏览器即可将视频在线转gif(https://www.gif.cn),操作简单,使用方便,下面是详细的步骤。 打开网站&am…

SpringBoot案例-员工管理-新增员工

查看页面原型,明确需求 页面原型 需求 阅读接口文档 接口文档链接如下: 【腾讯文档】SpringBoot案例所需文档 https://docs.qq.com/doc/DUkRiTWVaUmFVck9N 思路分析 阅读需求文档后可知,前端发送请求的同时,将前端请求参数以…

centos8 使用phpstudy安装tomcat部署web项目

系统配置 1、安装Tomcat 2、问题 正常安装完Tomcat应该有个配置选项,用来配置server.xml web.xml 还有映射webapps路径选项,但是我用的这个版本并没有。所以只能曲线救国。 3、解决 既然没有配置项,那就只能按最基本的方法配置&#xff0c…

算法之排序总结

排序算法 最近,一直在学习业务上的知识,对基础没有怎么重视,因此,这篇文章想对于排序算法进行一个大致的总结🤓🤓🤓。 首先来说一下,关于排序一些相关的基础知识。 排序概述 原地…

代码随想录第25天|216.组合总和III ​​​​​​​,17. 电话号码的字母组合

216.组合总和III 回溯三部曲 确定递归函数参数 targetSum(int)目标和,也就是题目中的n。k(int)就是题目中要求k个数的集合。sum(int)为已经收集的元素的总和,也就是path里元素的…

(学习笔记-进程管理)什么是悲观锁、乐观锁?

互斥锁与自旋锁 最底层的两种就是 [互斥锁和自旋锁],有很多高级的锁都是基于它们实现的。可以认为它们是各种锁的地基,所以我们必须清楚它们之间的区别和应用。 加锁的目的就是保证共享资源在任意时间内,只有一个线程访问,这样就…

LabVIEW模拟化学反应器的工作

LabVIEW模拟化学反应器的工作 近年来,化学反应器在化学和工业过程领域有许多应用。高价值产品是通过混合产品,化学反应,蒸馏和结晶等多种工业过程转换原材料制成的。化学反应器通常用于大型加工行业,例如酿酒厂公司饮料产品的发酵…

C 基础拾遗

C基础拾遗 预处理器 预处理器 14.1 预定义符号 14.2 #define

5种常见的3D游戏艺术风格及工具栈

在游戏开发领域,3D 艺术风格已成为为玩家创造身临其境、引人入胜的体验的重要组成部分。 随着技术的进步,创造令人惊叹的 3D 视觉效果的可能性已经大大扩展,为游戏开发人员提供了广泛的选择。 在本文中,我们将探讨当今游戏开发中…

Seaborn数据可视化(一)

目录 1.seaborn简介 2.Seaborn绘图风格设置 21.参数说明: 2.2 示例: 1.seaborn简介 Seaborn是一个用于数据可视化的Python库,它是建立在Matplotlib之上的高级绘图库。Seaborn的目标是使绘图任务变得简单,同时产生美观且具有信…

micropython SSD1306/SSD1315驱动

目录 简介 代码 功能 显示ASCII字符 ​编辑 画任意直线 画横线 画竖线 画矩形 画椭圆 画立方体 画点阵图 翻转 反相 滚动 横向滚动 纵向滚动 奇葩滚动 简介 我重新写了一个驱动,增加了一些功能,由于我的硬件是128*64oled单色I2C,我只…

lvs-DR模式:

lvs-DR数据包流向分析 客户端发送请求到 Director Server(负载均衡器),请求的数据报文(源 IP 是 CIP,目标 IP 是 VIP)到达内核空间。 Director Server 和 Real Server 在同一个网络中,数据通过二层数据链路…

08.异常处理与异常Hook(软件断点Hook,硬件断点Hook)

文章目录 异常处理异常Hook&#xff1a;VEH软件断点HOOKVEH硬件断点HOOK 异常处理 1.结构化异常SEH #include <iostream>int main() {goto Exit;__try {//受保护节int a 0;int b 0;int c a / b;std::cout << "触发异常" << std::endl;}/*EXCE…