DeepSeek-7B-chat 4bits量化 Qlora 微调

news2024/12/23 23:28:01

在本文中我们将学习DeepSeek量化微调的方法,并且从微调结果体会大模型微调的重要性。

引言

在当前快速发展的自然语言处理领域,模型的精度和效率是关键。量化和微调技术可以有效提高模型性能。本文将探讨如何对DeepSeek-7B-chat模型进行4bits量化,并利用Qlora技术进行微调,以实现高效的模型部署。

什么是模型量化?

模型量化是将高精度的浮点数表示转换为低精度表示(如4bits),以减少模型的存储和计算资源。量化可以显著降低模型的内存占用和计算复杂度,同时保持较高的推理性能。

DeepSeek-7B-chat模型

DeepSeek-7B-chat是一个大规模的语言模型,设计用于对话生成和语言理解。其庞大的参数量使得直接部署在资源受限的环境中具有挑战性,因此量化技术尤为重要。

Qlora 技术简介

Qlora(Quantized Low-Rank Adapter)是一种优化微调技术,适用于量化后的模型。通过低秩近似和适应层的结合,Qlora在微调阶段保持高效,并在不显著增加计算成本的情况下提高模型性能。

环境配置
pip install transformers==4.35.2
pip install peft==0.4.0
pip install datasets==2.10.1
pip install accelerate==0.20.3
pip install tiktoken
pip install transformers_stream_generator
pip install bitsandbytes==0.41.1
指令集构建

LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:

{
    "instrution":"回答以下用户问题,仅输出答案。",
    "input":"1+1等于几?",
    "output":"2"
}

其中,instruction 是用户指令,告知模型其需要完成的任务;input 是用户输入,是完成用户指令所必须的输入内容;output 是模型应该给出的输出。即我们的核心训练目标是让模型具有理解并遵循用户指令的能力。因此,在指令集构建时,我们应针对我们的目标任务,针对性构建任务指令集。

数据格式化

Lora 训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 Pytorch 模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 labels,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:

def process_func(example):
    MAX_LENGTH = 384    # Llama分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性
    input_ids, attention_mask, labels = [], [], []
    instruction = tokenizer(f"User: {example['instruction']+example['input']}\n\n", add_special_tokens=False)  # add_special_tokens 不在开头加 special_tokens
    response = tokenizer(f"Assistant: {example['output']}<|end▁of▁sentence|>", add_special_tokens=False)
    input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]
    attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1]  # 因为eos token咱们也是要关注的所以 补充为1
    labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id]  
    if len(input_ids) > MAX_LENGTH:  # 做一个截断
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }
加载tokenizer和半精度模型
tokenizer = AutoTokenizer.from_pretrained('./deepseek-ai/deepseek-llm-7b-chat/', use_fast=False, trust_remote_code=True)
tokenizer.padding_side = 'right' # padding在右边

model = AutoModelForCausalLM.from_pretrained(
        '/root/model/deepseek-ai/deepseek-llm-7b-chat/', 
        trust_remote_code=True, 
        torch_dtype=torch.half, 
        device_map="auto",
        low_cpu_mem_usage=True,   # 是否使用低CPU内存
        load_in_4bit=True,  # 是否在4位精度下加载模型。如果设置为True,则在4位精度下加载模型。
        bnb_4bit_compute_dtype=torch.half,  # 4位精度计算的数据类型。这里设置为torch.half,表示使用半精度浮点数。
        bnb_4bit_quant_type="nf4", # 4位精度量化的类型。这里设置为"nf4",表示使用nf4量化类型。
        bnb_4bit_use_double_quant=True  # 是否使用双精度量化。如果设置为True,则使用双精度量化。
    )
model.generation_config = GenerationConfig.from_pretrained('/root/model/deepseek-ai/deepseek-llm-7b-chat/')
model.generation_config.pad_token_id = model.generation_config.eos_token_id
定义LoraConfig
  • task_type:模型类型
  • target_modules:需要训练的模型层的名字,主要就是attention部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。
  • rlora的秩,具体可以看Lora原理
  • lora_alphaLora alaph,具体作用参见 Lora 原理
config = LoraConfig(
    task_type=TaskType.CAUSAL_LM, 
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    inference_mode=False, # 训练模式
    r=8, # Lora 秩
    lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理
    lora_dropout=0.1# Dropout 比例
)
自定义 TrainingArguments 参数
  • output_dir:模型的输出路径
  • per_device_train_batch_size:顾名思义 batch_size
  • gradient_accumulation_steps: 梯度累加,如果你的显存比较小,那可以把 batch_size 设置小一点,梯度累加增大一些。
  • logging_steps:多少步,输出一次log
  • num_train_epochs:顾名思义 epoch
  • gradient_checkpointing:梯度检查,这个一旦开启,模型就必须执行model.enable_input_require_grads(),这个原理大家可以自行探索,这里就不细说了。
  • optim="paged_adamw_32bit" 使用QLora的分页器加载优化器
