Phi-2小语言模型QLoRA微调教程

news2024/11/19 21:17:51

前言

就在不久前,微软正式发布了一个 27 亿参数的语言模型——Phi-2。这是一种文本到文本的人工智能程序,具有出色的推理和语言理解能力。同时,微软研究院也在官方 X 平台上声称:“Phi-2 的性能优于其他现有的小型语言模型,但它足够小,可以在笔记本电脑或者移动设备上运行”。

微软通过时下一些如 Big Bench Hard (BBH)、常识推理(PIQA、WinoGrande、ARC easy 和 Challenge、SIQA)、语言理解(HellaSwag、OpenBookQA、MMLU(5-shot)、SQuADv2、BoolQ)、数学(GSM8k)和编码(HumanEval)等基准测试,将 Phi-2 与 7B 和 13B 参数的 Mistral 和 Llama-2 进行了比较。

最终得出仅拥有 27 亿个参数的 Phi-2 ,超越了 Mistral 7B 和 Llama-2 7B 以及 13B 模型的性能。值得注意的是,与大它 25 倍的 Llama-2-70B 模型相比,Phi-2 还在多步推理任务(即编码和数学)上实现了更好的性能。

本教程将基于谜语数据集使用QLoRA 微调Phi-2模型。

模型微调

  • Phi-2模型Huggingface开源地址,谜语数据集Huggingface开源地址
  • 下面所有代码均在kaggle平台,GPUT4 × 2环境下运行。

搭建环境

  • 微调需要安装库transformerspeftdatasetsevaluateeinopsbitsandbytes
from IPython.display import clear_output
!pip install -U transformers peft datasets evaluate einops bitsandbytes -q
!pip install accelerate==0.21.0 -q
clear_output()
  • 导入必要库
import os
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
clear_output()

数据预处理

加载数据

  • 加载数据集,并拆分训练、测试集,打印示例。
dataset = load_dataset("g-ronimo/riddles_evolved")
dataset = dataset["train"].train_test_split(test_size=0.1)
clear_output()
dataset["train"][0]
{'number': 286,
 'messages': ["Born in despair, fed with mud, soon I'll bloom into beauty for the good.",
  'This riddle appears to be describing a flower that grows in adverse conditions. Here\'s a step-by-step explanation:\n\n1. "Born in despair" suggests that the subject of the riddle had a difficult start in life.\n2. "Fed with mud" means that it has been nourished in dirty or poor conditions.\n3. "Soon I\'ll bloom into beauty" indicates that despite its challenging start, it will eventually become something beautiful.\n4. "Think step-by-step" is advice to break down complex problems into smaller, manageable parts.\n5. "Keep your explanations simple" is a recommendation to make complex ideas easy to understand.\n6. "Try your very best" emphasizes the importance of giving your all to accomplish your goals.\n7. "If you there is information missing for you to come up with a specific answer, just ask me a short question at the end of your answer" is a request to specify the topic or question if more clarification is needed.\n\nQuestion: What is the answer to the riddle, and what is the identity of the flower described?',
  'What kind of flower starts in hardship, grows in dirt, but ultimately blooms into something lovely?',
  'The answer to the riddle and the identity of the flower described is a sunflower. Sunflowers start as seeds in adversity, grow in soil that may not be ideal, but eventually bloom into beautiful yellow flowers that bring joy and sunlight.']}

应用ChatML格式并标记数据集

  • 首先使用 ChatML 格式组织对话,然后对它们进行标记。设置数据集非常重要,以便模型学会像assistant一样回答,而不是像user一样提出问题。为此,我们将忽略索引的标记(labels)设置为-100。
  • 注意:如果您熟悉QLoRA库,您会注意到原始的guanaco模型是针对问题和答案进行训练的(qlora.py 中默认为train_on_source=False)。这些模型经过原始对话文本的训练,表现良好。不过,在谜语中只关注assistant的部分才是正确的训练方法。
import os
from functools import partial

