采用LoRA方法微调llama3大语言模型

news2024/11/20 21:21:31

文章目录

  • 前言
  • 一、Llama3模型简介
    • 1.下载llama3源码到linux服务器
    • 2.安装依赖
    • 3.测试预训练模型Meta-Llama-3-8B
    • 4.测试指令微调模型Meta-Llama3-8B-Instruct
    • 5.小结
  • 二、LoRA微调Llama3
    • 1.引入库
    • 2.编写配置文件
    • 3.LoRA训练的产物
  • 三、测试新模型效果
    • 1.编写配置文件
    • 2.运行配置文件:
  • 总结


前言

因为上篇文章点赞数超过1,所以今天继续更新llama3微调方法。先介绍一下如何与本地llama3模型交互,再介绍如何使用torchtune和LoRA方式微调llama3,最后介绍一下如何用torchtune与llama3模型交互。


一、Llama3模型简介

目前llama3开源的模型有Meta-Llama-3-8B、Meta-Llama-3-8B-Instruct、Meta-Llama-3-70B和Meta-Llama-3-70B-Instruct。这里Meta-Llama-3-8B是参数规模为80亿的预训练模型(pretrained model),Meta-Llama-3-8B-Instruct是80亿参数经过指令微调的模型(instruct fine-tuned model);对应的,后两个模型就是对应700亿参数的预训练和指令微调模型。那么,预训练模型和指令微调模型有什么区别呢?我们来跟她们对话一下就明白了。

1.下载llama3源码到linux服务器

git clone https://github.com/meta-llama/llama3.git

2.安装依赖

最好先用anaconda创建一个专门为微调模型准备的python虚拟环境,然后运行命令:

cd llama3
pip install -e .

3.测试预训练模型Meta-Llama-3-8B

torchrun --nproc_per_node 1 example_text_completion.py
–ckpt_dir Meta-Llama-3-8B/
–tokenizer_path Meta-Llama-3-8B/tokenizer.model
–max_seq_len 128 --max_batch_size 4

参数解释:
–ckpt_dir 模型权重所在文件夹路径,一般后缀为.pt、.pth或.safetensors
–tokenizer_path 分词器路径,必须带上分词器名称,例如tokenizer.model
–max_seq_len 输出的最大序列长度,这个在预训练模型的调用中是必带参数
–max_batch_size 每个批次包含的最大样本数

下图是模型的输出结果,当我输入英文"I believe the meaning of life is"时,模型会输出"to love. It is to love others, to love ourselves, and to love God. Love is the meaning of life blablabla"。
llama3预训练模型的输出
很明显,预训练模型Meta-Llama3-8B是对用户输入的一个续写。

4.测试指令微调模型Meta-Llama3-8B-Instruct

torchrun --nproc_per_node 1 example_chat_completion.py
–ckpt_dir /data/jack/Meta-Llama-3-8B-Instruct/original/
–tokenizer_path /data/jack/Meta-Llama-3-8B-Instruct/original/tokenizer.model
–max_seq_len 512 --max_batch_size 4
参数解释:
–max_seq_len 输出的最大序列长度,这个对指令微调模型是可选参数,不设置指令微调模型也会在问答中自动停止输出
指令词微调的模型问答
如上图所示,Meta-Llama-3-8B-Instruct模型面对用户的提问,会给出合适的回答。

5.小结

Meta-Llama-3-8B是针对用户输入的一个续写,跟Transformer架构的模型在预训练过程中的下一词汇预测很相似;Meta-Llama-3-8B-Instruct是可以回答用户提问的模型。因此,在选择LoRA微调的基底模型时,大部分情况应当选择指令词微调模型。

二、LoRA微调Llama3

1.引入库

在切换到anaconda或venv的python环境后:

pip install torchtune

2.编写配置文件

如果下载了torchtune仓库的源码,可以从中拷贝出对应的recipe文件,文件夹的相对路径为:
torchtune\recipes\configs\llama3

# Model Arguments
model:
  _component_: torchtune.models.llama3.lora_llama3_8b
  lora_attn_modules: ['q_proj', 'v_proj']
  apply_lora_to_mlp: False
  apply_lora_to_output: False
  lora_rank: 8
  lora_alpha: 16

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /data/jack/Meta-Llama-3-8B-Instruct/original/tokenizer.model

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: /data/jack/Meta-Llama-3-8B-Instruct/original/
  checkpoint_files: [
    consolidated.00.pth
  ]
  recipe_checkpoint: null
  output_dir: /data/jack/Meta-Llama-3-8B-Instruct/
  model_type: LLAMA3
resume_from_checkpoint: False

# Dataset and Sampler
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  train_on_input: True
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
  _component_: torch.optim.AdamW
  weight_decay: 0.01
  lr: 3e-4
