超大模型加载转换Trick

news2025/1/11 2:59:42

        在深度学习领域,大模型的训练和推理通常需要消耗大量的计算和内存。如何高效地加载和使用大模型是一个相当关键的问题。在这篇博客中,我将分享一些关于更快加载大模型和减少内存的技巧。

1.问题分析

        假设现在我们有一个236B 超大模型的原始权重的 checkpoint.pth 文件, 比如 DeepSeek Chat V2, 以BF16 格式存储, 一个标准的加载流程如下

import torch

state_dict = torch.load(checkpoint_file)
my_model = BigModelClass(...)
my_model.load_state_dict(state_dict)

        在这段代码的中, my_model = BigModelClass(...) 会初始化一个模型, torch.load(checkpoint_file)函数会将模型权重从磁盘加载到内存中。然后,my_model.load_state_dict(state_dict)函数会将权重从内存加载到模型的参数中。这两个步骤都可能会消耗大量的时间和内存。理想情况下, 一个236B BF16格式的模型需要占据 472GB 的内存, 上面的代码会有两个模型副本, 这意味着峰值需要944GB 内存, 接近1T ,这是非常夸张的也是不可接受的.

        我们用一段简单的代码来验证上面的推断, 首先初始化一个 1B size 的模型并存下来,

import torch

def count_parameters(model):
    total_params =  sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params / 1e9 

def model_memory_size_in_megabytes(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.numel() * param.element_size()  

    bytes_in_gb = 1024 * 1024 * 1024 
    return param_size / bytes_in_gb

class BigModel(torch.nn.Module):
    def __init__(self, size):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(size, size) for i in range(10)])

    def forward(self, x):
        return self.linears(x)

size = 10000
model = BigModel(size)

# 打印模型的参数量
print(f'The model has {count_parameters(model):,} B trainable parameters')
print(f"The model's memory size is approximately {model_memory_size_in_megabytes(model):.2f} GB.")
torch.save(model.state_dict(), 'checkpoint.pth')

        The model has 1.0001 B trainable parameters
        The model's memory size is approximately 3.73 GB.

        然后 按照上面的方式加载模型, 并统计cpu 内存占用, torch 默认是FP32 格式, 1B模型占用约 4GB 内存(实际为3.73GB左右), 下面代码验证后基本符合预期

def print_usage():
    pid = os.getpid()
    py = psutil.Process(pid)
    memory_use = py.memory_info()[0] / 2. ** 30  # memory use in GB...I think
    print(f'memory: {memory_use:.2f} GB')
    print('CPU percent:', psutil.cpu_percent())

print('Before Load the state_dict:')
print_usage()
Before Load the state_dict:
memory: 0.34 GB
CPU percent: 8.5
start_time = time.time()
state_dict = torch.load('checkpoint.pth')
print(f'Loading the state_dict took {time.time() - start_time:.2f} seconds')
print('After Load the state_dict:')
print_usage()
Loading the state_dict took 2.09 seconds
After Load the state_dict:
memory: 4.06 GB
CPU percent: 7.0

        4.06 - 0.34 = 3.72基本一致

start_time = time.time()
model = BigModel(size)
print(f'Init the model took {time.time() - start_time:.2f} seconds')
print('After Init the model:')
print_usage()
Init the model took 7.23 seconds
After Init the model:
memory: 7.79 GB
CPU percent: 7.6

7.79 - 4.06 = 3.73 基本一致

start_time = time.time()
model.load_state_dict(state_dict)
print(f'Loading the state_dict to model took {time.time() - start_time:.2f} seconds')
print('After Load the state_dict to model:')
print_usage()
Loading the state_dict to model took 2.63 seconds
After Load the state_dict to model:
memory: 7.79 GB
CPU percent: 16.4

2.问题解决

        分析清楚在加载和初始化环节中各个流程的开销, 我们来看看可以如何加速每个过程.

2.1 使用torch.load(mmap=True)

        首先,让我们考虑一下当我们使用 加载检查点时会发生什么torch.load。当我们使用 保存检查点时torch.save,张量存储会使用保存它们的设备进行标记。使用torch.load,张量存储将加载到它们标记的设备(除非使用标志覆盖此行为 map_location)。为了便于解释,我们假设张量保存在 CPU 上。这意味着在第一行,所有张量存储都将加载到 CPU RAM 中,这在以下情况下可能是不可行的:

  • CPU RAM 小于检查点的大小。

  • 等待整个检查点加载到 RAM 中,然后再执行某些按张量处理等操作。

