ChatGLM-6B部署和微调实例

news2024/9/24 18:00:09

文章目录

  • 前言
  • 一、ChatGLM-6B安装
    • 1.1 下载
    • 1.2 环境安装
  • 二、ChatGLM-6B推理
  • 三、P-tuning 微调
    • 3.1微调数据集
    • 3.2微调训练
    • 3.3微调评估
    • 3.4 调用新的模型进行推理
  • 总结


前言

ChatGLM-6B ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 General Language Model (GLM) 架构,具有 62 亿参数。ChatGLM-6B是本人尝试使用和微调的第一个大语言模型,自我感觉该模型很适合作为大语言模型的入门级选手,无论是部署配置还是推理微调都十分方便。本文主要介绍如何配置部署ChatGLM-6B,以及ChatGLM-6B推理和P-tuning v2微调基本步骤,希望可以帮助大家使用ChatGLM-6B。


一、ChatGLM-6B安装

1.1 下载

ChatGLM-6B项目仓库地址为 GitHub,模型文件下载地址为Huggingface,将下载好的模型文件chatglm-6b文件放至项目仓库中的ptuning文件目录下(如下图所示)。整个下载时间的长短根据网速和是否使用远程服务器因人而异,本人因使用的是远程服务器,下载时间共约5个小时。
在这里插入图片描述

1.2 环境安装

服务器的版本为RTX 3090,内存为24GB。Python版本为3.8.16,ubuntu的版本为20.04,Cuda的版本为11.6

库名版本
transformers4.27.1
torch1.13.1

详情可见requirements.txt,其中gradio库有的时候会安装失败,如果后续不考虑前端交互的平台的构建,此库可以先不安装,并不影响模型推理和微调。环境配置步骤如下代码所示:

conda create -n test python=3.8.16 -y
source activate test
pip install -r requirements.txt
cd ChatGLM

二、ChatGLM-6B推理

ChatGLM-6B推理部分,只要找到cli_demo.py文件运行即可。

