ChatGLM-6B的P-Tuning微调详细步骤及结果验证

news2024/12/23 14:50:07

文章目录

    • 1. ChatGLM-6B
      • 1.1 P-Tuning v2简介
    • 2. 运行环境
      • 2.1 项目准备
    • 3.数据准备
    • 4.使用P-Tuning v2对ChatGLM-6B微调
    • 5. 模型评估
    • 6. 利用微调后的模型进行验证
      • 6.1 微调后的模型
      • 6.2 原始ChatGLM-6B模型
      • 6.3 结果对比

1. ChatGLM-6B

ChatGLM-6B仓库地址:https://github.com/THUDM/ChatGLM-6B

ChatGLM-6B/P-Tuning仓库地址:https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning

1.1 P-Tuning v2简介

P-Tuning是一种较新的模型微调方法,它采用了参数剪枝的技术,可以将微调的参数量减少到原来的0.1%。具体来说,P-Tuning v2是基于P-Tuning v1的升级版,主要的改进在于采用了更加高效的剪枝方法,可以进一步减少模型微调的参数量。

P-Tuning v2的原理是通过对已训练好的大型语言模型进行参数剪枝,得到一个更加小巧、效率更高的轻量级模型。具体地,P-Tuning v2首先使用一种自适应的剪枝策略,对大型语言模型中的参数进行裁剪,去除其中不必要的冗余参数。然后,对于被剪枝的参数,P-Tuning v2使用了一种特殊的压缩方法,能够更加有效地压缩参数大小,并显著减少模型微调的总参数量。

总的来说,P-Tuning v2的核心思想是让模型变得更加轻便、更加高效,同时尽可能地保持模型的性能不受影响。这不仅可以加快模型的训练和推理速度,还可以减少模型在使用过程中的内存和计算资源消耗,让模型更适用于各种实际应用场景中。

2. 运行环境

本项目租借autoDL GPU机器,具体配置如下:

在这里插入图片描述

在这里插入图片描述

2.1 项目准备

1.创建conda环境

conda create -n tuning-chatglm python=3.8
conda activate tuning-chatglm

2.拉取ChatGLM-6B项目代码

# 拉取代码
git clone https://github.com/THUDM/ChatGLM-6B.git

# 安装依赖库
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/

3.进入ptuning目录

运行微调需要4.27.1版本的transformers。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖

cd ptuning
# 再次安装依赖,ptuning文档里有说明
pip install rouge_chinese nltk jieba datasets  -i https://pypi.tuna.tsinghua.edu.cn/simple/

4.补充

对于需要pip安装失败的依赖,可以采用源码安装的方式,具体步骤如下

git clone https://github.com/huggingface/peft.git
cd peft
pip install -e .

3.数据准备

官方微调样例是以 ADGEN (广告生成) 数据集为例来介绍微调的具体使用。

ADGEN 数据集为根据输入(content)生成一段广告词(summary),具体格式如下所示:

