深度学习中的并行策略概述:1 单GPU优化

news2024/12/26 11:54:04

深度学习中的并行策略概述:1 单GPU优化
在这里插入图片描述

1 Training Larger Models on a Single GPU

在讨论模型的“扩展”时,往往会想到在多个GPU或多台机器上进行模型训练。不过,即便是在单个GPU上,也存在多种方法来训练更大规模的模型并提升其效率。本文将探讨包括混合精度训练 Mixed Precision Training、激活函数检查点 Activation Recomputation 、梯度累积 Gradient Accumulation 在内的一些技术。这些技术主要致力于降低训练过程中的内存占用,因为在单设备训练中,内存常常是受限的资源。此外,当在多个GPU或TPU上进行训练时,这些技术同样适用。

混合精度训练(Mixed Precision Training)

混合精度训练是一种结合16位和32位浮点数以加速模型训练的技术。16位浮点数(如float16和bfloat16)用于加速计算并减少内存使用,而32位浮点数(float32)用于关键计算以保持数值稳定。bfloat16因其较宽的表示范围,可以在无需损失缩放的情况下使用,适合作为float32的替代品,以节省内存并接近其性能。
在这里插入图片描述
通过将模型的特征和激活值转换为bfloat16,而保持权重和优化器状态为float32,来实现混合精度训练。这种方法在保持高精度更新的同时,减少了内存占用并提高了训练速度。如果模型过大无法完全载入内存,还可以考虑对模型参数和优化器应用更低精度的处理。通过这种方式,可以在不牺牲模型性能的前提下,优化内存使用和训练效率。

PyTorch从1.6版本开始内置了torch.cuda.amp,用于自动混合精度训练。使用自动混合精度(AMP)可以减少显存占用并加速训练。

from torch.cuda.amp import autocast, GradScaler

# 初始化模型、优化器
model = TransformerModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 初始化梯度缩放器
scaler = GradScaler()

# 训练循环
for data, target in dataloader:
    optimizer.zero_grad()
    
    # 使用autocast上下文管理器实现混合精度
    with autocast():
        output = model(data)
        loss = criterion(output, target)
    
    # 缩放损失,以避免在反向传播时出现梯度下溢
    scaler.scale(loss).backward()
    
    # 梯度缩放器会根据需要调整梯度
    scaler.step(optimizer)
    
    # 更新缩放器
    scaler.update()
激活检查点(Activation Checkpointing / Activation Recomputation)

梯度检查点技术通过在反向传播中重新计算部分激活值,以计算换取内存。这种方法在前向传播时仅保存部分激活值,其余在反向传播时重新计算。对于像Transformer这样的大内存占用模型,当激活值的存储成为瓶颈时,此技术尤其有效,因为重新计算激活值的成本通常低于存储它们。

以一个简化的仅包含MLP块的Transformer为例。每个MLP块由两个全连接层组成,中间是GELU激活函数,使用bfloat16格式的激活值(每个激活值2字节)。设批量大小为B,序列长度为S,隐藏层大小为H。前向传播中激活值的内存消耗总计为(2BSH+8BSH+8BSH+BSH)19BSH字节。
在这里插入图片描述
采用梯度检查点,可以选择仅保留大小为2BSH的输入张量,并在反向传播中重新计算其他激活值。这几乎可以减少90%的激活值内存消耗,代价是需要在反向传播中重新计算这些激活值。

激活检查点是一种节省显存的技术,它通过在反向传播时重新计算前向传播中的中间激活值,而不是保存它们。

from torch.utils.checkpoint import checkpoint

# 定义一个使用激活检查点的Transformer模型部分
def partial_forward(model, inputs):
    activation = model(inputs)
    return activation

# 在训练循环中使用激活检查点
for data, target in dataloader:
    optimizer.zero_grad()
    
    # 使用checkpoint函数实现激活检查点
    with autocast():
        activation = checkpoint(partial_forward, model, data)
    
    # 继续计算损失
    output = model(activation)
    loss = criterion(output, target)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
梯度累积(Gradient Accumulation)

训练大型模型时,常在批量大小和内存需求间权衡。增加批量大小能够提高梯度估计的准确性,但也会增加内存消耗。面对硬件内存限制批量大小的情况,梯度累积技术通过累积多个小批量的梯度来模拟大批量训练的效果。每个小批量独立处理,所有小批量处理完毕后再统一更新模型。这种方法在内存受限但需要大批量训练时特别有用。然而,梯度累积的缺点在于其串行处理小批量数据,未能实现并行化,因此需要确保即使在小批量训练时也能最大化硬件的利用率。
在这里插入图片描述
在示意图中,展示了一个总批量大小为8的梯度累积过程,该过程通过4个大小为2的小批量(称为迷你批次)来实现。每处理完一个迷你批次,即可释放其前向传播和反向传播过程中产生的所有中间数组,随后继续处理下一个迷你批次。待所有迷你批次处理完毕,便执行优化器的更新步骤。这种方法使得我们能够以仅需相当于批量大小2的内存需求,模拟出批量大小为8的训练效果。

