ChatGLM-6B模型微调实战(以 ADGEN (广告生成) 数据集为例)

news2024/11/18 7:48:49

1 介绍

对于 ChatGLM-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,差不多需要 7GB或则8GB 显存即可运行。
在这里插入图片描述
在这里插入图片描述

2 环境

2.1 python 环境
conda create -n py310_chat python=3.10       # 创建新环境    
source activate py310_chat                   # 激活环境

或者

# 创建虚拟环境
conda create -n xxx python=3.8
# 进入虚拟环境
conda activate xxx
# 退出当前虚拟环境
conda deactivate
# 查看本地虚拟环境
conda info --env
# 删除虚拟环境
conda remove -n xxx --all

2.2 下载代码
git clone https://github.com/THUDM/ChatGLM-6B.git    
cd ChatGLM-6B
2.3 安装依赖

运行微调需要4.27.1版本的transformers。除 ChatGLM-6B 的依赖之外,还需要按照以下依赖

# torch cuda 安装要匹配cuda 驱动版本:
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
# 安装gradio用于启动图形化web界面
pip install gradio
pip install -r requirements.txt    
pip install rouge_chinese nltk jieba datasets

验证pytorch是否为GPU版本

import torch
torch.cuda.is_available()  ## 输出应该是True
2.4(选做)

在运行前,可以修改一些文件内容

# web_demo.py
# 1. 新增mirror='https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models,下载模型使用清华源
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, mirror='https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models')
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, mirror='https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models').half().cuda()
# 2. 增加server_name和server_port参数
demo.queue().launch(share=True,server_name="0.0.0.0",server_port=9234)

3 运行

#基于 Gradio 的网页版 Demo
python web_demo.py
#命令行 Demo
python cli_demo.py

值得注意的是: 显存够用下面这些不用管,当显存不够时(即GPU 显存有限低于13GB),尝试以量化方式加载模型的,需要添加代码.quantize(8) .quantize(4) :
int8精度加载,需要10G显存;
int4精度加载,需要6G显存;

#将句子对列表传给tokenizer,就可以对整个数据集进行分词处理
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) #将文本转换为模型能理解的数字# 自动加载该模型训练时所用的分词器

model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(6).cuda()#从checkpoint实例化任何模型,下载预训练模型

4 微调

https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning

4.1 数据集

从 Google Drive或Tsinghua Cloud 下载处理好的 ADGEN 数据集,将解压后的 AdvertiseGen 目录放到本目录下(ptuning/AdvertiseGen)

4.2 模型下载

Huggingface 平台下载

git lfs install
git clone https://huggingface.co/THUDM/chatglm-6b
4.3 微调训练
cd ptuning/
bash train.sh

train.sh 脚本如下

PRE_SEQ_LEN=128                      # soft prompt 长度,P-tuning v2 参数
LR=1e-2                            # 训练的学习率,P-tuning v2 参数

CUDA_VISIBLE_DEVICES=0 python main.py \
--do_train \                                # 训练
    --train_file AdvertiseGen/train.json \      # 训练集地址
    --validation_file AdvertiseGen/dev.json \   # 验证集地址
--prompt_column content \              # 训练集中prompt 的key名称【可以理解为输入值的key】
--response_column summary \            # 训练集中response的key名称【可以理解为生成值的key】
--overwrite_cache \                    # 是否覆盖 缓存
--model_name_or_path THUDM/chatglm-6b \ # chatglm-6b 模型地址
--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \    # 模型保存地址
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--predict_with_generate \
--max_steps 3000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 8               # 模型 量化方式,P-tuning v2 参数

train.sh 中的 PRE_SEQ_LEN 和 LR 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。

在默认配置 quantization_bit=4、per_device_train_batch_size=1、gradient_accumulation_steps=16 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。

4.4 模型推理

evaluate.sh 中的 CHECKPOINT 更改为训练时保存的 checkpoint 名称,运行以下指令进行模型推理和评测:

bash evaluate.sh

evaluate.sh 脚本如下

PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm-6b-pt-8-1e-2
STEP=3000
CUDA_VISIBLE_DEVICES=0 python3 main.py \
--do_predict \
    --validation_file AdvertiseGen/dev.json \
    --test_file AdvertiseGen/dev.json \
--overwrite_cache \
--prompt_column content \
--response_column summary \
--model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP  \
--output_dir ./output/$CHECKPOINT \
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_eval_batch_size 1 \
--predict_with_generate \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4
4.5 生成结果分析

评测指标为中文 Rouge scoreBLEU-4。生成的结果保存在 ./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt

  • 示例1
Input: 类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞

Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。

Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。

Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。


  • 示例2
Input: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领

Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。

Output[微调前]: 类型#裙版型#显瘦风格#文艺风格#简约图案#印花图案#撞色裙下摆#压褶裙长#连衣裙裙领型#圆领 1. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5. 风格:文艺风格,让连衣裙更加有内涵和品味。6. 图案:印花设计,在连衣裙上印有独特的图案。7. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。

Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。
4.6 模型部署

将对应的demo或代码中的THUDM/chatglm-6b换成经过 P-Tuning 微调之后 checkpoint 的地址(在示例中为 ./output/adgen-chatglm-6b-pt-8-1e-2/checkpoint-3000)。注意,目前的微调还不支持多轮数据,所以只有对话第一轮的回复是经过微调的。

友情链接

以下是部分基于本仓库开发的开源项目:

  • SwissArmyTransformer: 一个Transformer统一编程框架,ChatGLM-6B已经在SAT中进行实现并可以进行P-tuning微调。
  • ChatGLM-MNN: 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU。
  • ChatGLM-Tuning: 基于 LoRA 对 ChatGLM-6B 进行微调。类似的项目还包括 Humanable ChatGLM/GPT Fine-tuning | ChatGLM 微调
  • langchain-ChatGLM:基于本地知识的 ChatGLM 应用,基于LangChain
  • bibliothecarius:快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。
  • 闻达:大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能
  • JittorLLMs:最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署

5 遇到的问题

报错1
ERROR: Could not find a version that satisfies the requirement protobuf<3.20.1,>=3.19.5 (from versions: none)
ERROR: No matching distribution found for protobuf<3.20.1,>=3.19.5

可能换了国内的镜像源,所以只需要指定装包路径(源)即可

pip install -r requirements.txt -i https://pypi.Python.org/simple/

报错 2
ImportError: Using SOCKS proxy, but the 'socksio' package is not installed. Make sure to install httpx using `pip install httpx[socks]`.

因为在命令行设置了“科学上网”,关掉即可

# 因为我设置的是临时的,所以在命令行输入如下代码即可
unset http_proxy
unset https_proxy
报错 3
RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 7.93 GiB total capacity; 7.40 GiB already allocated; 53.19 MiB free; 7.40 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:

# int4精度加载,需要6G显存
# web_demo.py
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(4).cuda()

报错 4
RuntimeError: Library cudart is not initialized

用conda管理的环境,此时应该是cudatoolkit有问题,参考此issues

# 使用conda安装cudatoolkit
conda install cudatoolkit=11.3 -c nvidia

CPU 占用过高的或者GPU显存不够都可能被killed 掉

参考

THUDM/ChatGLM-6B
ChatGLM-Tuning
ptuning/README.md
LLMs入门实战篇(二)——清华大学开源中文版ChatGLM-6B模型微调实战
ChatGLM-6B (介绍相关概念、基础环境搭建及部署)
学习实践ChatGLM-6B(部署+运行+微调)
torch install

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/485580.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

go 打包文件夹成zip文件

go 打包文件夹成zip文件 代码有些乱&#xff0c;找不到合适的例子&#xff0c;和优雅的代码 当前代码打包文件是在 需要打包的目录下&#xff0c;测试的时候注意文件翻倍容量 writer, err : zzip.CreateHeader(header) //这里创建文件时注意不要用完整路径 zip中会生产完整路径…

【51单片机】蜂鸣器

