在Linux系统下微调Llama2(MetaAI)大模型教程—Qlora

news2024/9/21 20:40:59

Llama2是Meta最新开源的语言大模型,训练数据集2万亿token,上下文长度是由Llama的2048扩展到4096,可以理解和生成更长的文本,包括7B、13B和70B三个模型,在各种基准集的测试上表现突出,最重要的是,该模型可用于研究和商业用途。

一、准备工作

1、本文选择微调的基础模型Llama2-chat-13B-Chinese-50W( 如何部署Llama2大模型,可以转到在Linux系统下部署Llama2(MetaAI)大模型教程-CSDN博客)

2、由于大部分笔记本电脑无法满足大模型Llama2的微调条件,因此可以选用autodl平台(算力云)作为部署平台。注:显存选择40GB以上的,否则微调过程会报错。

二、创建新实例(需要对数据盘进行扩容20GB)

基础的数据盘内存无法满足微调要求,因此需要对数据盘进行扩容。点击已经部署好Llama2大模型实例的“更多”中的“克隆实例”

勾选“数据盘”

选择可扩容的主机。

选择“需要扩容”,填写“20”GB。

填写完成后,点击“立即创建”。创建完成后,不要着急,等待一会儿。状态栏的“运行中”下面会出现“正在拷贝数据集”字样,等待数据集拷贝完成

“正在拷贝数据集”字样消失后,说明拷贝完成,点击JupyterLab。

三、下载、预处理微调数据集

cd到数据盘autodl-tep,并设置学术加速,然后运行以下代码下载数据集

如果你有自己的数据集,那么可以选择使用自己的数据集。

wget https://huggingface.co/datasets/BelleGroup/train_0.5M_CN/resolve/main/Belle_open_source_0.5M.json

原始数据集共有50万条数据,格式:{"instruction":"xxxx", "input":"", "output":"xxxx"}

数据集下载完毕之后,需要对数据集进行预处理新建一个文件:split_json.py. 右击,点击“新建文件”,然后将文件名改为split_json.py即可。

接下来,将以下代码复制粘贴至文件split_json.py中。这段程序的作用是对数据集进行拼接,只使用introduction和output,并仅选择1000条数据作为演示。但在正常生产环境中,我们就需要更大的数据量。

import random,json

def write_txt(file_path,datas):
    with open(file_path,"w",encoding="utf8") as f:
        for d in datas:
            f.write(json.dumps(d,ensure_ascii=False)+"\n")
        f.close()

with open("/root/autodl-tmp/Belle_open_source_0.5M.json","r",encoding="utf8") as f:
    lines=f.readlines()
    
    changed_data=[]
    for l in lines:
        l=json.loads(l)
        changed_data.append({"text":"### Human: "+l["instruction"]+" ### Assistant: "+l["output"]})

    r_changed_data=random.sample(changed_data, 1000)

    write_txt("/root/autodl-tmp/Belle_open_source_0.5M_changed_test.json",r_changed_data)

运行以下代码对split_json.py进行执行

python split_json.py

生成了一个新的json文件Belle_open_source_0.5M_changed_test.json,说明运行成功。

拼接好的数据格式:{"text":"### Human: xxxx ### Assistant: xxx"}

四、运行微调文件

1、返回启动页,新建一个notebook。

2、安装相关包

输入之后,按Shift+Enter运行。

!pip install -q huggingface_hub
!pip install -q -U trl transformers accelerate peft
!pip install -q -U datasets bitsandbytes einops wandb

左上角由 [*] 变为 [1] 后,说明安装成功。

3、设置学术加速

4、登录huggingface的notebook

这里需要到:https://huggingface.co/settings/tokens 中复制tokentoken获取方式可以参考:如何获取HuggingFace的Access Token;如何获取HuggingFace的API Key-CSDN博客

然后执行下列语句:

from huggingface_hub import notebook_login
notebook_login()

将token复制进去:

5、初始化wandb

首先需要先到:https://wandb.me/wandb-server 注册wandb。进入网址后,点击右上角进行登录注册。

注册完毕后在https://wandb.ai/authorize中复制Key

运行代码:

import wandb
wandb.init()

复制的Key粘贴进去,然后再Enter。如果左侧出现文件夹wandb说明运行成功。

