微调llama2模型教程:创建自己的Python代码生成器

news2025/1/15 20:41:24

本文将演示如何使用PEFT、QLoRa和Huggingface对新的lama-2进行微调,生成自己的代码生成器。所以本文将重点展示如何定制自己的llama2,进行快速训练,以完成特定任务。

一些知识点

llama2相比于前一代,令牌数量增加了40%,达到2T,上下文长度增加了一倍,并应用分组查询注意(GQA)技术来加速在较重的70B模型上的推理。在标准的transformer 体系结构上,使用RMSNorm归一化、SwiGLU激活和旋转位置嵌入,上下文长度达到了4096个,并应用了具有余弦学习率调度、权重衰减0.1和梯度裁剪的Adam优化器。

有监督微调(SFT)阶段的特点是优先考虑质量样本而不是数量,因为许多报告表明,使用高质量数据可以提高最终模型的性能。

最后,通过带有人类反馈的强化学习(RLHF)步骤使模型与用户偏好保持一致。收集了大量示例,其中人类在比较中选择他们首选的模型输出。这些数据被用来训练奖励模型。

最主要的一点是,LLaMA 2-CHAT已经和OpenAI ChatGPT一样好了,所以我们可以使用它作为我们本地的一个替代了

数据集

对于的微调过程,我们将使用大约18,000个示例的数据集,其中要求模型构建解决给定任务的Python代码。这是原始数据集[2]的提取,其中只选择了Python语言示例。每行包含要解决的任务的描述,如果适用的话,任务的数据输入示例,并提供解决任务的生成代码片段[3]。

 # Load dataset from the hub
 dataset = load_dataset(dataset_name, split=dataset_split)
 # Show dataset size
 print(f"dataset size: {len(dataset)}")
 # Show an example
 print(dataset[randrange(len(dataset))])

创建提示

为了执行指令微调,我们必须将每个数据示例转换为指令,并将其主要部分概述如下:

 def format_instruction(sample):
  return f"""### Instruction:
 Use the Task below and the Input given to write the Response, which is a programming code that can solve the following Task:
 
 ### Task:
 {sample['instruction']}
 
 ### Input:
 {sample['input']}
 
 ### Response:
 {sample['output']}
 """

输出的结果是这样的:

 ### Instruction:
 Use the Task below and the Input given to write the Response, which is a programming code that can solve the following Task:
 
 ### Task:
 Develop a Python program that prints "Hello, World!" whenever it is run.
 
 ### Input:
 
 
 ### Response:
 #Python program to print "Hello World!"
 
 print("Hello, World!")

微调模型

为了方便演示,我们使用Google Colab环境,对于第一次测试运行,T4实例就足够了,但是当涉及到运行整个数据集训练,则需要使用A100。

除此以外,还可以登录Huggingface hub ,这样可以上传和共享模型,当然这个是可选项。

 from huggingface_hub import login
 from dotenv import load_dotenv
 import os
 
 # Load the enviroment variables
 load_dotenv()
 # Login to the Hugging Face Hub
 login(token=os.getenv("HF_HUB_TOKEN"))

PEFT、Lora和QLora

训练LLM的通常步骤包括:首先,对数十亿或数万亿个令牌进行预训练得到基础模型,然后对该模型进行微调,使其专门用于下游任务。

参数高效微调(PEFT)允许我们通过微调少量额外参数来大大减少RAM和存储需求,因为所有模型参数都保持冻结状态。并且PEFT还增强了模型的可重用性和可移植性,它很容易将小的检查点添加到基本模型中,通过添加PEFT参数让基础模型在多个场景中重用。最后由于没有调整基本模型,还可以保留在预训练阶段获得的所有知识,从而避免了灾难性遗忘。

PEFT保持预训练的基本模型不变,并在其上添加新的层或参数。这些层被称为“适配器”,我们将这些层添加到预训练的基本模型中,只训练这些新层的参数。但是这种方法的一个严重问题是,这些层会导致推理阶段的延迟增加,从而使流程在许多情况下效率低下。

