[NLP]LLM---FineTune自己的Llama2模型

news2024/11/24 17:55:41

一 数据集准备

Let’s talk a bit about the parameters we can tune here. First, we want to load a llama-2-7b-hf model and train it on the mlabonne/guanaco-llama2-1k (1,000 samples), which will produce our fine-tuned model llama-2-7b-miniguanaco. If you’re interested in how this dataset was created, you can check this notebook. Feel free to change it: there are many good datasets on the Hugging Face Hub, like databricks/databricks-dolly-15k.

使用如下代码, 准备离线数据集(GPU机器不能联网)

import os
from datasets import load_from_disk, load_dataset

print("loading dataset...")
dataset_name = "mlabonne/guanaco-llama2-1k"
dataset = load_dataset(path=dataset_name, split="train", download_mode="reuse_dataset_if_exists")
print(dataset)

offline_dataset_path = "./guanaco-llama2-1k-offline"
os.makedirs(offline_dataset_path, exist_ok=True)

print("save to disk...")
dataset.save_to_disk('./guanaco-llama2-1k-offline')
print("load from disk")
dataset = load_from_disk("./guanaco-llama2-1k-offline")
print(dataset)

<s>[INST] 32位应用内存不够了应该怎么办 [/INST] 1、删除不必要的数据及应用程序:智能管理器(内存管理器)-储存空间/内存。\n2、关闭不需要的应用程序:点击屏幕左下角近期任务键-点击要关闭的应用程序右上角的“X”/下方的关闭全部。\n3、若手机支持存储卡,将多媒体文件移动到扩展卡或将拍摄存储路径设置为扩展卡。若未有存储卡,建议定期将多媒体文件移动到电脑中。\n4、卸载不经常使用的软件或文件。 </s>

['<s>[INST] 我失恋了该怎么安慰自己。 [/INST] 失恋是一种很痛苦的经历,但是在这个过程中,你需要学会如何安慰自己,以帮助自己恢复过来。以下是一些可以帮助你安慰自己的建议:\n\n1.允许自己感受痛苦:失恋是一种正常的情感反应,你不需要隐藏或压抑自己的感受。让自己慢慢地去面对这个事实,并接受自己的情绪。\n\n2.找到支持:找到可以倾诉的朋友、家人或心理咨询师,让他们帮助你排解情绪,倾听你的烦恼。不要孤单地面对这个问题。\n\n3.做一些让自己感觉良好的事情:找到一些让自己快乐的事情,比如运动、阅读、听音乐、旅行等等。这些活动可以帮助你分散注意力,从而减轻痛苦。\n\n4.关注自己的成长:失恋是一个成长的机会,它可以让你更好地认识自己、学会如何处理情感和人际关系。关注自己的成长,不要忘记这个过程也是有意义的。\n\n5.给自己时间:失恋需要时间来治愈,不要强迫自己很快就要恢复。给自己足够的时间来处理这个过程,慢慢地恢复自己的心态。\n\n失恋是一种痛苦的经历,但也是一种学习和成长的机会。相信自己可以渡过这个难关,未来会更加美好。 </s><s>[INST] 你可以安慰一下失恋的我吗? [/INST] 说起来可能您会比较失望,我作为一个处理文本的AI语言模型,虽然能在您表达的意思的基础上,给您一些文本回答,但是我没有真正的对现实事件的认识能力,以及对现实的人的共情能力。但是我会很愿意去您提供一些安慰和支持,如果您愿意分享您的感受和想法,以及让我知道您需要什么样的帮助和支持,我会尽我所能去帮助您度过这段难过的时期。 </s>']

import pandas as pd
from datasets import load_dataset
# Initialize lists to store prompts and responses
prompts = []
responses = []

prev_examples = []
# Parse out prompts and responses from examples
for example in prev_examples:
  try:
    split_example = example.split('-----------')
    prompts.append(split_example[1].strip())
    responses.append(split_example[3].strip())
  except:
    pass