args = TrainingArguments(
    output_dir="./output/DeepSeek",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    logging_steps=10,
    num_train_epochs=3,
    save_steps=100,
    learning_rate=1e-4,
    save_on_each_node=True,
    gradient_checkpointing=True,
    optim="paged_adamw_32bit"  # 优化器类型
)
使用 Trainer 训练
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_id,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()

在利用量化微调之后,我们能明显看出大模型的推理能力增强,结果如下:

效果很明显,当没有进行微调时,模型的理解能力还停留在信息检索层面,当微调之后,我们能进入角色对话状态,很好地带入语境。

以此模型微调为基础,我们可以深入探讨模型的量化与微调。

模型量化是将模型权重从高精度的浮点数表示(如32位浮点数)转换为低精度表示(如8位、4位甚至1位),以减少模型的内存占用和计算资源。常见的量化方法包括:

  • 静态量化:在训练后对模型进行量化。
  • 动态量化:在推理过程中进行量化。
  • 量化感知训练(QAT):在训练过程中模拟量化的影响。

量化的好处

  • 减少内存占用:显著减少模型的存储空间。
  • 加速推理:降低计算复杂度,提高推理速度。
  • 节省能耗:在移动设备或嵌入式系统中,量化模型能显著降低能耗。

量化过程 以4bits量化为例,主要步骤包括:

  • 缩放和零点计算:将浮点数范围映射到整数范围。
  • 量化权重和激活:使用缩放因子和零点进行量化。
  • 反量化:在推理时将量化值恢复为浮点数进行计算。

QLoRA 简介 QLoRA(Quantized Low-Rank Adapter)是一种优化微调技术,结合了量化和低秩适配层,通过降低模型参数的秩和引入适配层,实现高效微调。

QLoRA 的优势

  • 高效训练:减少训练参数,降低计算资源需求。
  • 保持性能:在量化后的模型上进行微调,保持模型的高性能。
  • 灵活性强:适用于多种模型和任务,特别是大规模语言模型。

QLoRA 微调步骤

  • 数据准备:整理微调所需的数据集。
  • 模型量化:先对模型进行量化,如前文的4bits量化。
  • 低秩适配:使用低秩近似技术降低模型参数的秩。
  • 适配层训练:引入适配层,结合数据进行微调。

学习原文链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1807423.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux | buildrootfs 添加mkfs.ext3/mkfs.ext4 支持

因个人需要&#xff0c;mkfs.ext3 但是项目中还没有这个命令 所以琢磨了半天 这里将其小记一下 在buildrootfsz中&#xff0c;需要将e2fsprogs 勾选上然后重新编译就好了 make menuconfig Target packages-> Filesystem and flash utilities-> e2fsprogs

stm32编写Modbus步骤

1. modbus协议简介&#xff1a; modbus协议基于rs485总线&#xff0c;采取一主多从的形式&#xff0c;主设备轮询各从设备信息&#xff0c;从设备不主动上报。 日常使用都是RTU模式&#xff0c;协议帧格式如下所示&#xff1a; 地址 功能码 寄存器地址 读取寄存器…

springCloudAlibaba之分布式事务组件---seata

Seata Sea学习分布式事务Seata二阶段提交协议AT模式 Sea学习 事务&#xff1a;事务是访问数据库并更新数据库中各项数据的一个程序执行单元。在关系数据库中&#xff0c;一个事务由一组或多组SQL语句组成。事务应该具有4个属性&#xff1a;原子性、一致性、隔离性、持久性。例如…

数据交换平台_10_activatemq 中间件容错性测试

目录概要 3. 容错测试: - 模拟ActiveMQ在异常情况下的表现,如网络中断、节点故障等。 - 观察ActiveMQ的容错机制是否能够正确处理异常情况,保证消息的可靠传输。 - 根据容错测试结果,优化ActiveMQ的容错机制,确保系统在面对异常情况时能够正确处理并恢复。 设计: 容错测…

WPF-UI布局

WPF布局元素有如下几个&#xff1a; Grid&#xff1a;网格。可以自定义行和列并通过行列的数量、行高和列宽来调整控件的布局。StackPanel&#xff1a;栈式面板。可将包含的元素在竖直或水平方向上排成一条直线&#xff0c;当移除一个元素后&#xff0c;后面的元素会自动向前移…

鸿蒙HarmonyOS中的ohpm相关知识点总结

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 OHPM是什么&#xff1f;一、OpenHarmony三方库中心仓网站&#xff08;website&#xff09;&#xff1a;用于检索、查看所需 OpenHarmony 三方库信息&#xff0c;也可…

二维鱼游CFD代码

最近学了会Julia&#xff0c;参考了原作者的shark&#xff0c;做一下基于airfoils 2D的鱼游&#xff0c;暂时没想好有什么需要深入研究的&#xff0c;代码公开如下&#xff1a; 鱼身是naca0016&#xff0c;然后一些参数可以参考我以前发的论文。 using WaterLily, StaticArra…

如何解决网络问题?