而在LoRa技术(大型语言模型的低秩适应)中不是添加新的层,而是以一种避免在推理阶段出现这种可怕的延迟问题的方式向模型各层参数添加值。LoRa训练并存储附加权重的变化,同时冻结预训练模型的所有权重。也就是说我们利用预训练模型矩阵的变化训练一个新的权重矩阵,并将这个新矩阵分解为2个低秩矩阵,如下所示:

LoRA[1]的作者提出权值变化矩阵∆W的变化可以分解为两个低秩矩阵A和b。LoRA不直接训练∆W中的参数,而是直接训练A和b中的参数,因此可训练参数的数量要少得多。假设A的维数为100 * 1,B的维数为1 * 100,则∆W中的参数个数为100 * 100 = 10000。在A和B中训练的人数只有100 + 100 = 200,而在∆W中训练的个数是10000

这些低秩矩阵的大小由r参数定义。这个值越小,需要训练的参数就越少,速度更快。但是参数过少可能会损失信息和性能,所以r参数的选择也是需要考虑的问题。

最后,QLoRa[6]则是将量化应用于LoRa方法,通过优化内存使用的技巧,以实现“更轻量”和更便宜的训练。

微调流程

我们的示例中使用QLoRa,所以要指定BitsAndBytes配置,下载4位量化的预训练模型,定义LoraConfig。

 # Get the type
 compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
 
 # BitsAndBytesConfig int-4 config
 bnb_config = BitsAndBytesConfig(
     load_in_4bit=use_4bit,
     bnb_4bit_use_double_quant=use_double_nested_quant,
     bnb_4bit_quant_type=bnb_4bit_quant_type,
     bnb_4bit_compute_dtype=compute_dtype
 )
 # Load model and tokenizer
 model = AutoModelForCausalLM.from_pretrained(model_id, 
   quantization_config=bnb_config, use_cache = False, device_map=device_map)
 model.config.pretraining_tp = 1
 # Load the tokenizer
 tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
 tokenizer.pad_token = tokenizer.eos_token
 tokenizer.padding_side = "right"

下面是参数定义,

 # Activate 4-bit precision base model loading
 use_4bit = True
 # Compute dtype for 4-bit base models
 bnb_4bit_compute_dtype = "float16"
 # Quantization type (fp4 or nf4)
 bnb_4bit_quant_type = "nf4"
 # Activate nested quantization for 4-bit base models (double quantization)
 use_double_nested_quant = False
 # LoRA attention dimension
 lora_r = 64
 # Alpha parameter for LoRA scaling
 lora_alpha = 16
 # Dropout probability for LoRA layers
 lora_dropout = 0.1

接下来的步骤对于所有的Hugging Face用户来说应该都很熟悉了,设置训练参数,创建Trainer。在执行指令微调时,我们调用封装PEFT模型定义和其他步骤的SFTTrainer方法。

 # Define the training arguments
 args = TrainingArguments(
     output_dir=output_dir,
     num_train_epochs=num_train_epochs,
     per_device_train_batch_size=per_device_train_batch_size, # 6 if use_flash_attention else 4,
     gradient_accumulation_steps=gradient_accumulation_steps,
     gradient_checkpointing=gradient_checkpointing,
     optim=optim,
     logging_steps=logging_steps,
     save_strategy="epoch",
     learning_rate=learning_rate,
     weight_decay=weight_decay,
     fp16=fp16,
     bf16=bf16,
     max_grad_norm=max_grad_norm,
     warmup_ratio=warmup_ratio,
     group_by_length=group_by_length,
     lr_scheduler_type=lr_scheduler_type,
     disable_tqdm=disable_tqdm,
     report_to="tensorboard",
     seed=42
 )
 # Create the trainer
 trainer = SFTTrainer(
     model=model,
     train_dataset=dataset,
     peft_config=peft_config,
     max_seq_length=max_seq_length,
     tokenizer=tokenizer,
     packing=packing,
     formatting_func=format_instruction,
     args=args,
 )
 # train the model
 trainer.train() # there will not be a progress bar since tqdm is disabled
 
 # save model in local
 trainer.save_model()

