[NLP] LLM---<训练中文LLama2(四)方式一>对LLama2进行SFT微调

news2025/1/11 22:58:12

指令精调

指令精调阶段的任务形式基本与Stanford Alpaca相同。训练方案也采用了LoRA进行高效精调,并进一步增加了可训练参数数量。在prompt设计上,精调以及预测时采用的都是原版Stanford Alpaca不带input的模版。对于包含input字段的数据,采用f"{instruction}+\n+{input}"的形式进行拼接。

其中,Stanford Alpaca 格式如下所示:

[
  {"instruction" : ...,
   "input" : ...,
   "output" : ...},
  ...
]

首先,修改模型精调脚本run_sft.sh,需要修改的参数如下:

  • --model_name_or_path: 模型经过词表扩充并完成预训练进行权重合并之后所在的目录
  • --tokenizer_name_or_path: Chinese-Alpaca tokenizer 所在的目录
  • --dataset_dir: 指令精调数据的目录,包含一个或多个以json结尾的Stanford Alpaca格式的指令精调数据文件
  • --validation_file: 用作验证集的单个指令精调文件,以json结尾,同样遵循Stanford Alpaca格式
  • --output_dir: 模型权重输出路径
dataset_dir=./sft_dataset/train = Chinese-LLaMA-Alpaca/data

其他参数(如:per_device_train_batch_size、training_steps等)是否修改视自身情况而定。

# 运行脚本前请仔细阅读wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/sft_scripts_zh)
# Read the wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/sft_scripts_zh) carefully before running the script
lr=1e-4
lora_rank=64
lora_alpha=128
lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj"
modules_to_save="embed_tokens,lm_head"
lora_dropout=0.05

pretrained_model=./merged_output_dir
chinese_tokenizer_path=./merged_output_dir
dataset_dir=./sft_dataset/train
per_device_train_batch_size=1
per_device_eval_batch_size=1
gradient_accumulation_steps=8
max_seq_length=512
output_dir=./sft_output_dir
validation_file=./sft_dataset/test/test.json

deepspeed_config_file=ds_zero2_no_offload.json

torchrun --nnodes 1 --nproc_per_node 1 run_clm_sft_with_peft.py \
    --deepspeed ${deepspeed_config_file} \
    --model_name_or_path ${pretrained_model} \
    --tokenizer_name_or_path ${chinese_tokenizer_path} \
    --dataset_dir ${dataset_dir} \
    --per_device_train_batch_size ${per_device_train_batch_size} \
    --per_device_eval_batch_size ${per_device_eval_batch_size} \
    --do_train \
    --do_eval \
    --seed $RANDOM \
    --fp16 \
    --num_train_epochs 1 \
    --lr_scheduler_type cosine \
    --learning_rate ${lr} \
    --warmup_ratio 0.03 \
    --weight_decay 0 \
    --logging_strategy steps \
    --logging_steps 10 \
    --save_strategy steps \
    --save_total_limit 3 \
    --evaluation_strategy steps \
    --eval_steps 100 \
    --save_steps 200 \
    --gradient_accumulation_steps ${gradient_accumulation_steps} \
    --preprocessing_num_workers 8 \
    --max_seq_length ${max_seq_length} \
    --output_dir ${output_dir} \
    --overwrite_output_dir \
    --ddp_timeout 30000 \
    --logging_first_step True \
    --lora_rank ${lora_rank} \
    --lora_alpha ${lora_alpha} \
    --trainable ${lora_trainable} \
    --lora_dropout ${lora_dropout} \
    --modules_to_save ${modules_to_save} \
    --torch_dtype float16 \
    --validation_file ${validation_file} \
    --load_in_kbits 16 \
    --gradient_checkpointing \
    --ddp_find_unused_parameters False

run_clm_sft_with_peft.py  添加如下两行:

为了测试,对数据进行了sample

# coding=utf-8
import json

with open("alpaca_data_zh_51k.json", encoding="UTF-8") as f:
    data = json.load(f)
    print(len(data))
print(data[0])

import random

# 设置要划分的测试集大小
sample_size = int(0.1 * (len(data)))

