我的PyTorch模型比内存还大,怎么训练呀?

news2025/1/18 12:01:50

原文:我的PyTorch模型比内存还大,怎么训练呀? - 知乎

看了一篇比较老(21年4月文章)的不大可能训练优化方案,保存起来以后研究一下。

我的PyTorch模型比内存还大,怎么训练呀?

随着深度学习的飞速发展,模型越来越臃肿,哦不,先进,运行SOTA模型的主要困难之一就是怎么把它塞到 GPU 上,毕竟,你无法训练一个设备装不下的模型。改善这个问题的技术有很多种,例如,分布式训练和混合精度训练。

本文将介绍另一种技术: 梯度检查点(gradient checkpointing)简单的说,梯度检查点的工作原理是在反向时重新计算深层神经网络的中间值(而通常情况是在前向时存储的)。这个策略是用时间(重新计算这些值两次的时间成本)来换空间(提前存储这些值的内存成本)。

文末有一个示例基准测试,它显示了梯度检查点减少了模型 60% 的内存开销(以增加 25% 的训练时间为代价)。

详细代码请查看我的 GitHub 库: https://github.com/spellml/tweet-sentiment-extraction/blob/master/notebooks/5-checkpointing.ipynb

>>> 神经网络如何使用内存

为了理解梯度检查点是如何起作用的,我们首先需要了解一下模型内存分配是如何工作的。

神经网络使用的总内存基本上是两个部分的和。

第一部分是模型使用的静态内存。尽管 PyTorch 模型中内置了一些固定开销,但总的来说几乎完全由模型权重决定。当今生产中使用的现代深度学习模型的总参数在100万到10亿之间。作为参考,一个带 16GB GPU 内存的 NVIDIA T4 的实际限制大约在1-1.5亿个参数之间。

第二部分是模型的计算图所占用的动态内存。在训练模式下,每次通过神经网络的前向传播都为网络中的每个神经元计算一个激活值,这个值随后被存储在所谓的计算图中。必须为批中的每个单个训练样本存储一个值,因此数量会迅速的累积起来。总开销由模型大小和批次大小决定,一般设置最大批次大小限制来适配你的 GPU 内存。

要了解更多关于 PyTorch autograd 的信息,请查看我的 Kaggle 笔记本《PyTorch autograd 解释》: https://www.kaggle.com/residentmario/pytorch-autograd-explained

>>> 梯度检查点是如何起作用的

大型模型在静态和动态方面都很耗资源。首先,它们很难适配 GPU,而且哪怕你把它们放到了设备上,也很难训练,因为批次大小被迫限制的太小而无法收敛。

现有的各种技术可以改善这些问题中的一个或两个。梯度检查点就是这样一种技术; 分布式训练,是另一种技术。

梯度检查点(gradient checkpointing) 的工作原理是从计算图中省略一些激活值。这减少了计算图使用的内存,降低了总体内存压力(并允许在处理过程中使用更大的批次大小)。

但是,一开始存储激活的原因是,在反向传播期间计算梯度时需要用到激活。在计算图中忽略它们将迫使 PyTorch 在任何出现这些值的地方重新计算,从而降低了整体计算速度。

因此,梯度检查点是计算机科学中折衷的一个经典例子,即在内存和计算之间的权衡。

PyTorch 通过 torch.utils.checkpoint.checkpoint 和 torch.utils.checkpoint.checkpoint_sequential 提供梯度检查点,根据官方文档的 notes,它实现了如下功能,在前向传播时,PyTorch 将保存模型中的每个函数的输入元组。在反向传播过程中,对于每个函数,输入元组和函数的组合以实时的方式重新计算,插入到每个需要它的函数的梯度公式中,然后丢弃。网络计算开销大致相当于每个样本通过模型前向传播开销的两倍。

梯度检查点首次发表在2016年的论文 《Training Deep Nets With Sublinear Memory Cost》 中。论文声称提出的梯度检查点算法将模型的动态内存开销从 O(n)n 为模型中的层数)降低到 O(sqrt(n)),并通过实验展示了将 ImageNet 的一个变种从 48GB 压缩到了 7GB 内存占用。

>>> 测试 API

PyTorch API 中有两个不同的梯度检查点方法,都在 torch.utils.checkpoint 命名空间中。两者中比较简单的一个是 checkpoint_sequential,它被限制用于顺序模型(例如使用 torch.nn.Sequential wrapper 的模型)。另一个是更灵活的 checkpoint,可以用于任何模块。

下面是一个完整的代码示例,显示了 checkpoint_sequential 的实际用法:

import torch
import torch.nn as nn