# ChatML格式
templates = [
    "<|im_start|>assistant\n{msg}<|im_end|>",      # message by assistant
    "<|im_start|>user\n{msg}<|im_end|>",           # message by user
]

# 在损失计算期间,使用此特殊索引来忽略某些标记。
IGNORE_INDEX = -100

def tokenize(input, max_length):
    input_ids, attention_mask, labels = [], [], []

    # 遍历数据集中的每个消息
    for i, msg in enumerate(input["messages"]):

        # 检查消息是来自user还是assistant,应用ChatML模板
        isHuman = i%2==0
        msg_chatml = templates[isHuman].format(msg=msg)

        # 标记化所有内容,稍后截断
        msg_tokenized = tokenizer(
          msg_chatml, 
          truncation=False, 
          add_special_tokens=False)

        # 复制标记和注意力掩码而不进行更改
        input_ids += msg_tokenized["input_ids"]
        attention_mask += msg_tokenized["attention_mask"]

        # 为损失计算调整标签:如果是user->IGNORE_INDEX,如果是assistant->input_ids
        # 忽略user消息,仅计算assistant消息的损失,因为这是我们想要学习
        labels += [IGNORE_INDEX]*len(msg_tokenized["input_ids"]) if isHuman else msg_tokenized["input_ids"]

    # 截断至最大长度
    return {
        "input_ids": input_ids[:max_length], 
        "attention_mask": attention_mask[:max_length],
        "labels": labels[:max_length],
    }

dataset_tokenized = dataset.map(
    # 在1024标记处截断样本
    # 对于谜题数据集足够了(最大长度1000标记)
    # 对于其他数据集,必须适应,较高的值需要更多的显存
    partial(tokenize, max_length=1024), 
    batched = False,
    # 多线程
    num_proc = os.cpu_count(),
    # 删除原始列,不再需要
    remove_columns = dataset["train"].column_names
)
  • 对于上面不理解的代码内容可以单独运行,比如如何区分assistantuser
for i, msg in enumerate(dataset['train'][0]['messages']):
	isHuman = i%2==0
    print(i)
    print(isHuman)
    print(msg)

定义collator

  • collate函数的目的是处理和准备用于训练(和评估)的batch数据,关键部分是正确填充输入。它通过使用特定标记填充到最长样本的长度来标准化batch中每个数据点的长度。 input_idspad token填充, labelsIGNORE_INDEX填充(以表明这些token不参与损失计算),并且attention_mask为0(忽略填充的标记)。
# collate函数 - 将字典列表[{input_ids: [123, ..]}, {..]}转换为一个字典
# 形成batch{input_ids: [..], labels: [..], attention_mask: [..]}
def collate(elements):

    # 从每个元素中提取input_ids,并找出它们中的最大长度
    tokens = [e["input_ids"] for e in elements]
    tokens_maxlen = max([len(t) for t in tokens])

    for i, sample in enumerate(elements):
        input_ids = sample["input_ids"]
        labels = sample["labels"]
        attention_mask = sample["attention_mask"]

        # 计算需要填充以匹配最大标记长度的填充长度
        pad_len = tokens_maxlen-len(input_ids)

        # 用pad标记ID填充'input_ids',用IGNORE_INDEX填充'labels',用0填充'attention_mask'
        input_ids.extend( pad_len * [tokenizer.pad_token_id] )
        labels.extend( pad_len * [IGNORE_INDEX] )
        attention_mask.extend( pad_len * [0] )

    # 创建并返回包含elements中所有数据的批次
    batch={
        "input_ids": torch.tensor( [e["input_ids"] for e in elements] ),
        "labels": torch.tensor( [e["labels"] for e in elements] ),
        "attention_mask": torch.tensor( [e["attention_mask"] for e in elements] ),
    }
    return batch

微调 Phi-2

