显存不够用?一种大模型加载时节约一半显存的方法

news2024/12/27 13:55:49

Loading huge PyTorch models with linear memory consumption

本文主要介绍了一种用于加载巨大模型权重时节约接近一半显存的方法

首先,创建一个模型:

import torch
from torch import nn

class BoringModel(nn.Sequential):
    def __init__(self):
        super().__init__()
        self.in_proj = nn.Linear(2, 10)
        self.stages = nn.Sequential(
             nn.Linear(10, 10),
             nn.Linear(10, 10)
        )
        self.out_proj = nn.Linear(10, 2)
        

上述创建,模型占用 1x 显存, x是指模型的大小

model = BoringModel()
# model is now in memory

有些时候我们把模型保存到本地硬盘中

torch.save(model.state_dict(), "./checkpoint.pt")
# our models is now stored on disk

之后需要用到之前保存的模型(两倍显存消耗)

# we need to redefine the model
model = BoringModel()

# 1x memory used
state_dict = torch.load("./checkpoint.pt")

# 2x memory used -> both model and state_dict are in memory!!!
model.load_state_dict(state_dict)
# 1x memory used

我们需要两倍的显存来加载我们之前存储过的权重
如果我们有一个巨大的模型,这是有问题的,因为我们需要两倍的空闲RAM。例如,假设我们有16GB的RAM,而我们的模型使用10GB。加载它需要20GB,我们需要改变我们的策略。
Recently, PyTorch introduced the meta device. When you put a tensor to the meta device, only its metadata (e.g. shape) are stored, and its values are tossed away. Thus, no space is used.

meta例子

x = torch.tensor([1])
x

tensor([1])

x.to(torch.device("meta"))

tensor(…, device=‘meta’, size=(1,), dtype=torch.int64)

因此,我们可以通过这种方法使用一倍的显存消耗来加载我们的模型

  • 定义我们的模型 1x显存

  • 实例化到meta设备上 1x显存

  • 加载state_dict,1x显存

  • replace all empty parameters of our model with the values inside the state_dict 1x显存

我们首先需要弄清楚如何将所有模型的参数替换为加载的“state_dict”中的原始参数

Let’s create the load_state_dict_with_low_memory function.

from typing import Dict

def load_state_dict_with_low_memory(model: nn.Module, state_dict: Dict[str, torch.Tensor]):
    # 通过把模型放到meta设备上来释放一半的显存
    model.to(torch.device("meta"))
    # 我们需要将state_dict中的每个键关联到一个子模块# we need to associate each key in state_dict to a submodule
    # 然后,迭代地使用' state_dict '中的值重新创建所有子模块的参数then, iteratively, re-creat all submodules' parameters with the values in `state_dict`
    pass
load_state_dict_with_low_memory(model, {})

model.state_dict()
OrderedDict([('in_proj.weight', tensor(..., device='meta', size=(10, 2))),
             ('in_proj.bias', tensor(..., device='meta', size=(10,))),
             ('stages.0.weight', tensor(..., device='meta', size=(10, 10))),
             ('stages.0.bias', tensor(..., device='meta', size=(10,))),
             ('stages.1.weight', tensor(..., device='meta', size=(10, 10))),
             ('stages.1.bias', tensor(..., device='meta', size=(10,))),
             ('out_proj.weight', tensor(..., device='meta', size=(2, 10))),
             ('out_proj.bias', tensor(..., device='meta', size=(2,)))])

模型现在是空的。

现在我们必须计算出来自state_dict的每个参数必须放入模型的哪个submodule of model中。一种方法是使用[key_in_state_dict] -> [submodule_in_module]创建一个字典。Now we have to figure out in which submodule of model each parameter from state_dict has to go. One way to do it is to create a dictionary with [key_in_state_dict] -> [submodule_in_module].

因此,我们知道我们必须将加载的state_dict中的值放在哪里。记住,一旦模型被放置在元设备中,它的所有权重都将被丢弃。
So we know where we have to place the values from the loaded state_dict. Remember, as soon as the model is placed inside the meta device, all its weights are tossed away.)

from typing import Dict

def get_keys_to_submodule(model: nn.Module) -> Dict[str, nn.Module]:
    keys_to_submodule = {}
    # iterate all submodules
    for submodule_name, submodule in model.named_modules():
        # iterate all paramters in each submobule
        for param_name, param in submodule.named_parameters():
            # param_name is organized as <name>.<subname>.<subsubname> ...
            # the more we go deep in the model, the less "subname"s we have
            splitted_param_name = param_name.split('.')
            # if we have only one subname, then it means that we reach a "leaf" submodule, 
            # we cannot go inside it anymore. This is the actual parameter
            is_leaf_param = len(splitted_param_name) == 1
            if is_leaf_param:
                # we recreate the correct key
                key = f"{submodule_name}.{param_name}"
                # we associate this key with this submodule
                keys_to_submodule[key] = submodule
                
    return keys_to_submodule