from torch.utils.checkpoint import checkpoint_sequential

# a trivial model
model = nn.Sequential(
    nn.Linear(100, 50),
    nn.ReLU(),
    nn.Linear(50, 20),
    nn.ReLU(),
    nn.Linear(20, 5),
    nn.ReLU()
)

# model input
input_var = torch.randn(1, 100, requires_grad=True)

# the number of segments to divide the model into
segments = 2

# finally, apply checkpointing to the model
# note the code that this replaces:
# out = model(input_var)
out = checkpoint_sequential(modules, segments, input_var)

# backpropagate
out.sum().backwards()

如你所见,checkpoint_sequential 替换了 module 对象上的 forward 或 __call__ 方法。out 几乎和我们调用 model(input_var) 时得到的张量一样; 关键的区别在于它缺少了累积值,并且附加了一些额外的元数据,指示 PyTorch 在 out.backward() 期间需要这些值时重新计算。

值得注意的是,checkpoint_sequential 接受整数值的片段数作为输入。checkpoint_sequential 将模型分割成 n 个纵向片段,并对除了最后一个的每个片段应用检查点。

这工作很容易,但有一些主要的限制。你无法控制片段的边界在哪里,也无法对整个模块应用检查点(而是其中的一部分)。

替代方法是使用更灵活的 checkpoint API. 下面展示了一个简单的卷积模型:

class CIFAR10Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn_block_1 = nn.Sequential(*[
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(0.25)
        ])
        self.cnn_block_2 = nn.Sequential(*[
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(0.25)
        ])
        self.flatten = lambda inp: torch.flatten(inp, 1)
        self.head = nn.Sequential(*[
            nn.Linear(64 * 8 * 8, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 10)
        ])

    def forward(self, X):
        X = self.cnn_block_1(X)
        X = self.cnn_block_2(X)
        X = self.flatten(X)
        X = self.head(X)
        return X

这种模型有两个卷积块,一些 dropout,和一个线性头(10个输出对应 CIFAR10 的10类)。

下面是这个模型使用梯度检查点的更新版本:

class CIFAR10Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn_block_1 = nn.Sequential(*[
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        ])
        self.dropout_1 = nn.Dropout(0.25)
        self.cnn_block_2 = nn.Sequential(*[
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        ])
        self.dropout_2 = nn.Dropout(0.25)
        self.flatten = lambda inp: torch.flatten(inp, 1)
        self.linearize = nn.Sequential(*[
            nn.Linear(64 * 8 * 8, 512),
            nn.ReLU()
        ])
        self.dropout_3 = nn.Dropout(0.5)
        self.out = nn.Linear(512, 10)

    def forward(self, X):
        X = self.cnn_block_1(X)
        X = self.dropout_1(X)
        X = checkpoint(self.cnn_block_2, X)
        X = self.dropout_2(X)
        X = self.flatten(X)
        X = self.linearize(X)
        X = self.dropout_3(X)
        X = self.out(X)
        return X

在 forward 中显示的 checkpoint 接受一个模块(或任何可调用的模块,如函数)及其参数作为输入。参数将在前向时被保存,然后用于在反向时重新计算其输出值。

为了使其能够工作,我们必须对模型定义进行一些额外的更改。

首先,你会注意到我们从卷积块里删除了 nn.Dropout 层; 这是因为检查点与 dropout 不兼容(回想一下,样本有效地通过模型两次 —— dropout 会在每次通过时任意丢失不同的值,从而产生不同的输出)。基本上,任何在重新运行时表现出非幂等(non-idempotent )行为的层都不应该应用检查点(nn.BatchNorm 是另一个例子)。解决方案是重构模块,这样问题层就不会被排除在检查点片段之外,这正是我们在这里所做的。

其次,你会注意到我们在模型中的第二卷积块上使用了检查点,但是第一个卷积块上没有使用检查点。这是因为检查点简单地通过检查输入张量的 requires_grad 行为来决定它的输入函数是否需要梯度下降(例如,它是否处于 requires_grad=True 或 requires_grad=False模式)。模型的输入张量几乎总是处于 requires_grad=False 模式,因为我们感兴趣的是计算相对于网络权重而不是输入样本本身的梯度。因此,模型中的第一个子模块应用检查点没多少意义: 它反而会冻结现有的权重,阻止它们进行任何训练。更多细节请参考这个 PyTorch 论坛帖子:https://discuss.pytorch.org/t/use-of-torch-utils-checkpoint-checkpoint-causes-simple-model-to-diverge/116271