lr_scheduler:
  _component_: torchtune.modules.get_cosine_schedule_with_warmup
  num_warmup_steps: 100

loss:
  _component_: torch.nn.CrossEntropyLoss

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 64
compile: False

# Logging
output_dir: /data/jack/torchtune_test/lora_finetune_output
metric_logger:
  _component_: torchtune.utils.metric_logging.DiskLogger
  log_dir: ${output_dir}
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True

# Profiler (disabled)
profiler:
  _component_: torchtune.utils.profiler
  enabled: False

这里解释一下各个参数:
tokenizer下的path 这个一定要填写你下载的基底模型的分词器路径
checkpointer下的checkpoint_dir 填写你要微调的基底模型的存放文件夹路径
checkpointer下的checkpoint_files 基底模型文件名(一个或多个)
dataset下的_component_ 填写现有的或自定义的数据集类名
loss下的_component_ 填写训练过程中的损失函数,默认是交叉熵损失
epochs(发音:一剖克斯) 训练轮次,可以先设置为1试一下

下命令进行微调:

tune run lora_finetune_single_device --config ./8B_lora_single_device_custom.yaml

3.LoRA训练的产物

结束训练后,在输出文件夹中会产出一个合并了lora和基底模型权重的新模型meta_model_0.pt,也会产出一个独立的lora权重文件adapter_0.pt。

三、测试新模型效果

1.编写配置文件

# Model arguments
model:
  _component_: torchtune.models.llama3.llama3_8b

checkpointer:
  _component_: torchtune.utils.FullModelMetaCheckpointer
  checkpoint_dir: /data/feipan3/Meta-Llama-3-8B-Instruct/
  checkpoint_files: [meta_model_0.pt]
  output_dir: /data/feipan3/Meta-Llama-3-8B-Instruct/
  model_type: LLAMA3

device: cuda
dtype: bf16

seed: 5678

# Tokenizer arguments
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: /data/feipan3/Meta-Llama-3-8B-Instruct/original/tokenizer.model

# Generation arguments; defaults taken from gpt-fast
prompt: "Give three tips for staying healthy."
max_new_tokens: 512
temperature: 0.6 # 0.8 and 0.6 are popular values to try
top_k: 300

quantizer: null

2.运行配置文件:

tune run generate --config generation_custom.yaml
我的问题是“给出保持健康的三条建议”,得到输出结果:
微调后的问答效果


总结

本文主要介绍了如何用torchtune工具微调Llama3,详细介绍了如何设置配方文件,如何运行以及最后如何进行效果验证。本文点赞超过2个,下一期继续更新如何准备微调数据集。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1682876.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(Qt) 默认QtWidget应用包含什么?

文章目录 ⭐前言⭐创建🛠️选择一个模板🛠️Location🛠️构建系统🛠️Details🛠️Translation🛠️构建套件(Kit)🛠️汇总 ⭐项目⚒️概要⚒️构建步骤⚒️清除步骤 ⭐Code🔦untitled…

Arcpy安装和环境配置

一、前言 ArcPy 是一个以成功的arcgisscripting 模块为基础并继承了arcgisscripting 功能进而构建而成的站点包。目的是为以实用高效的方式通过 Python 执行地理数据分析、数据转换、数据管理和地图自动化创建基础。该包提供了丰富纯正的 Python 体验,具有代码自动…

思维导图-VPN

浏览器集成了受信任的机构的证书

解决word里加入mathtype公式后行间距变大

1.布局>页面设置>文档网格,网格栏选为无网格 2.固定间距

数据库|基于T-SQL创建数据库

哈喽,你好啊,我是雷工! SQL Server用于操作数据库的编程语言为Transaction-SQL,简称T-SQL。 本节学习基于T-SQL创建数据库。以下为学习笔记。 01 打开新建查询 首先连接上数据库,点击【新建查询】打开新建查询窗口, …

Linux基础命令[27]-gpasswd

文章目录 1. gpasswd 命令说明2. gpasswd 命令语法3. gpasswd 命令示例3.1 不加参数3.2 -a(将用户加入组)3.3 -d(从组中删除用户)3.4 -r(删除组密码)3.5 -M(多个用户一起加入组)3.6 …

23种设计模式(持续输出中)

一.设计模式的作用 设计模式是软件从业人员长期总结出来用于解决特定问题的通用性框架,它提高了代码的可维护性、可扩展性、可读性以及复用性。 二.设计模式 1.工厂模式 工厂模式提供了创建对象的接口,而无需制定创建对象的具体类,工厂类…

kafka集群跨区域跨集群同步方案MirrorMaker1 —— 筑梦之路

