LLM - 大模型速递之 Yi-34B 入门与 LoRA 微调

news2024/11/26 21:38:22

一.引言

目前国内大部分开源模型都集中在 7B、13B,而国外开源模型则是集中在 7B、13B、70B 的尺寸范围,算法开发很需要一个介于 13B-70B 的大模型,弥补 13B 模型能力不足和 70B 模型显卡不够的空档。虽然 LLaMA-1-33B 有一些衍生的 Chinese 版本,但是 LLaMA2 后期并未更新维护该模型,作者在测试中发现 LLaMA-1-33B 能力与新版的 Baichuan-2-13B 相近,所以放弃了这款 33B 模型。11 月零一万物正式开源发布首款预训练大模型 Yi-34B,今天也顺便分享下 Yi-34B 模型以及其 LoRA 微调,有需要的同学欢迎评论区交流讨论~

二.零一万物

1.模型简介

模型地址: https://huggingface.co/01-ai/Yi-34B-Chat

此次发布包含两个基于先前发布的基本模型的聊天模型,两个由 GPTQ 量化的8位模型,两种由 AWQ 量化的 4 位模型:

大家可以在 Hugging-Face 官网下载模型,这里我们使用 Yi-34B-Chat 模型。

2.模型评估

◆ Base 模型表现

◆ Chat 模型表现

除此之外还有量化的模型对比,整体来说,国内开源的网站在 Model Performance 上一般都是 SOTA 的,不过表现好坏还是得实际下下来测测看,后面我们也会把模型拿下来看下怎么事。

3.模型测试

为了使用该模型,建议更新 Transformer 版本 >= 4.36.0: 

from transformers import AutoModelForCausalLM, AutoTokenizer

model_path = '01-ai/Yi-34b-Chat'

tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)

# Since transformers 4.35.0, the GPT-Q/AWQ model can be loaded using AutoModelForCausalLM.
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    torch_dtype='auto'
).eval()

# Prompt content: "hi"
messages = [
    {"role": "user", "content": "hi"}
]

input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
output_ids = model.generate(input_ids.to('cuda'))
response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)

# Model response: "Hello! How can I assist you today?"
print(response)

其对应的 tmplate 模板如下,可以在 tokenizer_config.json 文件在找到:

<|im_start|>system
{system_message}<|im_end|>
<|im_start|>user
{prompt}<|im_end|>
<|im_start|>assistant

三.LoRA 微调

1.微调工具

GitHub: https://github.com/hiyouga/LLaMA-Factory

◆ 支持场景

微调我们选择 LLaMA-Factory 框架,之前介绍的 Baichuan、ChatGLM 微调也是基于该框架实现 LoRA 微调。目前框架已支持 Full-Parameter、Partial-Parameter、LoRA 和 QLoRA 以及 PT、SFT、RM、PPO 、DPO 的全套流程:

◆ 硬件要求

这里我们 LoRA 微调 Yi-34B,需要 80 GB,正好对应单卡 A800,如果使用 P40-24G 需要 4 台,A100-32G 需要 3 台:

◆ 环境配置 

LLaMA-Factory 需要上述依赖,下载对应代码后,创建 Python 环境安装 requirements 即可。 

git clone https://github.com/hiyouga/LLaMA-Factory.git
conda create -n llama_factory python=3.10
conda activate llama_factory
cd LLaMA-Factory
pip install -r requirements.txt

如果 pip install 比较卡顿,可以尝试切换 pip 源提高安装速度:

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

 

2.微调代码

◆ 运行脚本

CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
    --stage sft \
    --model_name_or_path path_to_llama_model \
    --do_train \
    --dataset alpaca_gpt4_en \
    --template default \
    --finetuning_type lora \
    --lora_target q_proj,v_proj \
    --output_dir path_to_sft_checkpoint \
    --overwrite_cache \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --lr_scheduler_type cosine \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate 5e-5 \
    --num_train_epochs 3.0 \
    --plot_loss \
    --fp16