get_keys_to_submodule(model)

请添加图片描述
现在我们有办法知道哪个键对应’ model 的哪个submodule of model。让我们回到我们的load_state_dict_with_low_memory函数并使用来自state_dict的正确值将每个子模块的参数具体化

def load_state_dict_with_low_memory(model: nn.Module, state_dict: Dict[str, torch.Tensor]):
    # free up memory by placing the model in the `meta` device
    model.to(torch.device("meta"))
    keys_to_submodule = get_keys_to_submodule(model)
    for key, submodule in keys_to_submodule.items():
        # get the valye from the state_dict
        val = state_dict[key]
        # we need to substitute the parameter inside submodule, 
        # remember key is composed of <name>.<subname>.<subsubname>
        # the actual submodule's parameter is stored inside the 
        # last subname. If key is `in_proj.weight`, the correct field if `weight`
        param_name = key.split('.')[-1]
        param_dtype = getattr(submodule, param_name).dtype
        val = val.to(param_dtype)
        # create a new parameter
        new_val = torch.nn.Parameter(val, requires_grad=False))
        setattr(submodule, param_name, new_val)

model.state_dict()

请添加图片描述

load_state_dict_with_low_memory(model, torch.load("checkpoint.pt"))
model.state_dict()

请添加图片描述
🎉 We have successfully loaded our checkpoint inside our model with linear memory consumption!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/429924.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NIFI大数据进阶_实时同步MySql的数据到Hive中去_可增量同步_实时监控MySql数据库变化_实际操作_03---大数据之Nifi工作笔记0035

然后我们来操作一下首先创建一个处理器组名字是: MysqlToHive_timely 创建组以后我们进入这个组 然后我们先去添加CaptureChangeMySql这个处理器 拖入处理器以后,我们配置这个处理器 可以看到很多属性,这里我们配置 首先看这个Distrbuted Map Cache Client 这个客户端,我们先来…

ChatGPT聊天机器人程序

ChatGPT聊天机器人程序是一种基于人工智能技术的智能对话程序&#xff0c;利用ChatGPT等自然语言处理模型和算法实现与用户的交互&#xff0c;回答问题、提供服务等。 ChatGPT聊天机器人程序通常包括以下模块&#xff1a; 输入模块&#xff1a;用于接收用户输入的信息&…

vmware虚拟机上网设置教程(vmware虚拟机设置网络)

安装vmware后&#xff0c;一般都会有虚拟机能连互联网的需求&#xff08;如虚拟机中Linux想访问百度&#xff09;&#xff0c;vmware为我们提供了几种连接网络的方式&#xff0c;它们分别是&#xff1a;Bridged&#xff08;桥接模式&#xff09;、NAT&#xff08;网络地址转换模…

SpringBootApplication最详细注解

SpringBootApplication最详细注解SpringBootApplication的注解分类1.Target2.Retention3.Document4.Inherited5.SpringBootConfiguration6.EnableAutoConfiguration6.1AutoConfigurationPackage这个注解6.1.1 Import6.1.2 AutoConfigurationpackages.Registrar.class6.2 AutoCo…

DeepSpeed Chat: 一键式RLHF训练,让你的类ChatGPT千亿大模型提速省钱15倍

DeepSpeed Chat: 一键式RLHF训练&#xff0c;让你的类ChatGPT千亿大模型提速省钱15倍 1. 概述 近日来&#xff0c;ChatGPT及类似模型引发了人工智能&#xff08;AI&#xff09;领域的一场风潮。 这场风潮对数字世界产生了革命性影响。ChatGPT类模型具有惊人的泛用性&#xff0…

硬盘未格式化如何处理(硬盘忽然未格式化如何处理)

将硬盘插入电脑的时候为什么会出现“未格式化”的提示框呢?遇到这个问题时又该怎么处理呢?别慌&#xff0c;下面小编就来给大家演示一下子解决未格式化这个问题的解决方法。 硬盘未格式化如何处理工具/软件&#xff1a;sayRecy 步骤1&#xff1a;先百度搜索并下载程序打开后&…

一文吃透Java线程池——基础篇

前言 本文分为两部分。 第一部分是基础章节。可以帮助我们了解线程池的概念&#xff0c;用法&#xff0c;以及他们之间的的关系和实际应用。 第二部分是实现机制篇。通过源码解析&#xff0c;更深刻理解线程池的工作原理&#xff0c;以及各个概念的准确含义。 原本是一篇文章&…

ping不通的几种故障

网络ping不通是网络中出现频率最高的故障之一&#xff0c;同时也是最让人抓狂的故障&#xff0c;基本上大部分人都遇到过了&#xff0c;如果在项目中出现网络ping不通&#xff0c;没有一个有序的方法去排除解决&#xff0c;那么很难入手&#xff0c;也是讨论最多的问题之一&…

DNS(UOS)