组织和 IT 管理员尽其所能完善他们的网络&#xff0c;但是&#xff0c;不同程度的网络问题仍然可能出现&#xff0c;这些网络问题需要立即响应和解决&#xff0c;如果这些问题在不合理的时间内得不到解决&#xff0c;网络和组织的损害可能会付出高昂的代价。这就是为什么 IT 管…

【C#线程设计】3:threadpool

实现&#xff1a; &#xff08;1&#xff09;.控件&#xff1a;group Box&#xff0c;text Box&#xff0c;check Box&#xff0c;label&#xff0c;botton&#xff0c;richtextbox 控件拉取见&#xff1a;https://blog.csdn.net/m0_74749240/article/details/139409510?spm1…

IDEA配置mybatis-config.xml模板文件

IDEA配置mybatis-config.xml模板文件 File>>Settings>>File and Code Templates 创建mybatis-config.xml模板 模板内容取自mybatis官网 mybatis官网 <?xml version"1.0" encoding"UTF-8" ?> <!DOCTYPE configurationPUBLIC &qu…

怎么使用手机远程访问电脑文件?(3种方法)

手机远程访问电脑文件 “有时&#xff0c;当我离开电脑时&#xff0c;仍然需要访问和使用桌面上的文件。是否有一种工具可以通过WiFi而不是USB连接&#xff0c;让我的手机远程访问电脑上的文件&#xff1f;如果有任何建议&#xff0c;我将非常感激&#xff01;” 除了希望手机…

c++引用的本质(反汇编角度分析)

目录 一、引用基础理论 二、 引用的本质 三、从反汇编角度进行分析 1.变量赋值 2.引用和指针初始化 3.通过引用和指针赋值 4.eaxd的作用 一、引用基础理论 在c中我们都知道&#xff0c;引用&#xff08;&&#xff09;就是变量的一个别名&#xff0c;它允许我们为已存…

上市公司-市场竞争程度(1999-2023年)赫份达尔、勒纳指数数据集

数据年份&#xff1a;1999-2023年 有效样本&#xff1a;64505条 数据来源&#xff1a;上市公司年报 数据说明&#xff1a; ① 在行业层面&#xff0c;赫芬达尔指数可衡量一个公司在市场中的相对份额或集中度。它是由每家公司在市场中份额的平方和得到的。指数值越高&#x…

idea打开hierarchy面板

hierarchy&#xff1a;查看类层级关系图 不同版本的IDEA的快捷键不一样&#xff0c;同时如果修改了IDEA快捷键&#xff0c;也可能会不一样&#xff0c;具体查看可通过IDEA上方的Navigate来查看navigate--Type Hierarchy&#xff0c;就可以看见其快捷键了&#xff0c;我的快捷键…

数据结构--递归和数组

个人介绍 hello hello~ &#xff0c;这里是 code袁~&#x1f496;&#x1f496; &#xff0c;欢迎大家点赞&#x1f973;&#x1f973;关注&#x1f4a5;&#x1f4a5;收藏&#x1f339;&#x1f339;&#x1f339; &#x1f981;作者简介&#xff1a;一名喜欢分享和记录学习的…

企业网站策划

企业网站策划是企业推广和宣传的重要组成部分&#xff0c;它不仅是企业对外传达形象和信息的平台&#xff0c;更是企业与客户、供应商、合作伙伴进行交流和互动的重要工具。好的企业网站策划不仅能够展示企业形象和产品信息&#xff0c;还能够为用户提供更好的体验&#xff0c;…

双指针数组问题

删除有序数组中的重复项 重点在于p1 class Solution {public int removeDuplicates(int[] nums) {if(nums.length0) return 0;int p10,p21;while(p2<nums.length){if(nums[p2]!nums[p1]){nums[p1]nums[p2];}else p2;}return p11;} } class Solution {public void moveZeroe…

数据结构(C):二叉树前中后序和层序详解及代码实现及深度刨析

目录 &#x1f31e;0.前言 &#x1f688;1.二叉树链式结构的代码是实现 &#x1f688;2.二叉树的遍历及代码实现和深度刨析代码 &#x1f69d;2.1前序遍历 ✈️2.1.1前序遍历的理解 ✈️2.1.2前序代码的实现 ✈️2.1.3前序代码的深度解剖 &#x1f69d;2.2中序遍历 ✈…

简单记录玩4399游戏flash插件问题

一、因谷歌浏览器默认禁止flash插件自动运行,所以玩家在使用谷歌浏览器,访问www.4399.com平台页面或者4399小游戏(flash资源)时,可能会出现加载异常的情况。今天教大家如何开启flash插件 二、下载falsh官方插件 地址:Flash Player官方下载中心-Flash中国官网 三、如果您…

【Modelground】个人AI产品MVP迭代平台(5)——神投手(实时投篮检测游戏)

文章目录 介绍篮框识别进球算法离屏渲染总结 介绍 神投手是我开发的一款移动端web实时投篮检测游戏&#xff0c;基于Mediapipe对象检测模型&#xff0c;提供数据集&#xff0c;训练出可识别篮框的模型。利用图像处理算法&#xff0c;检测篮球进框的场景。提供了两种模式&#…