start_time = time.time()
state_dict = torch.load('checkpoint.pth')
end_time = time.time()
print(f"loading time without mmap={end_time - start_time}")
print_usage()
loading time without mmap=2.0737619400024414
memory: 4.06 GB
CPU percent: 8.7

        torch.load中的mmap参数图解决上述两个问题。顾名思义,mmap关键字参数 totorch.load 使用mmap 调用 ,将磁盘上的文件映射到虚拟内存,并让操作系统自动处理到物理内存的加载和卸载。当这个标志被传递时,张量存储将被内存映射。

start_time = time.time()
state_dict = torch.load('checkpoint.pth', mmap=True)
end_time = time.time()
print(f"loading time with mmap={end_time - start_time}")
print_usage()
loading time with mmap=0.003424406051635742
memory: 0.34 GB

        通过上面对比,我们可以发现 使用mmap可以加速模型加载并减少内存占用, 对于236B的模型, 我们实际上并不需要 1TB的 CPU内存来完成转换

2.2使用 torch.device('meta')

        当模型size 巨大时, 模型初始化也需要巨大时间, 我们扩大一下模型size到25B, 初始化一个模型就需要接近3分钟.

size = 50000
start_time = time.time()
model = BigModel(size)
end_time = time.time()
print(f"init time={end_time - start_time}")
print(f'The model has {count_parameters(model):,} B trainable parameters')
print(f"The model's memory size is approximately {model_memory_size_in_megabytes(model):.2f} GB.")
init time=184.56671452522278
The model has 25.0005 B trainable parameters
The model's memory size is approximately 93.13 GB.

        但在load 模型时, 初始化这一步是多余的, 我们实际上只需要知道模型的所有 key 和 对应的 shape,这个时候, torch.device('meta') 这个 上下文就可以发挥作用了, torch.device() 上下文管理器确保工厂调用将像它们被传递了指定的"device"作为参数一样执行。        在 torch.device('meta') 上的张量不携带数据。然而,它们具有张量所具有的所有其他元数据,例如.size().stride().requires_grad等。

with torch.device('meta'):
   model = BigModel(size)
model.load_state_dict(state_dict, assign=True)

for n, p in model.named_parameters():
    assert p.device.type != "meta", f"{n} has not been loaded!"

        注意, 在使用 torch.device('meta')后, 我们需要加上 assign=True参数来让参数被加载. 最后一段代码可以check 所有参数被正确加载了, 加载后的参数的 device应该不再是 meta 了.

2.3实验结果

        最后, 我们直接上一个100B size大小的大模型来对比, 是否使用 torch.load(mmap=True) 和torch.device('meta') 速度差别.

size = 100000
model = BigModel(size)

# 打印模型的参数量
print(f'The model has {count_parameters(model):,} B trainable parameters')
print(f"The model's memory size is approximately {model_memory_size_in_megabytes(model):.2f} GB.")
torch.save(model.state_dict(), 'checkpoint.pth')
The model has 100.001 B trainable parameters
The model's memory size is approximately 186.27 GB.

加速前

start_time = time.time()
state_dict = torch.load('checkpoint.pth')
print(f'Loading the state_dict took {time.time() - start_time:.2f} seconds')
print('After Load the state_dict:')
print_usage()

start_time = time.time()
model = BigModel(size)
print(f'Init the model took {time.time() - start_time:.2f} seconds')
print('After Init the model:')
print_usage()

start_time = time.time()
model.load_state_dict(state_dict)
print(f'Loading the state_dict to model took {time.time() - start_time:.2f} seconds')
print('After Load the state_dict to model:')
print_usage()

start_time = time.time()
input = torch.randn(1, size)
output = model(input)
print(output)
print(f'One time forward {time.time() - start_time:.2f} seconds')
print_usage()
Before Load the state_dict:
memory: 0.34 GB
CPU percent: 9.1
Loading the state_dict took 852.06 seconds
After Load the state_dict:
memory: 372.87 GB
CPU percent: 5.0
Init the model took 518.15 seconds
After Init the model:
memory: 745.41 GB
CPU percent: 4.9
Loading the state_dict to model took 125.63 seconds
After Load the state_dict to model:
memory: 745.41 GB
CPU percent: 11.7
tensor([[-0.0015, 0.0017, -0.0009, ..., -0.0036, 0.0041, 0.0052]],
grad_fn=\)
One time forward 6.95 seconds
memory: 745.42 GB
CPU percent: 11.4

加速后

