【AI基础】pytorch lightning 基础学习

news2025/1/10 1:40:07

传统pytorch工作流是首先定义模型框架,然后写训练和验证,测试循环代码。训练,验证,测试代码写起来比较繁琐。这里介绍使用pytorch lightning 部署模型,加速模型训练和验证,记录。

准备工作

1 安装pytorch lightning 检查版本

$ conda create -n lightning python=3.9 -y
$ conda activate lightning
import lightning as L
import torch

print("Lightning version:", L.__version__)
print("Torch version:", torch.__version__)
print("CUDA is available:", torch.cuda.is_available())

2 加载基本库函数

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import lightning as L
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers.tensorboard import TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

3 设置随机种子(可复现性)

L.seed_everything(1121218)

4 数据集下载和增强变换

这里以CIFAR10数据集为例子,该数据集包含 10 个类的 6 万张 32x32 彩色图像,每个类 6000 张图像。

from torchvision import datasets, transforms

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(
   root="./data", train=True, download=True, transform=transform_train
)
val_dataset = datasets.CIFAR10(
   root="./data", train=False, download=True, transform=transform_test
)
# Data augmentation and normalization for training
transform_train = transforms.Compose(
   [
       transforms.RandomCrop(32, padding=4),
       transforms.RandomHorizontalFlip(),
       transforms.ToTensor(),
       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
   ],
)
transform_test = transforms.Compose(
   [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

上面的增强变换包括以下四种基本变换: 

  • 裁剪(需要指定图像大小,在本例中为 32x32)。
  • 水平翻转。
  • 转换为张量数据类型,这是 PyTorch 所必需的。
  • 对图像的每个颜色通道进行归一化处理。

传统pytorch模型训练流

定义一个CNN模型

class CIFAR10CNN(nn.Module):
   def __init__(self):
       super(CIFAR10CNN, self).__init__()
       self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
       self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
       self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
       self.pool = nn.MaxPool2d(2, 2)
       self.fc1 = nn.Linear(64 * 4 * 4, 512)
       self.fc2 = nn.Linear(512, 10)
   def forward(self, x):
       x = self.pool(torch.relu(self.conv1(x)))
       x = self.pool(torch.relu(self.conv2(x)))
       x = self.pool(torch.relu(self.conv3(x)))
       x = x.view(-1, 64 * 4 * 4)
       x = torch.relu(self.fc1(x))
       x = self.fc2(x)
       return x

编写训练、验证循环代码

  • 需要初始化模型,损失函数和优化器
  • 管理模型和数据在机器上的运行(CPU 与 GPU)
  • 训练步骤:前向传播、损失计算、反向传播和优化
  • 验证步骤:计算准确性和损失
  • tensorboard日志记录,训练损失,准确率,其他相关指标记录等
  • 模型保存
  • # Initialize the model, loss function, and optimizer
    model = CIFAR10CNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
    
    # TensorBoard setup
    writer = SummaryWriter('runs/cifar10_cnn_experiment')
    
    # Training loop
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)
    
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
    
            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            train_loss += loss.item()
    
            if (i+1) % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{total_step}], Loss: {loss.item():.4f}')
    
        # Calculate average training loss for the epoch
        avg_train_loss = train_loss / len(train_loader)
        writer.add_scalar('training loss', avg_train_loss, epoch)
    
        # Validation
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            val_loss = 0.0
            for images, labels in test_loader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
    
            accuracy = 100 * correct / total
            avg_val_loss = val_loss / len(test_loader)
            print(f'Validation Accuracy: {accuracy:.2f}%')
            writer.add_scalar('validation loss', avg_val_loss, epoch)
            writer.add_scalar('validation accuracy', accuracy, epoch)
    
        # Learning rate scheduling
        scheduler.step(avg_val_loss)
    
    # Final test
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
        print(f'Test Accuracy: {100 * correct / total:.2f}%')
    
    writer.close()
    
    # Save the model
    torch.save(model.state_dict(), 'cifar10_cnn.pth')

     在上面的代码示例,有一些需要特别注意繁琐的细节:

    训练和验证模式之间可以手动切换。
    有梯度计算的手动规范。
    使用较差的 SummaryWriter 类进行日志记录。
    有一个学习率调度程序。

Pytorch lightning 工作流

1 使用LightningModule 类定义模型结构

class CIFAR10CNN(L.LightningModule):
   def __init__(self):
       super().__init__()
       self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
       self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
       self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
       self.pool = nn.MaxPool2d(2, 2)
       self.fc1 = nn.Linear(64 * 4 * 4, 512)
       self.fc2 = nn.Linear(512, 10)
   def forward(self, x):
       x = self.pool(F.relu(self.conv1(x)))
       x = self.pool(F.relu(self.conv2(x)))
       x = self.pool(F.relu(self.conv3(x)))
       x = x.view(-1, 64 * 4 * 4)
       x = F.relu(self.fc1(x))
       x = self.fc2(x)
       return x

