使用GaLore在本地GPU进行高效的LLM调优

news2024/9/23 20:00:15

训练大型语言模型(llm),即使是那些“只有”70亿个参数的模型,也是一项计算密集型的任务。这种水平的训练需要的资源超出了大多数个人爱好者的能力范围。为了弥补这一差距,出现了低秩适应(LoRA)等参数高效方法,可以在消费级gpu上对大量模型进行微调。

GaLore是一种新的方法,它不是通过直接减少参数的数量,而是通过优化这些参数的训练方式来降低VRAM需求,也就是说GaLore是一种新的模型训练策略,可让模型使用全部参数进行学习,并且比LoRA更省内存。

GaLore将这些梯度投影到低秩空间上,显著减少了计算负荷,同时保留了训练所需的基本信息。与传统的优化器在反向传播后同时更新所有层的方法不同,GaLore在反向传播期间实现逐层更新。这种方法进一步减少了整个训练过程中的内存占用。

就像LoRA一样,GaLore可以让我们在具有24 GB VRAM的消费级GPU上微调7B模型。结果模型的性能与全参数微调相当,并且似乎优于LoRA。

优于目前Hugging Face还没有官方代码,我们就来手动使用论文的代码进行训练,并与LoRA进行对比

安装依赖

首先就要安装GaLore

 pip install galore-torch

然后我们还要一下这些库,并且请注意版本

 datasets==2.18.0
 transformers==4.39.1
 trl==0.8.1
 accelerate==0.28.0
 torch==2.2.1

调度器和优化器的类

Galore分层优化器是通过模型权重挂钩激活的。由于我们使用Hugging Face

Trainer

,还需要自己实现一个优化器和调度器的抽象类。这些类的结构不执行任何操作。

 from typing import Optional
 import torch
 
 # Approach taken from Hugging Face transformers https://github.com/huggingface/transformers/blob/main/src/transformers/optimization.py
 class LayerWiseDummyOptimizer(torch.optim.Optimizer):
     def __init__(self, optimizer_dict=None, *args, **kwargs):
         dummy_tensor = torch.randn(1, 1)
         self.optimizer_dict = optimizer_dict
         super().__init__([dummy_tensor], {"lr": 1e-03})
 
     def zero_grad(self, set_to_none: bool = True) -> None: 
       pass
 
     def step(self, closure=None) -> Optional[float]: 
       pass
 
 class LayerWiseDummyScheduler(torch.optim.lr_scheduler.LRScheduler):
     def __init__(self, *args, **kwargs):
         optimizer = LayerWiseDummyOptimizer()
         last_epoch = -1
         verbose = False
         super().__init__(optimizer, last_epoch, verbose)
 
     def get_lr(self): 
       return [group["lr"] for group in self.optimizer.param_groups]
 
     def _get_closed_form_lr(self): 
       return self.base_lrs

加载GaLore优化器

GaLore优化器的目标是特定的参数,主要是那些在线性层中以attn或mlp命名的参数。通过系统地将函数与这些目标参数挂钩,GaLore 8位优化器就会开始工作。

 from transformers import get_constant_schedule
 from functools import partial
 import torch.nn
 import bitsandbytes as bnb
 
 from galore_torch import GaLoreAdamW8bit
         
 def load_galore_optimizer(model, lr, galore_config):    
     # function to hook optimizer and scheduler to a given parameter 
     def optimizer_hook(p, optimizer, scheduler):
         if p.grad is not None: 
             optimizer.step()
             optimizer.zero_grad()
             scheduler.step()
 
     # Parameters to optimize with Galore
     galore_params = [
         (module.weight, module_name) for module_name, module in model.named_modules() 
         if isinstance(module, nn.Linear) and any(target_key in module_name for target_key in galore_config["target_modules_list"])
     ] 
     id_galore_params = {id(p) for p, _ in galore_params}
     
     # Hook Galore optim to all target params, Adam8bit to all others
     for p in model.parameters():
         if p.requires_grad:
             if id(p) in id_galore_params:
                 optimizer = GaLoreAdamW8bit([dict(params=[p], **galore_config)], lr=lr)
             else:
                 optimizer = bnb.optim.Adam8bit([p], lr = lr)
             scheduler = get_constant_schedule(optimizer)
             
             p.register_post_accumulate_grad_hook(partial(optimizer_hook, optimizer=optimizer, scheduler=scheduler))
             
     # return dummies, stepping is done with hooks 
     return LayerWiseDummyOptimizer(), LayerWiseDummyScheduler()

