LLM优化:开源星火13B显卡及内存占用优化

news2024/11/26 23:54:40

1. 背景

本qiang~这两天接了一个任务,部署几个开源的模型,并且将本地经过全量微调的模型与开源模型做一个效果对比。

部署的开源模型包括:星火13B,Baichuan2-13B, ChatGLM6B等

其他两个模型基于transformers架构封装,因此推理服务启动还是十分丝滑,但星火13B是基于Megatron-DeepSpeed框架实现,地址是:https://gitee.com/iflytekopensource/iFlytekSpark-13B,启动推理服务的过程中发现启动13B的显卡占用71G-78G,有些反直觉。

此文就是整理开源星火13B的显存及内存排查并优化的整理过程,至于哪家开源模型效果好,不在此文的讨论范围内。

2. 原因分析

直观上来说,13B的模型,数据类型为bf16,显卡占用大概在26G左右,但星火13B直接占用70G+,不可思议,怪不得网上关于星火开源模型的讨论少之又少,原因显而易见,这么大的显存占用只能用多卡或者A800等80G显卡才能适配。穷人家的孩子,哪有这么多余粮。

排查原因的过程中,少不了源码的调试与分析。在排查的过程中,启动推理服务的文件run_iFlytekSpark_text_generation.py中,model_provider方法是初始化模型并加载模型文件的方法。

def model_provider(pre_process=True, post_process=True):

    """Build the model."""

    print_rank_0('building iFlytekSpark model ...')

    args = get_args()

    config = core_transformer_config_from_args(args)

        

         ### 初始化星火模型

    model = iFlytekSparkModel(

        config,

        num_tokentypes=0,

        parallel_output=False,

        pre_process=pre_process,

        post_process=post_process,

        return_moe_loss=False

    )





    if args.from_pretrained is not None:

        assert os.path.exists(args.from_pretrained)

        ckpt_path = get_checkpoint_name(args.from_pretrained)

        print_rank_0('Loading from {} '.format(

                args.from_pretrained))

                  # 模型加载权重文件

        state_dict = torch.load(ckpt_path, map_location=f"cuda:{torch.cuda.current_device()}")

        if 'module' in state_dict:

            state_dict = state_dict['module']

        model.load_state_dict(state_dict)

    return model

其中,加载权重文件可以看到,加载state_dict时,直接将权重文件加载到显卡中,而非加载至CPU,然后再执行to方法,转移到GPU。因此该处是一个潜在的优化点

再打入iFlytekSparkModel内部,词表Embedding层,线性转换层,等初始化weight时,也是直接将weight分配在GPU上运行。例如下例:

class RowParallelLinear(torch.nn.Module):
    def __init__(self, input_size: int, output_size: int, *,
                 config: ModelParallelConfig,
                 init_method: Callable,
                 bias: bool = True,
                 input_is_parallel: bool = False,
                 stride: int = 1,
                 keep_master_weight_for_test: bool = False,
                 skip_bias_add: bool = False,
                 moe=False, enable_expert_tensor_parallelism=False):
        super(RowParallelLinear, self).__init__()

        # .........
		
        if config.use_cpu_initialization:
            self.weight = Parameter(torch.empty(self.output_size,
                                                self.input_size_per_partition,
                                                dtype=config.params_dtype))
            if config.perform_initialization:
                self.master_weight = _initialize_affine_weight_cpu(
                    self.weight, self.output_size, self.input_size,
                    self.input_size_per_partition, 1, init_method,
                    stride=stride, return_master_weight=keep_master_weight_for_test,
                    params_dtype=config.params_dtype)
        else:
			# 默认按照启动sh命令,会走该分支
            self.weight = Parameter(torch.empty(
                self.output_size, self.input_size_per_partition,
                device=get_accelerator().current_device_name(), dtype=config.params_dtype))
            if config.perform_initialization:
                _initialize_affine_weight_gpu(self.weight, init_method,
                                              partition_dim=1, stride=stride)
        if bias:
            if config.use_cpu_initialization:
                self.bias = Parameter(torch.empty(self.output_size,
                                                  dtype=config.params_dtype))
            else:
				# 默认按照启动sh命令,会走该分支
                self.bias = Parameter(torch.empty(
                    self.output_size, device=get_accelerator().current_device_name(),
                    dtype=config.params_dtype))
            setattr(self.bias, 'sequence_parallel', self.sequence_parallel)

            if config.perform_initialization:
                # Always initialize bias to zero.
                with torch.no_grad():
                    self.bias.zero_()
        else:
            self.register_parameter('bias', None)

3. 优化方案

1. 模型初始化时,模型的Embedding,线性层的权重weight均直接加载至GPU,因此可以优化为先将这些weight加载至CPU

改进的方式也很简单,从上面的源码层面,可以看到,当增加参数” use_cpu_initialization”,将使用CPU进行初始化权重,因此只需要在启动推理服务的脚本中增加” --use-cpu-initialization”参数即可。