唯一的区别是,我们是从LightningModule类继承,而不是从继承nn.Module。是类LightningModule的扩展nn.Module。它将 PyTorch 工作流的训练、验证、测试、预测和优化步骤组合到一个没有循环的单一界面中。 当你开始使用时LightningModule,它被组织成六个部分:

  • 初始化(__init__和setup()方法)
  • 训练循环(training_step()方法)
  • 验证循环(validation_step()方法)
  • 测试循环(test_step()方法)
  • 预测循环(prediction_step()方法)
  • 优化器和 LR 调度程序(configure_optimizers())

我们已经看到了初始化部分。让我们继续进行训练步骤。

2 编写训练过程代码

在模型类中,复写training_step()方法

# Add the method inside the class
def training_step(self, batch, batch_idx):
   x, y = batch
   y_hat = self(x)
   loss = F.cross_entropy(y_hat, y)
   self.log('train_loss', loss)
   return loss

此方法将整个训练循环压缩为几行代码。首先,从数据batch中读取模型输入和模型输出。然后,我们运行前向传递self(x)并计算损失。然后,我们只需使用内置的 Lightning 记录器函数记录训练损失即可self.log()。

还可以在此方法中记录其他指标,例如训练准确性:

def training_step(self, batch, batch_idx):
   x, y = batch
   y_hat = self(x)
  
   loss = F.cross_entropy(y_hat, y)
   acc = (y_hat.argmax(1) == y).float().mean()
  
   self.log("train_loss", loss)
   self.log("train_acc", acc)
   return loss

log()方法可以自动计算每个epoch的模型的各个指标,比如准确性,F1-score等等。该方法里面有一些参数是可以额外设置的,比如记录每个batch和epoch下的模型指标,模型训练和验证时创建进度条,还有将模型的各个指标输出到本地文件中。

# Log the loss at each training step and epoch, create a progress bar
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

3 编写验证和测试步骤代码

def validation_step(self, batch, batch_idx):
   x, y = batch
   y_hat = self(x)
   loss = F.cross_entropy(y_hat, y)
   acc = (y_hat.argmax(1) == y).float().mean()
   self.log('val_loss', loss)
   self.log('val_acc', acc)
def test_step(self, batch, batch_idx):
   x, y = batch
   y_hat = self(x)
   loss = F.cross_entropy(y_hat, y)
   acc = (y_hat.argmax(1) == y).float().mean()
   self.log('test_loss', loss)
   self.log('test_acc', acc)

唯一的区别是不需要返回计算出的指标。Lightning模块会自动将正确的数据加载器分配给验证和测试步骤,并在后台创建循环。

尽管validation_step()和test_step()看起来相同,但它们有一个关键的区别:

  • validation_step()在训练期间,直接参与模型验证。
  • test_step()在测试期间,需要调用训练器对象的.test()方法,才能执行此操作。

4 配置优化器和优化器scheduler程序

为了定义优化器和学习率调度器,需要重写configure_optimizers()类的方法。

def configure_optimizers(self):
   optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
   scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
       optimizer, mode="min", factor=0.1, patience=5
   )
  
   return {
       "optimizer": optimizer,
       "lr_scheduler": {
           "scheduler": scheduler,
           "monitor": "val_loss",
       },
   }

上面,创建了一个Adam优化器,传入超参数和学习率。还定义了一个ReduceLROnPlateau调度函数,用于在验证损失稳定时降低学习率。返回对象字典是最灵活的选项,因为它允许定义需要额外参数的scheduler。

https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#configure-optimizers

5 定义callbacks和记录器

模型类和附带的训练,验证,优化器,学习率调度器和指标计算都已经完成,模型可以实现前向和反向传播,模型更新,验证,记录模型的各个指标。此时,还需要定义一系列的callbacks和记录器类型。这里定义一个checkpoint callback和记录器。

checkpoint_callback = ModelCheckpoint(
   dirpath="checkpoints",
   monitor="val_loss",
   filename="cifar10-{epoch:02d}-{val_loss:.2f}-{val_acc:.2f}",
   save_top_k=3,
   mode="min",
)

ModelCheckpoint是一个强大的回调,用于在监控给定指标的同时定期保存模型。每个模型检查点都记录到dirpath中。

定义一个tensorboardlogger() 记录方法

logger = TensorBoardLogger(save_dir="lightning_logs", name="cifar10_cnn")

定义一个early_stopping callback