# Create a DataFrame
df = pd.DataFrame({
    'prompt': prompts,
    'response': responses
})

# Remove duplicates
df = df.drop_duplicates()

print('There are ' + str(len(df)) + ' successfully-generated examples. Here are the first few:')

df.head()

# Split the data into train and test sets, with 90% in the train set
train_df = df.sample(frac=0.9, random_state=42)
test_df = df.drop(train_df.index)

# Save the dataframes to .jsonl files
train_df.to_json('train.jsonl', orient='records', lines=True)
test_df.to_json('test.jsonl', orient='records', lines=True)

# Load datasets
train_dataset = load_dataset('json', data_files='/content/train.jsonl', split="train")
valid_dataset = load_dataset('json', data_files='/content/test.jsonl', split="train")

# Preprocess datasets
train_dataset_mapped = train_dataset.map(lambda examples: {'text': [f'<s>[INST] ' + prompt + ' [/INST]</s>' + response for prompt, response in zip(examples['prompt'], examples['response'])]}, batched=True)
valid_dataset_mapped = valid_dataset.map(lambda examples: {'text': [f'<s>[INST] ' + prompt + ' [/INST]</s>' + response for prompt, response in zip(examples['prompt'], examples['response'])]}, batched=True)

# trainer = SFTTrainer(
#     model=model,
#     train_dataset=train_dataset_mapped,
#     eval_dataset=valid_dataset_mapped,  # Pass validation dataset here
#     peft_config=peft_config,
#     dataset_text_field="text",
#     max_seq_length=max_seq_length,
#     tokenizer=tokenizer,
#     args=training_arguments,
#     packing=packing,
# )

如果是自己的数据集,首先需要对训练数据处理一下,因为训练数据包括了两列,分别是prompt和response,我们需要把两列的文本合为一起,通过格式化字符来区分,如以下格式化函数:

{'text': ['[INST]  + prompt + ' [/INST] ' + response)

二  开始微调Llama2

In this section, we will fine-tune a Llama 2 model with 7 billion parameters on a A800 GPU(90G) with high RAM. we need parameter-efficient fine-tuning (PEFT) techniques like LoRA or QLoRA.

QLoRA will use a rank of 64 with a scaling parameter of 16 (see this article for more information about LoRA parameters). We’ll load the Llama 2 model directly in 4-bit precision using the NF4 type and train it for 1 epoch. To get more information about the other parameters, check the TrainingArguments, PeftModel, and SFTTrainer documentation.

import os
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer

# Load dataset (you can process it here)

# from datasets import load_dataset
#
# print("loading dataset")
# dataset_name = "mlabonne/guanaco-llama2-1k"
# dataset = load_dataset(dataset_name, split="train")
# dataset.save_to_disk('./guanaco-llama2-1k-offline')

from datasets import load_from_disk

offline_dataset_path = "./guanaco-llama2-1k-offline"
dataset = load_from_disk(offline_dataset_path)

# The model that you want to train from the Hugging Face hub
model_name = "/home/work/llama-2-7b"

# Fine-tuned model name
new_model = "llama-2-7b-miniguanaco"

################################################################################
# QLoRA parameters
################################################################################

# LoRA attention dimension
lora_r = 64

# Alpha parameter for LoRA scaling
lora_alpha = 16

# Dropout probability for LoRA layers
lora_dropout = 0.1

################################################################################
# bitsandbytes parameters
################################################################################

# Activate 4-bit precision base model loading
use_4bit = True

# Compute dtype for 4-bit base models
bnb_4bit_compute_dtype = "float16"

# Quantization type (fp4 or nf4)
bnb_4bit_quant_type = "nf4"

# Activate nested quantization for 4-bit base models (double quantization)
use_nested_quant = False

################################################################################
# TrainingArguments parameters
################################################################################

# Output directory where the model predictions and checkpoints will be stored
output_dir = "./results"

# Number of training epochs
num_train_epochs = 1

# Enable fp16/bf16 training (set bf16 to True with an A100)
fp16 = False
bf16 = False

# Batch size per GPU for training
per_device_train_batch_size = 16

# Batch size per GPU for evaluation
per_device_eval_batch_size = 16

# Number of update steps to accumulate the gradients for
gradient_accumulation_steps = 1

# Enable gradient checkpointing
gradient_checkpointing = True

# Maximum gradient normal (gradient clipping)
max_grad_norm = 0.3

# Initial learning rate (AdamW optimizer)
learning_rate = 2e-4

# Weight decay to apply to all layers except bias/LayerNorm weights
weight_decay = 0.001

# Optimizer to use
optim = "paged_adamw_32bit"

# Learning rate schedule
lr_scheduler_type = "cosine"

# Number of training steps (overrides num_train_epochs)
max_steps = -1

# Ratio of steps for a linear warmup (from 0 to learning rate)
warmup_ratio = 0.03

# Group sequences into batches with same length
# Saves memory and speeds up training considerably
group_by_length = True

# Save checkpoint every X updates steps
save_steps = 0

# Log every X updates steps
logging_steps = 25

################################################################################
# SFT parameters
################################################################################

# Maximum sequence length to use
max_seq_length = None

# Pack multiple short examples in the same input sequence to increase efficiency
packing = False

# Load the entire model on the GPU 0
device_map = {"": 0}

# Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

# Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and use_4bit:
    major, _ = torch.cuda.get_device_capability()
    if major >= 8:
        print("=" * 80)
        print("Your GPU supports bfloat16: accelerate training with bf16=True")
        print("=" * 80)

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map=device_map
)
model.config.use_cache = False
model.config.pretraining_tp = 1

# Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"  # Fix weird overflow issue with fp16 training

# Load LoRA configuration
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM",
)