tokenizer = AutoTokenizer.from_pretrained("ptuning/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("ptuning/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()

以下是推理部分的展示:
在这里插入图片描述
当然我们也想要批量式询问ChatGLM-6B,这里我自己写了一个批量调用的py文件:

import torch
from transformers import AutoTokenizer, AutoModel
import torch
import sys
import pandas as pd
model_path="ptuning/chatglm-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained("ptuning/chatglm-6b",trust_remote_code=True).float()
#model =model.to("cpu")
model = model.eval()
data = pd.read_csv('Q1.csv')
MC = data['Question'].tolist()
j = -1
for i in MC:
    j = j+1
    input1 = f"{i}"
    print(input1)
    response,history = model.chat(tokenizer,input1,history=[],temperature=1)
    print(response)
    print("--------------------------------------------------")
    data['Answer'].loc[j] = response
    data.to_csv('Q1.csv',index = False,encoding='utf_8_sig')

最终Q1.csv的结果为:
在这里插入图片描述

三、P-tuning 微调

下图展示出ChatGLM-6B进行P-tuning v2微调的大致流程,首先需要构建好微调模型使用的数据集(包括训练集,验证集和测试集),接着是配置运行train.sh,进行数小时的训练之后将会得到模型参数权重文件Checkpoint,然后对evaluate.sh进行参数配置和运行,将会得到一系列的测试集结果,到此便是微调部分。为了检测微调后的模型在新数据上的效果,可以对cli_demo.py文件进行配置和运行。
在这里插入图片描述

3.1微调数据集

我的课题是研究法律判决预测任务,因此我的微调数据集的输入为案情陈述,输出为罪行判决。ChatGLM-6B的微调数据集有很多的格式可以选择,这里是经典的content+summary格式。以下是一个例子🌰:
{“content”: “经审理查明,2017年10月10日18时左右,被告人张某酒后驾驶牌号为川Q???二轮摩托车从宜宾市翠屏区牟坪镇牟坪村5组35号家中出发,前往宜宾市翠屏区牟坪镇派出所办事,被办案民警发现被告人张某饮酒驾驶机动车,即对张某进行了呼气式酒精检测,检出酒精含量268mg/100mL。”, “summary”: “根据中华人民共和国刑法第133条,判处张某危险驾驶罪。其中检出张某酒精含量268mg/100mL,根据中国的交通法规,血液中酒精含量超过80mg/100ml,即被认定为醉驾,因此张某符合危险驾驶罪中的醉酒驾驶机动车。”}
我们将标注好的数据分成训练集、验证集和测试集,一起存入Legal_data文件夹中,并放在ptuning目录下,如下图所示:
在这里插入图片描述
其中train.json中有44条数据, dev.json中有10条数据, test.json中有10条数据。数据量不大,只是为了方便走一遍微调流程,大家可以在创建自己的微调数据集的时候多标注些,这样会大大提高模型的性能。

3.2微调训练

查看train.sh,模型主要的训练参数有PRE_SEQ_LEN,max_target_length,max_source_length,learning_rate,per_device_train_batch_size,max_steps,per_device_train_batch_size,gradient_accumulation_stepsquantization_bit。下面将详细介绍这些训练参数的含义与作用。
PRE_SEQ_LEN是指自然语言指令的长度,而max_source_length是指整个输入序列的最大长度,max_target_length指整个输出序列的最大长度。 一般来说,PRE_SEQ_LEN应该小于或者等于max_source_length,因为输入序列除了包含指令之外,还可能包含其他内容,例如上下文信息或对话历史。根据微调标注数据的输入输出的文本长度,我们设置PRE_SEQ_LEN为128,max_target_lengthmax_source_length为300。
learning_rate是一个关键参数,它决定了每次更新模型权重时,根据梯度下降的方向应该迈出多大的步伐。我们使用ChatGLM的默认值1e-3。
per_device_train_batch_size设置为1,那么每个设备上将会有1个样本作为输入进行模型训练,并且基于这1个样本的损失值来进行一次模型参数的更新。
gradient_accumulation_steps指定了在执行一次模型权重更新(即一次反向传播步骤)之前,要累积多少个批次的梯度。gradient_accumulation_steps的值为16,这样模型会在每16个批次后进行一次权重更新,等价于使用大小为 16的批次进行训练。
max_steps 参数用于指定在结束训练前,模型应进行多少步的更新。每一“步”通常包括一个前向传播和一个反向传播,并且可能涉及到多个批次,在实验中max_steps设置为300。
quantization_bit 参数通常关联到模型权重的量化过程。量化是一种将模型权重从浮点数转换为低精度(如整数)表示形式的技术。这个过程可以显著减少模型的存储需求和计算复杂性,从而提高推理速度并减少内存使用。

PRE_SEQ_LEN=128
LR=2e-2

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --train_file Legal_data/train.json \
    --validation_file Legal_data/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path chatglm-6b \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 300 \
    --max_target_length 300 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 300 \
    --logging_steps 10 \
    --save_steps 100 \
    --learning_rate $LR \
    # --report_to tensorboard \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

在运行sh train.sh之前,我们需要额外安装datasetsjiebarouge_chinesenltk库:

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple datasets
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple jieba
pip install rouge_chinese
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple nltk

微调训练成功的截图如下:
在这里插入图片描述
由于logging_steps 设置为10,因此每运行十个step,便会保留一次loss,learning_rate和epoch值。最后也会输出一个train metrics
在这里插入图片描述
在这里插入图片描述

sh train.sh完成之后会发现ptuning路径下有新的文件夹output,文件夹中保存了微调训练后的checkpoint和一些评估指标,下图中有三个checkpoint是因为max_steps 为300,save_steps为100。在这里插入图片描述

3.3微调评估

执行sh evaluate.sh,注意evaluate.sh中的一些参数要和train.sh中的参数一致。如max_source_length,max_target_lengthPRE_SEQ_LEN

PRE_SEQ_LEN=128
CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
STEP=300

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_predict \
    --validation_file Legal_data/dev.json \
    --test_file Legal_data/test.json \
    --overwrite_cache \
    --prompt_column content \
    --response_column summary \
    --model_name_or_path chatglm-6b \
    --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
    --output_dir ./output/$CHECKPOINT \
    --overwrite_output_dir \
    --max_source_length 300 \
    --max_target_length 300 \
    --per_device_eval_batch_size 1 \
    --predict_with_generate \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

sh evaluate.sh执行成功的截图如下:
在这里插入图片描述
sh evaluate.sh完成之后会发现output文件夹中保存了评估后的generated_predictions.txt``predict_results.json文件。

3.4 调用新的模型进行推理

完成微调后的模型在测试集上的评估之后,我们如何使用微调好的模型进行推理呢?这里我们以cli_demo.py为例,之前的是这样的调用模型:

tokenizer = AutoTokenizer.from_pretrained("ptuning/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("ptuning/chatglm-6b", trust_remote_code=True).half().cuda()
model = model.eval()

我们仅需要将上面的代码变为:

tokenizer = AutoTokenizer.from_pretrained("ptuning/chatglm-6b", trust_remote_code=True)

config = AutoConfig.from_pretrained("ptuning/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained("ptuning/chatglm-6b", config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join('ptuning/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-300', "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()

主要注意⚠️的是以下一行代码,需要导入微调好的checkpoint:

prefix_state_dict = torch.load(os.path.join('ptuning/output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-300', "pytorch_model.bin"))

这样我们便可以调用自己微调好的ChatGLM-6B模型:
在这里插入图片描述


总结

ChatGLM-6B对于中文的问题回答能力优秀,希望大家可以通过我的分享来测试它❤️❤️❤️

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1397380.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

unity-shader笔记OLD

shader shader在面板中的位置相关代码代码切换shader shader在面板中的位置 选中物体属性面板中 相关代码 代码切换shader 挂载到怪物上的shader名字统一叫body,然后获取上面的SkinnedMeshRender SkinnedMeshRender smr; //恢复到原来的shader …

JavaScript DOM可以做什么?

1、通过id获取标签元素 DOM是文档对象模型&#xff0c;它提供了一些属性和方法来方便我们操作document对象&#xff0c;比如getElementById()方法可以通过某个标签元素的id来获取这个标签元素 // 用法 window.document.getElementById(id); // 例子 <!DOCTYPE html> &l…

LeetCode、374. 猜数字大小【简单,二分】

文章目录 前言LeetCode、374. 猜数字大小【简单&#xff0c;二分】题目及类型思路及代码实现 资料获取 前言 博主介绍&#xff1a;✌目前全网粉丝2W&#xff0c;csdn博客专家、Java领域优质创作者&#xff0c;博客之星、阿里云平台优质作者、专注于Java后端技术领域。 涵盖技…

移动云助力智慧交通数智化升级

智慧交通是在整个交通运输领域充分利用物联网、空间感知、云计算、移动互联网等新一代信息技术&#xff0c;综合运用交通科学、系统方法、人工智能、知识挖掘等理论与工具&#xff0c;以全面感知、深度融合、主动服务、科学决策为目标&#xff0c;推动交通运输更安全、更高效、…

联想拯救者冠名2024第二届OPENAIGC开发者大赛,开启AI落地新纪元

2024年1月17日&#xff0c;在联想拯救者及消费生态新品发布会上&#xff0c;AIGC开放社区携手联想拯救者&#xff0c;宣布将共同举办“AI生成未来第二届拯救者杯OPENAIGC开发者大赛”。此次大赛旨在集结所有开发者的智慧和创造力&#xff0c;推动人工智能技术的创新和应用实践。…

Promise的几道基础题

event loop它的执行顺序&#xff1a; 一开始整个脚本作为一个宏任务执行执行过程中同步代码直接执行&#xff0c;宏任务进入宏任务队列&#xff0c;微任务进入微任务队列当前宏任务执行完出队&#xff0c;检查微任务列表&#xff0c;有则依次执行&#xff0c;直到全部执行完执…

服务器自动拉取git代码运行脚本

# 1.场景分析 工作中常常会遇到本地编辑shell脚本或者python脚本完成后需要在服务器上运行的情况&#xff0c;每次进行拷贝费时费力。下面介绍下通过git管理器&#xff0c;实现本地与服务器代码同步的方式。选择公司搭建的gitlab为例&#xff1a; 2.gitlab配置服务器ssh密钥 …

免费的爬虫软件【2024最新】

在国际市场竞争日益激烈的背景下&#xff0c;国外网站的SEO排名直接关系到网站在搜索引擎中的曝光度和用户点击量。良好的SEO排名能够带来更多的有针对性的流量&#xff0c;提升网站的知名度和竞争力。 二、国外网站SEO排名的三种方法 关键词优化&#xff1a; 关键词优化是SEO…

【IAP】核心开发流程

最近做了IAP U盘升级模块开发&#xff0c;总结下IAP基本开发流程&#xff0c;不深入讨论原理。 详细原理参考 首先需要知道我们需要把之前的APP区域拆一块出来做BOOT升级程序区域。 以STM32F103为例&#xff0c;0x08000000到0x0807FFFF为FLASH空间&#xff0c;即上图代码区域…

Java基础面试题-2day

面向对象 创建一个对象用什么运算符&#xff0c;对象实体和对象引用有什么不同&#xff1f; 创建对象使用new String A new String(); A即为对象引用&#xff0c;通过new运算符&#xff0c;创建String()类型的对象实体。 对象引用的存储位置在栈内存 对象实体的存储位置在堆…

2024玩儿转TikTok之环境介绍及搭建

郑重申明&#xff1a;本文章只对合法合理做tiktok视频运营的用户做学习交流使用&#xff0c;有其他使用不当的违规违法行为后果自负&#xff01; 一、网络环境图介绍&#xff1a;我们只需要保证红色的环境通畅即可(手机阿里tiktok运营专用服务器) 二、服务器部分环境搭建 1、…

STM32F103标准外设库——SysTick系统定时器(八)

个人名片&#xff1a; &#x1f981;作者简介&#xff1a;一名喜欢分享和记录学习的在校大学生 &#x1f42f;个人主页&#xff1a;妄北y &#x1f427;个人QQ&#xff1a;2061314755 &#x1f43b;个人邮箱&#xff1a;2061314755qq.com &#x1f989;个人WeChat&#xff1a;V…

软件测试|sqlalchemy一对一关系详解

简介 SQLAlchemy 是一个强大的 Python ORM&#xff08;对象关系映射&#xff09;库&#xff0c;它允许我们将数据库表映射到 Python 对象&#xff0c;并提供了丰富的关系模型来处理不同类型的关系&#xff0c;包括一对一关系。在本文中&#xff0c;我们将深入探讨 SQLAlchemy …

大数据工作岗位需求分析

前言&#xff1a;随着大数据需求的增多&#xff0c;许多中小公司和团队也新增或扩展了大数据工作岗位&#xff1b;但是却对大数据要做什么和能做什么&#xff0c;没有深入的认识&#xff1b;往往是招了大数据岗位&#xff0c;搭建起基础能力后&#xff0c;就一直处于重复开发和…

基于springboot+vue的校园周边美食探索及分享平台系统(前后端分离)

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战 主要内容&#xff1a;毕业设计(Javaweb项目|小程序等)、简历模板、学习资料、面试题库、技术咨询 文末联系获取 项目背景…

网络知识梳理:TCP与UDP

TCP&#xff08;传输控制协议&#xff09;和UDP&#xff08;用户数据报协议&#xff09;是两种主要的互联网协议&#xff0c;它们都用于在网络层和传输层进行数据传输&#xff0c;但它们在数据传输的方式和特性上有显著的区别&#xff1a; 1.TCP (传输控制协议) 连接导向&…

贪心算法 ——硬币兑换、区间调度、

硬币兑换&#xff1a; from book&#xff1a;挑战程序设计竞赛 思路&#xff1a;优先使用大面额兑换即可 package mainimport "fmt"func main() {results : []int{}//记录每一种数额的张数A : 620B : A//备份cnts : 0 //记录至少需要多少张nums : []int{1, 5, 10, 5…

万户 ezOFFICE wf_printnum.jsp SQL注入漏洞复现

0x01 产品简介 万户OA ezoffice是万户网络协同办公产品多年来一直将主要精力致力于中高端市场的一款OA协同办公软件产品,统一的基础管理平台,实现用户数据统一管理、权限统一分配、身份统一认证。统一规划门户网站群和协同办公平台,将外网信息维护、客户服务、互动交流和日…

vue中父组件异步传值,渲染问题

vue中父组件异步传值&#xff0c;渲染问题 父组件异步传值&#xff0c;子组件渲染不出来。有如下两种解决方法&#xff1a; 1、用v-if解决&#xff0c;当父组件有数据才渲染 <Child v-if"dataList && dataList.length > 0" :data-list"dataLis…

深度学习记录--mini-batch gradient descent

batch vs mini-batch gradient descent batch&#xff1a;段&#xff0c;块 与传统的batch梯度下降不同&#xff0c;mini-batch gradient descent将数据分成多个子集&#xff0c;分别进行处理&#xff0c;在数据量非常巨大的情况下&#xff0c;这样处理可以及时进行梯度下降&…