大语言模型的工程技巧(二)——混合精度训练

news2024/10/7 17:31:28

相关说明

这篇文章的大部分内容参考自我的新书《解构大语言模型:从线性回归到通用人工智能》,欢迎有兴趣的读者多多支持。
混合精度训练的示例请参考如下链接:regression2chatgpt/ch11_llm/gpt2_lora_optimum.ipynb

本文将讨论如何利用混合精度训练(Mixed Precision Training)来减少内存的开销,特别是GPU内存的开销。这在大语言模型的训练当中是非常重要的。关于GPU的计算可以参考

  • 大语言模型的工程技巧(一)——GPU计算

关于大语言模型的讨论请参考:

  • 理解大语言模型(二)——从零开始实现GPT-2

内容大纲

  • 相关说明
  • 一、概述
  • 二、什么是混合精度训练?
  • 三、算法细节
  • 四、代码实现

一、概述

在人工智能领域,反向传播算法(计算参数梯度的算法)是非常重要的。而在进行反向传播计算时,必须将经过膨胀的计算图存储在内存中(如果使用GPU运算,那么将存储在GPU的专用内存中)。然而,这种存储量相当庞大,在整个计算图的存储结构中,数值存储占据了最大的比例。这些数值包括各个节点的计算结果(来自向前传播的输出),以及相应的梯度(这些梯度是来自反向传播的结果)。虽然梯度累积技术可以通过分解计算图来限制计算图的膨胀,从而降低内存的使用,但面对庞大的模型时,即便是单个数据点的计算图,其所需的内存都是巨大的。例如,大语言模型的参数数量可能高达数十亿甚至上百亿。

二、什么是混合精度训练?

为了解决这个具有挑战性的问题,需要采取额外的优化策略来降低内存的使用。在深入探讨这些策略之前,我们需要更详细地了解数字在计算机中的存储方式。一般而言,数值计算结果使用32位浮点数(需要4字节来存储,使用32位的二进制的方式表示)存储。这种存储方式被称为单精度浮点数。那么,如果使用16位二进制数表示一个数值,会产生什么影响呢?

这种方法的好处之一是能够立即减少所需的存储空间,同时提升计算速度。然而,这种方法也存在一个明显的缺陷,即能够表示的数值范围受限。为了便于讨论,下面以能够表示的最小正数为例。使用16位浮点数,能够表示的最小正数是 2 − 24 2^{-24} 224(相比之下,32位浮点数能够表示的最小正数为 2 − 149 2^{-149} 2149)。当实际的数值小于这个阈值时,计算机会错误地将其视作0,这就是浮点数下溢(Underflow)。

为了尽可能地减少这类错误的发生,可以混合精度训练(Mixed Precision Training)算法,顾名思义,它是指在模型训练过程中使用不同的数值精度来处理不同部分的计算。

三、算法细节

这一算法包含两个主要部分。

  1. 精度分层处理:在这种训练中,模型本身(模型参数)依然使用32位浮点数进行存储,参数更新过程也使用32位浮点数。在模型的向前传播和反向传播过程中,转而使用16位浮点数进行计算。具体情况如图1所示。

图1

图1

  1. 引入比例因子(Scale Factor):在数学上,要防止浮点数下溢是相当容易的,只需要将模型损失乘以一个较大的常数n,该常数也被称为比例因子。根据链式法则,这将导致所有节点的梯度都增大n倍。这种方法确保了梯度落入16位浮点数表示的范围,从而解决浮点数下溢问题。在使用这些梯度进行参数更新时,需要将引入的缩放移除,也就是将梯度除以n。将这个过程与精度分层处理相结合,如图2所示。

图2

图2

混合精度训练方法的优势在于,在保持适当的模型表示能力的同时,显著降低了内存开锁。通过将高精度的32位浮点数与16位浮点数的计算相结合,在不牺牲模型性能的前提下,显著减少内存需求,使计算机能够处理更大规模的模型和数据集。

四、代码实现

在实际应用中,PyTorch已经提供了相应的封装函数,分别是torch.cuda.amp.autocast和torch.cuda.amp.GradScaler。其中autocast实现的是第一部分——精度分层处理;GradScaler实现的是第二部分——引入比例因子。借助这两个工具,在优化算法中使用混合精度训练就变得很容易了。示意代码如下:

# 常规的模型训练实现
for epoch in range(0): 
    for input, target in zip(data, targets):
        # 启动混合精度训练
        with torch.autocast(device_type=device, dtype=torch.float16):
            output = net(input)
            loss = loss_fn(output, target)

        # 在触发反向传播之前,启动缩放因子
        scaler.scale(loss).backward()

        # 更新模型参数
        scaler.step(opt)
        scaler.update()

        opt.zero_grad()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1684367.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows ssh客户端mobaxterm密码登录到debian12 openssh服务器

1,在debian12生成公钥、秘钥 ssh-keygen -t rsa ~/.ssh/id_rsa 是秘钥,要放到windows的(这里先不要放,等转换一下再放); ~/.ssh/id_rsa.pub 是公钥,放在debian12本地就好了, 顺…

初识DataX3.0

目前接到任务,让同步表数据。市面很多同步工具不一一尝试了,信赖阿里,所以调研了一下阿里的dataX,一点点来吧,学习为主 环境准备:linux6.8 python自带的2.7 MySQL 5.7.1 1.先下载: wget http://datax-o…

TG5032CKN是一种高稳定性晶体振荡器

TG5032CKN的输出频率范围为10 MHz至24 MHz,能够在-40C至105C的温度范围内工作,其频率/温度特性为0.110^-6 Max。这表明该设备具有很好的温度稳定性,适合在极端温度条件下使用。TG5032CKN的尺寸为5.03.21.65 mm,可以选择10针或4针封…

