ChatGLM-6B模型微调实战(以 ADGEN (广告生成) 数据集为例,序列长度达 2048)

news2025/1/15 11:39:37

kingglory/ChatGLM-6B 项目地址

1 介绍

对于 ChatGLM-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,差不多需要 7GB或则8GB 显存即可运行。
在这里插入图片描述
在这里插入图片描述

2 环境

2.1 python 环境
conda create -n py310_chat python=3.10       # 创建新环境    
source activate py310_chat                   # 激活环境

或者

# 创建虚拟环境
conda create -n xxx python=3.8
# 进入虚拟环境
conda activate xxx
# 退出当前虚拟环境
conda deactivate
# 查看本地虚拟环境
conda info --env
# 删除虚拟环境
conda remove -n xxx --all

2.2 下载代码
git clone https://github.com/THUDM/ChatGLM-6B.git    
cd ChatGLM-6B
2.3 安装依赖

运行微调需要4.27.1版本的transformers。除 ChatGLM-6B 的依赖之外,还需要按照以下依赖

# torch cuda 安装要匹配cuda 驱动版本:
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
# 安装gradio用于启动图形化web界面
pip install gradio
pip install -r requirements.txt    
pip install rouge_chinese nltk jieba datasets

验证pytorch是否为GPU版本

import torch
torch.cuda.is_available()  ## 输出应该是True
2.4(选做)

在运行前,可以修改一些文件内容

# web_demo.py
# 1. 新增mirror='https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models,下载模型使用清华源
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, mirror='https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models')
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, mirror='https://mirrors.tuna.tsinghua.edu.cn/hugging-face-models').half().cuda()
# 2. 增加server_name和server_port参数
demo.queue().launch(share=True,server_name="0.0.0.0",server_port=9234)

3 运行

#基于 Gradio 的网页版 Demo
python web_demo.py
#命令行 Demo
python cli_demo.py

值得注意的是: 显存够用下面这些不用管,当显存不够时(即GPU 显存有限低于13GB),尝试以量化方式加载模型的,需要添加代码.quantize(8) .quantize(4) :
int8精度加载,需要10G显存;
int4精度加载,需要6G显存;

#将句子对列表传给tokenizer,就可以对整个数据集进行分词处理
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) #将文本转换为模型能理解的数字# 自动加载该模型训练时所用的分词器

model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(6).cuda()#从checkpoint实例化任何模型,下载预训练模型

4 微调

https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning

4.1 数据集

从 Google Drive或Tsinghua Cloud 下载处理好的 ADGEN 数据集,将解压后的 AdvertiseGen 目录放到本目录下(ptuning/AdvertiseGen)

4.2 模型下载

Huggingface 平台下载

git lfs install
git clone https://huggingface.co/THUDM/chatglm-6b
4.3 微调训练
cd ptuning/
bash train.sh

train.sh 脚本如下

PRE_SEQ_LEN=128                      # soft prompt 长度,P-tuning v2 参数
LR=1e-2                            # 训练的学习率,P-tuning v2 参数

CUDA_VISIBLE_DEVICES=0 python main.py \
--do_train \                                # 训练
    --train_file AdvertiseGen/train.json \      # 训练集地址
    --validation_file AdvertiseGen/dev.json \   # 验证集地址
--prompt_column content \              # 训练集中prompt 的key名称【可以理解为输入值的key】
--response_column summary \            # 训练集中response的key名称【可以理解为生成值的key】
--overwrite_cache \                    # 是否覆盖 缓存
--model_name_or_path THUDM/chatglm-6b \ # chatglm-6b 模型地址
--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \    # 模型保存地址
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--predict_with_generate \
--max_steps 3000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 8               # 模型 量化方式,P-tuning v2 参数

train.sh 中的 PRE_SEQ_LEN 和 LR 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。

在默认配置 quantization_bit=4、per_device_train_batch_size=1、gradient_accumulation_steps=16 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。

4.4 模型推理