&#x1f38a;专栏【51单片机】 &#x1f354;喜欢的诗句&#xff1a;更喜岷山千里雪 三军过后尽开颜。 &#x1f386;音乐分享【Love Story】 大一同学小吉&#xff0c;欢迎并且感谢大家指出我的问题&#x1f970; &#x1f354;效果 &#xff08;注意听声音哦&#xff09; 按…

Blob File

文章目录 学习链接Blob创建演示 分片演示 Fileinput手动拖拽fetch 从后端获取流前端代码后端代码 window.showOpenFilePicker Filereader示例1示例2 ArrayBuffer创建bufferTypedArray读写bufferDataView读写buffer与Blob对比 Blob Url & DataUrl示例1示例2 学习链接 Blog …

WIFI密码hacking学习

1 wifite sudo -i 切换到root终端&#xff0c;执行wifite CtrlC 停止WIFI列表扫描 这里选择爆破目标1&#xff0c;他会先监听客户端和wifi的连接的数据包&#xff0c;然后用wordlist-probable.txt里面的字典进行爆破 区别 WPS 是 no 的和yes no会自动化爆破 这种模式默认下…

软件测试行业前景怎么样

软件测试工程师工作好找&#xff0c;但是企业有时却很难招聘到合适的软件测试工程师&#xff0c;这是因为现在企业需要的是自动化测试人才和测试开发人才&#xff01;前些年的手动测试员早已不能满足企业现在的发展需求。 随着人工智能时代的到来&#xff0c;IT行业受到了越来…

从‘孔乙己长衫’现象看社会不公

孔乙已是鲁迅笔下人物&#xff0c;穷困流倒还穿着象征读书人的长衫&#xff0c;迁腐、麻木。最近&#xff0c;大家自我调佩是“当代孔乙己”&#xff0c;学历成为思想负担&#xff0c;找工作时高不成低不就。 认识孔乙己 孔乙己是清朝末年的小贩&#xff0c;生活在贫苦的阶层…

ubuntu22.04安装与配置

目录 一、环境及下载 iso下载 VM配置 二、虚拟机与环境配置 虚拟机开始后的配置 一些工具配置 参考&#xff1a; VMware Workstation Pro 文档 一、环境及下载 iso下载 Download Ubuntu Desktop | Download | Ubuntu 新出了23但是偶数年份稳定支持&#xff0c;所以我…

openai接口调用-如何接入openai获取 api key

openai api key获取 获取 OpenAI API Key 非常简单&#xff0c;您只需要按照以下步骤进行操作: 创建 OpenAI 账户。如果您还没有 OpenAI 账户&#xff0c;请访问 OpenAI 官网&#xff08;http://openai.com/signup&#xff09;并创建一个帐户。登录您的 OpenAI 账户。使用您的…

docker安装ES、kibana和IK分词器

拉取镜像 docker pull elasticsearch:7.4.2 docker pull kibana:7.4.2创建存储数据的目录 mkdir -p /home/lab1018/docker_volume/elasticsearch/config mkdir -p /home/lab1018/docker_volume/elasticsearch/data mkdir -p /home/lab1018/docker_volume/elasticsearch/plugi…

C++-----动态规划

目录 一、动态规划的基本思想 二、设计动态规划法的步骤 三、动态规划问题的特征 4.1 矩阵连乘积问题 4.1.1 分析最优解的结构 4.1.2 建立递归关系 4.1.3 计算最优值 4.1.3 计算最优值 4.1.3 构造最优解 4.2 动态规划算法的基本要素 4.2.1 最优子结构 4.2.2 重叠子问题 …

gRPC结合vcpkg在x64-windows平台visual studio2019 cmake工程里面的应用

这里我们运用vcpkg去下载安装gRPC&#xff0c;进入vcpkg目录后&#xff0c;执行命令&#xff1a;.\vcpkg.exe install grpc:x64-windows grpc在vcpkg里面安装完成后&#xff0c;我们就来使用grpc做一个简单的例子。 gRPC顾名思义&#xff0c;就是google的RPC方案&#xff0c;…

ProtoBuf 编码原理

