使用亚马逊针对 PyTorch 和 MinIO 的 S3 连接器进行模型检查点处理

news2025/2/13 9:16:24

2023 年 11 月,Amazon 宣布推出适用于 PyTorch 的 S3 连接器。适用于 PyTorch 的 Amazon S3 连接器提供了专为 S3 对象存储构建的 PyTorch 数据集基元(数据集和数据加载器)的实现。它支持用于随机数据访问模式的地图样式数据集和用于流式处理顺序数据访问模式的可迭代样式数据集。适用于 PyTorch 的 S3 连接器还包括一个检查点接口,用于将检查点直接保存和加载到 S3 存储桶,而无需先保存到本地存储。如果您还没有准备好采用正式的 MLOps 工具,而只需要一种简单的方法来保存模型,那么这是一个非常好的选择。这就是我将在这篇文章中介绍的内容。S3 连接器的文档仅展示了如何将其与 Amazon S3 一起使用 - 我将在此处向您展示如何将其用于 MinIO。让我们先执行此作 - 让我们设置 S3 连接器,以便它从 MinIO 写入和读取检查点。

将 S3 连接器连接到 MinIO

将 S3 连接器连接到 MinIO 就像设置环境变量一样简单。之后,一切都会顺利进行。诀窍是以正确的方式设置正确的环境变量。

本文的代码下载使用 .env 文件来设置环境变量,如下所示。此文件还显示了我用于使用 MinIO Python SDK 直接连接到 MinIO 的环境变量。请注意,AWS_ENDPOINT_URL 需要 protocol,而 MinIO 变量不需要。

AWS_ACCESS_KEY_ID=admin
AWS_ENDPOINT_URL=http://172.31.128.1:9000
AWS_REGION=us-east-1
AWS_SECRET_ACCESS_KEY=password
MINIO_ENDPOINT=172.31.128.1:9000
MINIO_ACCESS_KEY=admin
MINIO_SECRET_KEY=password
MINIO_SECURE=false

写入和读取 Checkpoint

我从一个简单的例子开始。下面的代码段创建了一个 S3Checkpointing 对象,并使用其 writer() 方法将模型的状态字典发送到 MinIO。我还使用 Torchvision 创建了一个 ResNet-18(18 层)模型,用于演示目的。

import os

from dotenv import load_dotenv
from s3torchconnector import S3Checkpoint
import torchvision
import torch

# Load the credentials and connection information.
load_dotenv()

model = torchvision.models.resnet18()
model_name = 'resnet18.pth'
bucket_name = 'checkpoints'

checkpoint_uri = f's3://{bucket_name}/{model_name}'
s3_checkpoint = S3Checkpoint(os.environ['AWS_REGION'])

# Save checkpoint to S3
with s3_checkpoint.writer(checkpoint_uri) as writer:
   torch.save(model.state_dict(), writer)

请注意,该区域有一个强制参数。从技术上讲,访问 MinIO 时没有必要,但如果为此变量选择错误的值,内部检查可能会失败。此外,您的存储桶必须存在,上述代码才能正常工作。如果 writer() 方法不存在,它将引发错误。不幸的是,无论出了什么问题,writer() 方法都会引发相同的错误。例如,如果您的存储桶不存在,您将收到如下所示的错误。如果 writer() 方法不喜欢您指定的区域,您也会收到相同的错误。希望未来的版本将提供更具描述性的错误消息。

S3Exception: Client error: Request canceled

将以前保存的模型读取到内存中的代码类似于写入 MinIO。使用 reader() 方法,而不是 writer() 方法。下面的代码显示了如何执行此作。

import os

from dotenv import load_dotenv
from s3torchconnector import S3Checkpoint
import torchvision
import torch

# Load the credentials and connection information.
load_dotenv()

model_name = 'resnet18.pth'
bucket_name = 'checkpoints'

checkpoint_uri = f's3://{bucket_name}/{model_name}'
s3_checkpoint = S3Checkpoint(os.environ['AWS_REGION'])

# Load checkpoint from S3
with s3_checkpoint.reader(checkpoint_uri) as reader:
   state_dict = torch.load(reader, weights_only=True)