这些参数大多数通常用于llm上的其他微调脚本,我们就不做过多的说明了:

 # Number of training epochs
 num_train_epochs = 1
 # Enable fp16/bf16 training (set bf16 to True with an A100)
 fp16 = False
 bf16 = True
 # Batch size per GPU for training
 per_device_train_batch_size = 4
 # Number of update steps to accumulate the gradients for
 gradient_accumulation_steps = 1
 # Enable gradient checkpointing
 gradient_checkpointing = True
 # Maximum gradient normal (gradient clipping)
 max_grad_norm = 0.3
 # Initial learning rate (AdamW optimizer)
 learning_rate = 2e-4
 # Weight decay to apply to all layers except bias/LayerNorm weights
 weight_decay = 0.001
 # Optimizer to use
 optim = "paged_adamw_32bit"
 # Learning rate schedule
 lr_scheduler_type = "cosine" #"constant"
 # Ratio of steps for a linear warmup (from 0 to learning rate)
 warmup_ratio = 0.03
 # Group sequences into batches with same length
 # Saves memory and speeds up training considerably
 group_by_length = False
 # Save checkpoint every X updates steps
 save_steps = 0
 # Log every X updates steps
 logging_steps = 25
 # Disable tqdm
 disable_tqdm= True

合并权重

正如上面我们提到的方法,LoRa在基本模型上训练了“修改权重”,所以最终模型需要将预训练的模型和适配器权重合并到一个模型中。

 from peft import AutoPeftModelForCausalLM
 
 model = AutoPeftModelForCausalLM.from_pretrained(
     args.output_dir,
     low_cpu_mem_usage=True,
     return_dict=True,
     torch_dtype=torch.float16,
     device_map=device_map,    
 )
 
 # Merge LoRA and base model
 merged_model = model.merge_and_unload()
 
 # Save the merged model
 merged_model.save_pretrained("merged_model",safe_serialization=True)
 tokenizer.save_pretrained("merged_model")
 # push merged model to the hub
 merged_model.push_to_hub(hf_model_repo)
 tokenizer.push_to_hub(hf_model_repo)

推理

最后就是推理的过程了

 import torch
 from transformers import AutoModelForCausalLM, AutoTokenizer
 
 # Get the tokenizer
 tokenizer = AutoTokenizer.from_pretrained(hf_model_repo)
 # Load the model
 model = AutoModelForCausalLM.from_pretrained(hf_model_repo, load_in_4bit=True, 
                                              torch_dtype=torch.float16,
                                              device_map=device_map)
 # Create an instruction
 instruction="Optimize a code snippet written in Python. The code snippet should create a list of numbers from 0 to 10 that are divisible by 2."
 input=""
 
 prompt = f"""### Instruction:
 Use the Task below and the Input given to write the Response, which is a programming code that can solve the Task.
 
 ### Task:
 {instruction}
 
 ### Input:
 {input}
 
 ### Response:
 """
 # Tokenize the input
 input_ids = tokenizer(prompt, return_tensors="pt", truncation=True).input_ids.cuda()
 # Run the model to infere an output
 outputs = model.generate(input_ids=input_ids, max_new_tokens=100, do_sample=True, top_p=0.9,temperature=0.5)
 
 # Print the result
 print(f"Prompt:\n{prompt}\n")
 print(f"Generated instruction:\n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}")

结果如下:

 Prompt:
 ### Instruction:
 Use the Task below and the Input given to write the Response, which is a programming code that can solve the Task.
 
 ### Task:
 Optimize a code snippet written in Python. The code snippet should create a list of numbers from 0 to 10 that are divisible by 2.
 
 ### Input:
 arr = []
 for i in range(10):
  if i % 2 == 0:
  arr.append(i)
 
 ### Response:
 
 
 Generated instruction:
 arr = [i for i in range(10) if i % 2 == 0]
 
 Ground truth:
 arr = [i for i in range(11) if i % 2 == 0]

看样子还是很不错的

总结

以上就是我们微调llama2的完整过程,这里面的一个最重要的步骤其实是提示的生成,一个好的提示对于模型的性能也是非常有帮助的。

[1] Llama-2 paper https://arxiv.org/pdf/2307.09288.pdf

[2] python code dataset http://sahil2801/code_instructions_120k

[3] 本文使用的数据集 https://huggingface.co/datasets/iamtarun/python_code_instructions_18k_alpaca