HF Trainer

准备好优化器后,我们开始使用Trainer进行训练。下面是一个简单的例子,使用TRL的SFTTrainer (Trainer的子类)在Open Assistant数据集上微调llama2-7b,并在RTX 3090/4090等24 GB VRAM GPU上运行。

 from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, set_seed, get_constant_schedule
 from trl import SFTTrainer, setup_chat_format, DataCollatorForCompletionOnlyLM
 from datasets import load_dataset
 import torch, torch.nn as nn, uuid, wandb
 
 lr = 1e-5
 
 # GaLore optimizer hyperparameters
 galore_config = dict(
     target_modules_list = ["attn", "mlp"], 
     rank = 1024, 
     update_proj_gap = 200, 
     scale = 2, 
     proj_type="std"
 )
 
 modelpath = "meta-llama/Llama-2-7b"
 model = AutoModelForCausalLM.from_pretrained(
     modelpath,    
     torch_dtype=torch.bfloat16,
     attn_implementation = "flash_attention_2",  
     device_map = "auto",
     use_cache = False,
 )
 tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast = False)
 
 # Setup for ChatML
 model, tokenizer = setup_chat_format(model, tokenizer)
 if tokenizer.pad_token in [None, tokenizer.eos_token]: 
     tokenizer.pad_token = tokenizer.unk_token
 
 # subset of the Open Assistant 2 dataset, 4000 of the top ranking conversations
 dataset = load_dataset("g-ronimo/oasst2_top4k_en")
 
 training_arguments = TrainingArguments(
     output_dir = f"out_{run_id}",
     evaluation_strategy = "steps",
     label_names = ["labels"],
     per_device_train_batch_size = 16,
     gradient_accumulation_steps = 1,
     save_steps = 250,
     eval_steps = 250,
     logging_steps = 1, 
     learning_rate = lr,
     num_train_epochs = 3,
     lr_scheduler_type = "constant",
     gradient_checkpointing = True,
     group_by_length = False,
 )
 
 optimizers = load_galore_optimizer(model, lr, galore_config)
 
 trainer = SFTTrainer(
     model = model,
     tokenizer = tokenizer,
     train_dataset = dataset["train"],
     eval_dataset = dataset['test'],
     data_collator = DataCollatorForCompletionOnlyLM(
         instruction_template = "<|im_start|>user", 
         response_template = "<|im_start|>assistant", 
         tokenizer = tokenizer, 
         mlm = False),
     max_seq_length = 256,
     dataset_kwargs = dict(add_special_tokens = False),
     optimizers = optimizers,
     args = training_arguments,
 )
 
 trainer.train()

GaLore优化器带有一些需要设置的超参数如下:

target_modules_list:指定GaLore针对的层

rank:投影矩阵的秩。与LoRA类似,秩越高,微调就越接近全参数微调。GaLore的作者建议7B使用1024

update_proj_gap:更新投影的步骤数。这是一个昂贵的步骤,对于7B来说大约需要15分钟。定义更新投影的间隔,建议范围在50到1000步之间。

scale:类似于LoRA的alpha的比例因子,用于调整更新强度。在尝试了几个值之后,我发现scale=2最接近于经典的全参数微调。

微调效果对比

给定超参数的训练损失与全参数调优的轨迹非常相似,表明GaLore分层方法确实是等效的。

用GaLore训练的模型得分与全参数微调非常相似。

GaLore可以节省大约15 GB的VRAM,但由于定期投影更新,它需要更长的训练时间。

上图为2个3090的内存占用对比

训练事件对比,微调:~58分钟。GaLore:约130分钟

最后我们再看看GaLore和LoRA的对比

上图为LoRA微调所有线性层,rank64,alpha 16的损失图

从数值上可以看到GaLore是一种近似全参数训练的新方法,性能与微调相当,比LoRA要好得多。

