ChatGLM系列五:Lora微调

news2024/9/29 3:35:32

目前主流对大模型进行微调方法有三种:Freeze方法、P-Tuning方法和Lora方法

LoRA: 在大型语言模型上对指定参数(权重矩阵)并行增加额外的低秩矩阵,并在模型训练过程中,仅训练额外增加的并行低秩矩阵的参数,冻结其他参数。 当“秩值”远小于原始参数维度时,新增的低秩矩阵参数量也就很小。在下游任务tuning时,仅须训练很小的参数,但能获取较好的表现结果。
在这里插入图片描述
LoRA: 在大型语言模型上对指定参数(权重矩阵)并行增加额外的低秩矩阵,并在模型训练过程中,仅训练额外增加的并行低秩矩阵的参数,冻结其他参数。 当“秩值”远小于原始参数维度时,新增的低秩矩阵参数量也就很小。在下游任务tuning时,仅须训练很小的参数,但能获取较好的表现结果。
在这里插入图片描述

下载代码

git clone https://github.com/liucongg/ChatGLM-Finetuning

环境配置

cpm_kernels==1.0.11
deepspeed==0.9.0
numpy==1.24.2
peft==0.3.0
sentencepiece==0.1.96
tensorboard==2.11.0
tensorflow==2.13.0
torch==1.13.1+cu116
tqdm==4.64.1
transformers==4.27.1

(1)、ChatGLM单卡训练

CUDA_VISIBLE_DEVICES=0 deepspeed --master_port 520 train.py \
              --train_path data/spo_0.json \
              --model_name_or_path ChatGLM-6B \
              --per_device_train_batch_size 1 \
              --max_len 1560 \
              --max_src_len 1024 \
              --learning_rate 1e-4 \
              --weight_decay 0.1 \
              --num_train_epochs 2 \
              --gradient_accumulation_steps 4 \
              --warmup_ratio 0.1 \
              --mode glm \
              --train_type lora \
              --lora_dim 16 \
              --lora_alpha 64 \
              --lora_dropout 0.1 \
              --lora_module_name "query_key_value" \
              --seed 1234 \
              --ds_file ds_zero2_no_offload.json \
              --gradient_checkpointing \
              --show_loss_step 10 \
              --output_dir ./output-glm

(2)、ChatGLM四卡训练

CUDA_VISIBLE_DEVICES=0,1,2,3 deepspeed --master_port 520 train.py \
              --train_path data/spo_0.json \
              --model_name_or_path ChatGLM-6B \
              --per_device_train_batch_size 1 \
              --max_len 1560 \
              --max_src_len 1024 \
              --learning_rate 1e-4 \
              --weight_decay 0.1 \
              --num_train_epochs 2 \
              --gradient_accumulation_steps 4 \
              --warmup_ratio 0.1 \
              --mode glm \
              --train_type lora \
              --lora_dim 16 \
              --lora_alpha 64 \
              --lora_dropout 0.1 \
              --lora_module_name "query_key_value" \
              --seed 1234 \
              --ds_file ds_zero2_no_offload.json \
              --gradient_checkpointing \
              --show_loss_step 10 \
              --output_dir ./output-glm

(3)、ChatGLM2单卡训练

CUDA_VISIBLE_DEVICES=0 deepspeed --master_port 520 train.py \
              --train_path data/spo_0.json \
              --model_name_or_path ChatGLM2-6B \
              --per_device_train_batch_size 1 \
              --max_len 1560 \
              --max_src_len 1024 \
              --learning_rate 1e-4 \
              --weight_decay 0.1 \
              --num_train_epochs 2 \
              --gradient_accumulation_steps 4 \
              --warmup_ratio 0.1 \
              --mode glm2 \
              --train_type lora \
              --lora_dim 16 \
              --lora_alpha 64 \
              --lora_dropout 0.1 \
              --lora_module_name "query_key_value,dense_h_to_4h,dense_4h_to_h,dense" \
              --seed 1234 \
              --ds_file ds_zero2_no_offload.json \
              --gradient_checkpointing \
              --show_loss_step 10 \
              --output_dir ./output-glm2

(4)、ChatGLM2四卡训练