2. 加载模型文件时,直接加载至GPU,然后run_iFlytekSpark_text_generation.py中的get_model方法中,当模型加载完成后,会进行分配至GPU以及FP16的转换的操作。如下代码所示。

def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap_with_ddp=True):
    """Build the model."""
    args = get_args()
    args.model_type = model_type

    # ..........

    # GPU allocation.
    for model_module in model:
        model_module.to(get_accelerator().current_device_name())
 

    # Fp16 conversion.
    if args.fp16 or args.bf16:
        model = [Float16Module(model_module, args) for model_module in model]

    # .......

    return model

因此,优化的方式也很简单,可以优化为先加载至CPU,再运行get_model中的默认分配至GPU,加载完后,再使用垃圾回收机制清除CPU占用的内存即可

话不多说,优化后的代码如下:

def model_provider(pre_process=True, post_process=True):
    """Build the model."""
    print_rank_0('building iFlytekSpark model ...')
    args = get_args()
    config = core_transformer_config_from_args(args)
    model = iFlytekSparkModel(
        config,
        num_tokentypes=0,
        parallel_output=False,
        pre_process=pre_process,
        post_process=post_process,
        return_moe_loss=False
    )


    if args.from_pretrained is not None:
        print(args.from_pretrained)
        assert os.path.exists(args.from_pretrained)
        ckpt_path = get_checkpoint_name(args.from_pretrained)
        print_rank_0('Loading from {} '.format(
                args.from_pretrained))

        # state_dict = torch.load(ckpt_path, map_location=f"cuda:{torch.cuda.current_device()}")
		# CPU进行加载
        state_dict = torch.load(ckpt_path, map_location=f"cpu")
        if 'module' in state_dict:
            state_dict = state_dict['module']
        model.load_state_dict(state_dict)
		
		# 加载完成,删除state_dict,并垃圾回收
        del state_dict
        gc.collect()
        torch.cuda.empty_cache()

    return model

4. 效果对比

(1) 优化前的显卡占用: 71.5G

(2) 优化前的内存占用: 虚拟内存占用94.5G

(3) 优化后的显卡占用: 26G

(4) 优化后的内存占用: 43.1G

5. 总结

一句话足矣~

本文主要是针对开源星火13B的显存及内存占用过大的一个代码优化。核心思想是使用CPU预加载模型,再转换至GPU。

后期如有遇到此类问题,可以借鉴之~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1631902.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

创建基于时间的 UUID

概述 在本文中,我们将会 对 UUIDs 和基于时间的 UUIDs(time-based UUIDs) 进行一些探讨。 当我们在对基于时间的 UUIDs 进行选择的时候,总会遇到一些好的方面和不好的方面,如何进行选择,也是我们将要简要…

代码+视频,R语言绘制生存分析模型的时间依赖(相关)性roc曲线和时间依赖(相关)性cindex曲线

ROC曲线分析是用于评估一个因素预测能力的手段,是可以用于连续型变量分组的方法。在生存分析中,疾病状态和因素取值均会随时间发生变化。而标准的ROC曲线分析将个体的疾病状态和因素取值视作固定值,未将时间因素考虑在分析之中。在这种情况下…

一加Ace3/12/Ace2pro手机ColorOS14刷KernelSU内核ROOT-解决无限重启变砖

一加Ace3/一加12/一加11等手机升级了安卓14底层,并且ColorOS版本也更新到了14版本界面和功能都比之前的系统表现更加优秀,但刷机方面,相对之前存在一些差异,特别是KernelSU内核级别root权限,不再支持一键刷入KernelSU通…

【Linux网络】SSH--远程控制与访问

目录 一、SSH远程管理 1.SSH的定义 2.远程传输的种类 3.OpensSSH 4.SSH客户端与服务端 二、配置OpenSSH服务器 1.sshd_config配置文件的常用选项设置 2.sshd 服务支持两种验证方式 1)密码验证 2)密钥对验证 三、使用 SSH 客户端程序 1.ssh 远…

从 Sora 制作的短片看AI生成视频的优势与局限性解析

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

R语言的基本图形

一&#xff0c;条形图 安装包 install.packages("vcd") 绘制简单的条形图 barplot(c(1,2,4,5,6,3)) 水平条形图 barplot(c(1,2,4,5,6,3),horiz TRUE) 堆砌条形图 > d1<-c("Placebo","Treated") > d2<-c("None",&qu…

聚类分析:使用R语言对Iris数据集进行K均值聚类

引言 聚类分析是一种常用的无监督学习技术&#xff0c;旨在将数据集中的样本分成具有相似特征的组。K均值聚类是其中一种常见的方法&#xff0c;它通过将数据点划分为K个簇&#xff0c;并使每个数据点与其所属簇的中心点距离最小化来实现聚类。本文将介绍如何使用R语言执行K均…

matlab求时间序列的时间滞后相关性

