小白也能微调大模型:LLaMA-Factory使用心得

news2024/12/30 3:33:49

大模型火了之后,相信不少人都在尝试将预训练大模型应用到自己的场景上,希望得到一个垂类专家,而不是通用大模型。

目前的思路,一是RAG(retrieval augmented generation),在模型的输入prompt中加入尽可能多的“目标领域”的相关知识,引导模型在生成时尽量靠拢目标领域,运用prompt中给予的目标知识;二是有监督微调,用适量的专业领域的数据(或混通用语料)让模型更能生成目标场景的内容。本文主要讲的就是微调。

什么是LLaMA-Factory

当我们想要微调大模型的时候,一个粗略的实验过程无外乎以下几个环节:

  1. 准备好硬件(GPU)、数据;通过各方面的资讯选中你想要微调的基座模型
  2. 准备好代码:输入数据 + 模型 -> 在GPU上反复训练
  3. 训练结束以后,得到训练过程中的checkpoint + 一些log信息
  4. 根据log信息选一些比较有希望的checkpoint在自己的测试集上推理,获得相应的结果
  5. 分析结果,获得下一轮实验(数据、训练方案的迭代)思路

而LLaMA-Factory就是一个很好的负责step 2的工具(当然它能做到的远不止step2,我们后面也会提),你可以理解为,他是一份写好的代码,你只需要把你准备好的数据、硬件、模型,以传参的方式传入,运行代码,模型就开始训练了。等训练结束以后,你把训练好的模型、测试集、硬件又作为参数传入,它就会帮你推理。

LLaMA-Factory的优点

LLaMA-Factory非常适合实验阶段使用,因为:

  1. 支持很多种开源大语言模型:

    实验阶段我们肯定有好几个觉得靠谱的模型,它们往往有自己的标准输入模板(尤其是代码补全这类任务,涉及较多的special token),你想试试的模型LLaMA-Factory基本都支持,通过template参数可以很方便地指定prompt的模板

  2. 支持非常多种训练方法:

    全量调参 vs Lora vs … 或预训练模型 vs 有监督fine-tuning,以及DPO PPO的对齐方案。

    你想试试的基本也都有,也是通过指定训练模式参数即可

  3. Log:

    训练过程中记录的内容比较全,除了同步能够输出loss曲线图以外,还自带bleu等评测指标

  4. 测试环节也很方便:

    支持merge model(比如微调后的adapter合并到原模型以便作为一个模型导出推理);

    支持各种时下比较流行的量化加速方案;

    支持vllm等高并发要求的推理框架;

    需要的话还可以快速搭建一个Gradio UI用于demo展示或可视化分析

使用心得

我没有用过LLaMA-Factory的全部功能,本文暂且以基本的微调任务为阐述重点,会覆盖上面提到的输入:数据 + GPU + 模型,输出微调后模型的使用。看完以后,应该基本就能完成任意一个支持的开源模型的微调任务了。此外,本文也会涉及一小部分LLaMA-Factory的代码文件目录讲解,方便你更好地探索其他的功能相关的参数来实现你的目标任务。

环境准备

首先是需要git clone两个文件目录,一个是目标大模型的仓库(包含模型权重文件等),一个是llama-factory的仓库

然后,我们通常会在两个地方遇到相关依赖的版本要求:

  1. llama-factory的Github仓库主页下,README的Requirement部分(目前已经很贴心地标注了最低要求和推荐要求),以及代码结构目录中的requirements.txt
  2. 想要使用的目标大模型的Huggingface或Github主页下,同样README部分、代码结构目录中的requirements.txt两个部分都会有相关依赖的版本要求

一般,我们以尽快跑通我们的实验目标为目的。