iOS App上架全流程及审核避坑指南

App Store作为苹果官方的应用商店,审核严格周期长一直让用户头疼不已,很多app都“死”在了审核这一关,那我们就要放弃iOS用户了吗?当然不是!本期我们从iOS app上架流程开始梳理,详细了解下iOS app上架的那些…

番外篇 | YOLOv5更换主干网络之Conformer:首个CNN + Transformer的backbone模型

前言:Hello大家好,我是小哥谈。Transformer和CNN在处理视觉表征方面都有着各自的优势以及一些不可避免的问题。因此,国科大、鹏城实验室和华为研究人员首次将二者进行了融合并提出全新的Conformer模型,其可以在不显著增加计算量的前提下显著提升了基网表征能力。论文已被IC…

简析网络风险量化的价值与应用实践,如何构建网络风险预防架构

网络风险量化能够让公司董事会和高管层看清当前的网络安全风险格局;它还将使安全团队能够在业务需求的背景下做出网络安全决策,帮助组织确定哪些风险对业务构成最大的威胁,以及预期的经济损失将是什么。 随着网络攻击手段的日益多样化和复杂…

解锁创意新境界:StartAI插件让Photoshop飞起来!

Photoshop AI插件的革命性突破:StartAI插件的全面体验 作为一名AIGC测评博主,我一直在寻找能够提升设计效率和创意表现的工具。今天,我将带大家深入了解一款令人兴奋的Photoshop AI插件——StartAI,它不仅为设计师带来了前所未有…

3---版本库和工作区、使用.git管理工作区的文件、HEAD指针和master的关系

一、本地仓库和工作区的概念: 1.1本地仓库——版本库: 本地仓库又称为版本库。版本库是隐藏目录.git,并不是.git所在的目录。版本库不属于工作区。我们不能手动操作.git目录及其中的文件,这样可能会直接破坏版本库。stage(暂存区…

vue的组件化

vue的组件化 vue的组件化,就是根据功能、业务逻辑、数据流向等因素进行划分把页面拆分成多个组件。组件是资源独立的,组件也可以相互嵌套。目的是提高代码的可读性、可维护性和可复用性。 组件化思想体现 ​ 组件封装步骤 1.公共组件 公共组件全局注…

【easyx】快速入门——弹球小游戏(第一代)

目录 1.需求 2.运动的小球 3.碰到边缘反弹 4.圆周撞击或越过边界反弹 5.绘制和移动挡板 6.小球碰到挡板反弹 7.游戏失败时该如何处理 8.随机初始条件 9.完整代码 我们这一节将结合动画和键盘交互的知识来做一个小游戏 1.需求 我们先看需求:小球在窗体内运动,撞到除…

HCIP【VRRP、MSTP、VLAN综合实验】

目录 一、实验拓扑图: ​编辑二、实验要求 三、实验思路 四、实验步骤 (1) eth-trunk技术配置 (2)vlan 技术配置 (3)配置SW1、SW2、AR1、ISP的IP地址 (4)在交换机…

作物水文模型AquaCrop---用于评估作物对水的需求、灌溉计划和管理策略

AquaCrop是由世界粮食及农业组织(FAO)开发的一个先进模型,旨在研究和优化农作物的水分生产效率。这个模型在全球范围内被广泛应用于农业水管理,特别是在制定农作物灌溉计划和应对水资源限制方面显示出其强大的实用性。AquaCrop 不…

openmldb install log

下载/源码编译# 如果你的操作系统可以直接运行预编译包,则可以从以下地址下载: GitHub release 页面:Releases 4paradigm/OpenMLDB GitHubOpenMLDB is an open-source machine learning database that provides a feature platform comput…

普源DHO924示波器OFFSET设置

一、简介 示波器是电子工程师常用的测量工具之一,能够直观地显示电路信号的波形和参数。普源DHO924是一款优秀的数字示波器,具有优异的性能和易用性。其中OFFSET功能可以帮助用户调整信号的垂直位置,使波形更清晰易读。本文将详细介绍DHO924…

基于python实现搜索的目标站点内容监测系统

基于python实现搜索的目标站点内容监测系统 开发语言:Python 数据库:MySQL所用到的知识:Django框架工具:pycharm、Navicat、Maven 系统功能实现 登录页面 后台的登录一般是为了管理员的管理方便进行一个用户权限的验证。也是为管理员提供的唯…

Android正向开发实现客户端证书认证

前言 如果第三方模块被混淆,那hook方式均不能生效。这时就需要根据系统包去定位校验的函数,因此需要对安卓开发者是如何实现客户端证书校验的有一定了解,接下来就介绍这部分内容。 开发者实现客户端证书校验的本质是:证书/密钥 + 代码。 在形式上有:证书校验、公钥校验和…

C++之单链表与双链表逆序实例(二百七十九)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏:多媒…

某跨国集团文件跨域安全交换解决方案

在全球化和数字化浪潮的推动下,大型企业越来越依赖于跨域文件交换,以实现跨地区、跨部门的高效协作。然而,随之而来的数据安全和管理挑战也变得愈加严峻。FileLink跨域文件交换安全管控系统应运而生,为大型企业提供了一站式解决方…

Casper Blockchain:基于 CSPR.build 套件,实现闪电般的 dApp 部署

对于许多工程师而言,即使作为对于区块链较为了解的终端用户,与区块链的整合仍然是一个谜团。虽然很多技术文章通常将注意力和报道重点放在智能合约开发上,但当涉及到如何将区块链技术与其应用程序的其余部分集成时,开发者往往只能…