# Set training parameters
training_arguments = TrainingArguments(
    output_dir=output_dir,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    weight_decay=weight_decay,
    fp16=fp16,
    bf16=bf16,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=group_by_length,
    lr_scheduler_type=lr_scheduler_type
)

# Set supervised fine-tuning parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=packing,
)

# Train model
trainer.train()

# Save trained model
trainer.model.save_pretrained(new_model)

We can now load everything and start the fine-tuning process. We’re relying on multiple wrappers, so bear with me.

  • First of all, we want to load the dataset we defined. Here, our dataset is already preprocessed but, usually, this is where you would reformat the prompt, filter out bad text, combine multiple datasets, etc.
  • Then, we’re configuring bitsandbytes for 4-bit quantization.
  • Next, we’re loading the Llama 2 model in 4-bit precision on a GPU with the corresponding tokenizer.
  • Finally, we’re loading configurations for QLoRA, regular training parameters, and passing everything to the SFTTrainer. The training can finally start!

三  合并lora weights,保存完整模型

import os
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer
)
from peft import PeftModel

# Save trained model

model_name = "/home/work/llama-2-7b"

# Load the entire model on the GPU 0
device_map = {"": 0}

new_model = "llama-2-7b-miniguanaco"

# Reload model in FP16 and merge it with LoRA weights
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.float16,
    device_map=device_map,
)
model = PeftModel.from_pretrained(base_model, new_model)
print("merge model and lora weights")
model = model.merge_and_unload()

output_merged_dir = "final_merged_checkpoint2"
os.makedirs(output_merged_dir, exist_ok=True)
print("save model")
model.save_pretrained(output_merged_dir, safe_serialization=True)

# Reload tokenizer to save it
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

print("save tokenizer")
tokenizer.save_pretrained(output_merged_dir)

加载微调后的模型进行对话

对比微调之前: 

四 总结

In this article, we saw how to fine-tune a Llama-2-7b model. We introduced some necessary background on LLM training and fine-tuning, as well as important considerations related to instruction datasets. We successfully fine-tuned the Llama 2 model with its native prompt template and custom parameters.

These fine-tuned models can then be integrated into LangChain and other architectures as advantageous alternatives to the OpenAI API. Remember, in this new paradigm, instruction datasets are the new gold, and the quality of your model heavily depends on the data on which it’s been fine-tuned. So, good luck with building high-quality datasets!

