从gradient_checkpointing_enable中学习

news2024/10/3 20:21:32

1.背景

最近在使用官网的教程训练chatGLM3,但是出现了“RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn”错误,查阅了官方的文档,目前这个问题还没什么解决方案
在这里插入图片描述

但是其中有人回复说:是注释掉503行的model.gradient_checkpointing_enable() 。个人验证确实是可以成功的,那么问题来了model.gradient_checkpointing_enable() 到底是干什么?为什么有它不行.

2.model.gradient_checkpointing_enable()的作用

这个函数调用启用了模型的梯度检查点。梯度检查点是一种优化技术,可用于减少训练时的内存消耗。通常,在反向传播期间,模型的中间激活值需要被保留以

计算梯度。启用梯度检查点后,系统只需在需要时计算和保留一部分中间激活值,从而减少内存需求。这对于处理大型模型或限制内存的环境中的训练任务非常有用。

这个函数的位置是在

/mnt/workspace/miniconda3/envs/chatglm3/lib/python3.10/site-packages/transformers/modeling_utils.py(2102)gradient_checkpointing_enable()

当然不同环境和版本会有差异,但是大体上不会根本上出入,关于其中的实现,大家可以去查看.

说到底这个函数只是为了节约内存的作用,所以去掉这个,对于上述问题解决是可行,但是具体为什么加了这个之后,chatGLM3微调会出现问题,这个会在另一篇文章进行深究, 在参阅资料发现, 还有很多节约显存的方法,接下来进行一一介绍:

3、Transformers的性能优化方法

算力依然是ai时代最重要的武器, 对于没有线买算力的伙伴而言,算力更是重中之重,节约使用GPU,人人有责, 因此引发学习到很多其他节约显存的方法,记录一下,方便自己和他人查阅学习.

(1)梯度累积(Gradient Accumulation)

(2)冻结(Freezing)

(3)自动混合精度(Automatic Mixed Precision)

(4)8位优化器(8-bit Optimizers)

(5)快速分词器(Fast Tokenizers)

(6)动态填充(Dynamic Padding)

(7)均匀动态填充(Uniform Dynamic Padding)

其中(1)~(4)包括上述说的gradient_checkpointing_enable()方法是适用于任何网络上, (5)~(7)一般是适用于自然语言的上的.

(1)梯度累积

我们都知道,最好是所有的样本的损失一起反向传播是最精确的,但是由于显存的限制,无法做到,所有样本计算和存储,又因为小批量,容易导致训练结果过分敏感样本,所以就有了中庸之道, 不大不小.

在这里插入图片描述

(2)冻结

冻结是一种非常有效的方法,通过取消计算模型某些层中的梯度计算(如embedding层,bert的前几层),可以大大加快训练速度并且降低了显存占用,而且几乎不会损失模型的性能, 特别是某种优化算法(如SGD、AdamW或RMSprop)执行优化步骤时,网络的底层的梯度就都很小,因此参数几乎保持不变,这也被称为梯度消失,因此,与其花费大量的时间和算力来计算底层这些“无用”梯度,并对此类梯度很小的参数进行优化,不如直接冻结它们,直接不计算梯度也不进行优化。

PyTorch为关闭梯度计算提供了一个舒适的API,可以通过torch.Tensor的属性requires_ grad设置。

def freeze(module):
    """
    Freezes module's parameters.
    """
    for parameter in module.parameters():
        parameter.requires_grad = False

(3)自动混合精度

在这里插入图片描述

关键思想是使用较低的精度将模型的梯度和参数保留在内存中,即不使用全精度(float32),而是使用半精度(例如float16)将张量保存在内存中。然而,当以较低精度计算梯度时,某些值可能太小,以至于被视为零,这种现象被称为“溢出”。为了防止“溢出”,原始论文的作者提出了一种梯度缩放方法。

PyTorch提供了一个包:torch.cuda.amp,具有使用自动混合精度所需的功能(从降低精度到梯度缩放),自动混合精度作为上下文管理器实现,因此可以随时随地插入到训练和推理脚本中。


from torch.cuda.amp import autocast, GradScaler


scaler = GradScaler()

for step, batch in enumerate(loader, 1):

    # prepare inputs and targets for the model and loss function respectively.

    # forward pass with `autocast` context manager
    with autocast(enabled=True):
        outputs = model(inputs)

    # computing loss
    loss = loss_fn(outputs, targets)

    # scale gradint and perform backward pass
    scaler.scale(loss).backward()

    # before gradient clipping the optimizer parameters must be unscaled.
    scaler.unscale_(optimizer)

    # perform optimization step
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)

    scaler.step(optimizer)
  

(4)8位优化器