总结

GaLore可以节省VRAM,允许在消费级GPU上训练7B模型,但是速度较慢,比微调和LoRA的时间要长差不多两倍的时间。

GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection.

https://avoid.overfit.cn/post/0b15de8db27040f0abcaa7e554b0b993

作者:Geronimo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1542650.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Canvas与艺术】暗蓝网格汽车速度仪表盘

【关键点】 采用线性渐变色&#xff0c;使上深下浅的圆有凹下效果&#xff0c;使上浅下深的圆有凸起效果&#xff0c;两者结合就有立体圆钮的感觉。 【图例】 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type&quo…

【Python机器学习系列】机器学习中的模型微调---随机搜索(案例+源码)

这是我的第245篇原创文章。 一、引言 如果探索的组合数量较少时&#xff0c;网格搜索是一种不错的方法&#xff0c;但当超参数的搜索范围较大时&#xff0c;通常会优先选择使用 RandomizedSearchCV 。它与 GridSearchCV 用法相似&#xff0c;但它不会尝试所有可能的组合&…

华为升级FIT AP示例(通过AC的命令行)

升级FIT AP示例&#xff08;通过AC的命令行&#xff09; 前提条件 从官网下载升级目标版本对应的系统软件包&#xff0c;保存在PC本地。如果下载的文件是压缩文件&#xff0c;则需要解压缩出系统软件包。 AP已在WAC上线。 背景信息 升级的过程是先将系统软件包传到设备上&…

微信小程序button动态跳转到页面

微信小程序中如何动态的跳转到某个页面。 目录 1、首先在js文件中定义事件函数 2、在页面中进行传参调用 3、其它跳转方法简单说明 1、首先在js文件中定义事件函数 goto(e){const urle.currentTarget.dataset.url;wx.navigateTo({url: url})}, 2、在页面中进行传参调用 &l…

C++之char16_t*与char*类型相互转换(二百六十)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a;多媒…

国内ip修改用什么软件下载?

在特定情况下&#xff0c;可能需要修改你的国内IP地址以实现网络访问需求或绕过地域限制。许多软件和工具可以帮助你实现这一目标&#xff0c;无论是为了隐私保护还是访问特定内容。虎观代理小二将介绍一些推荐的软件下载途径&#xff0c;以便你修改国内IP地址。 1. IP代理软件…

深度解析:Elasticsearch写入请求处理流程

版本 Elasticsearch 8.x 原文链接&#xff1a;https://mp.weixin.qq.com/s/hZ_ZOLFUoRuWyqp47hqCgQ 今天来看下 Elasticsearch 中的写入流程。 不想看过程可以直接跳转文章末尾查看总结部分。最后附上个人理解的一个图。 从我们发出写入请求&#xff0c;到 Elasticsearch 接收请…

5个适用于 Windows/PC 的水印去除软件(视频/图像)

水印是文本、徽标、印记、图像或签名&#xff0c;通常叠加在视频、其他图像或具有较高透明度的 PDF 文档上。当您免费使用某些产品&#xff08;例如视频编辑器&#xff09;时&#xff0c;最终输出通常带有代表您使用的编辑器的水印。您可能需要出于您的目的从此类媒体文件中删除…

Django之Celery篇(三)

一、任务交给Celery Django任务交给Celery的方法和普通使用Celery任务的调用基本无区别,只是将执行代码的放到到View视图中 而获取结果,往往并不能把结果和第1次请求一起响应,若想获取结果是通过第2次请求获取结果 代码如下: from django.http import HttpResponsefrom …

我们是如何在 IDE 中设计 AutoDev 的 AI 编程开发智能体语言与框架?

上周微软发布了自家的 AI 编程和软件开发智能体框架&#xff1a;AutoDev&#xff0c;其与我们开发的 IDE 插件 AutoDev 有颇多的相似之处&#xff0c;特别是一些设计思路&#xff0c;以及在对于辅助软件开发任务的智能体以及一些基础设施上。 稍有不同的是&#xff1a; 交互介质…

Axure RP 9 for mac中文版密钥激活版下载