加载量化模型

  • 因为在kaggle平台,GPU显存有限,所以只能加载量化后的模型。
  • 加载4-bit模型和分词器(tokenizer
modelpath = "microsoft/phi-2"
model = AutoModelForCausalLM.from_pretrained(
    modelpath,
    device_map="auto",
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_quant_type="nf4",
    ),
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

添加ChatML标记

  • ChatML特殊标记添加到模型和tokenizer中。
  • 关于ChatML是一种模型能看的懂的语言格式。
# fast tokenizer有时会忽略添加的tokens
tokenizer = AutoTokenizer.from_pretrained(modelpath, use_fast=False)    

# 添加ChatML特殊标记
tokenizer.add_tokens(["<|im_start|>", "<PAD>"])
tokenizer.pad_token = "<PAD>"
tokenizer.add_special_tokens(dict(eos_token="<|im_end|>"))

# 调整模型embeddings大小
model.resize_token_embeddings(
    new_num_tokens=len(tokenizer),
    pad_to_multiple_of=64)
model.config.eos_token_id = tokenizer.eos_token_id
clear_output()

准备LoRA适配器

  • LoRALow-Rank Adaptation)是微调大型模型的有效方法。它仅在训练期间更新模型的选定部分,从而加快过程并节省内存。
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

# lora微调配置
lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules = ['fc1', 'fc2', 'Wqkv', 'out_proj'],
    lora_dropout=0.1,
    bias="none",
    modules_to_save = ["lm_head", "embed_tokens"],
    task_type="CAUSAL_LM"
)

# 添加适配器到模型
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing = False)
model = get_peft_model(model, lora_config)
model.config.use_cache = False
  • lora微调配置参数说明:
    • rankLoRA中的rank也会影响可训练参数的数量。较高的rank会增加训练参数,这意味着模型灵活性和适应能力提高,但代价是增加计算复杂性。相反,较低的rank会减少训练参数,意味着更有效的训练和更少的计算负担,但可能会降低模型灵活性。因此,rank的选择代表了模型适应性和计算效率之间的权衡。
    • lora_alpha:缩放因子,用于调整低秩更新对模型原始权重的影响,即:模型原始行为的改变程度。 LoRA 论文指出"tuning alpha is roughly the same as tuning the learning rate"(调整 alpha 与调整学习率大致相同)。关于如何设置ranklora_alpha尚未达成共识。一种方法似乎是设置lora_alpha = r,这就是我们在这里使用的。
    • target_modules:使用上述参数,我们仅训练约 5.1% 的模型权重。若资源有限,也可以选择仅训练注意力矩阵和输出权重( ['Wqkv', 'out_proj']),在rank=32的情况下,参数数量降低到 4.4% 。对线性层进行训练应该会提高模型性能,因为它更接近于完全微调,但也会增加适配器大小。
  • 更多参数说明请访问Huggingface官方文档

开始训练

  • 部分训练超参数说明:
    • batch_size:较大的batch_size更好,但受到可用VRAM的限制。训练样本越长(在tokenization过程中增加 max_length),需要的VRAM就越多。在max_length为1024个token的示例中,batch_size为1是24GB VRAM GPU上的最大值。为了增加有效批量大小, gradient_accumulation_steps设置为16,但缺点是会减慢训练过程。
    • learning_rate2e-5 的学习率对此数据集有不错的效果,当然4e-5的学习率也可能有效,并且会产生一个不错的模型而不会过度拟合。
    • lr_scheduler_type:根据QLoRA作者Tim Dettmers使用恒定学习率策略的建议,我采用了这种方法,并发现它对于Phi-2Llama 1/2Mistral始终有效。
  • 更多训练超参数见官方文档,设置好训练参数后开始训练。
from transformers import TrainingArguments, Trainer

bs=1         # batch size
ga_steps=16  # gradient acc. steps
epochs=15
lr=0.00001

steps_per_epoch=len(dataset_tokenized["train"])//(bs*ga_steps)