如果是自己掌控度比较高的环境(自己的GPU),装包装cuda什么的都比较擅长:

  1. 检查llama-factory主页README中的要求,把几个依赖库的版本检查一下,保持在规定范围

    这主要是因为llama-factory的requirements.txt里面的相关依赖可能比较多,你不一定会用到llama-factory的所有功能

  2. 基于目标大模型文件目录中的requirements.txt,使用pip install -r requirements.txt

    这主要是因为,这个文件中基本包含这个模型要运行起来的所有依赖

  3. 尝试运行,缺啥补啥

    这里建议按照目标大模型主页的quick start,写一个简单的脚本就可以了

如果是实验室的服务器或者公司的服务器这种掌控度小的场景,记得要自己创建一个虚拟环境,或者起一个自己的docker容器,在虚拟环境或docker容器内操作,具体使用conda还是docker取决于你们公共服务器的权限管理,哪个方便用哪个,或者其他人平常用什么你就用什么。

数据集准备

模型能够跑通以后,我们准备用于微调的数据集。

这里需要理清楚几个概念

数据的内容组织方式,取决于训练场景的输入和输出。

通常一个样本由(输入,输出)的pair构成,场景上主要是下面3种(更多的可以参考readme里关于数据集准备的部分)

  1. 预训练场景:在一句话里并没有特别关注某个位置的内容,想要提升整个训练集语料上的general效果,此时对于GPT架构的模型,一般使用的输入和输出是相同的,所以如果我们指定了训练模式为预训练,那么llama-factory会自动copy输入内容作为输出label的
  2. sft(有监督微调supervised fine tuning):这个场景下,我们特别关注字符串里某个位置的内容,想要针对性地提升。比如NL2SQL专门去调SQL部分的风格或者内容,那就可以只把SQL部分作为输出,NL部分作为输入,而不是把NL+SQL一整句话作为输入和输出;再比如代码补齐场景,一般前文后文作为输入,补齐的中间部分作为输出,针对补齐部分做loss的计算
  3. 偏好对齐场景:主要是输出部分会有两个label,一个更好的,一个更差的,主要是适应DPO等热门的微调方法,模型不光可以从具体的label中学习,还可以通过两个label的差距来学习,目前后者带来的效果大体上更好,学习的目标更精细,有很多文章可以按兴趣去学习。

无论哪种场景,我们都可以按照llama-factory要求的标准格式组织数据集,保存成一个文件,比如下面这种.json文件:

[
  {
    "instruction": "user instruction (required)",
    "input": "user input (optional)",
    "output": "model response (required)",
    "system": "system prompt (optional)",
    "history": [
      ["user instruction in the first round (optional)", "model response in the first round (optional)"],
      ["user instruction in the second round (optional)", "model response in the second round (optional)"]
    ]
  }
]