梯度累积允许模拟更大的批量大小,通过累积多个小批量的梯度。

# 设置梯度累积参数
accum_iter = 4

# 训练循环
for batch_idx, (data, target) in enumerate(dataloader):
    optimizer.zero_grad()
    
    with autocast():
        output = model(data)
        loss = criterion(output, target)
    
    # 累积梯度
    loss = loss / accum_iter
    loss.backward()
    
    # 每累积一定步数后更新参数
    if ((batch_idx + 1) % accum_iter == 0) or (batch_idx + 1 == len(dataloader)):
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

2 Profiling and Scaling Single-GPU Transformer Models

下面是一个完整的示例代码,展示了如何在PyTorch中使用混合精度训练(AMP)、激活检查点(Activation Checkpointing)和梯度累积(Gradient Accumulation)来优化Transformer模型的训练:

import torch
from torch.cuda.amp import autocast, GradScaler
from torch.utils.checkpoint import checkpoint
from torch import nn, optim

# 假设TransformerModel是你要训练的模型
class TransformerModel(nn.Module):
    # 这里只是一个示例结构,你需要根据实际情况定义你的Transformer模型
    def __init__(self):
        super(TransformerModel, self).__init__()
        self.encoder = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        self.decoder = nn.TransformerDecoderLayer(d_model=512, nhead=8)

    def forward(self, src, tgt):
        src = self.encoder(src)
        tgt = self.decoder(tgt, src)
        return tgt

# 初始化模型、优化器和损失函数
model = TransformerModel().cuda()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

# 设置梯度累积参数
accum_num = 4
accum_iter = 0

# 定义一个使用激活检查点的Transformer模型部分
def partial_forward(model, inputs):
    activation = model.encoder(inputs)
    return activation

# 训练循环
for data, target in dataloader:
    optimizer.zero_grad()
    
    # 使用checkpoint函数实现激活检查点
    with autocast():
        # 假设data和target已经是适当的tensor并且已经移到了GPU上
        src = checkpoint(partial_forward, model, data)
        
        # 继续计算损失
        output = model.decoder(target, src)
        loss = criterion(output, target)
    
    # 累积梯度
    scaler.scale(loss).backward()
    accum_iter += 1
    
    # 每累积一定步数后更新参数
    if accum_iter % accum_num == 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        accum_iter = 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2265826.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构(哈希表(中)纯概念版)

前言 哈希表(Hash Table)是计算机科学中的一个基础而重要的数据结构,它广泛评估各种算法和系统中,尤其是在需要快速查找、插入和删除操作的场景中。由于其O( 1)的平均时间复杂度,存储表在性能要求较高的应用中表现得非…

centos7的磁盘扩容

1、首先,确认你的磁盘是否已经正确识别并添加了新的空间。你可以使用lsblk或fdisk -l命令来查看 lsblk fdisk /dev/vda 2、我的情况是这样的,误操作将盘扩展为物理卷轴了,所以说是这样呈现的,如果有我的那种情况请先删除物理卷轴…

uniapp 微信小程序 页面部分截图实现

uniapp 微信小程序 页面部分截图实现 ​ 原理都是将页面元素画成canvas 然后将canvas转化为图片,问题是我页面里边本来就有一个canvas,ucharts图画的canvas我无法画出这块。 ​ 想了一晚上,既然canvas最后能转化为图片,那我直接…

ubuntu笔记

1.系统下载与虚拟机设置 系统下载https://cn.ubuntu.comhttps://releases.ubuntu.com 虚拟机设置: 桥接模式 在桥接模式下, 虚拟出来的操作系统就像是局域网中的一台独立的主机, 它可以访问网内任何一台机器主机网卡和虚拟网卡的IP地址处于同一个网段, 子网掩码、网关、DNS等…

音视频入门基础:AAC专题(13)——FFmpeg源码中,获取ADTS格式的AAC裸流音频信息的实现

音视频入门基础:AAC专题系列文章: 音视频入门基础:AAC专题(1)——AAC官方文档下载 音视频入门基础:AAC专题(2)——使用FFmpeg命令生成AAC裸流文件 音视频入门基础:AAC…

开发高效实时美颜工具:从美颜SDK到直播APP插件的全流程解析

今天,小编将以美颜SDK为核心,从开发、集成到优化的全流程,深入解析高效实时美颜工具的实现路径。 一、美颜SDK的核心功能与技术构成 美颜SDK是实时美颜技术的核心模块,承担着图像处理和效果呈现的重任。其主要功能包括&#xff…

用 gdbserver 调试 arm-linux 上的 AWTK 应用程序