evaluate.sh 中的 CHECKPOINT 更改为训练时保存的 checkpoint 名称,运行以下指令进行模型推理和评测:

bash evaluate.sh

evaluate.sh 脚本如下

PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm-6b-pt-8-1e-2
STEP=3000
CUDA_VISIBLE_DEVICES=0 python3 main.py \
--do_predict \
    --validation_file AdvertiseGen/dev.json \
    --test_file AdvertiseGen/dev.json \
--overwrite_cache \
--prompt_column content \
--response_column summary \
--model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP  \
--output_dir ./output/$CHECKPOINT \
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_eval_batch_size 1 \
--predict_with_generate \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4
4.5 生成结果分析

评测指标为中文 Rouge scoreBLEU-4。生成的结果保存在 ./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt

  • 示例1
Input: 类型#上衣*材质#牛仔布*颜色#白色*风格#简约*图案#刺绣*衣样式#外套*衣款式#破洞

Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。

Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。

Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。


  • 示例2
Input: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领

Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。

Output[微调前]: 类型#裙版型#显瘦风格#文艺风格#简约图案#印花图案#撞色裙下摆#压褶裙长#连衣裙裙领型#圆领 1. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5. 风格:文艺风格,让连衣裙更加有内涵和品味。6. 图案:印花设计,在连衣裙上印有独特的图案。7. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。

Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。
4.6 模型部署

将对应的demo或代码中的THUDM/chatglm-6b换成经过 P-Tuning 微调之后 checkpoint 的地址(在示例中为 ./output/adgen-chatglm-6b-pt-8-1e-2/checkpoint-3000)。注意,目前的微调还不支持多轮数据,所以只有对话第一轮的回复是经过微调的。
默认情况下,模型以 FP16 精度加载(无量化),需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:

  • 模型量化
# 按需修改,目前只支持 4/8 bit 量化 
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(8).half().cuda()

8-bit 量化下 GPU 显存占用约为 10GB,4-bit 量化下仅需 6GB 占用

随着对话轮数的增多,对应显存消耗也随之增大
理论上 ChatGLM-6B 支持无限长的 context-length,但总长度超过 2048 后性能会逐渐下降
量化模型会带来一定的性能损失

量化模型加载方式

# INT8 量化的模型将"THUDM/chatglm-6b-int4"改为"THUDM/chatglm-6b-int8" 
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
  • CPU 部署(需要 32G 内存)
    在 32G 内存的机器上经过测试,推理速度很慢
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()

友情链接

以下是部分基于本仓库开发的开源项目:

  • SwissArmyTransformer: 一个Transformer统一编程框架,ChatGLM-6B已经在SAT中进行实现并可以进行P-tuning微调。
  • ChatGLM-MNN: 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU。
  • ChatGLM-Tuning: 基于 LoRA 对 ChatGLM-6B 进行微调。类似的项目还包括 Humanable ChatGLM/GPT Fine-tuning | ChatGLM 微调
  • langchain-ChatGLM:基于本地知识的 ChatGLM 应用,基于LangChain
  • bibliothecarius:快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。
  • 闻达:大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能
  • JittorLLMs:最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署

5 遇到的问题

报错1
ERROR: Could not find a version that satisfies the requirement protobuf<3.20.1,>=3.19.5 (from versions: none)
ERROR: No matching distribution found for protobuf<3.20.1,>=3.19.5

可能换了国内的镜像源,所以只需要指定装包路径(源)即可

pip install -r requirements.txt -i https://pypi.Python.org/simple/

报错 2
ImportError: Using SOCKS proxy, but the 'socksio' package is not installed. Make sure to install httpx using `pip install httpx[socks]`.

因为在命令行设置了“科学上网”,关掉即可