6、导入相关包

from datasets import load_dataset
import torch,einops
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer, TrainingArguments
from peft import LoraConfig
from trl import SFTTrainer

7、加载上面拼接好之后的1000条数据

dataset = load_dataset("json",data_files="/root/autodl-tmp/Belle_open_source_0.5M_changed_test.json",split="train")

8、配置本地模型

base_model_name ="/root/autodl-tmp/Llama2-chat-13B-Chinese-50W"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,#在4bit上,进行量化
    bnb_4bit_use_double_quant=True,# 嵌套量化,每个参数可以多节省0.4位
    bnb_4bit_quant_type="nf4",#NF4(normalized float)或纯FP4量化 博客说推荐NF4
    bnb_4bit_compute_dtype=torch.float16,
)

9、配置GPU

device_map = {"": 0}
#有多个gpu时,为:device_map = {"": [0,1,2,3……]}

10、加载本地模型

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,#本地模型名称
    quantization_config=bnb_config,#上面本地模型的配置
    device_map=device_map,#使用GPU的编号
    trust_remote_code=True,
    use_auth_token=True
)
base_model.config.use_cache = False
base_model.config.pretraining_tp = 1

11、配置QLora

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
)

12、对本地模型,把长文本拆成最小的单元词(即token)

tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

13、训练参数的配置

output_dir = "./results"
training_args = TrainingArguments(
    report_to="wandb",
    output_dir=output_dir,#训练后输出目录
    per_device_train_batch_size=4,#每个GPU的批处理数据量
    gradient_accumulation_steps=4,#在执行反向传播/更新过程之前,要累积其梯度的更新步骤数
    learning_rate=2e-4,#超参、初始学习率。太大模型不稳定,太小则模型不能收敛
    logging_steps=10,#两个日志记录之间的更新步骤数
    max_steps=100#要执行的训练步骤总数
)
max_seq_length = 512
#TrainingArguments 的参数详解:https://blog.csdn.net/qq_33293040/article/details/117376382

trainer = SFTTrainer(
    model=base_model,
    train_dataset=dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_args,
)

14、开始进行微调训练

trainer.train()

可以看到,随着训练的进行,损失函数在下降:

15、把训练好的模型保存下来

import os
output_dir = os.path.join(output_dir, "final_checkpoint")
trainer.model.save_pretrained(output_dir)

五、执行代码合并

把训练好的模型与原始模型进行合并。

1、新建一个merge_model.py的文件,把下面的代码粘贴进去:

from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

#设置原来本地模型的地址
model_name_or_path = '/root/autodl-tmp/Llama2-chat-13B-Chinese-50W'
#设置微调后模型的地址,就是上面的那个地址
adapter_name_or_path = '/root/autodl-tmp/results/final_checkpoint'
#设置合并后模型的导出地址
save_path = '/root/autodl-tmp/new_model'

tokenizer = AutoTokenizer.from_pretrained(
    model_name_or_path,
    trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    trust_remote_code=True,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    device_map='auto'
)
print("load model success")
model = PeftModel.from_pretrained(model, adapter_name_or_path)
print("load adapter success")
model = model.merge_and_unload()
print("merge success")

tokenizer.save_pretrained(save_path)
model.save_pretrained(save_path)
print("save done.")

2、新建终端,然后执行上述合并代码,进行合并

python merge_model.py

运行结果:

六、使用gradio运行模型

进入Llama2文件夹:cd Llama2

python gradio_demo.py --base_model /root/autodl-tmp/new_model --tokenizer_path /root/autodl-tmp/new_model --gpus 0

七、可能遇到的问题

1、执行代码notebook_login()时报错

报错显示:

(MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /api/whoami-v2 (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x7fdf07ee4940>: Failed to establish a new connection: [Errno 110] Connection timed out'))"), '(Request ID: 3557b723-1341-4c75-b72a-f8ecf6c6a070)')

解决办法:

这是一个Python的错误信息,表明在使用Hugging Face的连接池时出现了最大重试误。根该错误信息,我们可以推测可能的原因是连接到huggingface.co的连接池达到了最大重试次数,但仍无法建立连接。这可能是由于网络连接问题、服务器不可用或其他问题导致的。

2、执行代码trainer.train()时报错

报错显示:

OutOfMemoryError: CUDA out of memory. Tried to allocate 270.00 MiB (GPU 0; 31.74 GiB total capacity; 29.60 GiB already allocated; 36.88 MiB free; 30.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

解决办法一:

将训练参数的配置中的 per_device_train_batch_size 参数设置为2,再执行代码trainer.train(),即可解决。

解决办法二:

报错的主要原因为显存不足,可以更换显存更大的主机。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1202575.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《QT从基础到进阶·二十二》QGraphicsView显示大量图形项item导致界面卡顿的解决办法

有时候因业务需要&#xff0c;paint函数在界面上绘制了成百上千个图形项Items&#xff0c;导致操作界面的时候有明显的卡顿感&#xff0c;下文会提供一种比较好的解决办法&#xff0c;先来了解下QGraphicsItem的缓存方式。 &#xff08;1&#xff09;setCacheMode(QGraphicsIt…

0基础学习PyFlink——水位线(watermark)触发计算

在《0基础学习PyFlink——个数滚动窗口(Tumbling Count Windows)》和《0基础学习PyFlink——个数滑动窗口&#xff08;Sliding Count Windows&#xff09;》中&#xff0c;我们发现如果窗口中元素个数没有把窗口填满&#xff0c;则不会触发计算。 为了解决长期不计算的问题&a…

日志及其框架

日志技术的概述 日志 生活中的日志&#xff1a; 生活中的日志就好比日记&#xff0c;可以记录你生活的点点滴滴。 程序中的日志&#xff1a; 程序中的日志可以用来记录程序运行过程中的信息&#xff0c;并可以进行永久存储。 以前记录日志的方式&#xff08;输出语句&#…

设置专属链接的这些作用你知道吗?

专属链接作为一种个性化的链接&#xff0c;用于为特定的客户或群体提供定制化的体验或服务。对于企业来说&#xff0c;每个渠道或者每个客户都能拥有一个专属链接是无比便利的事情。企业可以将这个链接嵌入到各种宣传物料中&#xff0c;让客户通过输入链接即可进入与客服的交流…

thinkphp5 连接多个服务器数据库

如果你的database.php 是这样&#xff0c; 这是默认的db连接配置 如果还想连接其他服务器&#xff0c;或数据库 在config.php中追加数据库配置&#xff0c; 在使用的地方调用&#xff1a; use think\Db;public function test(){$db3Db::connect(config(db3));$result $db3…

使用Python的requests库模拟爬取地图商铺信息

目录 引言 一、了解目标网站 二、安装requests库 三、发送GET请求 四、解析响应内容 五、处理异常和数据清洗 六、数据存储和分析 七、数据分析和可视化 八、注意事项和最佳实践 总结 引言 随着互联网的快速发展&#xff0c;网络爬虫技术已经成为获取数据的重要手段…

Leetcode-104 二叉树的最大深度

递归实现 /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}* TreeNode(int val) { this.val val; }* TreeNode(int val, TreeNode left, TreeNode right) {* …

谈谈steam游戏搬砖的收益与风险,以及注意事项

11月CSGO市场行情分析&#xff0c;是否到了该抄底的时候了&#xff1f; 今天&#xff0c;要跟大家分享的Steam平台——全球最大的游戏平台&#xff0c;现在给大家介绍下steam搬砖项目&#xff0c;这个项目既小众又稳定。 先了解一下 steam这个平台是做什么的&#xff0c;steam…

navicat创建MySql定时任务

navicat创建MySql定时任务 前提 需要root用户权限 需要开启定时任务 1、开启定时任务 1.1 查看定时任务是否开启 mysql> show variables like event_scheduler;1.2 临时开启定时任务(下次重启后失效) set global event_scheduler on;1.3 设置永久开启定时任务 查看my…

c语言-数据结构-带头双向循环链表

目录 1、双向循环链表的结构 2、双向循环链表的结构体创建 3、双向循环链表的初始化 3.1 双向链表的打印 4、双向循环链表的头插 5、双向循环链表的尾插 6、双向循环链表的删除 6.1 尾删 6.2 头删 6.3 小节结论 7、查找 8、在pos位置前插入数据 9、删除pos位…

Scala---介绍及安装使用

一、Scala介绍 1. 为什么学习Scala语言 Scala是基于JVM的语言&#xff0c;与java语言类似&#xff0c;Java语言是基于JVM的面向对象的语言。Scala也是基于JVM&#xff0c;同时支持面向对象和面向函数的编程语言。这里学习Scala语言的原因是后期我们会学习一个优秀的计算框架S…

单链表(7)

插入函数——插入数据&#xff0c;在链表plist的pos位置插入val数据元素 由图知&#xff0c;poslength时&#xff0c;是可以插入的 在大多数情况下&#xff0c;说位置的时候&#xff0c;从0开始计数&#xff1b;说第几个数据的时候&#xff0c;从1开始计数 现在来测试一下 这就…

CSDN的规范、检测文章质量、博客等级好处等等(我也是意外发现的,我相信很多人还不知道,使用分享给大家!)

前言 都是整理官方的文档&#xff0c;方便自己查看和检查使用&#xff0c;以前我也不知道。后来巧合下发现的&#xff0c;所以分享给大家&#xff01; 下面都有官方的链接&#xff0c;详情去看官方的文档。 大家严格按照官方的规范去记录自己工作生活中的文章&#xff0c;很快…

MacOS Ventura 13 优化配置(ARM架构新手向导)

一、系统配置 1、About My MacBook Pro 2、在当前标签打开新窗口 桌面上创建目录的文件夹&#xff0c;每次新打开一个目录&#xff0c;就会创建一个窗口&#xff0c;这就造成窗口太多&#xff0c;不太好查看和管理&#xff0c;我们可以改成在新标签处打开新目录。需要在&…

电动自动换刀高速电主轴的技术优势浅析

在制造业中&#xff0c;自动化技术的发展一直是一个重要的话题。其中&#xff0c;电动自动换刀被认为是一项高效、智能、先进的技术&#xff0c;在高速电主轴中使用电动自动换刀这一技术&#xff0c;不仅能够缩短换刀时间&#xff0c;还能减少换刀失误&#xff0c;本文将探讨Sy…

光计算1周2篇Nature,英伟达的时代彻底结束!

近期&#xff0c;光计算领域连续发出重量级文章&#xff0c;刊登在学术界的顶级期刊上。一时间&#xff0c;各大媒体纷纷转发&#xff0c;读者们也纷纷感叹&#xff1a;中国芯片取代英伟达的机会来了&#xff01;今天&#xff0c;光子盒用这篇万字长文为大家梳理光计算的背景、…

指标类型(一):北极星指标、虚荣指标

每个产品都有很多指标&#xff0c;每个指标都反映了对应业务的经营情况。但是在实际业务经营中&#xff0c;却要求我们在不同的产品阶段寻找到合适的指标&#xff0c;让这个指标可以代表当前产品阶段的方向和目标&#xff0c;让这个指标不仅对业务经营团队&#xff0c;而且对产…

双十一网络电视盒子哪个品牌好?内行分享权威电视盒子排行榜

双十一大促正如火如荼进行中&#xff0c;因为我从事的工作和电视盒子有关&#xff0c;身边的朋友们在选购电视盒子时不知道从何下手就会问我的意见&#xff0c;本期将盘点业内公认的电视盒子排行榜&#xff0c;给双十一想买电视盒子的朋友们做个参考。 排行一&#xff1a;泰捷W…

【C++】非类型模板参数 | array容器 | 模板特化 | 模板为什么不能分离编译

目录 一、非类型模板参数 二、array容器 三、模板特化 为什么要对模板进行特化 函数模板特化 补充一个问题 类模板特化 全特化与偏特化 全特化 偏特化 四、模板为什么不能分离编译 为什么 怎么办 五、总结模板的优缺点 一、非类型模板参数 模板参数分两类&#x…

MVVM框架:图片加载有问题

一、前言&#xff1a;在我使用ImageView加载图片的时候添加如下代码发现报错 app:imageUrl"{viewModel.observableField.assetImg}"报错如下错误 二、原因&#xff1a;是啥我不太清楚好像是没有imageView的适配器&#xff0c;后来我看了一下确实没有 public class I…