让PyTorch训练速度更快,你需要掌握这17种方法

news2025/1/12 3:07:08
掌握这 17 种方法,用最省力的方式,加速你的 Pytorch 深度学习训练。

近日,Reddit 上一个帖子热度爆表。主题内容是关于怎样加速 PyTorch 训练。原文作者是来自苏黎世联邦理工学院的计算机科学硕士生 LORENZ KUHN,文章向我们介绍了在使用 PyTorch 训练深度模型时最省力、最有效的 17 种方法

该文所提方法,都是假设你在 GPU 环境下训练模型。具体内容如下。

17 种加速 PyTorch 训练的方法

1. 考虑换一种学习率 schedule

学习率 schedule 的选择对模型的收敛速度和泛化能力有很大的影响。Leslie N. Smith 等人在论文《Cyclical Learning Rates for Training Neural Networks》、《Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates 》中提出了周期性(Cyclical)学习率以及 1Cycle 学习率 schedule。之后,fast.ai 的 Jeremy Howard 和 Sylvain Gugger 对其进行了推广。下图是 1Cycle 学习率 schedule 的图示:

Sylvain 写到:1Cycle 包括两个等长的步幅,一个步幅是从较低的学习率到较高的学习率,另一个是回到最低水平。最大值来自学习率查找器选取的值,较小的值可以低十倍。然后,这个周期的长度应该略小于总的 epochs 数,并且,在训练的最后阶段,我们应该允许学习率比最小值小几个数量级。

与传统的学习率 schedule 相比,在最好的情况下,该 schedule 实现了巨大的加速(Smith 称之为超级收敛)。例如,使用 1Cycle 策略在 ImageNet 数据集上训练 ResNet-56,训练迭代次数减少为原来的 1/10,但模型性能仍能比肩原论文中的水平。在常见的体系架构和优化器中,这种 schedule 似乎表现得很好。

Pytorch 已经实现了这两种方法:「torch.optim.lr_scheduler.CyclicLR」和「torch.optim.lr_scheduler.OneCycleLR」。

参考文档:https://pytorch.org/docs/stable/optim.html

2. 在 DataLoader 中使用多个 worker 和页锁定内存

当使用 torch.utils.data.DataLoader 时,设置 num_workers > 0,而不是默认值 0,同时设置 pin_memory=True,而不是默认值 False。

参考文档:https://pytorch.org/docs/stable/data.html

来自 NVIDIA 的高级 CUDA 深度学习算法软件工程师 Szymon Micacz 就曾使用四个 worker 和页锁定内存(pinned memory)在单个 epoch 中实现了 2 倍的加速。人们选择 worker 数量的经验法则是将其设置为可用 GPU 数量的四倍,大于或小于这个数都会降低训练速度。请注意,增加 num_workers 将增加 CPU 内存消耗。

3. 把 batch 调到最大

把 batch 调到最大是一个颇有争议的观点。一般来说,如果在 GPU 内存允许的范围内将 batch 调到最大,你的训练速度会更快。但是,你也必须调整其他超参数,比如学习率。一个比较好用的经验是,batch 大小加倍时,学习率也要加倍。

OpenAI 的论文《An Empirical Model of Large-Batch Training》很好地论证了不同的 batch 大小需要多少步才能收敛。在《How to get 4x speedup and better generalization using the right batch size》一文中,作者 Daniel Huynh 使用不同的 batch 大小进行了一些实验(也使用上面讨论的 1Cycle 策略)。最终,他将 batch 大小由 64 增加到 512,实现了 4 倍的加速。

然而,使用大 batch 的不足是,这可能导致解决方案的泛化能力比使用小 batch 的差。

4. 使用自动混合精度(AMP)

PyTorch 1.6 版本包括对 PyTorch 的自动混合精度训练的本地实现。这里想说的是,与单精度 (FP32) 相比,某些运算在半精度 (FP16) 下运行更快,而不会损失准确率。AMP 会自动决定应该以哪种精度执行哪种运算。这样既可以加快训练速度,又可以减少内存占用。

在最好的情况下,AMP 的使用情况如下:

import torch
# Creates once at the beginning of training
scaler = torch.cuda.amp.GradScaler()


for data, label in data_iter:
   optimizer.zero_grad()
   # Casts operations to mixed precision
   with torch.cuda.amp.autocast():
      loss = model(data)

   # Scales the loss, and calls backward()
   # to create scaled gradients
   scaler.scale(loss).backward()

   # Unscales gradients and calls
   # or skips optimizer.step()
   scaler.step(optimizer)

   # Updates the scale for next iteration
   scaler.update()

5. 考虑使用另一种优化器

