【大模型】基于ChatGLM进行微调及应用 [更新中......]

news2024/11/17 3:54:41

文章目录

  • 一、前言
  • 二、说明
    • 2.1 代码结构
    • 2.2 依赖包版本
  • 三、启动对话演示
    • 3.1 命令行交互 cli_demo.py
    • 3.2 网页交互 web_demo.py
  • 四、微调模型
    • 4.1 基于 P-Tuning v2 微调模型
      • 4.1.1 软件依赖
      • 4.1.2 下载数据集
      • 4.1.3 下载模型文件
      • 4.1.4 操作步骤
    • 4.2 基于 Full Parameter 微调模型
    • 4.3 基于LoRA微调模型
  • 参考资料

一、前言

ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 General Language Model (GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。 ChatGLM-6B 使用了和 ChatGPT 相似的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。

  • 智谱AI官方内测网站:ChatGLM

  • 项目地址:https://github.com/THUDM/ChatGLM-6B/tree/main

  • 模型文件:https://huggingface.co/THUDM/chatglm-6b/tree/main

  • B站讲解视频:【官方教程】ChatGLM-6B 微调:P-Tuning,LoRA,Full parameter

  • 七月在线博客:ChatGLM两代的部署/微调/实现:从基座GLM、ChatGLM的LoRA/P-Tuning微调、6B源码解读到ChatGLM2的微调与实现

二、说明

2.1 代码结构

在这里插入图片描述

2.2 依赖包版本

由于大模型相关的各种依赖包版本更新较快,会导致各种报错,如:

'ChatGLMTokenizer' object has no attribute 'sp_tokenizer'

这里主要是由transformers 版本问题导致的,解决方案可以参考博客:https://blog.csdn.net/Tink_bell/article/details/137942170

三、启动对话演示

ChatGLM-6B下提供了cli_demo.py和web_demo.py两个文件来启动模型:

  • cli_demo.py:使用命令行进行交互。
  • web_demo.py:使用gradio库使用本机服务器进行网页交互。

这里,依赖包的版本会影响到代码的运行。经过多次报错与尝试,我这里使用的依赖包版本为:

transformers==4.33.0
gradio==3.39.0

3.1 命令行交互 cli_demo.py

[待补充]

3.2 网页交互 web_demo.py

web_demo.py中基于gradio库使用本机服务器进行网页交互,具体运行步骤如下:

(1)模型路径测试。

首先需要将模型地址配置为本地模型路径
原代码:

tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()

修改为本地模型路径后的代码:

model_path = './model/chatglm-6b'
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
model = model.eval()

(2)模型量化。

这里,由于我的GPU内存不够,所以对模型做量化操作:

model_path = './model/chatglm-6b'

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, quantization_config=bnb_config, device_map={"":0}, trust_remote_code=True)
model = model.eval()

(3)运行 web_demo.py

可以看到控制台输出
在这里插入图片描述

说明网页交互服务已在本地 http://127.0.0.1:7860 运行起来。

查看GPU显存占用情况,可以看到使用模型量化,最后只占用了不到5GB的显存:
在这里插入图片描述

(4)端口映射。

这里,由于我们是在远程服务器上运行的服务,如果想在本地浏览器访问,需要做一个端口映射,具体命令如下:

ssh -L 1234:localhost:7860 root@172.xxx.yyy.zzz

基于上述命令,将远程服务器的 7860 端口映射至本地 1234 端口。

(5)本地访问服务。

然后我们在本地浏览器打开 http://localhost:1234/ 即可访问该页面,如下所示:
在这里插入图片描述

在输入窗口输入文本信息并提交即可实现调用ChatGLM的对话功能。

四、微调模型

ChatGLM模型的fine-tune有多种模式:

  • P-Tuning v2
  • LoRA
  • Full parameter

其中ChatGLM官方代码仓库中给出了基于 P-Tuning v2Full parameter 的方法,具体微调模型的方式可以参考B站视频:【官方教程】ChatGLM-6B 微调:P-Tuning,LoRA,Full parameter

LoRA的微调方法可以参考:https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/chatglm_v2_6b_lora

4.1 基于 P-Tuning v2 微调模型

官方给出的微调教程:https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/README.md

本仓库实现了对于 ChatGLM-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。

下面以 ADGEN (广告生成) 数据集为例介绍代码的使用方法。

4.1.1 软件依赖

这里,基于P-Tuning v2 微调模型对于transformers的版本有限制,需要4.27.1版本的transformers

pip install transformers==4.27.1

此外,还需要安装以下依赖

pip install rouge_chinese nltk jieba datasets

4.1.2 下载数据集

ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。