Axure RP 9是一款专业的快速原型设计工具&#xff0c;它可以帮助产品设计师、交互设计师和用户体验设计师等创建高保真度、交互性强的原型&#xff0c;以便在产品开发之前进行测试和用户验证。 软件下载&#xff1a;Axure RP 9 for mac中文版密钥激活版下载 该工具具有丰富的功…

微信小程序实现多张照片上传

hello hello~ &#xff0c;这里是 code袁~&#x1f496;&#x1f496; &#xff0c;欢迎大家点赞&#x1f973;&#x1f973;关注&#x1f4a5;&#x1f4a5;收藏&#x1f339;&#x1f339;&#x1f339; &#x1f4a5;个人主页&#xff1a;code袁 &#x1f4a5; 所属专栏&…

python绘图matplotlib——使用记录1

本博文来自于网络收集&#xff0c;如有侵权请联系删除 使用matplotlib绘图 1 常用函数汇总1.1 plot1.2 legend1.3 scatter1.4 xlim1.5 xlabel1.6 grid1.7 axhline1.7 axvspan1.8 annotate1.9 text1.10 title 2 常见图形绘制2.1 bar——柱状图2.2 barh——条形图2.3 hist——直…

浏览器工作原理与实践--渲染流程(上):HTML、CSS和JavaScript,是如何变成页面的

在上一篇文章中我们介绍了导航相关的流程&#xff0c;那导航被提交后又会怎么样呢&#xff1f;就进入了渲染阶段。这个阶段很重要&#xff0c;了解其相关流程能让你“看透”页面是如何工作的&#xff0c;有了这些知识&#xff0c;你可以解决一系列相关的问题&#xff0c;比如能…

服务端高并发分布式结构

前言 本文以⼀个 “电子商务” 应用为例&#xff0c;介绍从⼀百个到千万级并发情况下服务端的架构的演进过程&#xff0c;同时列举出每个演进阶段会遇到的相关技术&#xff0c;让大家对架构的演进有⼀个整体的认知&#xff0c;方便⼤家对后续知识做深⼊学习时有⼀定的整体视野…

二次开发Flink-coGroup算子支持迟到数据通过测输出流提取

目录 1.背景 2.coGroup算子源码分析 2.1完整的coGroup算子调用流程 2.2coGroup方法入口 2.3 CoGroupedStreams对象分析 2.4WithWindow内部类分析 2.5CoGroupWindowFunction函数分析 3.修改源码支持获取迟到数据测输出流 3.1复制CoGroupedStreams 3.2新增WithWindow.si…

YiYi-Web项目介绍

YiYi-Web项目介绍 1. 简介2. 使用2.1 后端开发环境2.2 前端开发环境 3. 测试环境&#xff1a;4. 更新日志5. 打包情况6.项目截图 本项目前端是html、css、js、jQuery基础技术。 后端都是最新的SpringBoot技术&#xff0c;不分离版本&#xff0c; 是最基础的项目开发教程&#x…

Spark Map 和 FlatMap 的比较

Spark Map 和 FlatMap 的比较 本节将介绍Spark中map(func)和flatMap(func)两个函数的区别和基本使用。 函数原型 map(func) 将原数据的每个元素传给函数func进行格式化&#xff0c;返回一个新的分布式数据集。 flatMap(func) 跟map(func)类似&#xff0c;但是每个输入项和…

Flink GateWay、HiveServer2 和 hive on spark

Flink SQL Gateway简介 从官网的资料可以知道Flink SQL Gateway是一个服务&#xff0c;这个服务支持多个客户端并发的从远程提交任务。Flink SQL Gateway使任务的提交、元数据的查询、在线数据分析变得更简单。 Flink SQL Gateway的架构如下图&#xff0c;它由插件化的Endpoi…

孩子看书用白光还是暖白光好呢?五大护眼台灯推荐,必选!

孩子看书时&#xff0c;选择合适的光线是至关重要的。白光与暖白光各有优劣&#xff0c;白光亮度高、色彩还原性好&#xff0c;适合需要高度集中注意力和精细操作的场合&#xff1b;而暖白光则更加柔和、不刺眼&#xff0c;有助于营造轻松的阅读氛围。为了帮助家长更好地为孩子…