AdamW 是由 fast.ai 推广的一种具有权重衰减(而不是 L2 正则化)的 Adam,在 PyTorch 中以 torch.optim.AdamW 实现。AdamW 似乎在误差和训练时间上都一直优于 Adam。

Adam 和 AdamW 都能与上面提到的 1Cycle 策略很好地搭配。

目前,还有一些非本地优化器也引起了很大的关注,最突出的是 LARS 和 LAMB。NVIDA 的 APEX 实现了一些常见优化器的融合版本,比如 Adam。与 PyTorch 中的 Adam 实现相比,这种实现避免了与 GPU 内存之间的多次传递,速度提高了 5%。

6. cudNN 基准

如果你的模型架构保持不变、输入大小保持不变,设置 torch.backends.cudnn.benchmark = True。

7. 小心 CPU 和 GPU 之间频繁的数据传输

当频繁地使用 tensor.cpu() 将张量从 GPU 转到 CPU(或使用 tensor.cuda() 将张量从 CPU 转到 GPU)时,代价是非常昂贵的。item() 和 .numpy() 也是一样可以使用. detach() 代替。

如果你创建了一个新的张量,可以使用关键字参数 device=torch.device('cuda:0') 将其分配给 GPU。

如果你需要传输数据,可以使用. to(non_blocking=True),只要在传输之后没有同步点。

8. 使用梯度 / 激活 checkpointing

Checkpointing 的工作原理是用计算换内存,并不存储整个计算图的所有中间激活用于 backward pass,而是重新计算这些激活。我们可以将其应用于模型的任何部分。

具体来说,在 forward pass 中,function 会以 torch.no_grad() 方式运行,不存储中间激活。相反的是, forward pass 中会保存输入元组以及 function 参数。在 backward pass 中,输入和 function 会被检索,并再次在 function 上计算 forward pass。然后跟踪中间激活,使用这些激活值计算梯度。

因此,虽然这可能会略微增加给定 batch 大小的运行时间,但会显著减少内存占用。这反过来又将允许进一步增加所使用的 batch 大小,从而提高 GPU 的利用率。

尽管 checkpointing 以 torch.utils.checkpoint 方式实现,但仍需要一些思考和努力来正确地实现。Priya Goyal 写了一个很好的教程来介绍 checkpointing 关键方面。

Priya Goyal 教程地址:

https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb

9. 使用梯度积累

增加 batch 大小的另一种方法是在调用 optimizer.step() 之前在多个. backward() 传递中累积梯度。

Hugging Face 的 Thomas Wolf 的文章《Training Neural Nets on Larger Batches: Practical Tips for 1-GPU, Multi-GPU & Distributed setups》介绍了如何使用梯度累积。梯度累积可以通过如下方式实现:


model.zero_grad()                                   # Reset gradients tensors
for i, (inputs, labels) in enumerate(training_set):
    predictions = model(inputs)                     # Forward pass
    loss = loss_function(predictions, labels)       # Compute loss function
    loss = loss / accumulation_steps                # Normalize our loss (if averaged)
    loss.backward()                                 # Backward pass
    if (i+1) % accumulation_steps == 0:             # Wait for several backward steps
        optimizer.step()                            # Now we can do an optimizer step
        model.zero_grad()                           # Reset gradients tensors
        if (i+1) % evaluation_steps == 0:           # Evaluate the model when we...
            evaluate_model()                        # ...have no gradients accumulate

这个方法主要是为了规避 GPU 内存的限制而开发的。

10. 使用分布式数据并行进行多 GPU 训练

加速分布式训练可能有很多方法,但是简单的方法是使用 torch.nn.DistributedDataParallel 而不是 torch.nn.DataParallel。这样一来,每个 GPU 将由一个专用的 CPU 核心驱动,避免了 DataParallel 的 GIL 问题。

分布式训练文档地址:https://pytorch.org/tutorials/beginner/dist_overview.html

11. 设置梯度为 None 而不是 0

梯度设置为. zero_grad(set_to_none=True) 而不是 .zero_grad()。这样做可以让内存分配器处理梯度,而不是将它们设置为 0。正如文档中所说,将梯度设置为 None 会产生适度的加速,但不要期待奇迹出现。注意,这样做也有缺点,详细信息请查看文档。

文档地址:https://pytorch.org/docs/stable/optim.html

12. 使用. as_tensor() 而不是. tensor()

torch.tensor() 总是会复制数据。如果你要转换一个 numpy 数组,使用 torch.as_tensor() 或 torch.from_numpy() 来避免复制数据。

13. 必要时打开调试工具