安装DNS apt install bind9 nfsutils -y 切换目录 cd /etc/bind vim named.conf.defaults.zones 复制备份 cp -a db.local skills.net.zone cp -a db.127 146.16.172.in-addr.arpa vim skills.net.zone vim 146.16.172.in-addr.arpa vim /named.conf.options 重启bind9 …

javascript学习笔记

本笔记来源于B站尚硅谷javascript教程10.尚硅谷_JS基础_Null和Undefined_哔哩哔哩_bilibili 1、Null和None Null类型的值只一个&#xff0c;就是null; null这个值专门用来表示一个为空的对象; 使用typeof 检查一个null值时&#xff0c;会返回object; Undefined类型的值只有一个…

C++ 特性简化STM32 风格固件库的GPIO 操作,使用HK32F030M

所谓的STM32 风格就是指下面这种&#xff1a; // 开启时钟 RCC_AHBPeriphClockCmd( LED1_GPIO_CLK | LED2_GPIO_CLK, ENABLE);//定义初始化结构体 GPIO_InitTypeDef GPIO_InitStructure; GPIO_InitStructure.GPIO_Mode GPIO_Mode_OUT; GPIO_InitStructure.GPIO_OType GPIO_O…

迭代器与仿函数

迭代器与仿函数一般分类功能方式分类STL迭代器的类型迭代器辅助函数流型迭代器仿函数仿函数的编写标准库中的仿函数一般分类 正向迭代器 容器名&#xff1a;iterator it begin() end() 2.反向迭代器 容器名&#xff1a;reverse_iterator it rbegin() rend() 3.常正向迭代器 容器…

MQTT 安全解析:构建可靠的物联网系统

物联网逐渐渗透到医疗保健、智能家居、智慧城市、自动驾驶等我们生活中的各个领域。这其中所涉及到的物联设备的安全也因此变得愈发重要。一旦物联网系统遭到恶意入侵&#xff0c;不仅海量设备数据将面临丢失、被窃取和篡改等安全风险&#xff0c;使用这些设备和物联网应用的终…

Githubs的使用方法(创建仓库\分支\提交【增删改查】\拉取与合并\管理与clone代码\修改分支等操作)

Githubs的使用方法 一、github基本使用 这一小节主要介绍github的基本使用方法以及每一步的流程和作用。 1. 创建仓库 2. 创建分支 此时有两个分支&#xff1a;main 和 readme-edits。 现在&#xff0c;它们看起来完全相同。 接下来&#xff0c;将向新分支添加更改。 3. 创…

Vue3 项目实例(一)ElementPlus+ pinia+vite创建

项目搭建 热重载&#xff1a;将一个项目切分成多个JS&#xff0c;同时利用浏览器的协商缓存。 etag: 文件唯一标识 如果某一片代码没有改变&#xff0c;devServer返回304&#xff0c;浏览器继续使用原来的文件&#xff0c;否则&#xff0c;返回200&#xff0c;响应新的js文件…

RK3568平台开发系列讲解(调试篇)IS_ERR函数的使用

🚀返回专栏总目录 文章目录 一、IS_ERR函数用法二、IS_ERR函数三、内核错误码沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇将介绍 IS_ERR 函数的使用。 一、IS_ERR函数用法 先看下用法: 二、IS_ERR函数 对于任何一个指针来说,必然存在三种情况: 一种是合…

知识图谱:Neo4j数据库的基本使用——创建张学良的关系谱

一、知识图谱及Neo4j数据库介绍 知识图谱&#xff08;Knowledge Graph&#xff09;是人工智能的重要分支技术&#xff0c;它在2012年由谷歌提出&#xff0c;是结构化的语义知识库&#xff0c;用于以符号形式描述物理世界中的概念及其相互关系&#xff0c;其基本组成单位是“实体…

4.1派生类的概念

&#xff1a;为什么使用继承 所谓继承就是从先辈处得到属性和行为特征。类的继承就是新的类从已有类那里得到已有特征。这样做的目的是&#xff1a;减少代码的重复。 派生类的声明 声明派生类的一般公式 &#xff1a; class 派生类名:[继承方式] 基类名 { 派生类新增的数据成…

Java并发基石_CAS原理实战02_CAS实现原理

文章目录什么是CAS&#xff1f;CAS的实现原理是什么&#xff1f;cmpxchg指令怎么保证多核心下的线程安全&#xff1f;什么是ABA问题&#xff1f;如何解决ABA问题呢&#xff1f;什么是CAS&#xff1f; CAS&#xff0c;全称CompareAndSwap&#xff0c;比较并替换。 CAS包含了三个…

MyBatis --- 缓存、逆向工程、分页插件

一、MyBatis的缓存 1.1、MyBatis的一级缓存 一级缓存是SqlSession级别的&#xff0c;通过同一个SqlSession查询的数据会被缓存&#xff0c;下次查询相同的数据&#xff0c;就会从缓存中直接获取&#xff0c;不会从数据库重新访问 使一级缓存失效的四种情况&#xff1a; 1、…