early_stopping = EarlyStopping(monitor="val_loss", patience=5, mode="min", verbose=False)

6 创建一个trainer类

在将模型LightningModule类和callback, 记录器全部定义完以后,就可以定义一个Trainer 类来实现模型的数据读取,自动训练,验证,模型自动保存,比较简洁。可以定义最大epoch数,使用gpu训练和gpu个数,记录器,callback,训练精度,训练数据比例(默认100%),验证数据比例(默认100%),多少个epoch 模型做一次验证,多少个epoch后记录一次模型指标,记录和模型地址,单gpu训练还是分布式训练。

# Initialize the Trainer
trainer = L.Trainer(
   max_epochs=50,
   callbacks=[checkpoint_callback, early_stopping],
   logger=logger,
   accelerator="gpu" if torch.cuda.is_available() else "cpu",
   devices="auto",
)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

7 训练和测试模型

# Train and test the model

trainer.fit(model, train_loader, test_loader)

trainer.test(model, test_loader)

8 pytorch lightning 训练模型的基本流程总结

  •   创建应用转换的训练、验证和测试数据加载器。
  • 将代码组织到一个LightningModule类中:
  • 定义初始化。
  • 定义训练、验证和(可选)测试步骤。
  • 定义优化器和学习率调度器。
  • 定义回调和记录器。
  • 创建一个训练类trainer
  • 初始化模型类。
  • 拟合并测试模型。  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2168976.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

铨顺宏科技携RTLS+RFID技术亮相工博会!

中国国际工业博览会盛大开幕! 铨顺宏科技展亮点速递 铨顺宏科技展位号:F117 中国国际博览会今日开幕,铨顺宏科技携创新产品亮相,吸引众多参观者。 我们珍视此次国际盛会,将全力以赴确保最佳体验。 工作人员热情解答…

实时数字人DH_live使用案例

参看: https://github.com/kleinlee/DH_live ubuntu 测试 apt install ffmpeg 下载安装: git clone https://github.com/kleinlee/DH_live.git cd DH_liveconda create -n dh_live python=3.12 conda activate dh_live pip install -r requirements.txt pip install torch -…

E. Alternating String

E. Alternating String 这道题就是前缀和的变化, 现在做起来比较简单, 打这场的时候差了点时间就做出来了 代码 #include <bits/stdc.h> #define int long long using namespace std;const int N 200010;int od[N][30], ev[N][30]; int n;void init() {for(int i 0; …

【Linux篇】常用命令及操作技巧(进阶篇 - 上)

&#x1f30f;个人博客主页&#xff1a;意疏-CSDN博客 文章目录&#xff1a; Linux常用命令以及操作技巧&#xff08;进阶&#xff09;前言一、远程管理常用命令1、关机/重启shutdown命令 二、查看或配置网卡信息2、网卡和IP地址网卡IP地址ifconfig命令ping命令 三、SSH基础1.…

Dart中FFI学习

Flutter中FFI学习 Dart FFI编程概述NativeType&#xff08;类型映射&#xff09;Window安装GCCDart调用C的函数数组字符串结构体 Dart FFI编程 概述 dart:ffi库可以使用Dart语言调用本地C语言API ,并读取、写入、分配和删除本地内存。FFI是指外部函数接口&#xff08;Foregin…

JS设计模式之组合模式:打造灵活高效的对象层次结构

引言 当我们构建复杂的应用程序时&#xff0c;经常会遇到处理对象层次结构的情况。这些层次结构通常是树形结构&#xff0c;由组合节点和叶子节点组成。在这样的情况下&#xff0c;JavaScript 设计模式之一的组合模式就能派上用场。 组合模式是一种结构型设计模式&#xff0c…

Gitlab学习(006 gitlab操作)

尚硅谷2024最新Git企业实战教程&#xff0c;全方位学习git与gitlab 总时长 5:42:00 共40P 此文章包含第21p-第24p的内容 文章目录 git登录修改root密码 设置修改语言取消相对时间勾选 团队管理创建用户创建一个管理员登录管理员账号创建一个普通用户登录普通用户账号 群组管理…

工业交换机一键重启的好处

在当今高度自动化和智能化的工业环境中&#xff0c;工业交换机作为网络系统中至关重要的一环&#xff0c;其稳定性和可靠性直接影响到整个生产过程的顺利进行。为了更好地维护这些设备的健康运行&#xff0c;一键重启功能应运而生&#xff0c;并呈现出诸多显著的好处。 首先&am…

助力降本增效,ByteHouse打造新一代云原生数据仓库