# 随机选择测试集的元素
sample_set = random.sample(data, sample_size)

data = sample_set
# 设置要划分的测试集大小
test_size = int(0.1 * (len(data)))

# 随机选择测试集的元素
test_set = random.sample(data, test_size)

# 构建训练集,即剩下的元素
train_set = [x for x in data if x not in test_set]

print("训练集:", len(train_set))
print("测试集:", len(test_set))

with open("train/train.json", "w", encoding="UTF-8") as f:
    json.dump(train_set, f, indent=2, ensure_ascii=False)

with open("valid/test.json", "w", encoding="UTF-8") as f:
    json.dump(test_set, f, indent=2, ensure_ascii=False)

运行后输出:

中文LLaMA&Alpaca大语言模型词表扩充+预训练+指令精调 - 知乎 (zhihu.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1017153.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

竞赛 基于机器学习与大数据的糖尿病预测

文章目录 1 前言1 课题背景2 数据导入处理3 数据可视化分析4 特征选择4.1 通过相关性进行筛选4.2 多重共线性4.3 RFE(递归特征消除法)4.4 正则化 5 机器学习模型建立与评价5.1 评价方式的选择5.2 模型的建立与评价5.3 模型参数调优5.4 将调参过后的模型重…

yolov5在rk3588上加速

不采用fastdeploy等三方框架,使用rknn-lite2或者rknpu在rk3588上加速,测试加速的是rknn自带的yolov5模型。 备注: 1.测试视频:多人,帧:3000,时长:2min,分辨率:1920x1080,fps:25 2…

傅里叶变换应用 (02/2):频域和相位

一、说明 到目前为止,在我们的讨论中,我已经交替使用了“傅里叶变换”和“快速傅里叶变换(FFT)”。在这一点上,值得注意的是区别!FFT 是“离散”傅里叶变换 (DFT) 的有效算法实现。“…

Remix+Cloudflare Pages+D1 快速上手

我们最近听到越来越多的关于Cloudflare的服务。 我对Clouflare D1特别感兴趣,所以我决定研究一下。 与这次我想使用的 Remix 一起,我想介绍 Remix Cloudflare Pages D1 的第一步。 我只是稍微地了解一下,但我所做的在下面的仓库中&#…

【深度学习】 Python 和 NumPy 系列教程(十二):NumPy详解:4、数组广播;5、排序操作

目录 一、前言 二、实验环境 三、NumPy 0、多维数组对象(ndarray) 多维数组的属性 1、创建数组 2、数组操作 3、数组数学 4、数组广播 5、排序操作 1. np.sort() 函数 2. np.argsort() 函数 3. ndarray.sort() 方法 4. 按列或行排序 5. n…

VHDL菜鸟入门到精通之激励文件编写

目录 一、概览 二、激励文件结构 三、样例 3.1 组合逻辑 3.2 时序逻辑 四、常用编写 4.1 时钟信号 4.2 延时 4.3 循环 4.4 进程 一、概览 二、激励文件结构 VHDL激励文件结构和设计文件较为类似,下面以3-8译码器的激励文件对结构进行说明。 激励文件主要…

git clone报错Failed to connect to github.com port 443 after 21055 ms:

git 设置代理端口号 git config --global http.proxy http://127.0.0.1:10085 和 git config --global https.proxy http://127.0.0.1:10085 然后就可以成功git clone hugging face的数据集了 如果是https://huggingface.co/datasets/shibing624/medical/tree/main 那么…

logstash通过kafka通道采集日志信息

1.修改文件/opt/app/elk/logstash-7.5.1/config.d/config1.conf,在input下添加kafka采集配置 #192.168.128.130:9103:kafka地址 #topics:主题 kafka {bootstrap_servers > ["192.168.128.130:9103"]group_id > "logstash"topics > [&…

Optuna学习博客

介绍 简单来说,OPtuna就是一个能够进行调整超参数的框架,它能够将自动调整超参数以及能够将超参数优化过程可视化,方便保存,分析。可拓展性较强。 使用方法 optuna的优化程序具体有三个组成部分。 objective(目标函…

MySQL数据库管理及数据库基本操作

目录 1 MySQL数据库基本操作 1.1 SQL分类 1.2 SQL语言规范 1.3 数据库对象和命名 1.4 SQL语句分类 2 管理MySQL数据库 2.1 查看数据库结构 2.1.1 查看当前服务器中的数据库 2.1.2 查看数据库中包含的表 2.1.3 查看表的结构(字段) 2.2 数据类型…

【linux】进程创建,进程终止

进程创建,进程终止 1.进程创建1.1写时拷贝1.2fork常规用法1.3fork调用失败的原因 2.进程终止2.1退出码2.2进程退出场景2.3进程如何退出 1.进程创建 在前面创建子进程的时候就学过了fork函数,它能从已经存在进程中创建一个新进程,新进程为子进…

Python 图形化界面基础篇:打开和关闭新窗口

Python 图形化界面基础篇:打开和关闭新窗口 引言 Tkinter 库简介步骤1:导入 Tkinter 模块步骤2:创建 Tkinter 窗口步骤3:创建一个新窗口步骤4:关闭新窗口步骤5:启动 Tkinter 主事件循环 完整示例代码代码解…

C语言指针详解(4)———找工作必看指针笔试题汇总

指针对于编程工作的重要性 C语言指针在找工作中具有重要性。以下是几个原因: 1.高效的内存管理:C语言指针可以帮助程序员高效地管理内存,包括动态内存分配和释放,以及数据的访问和操作。这对于开发性能优化的应用程序非常重要&am…

7.代理模式

1.UML 2.代码 #include <iostream> using namespace std;class Subject{ public:virtual void Request() 0; };class RealSubject:public Subject { public:virtual void Request(){cout << "RealSubject" << endl;} }; class Proxy:public Subj…

VUE build:gulp打包:测试、正式环境

目录 项目结构 Gulp VUE使用Gulp Vue安装Gulp Vue定义Gulp.js package.json build文件夹 config文件夹 static-config文件夹 项目结构 Gulp Gulp是一个自动化构建工具&#xff0c;可以帮助前端开发者通过自动化任务来管理工作流程。Gulp使用Node.js的代码编写&#xff…

go初识iris框架(五) -MVC包的使用

在Iris框架中&#xff0c;封装了mvc包作为对mvc架构的支持&#xff0c;方便开发者遵循mvc的开发原则进行开发。 iis框架支持请求数据、模型、持久数据分层处理&#xff0c;并支持各层级模块代码绑定执行。 MVC即&#xff1a;model、view、controller三个部分&#xff0c;分别代…

【微信小程序】swiper的使用

1.swiper的基本使用 <jxz-header></jxz-header> <view class"banner"><swiperprevious-margin"30rpx"autoplayinterval"2000"indicator-dotsindicator-color"rgba(0,0,0,0.3)"indicator-active-color"#bda…

数字化管理平台建设实践

在勘察设计行业&#xff0c;各企业加速推进数字化转型。通过管理要素数字化&#xff0c;不断优化内部组织运营效率&#xff1b;通过生产手段数字化、技术产品数字化&#xff0c;提升服务质量&#xff0c;改善客户体验&#xff1b;通过数字化营销&#xff0c;精准对接市场需求&a…

Linux下的系统编程——信号(十一)

前言&#xff1a; 信号在我们的生活中随处可见&#xff0c; 如&#xff1a;古代战争中摔杯为号&#xff1b;现代战争中的信号弹&#xff1b;体育比赛中使用的信号枪...... 他们都有共性&#xff0c;信号是信息的载体&#xff0c;Linux/UNIX 环境下&#xff0c;古老、经典的通信…

基于Java+SpringBoot+Vue的图书借还小程序的设计与实现(亮点:多角色、点赞评论、借书还书、在线支付)

图书借还管理小程序 一、前言二、我的优势2.1 自己的网站2.2 自己的小程序&#xff08;小蔡coding&#xff09;2.3 有保障的售后2.4 福利 三、开发环境与技术3.1 MySQL数据库3.2 Vue前端技术3.3 Spring Boot框架3.4 微信小程序 四、功能设计4.1 主要功能描述 五、系统实现5.1 小…