思想类似于自动混合精度(模型的参数和梯度使用较低的精度保存),但8-bit Optimizers还让优化器的状态使用低精度保存,作者为8位优化器提供了一个高级库,称为bitsandbytes。这种方式在大模型微调是非常常见的.

(5)快速分词器

HuggingFace Transformers提供两种类型的分词器:基本分词器和快速分词器。它们之间的主要区别在于,fast是在rust编写的,因为python在循环中非常慢,fast可以让我们在tokenize时获得额外的加速。下图是tokenize工作的原理示意,Tokenizer类型可以通过更改transformers.AutoTokenizerfrom_pretrained将use_fast属性设为True。

from transformers import AutoTokenizer

# initializing Base version of Tokenizer
model_path = "microsoft/deberta-v3-base"
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
print(f"Base version Tokenizer:\n\n{tokenizer}", end="\n"*3)

# initializing Fast version of Tokenizer
fast_tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
print(f"Fast version Tokenizer:\n\n{fast_tokenizer}")

(6)动态填充

是为了解决固定长度填充的问题:

固定长度填充的过程: 批中的每个输入必须具有固定大小,所有批量数据的尺寸都一样。固定尺寸通常是根据数据集中的长度分布、特征数量和其他因素来选择的。在NLP任务中,输入大小称为文本长度,或者最大长度(max length)。然而,不同的文本具有不同的长度,为了处理这种情况,研究人员提出了填充标记和截断。当最大长度小于输入文本的长度时,会使用截断,因此会删除一些标记。当输入文本的长度小于最大长度时,会将填充标记,比如[PAD],添加到输入文本的末尾.

在这里插入图片描述

缺点也是非常明显:比如在输入文本相对于选定的最大长度非常短的情况下,效率就很低,需要更多的额外内存.

将批量的输入填充到这一批量的最大输入长度,如下图所示,这种方法可以将训练速度提高35%甚至50%,当然这种方法加速的效果取决于批量的大小以及文本长度的分布,批量越小,加速效果越明显,文本长度分布越不均,加速效果也越好。

在这里插入图片描述

(7)均匀动态填充

分batch时,先按文本的长度对文本进行排序,这样同一个batch里面的文本长度就都差不多。这种方法非常有效,在训练或推理期间的计算量都比动态填充要来得少,这种方式比较适用于推理阶段,因为在训练的时候,更需要shuffle训练数据集

在这里插入图片描述

参考文章

1、Transformers的性能优化方法

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1457549.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

手持三防平板丨国产化加固平板丨国产三防平板发展的意义是什么?

随着现代科技的快速发展,平板电脑在我们的生活中扮演着越来越重要的角色。然而,传统的平板电脑只能在普通的环境中使用,而无法在恶劣的环境中使用,例如在高海拔、高温、高湿度、沙漠等环境中,传统平板电脑往往会出现故…

OpenHarmony—UIAbility组件间交互(设备内)

UIAbility是系统调度的最小单元。在设备内的功能模块之间跳转时,会涉及到启动特定的UIAbility,该UIAbility可以是应用内的其他UIAbility,也可以是其他应用的UIAbility(例如启动三方支付UIAbility)。 本章节将从如下场…

电脑数据丢失怎么办?5 种免费数据恢复软件能帮到您

我们存储在计算机中的个人和专业数据如果丢失,可能会给我们带来经济和精神上的困扰。有许多情况会导致此类数据丢失;其中一些包括意外删除、硬盘驱动器故障、软件崩溃、病毒攻击等。 5 种最佳免费数据恢复软件 为防止此类事故,建议定期备份计…

智慧城市驿站:智慧公厕升级版,打造现代化城市生活的便捷配套

随着城市化进程的加速,人们对城市生活质量的要求也越来越高。作为智慧城市建设的一项重要组成部分,多功能城市智慧驿站应运而生。它集合了信息技术、设计美学、结构工艺、系统集成、环保节能等多个亮点,将现代科技与城市生活相融合&#xff0…

RK3568平台 有线以太网接口之MAC芯片与PHY芯片

一.平台网络网络通路 平台有线以太网通路:有线以太网一般插入的是RJ45 座要与 PHY 芯片(RTL8306M)连接在一起,但是中间需要一个网络变压器,网络变压器经过模数转换后到达网卡(RTL8111)转换为帧数据后到达SOC。 二.网络…

【初始RabbitMQ】交换机的实现

交换机概念 RabbitMQ消息传递模型的核心思想就是:生产者生产的消息从不会直接发送到队列。实际上,通常生产者不知道这些消息会传递到那些队列中 相反,生产者只能将消息发送到交换机,交换机的工作内容也很简单,一方面…

Kubernetes基础(二十二)-k8s持久化存储详解

1 volume 1.1 介绍 在容器中的磁盘文件是短暂的,当容器崩溃时,Kubelet会重新启动容器,但容器运行时产生的数据文件都将会丢失,之后容器会以最干净的状态启动。另外,当一个Pod运行多个容器时,各个容器可能…