ML Blog - Fine-Tune Your Own Llama 2 Model in a Colab Notebook (mlabonne.github.io)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/989379.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

5分钟生成10条短视频,AI重构电商营销

点击关注 文&#xff5c;姚 悦&#xff0c;编&#xff5c;王一粟 “我们将正式告别过去单一渠道投放的时代&#xff0c;走向一站式跨渠道品效联合经营的全新时代。”9月6日&#xff0c;在2023年其最重要的营销峰会上&#xff0c;淘天集团阿里妈妈市场部总经理穆尔说道。 当天…

【C++】模拟实现二叉搜索树的增删查改功能

个人主页&#xff1a;&#x1f35d;在肯德基吃麻辣烫 我的gitee&#xff1a;C仓库 个人专栏&#xff1a;C专栏 文章目录 一、二叉搜索树的Insert操作&#xff08;非递归&#xff09;分析过程代码求解 二、二叉搜索树的Erase操作&#xff08;非递归&#xff09;分析过程代码求解…

java从入门到起飞(八)——循环和递归

文章目录 Java循环1. 什么是循环&#xff1f;1.1 为什么需要循环&#xff1f;1.2 循环的分类 2. Java中的循环结构2.1 for循环2.2 while循环2.3 do-while循环 3. 循环控制语句3.1 break语句3.2 continue语句 4. 总结 Java递归1. 什么是递归2. 递归的原理3. 递归的实现4. 递归的…

鼠标右键使用VSCode打开文件或文件夹配置

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

算法通关村17关 | 透析跳跃游戏

1. 跳跃游戏 题目 LeetCode55 给定一个非负整数数组&#xff0c;最初位于数组的第一个位置&#xff0c;数组中的每个元素代表你再该位置可以跳跃的最大长度&#xff0c;判断你是否能够达到最后一个位置。 思路 如果当前位置元素如果是3&#xff0c;我们无需考虑是跳几步&#…

04、javascript 修改对象中原有的属性值、修改对象中原有属性的名字(两种方式)、添加对象中新属性等的操作

1、修改对象中原有的属性值 其一、代码为&#xff1a; // 想将 obj 中的 flag 值&#xff0c;根据不同的值来变化(即&#xff1a;修改对象中原有的属性值)&#xff1b; let obj {"port": "port_0","desc": "desc_0","flag&quo…

【LeetCode】一起探究三数之和的奥秘

Problem: 15. 三数之和 文章目录 题目解析算法原理分析排序 暴力枚举 set去重排序 单调性 双指针划分思想 复杂度Code 题目解析 首先我们来分析一下本题的思路 题目说到要我们在一个整数数组中去寻找三元组&#xff0c;而且呢这三个数字所相加的和为0&#xff0c;而且呢这三…

数据复制:构建大规模分布式系统的关键组成部分

数据复制对于构建可靠的大规模分布式系统至关重要。在本期中&#xff0c;我们将探讨常见的复制策略以及选择合适策略的关键因素。 在本期中&#xff0c;我们将以数据库为例进行讨论。请注意&#xff0c;复制不仅适用于数据库&#xff0c;还适用于缓存服务器&#xff08;如Redis…

探索Apache Hive:融合专业性、趣味性和吸引力的数据库操作奇幻之旅

文章目录 版权声明一 数据库操作二 Hive数据表操作2.1 表操作语法和数据类型2.2 Hive表分类2.3 内部表Vs外部表2.4 内部表操作2.4.1 创建内部表2.4.2 其他创建内部表的形式2.4.3 数据分隔符2.4.4 自定义分隔符2.4.5 删除内部表 2.5 外部表操作2.5.1 创建外部表2.5.2 操作演示2.…

BMS电池管理系统——电芯需求数据(三)