MirrorMaker原理架构 数据流向 上图也是一种比较常见的用法,这里作为记录。下面介绍一则实战案例。 网络架构 配置日志采集器filebeat 配置从哪里采集日志 输出到kafka集群 配置MirrorMaker消费者 参数说明: bootstrap.servers 指定消费哪个kafka的数…

【HarmonyOS4学习笔记】《HarmonyOS4+NEXT星河版入门到企业级实战教程》课程学习笔记(八)

课程地址: 黑马程序员HarmonyOS4NEXT星河版入门到企业级实战教程,一套精通鸿蒙应用开发 (本篇笔记对应课程第 15 节) P15《14.ArkUI组件-状态管理state装饰器》 回到最初的 Hello World 案例,首先验证 如果删掉 State…

Day22:Leetcode:654.最大二叉树 + 617.合并二叉树 + 700.二叉搜索树中的搜索 + 98.验证二叉搜索树

LeetCode:654.最大二叉树 1.思路 解决方案: 单调栈是本题的最优解,这里将单调栈题解本题的一个小视频放在这里 单调栈求解最大二叉树的过程当然这里还有leetcode大佬给的解释,大家可以参考一下: 思路很清晰&#xf…

软件开源协议与QT的开源协议介绍

一.常见的六种开源协议 1.BSD协议 BSD协议全称为“Berkely Software Distribution”,中文译为“伯克利软件发行版”。其最早用于伯克利UNIX操作系统上的开源贡献。 主要特点: 允许修改源码 允许源码再发布 允许商业软件发布和销售 约束&#xff1…

JVM学习-垃圾回收(二)

标记-清除(Mark-Sweep)算法 当堆中的有效内存空间被耗尽的时候,就会停止整个程序(stop the world),然后进行两项工作,第一项则是标记,第二项是清除 标记:Collector从引用根节点开始遍历,标记所有被引用的…

fork 与 vfork 的区别

关键区别一: vfork 直接使用父进程存储空间,不拷贝。 关键区别二: vfork保证子进程先运行,当子进程调用exit退出后,父进程才执行。 我们可以定义一个cnt,在子进程中让它变成3, 如果使用fork,那…

java 8--Lambda表达式,Stream流

目录 Lambda表达式 Lambda表达式的由来 Lambda表达式简介 Lambda表达式的结构 Stream流 什么是Stream流? 什么是流呢? Stream流操作 中间操作 终端操作 Lambda表达式 Lambda表达式的由来 Java是面向对象语言,除了部分简单数据类型…

【机器学习】前沿探索,如何让前端开发更加搞笑

在当今数字化时代,机器学习的崛起为前端开发带来了巨大的机遇和挑战。随着人工智能和数据科学的不断进步,前端工程师不再局限于传统的界面设计和交互体验,而是开始探索如何将机器学习技术融入到他们的工作中,以创造更加智能、个性…

适合暑期做的赚钱副业兼职有哪些?盘点6个适合暑期做的赚钱副业

随着夏日的热浪滚滚而来,学生们纷纷结束了忙碌的学期,迎来了盼望已久的暑假。你是否想过在这个长假里,除了享受阳光、海滩和美食外,还能顺便赚点零用钱呢? 今天,就为大家揭秘六大神秘副业,让你在…

【漏洞复现】海康威视综合安防管理平台 iSecure Center applyCT fastjson 远程代码执行

0x01 漏洞名称 海康威视综合安防管理平台 iSecure Center applyCT fastjson 远程代码执行 0x02 漏洞影响 0x03 搜索引擎 app"HIKVISION-综合安防管理平台"0x04 漏洞详情 POST /bic/ssoService/v1/applyCT HTTP/1.1 User-Agent: Mozilla/5.0 (Windows NT 10.0; Wi…

程序语言基础知识

文章目录 1.程序设计语言2. 程序设计语言的特点和分类3. 编译程序(编译器)的工作原理4. 程序语言的数据成分4.1 数据成分4.2 运算成分4.3 控制成分4.4 传输成分 1.程序设计语言 低级语言:机器语言和汇编语言。 机器语言:二进制代…

正确可用--Notepad++批量转换文件编码为UTF8

参考了:Notepad批量转换文件编码为UTF8_怎么批量把ansi转成utf8-CSDN博客​​​​​​https://blog.csdn.net/wangmy1988/article/details/118698647我参考了它的教程,但是py脚本写的不对. 只能改一个.不能实现批量更改. 他的操作步骤没问题,就是把脚本代码换成我这个. #-*-…

浅谈后端整合Springboot框架后操作基础配置

boot基础配置 现在不访问端口8080 可以吗 我们在默认启动的时候访问的是端口号8080 基于属性配置的 现在boot整合导致Tomcat服务器的配置文件没了 我们怎么去修改Tomcat服务器的配置信息呢 配置文件中的配置信息是很多很多的... 复制工程 保留工程的基础结构 抹掉原始…