# 因为我设置的是临时的,所以在命令行输入如下代码即可
unset http_proxy
unset https_proxy
报错 3
RuntimeError: CUDA out of memory. Tried to allocate 128.00 MiB (GPU 0; 7.93 GiB total capacity; 7.40 GiB already allocated; 53.19 MiB free; 7.40 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

默认情况下,模型以 FP16 精度加载,运行上述代码需要大概 13GB 显存。如果你的 GPU 显存有限,可以尝试以量化方式加载模型,使用方法如下:

# int4精度加载,需要6G显存
# web_demo.py
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(4).cuda()

报错 4
RuntimeError: Library cudart is not initialized

用conda管理的环境,此时应该是cudatoolkit有问题,参考此issues

# 使用conda安装cudatoolkit
conda install cudatoolkit=11.3 -c nvidia

报错 5 (windows 系统)
# ModuleNotFoundError: No module named 'chardet'

# ImportError: cannot import name 'COMMON_SAFE_ASCII_CHARACTERS' from 'charset_normalizer.constant' (C:\Users\123\miniconda3\envs\chatglm6b\lib\site-packages\charset_normalizer\constant.py)

pip install chardet

# 仍然报错

# AttributeError: partially initialized module 'charset_normalizer' has no attribute 'md__mypyc' (most likely due to a circular import)

pip install --force-reinstall charset-normalizer==3.1.0
CPU 占用过高的或者GPU显存不够都可能被killed 掉

量化模型加载方式

model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).quantize(8).half().cuda()

应该改成

model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(8).cuda()

更省显存

局限性

由于ChatGLM-6B的小规模,其能力仍然有许多局限性。以下是目前发现的一些问题:

  • 模型容量较小:6B的小容量,决定了其相对较弱的模型记忆和语言能力。在面对许多事实性知识任务时,ChatGLM-6B可能会生成不正确的信息;它也不擅长逻辑类问题(如数学、编程)的解答。
  • 产生有害说明或有偏见的内容:ChatGLM-6B只是一个初步与人类意图对齐的语言模型,可能会生成有害、有偏见的内容。(内容可能具有冒犯性,此处不展示)
  • 英文能力不足:ChatGLM-6B 训练时使用的指示/回答大部分都是中文的,仅有极小一部分英文内容。因此,如果输入英文指示,回复的质量远不如中文,甚至与中文指示下的内容矛盾,并且出现中英夹杂的情况。
  • 易被误导,对话能力较弱:ChatGLM-6B 对话能力还比较弱,而且 “自我认知” 存在问题,并很容易被误导并产生错误的言论。例如当前版本的模型在被误导的情况下,会在自我认知上发生偏差。

不过 GLM 团队也坦言,整体来说 ChatGLM 距离国际顶尖大模型研究和产品(比如 OpenAI 的 ChatGPT 及下一代 GPT 模型)还存在一定的差距。该团队表示,将持续研发并开源更新版本的 ChatGLM 和相关模型。“欢迎大家下载 ChatGLM-6B,基于它进行研究和(非商用)应用开发。GLM 团队希望能和开源社区研究者和开发者一起,推动大模型研究和应用在中国的发展。”

参考