start_time = time.time()
state_dict = torch.load('checkpoint.pth', mmap=True)
print(f'Loading the state_dict took {time.time() - start_time:.2f} seconds')
print('After Load the state_dict:')
print_usage()

start_time = time.time()
with torch.device('meta'):
  model = BigModel(size)
print(f'Init the model took {time.time() - start_time:.2f} seconds')
print('After Init the model:')
print_usage()

start_time = time.time()
model.load_state_dict(state_dict, assign=True)
print(f'Loading the state_dict to model took {time.time() - start_time:.2f} seconds')
print('After Load the state_dict to model:')
print_usage()

for i in range(2):
    start_time = time.time()
    input = torch.randn(1, size)
    output = model(input)
    print(output)
    print(f'One time forward {time.time() - start_time:.2f} seconds')
    print_usage()
Before Load the state_dict:
memory: 0.34 GB
CPU percent: 9.1
Loading the state_dict took 0.11 seconds
After Load the state_dict:
memory: 0.34 GB
CPU percent: 6.1
Init the model took 0.00 seconds
After Init the model:
memory: 0.34 GB
CPU percent: 4.3
Loading the state_dict to model took 0.00 seconds
After Load the state_dict to model:
memory: 0.34 GB
CPU percent: 10.0
tensor([[ 0.0080, -0.0017, -0.0027, ..., -0.0011, 0.0097, -0.0048]],
grad_fn=\)
One time forward 48.37 seconds
memory: 372.85 GB
CPU percent: 5.2
tensor([[ 0.0038, 0.0014, -0.0076, ..., -0.0016, 0.0004, -0.0018]],
grad_fn=\)
One time forward 3.28 seconds
memory: 372.86 GB
CPU percent: 13.4

通过上面的对比, 加速前100B模型加载时间为

852.06 + 518.15 + 125.63 = 1495(s) = 25 (min)

而使用 mmap + meta device 加载几乎没有时间开销, 只有模型真正运行时才会从硬盘拷贝权重到CPU RAM。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1927297.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

jmeter-beanshell学习9-放弃beanshell

写这篇时候道心不稳了,前面写了好几篇benashell元件,突然发现应该放弃。想回去改前面的文章,看了看无从下手,反正已经这样了,我淋了雨,那就希望别人也没有伞吧,哈哈哈哈,放在第九篇送…

DHCP原理及配置

目录 一、DHCP原理 DHCP介绍 DHCP工作原理 DHCP分配方式 工作原理 DHCP重新登录 DHCP优点 二、DHCP配置 一、DHCP原理 1 DHCP介绍 大家都知道,现在出门很多地方基本上都有WIFI,那么有没有想过这样一个问题,平时在家里都是“固定”的…

互联网十万个为什么之什么是专有网络VPC?

专有网络VPC有什么优势? 专有网络VPC具有安全可靠、灵活可控、简单易用的特性和较强的可扩展性。 安全可靠 每个VPC都有一个独立的隧道号,一个隧道号对应着一个虚拟化网络。VPC之间通过隧道号进行隔离: 由于VPC内部存在交换机和路由器&#…

PyTorch人脸识别