科技守护大唐遗宝,预防保护传承千年

​ 一、“大唐遗宝——何家村窖藏出土文物展” 陕西历史博物馆的“唐朝遗宝——何家村窖藏出土文物展”算得上是博物馆展览的典范。展览不仅在于展现了数量之多、等级之高、种类之全,更在于对唐朝历史文化的深入揭露。 走入大唐财产展厅,好像穿越千年前…

【Azure 架构师学习笔记】- Azure Databricks (7) --Unity Catalog(UC) 基本概念和组件

本文属于【Azure 架构师学习笔记】系列。 本文属于【Azure Databricks】系列。 接上文 【Azure 架构师学习笔记】- Azure Databricks (6) - 配置Unity Catalog 前言 在以前的Databricks中,主要由Workspace和集群、SQL Warehouse组成, 这两年Databricks公…

Bert基础(一)--transformer概览

1、简介 当下最先进的深度学习架构之一,Transformer被广泛应用于自然语言处理领域。它不单替代了以前流行的循环神经网络(recurrent neural network, RNN)和长短期记忆(long short-term memory, LSTM)网络,并且以它为基础衍生出了诸如BERT、GPT-3、T5等…

安全架构设计理论与实践

一、考点分布 安全架构概述(※※)安全模型(※※※)信息安全整体架构设计网络安全体系架构设计区块链技术(※※) 二、安全架构概述 被动攻击:收集信息为主,破坏保密性 主动攻击&#…

深度学习发展的艺术

将人类直觉和相关数学见解结合后,经过大量研究试错后的结晶,产生了一些成功的深度学习模型。 深度学习模型的进展是理论研究与实践经验相结合的产物。科学家和工程师们借鉴了人类大脑神经元工作原理的基本直觉,并将这种生物学灵感转化为数学模…

Mac环境Obsidian的ExcaliDraw添加中文字体

Mac环境Obsidian的ExcaliDraw添加中文字体 ExcaliDraw画图工具直接看图 ExcaliDraw画图工具 顾名思义,这是画图用的,但是系统不支持中文字体,所以需要下载中文字体自己放进去。 直接看图

HCIA-HarmonyOS设备开发认证V2.0-IOT硬件子系统-SPI

目录 一、 SPI 概述二、SPI 模块相关API三、接口调用实例四、SPI HDF驱动开发4.1、开发步骤(待续...) 坚持就有收获 一、 SPI 概述 SPI 是串行外设接口(Serial Peripheral Interface)是一种高速的全双工同步的通信总线。 SPI 是由 Motorola 公司开发&a…

VUE3 中导入Visio 图形

微软的Visio是一个功能强大的图形设计工具,它能够绘制流程图,P&ID,UML 类图等工程设计中常用的图形。它要比其它图形设计软件要简单许多。以后我的博文中将更多地使用VISO 来绘制图形。之前我一直使用的是corelDraw。 Visio 已经在工程设…

新增长100人研讨会:快消零售专场探讨招商加盟数字化转型实战

2024年2月2日下午,一场由纷享销客与杨国福集团联合主办的招商加盟数字化转型研讨会在上海成功举办。本次研讨会汇聚了众多快消零售业界的领军人物,共同探讨行业未来的新增长点。 会议伊始,杨国福集团数字化中心负责人王林林发表了主题演讲&a…

php伪协议之phar

一.phar协议 用于将多个 PHP 文件、类、库、资源(如图像、样式表)等打包成一个单独的文件。这个归档文件可以像其他 PHP 文件一样被包含(include)或执行。PHAR 归档提供了一种方便的方式来分发和安装 PHP 应用程序和库&#xff0c…

【unity实战】使用unity制作一个类似Rust的3D生存建造建筑系统(附项目源码)

配置连接点 材质 连接器控制 using System.Collections; using System.Collections.Generic; using UnityEngine;public class Connector : MonoBehaviour {[Header("连接器位置")]public ConnectorPosition connectorPosition;[Header("连接器所属建筑类型&qu…

以太坊 Dencun 升级与潜在机会

撰文:Biteye 核心贡献者 Fishery Isla 文章来源Techub News专栏作者,搜Tehub News下载查看更多Web3资讯。 以太坊网络升级 Dencun 测试网版本在 2024 年 1 月 17 日上线了 Goerli 测试网,1 月 30 日成功上线了 Sepolia 测试网,D…

RocketMQ—RocketMQ消息重复消费问题

RocketMQ—RocketMQ消息重复消费问题 重复消费问题的描述 什么情况下会发生重复消费的问题: 生产者多次投递消息:如果生产者发送消息时,连接有延迟,MQ还没收到消息,生产者又发送了一次消息; 消费者方扩容…