因为涉及到分布式集群之间的通信&#xff0c;所以来学习了下 ProtoBuf&#xff0c;为什么选择 ProtoBuf 呢&#xff1f;主要还是因为相对于 json , xml 来说&#xff0c;ProtoBuf 传输效率更快&#xff0c;故需要了解下它的编码设计。 首先&#xff0c;每一个 message 进行编码…

科研人必看入门攻略(收藏版)

来源&#xff1a;投稿 作者&#xff1a;小灰灰 编辑&#xff1a;学姐 本文主要以如何做科研&#xff0c;日常内功修炼&#xff0c;常见科研误区&#xff0c;整理日常‘好论文’四个部分做以介绍&#xff0c;方便刚入门的科研者进行很好的规划。 1.如何做科研 1.1 选方向 当我…

【2023年第十一届泰迪杯数据挖掘挑战赛】A题:新冠疫情防控数据的分析 32页和40页论文及实现代码

【2023年第十一届泰迪杯数据挖掘挑战赛】A题&#xff1a;新冠疫情防控数据的分析 32页和40页论文及实现代码 相关链接 &#xff08;1&#xff09;建模方案 【2023年第十一届泰迪杯数据挖掘挑战赛】A题&#xff1a;新冠疫情防控数据的分析 建模方案及python代码详解 &#x…

【初学人工智能原理】【4】梯度下降和反向传播:能改(下)

前言 本文教程均来自b站【小白也能听懂的人工智能原理】&#xff0c;感兴趣的可自行到b站观看。 本文【原文】章节来自课程的对白&#xff0c;由于缺少图片可能无法理解&#xff0c;故放到了最后&#xff0c;建议直接看代码&#xff08;代码放到了前面&#xff09;。 代码实…

《LearnUE——基础指南:开篇—2》——准备工作

目录 0.2.1 UE的获取与创建项目 0.2.2 UE4编辑器界面布局 1. 编辑器介绍 2. 新建蓝图与蓝图拖动 3. 菜单介绍 4. 工具栏介绍 0.2.3 学习资料 0.2.1 UE的获取与创建项目 登录UE官网&#xff1a;www.unrealengine.com 点击“登录”&#xff0c;如果没有账号&#xff0…

自动驾驶——离散系统LQR的黎卡提方程Riccati公式推导与LQR工程化

1.LQR Question Background 之前写过连续系统的黎卡提方程Riccati推导,但是考虑到实际工程落地使用的是离散系统,于是又进行了离散黎卡提方程Riccati的公式推导。 2.Proof of Riccati Equation Formula for Discrete Systems 工程化落地,就是使用公式(2-14)实时计算控制率…

Windows编译安装AzerothCore魔兽世界开源服务端Lua脚本引擎Eluna和防作弊anticheat模块教程

Windows编译安装AzerothCore魔兽世界开源服务端Lua脚本引擎Eluna和防作弊anticheat模块教程 大家好&#xff0c;我是艾西今天和大家聊聊魔兽世界游戏内的脚步以及防作弊模块 Eluna是azerothcore服务端的Lua脚本引擎&#xff0c;可以在原有azerothcore的基础上实现很多拓展以及…

Linux操作系统指令(1)

目录 一、什么是Linux操作系统&#xff1f;二、Linux基本指令12.1 ls指令2.2 pwd指令2.3 cd指令2.4 touch指令2.5 mkdir指令&#xff08;非常重要&#xff09;2.6 rmdir指令 && rm 指令&#xff08;十分重要&#xff09;2.7 man指令&#xff08;非常重要&#xff09;2.…

基于松鼠算法的极限学习机(ELM)回归预测-附代码

基于松鼠算法的极限学习机(ELM)回归预测 文章目录 基于松鼠算法的极限学习机(ELM)回归预测1.极限学习机原理概述2.ELM学习算法3.回归问题数据处理4.基于松鼠算法优化的ELM5.测试结果6.参考文献7.Matlab代码 摘要&#xff1a;本文利用松鼠算法对极限学习机进行优化&#xff0c;并…