QLoRa:在消费级GPU上微调大型语言模型

news2025/1/22 12:26:27

大多数大型语言模型(LLM)都无法在消费者硬件上进行微调。例如,650亿个参数模型需要超过780 Gb的GPU内存。这相当于10个A100 80gb的gpu。就算我们使用云服务器,花费的开销也不是所有人都能够承担的。

而QLoRa (Dettmers et al., 2023),只需使用一个A100即可完成此操作。

在这篇文章中将介绍QLoRa。包括描述它是如何工作的,以及如何使用它在GPU上微调具有200亿个参数的GPT模型。

为了进行演示,本文使用nVidia RTX 3060 12 GB来运行本文中的所有命令。这样可以保证小显存的要求,并且也保证可以使用免费的Google Colab实例来实现相同的结果。但是,如果你只有较小内存的GPU,则必须使用较小的LLM。

QLoRa: Quantized LLMs with Low-Rank Adapters

2021年6月,发布的LoRa让我们的微调变得简单,我也在以前的文章中也有过介绍。

LoRa为LLM的每一层添加了少量的可训练参数(适配器),并冻结了所有原始参数。这样对于微调,只需要更新适配器权重,这可以显著减少内存占用。

而QLoRa更进一步,引入了4位量化、双量化和利用nVidia统一内存进行分页。

简而言之,QLoRa工作原理如下:

  • 4位NormalFloat量化:这是一种改进量化的方法。它确保每个量化仓中有相同数量的值。这避免了计算问题和异常值的错误。
  • 双量化:QLoRa的作者将其定义如下“对量化常量再次量化以节省额外内存的过程。”
  • 统一内存分页:它依赖于NVIDIA统一内存管理,自动处理CPU和GPU之间的页到页传输。它可以保证GPU处理无错,特别是在GPU可能耗尽内存的情况下。

所有这些步骤都大大减少了微调所需的内存,同时性能几乎与标准微调相当。

使用QLoRa对GPT模型进行微调

硬件要求:

下面的演示工作在具有12gb VRAM的GPU上,用于参数少于200亿个模型,例如GPT-J。

如果你有一个更大的卡,比如24gb的VRAM,则可以用一个200亿个参数的模型,例如GPT-NeoX-20b。

内存建议至少6 Gb,这个条件现在都能满足对吧

GPT-J和GPT-NeoX-20b都是非常大的模型。所以硬盘议至少有100gb的可用空间。

如果你的机器不满足这些要求,可以使用Google Colab的免费实例,因为它就足够使用了。

软件要求:

必须要CUDA。这是肯定的。然后还需要一些依赖:

  • bitsandbytes:包含量化LLM所需的所有库。
  • Hugging Face的Transformers和Accelerate:这些是标准库,用于训练模型。
  • PEFT:提供了各种微调方法的实现,我们只需要里面的LoRa。
  • 数据集:自己的数据集,这里安装了Hugging Face的datasets,这个是备选,装不装无所谓,因为这玩意挺难用的

PIP安装命令如下:

 pip install -q -U bitsandbytes
 pip install -q -U git+https://github.com/huggingface/transformers.git 
 pip install -q -U git+https://github.com/huggingface/peft.git
 pip install -q -U git+https://github.com/huggingface/accelerate.git
 pip install -q datasets

下面就是Python代码

1、GPT模型的加载与量化

我们需要以下导入来加载和量化LLM。

 import torch
 from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

我们将对EleutherAI预训练的GPT NeoX模型进行微调。这是一个有200亿个参数的模型。注意:GPT NeoX具有允许商业使用的宽松许可证(Apache 2.0)。

可以从hug Face Hub获得这个模型和相关的标记器:

 model_name = "EleutherAI/gpt-neox-20b"
 
 #Tokenizer
 tokenizer = AutoTokenizer.from_pretrained(model_name)

然后配置量化器,如下所示:

 quant_config = BitsAndBytesConfig(
     load_in_4bit=True,
     bnb_4bit_use_double_quant=True,
     bnb_4bit_quant_type="nf4",
     bnb_4bit_compute_dtype=torch.bfloat16
 )
  • load_in_4bit:模型将以4位精度加载到内存中。
  • bnb_4bit_use_double_quant:QLoRa提出的双量化。
  • bnb_4bit_quant_type:这是量化的类型。“nf4”代表4位的NormalFloat。
  • bnb_4bit_compute_dtype:当以4位加载和存储模型时,在需要时对其进行部分量化,并以16位精度(bfloat16)进行所有计算。

然后就可以加载4位模型:

 model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=quant_config, device_map={"":0})

下一步启用梯度检查点,这样可以减少内存占用,但是速度会稍微降低一些:

 model.gradient_checkpointing_enable()

2、LoRa的GPT模型预处理

