利用Accelerate()进行pytorch的多GPU加速

news2024/11/24 4:22:08

简介

官方Github:https://github.com/huggingface/accelerate

Accelerate 是为喜欢编写PyTorch模型的训练循环但不愿意编写和维护使用多GPU/TPU/fp16所需的样板代码的PyTorch用户创建的。

它可以仅加速与多 GPU/TPU/fp16 相关的样板代码,并保持其余代码不变。

示例如下

  import torch
  import torch.nn.functional as F
  from datasets import load_dataset
+ from accelerate import Accelerator

+ accelerator = Accelerator()
- device = 'cpu'
+ device = accelerator.device

  model = torch.nn.Transformer().to(device)
  optimizer = torch.optim.Adam(model.parameters())

  dataset = load_dataset('my_dataset')
  data = torch.utils.data.DataLoader(dataset, shuffle=True)

+ model, optimizer, data = accelerator.prepare(model, optimizer, data)

  model.train()
  for epoch in range(10):
      for source, targets in data:
          source = source.to(device)
          targets = targets.to(device)

          optimizer.zero_grad()

          output = model(source)
          loss = F.cross_entropy(output, targets)

-         loss.backward()
+         accelerator.backward(loss)

          optimizer.step()

正如您在此示例中所看到的,通过向任何标准 PyTorch 训练脚本添加 5 行代码,您现在可以在任何类型的单节点或分布式节点设置(单 CPU、单 GPU、多 GPU 和 TPU)以及使用或不使用混合精度(fp8、fp16、bf16)上运行。

特别是,无需修改即可在本地计算机上运行相同的代码,以进行调试或训练环境。

Accelerate 甚至会为您处理设备放置(这需要对代码进行更多更改,但通常更安全),因此您甚至可以进一步简化训练循环:

  import torch
  import torch.nn.functional as F
  from datasets import load_dataset
+ from accelerate import Accelerator

- device = 'cpu'
+ accelerator = Accelerator()

- model = torch.nn.Transformer().to(device)
+ model = torch.nn.Transformer()
  optimizer = torch.optim.Adam(model.parameters())

  dataset = load_dataset('my_dataset')
  data = torch.utils.data.DataLoader(dataset, shuffle=True)

+ model, optimizer, data = accelerator.prepare(model, optimizer, data)

  model.train()
  for epoch in range(10):
      for source, targets in data:
-         source = source.to(device)
-         targets = targets.to(device)

          optimizer.zero_grad()

          output = model(source)
          loss = F.cross_entropy(output, targets)

-         loss.backward()
+         accelerator.backward(loss)

          optimizer.step()

启动脚本:

 accelerate launch my_script.py --args_to_my_script

此 CLI 工具是可选的,您仍然可以使用或在方便时使用。python my_script.pypython -m torchrun my_script.py

如果你不想运行,你也可以直接将你想要的参数作为参数传递给。torchrunaccelerate launch accelerate config

例如,以下是在两个 GPU 上启动的方法:

accelerate launch --multi_gpu --num_processes 2 examples/nlp_example.py

更多查看官方提供的CLI文档:https://huggingface.co/docs/accelerate/package_reference/cli

如果想指定GPU,可以通过在终端运行 accelerate config 命令进行配置,根据需求会最终生成一个 default_config.yaml 文件如下:

这里面就可以设置对应的gpu_ids。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2160958.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Pyspark dataframe基本内置方法(5)