{
    "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
    "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}

ADGEN 数据集可以从 Google Drive 或者 Tsinghua Cloud 来下载。

4.1.3 下载模型文件

模型文件:https://huggingface.co/THUDM/chatglm-6b/tree/main

可以选择 git clone 或者 手动 的方式来下载模型文件。

4.1.4 操作步骤

(1)文件及数据准备

使用 ptuning 文件夹下的代码进行微调,这里我们在当前目录下创建:

  • model目录存放下载的模型文件
  • data 目录存放ADGEN 数据文件
    在这里插入图片描述

(2)修改 train.sh 代码

根据本地模型及数据文件目录,修改train.sh 中的相应参数:

PRE_SEQ_LEN=128
LR=2e-2

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --train_file './data/AdvertiseGen/train.json' \
    --validation_file './data/AdvertiseGen/dev.json' \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path './model/chatglm-6b/' \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

【参数说明】:
train.sh 中的 PRE_SEQ_LENLR 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。

在默认配置 quantization_bit=4、per_device_train_batch_size=1、gradient_accumulation_steps=16 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。

(3)运行脚本

bash train.sh

在这里插入图片描述

这里,我们对模型做了量化处理,可以看到占用的显存占用情况:
在这里插入图片描述

迭代3000次:
在这里插入图片描述

可以通过wandb 查看运行中的参数变化情况:
在这里插入图片描述

wandb中可以看到 train_loss 的变化情况:
在这里插入图片描述

4.2 基于 Full Parameter 微调模型

如果需要进行全参数的 Finetune,需要安装 Deepspeed,然后运行以下指令:

bash ds_train_finetune.sh

4.3 基于LoRA微调模型

ChatGLM官方仓库中的部分微调使用的是基于 P Tuning v2的微调方式,并未给出基于LoRA的微调。

LoRA的微调方法可以参考:https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/chatglm_v2_6b_lora

参考资料

  • ChatGLM两代的部署/微调/实现:从基座GLM、ChatGLM的LoRA/P-Tuning微调、6B源码解读到ChatGLM2的微调与实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1884706.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大模型简介

大模型框架 大模型基于深度学习,利用大量数据和计算资源训练具有大量参数的神经网络模型。通过不断地调整模型参数,使得模型能够在各种任务中取得最佳表现。 通常说的大模型的“大”的特点体现在:参数数量庞大、训练数据量大、计算资源需求…

记一次EasyExcel的错误使用导致的频繁FullGC

记一次EasyExcel的错误使用导致的频繁FullGC 一、背景描述二、场景复现三、原因分析四、解决方案五、思考复盘 一、背景描述 繁忙的校招结束了,美好的大学四年也结束了,作者也有10个月没有更新了。拿到心仪的offer之后也开始了苦B的打工生活。 最近接到…

Python爬取豆瓣电影+数据可视化,爬虫教程!

1. 爬取数据 1.1 导入以下模块 import os import re import time import requests from bs4 import BeautifulSoup from fake_useragent import UserAgent from openpyxl import Workbook, load_workbook1.2 获取每页电影链接 def getonepagelist(url,headers):try:r reque…

JAVA里的BigDecimal用法

public class BigDecimaldemo1 {public static void main(String[] args) {System.out.println(0.090.01);//为什么不是0.10呢?} }在使用float或者double类型的数据在进行数学运算的时候,很有可能会产生精度丢失问题。我们都知道计算机底层在进行运算的时候&#x…

SpringBoot中整合ONLYOFFICE在线编辑

SpringBoot整合OnlyOffice SpringBoot整合OnlyOffice实现在线编辑1. 搭建私有的OnlyOffice的服务2. SpringBoot进行交互2.1 环境2.2 我们的流程2.3 接口规划2.3.1 获取编辑器配置的接口2.3.2 文件下载地址2.3.3 文件下载地址 3. 总结4. 注意4.1 你的项目的地址一定一定要和only…

详细django框架+SIMPLEUI+import_export设计web管理后台(四)

目录 1.项目简介 2.搭建django框架 3.引入 SIMPLEUI插件 3.1安装simpleui 3.2 修改设置 3.3 克隆静态资源 3.4登陆测试 4.优化页面 4.1 修改后台名称显示 4.2 增加页面LOGO图标 4.3增加网址图标:目前主要的浏览器都支持favicon.ico图标 4.4 修改APP名称显…

用摄像头实现识别道路中的车道线、行人与车辆检测(级联分类器、HOG+SVM、行人检测)

基于树莓派的智能小车,用摄像头实现识别道路中的车道线识别、行人检测与车辆检测。 本项目旨在开发一套基于摄像头的智能道路环境感知系统,该系统能够实时识别道路中的车道线、行人与车辆,为自动驾驶汽车、智能交通管理以及辅助驾驶系统提供关…

Go语言数据类型--常量、iota枚举、数据类型分类

变量:程序运行期间,可以改变的量,变量声明需要var关键字。 常量:程序运行期间,不可以改变的量,变量声明需要const关键字。 自动推导 常量的自动推导不能加:; 不同类型数据的声明 可以使用…

华为OD机试 - 表演赛游戏分组 - 动态规划(Java 2024 D卷 200分)

华为OD机试 2024D卷题库疯狂收录中,刷题点这里 专栏导读 本专栏收录于《华为OD机试(JAVA)真题(D卷C卷A卷B卷)》。 刷的越多,抽中的概率越大,每一题都有详细的答题思路、详细的代码注释、样例测…

目标检测算法讲解:从传统方法到深度学习,全面解析检测技术的演进与应用!

在计算机视觉领域,目标检测是一个基本且关键的任务,它不仅涉及图像中对象的识别,还包括确定这些对象的具体位置。这一任务通常通过算法来实现,这些算法能够识别出图像中的一个或多个目标,并给出每个目标的类别和位置。…

【面试系列】产品经理高频面试题及详细解答

欢迎来到我的博客,很高兴能够在这里和您见面!欢迎订阅相关专栏: ⭐️ 全网最全IT互联网公司面试宝典:收集整理全网各大IT互联网公司技术、项目、HR面试真题. ⭐️ AIGC时代的创新与未来:详细讲解AIGC的概念、核心技术、…

4.BeanFactory

可以看出BeanFactory表面上只有getBean相关的方法。 实际上控制反转、基本的依赖注入、Bean的生命周期的各种功能,都是由BeanFactory的实现类来实现的。(DefaultListableBeanFactory) DefaultListableBeanFactory管理单例对象DefaultSinglet…

第11章 规划过程组(11.6规划进度管理)

第11章 规划过程组(二)11.6规划进度管理,在第三版教材第385页;#软考中级##中级系统集成项目管理师# 文字图片音频方式 第一个知识点:主要输出 1、进度管理计划 准确度 定义活动持续时间估算的可接受区间&#xff0…

springboot拦截器,ThreadLocal(每个线程的公共区域)

拦截器 配置信息(拦截所有请求) 其实这种可以作为springAOP作日志记录

flask数据连接池、定制命令

【 一 】数据库连接池 【 1 】flask操作mysql 基本的使用不使用连接池 from flask import Flask, jsonify import pymysqlapp Flask(__name__) app.debug Trueapp.route(/) def index():conn pymysql.connect(userroot,password"123123",host127.0.0.1,databas…

计算两个经纬度之间的球面距离(基于Mysql和PHP实现)

计算两个经纬度之间的球面距离 1、MySQL实现方式 - 基于空间函数(ST_Distance_Sphere)实现 前置条件:确保您使用的是 MySQL 8.0 或更高版本,因为较早的版本对地理空间的支持有限。 1.1 创建表和索引 说明:设置 location 为 point 类型 #…

Wireshark - tshark支持iptables提供数据包

tshark现在的数据包获取方式有两种,分别是读文件、网口监听(af-packet原始套接字)。两种方式在包获取上,都是通过读文件的形式;存在文件io操作,在专门处理大流量的情境下, 我们复用wireshark去做…

DNS访问百度

DNS,英文全称是 domain name system,域名解析系统,它的作用也很明确,就是域名和 IP 相互映射。 假设你要查询 baidu.com 的 IP 地址: 首先会查找浏览器的缓存,看看是否能找到 baidu.com 对应的IP地址,找到就直接返回&…

【NOI-题解】1326. 需要安排几位师傅加工零件1228. 排队打水问题1229. 拦截导弹的系统数量求解

文章目录 一、前言二、问题问题:1326. 需要安排几位师傅加工零件问题:1228. 排队打水问题问题:1229. 拦截导弹的系统数量求解 三、感谢 一、前言 本章节主要对贪心问题进行讲解,包括《1326. 需要安排几位师傅加工零件》《1228. 排…

【嵌入式】探索嵌入式世界:在ARM上构建俄罗斯方块游戏的奇妙之旅

文章目录 前言:1. 简介2. 总体设计思路及功能描述2.1 设计思路2.2 功能描述2.3 程序流程图 3. 各部分程序功能及详细说明3.1 游戏界面函数3.1.1 游戏界面中的图片显示3.1.2 游戏开始界面3.1.3 游戏主界面3.1.4 游戏结束广告界面3.1.5 游戏界面中的触摸反馈3.1.6 游戏…