为LoRa准备模型,为每一层添加可训练的适配器。

 from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
 
 model = prepare_model_for_kbit_training(model)
 
 config = LoraConfig(
     r=8, 
     lora_alpha=32, 
     target_modules=["query_key_value"], 
     lora_dropout=0.05, 
     bias="none", 
     task_type="CAUSAL_LM"
 )
 
 model = get_peft_model(model, config)

在LoraConfig中,可以使用r、alpha和dropout来获得更好的任务结果。具体内容可以在PEFT文档中找到更多选项和详细信息。

使用LoRa,我们只添加了800万个参数。并且只训练这些参数,这样使得微调很快。

3、数据集

对于这个演示,我们使用“english_quotes”数据集。这是一个由名言组成的数据集,在CC BY 4.0许可下发布。我们为了方便使用datasets直接加载。

 from datasets import load_dataset
 data = load_dataset("Abirate/english_quotes")
 data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)

4、微调

微调的代码非常标准

 import transformers
 
 tokenizer.pad_token = tokenizer.eos_token
 
 trainer = transformers.Trainer(
     model=model,
     train_dataset=data["train"],
     args=transformers.TrainingArguments(
         per_device_train_batch_size=1,
         gradient_accumulation_steps=8,
         warmup_steps=2,
         max_steps=20,
         learning_rate=2e-4,
         fp16=True,
         logging_steps=1,
         output_dir="outputs",
         optim="paged_adamw_8bit"
     ),
     data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
 )
 trainer.train()

要记住optim=”paged_adamw_8bit”。它将使用分页实现更好的内存管理。没有它可能会出现内存不足错误。

在Google Colab上运行这个微调只需要5分钟。VRAM消耗的峰值是15gb。

它有用吗?让我们试试推理。

基于QLoRa推理

微调的QLoRa模型可以直接与标准的Transformers的推理一起使用,如下所示:

 text = "Ask not what your country"
 device = "cuda:0"
 inputs = tokenizer(text, return_tensors="pt").to(device)
 
 outputs = model.generate(**inputs, max_new_tokens=20)
 print(tokenizer.decode(outputs[0], skip_special_tokens=True))

你应该得到这样的输出:

 Ask not what your country can do for you, ask what you can do for your country.”
 
 – John F.

5分钟的微调效果还可以吧。

总结

LoRa让我们的微调变得简单,而QLoRa可以让我们使用消费级的GPU对具有10亿个参数的模型进行微调,并且根据QLoRa论文,性能不会显著下降。

如果你对QLoRa感兴趣,看看他的代码吧:

https://avoid.overfit.cn/post/4c4c86e3f7974157a7a8e81c57a0f8a4

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/600712.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

公司裁员不给赔偿怎么办?

阅读本文大概需要 1.61 分钟。 最近在星球回答球友问题的时候,发现不少人都提到裁员这个话题。 有球友说他们公司在裁员,但不想给赔偿。 领导给他的方案是把年假调休休了,然后再给三周找工作时间,这三周不用打卡,三周后…

茅塞顿开的C#代码——通用型科学计算器

计算器是经常遇到的编程作业。 一般都是实现加、减、乘、除四则运算的普通计算器。 这里介绍用几十行C#代码实现的复杂的《科学计算器》,可以计算各种函数。 不知道其他语言实现同样的功能需要编写多少行代码?20000行? using System; usin…

SpringBoot接口如何正确地接收时间参数

唠嗑部分 在做Java开发时,肯定会碰到传递时间参数的情况吧,比如用户的出生日期、活动的开始,结束日期等等,这些参数往往是由前端传递过来的,那么在SpringBoot项目中,该如何正确的接收日期参数呢&#xff0…

如果不小心上了电信黑名单,应该怎么妥善处理呢?

有些小伙伴们在处理不用的手机卡时,可能会粗心大意,认为不用了就用不了呗,存在欠费停机的情况下也没有及时的去补交欠费,然后销户,导致了自己不小心上了电信黑名单,那遇到这种情况,应该怎么妥善…

论文解读 | 利用图形卷积核在距离图像中实现高效的3D目标检测

原创 | 文 BFT机器人 01 摘要 该论文提出了一种基于范围图像的高效3D物体检测方法,通过利用图卷积核来提取每个像素周围的局部几何信息。 作者设计了一种新颖的2D卷积网络架构,并提出了四种替代内积核心的卷积核,以注入所需的三维信息。该方法…

GPT最常用的应用场景有哪些?

生成式预训练转换器(GPT)是一种深度学习模型,它能够根据给定的提示生成类似人类的文本,彻底改变了自然语言处理(NLP)领域。 聊天机器人和虚拟助手 GPT最受欢迎的应用程序之一是开发聊天机器人和虚拟助手。凭…

【Python 自然语言处理(NLP)】零基础也能轻松掌握的学习路线与参考资料