文章目录 Pyspark sql DataFrame相关文章toDF 设置新列名toJSON row对象转换json字符串toLocallterator 获取迭代器toPandas 转换python dataframetransform dataframe转换union unionALL 并集不去重(按列顺序)unionByName 并集不去重(按列名…

jenkins声明式流水线语法详解

最基本的语法包含 pipeline:所有有效的声明式流水线必须包含在一个 pipeline 块中stages:包含一系列一个或多个stage指令stage:stage包含在stages中进行,比如某个阶段steps:在阶段中具体得执行操作,一个或…

提升工作效率神器

这五款软件让你事半功倍 在当今快节奏的社会中,提高工作效率成为了每个人追求的目标。而在这个数字化时代,选择对的软件工具无疑是提高效率的关键。今天,我为大家推荐五款优秀的工作效率软件,帮助你在工作中事半功倍。 1、亿可达…

15个 Jenkins 面试题

Jenkins 已成为持续集成和持续部署 (CI/CD) 流程中使用最广泛的自动化服务器之一。凭借其强大的功能和广泛的插件生态系统,Jenkins 已成为全球软件开发团队的首选工具。如果您正在准备 Jenkins 面试,那么精通其概念、架构和最佳实践至关重要。 为了帮助…

1.3 MySql的用户管理

一、下载Mysql客户端 下载navicat:Navicat 中国 | 支持 MySQL、Redis、MariaDB、MongoDB、SQL Server、SQLite、Oracle 和 PostgreSQL 的数据库管理 二、安装Navicat 三、创建数据库 创建一个数据库的连接吧,因为这个界面儿是图形界面儿,所以我们创建…

深入分析MySQL事务日志-Redo Log日志

文章目录 事务日志-Redo Log2.1 Redo Log2.1.1 Redo Log与持久性2.1.2 Redo Log的工作原理2.1.3 Redo Log的落盘策略2.1.4 Redo Log的系统参数 事务日志-Redo Log 事务的隔离性是通过锁实现,而事务的原子性、和持久性则是通过事务日志实现。在MySQL中,事…

【吉林大学编译原理题库】正则表达式的书写

1. 2. 选A 3. 没啥好说的,按意思写就行: 4. 5.设字母表S{0,1},写正则表达式表示所有偶数个0和偶数个1组成的字符串。 6. 设字母表S{0,1},写正则表达式表示所有偶数个0和奇数个1组成的字符串。(提示&am…

Token usage of Content Filtered messages in Azure OpenAI Services

题意:在Azure OpenAI服务中,内容过滤消息的令牌使用 问题背景: When sending a message to a chat via GetChatCompletions as a response, I get a RequestFailedException. In the exception, I get an answer for which category content…

2-101基于matlab的频带方差端点检测

基于matlab的频带方差端点检测,噪声频谱中,各频带之间变化很平缓,语音各频带之间变化较激烈。据此特征,语音和噪声就极易区分。计算短时频带方差,实质就是计算某一帧信号的各频带能量之间的方差。这种以短时频带方差作…

揭秘MySQL主从复制:打造高可用性与数据冗余的强效引擎

作者简介:我是团团儿,是一名专注于云计算领域的专业创作者,感谢大家的关注 座右铭: 云端筑梦,数据为翼,探索无限可能,引领云计算新纪元 个人主页:团儿.-CSDN博客 目录 前言&#…

从Web2到Web3:探索下一代互联网的无限可能性

互联网经历了从Web1到Web2的重大变革,现在正迈向Web3。Web2通过社交媒体、电子商务和内容平台改变了我们的数字生活,但同时也伴随着中心化平台的垄断和用户数据被广泛控制的问题。而Web3的出现,则试图通过去中心化技术解决这些挑战&#xff0…

人到中年,最清醒的活法—沉浸式做自己

生活中,你是不是常常被这样的事情所困扰? 工作的时候,每天被千头万绪的杂事缠身,看着一堆待完成事项,和工作群里一堆的消息在轰炸你,内心顿感烦躁甚至暴怒。 经常因为领导,同事或者熟人甚至陌生…

java 洛谷题单【算法1-7】搜索

P1219 [USACO1.5] 八皇后 Checker Challenge 解题思路 回溯法 递归与回溯: 从第0行开始,为每个行尝试放置棋子的位置,检查放置是否违反约束条件。如果放置合法,则继续递归处理下一行(即下一层递归)。如果当前行无法找…

【Go语言】深入解读Go语言中的指针,助你拨开迷雾见月明

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

浅谈提示工程之In-context learning技术

提示工程之In-context learning技术; 通过一张图片围绕下边几个方面进行简单说明 概念起因本质结构注意事项 日常总结

SQL语法学习与实战应用

第一章 引言 1.1 MySQL数据库概述 MySQL,作为一种广泛使用的关系型数据库管理系统,自其问世以来,便凭借开源、高性能及低成本等显著特点,迅速占据了广泛的市场份额。这一系统不仅支持大规模并发访问,更提供了多样化的…

【最新华为OD机试E卷-支持在线评测】绘图机器(100分)多语言题解-(Python/C/JavaScript/Java/Cpp)

🍭 大家好这里是春秋招笔试突围 ,一枚热爱算法的程序员 💻 ACM金牌🏅️团队 | 大厂实习经历 | 多年算法竞赛经历 ✨ 本系列打算持续跟新华为OD-E/D卷的多语言AC题解 🧩 大部分包含 Python / C / Javascript / Java / Cpp 多语言代码 👏 感谢大家的订阅➕ 和 喜欢�…

【ARM】MDK-当选择AC5时每次点击build都会全编译

1、 文档目标 解决MDK中选择AC5时每次点击build都会全编译 2、 问题场景 在MDK中点击build时,正常会只进行增量编译,但目前每次点击的时候都会全编译。 3、软硬件环境 1 软件版本:Keil MDK 5.38a 2 电脑环境:Window 10 4、解决…

新手操作指引:快速上手腾讯混元大模型

引言 腾讯混元大模型是一款功能强大的AI工具,适用于文本生成、图像创作和视频生成等多种应用场景。对于新手用户,快速上手并充分利用这一工具可能会有些挑战。本文将提供详细的新手操作指引,帮助您轻松开始使用腾讯混元大模型。 步骤一&…

kubernetes网络(二)之bird实现节点间BGP互联的实验

摘要 上一篇文章中我们学习了calico的原理,kubernetes中的node节点,利用 calico 的 bird 程序相互学习路由,为了加深对 bird 程序的认识,本文我们将使用bird进行实验,实验中实现了BGP FULL MESH模式让宿主相互学习到对…