[4] LoRA: Low-Rank Adaptation of Large Language Models. arXiv:2106.09685

[5]. QLoRa: Efficient Finetuning of QuantizedLLMs arXiv:2305.14314

https://avoid.overfit.cn/post/9794c9eef1df4e55adf514b3d727ee3b

作者:Eduardo Muñoz

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/926898.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CentOS7.9安装Java11

文章目录 Java11版本介绍安装步骤查看并卸载已有版本安装Java11最新版本配置生效 openjdk介绍 Java11版本介绍 Java 11是Java编程语言的一个重要版本,于2018年9月发布Java 11在语言特性、性能优化和安全性方面都有一些显著的改进,为Java开发者提供了更多…

.NetCore之log4net的使用

1.首先下载log4ne的包&#xff1a; 2.添加配置文件log4net.config <?xml version"1.0" encoding"utf-8" ?> <configuration><!-- This section contains the log4net configuration settings --><log4net><appender name&q…

初探函数式编程---以Map/Reduce/Filter为例

如函数式编程--酷壳[1] 总结&#xff0c; 函数式编程的三大特性; 数据不可变性 函数作为一等公民(函数可以像变量一样来创建/修改/传递 等) 尾递归优化(重用stack,减轻栈的压力) 函数式编程用到的几个技术&#xff1a; 函数式编程的理念&#xff1a;把函数当成变量来用&#xf…

亚马逊电咖啡壶UL1082测试标准

UL1082标准是适用于额定电压为120V,按照国家电气编码进行使用的便携式电咖啡壶&#xff0c;咖啡渗漏壶及其它酿造类器具&#xff0c;除了咖啡壶外&#xff0c;本标准也适用于荼壶、水煲、玻璃水煲、汤保温壶及其它类似器具。这些器具都具有以下特点&#xff1a; &#xff08;1…

Day06-Vue全家桶项目

Day01-Vue全家桶项目 一 全家桶项目介绍 Vue在使用脚手架创建项目的时候,提供前端工程化项目 目前主要学习了Vue基础:Vue指令、Vue组件开发、Vue样式、组件通信、生命周期 全家桶项目是很多技术结合的一种开发模式: 全家桶项目搭建路由搭建(前端路由)网络请求封装Elem…

【Java从0到1学习】11 Java集合框架

1. Collection 1.1 Java类中集合的关系图 1.2 集合类概述 在程序中可以通过数组来保存多个对象&#xff0c;但在某些情况下开发人员无法预先确定需要保存对象的个数&#xff0c;此时数组将不再适用&#xff0c;因为数组的长度不可变。例如&#xff0c;要保存一个学校的学生信…

华为数通方向HCIP-DataCom H12-821题库(单选题:81-100)

第81题 某公司新购入一台网络设备,作为网络管理员,初次配置该设备通常通过什么方式? A、FTP B、Telnet C、SNMP D、Console 口登录 答案: D 解析&#xff1a; 通常情况下&#xff0c;初次配置网络设备会通过Console口登录的方式进行。Console口是一种串口接口&#xff0c…

网络安全工程师岗位一览-徐庆臣(黑客洗白者)

安全服务工程师 安全运维工程师 渗透测试工程师 Web安全工程师 安全攻防工程师 等保测评工程师 …… 代码审计工程师 威胁分析工程师 无线安全工程师 安全研发工程师 移动安全工程师 云计算安全工程师 ……

Socket通信与WebSocket协议

文章目录 目录 文章目录 前言 一、Socket通信 1.1 BIO 1.2 NIO 1.3 AIO 二、WebSocket协议 总结 前言 一、Socket通信 Socket是一种用于网络通信的编程接口&#xff08;API&#xff09;&#xff0c;它提供了一种机制&#xff0c;使不同主机之间可以通过网络进行数据传输和通信…

算法通关村十三关 | 进制转换问题处理模板

1. 七进制数 题目&#xff1a;LeetCode504&#xff1a;504. 七进制数 - 力扣&#xff08;LeetCode&#xff09; 思路 进制转换&#xff0c;对几转换就是对几求余&#xff0c;最后将所有的余数反过来即可、如果num< 0&#xff0c;先取绝对值&#xff0c;再进行操作。 100转7…