这里是对应 train_sft.sh 的内容,具体执行时还需要和你的 GPU 环境所在机器匹配,有的是直接绑定 GPU 的实体机、也有 Docker,在对应环境执行上述脚本即可。

Tips:

训练过程中可能出现 torch.cuda.OutOfMemoryError: CUDA out of memory. 的错误,此时需要调小 batch_size,或者将 fp32 修改为 fp16;修改后依然报错 OOM 则需要开启 QLoRA 量化处理:

--quantization_bit 4/8 \

3.微调参数

◆ 参数解析 

model_name_or_path - 指向对应开源模型的地址

dataset - 指向训练数据标识,这里要求数据格式为 json,并且配置在 data/dataset_info.json 内

template - 指向模型对应模板

lora_target - 用于指定需要 LoRA 微调的 Layer Name

output_dir - 模型微调后的存储地址

per_device_train_batch_size - 每个设备的训练 batch_size

gradient_accumulation_steps - 梯度累计更新的 step

save_steps - 存储 checkpoint 的 step 数

num_train_epochs - 训练的 Epoch 数量

支持模型

Default module 对应 lora_target 参数,用来指定 LoRA 微调的模型 Layer,Template 对应 template 模板参数,主要适配模型原始模板,避免模型训练和输出异常。

Tips:

这里没有给出 Yi-34B 的信息,可以在源码中找到,这里直接给出:

lora_target='k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj'
template='yi'

从这个 lora_target 不难看出 Yi-34B 框架师承 LLaMA-2,LLaMA-2 框架有需要可以参考:

LLM - Transformer && LLaMA2 结构分析与 LoRA 详解

4.微调流程

构建数据集

从 dataset 参数对应的训练集地址读取对应地址的 json 训练数据,关于数据处理可以参考:

LLM - 数据处理之 Process Dataset For LLM With PT、SFT、RM

模型配置

从 architectures 也再次印证 Yi-34B 师承 LLaMA-2 了,剩下一些之前分析过的参数,例如 Silu 激活函数、7168 的 Hidden Size、4096 的 max_position_embeddings 以及 vocab_size 64000 的词库。按照 requirements.txt 的要求,这里需要 transformer 的版本为 >= 4.36.0。

模型读取

模型读取一共耗时 10min+,30B+ 的模型读起来还是比 13B 慢很多:

Tokenizer 数据

训练前需要将对应的训练数据使用 Tokenizer.model 进行 Token 化,转换为 TokenIds 传递给后续的 Transformers 使用。

模型训练

logging_steps 参数控制打印的频率,出现下述日志以及对应的训练信息且 Loss 正常降低代表训练正常,如果没训练多久 Loss 突降为 0.0 大概率为训练数据有问题,可以人工查看下有无异常。

save_steps 参数控制 checkpoint 的保存频率,在对应 output 目录下可以查看训练存储的多个 CKPT,下面为一个 CKPT 存储的信息:

Tips:

训练完毕后可以加载对应 LoRA Weights 进行后续的预测推理工作,可以参考:

LLM - LoRA 模型合并与保存

显存占用

训练我们使用单张 A-800 执行,使用 --fp16 精度,batch_size 取 4 时会出现 OOM,修改为 batch_size=1 后训练正常,此处显存占用大约为 72G+,如果使用多卡可以使用 accelerate 加载多卡配置进行训练。

四.总结

上面介绍了国产开源的 Yi-34B 以及其 LoRA 微调训练的流程,按照 Hugging Face 上的评测,其能力已经直逼 GPT-4,光说不练假把式,后期博主也会实际测试下相同问题二者的回复效果。除此之外,最近新出的 MOE Mistral-8x7B 也大放异彩,后续博主也会分享其训练流程,其参考了深度学习里 MOE 专家模型的特性,同时使用 8 个 7B 模型进行训练推理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1319178.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Golang清晰代码指南