这里的instruction,input,output等键实际上不重要,你想叫什么都可以,关键在于你要把这个数据集的相关信息,注册到/data/dataset_info.json中,什么叫“注册”呢,就是说参照它里面已经有的数据集注册信息的格式,再添加一个键到其中,比如:

 "给这个数据集取一个名字(传参时使用)": {
    "file_name": "把你的数据集按照上面说的保存成一个文件,也放在这个目录下,这里填文件的名字,如xxx.json",
    "file_sha1": "可以用一些算sha1值的函数对文件算一下,也可以省略这个键",
    
    # 这里是关键,llama-factory实际使用的是prompt,query这些键,你要在这里完成映射关系的描述,这也是上面说 instruction,input,output这些键你想叫什么名字都可以的原因
    "columns": {
      "prompt": "instruction",
      "query": "input",
      "response": "output",
      "history": "history"
    }

这其中,query对应的列的内容会拼接在prompt列对应的内容后面,变成{prompt\nquery}一起作为模型的输入,response列表示这个样本的期望输出,会用于计算loss

在启动时,通过--stage参数告诉llama-factory用什么方式使用这个数据集,比如说--stage pt,那么llama-factory就只会使用prompt列对应的内容,response列的内容会忽略(我们说了pt模式的输入和输出一般是一样的)。因此,我们要做的是根据我们的场景,是预训练(prompt),有监督(query prompt response)还是强化学习(response)等,是多轮对话(history)还是单纯的补全,把它们会涉及的数据对应的键,都映射到正确的数据集里的键上,具体的参考LLama-factory/data/README.md即可。

可能遇到的问题

我们在构造数据集.json文件的时候,可能有的人会用一个这样的脚本,伪代码如下:

# 参考官方示例创建一个空list存样本
samples = []
# 遍历自己的原始数据源
for data in src:
	# 各种处理
	process...sample...
	
	# 处理完了以后变成一个字典结构
	sample = {
		'instruction': xxxxx,
		'input': ....,
		'output': xxxx,
	}
	
	samples.append(sample)

# 保存成.json文件
with open(.....) as wf:
	json.dumps(samples, wf)

这里面有两个问题:

  1. 如果你的数据集特别大,或者一个样本包含的信息特别多,内存里要维护一个超级大的list,可能会导致你处理过程中就内存溢出了
  2. json.dumps实际上会构造出一个超长string,llama-factory里面的读取函数可能是基于transformers.load_datasets,这个函数使用pyarrow去读取字符串,读特别长的json会卡住,我就遇到了load不报错但是也不运行的情况

实际上,我们并不一定非要构造.json数据集,构造一个.jsonl数据集也是完全可以运行的,并且pyarrow更喜欢,伪代码如下

with open(保存数据集,'w', encoding='utf-8') as wf:
	# 遍历自己的原始数据源
	for data in src:
		# 各种处理
		process....sample...
		
		# 处理完以后变成一个字典结构
		sample = {
			....
		}
		
		wf.write(json.dumps(sample, ensure_ascii=False) + '\n')

这样就行了,既不用维护一个巨大的list在内存中,也不用担心读取的时候出问题,使用方法上没有任何变化,还是一样把这个文件注册到dataset_info.json里面。

训练参数

template参数:决定数据集中的prompt和response如何连接

数据准备好以后,最重要的就是--template参数了,每个大模型都有一些自己的special tokens,比如对话大模型往往会有user,assistant这样的标识,把这些标识插入到prompt和response之间,才能构成一个完整的模型输入。而我们准备数据的时候,只需要准备prompt,response的内容,这些标识是不需要我们插入的。

不要想当然地把模型的名字填上去,不如亲自去源码里看看

我们前往/src/llmtuner/data/template.py能够看到所有的template参数都会对我们的数据做什么,比如下面是deepseek两个模板对应的处理
在这里插入图片描述

有的时候我们想要使用的标识和提供的template不同,比如我们使用的base模型,其预训练任务就是直白的输入+输出,没有什么user,assistant;再比如我们做代码补全任务,有自己的special token想要插入,prompt形如<fim-prefix>xxxx<fim-suffix>xxxx<fim-middle>

这种情况我比较推荐使用vanilla作为template的参数,没有做任何的添加,我们可以在构造数据集的阶段,就自由地把special tokens和我们的文本内容拼接好,一起放入prompt对应的键中。
在这里插入图片描述

注意不是default,default仍然会添加Human, Assistant
在这里插入图片描述

资源相关的参数

官方的readme里面给出的单卡、多卡训练示例已经很详细了,这里不多赘述。

多卡训练时,可能没接触过的人会有些疑惑ds_config.json写什么样子

deepspeed --num_gpus 8 src/train_bash.py \
    --deepspeed ds_config.json \
    --ddp_timeout 180000000 \
    ... # arguments (same as above)

这里对于初学者的实验阶段,我建议直接copy官方示例构造一个ds_config.json也可以,直接去掉这个参数也可以,先跑通,再根据实际需要回来调整,一步一步来,总会逐渐了解的。

比较常见的是使用实验室或者公司的公共服务器的场景,需要指定在哪些卡上面训练,添加include按如下形式指定即可:

deepspeed --num_gpus 2 --master_port=9901 --include localhost:2,3

其他的和多卡训练相关的参数,比如每张卡的batch size等,理解都比较直接,自行查阅。运行起来以后,根据每个batch差不多要的时间估算一下,再根据自己的耗时需求调整即可。

这里可能会遇到一个小问题:起训练任务起失败了以后,master_port显示被占用,不得不换一个port。这主要是因为任务挂了以后,程序清理不干净,比如网络方面还占着。建议你使用ps -ef | grep 你的用户名 去检查是不是还有和刚刚那个挂掉的任务相关的进程,kill掉即可

其他参数

实际上llama-factory有很多训练参数可以设置,并不局限于示例中给出的参数,你应该积极地去/src/llmtuner/hparams这里看看,以实现你的需求

这里举例几个我的需求:

  1. 我的数据集非常大,需要加速generating train split的过程。

    增加--num_workers 8

  2. 我希望从训练集里面切一个小验证集,在训练过程中每个epoch结束的时候,在验证集上eval一遍,并保存结果在log中

​ 增加--val_size 0.01 --evaluation_strategy epoch等,思路和transformers的trainer差不多

  1. 同时,训练完毕输出的loss曲线图片,我也需要这个验证集上的loss曲线

    设置--plot_loss参数

  2. greedy推理

    如果你去了我说的位置看,就会发现默认temperature=0.95, top_p=0.7,如果我希望unset,并设置do_sample=False,需要设置二者为1.0(transformers的default就是这样)
    在这里插入图片描述

另外一个最重要的事情是,学会在issue中搜索,会有很多同样的问题已经得到了解答

训练完毕后的推理

训练完毕后,我们要在自己的测试集或者公用测试集上测试模型的效果:

如果是公用测试集,比如MMLU之类的任务,你可以直接参考官方的readme,使用evaluation即可

如果是自己的测试集,我们往往需要测试这个测试集上的指标,同样的:

  1. 首先把你的测试集,像训练集一样构造成一个.json / .jsonl,并注册到dataset_info.json中
  2. 参考官方的demo predict的部分即可

如果不急于算指标,只是想要看看具体的case,官方也提供了命令行demo和浏览器的demo,照抄着改就可以了,应该不难。

我遇到的问题是,我的测试集特别大,使用predict非常慢,并且得到的结果只保存了label和结果,我希望快速推完测试集,并且保存好输入、输出、label,这样方便我后续自己的可视化。这种场景,最好是将模型部署成一个服务,自己写一个脚本去发请求,边发请求边把自己想要保存的样本保存到文件中。

关于模型部署,为了避免本文过长就暂时不多叙述,可以用以下方法:

  1. TGI部署:使用llama-factory的模型导出功能(如果你是lora微调的就会顺便merge weights),将导出的模型用TGI部署。

​ 比较推荐,TGI部署很简单,适配非常多种模型,所以很适合实验阶段使用

  1. vllm部署:vllm相比之下更适合生产环境,面向高并发的真实场景,如果涉及前后处理策略、时延等方面的测试,建议保持和线上一致
  2. llama-factory的api demo,其实应该差不多就是把上面的封装了,我没有使用过,但是思路是一样的,把模型变成一个服务,可以参考官方的demo学习使用。

总结

本文介绍了初学者如何使用llama-factory这个工具进行大模型的微调任务,包含用自己的数据构造训练、测试数据,起训练任务时候的相关参数,训练完毕后的测试集推理环节等,虽然可能不够全面,但多多少少以授人以渔的思路介绍了应该去源码的什么位置获得进一步的信息,以满足文中没有覆盖到的需求。

值得一提的是,因为作者的使用经验也有些时日,llama-factory也一直保持着更新,文中的一些内容可能有谬误,一切以源码为主。初学者应该养成看源码、看issue,自行找答案的能力。

如果有任何我理解有误的地方,还希望多多指正,感激不尽!下一个新技能点再见!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1640611.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux-管道通信

1. 管道概念 管道&#xff0c;是进程间通信的一种方式&#xff0c;在Linux命令中“ | ”就是一种管道&#xff0c;它可以&#xff0c;连接前一条命令&#xff0c;和后一条命令&#xff0c;把前面命令处理完的内容交给后面&#xff0c;例如 cat filename | grep hello …

富文本编辑器CKEditor4简单使用-07(处理浏览器不支持通过工具栏粘贴问题 和 首行缩进的问题)

富文本编辑器CKEditor4简单使用-07&#xff08;处理浏览器不支持通过工具栏粘贴问题 和 首行缩进的问题&#xff09; 1. 前言——CKEditor4快速入门2. 默认情况下的粘贴2.1 先看控制粘贴的3个按钮2.1.1 工具栏粘贴按钮2.1.2 存在的问题 2.2 不解决按钮问题的情况下2.2.1 使用ct…

三维图形学知识分享---求平面与模型相交线

在CGAL&#xff08;Computational Geometry Algorithms Library&#xff09;中&#xff0c;Polygon_mesh_processing模块提供了用于处理多边形网格数据结构的功能。其中&#xff0c;surface_intersection函数是用来计算模型的表面相交线的工具。 CGAL_Mesh mesh_orcl;std::vect…

C++ 函数 参数与返回值

#一 参数与返回值 回顾文件读数据功能 文件读数据 1函数参数传值调用过程 将函数调用语句中的实参的一份副本传给函数的型材。 简单的值的传递&#xff0c;实参的值没有发生变化。 2 函数参数传值调用过程 传地址调用 将变量的地址传递给函数的形参 形参和实参指向了同…

SpringBoot文件上传+拦截器

1、resource static下有个图片&#xff0c;希望浏览器可以查看这个图片 访问&#xff1a; 若yml设置路径&#xff0c;则可以定义在static下才可以访问 classpath代表类路径&#xff0c;都在target下 也就是项目在运行后的resource下的文件都会到classes下去 无需在target下创…

MES(制造执行系统)与PDCA循环,斩不断理还乱的关系。

MES系统算是B端系统中比较复杂的一种&#xff0c;这与我国制造业标准化程度较低有一定的关联&#xff0c;MES的存在就是要更好执行PDCA循环&#xff0c;二者关联是千丝万缕的&#xff0c;B系统提升专家借此为大家分享一下。 一、什么是PDCA PDCA&#xff08;Plan-Do-Check-Ac…

前端Web开发基础知识

HTML定义 超文本标记语言&#xff08;英语&#xff1a;HyperText Markup Language&#xff0c;简称&#xff1a;HTML&#xff09;是一种用于创建网页的标准标记语言。 什么是 HTML? HTML 是用来描述网页的一种语言。 HTML 指的是超文本标记语言: HyperText Markup LanguageH…

# IDEA 复制项目 Module 出现 不同模块下的 Product 类报错

IDEA 复制项目 Module 出现 不同模块下的 Product 类报错 我们 用 IDEA 复制项目 Module 出现 不同模块下的 Product 类报错&#xff0c;发现复制的 module 名称没有改变或者 java 文件夹后面还有原项目 source root 字样&#xff0c;maven 父子项目没有标识等问题。 解决方法…

QQ+微信聊天记录分析工具,allin~

QQ群 ... QQ个人 微信群 个人朋友圈 更多维度有待探索~ 工具下载 TencentRecordAnalysisV1.0.2.zip 蓝奏云&#xff1a;链接: lanzoub.com/b00rn0g47e 密码:9hww 百度云&#xff1a;链接: pan.baidu.com/s/1Gf5EpJ 提取码: hp2p

Stm32CubeMX 为 stm32mp135d 添加 adc

Stm32CubeMX 为 stm32mp135d 添加 adc 一、启用设备1. adc 设备添加2. adc 引脚配置2. adc 时钟配置 二、 生成代码1. optee 配置 adc 时钟和安全验证2. linux adc 设备 dts 配置 bringup 可参考&#xff1a; Stm32CubeMX 生成设备树 一、启用设备 1. adc 设备添加 启用adc设…

R语言学习—1—将数据框中某一列数据改成行名

将数据框中某一列数据改成行名 代码 结果

DHCPv4_CLIENT_ALLOCATING_03: 发送DHCPREQUEST - 必须包含‘服务器标识符‘

测试目的&#xff1a; 验证客户端发送的DHCPREQUEST消息中是否包含“服务器标识符”选项&#xff0c;以指示它选择的服务器。 描述&#xff1a; 本测试用例旨在确保DHCP客户端在广播DHCPREQUEST消息时&#xff0c;必须包含“服务器标识符”选项。该选项用于指明客户端选择了…

Universal Thresholdizer:将多种密码学原语门限化

参考文献&#xff1a; [LS90] Lapidot D, Shamir A. Publicly verifiable non-interactive zero-knowledge proofs[C]//Advances in Cryptology-CRYPTO’90: Proceedings 10. Springer Berlin Heidelberg, 1991: 353-365.[Shoup00] Shoup V. Practical threshold signatures[C…

[嵌入式系统-53]:嵌入式系统集成开发环境大全 ( IAR Embedded Workbench(通用)、MDK(ARM)比较 )

目录 一、嵌入式系统集成开发环境分类 二、由MCU芯片厂家提供的集成开发工具 三、由嵌入式操作提供的集成开发工具 四、由第三方工具厂家提供的集成开发工具 五、开发工具的整合 5.1 Keil MDK for ARM 5.2 IAR Embedded Workbench&#xff08;通用&#xff09;、MDK&…

240503-关于VisualStudio2022社区版的二三事

240503-关于VisualStudio2022社区版的二三事 1 常用快捷键 快捷键描述AltEnter选中代码片段以提取方法Alt上下箭头移动选中的代码片段F12转到方法定义CtrlR*2批量修改选中的变量名称 2 自动生成构造函数 3 快速重写父类方法 4 节约时间&#xff1a;写代码使用“头插法”&…

深度解析 Spring 源码:从BeanDefinition源码探索Bean的本质

文章目录 一、BeanDefinition 的概述1.1 BeanDefinition 的定位1.2 BeanDefition 的作用 二、BeanDefinition 源码解读2.1 BeanDefinition 接口的主要方法2.2 BeanDefinition 的实现类2.2.1 实现类的区别2.2.2 setBeanClassName()2.2.3 getDependsOn()2.2.4 setScope() 2.3 Bea…

用双目相机实现坐标标定

一&#xff1a;相机参数设置和计算 镜头参数&#xff1a;MF2808-10MP 靶面尺寸2/3 &#xff0c;视场角&#xff08;对角水平垂直&#xff09; 69.758.545.5 焦距&#xff1a;8mm&#xff0c;分辨率&#xff1a;16241240 1.1视场角的计算 图像分辨率越高&#xff0c;双目匹…

FP16、BF16、INT8、INT4精度模型加载所需显存以及硬件适配的分析

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…

Arduino 推出带 Wi-Fi的 32 位 UNO 板

Arduino 推出了下一代 UNO 板&#xff0c;引入了 32 位 Renesas 微控制器和 Espressif ESP32-S3 模块、一键云连接和大量 I/O 以及 128 红色 LED 矩阵。新型 UNO R4 板有两个版本&#xff0c;带 Wi-Fi 连接和不带 Wi-Fi 连接&#xff0c;并保持了 UNO R3 的外形尺寸、屏蔽兼容性…

分布式事务—> seata

分布式事务之Seata 一、什么是分布式事务&#xff1f; 分布式事务是一种特殊类型的事务&#xff0c;它涉及多个分布式系统中的节点&#xff0c;包括事务的参与者、支持事务的服务器、资源服务器以及事务管理器。 在分布式事务中&#xff0c;一次大型操作通常由多个小操作组成…