Ceres Solver 入门

1. Ceres Solver 是什么 Ceres 可以解决以下形式的边界约束鲁棒化非线性最小二乘问题&#xff1a; 给定初始值&#xff0c;通过优化算法&#xff0c;得到最优解。 其中&#xff0c; f i f_i fi​是CostFunction&#xff0c;也叫误差函数&#xff0c;或者代价函数。 ρ i \rho…

第十七课:利用 Setup Factory 制作 Qt 软件安装包

功能描述&#xff1a;详细介绍如何利用 Setup Factory 制作 Qt 软件安装包&#xff0c;从 Setup Factory 软件下载、安装&#xff0c;到如何利用 Setup Factory 制作软件安装包&#xff0c;手把手教你将 Qt 应用程序制作成具有安装向导的安装包。 一、Setup Factory 简介 Setu…

C语言程序结构、基本语法与数据类型

文章目录 1. 程序结构1.1 Hello World示例1.2 编译并执行C程序 2. 基本语法2.1 C 标记2.2 分号2.3 注释2.4 标识符2.5 关键字2.6 C中的空格 3. 数据类型3.1 整数类型3.2 浮点类型3.3 void类型 1. 程序结构 1.1 Hello World示例 #include <stdio.h>int main() {/* my fi…

DirectExchange直连交换机

目录 一、简介 二、使用步骤 三、demo 父pom文件 pom文件 配置文件 config 消费者 生产者 测试 一、简介 直连型交换机&#xff0c;根据消息携带的路由键将消息投递给对应队列。 大致流程&#xff0c;有一个队列绑定到一个直连交换机上&#xff0c;同时赋予一个路由…

AMEYA360代理品牌:纳芯微芯片解决方案为光伏市场赋能

近年来&#xff0c;光伏市场进入了一个新的增长维度。SolarPower Europe数据显示&#xff0c;2022年全球光伏新增装机量达239GW&#xff0c;占所有可再生能源新增容量的三分之二。国家能源局也宣称&#xff0c;2022年我国工商业光伏新增装机达25.87GW&#xff0c;同比增长236.7…

淘宝商品数据采集(如何快速获取淘宝商品信息),淘宝API接口申请指南

淘宝作为国内的电商平台&#xff0c;拥有海量的商品信息。对于想要进行淘宝商品数据采集的人来说&#xff0c;如何快速获取淘宝商品信息是一个重要的问题。本文将介绍一些快速获取淘宝商品信息的方法。 1. 使用淘宝开放平台PI 淘宝开放平台提供了多种PI接口&#xff0c;可以通…

如何选择合适的开源许可证?

&#x1f337;&#x1f341; 博主猫头虎 带您 Go to New World.✨&#x1f341; &#x1f984; 博客首页——猫头虎的博客&#x1f390; &#x1f433;《面试题大全专栏》 文章图文并茂&#x1f995;生动形象&#x1f996;简单易学&#xff01;欢迎大家来踩踩~&#x1f33a; &a…

java八股文面试[java基础]——异常

自定义异常&#xff1a; 异常Exception 是指程序运行时&#xff0c; 由于输入错误、网络、程序逻辑等原因导致运行时出现的问题。出现异常时&#xff0c;程序会暂时中断执行&#xff0c;并根据产生异常的原因&#xff0c;创建对应异常类型的异常对象&#xff0c;并抛出给JVM捕…

高速收费站的智慧之选,工控机助力顺畅通行!

2020年初取消高速公路省界收费站后&#xff0c;全国高速公路进入“一张网运行、一体化服务”的新阶段。随着ETC用户量快速增长、驾乘人员对收费站高效通行需求不断提升&#xff0c;收费数据在线化运营及精准化、智能化、人性化的收费服务将成为主流。如何提高收费系统集成度、降…

Day4:前端路由(进阶篇)

目标: 持续输出&#xff01;每日分享关于web前端常见知识、面试题、性能优化、新技术等方面的内容。 主要面向群体&#xff1a;前端开发工程师&#xff08;初、中、高级&#xff09;、应届、转行、培训等同学 Day4-今日话题 今天分享的是前端路由的进阶篇&#xff0c;将从路由的…