LLM微调(一)| 单GPU使用QLoRA微调Llama 2.0实战

news2024/9/26 3:19:42

      最近LLaMA 2在LLaMA1 的基础上做了很多优化,比如上下文从2048扩展到4096,使用了Grouped-Query Attention(GQA)共享多头注意力的key 和value矩阵,具体可以参考:

关于LLaMA 2 的细节,可以参考如下文章:

Meta发布升级大模型LLaMA 2:开源可商用

揭秘最领先的Llama2中文大模型!

使用QLoRA微调LLaMA 2

安装环境

pip install transformers datasets peft accelerate bitsandbytes safetensors

导入库

import os, sysimport torchimport datasetsfrom transformers import (    AutoTokenizer,    AutoModelForCausalLM,    BitsAndBytesConfig,    DataCollatorForLanguageModeling,    DataCollatorForSeq2Seq,    Trainer,    TrainingArguments,    GenerationConfig)from peft import PeftModel, LoraConfig, prepare_model_for_kbit_training, get_peft_model

导入LLaMA 2模型

### config ###model_id = "NousResearch/Llama-2-7b-hf" # optional meta-llama/Llama-2–7b-chat-hfmax_length = 512device_map = "auto"batch_size = 128micro_batch_size = 32gradient_accumulation_steps = batch_size // micro_batch_size# nf4" use a symmetric quantization scheme with 4 bits precisionbnb_config = BitsAndBytesConfig(    load_in_4bit=True,    bnb_4bit_use_double_quant=True,    bnb_4bit_quant_type="nf4",    bnb_4bit_compute_dtype=torch.bfloat16)# load model from huggingfacemodel = AutoModelForCausalLM.from_pretrained(    model_id,    quantization_config=bnb_config,    use_cache=False,    device_map=device_map)# load tokenizer from huggingfacetokenizer = AutoTokenizer.from_pretrained(model_id)tokenizer.pad_token = tokenizer.eos_tokentokenizer.padding_side = "right"

输出模型的可训练参数量

def print_number_of_trainable_model_parameters(model):    trainable_model_params = 0    all_model_params = 0    for _, param in model.named_parameters():        all_model_params += param.numel()        if param.requires_grad:            trainable_model_params += param.numel()    print(f"trainable model parameters: {trainable_model_params}. All model parameters: {all_model_params} ")    return trainable_model_paramsori_p = print_number_of_trainable_model_parameters(model)# 输出# trainable model parameter: 262,410,240

配置LoRA参数

# LoRA configmodel = prepare_model_for_kbit_training(model)peft_config = LoraConfig(    r=8,    lora_alpha=32,    lora_dropout=0.1,    target_modules=["q_proj", "v_proj"],    bias="none",    task_type="CAUSAL_LM",)model = get_peft_model(model, peft_config)### compare trainable parameters #peft_p = print_number_of_trainable_model_parameters(model)print(f"# Trainable Parameter \nBefore: {ori_p} \nAfter: {peft_p} \nPercentage: {round(peft_p / ori_p * 100, 2)}")# 输出# trainable model parameter: 4,194,304

r:更新矩阵的秩,也称为Lora注意力维度。较低的秩导致具有较少可训练参数的较小更新矩阵。增加r(不超过32)将导致更健壮的模型,但同时会导致更高的内存消耗。

lora_lpha:控制lora比例因子

target_modules:是一个模块名称列表,如“q_proj”和“v_proj“,用作LoRA模型的目标。具体的模块名称可能因基础模型而异。

bias:指定是否应训练bias参数。可选参数为:“none”、“all”或“lora_only”。

输出LoRA Adapter的参数,发现只占原模型的不到2%。

在微调LLaMA 2之前,我们看一下LLaMA 2的生成效果

### generate ###prompt = "Write me a poem about Singapore."inputs = tokenizer(prompt, return_tensors="pt")generate_ids = model.generate(inputs.input_ids, max_length=64)print('\nAnswer: ', tokenizer.decode(generate_ids[0]))res = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]print(res)

       当要求模型写一首关于新加坡的诗时,产生的输出似乎相当模糊和重复,这表明模型很难提供连贯和有意义的回应。

微调数据加载

为了方便演示,我们使用开源的databricks/databricks-dolly-15k,数据格式如下:

{    'instruction': 'Why can camels survive for long without water?',    'context': '',    'response': 'Camels use the fat in their humps to keep them filled with energy and hydration for long periods of time.',    'category': 'open_qa',}

       要揭秘LLM能力,构建Prompt是至关重要,通常的Prompt形式有三个字段:Instruction、Input(optional)、Response。由于Input是可选的,因为这里设置了两种prompt_template,分别是有Input 的prompt_input和无Input 的prompt_no_input,代码如下:

max_length = 256dataset = datasets.load_dataset(    "databricks/databricks-dolly-15k", split='train')### generate prompt based on template ###prompt_template = {    "prompt_input": \    "Below is an instruction that describes a task, paired with an input that provides further context.\    Write a response that appropriately completes the request.\    \n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",    "prompt_no_input": \    "Below is an instruction that describes a task.\    Write a response that appropriately completes the request.\    \n\n### Instruction:\n{instruction}\n\n### Response:\n",    "response_split": "### Response:"}def generate_prompt(instruction, input=None, label=None, prompt_template=prompt_template):    if input:        res = prompt_template["prompt_input"].format(            instruction=instruction, input=input)    else:        res = prompt_template["prompt_no_input"].format(            instruction=instruction)    if label:        res = f"{res}{label}"    return res

      使用generate_prompt函数把instruction, context和response拼接起来;然后进行tokenize分词处理,转换为input_ids和attention_mask,为了让模型可以预测下一个token,设计了类似input_ids的labels便于右移操作;

def tokenize(tokenizer, prompt, max_length=max_length, add_eos_token=False):    result = tokenizer(        prompt,        truncation=True,        max_length=max_length,        padding=False,        return_tensors=None)    result["labels"] = result["input_ids"].copy()    return resultdef generate_and_tokenize_prompt(data_point):    full_prompt = generate_prompt(        data_point["instruction"],        data_point["context"],        data_point["response"],    )    tokenized_full_prompt = tokenize(tokenizer, full_prompt)    user_prompt = generate_prompt(data_point["instruction"], data_point["context"])    tokenized_user_prompt = tokenize(tokenizer, user_prompt)    user_prompt_len = len(tokenized_user_prompt["input_ids"])    mask_token = [-100] * user_prompt_len    tokenized_full_prompt["labels"] = mask_token + tokenized_full_prompt["labels"][user_prompt_len:]    return tokenized_full_promptdataset = dataset.train_test_split(test_size=1000, shuffle=True, seed=42)cols = ["instruction", "context", "response", "category"]train_data = dataset["train"].shuffle().map(generate_and_tokenize_prompt, remove_columns=cols)val_data = dataset["test"].shuffle().map(generate_and_tokenize_prompt, remove_columns=cols,)

模型训练

args = TrainingArguments(    output_dir="./llama-7b-int4-dolly",    num_train_epochs=20,    max_steps=200,    fp16=True,    optim="paged_adamw_32bit",    learning_rate=2e-4,    lr_scheduler_type="constant",    per_device_train_batch_size=micro_batch_size,    gradient_accumulation_steps=gradient_accumulation_steps,    gradient_checkpointing=True,    group_by_length=False,    logging_steps=10,    save_strategy="epoch",    save_total_limit=3,    disable_tqdm=False,)trainer = Trainer(    model=model,    train_dataset=train_data,    eval_dataset=val_data,    args=args,    data_collator=DataCollatorForSeq2Seq(      tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True),)# silence the warnings. re-enable for inference!model.config.use_cache = Falsetrainer.train()model.save_pretrained("llama-7b-int4-dolly")

模型测试

       模型训练几个小时结束后,我们合并预训练模型Llama-2–7b-hf和LoRA参数,我们还是以“Write me a poem about Singapore”测试效果,代码如下:

# model path and weightmodel_id = "NousResearch/Llama-2-7b-hf"peft_path = "./llama-7b-int4-dolly"# loading modelmodel = AutoModelForCausalLM.from_pretrained(    model_id,    quantization_config=bnb_config,    use_cache=False,    device_map="auto")# loading peft weightmodel = PeftModel.from_pretrained(    model,    peft_path,    torch_dtype=torch.float16,)model.eval()# generation configgeneration_config = GenerationConfig(    temperature=0.1,    top_p=0.75,    top_k=40,    num_beams=4, # beam search)# generating replywith torch.no_grad():    prompt = "Write me a poem about Singapore."    inputs = tokenizer(prompt, return_tensors="pt")    generation_output = model.generate(        input_ids=inputs.input_ids,        generation_config=generation_config,        return_dict_in_generate=True,        output_scores=True,        max_new_tokens=64,    )    print('\nAnswer: ', tokenizer.decode(generation_output.sequences[0]))