{
    "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
    "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}

请从官网下载 ADGEN 数据集,放到ptuning目录下并将其解压到 AdvertiseGen 目录。

下载地址:https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view

tar -zxvf AdvertiseGen.tar.gz

在这里插入图片描述

查看数据集大小:

> wc -l AdvertiseGen/*
> 1070 AdvertiseGen/dev.json
> 114599 AdvertiseGen/train.json
> 115669 total

4.使用P-Tuning v2对ChatGLM-6B微调

对于 ChatGLM-6B 模型基于 P-Tuning v2 进行微调。可将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。

进入到ptuning目录,首先,修改train.sh脚本,主要是修改其中的train_filevalidation_filemodel_name_or_pathoutput_dir参数:

  • train_file:训练数据文件位置
  • validation_file:验证数据文件位置
  • model_name_or_path:原始ChatGLM-6B模型文件路径
  • output_dir:输出模型文件路径
PRE_SEQ_LEN=128
LR=2e-2

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path model/chatglm-6b \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

执行bash train.sh脚本,运行过程如下:

  0%|                  | 0/3000 [00:00<?, ?it/s]
...  
{'loss': 4.2962, 'learning_rate': 0.0196, 'epoch': 0.01}
{'loss': 4.3112, 'learning_rate': 0.019533333333333333, 'epoch': 0.01}
  2%|███▊             | 70/3000 [03:20<4:17:06,  2.81s/it]

即使用了P-Tuning v2进行参数高效微调,但训练的速度还是很慢。

V100 32G显存的机器,训练花了4个多小时,显存占用率在85%左右

可以修改train.sh增大batch_size继续训练,由于时间及机器性能问题,本人没有进行操作过。

PRE_SEQ_LEN=128
LR=2e-2

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path model/chatglm-6b \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 128 \
    --per_device_eval_batch_size 8 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

输出文件:

> ls -al /root/autodl-tmp/tuning-chatglm/ChatGLM-6B/ptuning/output/adgen-chatglm-6b-pt-128-2e-2/
total 12
drwxrwxr-x 2 root root   98 Apr 24 21:12 .
drwxrwxr-x 8 root root  177 Apr 24 17:12 ..
-rw-rw-r-- 1 root root  195 Apr 24 21:12 all_results.json
-rw-rw-r-- 1 root root 1185 Apr 24 21:12 trainer_state.json
-rw-rw-r-- 1 root root  195 Apr 24 21:12 train_results.json

5. 模型评估

修改evaluate.sh文件,修改model_name_or_path(模型路径),ptuning_checkpointP-Tuning v2微调之后的权重路径)等参数:

  • model_name_or_path:原始ChatGLM-6B模型文件路径
  • ptuning_checkpoint:训练完成后,生成的文件目录

运行:bash evaluate.sh

PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
STEP=3000

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_predict \
    --validation_file AdvertiseGen/dev.json \
    --test_file AdvertiseGen/dev.json \
    --overwrite_cache \
    --prompt_column content \
    --response_column summary \
    --model_name_or_path model/chatglm-6b \
    --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
    --output_dir ./output/$CHECKPOINT \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_eval_batch_size 1 \
    --predict_with_generate \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

再次查看output输出文件:

在这里插入图片描述

模型评估花了3个多小时

6. 利用微调后的模型进行验证

6.1 微调后的模型

新建infer.py文件

import os
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer


MODEL_PATH = "./model/chatglm-6b"
CHECKPOINT_PATH = "./output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-1000"

# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

config = AutoConfig.from_pretrained(MODEL_PATH, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(MODEL_PATH, config=config, trust_remote_code=True).cuda()

prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}

for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

print(f"Quantized to 4 bit")
model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()


print("用户:你好\n")
response, history = model.chat(tokenizer, "你好", history=[])
print("ChatGLM-6B:\n",response)
print("\n------------------------------------------------\n用户:")

line = input()
while line:
    response, history = model.chat(tokenizer, line, history=history)
    print("ChatGLM-6B:\n", response)
    print("\n------------------------------------------------\n用户:")
    line = input()

6.2 原始ChatGLM-6B模型

新建infer_base.py文件

import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("./model/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("./model/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()

while True:
    a = input("请输入您的问题:(输入q以退出)")
    if a.strip() == 'q':
        exit()
    response, history = model.chat(tokenizer, "问题:" + a.strip() + '\n答案:', max_length=256, history=[])
    print("回答:", response)

6.3 结果对比

相同输入,上面窗口为原始ChatGLM-6B模型回答,下方为微调后模型回答。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/463218.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

<<c和指针>>温故及问题研讨(第三章)

第三章-数据 1. 前言2. 基本数据类型2.1 整型家族2.2 字面值常量 3. 基本声明3.1 数组的声明以及引用3.2 指针的声明注意事项3.3 隐式声明 4. 常量5. 作用域5.1 代码块作用域5.2 文件作用域5.3 原型作用域 6. 链接属性6.1 链接属性分类以及作用范围6.2 关键字:extern和static6.…

数据库基础篇 《18.MySQL8其它新特性》

第18章_MySQL8其它新特性 1. MySQL8新特性概述 MySQL从5.7版本直接跳跃发布了8.0版本&#xff0c;可见这是一个令人兴奋的里程碑版本。MySQL 8版本在功能上做了显著的改进与增强&#xff0c;开发者对MySQL的源代码进行了重构&#xff0c;最突出的一点是多MySQL Optimizer优化…

GPT详细安装教程-GPT软件国内也能使用

GPT (Generative Pre-trained Transformer) 是一种基于 Transformer 模型的自然语言处理模型&#xff0c;由 OpenAI 提出&#xff0c;可以应用于各种任务&#xff0c;如对话系统、文本生成、机器翻译等。GPT-3 是目前最大的语言模型之一&#xff0c;其预训练参数超过了 13 亿个…

LeetCode:206. 反转链表

&#x1f34e;道阻且长&#xff0c;行则将至。&#x1f353; &#x1f33b;算法&#xff0c;不如说它是一种思考方式&#x1f340; 算法专栏&#xff1a; &#x1f449;&#x1f3fb;123 一、&#x1f331;206. 反转链表 题目描述&#xff1a;给你单链表的头节点 head &#x…

Sharding-JDBC之垂直分库水平分表

目录 一、简介二、maven依赖三、数据库3.1、创建数据库3.2、订单表3.3、用户表 四、配置&#xff08;二选一&#xff09;4.1、properties配置4.2、yml配置 五、实现5.1、实体5.2、持久层5.3、服务层5.4、测试类5.4.1、保存订单数据5.4.2、查询订单数据5.4.3、保存用户数据5.4.4…

Android SeekBar控制视频播放进度(二)——seekTo()不准确

Android SeekBar控制视频播放进度二——seekTo不准确 简介seekTo()视频帧 和 视频关键帧解决办法方法一方法二 简介 上一篇文章中&#xff0c;我们介绍了使用SeekBar控制视频播放&#xff0c;使用过程中发现&#xff0c;对于一些视频&#xff0c;我们拖动SeekBar进度条调节播放…

喜报 | ScanA内容安全云监测获评“新一代信息技术创新产品”

4月20日&#xff0c;在赛迪主办的2023 IT市场年会上&#xff0c;“年度IT市场权威榜单”正式发布。 知道创宇的ScanA内容安全云监测产品荣获“新一代信息技术创新产品”奖项。作为中国IT业界延续时间最长的年度盛会之一&#xff0c;历届IT市场年会公布的IT市场权威榜单已成为市…

备份数据看这里,免费教你苹果手机怎么备份所有数据!

案例&#xff1a;苹果手机怎么算备份成功&#xff1f; 【友友们&#xff0c;手机恢复出厂设置前&#xff0c;怎么样可以备份苹果手机里面的所有数据&#xff1f;】 苹果手机备份数据对于用户来说是非常重要的。在备份数据的同时&#xff0c;还需要学会如何恢复误删的数据。那么…

【微服务笔记22】微服务组件之Sentinel控制台的使用(Sentinel Dashboard)

这篇文章&#xff0c;主要介绍微服务组件之Sentinel控制台的使用&#xff08;Sentinel Dashboard&#xff09;。 目录 一、Sentinel控制台 1.1、下载Dashboard控制台 1.2、搭建测试工程 &#xff08;1&#xff09;引入依赖 &#xff08;2&#xff09;添加配置信息 &#…

微服务生态 -- dubbo -- dubbo3应用级别服务发现(阅读官方文档)

服务发现概述 从 Internet 刚开始兴起&#xff0c;如何动态感知后端服务的地址变化就是一个必须要面对的问题&#xff0c;为此人们定义了 DNS 协议&#xff0c;基于此协议&#xff0c;调用方只需要记住由固定字符串组成的域名&#xff0c;就能轻松完成对后端服务的访问&#x…

236. 二叉树的最近公共祖先【190】

难度等级&#xff1a;中等 上一篇算法&#xff1a; 103. 二叉树的锯齿形层序遍历【191】 力扣此题地址&#xff1a; 236. 二叉树的最近公共祖先 - 力扣&#xff08;Leetcode&#xff09; 1.题目&#xff1a;236. 二叉树的最近公共祖先 给定一个二叉树, 找到该树中两个指定节点…

【MySQL】数据表的增删查改

1、CRUD的解释 C&#xff1a;Create增加 R&#xff1a;Retrieve查询 U&#xff1a;Update更新 D&#xff1a;Deleta删除 2、添加数据 2.1 添加一条记录 添加数据是对表进行添加数据的&#xff0c;表在数据库中&#xff0c;所以还是得先选中数据库&#xff0c;选中数据库还在进行…

STM32F429移植microPython笔记

目录 一、microPython下载。二、安装开发环境。三、编译开发板源码。四、下载验证。 一、microPython下载。 https://micropython.org/download/官网 下载后放在linux中。 解压命令&#xff1a; tar -xvf micropython-1.19.1.tar.xz 二、安装开发环境。 sudo apt-get inst…

MUSIC算法仿真

DOA波达方向估计 DOA&#xff08;Direction Of Arrival&#xff09;波达方向是指通过阵列信号处理来估计来波的方向&#xff0c;这里的信源可能是多个&#xff0c;角度也有多个。DOA技术主要有ARMA谱分析、最大似然法、熵谱分析法和特征分解法&#xff0c;特征分解法主要有MUS…

HTML+CSS+JS 学习笔记(四)———jQuery

&#x1f331;博客主页&#xff1a;大寄一场. &#x1f331;系列专栏&#xff1a;前端 &#x1f331;往期回顾&#xff1a; &#x1f618;博客制作不易欢迎各位&#x1f44d;点赞⭐收藏➕关注​​ 目录 jQuery 基础 jQuery 概述 下载与配置jQuery 2. 配置jQuery jQuery 选…

数据库管理-第七十期 自己?自己(20230425)

数据库管理 2023-04-25 第七十期 自己&#xff1f;自己1 自己吓自己2 自己坑自己3 自己挺自己4 自己懵自己总结 第七十期 自己&#xff1f;自己 来到70了&#xff0c;最近有点卷&#xff0c;写的稍微多了些。 吐槽一下五一调休&#xff0c;周末砍一天&#xff0c;连6天&#x…

重学Java第一篇——数组

本片博客主要讲述了以下内容&#xff1a; 1、 一维数组和二维数组的创建和初始化方式&#xff1b; 2、数组的遍历和赋值 3、java.util.Arrays的常用方法 4、数组在内存中的分布&#xff08;图示&#xff09; 创建数组和初始化 type[] arr_name;//方式一 type arr_name[];//方式…

一家传统制造企业的上云之旅,怎样成为了数字化转型典范?

众所周知&#xff0c;中国是一个制造业大国。在想要上云以及正在上云的企业当中&#xff0c;传统制造企业也占据了相当大的比例。 那么这类企业在实施数字化转型的时候&#xff0c;应该如何着手&#xff1f;我们不妨来看看一家传统制造企业的现身说法。 国茂股份的数字化转型诉…

云原生-如何部署k8s集群与部署sms集群

阿里云开通三台云服务器实例&#xff0c;&#xff08;同一个vpc下&#xff09;&#xff0c;配置安全组入规则&#xff0c;加入80端口 ssh登录三台云服务器 在三台云服务器上部署容器环境&#xff08;安装docker&#xff09;&#xff08;https://www.yuque.com/leifengyang/oncl…

Springboot Mybatis使用pageHelper实现分页查询

以下介绍实战中数据库框架使用的是mybatis&#xff0c;对整合mybatis此处不做介绍。 使用pageHelper实现分页查询其实非常简单&#xff0c;共两步&#xff1a; 一、导入依赖&#xff1b; 二、添加配置&#xff1b; 那么开始&#xff0c; 第一步&#xff1a; pom.xml添加依…