PyTorch 提供了很多调试工具,例如 autograd.profiler、autograd.grad_check、autograd.anomaly_detection。请确保当你需要调试时再打开调试器,不需要时要及时关掉,因为调试器会降低你的训练速度。

14. 使用梯度裁剪

关于避免 RNN 中的梯度爆炸的问题,已经有一些实验和理论证实,梯度裁剪(gradient = min(gradient, threshold))可以加速收敛。HuggingFace 的 Transformer 实现就是一个非常清晰的例子,说明了如何使用梯度裁剪。本文中提到的其他一些方法,如 AMP 也可以用。

在 PyTorch 中可以使用 torch.nn.utils.clip_grad_norm_来实现。

15. 在 BatchNorm 之前关闭 bias

在开始 BatchNormalization 层之前关闭 bias 层。对于一个 2-D 卷积层,可以将 bias 关键字设置为 False:torch.nn.Conv2d(..., bias=False, ...)。

16. 在验证期间关闭梯度计算

在验证期间关闭梯度计算,设置:torch.no_grad() 。

17. 使用输入和 batch 归一化

要再三检查一下输入是否归一化?是否使用了 batch 归一化?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/423923.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

利用Chat GPT建立一个To-Do应用程序--我们终于遇到了我们的替代者吗?

海外Udemy、Coursera、Skillshare、Cantrill等平台精品编码课程,请访问 https://www.postcode.vip 我们看到GitHub Copilot在2021年10月发布,整个开发社区都疯了。 有些人声称我们很快就会失去工作,而其他人,像我一样&#xff0…

首家完成并购并进行重新备案公示的企业征信牌照公司-湖南省征信

2023年4月13日,中国人民银行长沙中心支行发布《关于对湖南省征信有限公司企业征信机构变更备案的公示》。内容显示中国人民银行长沙中心支行根据《征信业管理条例》《征信机构管理办法》《企业征信机构备案管理办法》及有关规定,决定受理湖南省征信有限公…

小程序学习四--组件--样式、数据、方法、属性、数据监听、生命周期、插槽、behavior

一、自定义组件 1.创建组件 2.组件引用--局部引用 3.组件引用--全局引用 4.组件和页面的区别 5.修改组件胡样式隔离选项 stypelsolation的可选值 二、自定义组件数据、方法、属性和数据监听 1.data数据 2.methods方法 事件处理函数、自定义方法_ 3.properties属性 页面中调…

JVM 内存结构

文章目录1、程序计数器2、虚拟机栈2.1 、定义2.2、栈内存溢出2.3 、线程运行诊断3、本地方法栈4、堆4.1、定义4.2 、堆内存溢出4.3 、堆内存诊断5、方法区(Method Area)5.1 、定义5.2、方法区组成5.3 、方法区内存溢出5.4 、运行时常量池5.5 、StringTab…

【JavaEE】TCP网络原理

目录 1.TCP协议定义 2.TCP原理 2.1确认应答机制 2.2超时重传机制 2.3连接管理 2.3.1建立连接(三次握手) 2.3.2断开连接(四次挥手) 2.4滑动窗口 2.5流量控制 2.6拥塞控制 2.7延迟应答 2.8捎带应答 2.9面向字节流&…

【STC8A8K64D4开发板】——按键检测

第2-3讲:按键检测 学习目的学习轻触按键和触摸按键硬件电路原理。学习STC8A8K64D4用作输入时相关寄存器的配置。掌握如何读取GPIO状态。掌握编写轻触按键和触摸按键检测程序。 硬件电路设计 IK-64D4开发板上设计了4个轻触按键和一个触摸按键,提供给用户作…

企业级信息系统开发讲课笔记2.3 利用MyBatis实现关联查询

文章目录零、本节学习目标一、查询需求(一)针对三张表关联查询(二)按班级编号查询班级信息(三)查询全部班级信息二、创建数据库表(一)创建教师表(二)创建班级…

BUUCTF-WEB-INF/web.xml泄露-SSTI注入

第八周 目录 WEB [RoarCTF 2019]Easy Java WEB-INF/web.xml泄露 WEB-INF/web.xml泄露原因 WEB-INF/web.xml泄露利用方法 解决方法 [BJDCTF2020]The mystery of ip 什么是板块注入 SSTI 为什么会产生 什么是render_template render_template: 我们为什么…

背包问题-动态规划

背包问题 容量有限的背包&#xff0c;给很多商品&#xff0c;商品有相应的体积与价值 01背包问题 一个背包 每个物品只有一个 最终状态方程 dp[i][j]max(dp[i-1][j],dp[i-1][j-w[i]]v[i]) 推导过程 由上一层推导过来 选择拿不拿i 最终代码 #include<iostream> #…