发挥易读和易维护软件的好处 - 第一部分 嗨&#xff0c;开发者们&#xff0c;清晰的代码是指编写易于阅读、理解和维护的软件代码。它是遵循一组原则和实践&#xff0c;优先考虑清晰性、简单性和一致性的代码。清晰的代码旨在使代码库更易管理&#xff0c;减少引入错误的可能性…

Go语言并发编程:死锁预防的性能优化之旅

文章目录 引言:Go并发编程的挑战与机遇Go并发的特点并发编程的挑战死锁对性能的影响文章概览死锁基础:原因、类型和识别死锁的定义死锁产生的原因死锁的类型识别死锁的方法代码示例:简单的死锁3. 预防策略:编写无死锁的Go代码理解并正确使用锁合理使用通道和goroutines侦测…

C# 命令行参数解析库示例

写在前面 在日常开发中&#xff0c;我们经常会用到命令行参数&#xff0c;比如cmd下的各种指令&#xff1b;还有C#的控制台类型的项目&#xff0c;在默认入口Main函数中&#xff0c;那个args参数&#xff0c;就是有系统传入到程序进程的命令行参数&#xff1b;在传入的参数相对…

✺ch3——数学基础

目录 3D坐标系和点矩阵单位矩阵转置矩阵逆矩阵逆转置矩阵矩阵的运算矩阵加法()矩阵乘法() 常用的变换矩阵平移矩阵缩放矩阵旋转矩阵透视矩阵正射投影矩阵LookAt矩阵 向量加法和减法点积叉积 局部空间和世界空间——模型矩阵M视觉空间和合成相机——模型-视图矩阵MV用GLSL函数构…

机器学习算法---异常检测

类别内容导航机器学习机器学习算法应用场景与评价指标机器学习算法—分类机器学习算法—回归机器学习算法—聚类机器学习算法—异常检测机器学习算法—时间序列数据可视化数据可视化—折线图数据可视化—箱线图数据可视化—柱状图数据可视化—饼图、环形图、雷达图统计学检验箱…

RTOS队列的写入与读出

我们在stm32f103c8t6单片机上验证RTOS队列的写入与读出&#xff0c;利用stm32cube进行RTOS的配置。在选择TIM2当做RTOS的时钟&#xff0c;裸机的时钟源默认是 SysTick&#xff0c;但是开启 FreeRTOS 后&#xff0c;FreeRTOS会占用 SysTick &#xff08;用来生成1ms 定时&#x…

flask简单应用-1

目标&#xff1a; 做一个搜索网页&#xff0c;搜索当前路径下是否含有指定关键字的文件&#xff0c;如果有就列出来&#xff0c;没有返回消息 第一步&#xff1a;我们需要先显示一个搜索页面&#xff0c;页面上需要有一个可以输入的对话框&#xff0c;一个按钮执行搜索 建立ht…

Vue3-05-计算属性使用详解

计算属性简介 计算属性的函数是 computed()。计算属性可以帮助我们处理有复杂逻辑的响应式数据的渲染&#xff0c; 从而代替 模板表达式 的写法。比如 &#xff1a; 一个数值类型的数组对象&#xff0c;我们希望页面展示的只有 偶数。 此时&#xff0c;就可以通过 计算属性 来…

02.Git常用基本操作

一、基本配置 &#xff08;1&#xff09;打开Git Bash &#xff08;2&#xff09;配置姓名和邮箱 git config --global user.name "Your Name" git config --global user.email "Your email" 因为Git是分布式版本控制工具&#xff0c;所以每个用户都需要…

手拉手EasyExcel极简实现web上传下载(全栈)

环境介绍 技术栈 springbootmybatis-plusmysqleasyexcel 软件 版本 mysql 8 IDEA IntelliJ IDEA 2022.2.1 JDK 1.8 Spring Boot 2.7.13 mybatis-plus 3.5.3.2 EasyExcel是一个基于Java的、快速、简洁、解决大文件内存溢出的Excel处理工具。 他能让你在不用考虑性…

【PostgreSQL】从零开始:(十三)PostgreSQL-SQL语句操作架构(模式) Schema