很多嵌入式 linux 开发者都能熟练的使用 gdb/lldb 调试应用程序,但是还有不少朋友在调试开发板上的程序时,仍然在使用原始的 printf。本文介绍一下使用 gdbserver 通过网络调试开发板上的 AWTK 应用程序的方法,供有需要的朋友参考。 1. 下载 …

20241225在ubuntu22.04.5下使用smartmontools命令查看ssd的寿命

20241225在ubuntu22.04.5下使用smartmontools命令查看ssd的寿命 2024/12/25 15:10 rootrootrootroot-ThinkBook-16-G5-IRH:~$ sudo apt install smartmontools rootrootrootroot-ThinkBook-16-G5-IRH:~$ sudo fdisk -l Disk /dev/nvme0n1: 3.73 TiB, 4096805658624 bytes, 800…

ASP.NET |日常开发中定时任务详解

ASP.NET |日常开发中定时任务详解 前言一、定时任务的概念与用途1.1 定义1.2 应用场景 二、在ASP.NET中实现定时任务的方式2.1 使用System.Timers.Timer2.2 使用Quartz.NET 三、定时任务的部署与管理3.1 部署考虑因素3.2 管理与监控 结束语优质源码分享 ASP.NET &am…

整车厂如何规划构建汽车集成安全团队的软件研发能力

(一)、汽车集成安全团队职责 汽车集成安全团队肩负着保障汽车整体安全性的重任,从多个维度守护驾乘人员安全与车辆稳定运行,其主要职责如下: 功能安全管理 标准遵循与流程制定:严格依据ISO 26262等功能安…

使用 Python 创建多栏 Word 文档 – 详解

目录 引言 一、工具与安装 二、Python 在 Word 中创建简单的多栏布局 三、Python 在 Word 文档的栏间添加分隔线 四、Python 从Word文档的指定位置开启多栏设置 五、Python 为多栏 Word 文档的各栏添加页码 引言 在文档设计中,排版不仅决定了内容的呈现方式&…

使用强化学习与遗传算法优化3D低空物流路径_版本2

在快速发展的物流与自主系统领域,优化无人机在三维空间中的飞行路径至关重要。无论是在城市环境中导航还是在复杂地形中穿行,确保高效、安全且节能的航线规划能够显著提升运营效率。本文将深入探讨一种创新方法,结合强化学习(Rein…

[手机Linux] 七,NextCloud优化设置

安装完成后在个人设置里发现很多警告,一一消除。 只能一条一条解决了。 关于您的设置有一些错误。 1,PHP 内存限制低于建议值 512 MB。 设置php配置文件: /usr/local/php/etc/php.ini 把里面的: memory_limit 128M 根据你自…

【设备 磁盘】重要备份存放U盘的风险 + winhex 磁盘清零(清理windows无法格式化的磁盘)

简述 清理用设备管理器和DiskGenious无法打开的磁盘 winhex安装 官网https://www.x-ways.net/winhex/下载,解压后以管理员身份运行 注意:非完全版不能像磁盘写入编辑后的数据 使用 解压后直接点击打开即可 打开磁盘 “全选”后,选择…

从LockSupport开始带来的思考

LockSupport是什么 LockSupport是JUC下的一个线程同步工具类,实现了线程的阻塞和唤醒操作。相比其他同步机制,如Synchronized、ReentrantLock等,LockSupport的性能更高、更灵活,同时也可以避免线程操作不当引起的死锁问题。Java中…

树莓集团:以产教融合助力人才培养

在当今快速发展的数字时代,人才是推动产业进步和创新的核心驱动力。树莓集团作为数字产业生态链建设者,深刻认识到人才培养的关键意义,积极探索并大力践行产教融合模式,为数字产业源源不断地输送高素质专业人才,在助力…

基于ISO 21434的汽车网络安全实践

商业领域的IT系统和嵌入式产品的IT系统正在融合为一种多功能系统。相应地,关注汽车网络安全的ISO 21434标准应运而生。该标准的意义在于提供了一个指南,可用于降低产品、项目和组织中存在的安全风险。为了有效实施ISO 21434标准,本文介绍了遵…

3.银河麒麟V10 离线安装Nginx

1. 下载nginx离线安装包 前往官网下载离线压缩包 2. 下载3个依赖 openssl依赖,前往 官网下载 pcre2依赖下载,前往Git下载 zlib依赖下载,前往Git下载 下载完成后完整的包如下: 如果网速下载不到请使用网盘下载 通过网盘分享的文件…

视频监控平台:Liveweb视频汇聚融合平台智慧安防视频监控应用方案

Liveweb是一款功能强大、灵活部署的安防视频监控平台,支持多种主流标准协议,包括GB28181、RTSP/Onvif、RTMP等,同时兼容海康Ehome、海大宇等厂家的私有协议和SDK接入。该平台不仅提供传统安防监控功能,还支持接入AI智能分析&#…