生成模型中参数temperaturetop-ktop-pnum_beam含义可以参考:https://github.com/ArronAI007/Awesome-AGI/blob/main/LLM%E4%B9%8BGenerate%E4%B8%AD%E5%8F%82%E6%95%B0%E8%A7%A3%E8%AF%BB.ipynb

参考文献:

[1] https://ai.plainenglish.io/fine-tuning-llama2-0-with-qloras-single-gpu-magic-1b6a6679d436

[2] https://github.com/ChanCheeKean/DataScience/blob/main/13%20-%20NLP/E04%20-%20Parameter%20Efficient%20Fine%20Tuning%20(PEFT).ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1026564.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DHorse v1.4.0 发布,基于 k8s 的发布平台

版本说明 新增特性 提供Fabric8客户端操作k8s(预览)的功能,可以通过指定-Dkubernetes-clientfabric8参数开启;Vue、React应用增加Pnpm、Yarn的构建方式;支持Go、Flask、Django、Nuxt应用部署; 优化特性 …

2000-2021年上市公司设立党组织数据

2000-2021年上市公司设立党组织数据 1、时间:2000-2021年 2、指标:时间、证券代码、公司名称、行业名称、所在城市、所在省份、是否建立党组织、建立党组织时间、党组织年龄; 董监高中党组织成员人数、高管中党组织成员人数、董事中党组织…

Java笔记二

学习资源来自哔哩哔哩——遇见狂神说——狂神说Java 目录 数据类型: 字符 字符串 布尔值 强制转换: 变量 常量: 运算符: 数据类型: long定义的必须在数字后面L float定义的要在数字后面加F 如 long num130…

计算机视觉与深度学习-图像分割-视觉识别任务02-目标检测-【北邮鲁鹏】

目录标题 参考目标检测定义深度学习对目标检测的作用单目标检测多任务框架多任务损失预训练模型姿态估计 多目标检测问题滑动窗口(Sliding Window)滑动窗口缺点 AdaBoost(Adaptive Boosting)参考 区域建议 selective search 思想慢…

ISP代理是什么?双ISP是什么意思?

代理是路由互联网流量的中间服务器,通常分为两类:数据中心、住宅ISP。根据定义,ISP 代理隶属于互联网服务提供商,但实际上,大家会将它们视为数据中心和住宅代理的组合。 让我们仔细研究一下ISP代理,看看它们…

verilog学习笔记(1)module实例化

兜兜转转又回来学硬件了,哎,命啊! 我的答案(有bug): module top_module ( input a, input b, output out );wire w1;wire w2;wire w3;mod_a mod_a_inst1(.in1(w1),.in2(w2),.out(w3) );assign w1 a…

【Vue】轻松理解数据代理

hello&#xff0c;我是小索奇&#xff0c;精心制作的Vue教程持续更新哈&#xff0c;想要学习&巩固&避坑就一起学习叭~ Object定义配置方法 代码 引出数据代理&#xff0c;先上代码&#xff0c;后加解释 <!DOCTYPE html> <html><head><meta cha…

助力工业智能化升级 复合移动机器人生态圈在沪启动

9月19日&#xff0c;由移动机器人&#xff08;AGV/AMR&#xff09;产业联盟组织&#xff0c;深圳优艾智合机器人科技有限公司&#xff08;以下简称“优艾智合”&#xff09;牵头&#xff0c;工业机器人产业上下游30家代表企业共同组成的复合移动机器人生态圈在上海国家会展中心…

SpringBoot 统一登录鉴权、异常处理、数据格式

本篇将要学习 Spring Boot 统一功能处理模块&#xff0c;这也是 AOP 的实战环节 用户登录权限的校验实现接口 HandlerInterceptor WebMvcConfigurer 异常处理使用注解 RestControllerAdvice ExceptionHandler 数据格式返回使用注解 ControllerAdvice 并且实现接口 Response…

Android 9 底部导航栏样式不正确

1.项目预制了GMS后&#xff0c;底部导航栏只剩下一个返回键和唤醒Assistant的按钮&#xff0c;需要回到原来的导航栏来 修改方式屏蔽掉 config_defaultAssistantAccessPackage&#xff0c;使用Android原始的config_defaultAssistantAccessPackage vendor/partner_gms/product…