args = TrainingArguments(
    output_dir="out",
    per_device_train_batch_size=bs,
    per_device_eval_batch_size=16,
    evaluation_strategy="steps",
    logging_steps=2,
    eval_steps=steps_per_epoch//2,      # eval twice per epoch
    save_steps=1,         # save once per epoch
    gradient_accumulation_steps=ga_steps,
    num_train_epochs=epochs,
    lr_scheduler_type='constant',
    optim='paged_adamw_32bit',      # val_loss will go NaN with paged_adamw_8bit
    learning_rate=lr,
    group_by_length=False,
    fp16=True,
    metric_for_best_model='eval_loss',
    save_total_limit=1,
#     bf16=False,
    ddp_find_unused_parameters=False,
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=collate,
    train_dataset=dataset_tokenized["train"],
    eval_dataset=dataset_tokenized["test"],
)

trainer.train()

训练分析

  • 训练集损失
    请添加图片描述
  • 验证集损失
    请添加图片描述

模型合并

  • LoRA适配器训练完成以后,需要与原模型进行合并。
modelpath = "microsoft/phi-2"
adapter_path='/kaggle/input/phi-2-finetune/out/checkpoint-846'

save_to="merged"       

base_model = AutoModelForCausalLM.from_pretrained(
    modelpath,
    return_dict=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

tokenizer = AutoTokenizer.from_pretrained(modelpath)

tokenizer.add_tokens(["<|im_start|>", "<PAD>"])
tokenizer.pad_token = "<PAD>"
tokenizer.add_special_tokens(dict(eos_token="<|im_end|>"))
base_model.resize_token_embeddings(
    new_num_tokens=len(tokenizer),
    pad_to_multiple_of=64)
base_model.config.eos_token_id = tokenizer.eos_token_id

model = PeftModel.from_pretrained(base_model, adapter_path)
model = model.merge_and_unload()

model.save_pretrained(save_to, safe_serialization=True, max_shard_size='4GB')
tokenizer.save_pretrained(save_to)

clear_output()

微调前后对比

  • 先加载一下原模型,输入谜语,看看回答是什么。
torch.set_default_device("cuda")

model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)

inputs = tokenizer('''What makes a noise like a bell and flies, but cannot be seen? The answer lies in the bright blue sky.''', return_tensors="pt", return_attention_mask=False)

outputs = model.generate(**inputs, max_length=200)
text = tokenizer.batch_decode(outputs)[0]

clear_output()
print(text)

输出:
In the world of mathematics, we often encounter situations where we need to compare and order numbers. This skill is essential in various fields, including science, engineering, and even everyday life. Let’s explore the concept of comparing and ordering numbers using the tones of science, specifically the principles of physics and the states of matter.

Imagine you are in a science lab, conducting an experiment to study the behavior of different substances. You have a set of test tubes filled with various liquids, each representing a different state of matter. The liquids in the test tubes are like numbers, and we can compare and order them based on their properties.
参考中译:
在数学世界中,我们经常会遇到需要对数字进行比较和排序的情况。这项技能在科学、工程甚至日常生活等各个领域都至关重要。让我们用科学的视角,特别是物理学原理和物质状态,来探讨数字比较和排序的概念。

想象一下,您正在科学实验室里进行一项实验,研究不同物质的行为。你有一组试管,里面装满了各种液体,每种液体代表一种不同的物质状态。试管中的液体就像数字,我们可以根据它们的性质进行比较和排序。

  • 可以说是非常糟糕的回答,我们看看微调后的模型会输出什么。