Schema概述 PostgreSQL 数据库集群包含一个或多个命名数据库。角色和一些其他对象类型在整个集群中共享。与服务器的客户端连接只能访问单个数据库中的数据&#xff0c;该数据库在连接请求中指定。 用户不一定有权访问集群中的每个数据库。共享角色名称意味着不能在同一集群中…

IDEA2023 + spring cloud 工程热部署设置方法

基于spring cloud 工程进行热部署 &#xff0c;实现每次修改工程源文件&#xff0c;后台自动启动&#xff0c;方便开发测试工作。具体分为5步骤即可&#xff1a; 1、修改工程的pom文件&#xff0c;增加adding devtools 工具包。 <dependency> <groupId>org.s…

1264. 动态求连续区间和(树状数组---某个位置加上一个数/求在线(动态)前缀和/蓝桥杯)

题目&#xff1a; 输入样例&#xff1a; 10 5 1 2 3 4 5 6 7 8 9 10 1 1 5 0 1 3 0 4 8 1 7 5 0 4 8输出样例&#xff1a; 11 30 35 树状数组&#xff1a; 代码&#xff1a; #include<cstdio> #include<iostream> using namespace std;const int N100010; int n,…

【elementui笔记:el-table表格的输入校验】

之前做得比较多的校验是在el-form表单里做的&#xff0c;但有时也遇到&#xff0c;需要在table内输入数据&#xff0c;然后校验输入的数据是否符合要求的情况。因此记录一下。 思路&#xff1a; 1.需要借助el-form的校验&#xff0c;el-table外层嵌套一层el-form&#xff0c;使…

数据分析为何要学统计学(4)——何为置信区间?它有什么作用?

置信区间是统计学中的一个重要工具&#xff0c;是用样本参数()估计出来的总体均值在某置信水平下的范围。通俗一点讲&#xff0c;如果置信度为95%&#xff08;等价于显著水平a0.05&#xff09;&#xff0c;置信区间为[a,b]&#xff0c;这就意味着总体均值落入该区间的概率为95%…

宏基因组学Metagenome-磷循环Pcycle功能基因分析-从分析过程到代码及结果演示-超详细保姆级流程

大背景介绍 生信分析,凡事先看论文,有了论文就有了参考,后续分析就有底了,直接上硬菜开干: PCycDB: a comprehensive and accurate database for fast analysis of phosphorus cycling genes - PubMed 数据库及部分分析代码github库: GitHub - ZengJiaxiong/Phospho…

7.实现任务的rebalance

1.设计 1.1 背景 系统启动后&#xff0c;所有任务都在被执行&#xff0c;如果这时某个节点宕机&#xff0c;那它负责的任务就不能执行了&#xff0c;这对有稳定性要求的任务是不能接受的&#xff0c;所以系统要实现rebalance的功能。 1.2 设计 下面是Job分配与执行的业务点…

深度学习中的潜在空间

1 潜在空间定义 Latent Space 潜在空间&#xff1a;Latent &#xff0c;这个词的语义是“隐藏”的意思。“Latent Space 潜在空间”也可以理解为“隐藏的空间”。Latent Space 这一概念是十分重要的&#xff0c;它在“深度学习”领域中处于核心地位&#xff0c;即它是用来学习…

【每日一题】寻找峰值

文章目录 Tag题目来源解题思路方法一&#xff1a;二分查找 写在最后 Tag 【二分查找】【数组】【2023-12-18】 题目来源 162. 寻找峰值 解题思路 方法一&#xff1a;二分查找 思路 进行二分查找&#xff0c;记当前的二分中点为 mid&#xff1a; 如果 nums[mid] < nums…

UE4 去除重复纹理

如果直接连的话&#xff0c;效果如下&#xff1a; 就存在很多重复的纹理&#xff0c;如何解决这个问题呢&#xff1f; 将同一个纹理&#xff0c;用不同的Tilling&#xff0c;将Noise进行Lerp两者之间&#xff0c;为什么要这么做呢&#xff1f;因为用一个做清晰纹理&#xff0c;…