CUDA_VISIBLE_DEVICES=0,1,2,3 deepspeed --master_port 520 train.py \
              --train_path data/spo_0.json \
              --model_name_or_path ChatGLM2-6B \
              --per_device_train_batch_size 1 \
              --max_len 1560 \
              --max_src_len 1024 \
              --learning_rate 1e-4 \
              --weight_decay 0.1 \
              --num_train_epochs 2 \
              --gradient_accumulation_steps 4 \
              --warmup_ratio 0.1 \
              --mode glm2 \
              --train_type lora \
              --lora_dim 16 \
              --lora_alpha 64 \
              --lora_dropout 0.1 \
              --lora_module_name "query_key_value,dense_h_to_4h,dense_4h_to_h,dense" \
              --seed 1234 \
              --ds_file ds_zero2_no_offload.json \
              --gradient_checkpointing \
              --show_loss_step 10 \
              --output_dir ./output-glm2

(5)、耗费显存资源占用对比—LoRA方法:对比ChaGLM和ChaGLM2

在这里插入图片描述注意:Lora方法在模型保存时仅保存了Lora训练参数,因此在模型预测时需要将模型参数进行合并,具体参考merge_lora.py。

四种微调资源耗费比较

在这里插入图片描述

结果分析:

  • 效果为PT>Freeze>Lora>PT-Only-Embedding;
  • 速度为PT-Only-Embedding>Lora>Freeze>PT;
  • PT-Only-Embedding效果很不理想,发现在训练时,最后的loss仅能收敛到2.几,而其他机制可以收敛到0.几。分析原因为,输出内容形式与原有语言模型任务相差很大,仅增加额外Embedding参数,不足以改变复杂的下游任务;
  • 由于大模型微调都采用大量instruction进行模型训练,仅采用单一的指令进行微调时,对原来其他的指令影响不大,因此并没导致原来模型的能力丧失;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1140507.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java中级面试题记录(四)

一面面试题 1.Innodb的行数据存储模式 https://baijiahao.baidu.com/s?id1775090633458928876&wfrspider&forpc 2.行数据包含哪些信息? https://baijiahao.baidu.com/s?id1775090633458928876&wfrspider&forpc 3.MySQL在进行存储VARCHAR的时…

qq怎么发长视频?超级好用!

在平时的工作和生活中,我们会想分享一些比较长的内容。但是我们会发现视频文件过大,可能会超过腾讯规定的单次发送文件的大小限制,导致无法发送成功。这时候就需要借助一些视频压缩工具,下面介绍了四种方法,一起来看看…

浅谈信息化与数字化

一、信息化/数字化的概念 信息化、数字化按字面意思理解,这两个词的确代表了不同的含义。但是也不可否认,在目前我们可以接触到的信息平台来看。信息化、数字化很多时候都被混在一起了。 那么,既然今天要聊这个话题。我们得先把这两个词分清…

【C++】:拷贝构造函数与赋值运算符重载的实例应用之日期类的实现

C实现日期类 ├─属性: │ ├─年份 │ ├─月份 │ └─日期 ├─方法: │ ├─构造函数 │ ├─拷贝构造函数 │ ├─析构函数 │ ├─设置年份 │ ├─设置月份 │ ├─设置日期 │ ├─获取年份 │ ├─获取月份 │ ├─获取日期 │ ├…

mysql-linux归档版安装

什么是归档版安装?简单来说就是编译好的软件压缩打包版。 说明:我这里服务器之前已经装过一个不同版本的mysql,已经占用了3306端口,所以这里我用3307端口来演示,命令和官方的稍有不同,不过步骤都是差不多的…

next项目部署到云服务器上(手动)

准备环境: 云服务器 ECS,服务器安装好了docker 自己的next项目 开始: 1.在next项目根目录下创建Dockerfile文件 FROM node:18-alpine AS base# Install dependencies only when needed FROM base AS deps # Check https://github.com/nodejs/docker-node/tree/b4117f9333d…

SpringMVC Day 05 : Spring 中的 Model

前言 欢迎来到 SpringMVC 系列教程的第五天!在之前的教程中,我们已经学习了如何使用控制器处理请求和返回视图。今天,我们将深入探讨 Spring 中的 Model。 在 Web 应用程序开发中,数据的传递和展示是非常重要的。SpringMVC 提供…