THUDM/ChatGLM-6B
ChatGLM-Tuning
ptuning/README.md
LLMs入门实战篇(二)——清华大学开源中文版ChatGLM-6B模型微调实战
ChatGLM-6B (介绍相关概念、基础环境搭建及部署)
学习实践ChatGLM-6B(部署+运行+微调)
LLMs九层妖塔(第一层 ChatGLM-6B)——ChatGLM-6B模型初体验
LLMs九层妖塔——第一层 ChatGLM学习实战-闯关笔记
torch install
试用宝典-阿里云开发者社区-云计算-阿里云 (aliyun.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/497925.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【软考-中级】系统集成项目管理工程师-计算题

系统集成项目管理工程师 二、计算题题型一&#xff1a;EMV&#xff08;预期货币价值&#xff09;题型二&#xff1a;加权系统题型三&#xff1a;自制和外贸决策——采购管理题型四&#xff1a;沟通渠道——沟通管理题型五&#xff1a;投资回收期、回收率题型六&#xff1a;进度…

metinfo_6.0.0 任意文件读取漏洞复现

一.漏洞简介 MetInfo是一套使用PHP和Mysql开发的内容管理系统。 MetInfo 6.0.0~6.1.0版本中的 old_thumb.class.php文件存在任意文件读取漏洞。攻击者可利用漏洞读取网站上的敏感文件。 二.漏洞影响 MetInfo 6.0.0 MetInfo 6.1.0 三.漏洞分析 在\MetInfo6.0.0\app\system\i…

推开“任意门”,华为全屋智能正在实现一代科幻迷的童年梦想

科幻作家亚瑟查理斯克拉克有句名言&#xff0c;“任何足够先进的科技&#xff0c;都和魔法无异”。 提到空间魔法&#xff0c;很多科技爱好者或科幻迷会想到哆啦A梦的“任意门”。通过那扇门&#xff0c;可以进入全新的世界&#xff0c;去任何想去的地方&#xff0c;是不少人在…

最新研究,GPT-4暴露了缺点!无法完全理解语言歧义!

夕小瑶科技说 原创作者 |智商掉了一地、Python自然语言推理&#xff08;Natural Language Inference&#xff0c;NLI&#xff09;是自然语言处理中一项重要任务&#xff0c;其目标是根据给定的前提和假设&#xff0c;来判断假设是否可以从前提中推断出来。然而&#xff0c;由于…

远程连接阿里云mysql数据库教程(SSH方式和宝塔面板方式)

一、SSH方式 1.首先登录mysql数据库 mysql -u root -p 输入密码后&#xff1a; 第一次连接的话该密码可以通过宝塔面板来重置&#xff1a; &#xff08;1&#xff09;输入 bt&#xff1a; &#xff08;2&#xff09;输入 7 即可重置mysql密码 2.查询mysql数据库中的user表 …

如何快速给pdf添加水印?

如何快速给pdf添加水印&#xff1f;在当今数字化时代&#xff0c;任何工作都是在电脑上完成了&#xff0c;PDF文档已成为人们日常工作中必不可少的一部分。虽然pdf文件具有较强的稳定性&#xff0c;不能被别人轻易的编辑修改&#xff0c;有时我们还需要提高pdf文件的安全性&…

AlphaFold的极限:高中生揭示人工智能在生物信息学挑战中的缺陷

人工智能程序AlphaFold (AlphaFold2开源了&#xff0c;不是土豪也不会编程的你怎么蹭一波&#xff1f;)&#xff0c;通过预测蛋白质结构解决了结构生物信息学的核心问题。部分AlphaFold迷们声称“该程序已经掌握了终极蛋白质物理学&#xff0c;其工作能力已超越了最初的设计”。…

Doc2Bot: 达摩院推出多类型文档对话数据集

一、概述 title&#xff1a;Doc2Bot: Accessing Heterogeneous Documents via Conversational Bots 论文地址&#xff1a;Doc2Bot: Accessing Heterogeneous Documents via Conversational Bots - ACL Anthology 数据地址&#xff08;大概5千多轮开源数据&#xff09;&#…

用 Spark 预测回头客

访问【WRITE-BUG数字空间】_[内附完整源码和文档] 至此“淘宝双 11 数据分析与预测课程案例”所需要的环境配置完成。另外实际操作中发现在案例教程中存在一些小问题&#xff0c;比如教程中 Eclipse 版本为 3.8&#xff0c;但是在配置 Tomcat Server 时又要求配置 v8.0 版本&a…

【分布式技术专题】「授权认证体系」OAuth2.0协议的入门到精通系列之授权码模式

这里写目录标题 OAuth2.0是什么OAuth2.0协议体系的Roles角色OAuth定义了四个角色资源所有者资源服务器客户端授权服务器 传统的客户机-服务器身份验证模型的问题 协议流程认证授权授权码 OAuth2.0是什么 OAuth 2.0是用于授权的行业标准协议。OAuth 2.0专注于简化客户端开发人员…

从【连接受限】看Android网络

从连接受限看Android网络 现象摸索从通知开始是Handler发的通知看看NetworkStateTrackerHandler NetworkMonitor做了什么NetworkMonitor是一个状态机CaptivePortalProbeResult从何而来连接受限的直接原因 嗅探是怎样进行的ProbeThread 回过头看看InternalHanderregisterNetwork…

GRE 隧道协议

1.GRE协议简介 GRE&#xff08;General Routing Encapsulation &#xff0c;通用路由封装&#xff09;是对某些网络层协议(如IP和IPX)的数据报文进行封装&#xff0c;使这些被封装的报文能够在另一网络层协议(如IP)中传输。此外 GRE协议也可以作为VPN的第三层隧道协议连接两个…

ES6之迭代器

文章目录 前言迭代器1.原生具备Iterator接口的数据&#xff08;可用for...of遍历&#xff09;2.工作原理3.自定义遍历数据 总结 前言 迭代器&#xff08;Iterator&#xff09; for…of遍历 迭代器 迭代器是一种接口&#xff0c;为各种不同数据结构提供统一的访问机制。任何数…

c++ 11标准模板(STL) std::vector (八)

定义于头文件 <vector> template< class T, class Allocator std::allocator<T> > class vector;(1)namespace pmr { template <class T> using vector std::vector<T, std::pmr::polymorphic_allocator<T>>; }(2)(C17…

智慧工地烟火识别算法 opencv

智慧工地烟火识别系统应用pythonopencv深度学习算法模型技术分析前端视频信息&#xff0c;智慧工地烟火识别算法模型主动发现工地或者厂区现场区域内的烟雾和火灾苗头及时进行告警。OpenCV的全称是Open Source Computer Vision Library&#xff0c;是一个跨平台的计算机视觉处理…

前端三剑客 - HTML

前言 前面都是一些基础的铺垫&#xff0c;现在就正式进入到web开发环节了。 我们的目标就是通过学习 JavaEE初阶&#xff0c;搭建出一个网站出来。 一个网站分成两个部分&#xff1a; 前端&#xff08;客户端&#xff09; 后端&#xff08;服务器&#xff09; 通常这里的客户端…

ASP.NET Core Web API用户身份验证

一、JWT介绍 ASP.NET Core Web API用户身份验证的方法有很多&#xff0c;本文只介绍JWT方法。JWT实现了服务端无状态&#xff0c;在分布式服务、会话一致性、单点登录等方面凸显优势&#xff0c;不占用服务端资源。简单来说&#xff0c;JWT的验证过程如下所示&#xff1a; &a…

基于微服务架构的水果销售系统的设计与实现

访问【WRITE-BUG数字空间】_[内附完整源码和文档] 整体上为微服务架构&#xff0c;使用 SpringCloud 技术&#xff0c;每个独立的服务为一个单独的 SpringBoot 工程&#xff1b;数据库使用 MySQL 数据库&#xff1b;分布式缓存使用 Redis&#xff0c;消息队列使用 Kafka。包括…

基于matlab的相控阵系统仿真场景可视化

一、前言 此示例演示如何使用方案查看器可视化系统级仿真。 二、介绍 相控阵系统仿真通常包括许多移动物体。例如&#xff0c;阵列和目标都可以处于运动状态。此外&#xff0c;每个移动物体可能都有自己的方向&#xff0c;因此当模拟中出现更多玩家时&#xff0c;簿记变得越来越…

是人就能学会的Spring源码教学-Spring的简单使用

是人就能学会的Spring源码教学-Spring的简单使用 Spring的最简单入门使用第一步 创建项目第二步 配置项目第三步 启动项目 Spring的最简单入门使用 各位道友且跟我一道来学习Spring的最简单的入门使用&#xff0c;为了方便和简单&#xff0c;我使用了Spring Boot项目&#xff…