在 PyTorch 文档(https://pytorch.org/docs/stable/checkpoint.html#)中还讨论了 RNG 状态以及与分离张量不兼容的一些其他细节。

完整的训练代码示例可以看这里: https://gist.github.com/ResidentMario/e3254172b4706191089bb63ecd610e21

和这里: https://gist.github.com/ResidentMario/9c3a90504d1a027aab926fd65ae08139

>>> 基准测试

作为一个快速的基准测试,我在 tweet-sentiment-extraction 上启用了模型检查点,这是一个基于 Twitter 数据的带有 BERT 主干的情感分类器模型。你可以在这里看到代码:https://github.com/spellml/tweet-sentiment-extraction。transformers 已经将模型检查点作为 API 的一个可选部分来实现; 为我们的模型启用它就像翻转一个布尔值标记一样简单:

# code from model_5.py

cfg = transformers.PretrainedConfig.get_config_dict("bert-base-uncased")[0]
cfg["output_hidden_states"] = True
cfg["gradient_checkpointing"] = True  # NEW!
cfg = transformers.BertConfig.from_dict(cfg)
self.bert = transformers.BertModel.from_pretrained(
    "bert-base-uncased", config=cfg
)

我对这个模型进行了四次训练: 分别在 NVIDIA T4和 NVIDIA V100 GPU 上,包括检查点和无检查点模式。所有运行的批次大小为 64。以下是结果:

第一行是在模型检查点关闭的情况下进行的训练,第二行是在模型检查点开启的情况下进行的训练。

模型检查点降低了峰值模型内存使用量 60% ,同时增加了模型训练时间 25% 。

当然,你想要使用检查点的主要原因可能是,这样你就可以在 GPU 上使用更大的批次大小。在另一篇博文:https://qywu.github.io/2019/05/22/explore-gradient-checkpointing.html 中演示了这个很好的例子: 在他们的例子中,每批次样本从 24 个提高到惊人的 132 个!

要处理大型神经网络,模型检查点显然是一个非常强大和有用的工具。

原文: https://spell.ml/blog/gradient-checkpointing-pytorch-YGypLBAAACEAefHs

发布于 2021-04-27 22:39

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1439026.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2. Maven 继承与聚合

目录 2. 2.1 继承 2.2继承关系 2.2.1 思路分析 2.2.2 实现 2.1.2 版本锁定 2.1.2.1 场景 2.1.2.2 介绍 2.1.2.3 实现 2.1.2.4 属性配置 2.2 聚合 2.2.1 介绍 2.2.2 实现 2.3 继承与聚合对比 maven1:分模块设计开发 2. 在项目分模块开发之后啊&#x…

io和File的综合练习:

先来说说字节流和字符流的应用场景 练习一: /*拷贝一个文件夹考虑子文件夹*///源文件夹路径File src new File("E:\\aaa-FIle学习测试\\bbb");//目的文件夹路径File dest new File("E:\\aaa-FIle学习测试\\ccc");copy(src,dest);}public stati…

next项目页面性能调优

next项目页面性能调优 一般来说性能优化可以分为加载时、运行时两部分的优化。 扩展参考链接: 前端性能优化 24 条建议 Webpack 4进阶–从前的日色变得慢 ,一下午只够打一次包 Webpack 分包优化首屏加载 参考指标 FCP(First Contentful P…

《MySQL 简易速速上手小册》第3章:性能优化策略(2024 最新版)

文章目录 3.1 查询优化技巧3.1.1 基础知识3.1.2 重点案例3.1.3 拓展案例 3.2 索引和查询性能3.2.1 基础知识3.2.2 重点案例3.2.3 拓展案例 3.3 优化数据库结构和存储引擎3.3.1 基础知识3.3.2 重点案例3.3.3 拓展案例 3.1 查询优化技巧 让我们来聊聊如何让你的 MySQL 查询跑得像…

【Linux】vim的基本操作与配置(上)

Hello everybody!今天我们要进入vim的讲解了。学会了vim,咱们就可以在Linux系统上做一些简单的编程啦! 那么废话不多说,咱们直接进入正题! 1.初识vim vim是一款多模式的文本编辑器,可以对一个文件进行编辑操作。 它一共有三个模…

【射影几何13 】梅氏定理和塞瓦定理探讨

梅氏定理和塞瓦定理 目录 一、说明二、梅涅劳斯(Menelaus)定理三、塞瓦(Giovanni Ceva)定理四、塞瓦点的推广 一、说明 在射影几何中,梅涅劳斯(Menelaus)定理和塞瓦定理是非常重要的基本定理。通过这两个定…

09 AB 10串口通信发送原理

通用异步收发传输器( Universal Asynchronous Receiver/Transmitter, UART)是一种异步收发传输器,其在数据发送时将并行数据转换成串行数据来传输, 在数据接收时将接收到的串行数据转换成并行数据, 可以实现…

【数据分享】1929-2023年全球站点的逐年平均降水量(Shp\Excel\免费获取)

气象数据是在各项研究中都经常使用的数据,气象指标包括气温、风速、降水、湿度等指标,说到常用的降水数据,最详细的降水数据是具体到气象监测站点的降水数据! 有关气象指标的监测站点数据,之前我们分享过1929-2023年全…

训练集,验证集,测试集比例

三者的区别 训练集(train set) —— 用于模型拟合的数据样本。验证集(validation set)—— 是模型训练过程中单独留出的样本集,它可以用于调整模型的超参数和用于对模型的能力进行初步评估。 通常用来在模型迭代训练时…

DevOps落地笔记-17|度量指标:寻找真正的好指标?

前面几个课时端到端地介绍了软件开发全生命周期中涉及的最佳实践,经过上面几个步骤,企业在进行 DevOps 转型时技术方面的问题解决了,这个时候我们还缺些什么呢?事实上很多团队和组织在实施 DevOps 时都专注于技术,而忽…

【力扣】查找总价格为目标值的两个商品,双指针法

查找总价格为目标值的两个商品原题地址 方法一:双指针 这道题和力扣第一题“两数之和”非常像,区别是这道题已经把数组排好序了,所以不考虑暴力枚举和哈希集合的方法,而是利用单调性,使用双指针求解。 考虑数组pric…

零代码3D可视化快速开发平台

老子云平台 老子云3D可视化快速开发平台,集云压缩、云烘焙、云存储云展示于一体,使3D模型资源自动输出至移动端PC端、Web端,能在多设备、全平台进行展示和交互,是全球领先、自主可控的自动化3D云引擎。此技术已经在全球申请了专利…

力扣优选算法100道——【模板】前缀和(一维)

【模板】前缀和_牛客题霸_牛客网 (nowcoder.com) 目录 🚩了解题意 🚩算法原理 🎈设定下标为1开始 🎈取值的范围 🚩实现代码 🚩了解题意 第一行的3和2,3代表行数,2代表q次查询(…

【Java数据结构】ArrayList和LinkedList的遍历

一&#xff1a;ArrayList的遍历 import java.util.ArrayList; import java.util.Iterator; import java.util.List;/*** ArrayList的遍历*/ public class Test {public static void main(String[] args) {List<Integer> list new ArrayList<>();list.add(5);list…

MATLAB环境下生成对抗网络系列(11种)

为了构建有效的图像深度学习模型&#xff0c;数据增强是一个非常行之有效的方法。图像的数据增强是一套使用有限数据来提高训练数据集质量和规模的数据空间解决方案。广义的图像数据增强算法包括&#xff1a;几何变换、颜色空间增强、核滤波器、混合图像、随机擦除、特征空间增…

寒假作业2024.2.6

1.现有无序序列数组为23,24,12,5,33,5347&#xff0c;请使用以下排序实现编程 函数1:请使用冒泡排序实现升序排序 函数2:请使用简单选择排序实现升序排序 函数3:请使用直接插入排序实现升序排序 函数4:请使用插入排序实现升序排序 #include <stdio.h> #include <stdl…

一个坐标系查询网站python获取所有坐标系

技术路线选择 我是使用的vue 3开发的网页界面&#xff0c;element-plus构建网页组件&#xff0c;openlayer展示地图&#xff0c;express提供后端API&#xff0c;vercel进行在线部署。 python获取所有坐标系 想要展示所有坐标系&#xff0c;那需要先获取坐标系&#xff0c;怎么…

【开源】基于JAVA+Vue+SpringBoot的贫困地区人口信息管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 人口信息管理模块2.2 精准扶贫管理模块2.3 特殊群体管理模块2.4 案件信息管理模块2.5 物资补助模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 人口表3.2.2 扶贫表3.2.3 特殊群体表3.2.4 案件表3.2.5 物资补助表 四…

机器人学、机器视觉与控制 上机笔记(2.1章节)

机器人学、机器视觉与控制 上机笔记&#xff08;2.1章节&#xff09; 1、前言2、本篇内容3、代码记录3.1、新建se23.2、生成坐标系3.3、将T1表示的变换绘制3.4、完整绘制代码3.5、获取点*在坐标系1下的表示3.6、相对坐标获取完整代码 4、结语 1、前言 工作需要&#xff0c;想同…