PyTorch | 加速模型训练的妙招

news2025/1/23 7:15:36

引言

alt

提升机器学习模型的训练速度是每位机器学习工程师的共同追求。训练速度的提升意味着实验周期的缩短,进而加速产品的迭代过程。同时,这也表示在进行单一模型训练时,所需的资源将会减少。简而言之,我们追求的是效率。

熟悉 PyTorch profiler

在进行任何优化之前,首先需要了解代码中各个部分的执行时长。Pytorch profiler 是一款功能全面的训练性能分析工具,能够捕捉以下信息:

  • CPU 操作的耗时
  • CUDA 核心的运行时间
  • 内存使用情况的历史记录

这些就是你需要关注的所有内容。而且,使用起来非常简单!记录这些事件的方法是,将训练过程封装在一个 profiler 的上下文环境中,操作方式如下:

import torch.autograd.profiler as profiler

with profiler.profile(
  activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
  on_trace_ready=torch.profiler.tensorboard_trace_handler('./logs'),
as prof:
  train(args)

之后,您可以启动张量板并查看分析跟踪。Profiler 提供了众多选项,但最关键的是 "activities" 和 "profile_memory" 这两个功能。尽管你可以探索其他功能,但请记住一个基本原则:启用的选项越少,性能开销也就越低。

例如,如果你的目的是分析 CUDA 内核的执行时间,那么最好的做法是关闭 CPU 分析和其他所有功能。这样,分析结果会更贴近实际的执行情况。

为了让分析结果更易于理解,建议添加一些描述代码关键部分的分析上下文。如果分析功能没有被激活,这些上下文就不会产生任何影响。

with profiler.record_function("forward_pass"):
  result = model(**batch)

with profiler.record_function("train_step"):
  step(**result)

这样,您使用的标签将在迹线中可见。因此,识别代码块会更容易。或者更精细的内部模式的前进:

with profiler.record_function("transformer_layer:self_attention"):
  data = self.self_attention(**data)

...

with profiler.record_function("transformer_layer:encoder_attention"):
  data = self.encoder_attention(**data, **encoder_data)

了解 PyTorch traces

收集traces后,在张量板中打开它们。 CPU + CUDA 配置文件如下所示:

alt

立刻识别出任何训练过程中的关键环节:

  • 数据加载
  • 前向传播
  • 反向传播

PyTorch 会在一个独立线程中处理反向传播(如上图所示的线程 16893),这使得它很容易被识别出来。

数据加载

在数据加载方面,我们追求极致的效率,即几乎不耗费时间。

原因在于,在数据加载的过程中,GPU 闲置不工作,这导致资源没有得到充分利用。但是,由于数据处理和 GPU 计算是两个独立的部分,它们可以同时进行。

你可以通过查看分析器跟踪中的 GPU 估计 SM 效率和 GPU 利用率来轻松识别 GPU 空闲的区域。那些活动量为零的区域就是我们需要注意的问题所在。在这些区域,GPU 并没有参与任何工作。

解决这个问题的一个简单方法是:

  • 在后台进程中进行数据处理,这样不会受到全局解释器锁(GIL)的限制。
  • 通过并行进程来同时执行数据增强和转换操作。 如果你使用的是 PyTorch 的 DataLoader,通过设置 num_workers 参数就可以轻松实现这一点。如果你使用的是 IterableDataset,情况会稍微复杂一些,因为数据可能会被重复处理。不过,通过使用 get_worker_info() 方法,你仍然可以解决这个问题——你需要调整迭代方式,确保每个工作进程处理的是互不重叠的不同数据行。

如果你需要更灵活的数据处理方式,你可以考虑使用 multiprocessing 模块来自己实现多进程转换功能。

内存分配器

使用 PyTorch 在 CUDA 设备上分配张量时,PyTorch 会利用缓存分配器来避免执行成本较高的 cudaMalloccudaFree 操作。PyTorch 的分配器会尝试复用之前通过 cudaMalloc 分配的内存块。如果分配器手头有合适的内存块,它将直接提供这块内存,而无需再次调用 cudaMalloc,这样 cudaMalloc 只在程序启动时调用一次。

但是,如果你处理的是长度不一的数据,不同前向传播过程可能需要不同大小的中间张量。这时,PyTorch 的分配器可能没有合适的内存块可用。在这种情况下,分配器会尝试通过调用 cudaFree 释放之前分配的内存块,以便为新的内存分配腾出空间。

释放内存后,分配器会重新开始构建其缓存,这将涉及到大量的 cudaMalloc 调用,这是一个资源消耗较大的操作。你可以通过观察 tensorboard profiler viewer 的内存分析部分来识别这个问题。

alt

请注意,代表分配器预留内存的红线持续波动。这表明 PyTorch 的内存分配器在处理内存请求时遇到了效率问题。

当内存分配在没有触发分配器紧急情况下顺利进行时,你会看到红线保持平稳。

alt

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1919415.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Rust代码优化的九大技巧

一.使用 Cargo 内置的性能分析工具 描述:Cargo 是 Rust 的包管理器,带有内置工具来分析代码性能,以识别性能瓶颈。 解释: 发布模式:在发布模式下编译启用优化,可以显著提高性能。 cargo build --release基…

解决vue多层弹框时存在遮挡问题

本文给大家介绍vue多层弹框时存在遮挡问题,解决思路首先想到的是找到对应的遮挡层的css标签,然后修改z-index值,但是本思路只能解决首次问题,再次打开还会存在相同的问题,故该思路错误,下面给大家带来一种正…

【鸿蒙学习笔记】文件管理

官方文档:Core File Kit简介 目录标题 文件分类什么是应用沙箱? 文件分类 应用文件,比如应用的安装包,自己的资源文件等。用户文件,比如用户自己的照片,录制的音视频等。 什么是应用沙箱? 应…

完美解决:MySQL8报错:Public Key Retrieval is not allowed

在配置数据源的时候直接将属性allowPublicKeyRetrieval设置为true即可 &AutoReconnecttrue

论文发表作图必备:训练结果对比,多结果绘在一个图片【Precision】【Recall】【mAP0.5】【mAP0.5-0.95】【loss】

前言:Hello大家好,我是小哥谈。YOLO(You Only Look Once)算法是一种目标检测算法,它可以在图像中实时地检测和定位目标物体。YOLO算法通过将图像划分为多个网格,并在每个网格中检测目标物体,从而实现快速的目标检测。本文所介绍的作图教程适用于所有YOLO系列版本算法,接…

Linux Ubuntu MySQL环境安装

1. 更新软件源 首先,确保你的Ubuntu系统已经更新了软件源列表,以便能够下载到最新的软件包。打开终端并输入以下命令: sudo apt update 2. 安装MySQL服务器 打开终端并输入以下命令来安装MySQL服务器 sudo apt install mysql-server 在…

代码随想录算法训练营第五十天| 739. 每日温度、496.下一个更大元素 I、503.下一个更大元素II

739. 每日温度 题目链接: 739. 每日温度 文档讲解:代码随想录 状态:不会 思路: 这道题需要找到下一个更大元素。 使用栈来存储未找到更高温度的下标,那么栈中的下标对应的温度从栈底到栈顶是递减的。这意味着&#xff…

鸿蒙语言基础类库:【@ohos.application.testRunner (TestRunner)】 测试

TestRunner TestRunner模块提供了框架测试的能力。包括准备单元测试环境、运行测试用例。 如果您想实现自己的单元测试框架,您必须继承这个类并覆盖它的所有方法。 说明: 开发前请熟悉鸿蒙开发指导文档:gitee.com/li-shizhen-skin/harmony-…

5款常用的漏洞扫描工具,网安人员不能错过!

漏洞扫描是指基于漏洞数据库,通过扫描等手段对指定的远程或者本地计算机系统的安全脆弱性进行检测,发现可利用漏洞的一种安全检测的行为。 在漏洞扫描过程中,我们经常会借助一些漏扫工具,市面上漏扫工具众多,其中有一…

windows环境下基于3DSlicer 源代码编译搭建工程开发环境详细操作过程和中间关键错误解决方法说明

说明: 该文档适用于  首次/重新 搭建3D-Slicer工程环境  Clean up(非增量) 编译生成 1. 3D-slicer 软件介绍 (1)3D Slicer为处理MRI\CT等图像数据软件,可以实行基于MRI图像数据的目标分割、标记测量、坐标变换及三维重建等功能,其源于3D slicer 4.13.0-2022-01-19开…

新火种AI|微软和苹果放弃OpenAI董事会观察员席位

作者:一号 编辑:美美 微软苹果双双不做OpenAI“观察员”,OpenAI能更自由吗? 7月10消息,微软当地时间周一宣布将放弃在OpenAI董事会的观察员席位,他们称,OpenAI在过去八个月中取得了“重大进展…

Java(十八)---单链表

文章目录 前言1.链表的概念及结构2.单链表的创建3.功能的实现3.1.创建链表(create)(需要自己创建)3.2.显示链表(display)3.3.获取链表的个数( size() )3.4.是否包含指定元素(contains)3.5.头插法(addFirst)3.6.尾插法(addLast)3.7.在指定位置进行插入(addIndex)3.8.删除出现在第…

基于Spring Boot的高校后勤餐饮管理系统

1 项目介绍 1.1 研究背景 “互联网”时代的到来,既给高校后勤管理发展带来了机遇,也带来了更大的挑战。信息化应用已经开始普及,传统的高校后勤餐饮管理模式往往存在着效率低下、信息不透明、资源浪费等问题,已经难以满足现代高…

CSS 【实用教程】(2024最新版)

CSS 简介 CSS 是层叠样式表( Cascading Style Sheets ) 的简写,用于精确控制 HTML 页面的样式,以便更好地展示图文信息或产生炫酷/友好的交互体验。 没有必要让所有浏览器都显示得一模一样的,好的浏览器有更好的显示,糟糕的浏览器…

Linux编程第三篇:Linux简介,开源软件简介(Linux是否安全?参考TESEC指标)

业精于勤荒于嬉,行成于思毁于随。 今天这篇算是Linux的正式学习,废话不多说,我们开始吧 第三篇 一、UNIX与Linux发展史1.1、UNIX发展历史和发行版本1.2、UNIX主要发行版本1.3、Linux发展历史1.4、Linux内核版本1.5、Linux主要发行版本 二、开…

多周期路径的约束与设置原则

本节将回顾工具检查建立保持时间的原则,接下来介绍设置多周期后的检查原则。多周期命令是设计约束中常用的一个命令,用来修改默认的建立or保持时间的关系。基本语法如下 默认的建立时间与保持时间的检查方式 DC工具计算默认的建立保持时间关系是基于时钟…

前台线程和后台线程(了解篇)

在多线程编程中,理解线程的不同类型对于编写高效、稳定的程序至关重要。特别地,前台线程(Foreground Threads)与后台线程(Background Threads)在行为上有着根本的区别,这些区别直接影响到程序的…

8、matlab彩色图和灰度图的二值化算法汇总

1、彩色图和灰度图的二值化算法汇总原理及流程 彩色图和灰度图的二值化算法的原理都是将图像中的像素值转化为二值(0或1),以便对图像进行简化或者特定的图像处理操作。下面分别介绍彩色图和灰度图的二值化算法的原理及流程: 1&a…

【java实现结果集转为树结构,树转为扁平结构】

list转为树,树拉平 业务需求oracle实现树结构1、**Controller.java层** :前端调此处请求2、**service层:** 逻辑结构 (zbjcpjService.java),重点:this.entityMapper.queryZbjcpjTree接口3、**ma…

opc ua设备数据 转MQTT项目案例

目录 1 案例说明 1 2 VFBOX网关工作原理 1 3 准备工作 2 4 配置VFBOX网关采集OPC UA的数据 2 5 用MQTT协议转发数据 4 6 配置参数说明 4 7 上报内容配置 5 8 其他说明 8 9 案例总结 8 1 案例说明 设置网关采集OPC UA设备数据把采集的数据转成MQTT协议转发给其他系统。 2 VFB…