model.load_state_dict(state_dict)

接下来,让我们看看模型训练期间检查点的一些实际注意事项。

在模型训练期间编写检查点

如果您使用大型数据集训练大型模型,请考虑在每个 epoch 后设置检查点。这些训练运行可能需要数小时甚至数天才能完成,因此在发生故障时能够从上次中断的地方继续非常重要。此外,我们假设您必须使用共享存储桶来保存来自多个团队的多个模型的模型检查点。MLOps 约定是按试验组织训练运行。例如,如果您正在研究具有四个隐藏层的架构,那么在寻找各种超参数的最佳值时,您将使用此架构进行多次运行。如果同事使用五层体系结构运行实验,则需要一种方法来防止名称冲突。这可以通过模拟如下所示的层次结构的对象路径来解决。

最后,为了确保您在每个 epoch 中获得新版本的模型,请确保在用于保存检查点的存储桶上启用版本控制。下面的训练函数使用上述路径结构在每个 epoch 后对模型进行检查点作。(可以在本文的代码下载中找到此训练函数的更强大版本。

def train_model(model: nn.Module, loader: DataLoader, 
                training_parameters: Dict[str, Any]) -> List[float]:

   if training_parameters['checkpoint']:
       checkpoint_uri = f's3://{training_parameters["checkpoint_bucket"]} \
                          /{training_parameters["project_name"]} \
                          /{training_parameters["experiment_name"]} \
                          /{training_parameters["run_id"]} \
                          /{training_parameters["model_name"]}'
       s3_checkpoint = S3Checkpoint(region=os.environ['AWS_REGION'])

   loss_func = nn.NLLLoss()
   optimizer = optim.SGD(model.parameters(), lr=training_parameters['lr'], 
                         momentum=training_parameters['momentum'])

   # Epoch loop
   compute_time_by_epoch = []
   for epoch in range(training_parameters['epochs']):
       # Batch loop
       for images, labels in loader:

           # Flatten MNIST images into a 784 long vector.
           # shape = [32, 784]
           images = images.view(images.shape[0], -1)

           # Training pass
           optimizer.zero_grad()
           output = model(images)
           loss = loss_func(output, labels)
           loss.backward()
           optimizer.step()

       # Save checkpoint to S3
       if training_parameters['checkpoint']:
           with s3_checkpoint.writer(checkpoint_uri) as writer:
               torch.save(model.state_dict(), writer)

请注意,模型名称不包含指示纪元的子字符串。如前所述,我使用了启用了版本控制的存储桶 - 换句话说,版本号表示纪元。这种方法的优点在于,您无需知道引用最新模型的 epoch 数。在上述训练代码运行了 10 个 epoch 后,我的检查点存储桶如下面的屏幕截图所示。

此培训演示可被视为 DIY MLOps 解决方案的开始。

结论

适用于 PyTorch 的 S3 连接器易于使用,工程师在使用时编写的数据访问代码行数会更少。在本文中,我展示了如何将其配置为使用环境变量连接到 MinIO。配置完成后,工程师可以分别使用 writer() 和 reader() 方法将检查点写入和读取 MinIO。在本文中,我展示了如何配置 S3 Connect 以连接到 MinIO。我还演示了 S3Checkpoint 类及其 reader() 和 writer() 方法的基本用法。最后,我展示了一种在实际训练函数中针对启用了版本的检查点存储桶使用这些检查点功能的方法。在这篇文章中,我没有介绍在分布式训练期间检查点所需的技术和工具,这可能有点棘手。分布式训练期间的检查点设置会有所不同,具体取决于您使用的框架(PyTorch、Ray 或 DeepSpeed 等)和您正在进行的分布式训练类型:数据并行(每个工作程序都有模型的完整副本)或模型并行(每个工作程序只有一个模型分片)。在以后的文章中,我将介绍其中的一些技术。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2297269.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

408-数据结构

数据结构在学什么? 1.用代码把问题信息化 2.用计算机处理信息 ch1 数据:数据是信息的载体,是描述客观事物属性的数、字符及所有能输入到计算机中并被计算机程序识别和处理的符号的集合。数据是计算机程序加工的原料。 ch2 //假设线性表…

spring cloud 使用 webSocket

1.引入依赖,(在微服务模块中) <!-- Spring WebSocket --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-websocket</artifactId></dependency> 2.新建文件 package com.ruoyi.founda…

安科瑞 Acrel-2000ES:解锁储能管理新高度,引领能源未来!

安科瑞 崔丽洁 在能源转型的关键时期&#xff0c;高效的储能管理成为众多企业和项目的核心需求。今天&#xff0c;就给大家介绍一款储能管理的 “神器”—— 安科瑞 Acrel-2000ES 储能能量管理系统。 安科瑞电气可是行业内的 “明星企业”&#xff0c;2003 年成立&#xff0c;2…

基于Django以及vue的电子商城系统设计与实现

基于Django以及vue的电子商城系统设计与实现 引言 随着电子商务的快速发展&#xff0c;越来越多的企业和个人选择搭建线上商城&#xff0c;以提供更加便捷的购物体验。本文基于Python开发了一套电子商城系统&#xff0c;后端采用Django框架&#xff0c;前端使用Vue.js&#x…

电脑变慢、游戏卡顿,你的SSD固态可能快坏了!

电脑用久了&#xff0c;很多人都会感觉速度变慢&#xff0c;开机变慢、文件复制时间变长&#xff0c;甚至莫名其妙的卡顿。你可能怀疑是系统问题&#xff0c;或者内存不够&#xff0c;但往往被忽略的一个关键因素——你的硬盘&#xff0c;可能正在悄悄老化。 硬盘寿命不是永久的…

AI使用场景简单测试

前言 今天来分享下AI的2个实用场景&#xff0c;我这里是使用的博主&#xff1a;小虚竹&#xff0c;搭建的AI服务&#xff0c;用的ChatGPT 4O模型&#xff0c;主要是试了3个场景&#xff0c;服装设计、直播带货话术、检验报告分析。 一、服装设计 对于最后需要的裁片设计上的尺寸…

【并发控制、更新、版本控制】.NET开源ORM框架 SqlSugar 系列

系列文章目录 &#x1f380;&#x1f380;&#x1f380; .NET开源 ORM 框架 SqlSugar 系列 &#x1f380;&#x1f380;&#x1f380; 文章目录 系列文章目录一、并发累计&#xff08;累加&#xff09;1.1 单条批量累计1.2 批量更新并且字段11.3 批量更新并且字段list中对应的…

DeepSeek-R1本地搭建

1. 前言 现在deepseek火上天了&#xff0c;因为各种应用场景,加上DeepSeek一直网络异常&#xff0c;所以本地部署Deepseek成为大家的另一种选择。 目前网络上面关于DeepSeek的部署方式有很多&#xff0c;但是太麻烦了&#xff0c;本文是一篇极为简单的DeepSeek本地部署方式&…

查出 product 表中所有 detail 字段包含 xxx 的完整记录

您可以使用以下 SQL 查询语句来查出 product 表中所有 detail 字段包含 oss.kxlist.com 的完整记录&#xff1a; SELECT * FROM product WHERE INSTR(detail, oss.kxlist.com) > 0;下面是detail字段包含的完整内容 <p><img style"max-width:100%;" src…

Redis存储⑥Redis五大数据类型之 Zset

目录 1. Zset 有序集合 1.1 Zset 有序集合常见命令 zadd zcard zcount zrange zrevrange zrangebyscore&#xff08;弃用&#xff09; zpopmax bzpopmax zpopmin bzpopmin zrank zrevrank zscore zrem zremrangebyrank zremrangebyscore zincrby 1.2 Zset有…

将Excel中的图片保存下载并导出

目录 效果演示 注意事项 核心代码 有需要将excel中的图片解析出来保存到本地的小伙子们看过来&#xff01;&#xff01;&#xff01; 效果演示 注意事项 仅支持xlsx格式&#xff1a;此方法适用于Office 2007及以上版本的.xlsx文件&#xff0c;旧版.xls格式无法使用。 图片名…

SQL注入之布尔和时间盲注,sqli-labs

实验环境&#xff1a; sqli-labs&#xff0c;小皮面板搭建&#xff0c;edge浏览器 apache&#xff1a;2.4.39&#xff0c;MySQL&#xff1a;5.7 PHP&#xff1a;5.39 Python&#xff08;pycharm2023&#xff09;:3 less-8 布尔盲注&#xff1a; 1.我这里是采用最简单的直接采…

基于云计算、大数据与YOLO设计的火灾/火焰目标检测

摘要&#xff1a;本研究针对火灾早期预警检测需求&#xff0c;采用在Kaggle平台获取数据、采用云计算部署的方式&#xff0c;以YOLO11构建模型&#xff0c;使用云计算服务器训练模型。经训练&#xff0c;box loss从约3.5降至1.0&#xff0c;cls loss从约4.0降至1.0&#xff0c;…

YOLO自定义数据集实现K折交叉验证——K-Fold Cross Validation

实现K折交叉验证&#xff08;K-Fold Cross Validation&#xff09;对于YOLO&#xff08;You Only Look Once&#xff09;自定义数据集的目标检测任务可以显著提升模型的可靠性和泛化能力。 1. 数据集准备 首先&#xff0c;你需要确保你的数据集符合YOLO的格式&#xff0c;具体…

go语言简单快速的按顺序遍历kv结构(map)

文章目录 需求描述用map实现按照map的key排序用二维切片实现用结构体实现 需求描述 在go语言中&#xff0c;如果需要对map遍历&#xff0c;每次输出的顺序是不固定的&#xff0c;可以考虑存储为二维切片或结构体。 假设现在需要在页面的下拉菜单中展示一些基础的选项&#xff…

【竞技宝】LOL-LPL:EDG3-0零封LNG

北京时间2月12日,英雄联盟LPL2025正在如火如荼的进行之中,昨日迎来LNG对阵EDG,以下是本场比赛的详细战报。 第一局: EDG:杰斯、赵信、维克托、女枪、芮尔 LNG:猴子、猪妹、飞机、韦鲁斯、布隆 首局比赛,EDG在蓝色方,LNG在红色方。阵容方面,EDG点出了杰斯、赵信、维克托、女枪…

在fedora41中安装钉钉dingtalk_7.6.25.4122001_amd64

在Fedora-Workstation-Live-x86_64-41-1.4中安装钉钉dingtalk_7.6.25.4122001_amd64.deb 到官网下载钉钉Linux客户端com.alibabainc.dingtalk_7.6.25.4122001_amd64.deb https://page.dingtalk.com/wow/z/dingtalk/simple/ddhomedownload#/ 一、直接使用dpkg命令安装deb包报错…

看期货用的指标,可以提示买卖点和K线转折变颜色的主图指标源码下载

A:MA(CLOSE,17)ABS(MA(CLOSE,17)-REF(MA(CLOSE,17),1)); B:MA(CLOSE,17)MA(CLOSE,17)-REF(MA(CLOSE,17),1); 分界线:IF(MA(CLOSE,17)<B,B,MA(CLOSE,17)),COLORFF00FF,LINETHICK2; 操作线:分界线-(EMA(C,3)-分界线),COLOR00FFFF,LINETHICK2; GUP:MA(C,5),COLORWHITE,LINE…

【PS 2022】Adobe Genuine Service Alert 弹出

电脑总是弹出Adobe Genuine Service Alert弹窗 1. 不关掉弹窗并打开任务管理器&#xff0c;找到Adobe Genuine Service Alert&#xff0c;并右键进入文件所在位置 2 在任务管理器中结束进程并将文件夹中的 .exe 文件都使用空文档替换掉 3. 打开PS不弹出弹窗&#xff0c;解决&a…

30天开发操作系统 第 20 天 -- API

前言 大家早上好&#xff0c;今天我们继续努力哦。 昨天我们已经实现了应用程序的运行, 今天我们来实现由应用程序对操作系统功能的调用(即API, 也叫系统调用)。 为什么这样的功能称为“系统调用”(system call)呢&#xff1f;因为它是由应用程序来调用(操作)系统中的功能来完…