超硬核的Move Dev Meetup上海线下交流会圆满结束

北京时间9月16日下午2–6点&#xff0c;由MoveFuns DAO联合其他组织举办的Move开发者线下交流会在上海悦达国际大厦圆满完成。此次活动也是上海区块链周的周边活动&#xff0c;受到了Web3从业者的广泛关注。 本场交流会邀请了OpenBuild技术社区主理人Ian主持&#xff0c;50余位…

百度测开面试题分享

1、java常用的异常处理机制&#xff1f; Java常用的异常处理机制有以下几种&#xff1a; 1&#xff09;try-catch-finally 语句&#xff1a;用于捕获和处理异常。将可能抛出异常的代码放在try块中&#xff0c;然后在catch块中处理异常。无论是否发生异常&#xff0c;finally块…

MySQL什么情况下会死锁,发生了死锁怎么处理呢?

&#x1f3c6;作者简介&#xff0c;黑夜开发者&#xff0c;CSDN领军人物&#xff0c;全栈领域优质创作者✌&#xff0c;CSDN博客专家&#xff0c;阿里云社区专家博主&#xff0c;2023年6月CSDN上海赛道top4。 &#x1f3c6;数年电商行业从业经验&#xff0c;历任核心研发工程师…

Zabbix5.0_介绍_组成架构_以及和prometheus的对比_大数据环境下的监控_网络_软件_设备监控_Zabbix工作笔记

z 这里Zabbix可以实现采集 存储 展示 报警 但是 zabbix自带的,展示 和报警 没那么好看,我们可以用 grafana进行展示,然后我们用一个叫睿象云的来做告警展示, 会更丰富一点. 可以看到 看一下zabbix的介绍. 对zabbix的介绍,这个zabbix比较适合对服务器进行监控 这个是zabbix的…

华为云云耀云服务器L实例评测|轻量级应用服务器对决:基于 STREAM 深度测评华为云云耀云服务器L实例的内存性能

本文收录在专栏&#xff1a;#云计算入门与实践 - 华为云 专栏中&#xff0c;本系列博文还在更新中 相关华为云云耀云服务器L实例评测文章列表如下&#xff1a; 华为云云耀云服务器L实例评测 | 从零开始&#xff1a;云耀云服务器L实例的全面使用解析指南华为云云耀云服务器L实…

C++虚函数表

一、虚函数和纯虚函数 1.1 虚函数 在类成员方法的声明 (不是定义) 语句前加 “virtual”&#xff0c;如 virtual void func() class ISpeaker { public:virtual void func(); }; 1.2 纯虚函数 在虚函数后加 “0”&#xff0c;如 virtual void func()0 class ISpeaker { pu…

【JAVA-Day29】 多维数组和一维数组的区别:简明对照

多维数组和一维数组的区别&#xff1a;简明对照 多维数组和一维数组的区别&#xff1a;简明对照摘要&#xff08;博主语气&#xff09;&#xff1a;多维数组和一维数组是编程中常用的数据结构&#xff0c;它们在定义和使用上有很大的不同。本文将详细介绍它们的区别&#xff0c…

优麒麟下载、安装、体验

下载 官网 优麒麟 点击增强版、或者基础版进行下载 虚拟机安装 选择镜像 修改名称和存储路径 设置为50G 下一步&#xff0c;点击完成 开启安装 设置语言 去掉下载更新选项 继续 点击restart now 输入密码 出现下图说明安装成功&#xff0c;可以畅快的使用了

React中插槽处理机制

React中插槽处理机制 需求&#xff1a;假如底部可能有按钮&#xff0c;根据需求判断需要展示或不展示&#xff0c;或者需要展示不同的按钮或者其他DOM 解决1&#xff1a;需要的按钮可以在组件中写死&#xff0c;后期基于传递进来的属性来进行判断 解决2&#xff1a;我们也可以…

Nacos安装指南(Windows环境)

Windows安装 开发阶段采用单机安装即可。 1.下载安装包 在Nacos的GitHub页面&#xff0c;提供有下载链接&#xff0c;可以下载编译好的Nacos服务端或者源代码&#xff1a; GitHub主页&#xff1a;https://github.com/alibaba/nacos GitHub的Release下载页&#xff1a;https:…