随着数据量的爆炸式增长、企业上云速度加快以及数据实时性需求加强&#xff0c;云原生数仓市场迎来了快速发展机遇。 据 IDC、Gartner 研究机构数据显示&#xff0c;到 2025 年&#xff0c;企业 50% 数据预计为云存储&#xff0c;75% 数据库都将运行在云上&#xff0c;全球数据…

Swagger配置且添加小锁(asp.net)(笔记)

此博客是基于 asp.net core web api(.net core3.1)框架进行操作的。 一、安装Swagger包 在 NuGet程序包管理中安装下面的两个包&#xff1a; swagger包&#xff1a;Swashbuckle.AspNetCore swagger包过滤器&#xff1a;Swashbuckle.AspNetCore.Filters 二、swagger注册 在…

数据结构——初始树和二叉树

线性结构是一对一的关系&#xff0c;意思就是只有唯一的前驱和唯一的后继&#xff1b; 非线性结构&#xff0c;如树形结构&#xff0c;它可以有多个后继&#xff0c;但只有一个前驱&#xff1b;图形结构&#xff0c;它可以有多个前驱&#xff0c;也可以有多个后继。 树的定义…

进阶:反转二叉树的奇数层

目录标题 题目描述示例解题思路代码实现详细步骤解释复杂度分析 题目描述 给定一棵完美二叉树的根节点 root&#xff0c;请反转这棵树中每个奇数层的节点值。完美二叉树是指所有叶子节点都在同一层&#xff0c;并且每个非叶子节点都有两个子节点。 示例 示例 1&#xff1a; …

Harmony商城项目

目录&#xff1a; 1、启动项目看效果图2、代码分析 1、启动项目看效果图 2、代码分析 import CommonConstants from ../constants/CommonConstants; import WomanPage from ./components/WomanPage import ManPage from ./components/ManPage import HomePage from ./component…

Teams集成-会议侧边栏应用开发-实时转写

Teams虽然提供了转写的接口&#xff0c;但是不是实时的&#xff0c;即便使用订阅事件也不是实时的&#xff0c;为了达到实时转写的效果&#xff0c;使用recall.ai的转录和assembly_ai的转写实现。 前提&#xff1a;除Teams会议侧边栏应用开发-会议转写-CSDN博客的基本要求外&a…

实战教程!Zabbix 监控 Spark 中间件配置教程

本文将介绍以JMX方式监控Spark中间件。JMX具有跨平台、灵活性强、监控能力强、易于集成与扩展、图形化界面支持以及安全性与可配置性等多方面的优势&#xff0c;是监控Spark等复杂Java应用程序的重要工具之一。 Apache Spark 是一个开源的大数据处理框架&#xff0c;它提供了快…

【深度学习】ubuntu系统下docker部署cvat的自动标注功能(yolov8 segmentation)

cvat部署自动标注教程 前言step1. 拷贝yolov8项目step2. 创建yolov8的本地镜像step3. 在cvat中构建我们的工作空间 前言 安装docker和cvat的流程我这里就不赘述了&#xff0c;这样的教程还是挺多的&#xff0c;但是对于使用docker在cvat上部署自动标注算法的整个详细流程&#…

【MySQL】MVCC及其实现原理

目录 1. 概念介绍 什么是MVCC 什么是当前读和快照读 MVCC的好处 2. MVCC实现原理 隐藏字段 Read View undo-log 数据可见性算法 3. RC和RR隔离级别下MVCC的差异 4. MVCC&#xff0b;Next-key-Lock 防止幻读 1. 概念介绍 什么是MVCC Multi-Version Concurrency Cont…

通信工程学习:什么是FDD频分双工

FDD:频分双工 FDD(频分双工,Frequency Division Duplexing)是一种无线通信技术,它通过将频谱划分为上行和下行两个不重叠的频段来实现同时双向通信。以下是FDD频分双工的详细解释: 一、定义与原理 定义: FDD是一种无线通信系统的工作模式,其中上行链路(从移动…

以Flask为基础的虾皮Shopee“曲线滑块验证码”识别系统部署

以Flask为基础的虾皮Shopee“曲线滑块验证码”识别系统部署 一、验证码类型二、简介三、Flask应用 一、验证码类型 验证码类型&#xff1a;此类验证码存在两个难点&#xff0c;一是有右侧有两个凹槽&#xff0c;二是滑块的运动轨迹不是直线的&#xff0c;而是沿着曲线走的&…

您的业​​务端点是否完全安全?

根据 2023 年数据泄露调查报告&#xff0c;52% 的数据泄露涉及凭证泄露。这令人担忧&#xff0c;不是吗&#xff1f; 在当今的数字世界中&#xff0c;企业严重依赖技术&#xff0c;保护您的设备&#xff08;端点&#xff09;至关重要。这些设备&#xff08;包括计算机、笔记本…