Python 自然语言处理(NLP)是目前人工智能(AI)发展中的重要领域。随着科技的不断进步,NLP已经被应用于文本自动摘要、机器翻译、语音识别、情感分析、问答系统等各项实际任务中。 要学习 Python 自然语言处理&#xff…

“河南省数字化转型与信息技术应用创新专家研讨会-政府数字化转型推动信创产业发展”专场活动成功召开

由《中国信息化》杂志社主办的“2023河南省数字化转型与信息技术应用创新专家研讨会——政府数字化转型推动信创产业发展”专场活动于5月27日,在郑州成功举办。本次活动由深圳竹云科技股份有限公司协办,由河南省测绘学会、中国信息主管平台支持。中国交通…

Windows 安装部署 MinIo

1、下载地址 安装包下载地址:https://min.io/download#/windows 2、安装目录 下载的是一个可执行文件 minio.exe 将其放到一个方便寻找的目录,我这里放在 D:\develop\minio 同时新建一个 data 文件夹,用来存储上传的文件 3、启动 MinIo 服…

Godot引擎 4.0 文档 - 第一个 3D 游戏

本文为Google Translate英译中结果,DrGraph在此基础上加了一些校正。英文原版页面: Your first 3D game — Godot Engine (stable) documentation in English 你的第一个 3D 游戏 在这个循序渐进的教程系列中,您将使用 Godot 创建您的第一…

Java程序设计入门教程 -- 二维数组

二维数组创建 定义数组 声明二维数组 Java中二维数组的声明格式: 数据类型名[ ][ ] 数组名; 或 数据类型名 数组名[ ][ ]; 分配数组内存 常用格式: new 数组名[ M][N ] ; //M,N为数组行号和列号 分配数组内存 1)规…

chatgpt赋能python:Python的几次方符号介绍

Python的几次方符号介绍 当我们在使用Python编程时,经常需要进行数学计算。其中最常见的计算就是幂运算。Python使用幂运算符号来表示一个数的几次方。这个符号既可以用在数字之间,也可以用在变量之间。在本文中,我们将介绍Python中的几次方…

堆基础1_小白垃圾笔记

小白垃圾笔记,不建议阅读。 目录 1.什么是堆? 2.堆从哪里来? 3.堆管理器是什么 4.堆申请的实现方式 1.brk:brk仅仅主线程申请小空间的时候用,子线程不可调用brk。 2.mmap:主线程申请大的内存的时候和…

Mysql小知识 delete 清空表之后,磁盘空间未发生变化?

1. 删除空洞 1.1 案例展示 首先我们先来看这样一个例子。 我现在有一个名为 sakila 的数据库,该库中有一个 film 表,这个表中有 1000 条记录,我么先来看下这 1000 条记录占用了多少存储空间: 小伙伴们可以看到,这个…

Go-FastDFS 本地对象存储,Windows环境搭建(下载安装教程)!

文章目录 Go-FastDFS简介与地址下载安装服务与管理端台可视化测试 Go-FastDFS简介与地址 go-fastdfs 是一个基于 http 协议的分布式文件系统! 它基于大道至简的设计理念,一切从简设计,使得它的运维及扩展变得更加简单,它具有高性…

Elasticsearch:使用带有 X-Opaque-Id 的慢速查询功能在 Elasticsearch 中调试慢速查询

如果你在软件堆栈中使用 Elasticsearch,你可能已经意识到 Elasticsearch 管理大量数据和提供实时搜索功能的强大能力。 了解 Elasticsearch 中的慢速查询 Slow Log 是 Elasticsearch 的内置功能,可用于识别慢速搜索。 任何花费时间超过预期的请求都会记…

Paper reading: Conditional Diffusion for Interactive Segmentation ICCV2021

交互式语义分割 We propose Conditional Diffusion Network (CDNet), which propagates labeled representations from clicks to conditioned destinations with two levels of affinities: Feature Diffusion Module (FDM) spreads features from clicks to potential targ…

单例模式-图文详解

概念 全世界就只要一个---在整个java程序中,只有这个类的一个实例 比如Student a new Student(); 就是Student类只创建这一个实例,只能有这一个对象存在 主要解决:一个全局使用的类频繁地创建与销毁。在内存里只有一个实例,减…

【白话机器学习系列】白话张量

白话张量 张量(Tensor)是向量和矩阵向 n n n 维的推广。向量是一维张量,矩阵是二维张量。张量作为数值容器,是机器学习,尤其是深度学习中最基础的操作对象,以至于 Google 的机器学习框架都已 TensorFlow …

ffmpeg在windows环境下的详细安装教程

这两天整理好用的录屏软件,发现了Captura这个软件,软件本身的安装很简单,但由于Captura需要依赖ffmpeg(一套可以用来记录、转换数字音频、视频,并能将其转化为流的开源计算机程序),而ffmpeg在安…