新书速览|PyTorch深度学习与企业级项目实战-CSDN博客 一套基本的人脸识别系统主要包含三部分:检测器、识别器和分类器,流程架构如图11-3所示: 图11-5 检测器负责检测图片中的人脸,再将检测出来的人脸感兴趣区域(Reg…

如何在单片机外部Flash存储器上部署高效文件系统:从原理到实现

目录 1.Littlefs文件系统 1.1文件系统简介 2 Littlefs文件系统移植到单片机上 2.1 添加源代码 2.2 编辑接口函数 2.3 测试代码 1.Littlefs文件系统 1.1文件系统简介 littlefs文件系统源码下载地址:littlefs-project/littlefs: A little fail-safe filesystem…

Unity Shader学习笔记

Shader类型 类型详情Standard Surface Shader标准表面着色器,基于物理的着色系统,用于模拟各种材质效果,如石头、木材、玻璃、塑料和金属等。Unlit Shader最简单的着色器,不包含光照但包含雾效,只由最基础的Vertex Sh…

Pytorch使用Dataset加载数据

1、前言: 在阅读之前,需要配置好对应pytorch版本。 对于一般学习,使用cpu版本的即可。参考教程点我 导入pytorch包,使用如下命令即可。 import torch # 注意虽然叫pytorch,但是在引用时是引用torch2、神经网络获取…

【C++】—— 初识C++

【C】—— 初识C 一、什么是 C二、C 的发展历史三、C 版本更新四、C 的重要性五、C 在工作领域中的运用六、C 书籍推荐: 一、什么是 C C语言 是结构化和模块化的语言,适合处理较小规模的程序。对于复杂的问题,规模较大的程序,需要…

六、STM32F4+标准库+LWIP2.1.2移植+无操作系统

最快最简单的移植LWIP协议栈,可改可不改的东西统一不修改。后期学会了有能力了再回过头来修改,操作复杂理论复杂,同时讲解对新手不是很友好,故此此文档只讲操作无任何理论讲解。 零、所需文件及环境 1、第四章建立好的串…

51单片机11(蜂鸣器硬件设计和软件设计)

一、蜂鸣器硬件设计 1、 2、上面两张图,是针对不同产品的电路图。像左边这一块,是我们的A2,A3,A4的一个产品对应的一个封闭器的硬件电路。而右边的这一块是对应的A5到A7的一个硬件电路。因为A5到A7的一个产品,它的各…

排序算法3_冒泡排序、快速排序

一、冒泡排序 1.1 冒泡排序定义和思路 冒泡排序的基本思想是:通过相邻两个元素之间的比较和交换,使较大的元素逐渐从前面移向后面(升序),就像水底下的气泡一样逐渐向上冒泡,所以被称为“冒泡”排序。  在…

【YOLOv8】 用YOLOv8实现数字式工业仪表智能读数(二)

上一篇圆形表盘指针式仪表的项目受到很多人的关注,咱们一鼓作气,把数字式工业仪表的智能读数也研究一下。本篇主要讲如何用YOLOV8实现数字式工业仪表的自动读数,并将读数结果进行输出,若需要完整数据集和源代码可以私信。 目录 &…

王牌站士Ⅹ---人工智能中的数据隐私:PII 与个人信息

前言 今天,我将讨论如何区分美国和全球范围内不断涌现的数据隐私法所涵盖和不涵盖的数据类型。不同类型的数据受到更严格的保护,具体取决于司法管辖区,因此,如果您使用个人数据进行分析或机器学习,了解这一点很重要。…

痛心!不会用ChatGPT,差点错失一个亿

ChatGPT爆火这么久,今天我们也来聊聊GPT的玩法。等下,什么?你没听说过?没用过? 没听过没用过的朋友们,你们知道当我听到这回答的时候是多么痛心疾首吗? 为了让你们更直观的感受到,举个栗子,如果你用了GPT,就不需要抓耳挠腮的想方案了;如果你用了GPT,或许工作学习效…

MySQL 数据库 - 事务

MySQL 数据库(基础)- 事务 事务简介 事务 是一组操作集合,他是一个不可分割的工作单位,事务会把所有的操作看作是一个整体一起向系统发送请求,即这些操作要么同时成功,要么同时失败。 比如:张…

《Python数据科学之三:探索性数据分析与可视化》

《Python数据科学之三:探索性数据分析与可视化》 在数据科学项目中,探索性数据分析(EDA)和数据可视化是至关重要的步骤。它们帮助数据科学家理解数据的特征、发现数据中的模式和异常值,从而为后续的数据分析和机器学习…

python-29-零基础自学python-json、函数等存取用户数据+验证用户信息

学习内容:《python编程:从入门到实践》第二版 知识点: 如何验证用户、try-except-else处理异常 if判断、def方法及拆解方法 json引入、存储、读取 return none和return变量返回值很重要 answer 1 和answer “1”在使用后的区别 练习内容…

IDEA创建项目模块右边缺少Maven的解决

一、问题描述 我们在创建项目模块时,创建为Maven工程,创建后只是普通工程,idea右边缺少Mavenue标识管理 如图 二、问题的解决方法 在模块的pom.xml文件,点击选项,添加为Maven工程 如图 至此,创建maven工程…

2-34 小波神经网络采用传统 BP 算法

小波神经网络采用传统 BP 算法,存在收敛速度慢和易陷入局部极小值两个突出弱点。建立了基于遗传算法的小波神经网络股票预测模型 GA-WNN。该模型结合了遗传算法的全局优化搜索能力以及小波神经网络良好的时频局部特性。运用 MATLAB 对拟合和预测过程进行仿真。结果表…

COLING 2024 | AlphaFin:基于LLM的股票预测大模型,显著提高预测能力

COLING 2024 | AlphaFin:基于LLM的股票预测大模型,显著提高预测能力 发布于 2024-06-13 18:31:49 目前,机器学习和深度学习算法(ML&DL)已被广泛应用于股票趋势预测,并取得了显著进展。然而&#xff0c…