matlab求时间序列的时间滞后相关性 自相关、互相关、加权相关、滞后相关等相关性分析&#xff0c;在时间序列分析中经常被用到&#xff0c;可以量化两个时间序列的相关程度&#xff0c;特别对于有季节性趋势的序列中这个分析尤为必要。下面介绍一个Matlab函数&#xff0c;用于进…

FPGA实现图像处理之【直方图均衡-寄存器版】

FPGA实现直方图统计 一、图像直方图统计原理 直方图的全称为灰度直方图&#xff0c;是对图像每一灰度间隔内像素个数的统计。即对一张图片中每隔二灰度值的像素数量做统计&#xff0c;然后以直方图的形式展现出来。图下的亮暗分布在直方图中就可以一目了然&#xff0c;直方图…

分布式系统事务一致性解决方案(基于事务消息)

参考&#xff1a;https://rocketmq.apache.org/zh/docs/featureBehavior/04transactionmessage/ 文章目录 概要错误的方案方案一&#xff1a;业务方自己实现方案二&#xff1a;RocketMQ 事务消息什么是事务消息事务消息处理流程事务消息生命周期使用限制使用示例使用建议 概要 …

WPF —— MVVM 指令执行不同的任务实例

标签页 设置两个按钮&#xff0c; <Button Content"修改状态" Width"100" Height"40" Background"red"Click"Button_Click"></Button><Button Content"测试"Width"100"Height"40&…

java案例-读取xml文件

需求 导入依赖 <dependencies><!-- dom4j --><dependency><groupId>dom4j</groupId><artifactId>dom4j</artifactId><version>1.6.1</version></dependency> </dependencies>代码 SAXReader saxReade…

BERT一个蛋白质-季军-英特尔创新大师杯冷冻电镜蛋白质结构建模大赛-paipai

关联比赛: “创新大师杯”冷冻电镜蛋白质结构建模大赛 解决方案 团队介绍 paipai队、取自 PAIN AI&#xff0c;核心成员如我本人IvanaXu(IvanaXu GitHub)&#xff0c;从事于金融科技业&#xff0c;面向银行信用贷款的风控、运营场景。但我们团队先后打过很多比赛&#xf…

Vue后台系统demo小计

创建项目 1.报错 Error: command failed: npm install --loglevel error --legacy-peer-deps 措施1&#xff1a;node.js文件夹属性 》高级 》选择第一个允许 Users(XXX\Users) &#xff08;对我无用&#xff09; 措施2&#xff1a;PowerShell(以管理员身份运行) 》 cd 想存…

C++ | Leetcode C++题解之第55题跳跃游戏

题目&#xff1a; 题解&#xff1a; class Solution { public:bool canJump(vector<int>& nums) {int n nums.size();int rightmost 0;for (int i 0; i < n; i) {if (i < rightmost) {rightmost max(rightmost, i nums[i]);if (rightmost > n - 1) {r…

【Leetcode每日一题】 动态规划 - 简单多状态 dp 问题 - 打家劫舍 II(难度⭐⭐)(67)

1. 题目解析 题目链接&#xff1a;213. 打家劫舍 II 这个问题的理解其实相当简单&#xff0c;只需看一下示例&#xff0c;基本就能明白其含义了。 2.算法原理 这个问题是经典的“打家劫舍”问题的变种&#xff0c;原问题是在单排房屋中进行偷窃&#xff0c;而这个问题则是在…

利用Triple U.Net结构对冷冻切片HE染色组织学图像进行核实例分割

利用Triple U.Net结构对冷冻切片H&E染色组织学图像进行核实例分割 摘要IntroductionRelated WorksDatasetProposed MethodologyDataset PreparationSegmentation BranchLoss FunctionWatershed Algorithm Nuclei Instance Segmentation of Cryosectioned H&E Stained H…

JavaScript全套检验系统(LIS)源码C# + MVC + SQLserver + Redis 云LIS系统源码 区域医疗云LIS系统源码

JavaScript全套检验系统&#xff08;LIS&#xff09;源码C# MVC SQLserver Redis 云LIS系统源码 区域医疗云LIS系统源码 实验室信息系统&#xff08;Laboratory Information System&#xff0c;缩写LIS&#xff09;是一类用来处理实验室过程信息的软件。这套系统通常与其他信…

ArcGIS基础:便捷分享图层包和地图包

1、分享图层包&#xff1a; 首先&#xff0c;选中要分享的数据&#xff0c;右键创建图层包&#xff0c;修改保存路径。 找到项目描述那一栏&#xff0c;将摘要、标签、描述都填写分享图层包的相关内容。 一切设置好之后&#xff0c;点击右上角的【分析】按钮。 点击分析之后…

vue2集成ElementUI编写登录页面

目录 1. 整理目录文件&#xff1a; a. app.vue文件如下&#xff1a; b. Login.vue文件如下&#xff1a; c. router/index.js文件如下&#xff1a; d. 删除components中的文件&#xff1a; e. 最终项目目录整理如下&#xff1a; 2. 集成ElementUI编写登录页面 a. 安装El…