model = AutoModelForCausalLM.from_pretrained("/kaggle/working/merged", torch_dtype="auto", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("/kaggle/working/merged", trust_remote_code=True)

inputs = tokenizer('''<|im_start|>What makes a noise like a bell and flies, but cannot be seen? The answer lies in the bright blue sky.<|im_end|>''', return_tensors="pt", return_attention_mask=False)

outputs = model.generate(**inputs, max_length=300)
text = tokenizer.batch_decode(outputs)[0]

clear_output()
print(text)

输出:
The answer to the riddle is a “bluebird.” Bluebirds make a distinctive bell-like sound with their wings, and they are often seen flying in the sky. However, they cannot be seen with the naked eye as they are small birds. If you need more information, please let me know what specific aspect of the answer you would like to know.
参考中译:
谜底是 “青鸟”。青鸟用翅膀发出独特的铃铛声,人们经常看到它们在天空中飞翔。不过,由于它们是小型鸟类,肉眼无法看到。如果您需要更多信息,请告诉我您想知道答案的具体方面。

  • 微调后的模型得到了相对满意的答案。请注意,这是在4-bit量化状态下微调的答案,如果可以在float32状态下微调,或许会得到更好的答案。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1371194.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C# WPF 数据绑定

需求 后台变量发生改变&#xff0c;前端对应的相关属性值也发生改变 实现 接口 INotifyPropertyChanged 用于通知客户端&#xff08;通常绑定客户端&#xff09;属性值已更改。 示例 示例一 官方示例代码如下 using System; using System.Collections.Generic; using Sy…

IoT 物联网 MQTT 协议 5.0 版本新特性

MQTT 是一种基于发布/订阅模式的轻量级消息传输协议&#xff0c;专门为设备资源有限和低带宽、高延迟的不稳定网络环境的物联网场景应用而设计&#xff0c;可以用极少的代码为联网设备提供实时可靠的消息服务。MQTT 协议广泛应用于智能硬件、智慧城市、智慧农业、智慧医疗、新零…

Linux:linux计算机和windows计算机 之间 共享资源

在前面章节已经介绍过&#xff0c;NFS用于Linux系统之间的文件共享&#xff0c;windows 并不知道 NFS &#xff0c;而是使用 CIFS (Common Internet File System) 的协议机制 来 “共享” 文件。在1991年&#xff0c;Andrew Tridgell 通过逆向工程 实现了 CIFS 协议&#xff0c…

GAMES101-Assignment5

一、问题总览 在这次作业中&#xff0c;要实现两个部分&#xff1a;光线的生成和光线与三角的相交。本次代码框架的工作流程为&#xff1a; 从main 函数开始。我们定义场景的参数&#xff0c;添加物体&#xff08;球体或三角形&#xff09;到场景中&#xff0c;并设置其材质&…

【Cadence】sprobe的使用

实验目的&#xff1a;通过sprobe测试电路中某个节点的阻抗 这里通过sprobe测试输入阻抗&#xff0c;可以通过port来验证 设置如下&#xff1a; 说明&#xff1a;Z1代表sprobe往left看&#xff0c;Z2代表sprobe往right看 结果如下&#xff1a; 可以看到ZM1I0.Z2 顺便给出了I…

一篇文章了解做仿真软件的达索系统-达索代理商

达索系统是一家全球领先的仿真软件公司&#xff0c;致力于为客户提供创新和高效的解决方案。该公司的仿真软件被广泛应用于航空航天、汽车、能源、医疗等领域&#xff0c;为客户提供了强大的工程仿真能力。 达索系统的仿真软件具有多个特点&#xff0c;包括高精度、高效率、易用…

CSS 改变鼠标样式(大全)

使用方法&#xff1a; <span style"cursor:auto">Auto</span><span style"cursor:crosshair">Crosshair</span><span style"cursor:default">Default</span><span style"cursor:pointer">P…

高通平台开发系列讲解(USB篇)adb function代码分析

文章目录 一、FFS相关动态打印二、代码入口三、ffs_alloc_inst四、ep0、ep1&ep2的注册五、读写过程沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本文主要介绍高通平台USB adb function代码f_fs.c。 一、FFS相关动态打印 目录:msm-4.14/drivers/usb/gadget/fun…

系统存储架构升级分享

一、业务背景 系统业务功能&#xff1a;系统内部进行数据处理及整合, 对外部系统提供结果数据的初始化(写)及查询数据结果服务。 系统网络架构: • 部署架构对切量上线的影响 - 内部管理系统上线对其他系统的读业务无影响 •分布式缓存可进行单独扩容, 与存储及查询功能升级…

蓝牙信标定位原理

定位原理&#xff1a;蓝牙信标的定位原理是基于RSSI蓝牙信号强度来做定位的。 根据应用场景不同&#xff0c;通过RSSI定位原理可分为两种定位方式 一、存在性定位 这种方式通常要求所需定位的区域安装一个蓝牙信标即可&#xff0c;手持终端扫描蓝牙信标信号&#xff0c;扫描…

U盘删除的文件不在回收站如何恢复?教你3个简单方法!

“我在清理u盘的时候误删了一些重要的文件&#xff0c;想将这些文件恢复时才发现它们不在回收站中了。还有办法恢复吗&#xff1f;” 在数字化时代&#xff0c;u盘的作用渐渐显现。很多用户会将重要的数据直接保存在u盘中。但在使用u盘的过程中&#xff0c;不可避免会有数据的丢…

渐变登录页

效果演示 实现了一个简单的登录页面的样式和交互效果。 Code <div class"flex"><div class"login color">Login</div><label class"color">Username :</label><input type"text" class"input&…

第二证券:如何判断主力是在洗盘还是出货?

怎样判别主力是在洗盘仍是出货&#xff1f; 1、依据股票成交量判别 在洗盘时&#xff0c;个股的成交量与前几个生意相比较&#xff0c;呈现缩量的状况&#xff0c;而出货其成交量与前几个生意日相比较呈现放量的走势。 2、依据股票筹码分布判别 洗盘首要是将一些散户起浮筹…

如何创建自己的小程序?零编程一键创建实战指南

当今瞬息万变的数字世界中&#xff0c;拥有一个属于自己的小程序已成为企业与个人展示、服务和互动的重要途径。无需编码知识&#xff0c;通过便捷的云端可视化平台&#xff0c;也可以轻松创建一款符合自身需求且功能丰富的小程序。下面给大家分享如何创建自己的小程序。 1、选…

K8S Secret 一文详解, 全面覆盖 Secret 使用场景 | 全家桶

博客原文 文章目录 Secret介绍Secret 类型kubectl 创建类型 Secret 使用Opaque 类型 Secret 的使用创建1. kubectl create2. yaml 挂载1. 作为环境变量2. 作为文件挂载及设置 POSIX 权限 Secret 绑定 serviceAccount查看 secret TLS Secretyaml 方式创建kubectl 创建 Docker 镜…

windows下全免费手动搭建php8+mysql8开发环境及可视化工具安装

最近PHP项目少了&#xff0c;一直在研究UE5和Golang&#xff0c;但是考虑到政府、国企未来几年国产化的要求&#xff0c;可能又要重拾PHP。于是近日把用了N年的框架重新更新至适合PHP8.2以上的版本&#xff0c;同时也乘着新装机&#xff0c;再次搭建php和mysql开发环境。本文留…

一个Pygame的Hello World示例程序

创建一个标题为Hello World的窗口&#xff0c;窗口中间显示有Pygame的Logo的python代码 import sys import pygamedef main():pygame.init()screen pygame.display.set_mode((800, 400))pygame.display.set_caption("Hello World")logo pygame.image.load("p…

Python解析参数的三种方法

今天我们分享的主要目的就是通过在 Python 中使用命令行和配置文件来提高代码的效率 Let’s go! 我们以机器学习当中的调参过程来进行实践&#xff0c;有三种方式可供选择。第一个选项是使用 argparse&#xff0c;它是一个流行的 Python 模块&#xff0c;专门用于命令行解析&…

【UML】第16篇 活动图

目录 一、什么是活动图 二、应用场景&#xff1a; 三、绘图符号的说明&#xff1a; 四、语法&#xff1a; 五、例图 六、建模的流程 6.1 对业务流程建模时 6.2 对用例进行活动图建模时 一、什么是活动图 活动图&#xff08;Activity Diagram&#xff09;是UML中用于描…

python 基础 面向对象

类 class Student:name Nonegender Nonenationality Nonenative_place Noneage Nonedef say_hi(self):print(self.name)def dowork(self,work):print(f"{self.name} {work}") student1 Student() student1.name "xxx" student1.gender "男&qu…