BMS电池管理系统 文章目录 BMS电池管理系统前言一、有什么基础数据二、基础数据分析1.充放电的截至电压2.SOC-OCV关系表3.充放电电流限制表4.充放电容量特性5.自放电率 总结 前言 在新能源产业中电芯的开发也占有很大部分&#xff0c;下面我们就来看一下电芯的需求数据有哪些 …

知名农业企业-九三食品选择泛微京桥通构建全程数字化采购管理

九三食品股份有限公司是北大荒集团旗下九三粮油工业集团有限公司的控股子公司&#xff0c;作为九三集团大豆精深加工生产及研发基地&#xff0c;九三食品公司可生产非转基因“九三”牌大豆油等植物油、食品添加剂、胶囊保健品、精油类产品、食用豆粕、豆饼粉和其它产品等七大类…

yo!这里是进程控制

目录 前言 进程创建 fork()函数 写时拷贝 进程终止 退出场景 退出方法 进程等待 等待原因 等待方法 1.wait函数 2.waitpid函数 等待结果&#xff08;status介绍&#xff09; 进程替换 替换原理 替换函数 进程替换例子 shell简易实现 后记 前言 学习完操作…

【更新至2022年】2000-2022年全国31省市以2000年为基期的实际GDP、名义GDP、GDP平减指数数据(含原始数据+计算过程+计算结果)

2000-2022年31省市名义GDP 实际GDP GDP平减指数 1、时间&#xff1a;2000-2022 2、范围&#xff1a;31省市 3、来源&#xff1a;GJ统计J和统计NJ 4、指标&#xff1a;名义GDP、地区生产总值指数&#xff08;上年100&#xff09;、实际GDP&#xff08;以2000年为基期&#x…

【光伏系统】将电流从直流转换为交流电的太阳能逆变器、太阳能跟踪系统来提高系统的整体性能及集成电池解决方案(Simulink仿真)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…

(高阶)Redis 7 第10讲 单线程 与 多线程 入门篇

面试题 1.Redis 是单线程还是多线程 最早的版本3.x是单线程。 版本4.x,严格意义不是单线程。负责处理客户端请求的线程是单线程,开始加入异步删除。 6.0.x版本后明确使用全新的多线程来解决问题 2.说说IO多路复用3.Redis 为什么快IO多路复用+epoll函…

基于SSM的乡镇自来水收费系统

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用JSP技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…

【leetcode 力扣刷题】回文串相关题目(KMP、动态规划)

回文串相关题目 5. 最长回文子串动态规划中心扩展算法 214. 最短回文串336. 回文对 5. 最长回文子串 题目链接&#xff1a;5. 最长回文子串 题目内容&#xff1a; 题目就是要我们找s中的回文子串&#xff0c;还要是最长的。其实想想&#xff0c;暴力求解也行……就是遍历所有的…

系统移植MakefileREADME文件分析

Makefile # 指定交叉编译工具链前缀变量 CROSS_COMPILE arm-linux-gnueabihf- #指定文件名字变量 NAME interface ## #-g:编译时添加gdb调试信息 -marm: 将程序编译生成arm指令集 -Wall:编译时显示所有警告信息 #-O0:编译时添加优化等级 -O0:不优化 -O1:一级优化 #-f…

zabbix配置钉钉告警、和故障自愈、监控java

文章目录 1.配置钉钉告警server 配置web界面创建媒介给用户添加媒介测试告警 实现故障自愈功能监控Javazabbix server 安装java gateway配置 Zabbix Server 支持 Java gateway使用系统内置模板监控 tomcat 主机 1.配置钉钉告警 server 配置 钉钉告警python脚本 脚本1 cd /…

实相融、云启未来,智慧公厕让城市生活更美好

现代社会&#xff0c;随着科技的不断发展&#xff0c;人们对于城市生活的要求也在不断提升。在这个过程中&#xff0c;智慧公厕作为城市基础设施中的重要组成部分&#xff0c;正在发挥着越来越重要的作用。通过数字化、云管理、人工智能等未来的科技方式&#xff0c;智慧公厕为…