flutter版本选择

使用命令dart --version查看dart版本 使用命令flutter doctor查看flutter版本 Flutter 有 3 个发布渠道,分别是 stable、beta 和 master。我们推荐使用 stable 渠道除非你需要体验最新更新的 Flutter 特性。 要查看你当前使用的哪个渠道,使用下面的命令&…

山西电力市场日前价格预测【2023-10-28】

日前价格预测 预测说明: 如上图所示,预测明日(2023-10-28)山西电力市场全天平均日前电价为324.42元/MWh。其中,最高日前电价为601.09元/MWh,预计出现在18:15。最低日前电价为0.00元/MWh,预计出…

微信native支付对接

简介 Native支付是指商户系统按微信支付协议生成支付二维码,用户再用微信“扫一扫”完成支付的模式。 应用场景 Native支付适用于PC网站、实体店单品或订单、媒体广告支付等场景,用户扫描商户展示在各种场景的二维码进行支付。聚体步骤如下: 1.商户根据微信支付的规则,…

【EI会议投稿】第三届电气、控制与信息技术国际学术会议(ECITech 2024)

第三届电气、控制与信息技术国际学术会议(ECITech 2024) 2024 3rd International Conference on Electrical, Control and Information Technology 继往届ECITech年度系列会议的成功举办,第三届电气、控制与信息技术国际学术会议&#xff08…

JTAG 详解

10.1 JTAG简介 JTAG接口的基本工作原理是:在芯片内部定义一个TAP(Test Access Port,测试访问端口),开发人员使用连接到芯片的JTAG外部接口上的JTAG调试器,通过访问芯片内部的TAP端口来扫描芯片内部各个扫…

数据驱动决策:大数据分析如何塑造业务成功

文章目录 大数据分析的定义大数据分析如何影响业务1. 洞察业务趋势2. 提高决策质量3. 优化运营效率4. 个性化客户体验5. 发现新商机 如何利用大数据分析实现业务成功1. 收集和整合数据2. 选择适当的工具和技术3. 制定数据策略4. 建立数据分析团队5. 进行实验和反馈 大数据分析的…

激活函数作用以及 sigmoid和softmax

激活函数 激活函数在神经网络中起着非常重要的作用,它的主要功能是引入非线性性质,使得神经网络可以学习和表示更加复杂的模式和关系。下面是激活函数的几个主要作用: 引入非线性:激活函数通过引入非线性变换,打破了…

java.sql.SQLException: ORA-28000: the account is locked

1.遇到的问题 Oracle执行报下面的错误 java.sql.SQLException: ORA-28000: the account is locked 2.解决办法 登录sysdba管理账号,执行下面命令。 alter user demo account unlock;

NPDP产品经理证书值得考吗?

NPDP(New Product Development Professional)证书是由新产品开发专业协会(PDMA)提供的一项专业认证。对于那些在产品开发领域寻求进一步发展的人来说,考取这个证书可能是一个值得考虑的选择。 首先,NPDP证…

以“降本增效”为目标,智能视频监控能为企业带来哪些经济价值?

随着经济的发展和科技的进步,企业需要不断提升自身的品质和效率,以保持竞争优势。而智能视频监控技术正是一项值得考虑的工具,其对企业带来的降本增效效益可以通过以下几个方面来体现。 1、降低运行成本 EasyCVR智能视频监控平台可以实现远程…

Mysql数据库 5.SQL语言聚合函数 语言日期-字符串函数

一、聚合函数 SQL中提供了一些可以对查询的记录的列进行计算的函数——聚合函数 1.count() 统计函数,统计满足条件的指定字符的值的个数 统计表中rebirth_mood个数 select count(列名) from 表名; #统计表中rebirth_namelcl的个数 select …

亚马逊发布Q3财报,营收利润强劲,云业务增长缓慢

KlipC报道:10月26日,亚马逊发布财报显示,该公司2023年第三季度每股收益0.94美元,营收同比增13%至1431亿美元,营业利润率7.8%远超预期的5.46%,均高于预期。 KlipC的合伙人Andi D表示:“三季度盈利…