第12届蓝桥杯省赛真题剖析-2020年12月20日Scratch编程中级组

[导读]&#xff1a;超平老师的《Scratch蓝桥杯真题解析100讲》已经全部完成&#xff0c;后续会不定期解读蓝桥杯真题&#xff0c;这是Scratch蓝桥杯真题解析第124讲。 第12届蓝桥杯省赛举办了两次&#xff0c;这是2020年10月20日举行的第一次省赛中级组试题&#xff0c;比赛仍…

【Java实战篇】Day6.在线教育网课平台

文章目录一、需求&#xff1a;绑定媒资1、需求分析2、库表设计与模型类3、接口定义4、Mapper层开发5、Service层开发6、完善controller层二、需求&#xff1a;课程预览1、需求分析2、实现技术3、模板引擎4、Freemarker入门5、部署网站门户6、接口定义7、接口开发8、编写模板9、…

放弃 console.log 吧!用 Debugger 你能读懂各种源码

很多同学不知道为什么要用 debugger 来调试&#xff0c;console.log 不行么&#xff1f; 还有&#xff0c;会用 debugger 了&#xff0c;还是有很多代码看不懂&#xff0c;如何调试复杂源码呢&#xff1f; 这篇文章就来讲一下为什么要用这些调试工具&#xff1a; console.lo…

PostgreSQL技术内幕(七)索引扫描

索引概述 数据库索引&#xff0c;是将一个表的某些字段的数据进行重新组织的数据库对象。通过使用索引&#xff0c;可以大大加速数据库的一些操作&#xff0c;其背后的思想也很简单朴素&#xff1a;空间换时间。 数据库中的索引&#xff0c;可以类比为一本书的目录&#xff0…

linux java中使用POI将word转为PDF时无法显示文字

背景: 在windos上本地调试时使用POI将word转为PDF时, PDF无法显示文字的原因以及解决方案: 我的是在linux7.9上&#xff0c;原因是生成world时候汉字正常&#xff0c;转pdf时没有汉字&#xff0c;多次调查后发现没有 宋体: 原因1:字体不存在问题, word中使用的字体在系统(wind…

udp 版本的 echo server 和 echo client

文章目录前言UDP数据报套接字编程什么是套接字套接字的api示例&#xff1a;一发一收&#xff08;无响应&#xff09;客户端服务端前言 基于udp socket写一个最简单的客户端服务器程序. UDP数据报套接字编程 什么是套接字 我们先来解释一下什么是套接字吧! 套接字&#xff0…

流浪地球2:AI人工智能+数字生命+元宇宙

推荐&#xff1a;将 NSDT场景编辑器 加入你的3D开发工具链剧情介绍 太阳危机 太阳即将老化膨胀&#xff0c;吞没太阳系&#xff0c;地球上的人类构思了各种生存计划&#xff1a;其一是“数字生命计划”&#xff0c;该计划制造强大的量子计算机&#xff0c;希望让人类在数字世界…

D. Omkar and Circle(非常有意思的一道题)

Problem - D - Codeforces 丹尼是当地的数学狂人&#xff0c;他对圆形很着迷&#xff0c;这是奥姆卡最近的发明。帮他解决这个圆的问题!已知n个非负整数a1, a2&#xff0c;&#xff0c; an&#xff0c;它们排成一个圆&#xff0c;其中n必须是奇数。n -1能被2整除)。形式上&…

基于Tensorflow搭建卷积神经网络CNN(人脸识别)保姆及级教程

项目介绍 TensorFlow2.X 搭建卷积神经网络&#xff08;CNN&#xff09;&#xff0c;实现人脸识别&#xff08;可以识别自己的人脸哦&#xff01;&#xff09;。搭建的卷积神经网络是类似VGG的结构(卷积层与池化层反复堆叠&#xff0c;然后经过全连接层&#xff0c;最后用softm…

KD-2125地下电缆测试仪

一、产品概述 管线探测仪是一套高性能地下金属管线探测系统&#xff0c;由信号发射机和接收机组成&#xff0c;可用于金属管线、地下电缆的路径探测、管线普查和深度测量&#xff0c;配合多种选配附件&#xff0c;可以进行唯一性鉴别&#xff0c;以及管道绝缘破损和部分类型电缆…

HTML—javaEE

文章目录1.认识HTML2.HTML标签的使用2.1注释2.2标题2.3段落2.4换行2.5字体加粗、斜体字、删除线、下划线2.6图片2.7超链接2.8表格2.9列表2.10表单标签2.11div2.12span3.HTML特殊符号1.认识HTML &#xff08;1&#xff09;HTML是网页的编程语言&#xff0c;文件的内容主要由“标…