如何能基于prompt tuning v2训练好一个垂直领域的chatglm-6b

news2024/12/26 11:40:10

如何能基于prompt tuning v2训练好一个垂直领域的chatglm-6b

首先先抛出一个问题,是不是所有的文本生成数据集都适合用chatglm 6B的模型进行微调。那我们今天找到了三个数据集,分别为百科数据集、法律问答数据集、论文题目与摘要数据集、专利名称与专利摘要数据集。

官方的广告数据集是如下结构的

{
    "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
    "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}

官方的多轮对话数据集是如下结构的

{
    "prompt": "是的。上下水管都好的",
    "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!",
    "history": [
        [
            "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线",
            "用电脑能读数据流吗?水温多少"
        ],
        [
            "95",
            "上下水管温差怎么样啊?空气是不是都排干净了呢?"
        ]
    ]
}

今天的所有实验都是探索单轮生成chatglm-6B上的适配性。

ptuning chatglm 6B中有两个数据集作为标准的官方微调数据集案例

我们看一下ptuning chatglm 6B中的启动参数有哪些。

PRE_SEQ_LEN=128
LR=2e-2

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --do_eval \
    --train_file AdvertiseGen/patent_train.32.128.512.json \
    --validation_file AdvertiseGen/patent_dev.32.128.512.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path THUDM/chatglm-6b \
    --output_dir output/patent_dev-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 128 \
    --max_target_length 512 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --fp16 False\
    --pre_seq_len $PRE_SEQ_LEN

PRE_SEQ_LEN 预序列长度

LR 学习率

do_train 是否进行训练

do_eval 是否进行预测

train_file 训练文件相对地址

validation_file 验证文件相对地址

prompt_column prompt 提示信息字段

response_column 响应信息字段

overwrite_cache 重写数据集缓存。

model_name_or_path 模型名称或模型地址

output_dir 训练好的模型保存的地址

per_device_train_batch_size 每个设备上的训练批次大小 在实际的训练过程中3090显卡可以把这个参数开到4。

模型的指令输入应该如何拼接才可以让chatglm更好的服务。

train.sh 中的 PRE_SEQ_LENLR 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。 在默认配置 quantization_bit=4per_device_train_batch_size=1gradient_accumulation_steps=16 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。

调整batch size后的学习率应该如何调整。

chatglm的工作流程

基于openbayes的3090单卡,prompt tuning v2 训练chatglm 6B模型。

训练专利prompt的数据的时候基础训练参数 修改了 per_device_train_batch_size 为 4。

 ***** Running training *****
Num examples = 3384
Num Epochs = 58
Instantaneous batch size per device = 4
Total train batch size (w. parallel, distributed & accumulation) = 64
Gradient Accumulation steps = 16
Total optimization steps = 3000
Number of trainable parameters = 29360128

其中每一个设备的batch size设定为4,总共训练的批次大小是64。这里的总批次是因为采用了梯度累计策略,所以总训练批次大小是64。那如果是两张卡的话这里是128。

训练专利prompt的数据集的时候的损失表现

中国大百科数据集

PyTorch DataParallel和DDP是PyTorch提供的两个数据并行扩展。 1. PyTorch Data Parallel PyTorch Data Parallel是PyTorch框架中的一个重要组成部分,它提供了一种高效的并行计算机制,使得在GPU上运行Torch模型变得更加容易。Data Parallel使用GPU上的多线程来并行计算多个输入特征,从而提高计算效率。 Data Parallel的实现方式包括: - Data parallel器:负责将输入特征按照一定的规则划分成一组数据 parallel,例如按照相似度、长度、形状等特征进行划分。 - 并行化操作:将数据 parallel划分为多个并行块,并执行相应的操作。 - 数据预处理:对数据parallel块进行一些预处理,例如合并、排序、归一化等操作。 使用Data Parallel可以大大简化GPU编程,并提高模型的训练效率。 2. DDP 官方建议用新的DDP,采用all-reduce算法,本来设计主要是为了多机多卡使用,但是单机上也能用,使用方法如下:

初始化使用nccl后端

torch.distributed.init_process_group(backend="nccl")

模型并行化

model=torch.nn.parallel.DistributedDataParallel(model)

需要注意的是:DDP并不会自动shard数据 1. 如果自己写数据流,得根据torch.distributed.get_rank()去shard数据,获取自己应用的一份 2. 如果用Dataset API,则需要在定义Dataloader的时候用DistributedSampler 去shard:

sampler = DistributedSampler(dataset) # 这个sampler会自动分配数据到各个gpu上
DataLoader(dataset, batch_size=batch_size, sampler=sampler)

在chatglm 6B中训练的并行是基于transformers架构实现的

from transformers.trainer import Trainer

trainer默认是用torch.distributed的api来做多卡训练的,因此可以直接支持多机多卡,单机多卡,单机单卡。

目前autodl没有多卡资源,所以也没办法验证多卡这个如何可以更高效率的执行出来有效的结果。

不同的云计算平台

autodl 模型下载速度比较慢 可以通过在新建环境时候选择合适的thuglm镜像来减少模型下载上所需要的时间。实例初始化空间为20GB系统空间+50GB数据空间,数据空间可以扩容。

openbayes 模型下载速度比较快,环境每次重启的时候都要执行一遍pip install安装步骤。启动训练是需要在命令行中加上 --fp16 False,不然会报错。实例硬盘上限只有50GB,需要注意保存策略。存储空间费用如下。

用我的专用邀请链接,注册 OpenBayes,双方各获得 60 分钟 RTX 3090 使用时长,支持累积,永久有效:

注册 - OpenBayes

要不是autodl没有卡,我也不会来openbayes租用显卡。最吐槽的问题就是硬盘空间问题。作为个人研究者还要为硬盘按月付费实在是让人不舒服。目前我付费了100GB的硬盘。一个月80多块。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/421899.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

stable-diffusion-webui-colab部署记录

stable-diffusion-webui-colab 该模型可以在网上云端部署stable-diffusion,减少本地部署的繁琐步骤降低配置要求的依赖。 一、进入stable-diffusion-webui-colab 1.网址:https://github.com/camenduru/stable-diffusion-webui-colab 在分支中选择driv…

java 坐标体系与绘图

目录 一、坐标体系 1.像素 : 2.坐标系 : 二、绘图 1.机制 : 2.实例 : 3.原理 : 4.常用绘图方法 : 1 setColor(Color c) : 设置画笔颜色 2 drawLine(int x1, int y1, int x2, int y2) : 画直线 3 drawRect(int x, int y, int width, int height) : 画矩形边框 4 fillRec…

【密码学复习】第六讲 HASH函数和MAC(三)

H是一个Hash函数 K表示密钥 B表示计算消息摘要时消息分块的字节长度(对MD5和SHA-1是512比特,64字节) L表示消息摘要按字节计算的长度(对MD5是16字节) ipad表示0x36重复B次,opad表示0x5c重复B次。 K可以…

腾讯云轻量服务器价格表(2023版)

2023腾讯云轻量应用服务器2核2G4M带宽88元一年、2核4G6M带宽159元/年、4核8G10M优惠价425元、8核16G14M价格1249、16核32G20M服务器2499元一年,今天分享2023腾讯云服务器配置及精准报价。 腾讯云轻量应用服务器优惠价格表 腾讯云服务器分为轻量应用服务器和云服务器…

Games106学习记录第一课

本文地址:https://blog.csdn.net/t163361/article/details/130139998 最近准备申请新星创作者,需要2000个粉丝关注,觉得文章有用的,请点一下左侧边栏的关注,谢谢。 前段时间看到Games106课程,讲的是流水线…

【天梯赛—不想坑队友系列】L1-002 打印沙漏(java)

题目链接 PTA | 程序设计类实验辅助教学平台 本题要求你写个程序把给定的符号打印成沙漏的形状。例如给定17个“*”,要求按下列格式打印 ************ *****所谓“沙漏形状”,是指每行输出奇数个符号;各行符号中心对齐;相邻两行符…

c/c++:2进制、8进制、10进制、16进制和进制之间的转换,c语言输出匹配格式%

c/c:2进制、8进制、10进制、16进制和进制之间的转换,c语言输出匹配格式% 2022找工作是学历、能力和运气的超强结合体,遇到寒冬,大厂不招人,此时学会c的话, 我所知道的周边的会c的同学,可手握10…

Linux主机用WordPress搭建网站

文章目录一、搭建过程1.1、切换到超户1.2、更新1.3、安装一些包1.4、安装wordpress1.5、配置MariaDB1.6、创建WordPress数据库1.7、配置WordPress1.8、登录WordPress1.9、安装phpMyAdmin一、搭建过程 1.1、切换到超户 sudo su1.2、更新 apt-get update -y1.3、安装一些包 a…

我在windows10下,使用CMake gui 编译krita源码

系列文章目录 文章目录系列文章目录前言一、krita编译说明二、使用步骤前言 我在windows10下,使用CMake gui 编译krita源码 where is the source code:E:/krita-dev/krita where to build the binaries:E:/krita-dev/krita_camke current generator:MinGW Makefile…

成为程序员后才知道的6件事,第5点看完很心酸!

曾几时,总觉得IT精英外表光鲜亮丽,尤其是程序员咔咔咔打代码,月入几个w,不光挣得多,上班期间还能玩电脑游戏。但是,真正当了程序员之后,OMG!我再也不这样想了!好多事都是当了程序员才…

【C语言深入】带你了解C语言中的可变参数列表

【C语言深入】带你了解C语言中的可变参数列表一、可变参数函数的使用方式1、使用方式2、自定义可变把参数函数2.1、三个宏一个类型2.2、实现方式二、可变参数列表的原理1、va_start1.1、_ADDRESSOF1.2、关于临时拷贝的一个小知识点1.3、_INTSIZEOF2、va_arg3、va_end一、可变参…

23种设计模式总结(大白话,适合小白)

文章目录什么是设计模式?设计模式的分类创建型模式创建型类类型工厂方法模式创建型对象型抽象工厂模式生成器模式原型模式单例模式结构型模式结构型类类型适配器模式结构型对象型桥接模式组合模式装饰器模式外观模式享元模式代理模式行为型模式行为型对象型命令模式…

【C++PrimerPlus】第五章 循环和关系表达式

文章目录5.1 for循环5.1.1 for循环的组成部分5.1.2 回到for循环5.1.3 修改步长5.1.4 使用for循环访问字符串5.1.5 递增运算符 ()和递减运算符(--)5.1.6 副作用和顺序点5.1.7 前缀格式与后缀格式5.1.8 递增/递减和指针5.1.9 组合赋值运算符5.1.10 复合语句![](https://img-blog.…

Qt Quick - ToolTip

Qt Quick - ToolTip使用总结一、概述二、附带的ToolTip三、延迟和超时四、自定义ToolTip五、定制化一、概述 ToolTip 其实就是ToolTip,所谓ToolTip其实就是一段简短的文本,告知用户控件的功能。它通常置于父控件之上或之下。提示文本可以是任何富文本格…

常用异常检测模型的应用

常用异常检测模型的应用 描述 异常数据检测不仅仅可以帮助我们提高数据质量,同时在一些实际业务中,异常数据往往包含有价值的信息,如异常交易、网络攻击、工业品缺陷等,因此异常检测也是数据挖掘的重要手段。常用的异常检测模型…

【通过Cpython3.9源码看看python字符串拼接:“+”为什么比join低效】

基本说明 Python字符串拼接中,使用join()方法比运算符更高效,主要原因在于字符串对象的不可变性和内存分配策略。 首先,我们要知道Python字符串是不可变的对象。这意味着,每次使用运算符进行字符串拼接时,Python需要…

Vue2-黑马(四)

目录: (1)axios-响应格式 (2)axios-拦截器 (3)vue2-条件渲染 (4)vue2-列表渲染 (1)axios-响应格式 下面看axios的返回响应对象的内部组成 后…

【grpc02】安装protobuf和protoc

目录 Windows环境 下载通用编译器 配置环境变量 安装go专用的protoc的生成器 GoLang中安装插件 如何使用protobuf呢? Mac环境 Protoc安装 Protoc-gen-go的安装 Windows环境 下载通用编译器 下载地址:v3.20.1 Releases protocolbuffers/pr…

【优化算法】使用遗传算法优化MLP神经网络参数(TensorFlow2)

文章目录任务查看当前的准确率情况使用遗传算法进行优化完整代码任务 使用启发式优化算法遗传算法对多层感知机中中间层神经个数进行优化,以提高模型的准确率。 待优化的模型: 基于TensorFlow2实现的Mnist手写数字识别多层感知机MLP # MLP手写数字识别…

Java支付SDK接口远程调试 - 支付宝沙箱环境【公网地址调试】

文章目录1.测试环境2.本地配置3. 内网穿透3.1 下载安装cpolar内网穿透3.2 创建隧道4. 测试公网访问5. 配置固定二级子域名5.1 保留一个二级子域名5.2 配置二级子域名6. 使用固定二级子域名进行访问转发自CSDN远程穿透的文章:Java支付宝沙箱环境支付,SDK接…