使用 Amazon SageMaker 微调和部署 ChatGLM 模型

news2025/1/11 2:36:46

63da301a0d8a94f5737fbb6318645ea1.gif

本篇文章主要介绍如何使用 Amazon SageMaker 进行 ChatGLM 模型部署和微调的示例。

这个示例主要包括:

  1. ChatGLM 总体介绍

  2. ChatGLM 微调介绍

  3. ChatGLM 环境设置

  4. ChatGLM 微调训练

  5. ChatGLM 部署测试

前言

大语言模型是一种基于深度学习技术的人工智能模型,可以追溯到早期的语言模型和机器翻译系统。直到最近,随着深度学习技术的崛起,大型预训练语言模型才开始引起广泛的关注。

大语言模型使用大规模的文本数据集进行预训练,从而学习到丰富的语言知识和语境理解能力。通过预训练和微调的方式,大语言模型可以用于各种自然语言处理任务,例如文本生成、机器翻译、问答系统、对话系统等。它们在许多领域都展示出了令人印象深刻的性能,并成为推动人工智能技术发展的重要驱动力。

ChatGLM 总体介绍

ChatGLM 模型是由清华大学开源的、支持中英双语问答的对话语言模型,并针对中文进行了优化。该模型基于 General Language Model(GLM)架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署。

ChatGLM 具备以下特点:

  • 充分的中英双语预训练:ChatGLM 在 1:1 比例的中英语料上训练了 1T 的 token 量,兼具双语能力。

  • 优化的模型架构和大小:修正了二维 RoPE 位置编码实现。6B(62 亿)的参数大小,也使得研究者和个人开发者自己微调和部署 ChatGLM 成为可能。

  • 较低的部署门槛:FP16 半精度下,ChatGLM 需要至少 13 GB 的显存进行推理,结合模型量化技术,这一需求可以进一步降低到 10GB(INT8) 和 6GB(INT4),使得 ChatGLM 可以部署在消费级显卡上。

  • 更长的序列长度:ChatGLM 序列长度达 2048,支持更长对话和应用。

ChatGLM 微调介绍

模型微调主要分为 Full Fine-Tune 和 PEFT (Performance-Efficient Fine-Tune),前者模型全部参数都会进行更新,训练时间较长,训练资源较大;而后者会冻结大部分参数、微调训练网络结构,常见的方式是 LoRA 和 P-Tuning v2。对于 ChatGLM 来说,选择 P-Tuning v2 进行模型微调,其网络结构如下:在Transformers 的所有层均增加 Prompt/Prefix。

959dcb0a647654f6aef6c0266ba389a9.jpeg

ChatGLM 环境设置

备注:项目中的示例代码均保存于代码仓库,地址如下:

https://github.com/GlockGao/aws-sagemaker-llm

1. 升级 Python SDK

pip install --upgrade boto3
pip install --upgrade sagemaker
pip install huggingface_hub

2. 获取运行时资源,包括区域、角色、账号、S3 桶等

import boto3
import sagemaker
from sagemaker import get_execution_role


sess = sagemaker.Session()
role = get_execution_role()
sagemaker_default_bucket = sess.default_bucket()


account = sess.boto_session.client("sts").get_caller_identity()["Account"]
region = sess.boto_session.region_name

左滑查看更多

ChatGLM 微调训练

准备微调

克隆代码

rm -rf ChatGLM-6B
git clone https://github.com/THUDM/ChatGLM-6B.git
cd ChatGLM-6B
git checkout 163f94e160f08751545e3722730f1832d73b92d1

左滑查看更多

下载数据集

此处采用示例的广告数据集。根据输入实现广告语的输出,格式如下:

{
 "content": "类型#上衣版型#宽松版型#显瘦图案#线条衣样式#衬衫衣袖型#泡泡袖衣款式#抽绳",
 "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}
# 下载 ADGEN 数据集
wget -O AdvertiseGen.tar.gz https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1


# 解压数据集
tar -xzvf AdvertiseGen.tar.gz

左滑查看更多

下载 ChatGLM 原始模型

from huggingface_hub import snapshot_download
from pathlib import Path




local_cache_path = Path("./model")
local_cache_path.mkdir(exist_ok=True)


model_name = "THUDM/chatglm-6b"


# Only download pytorch checkpoint files
allow_patterns = ["*.json", "*.pt", "*.bin", "*.model", "*.py"]


model_download_path = snapshot_download(
    repo_id=model_name,
    cache_dir=local_cache_path,
    allow_patterns=allow_patterns,
)


# Get the model files path
import os
from glob import glob


local_model_path = None


paths = os.walk(r'./model')
for root, dirs, files in paths:
    for file in files:
        if file == 'config.json':
            # print(os.path.join(root, file))
            local_model_path = str(os.path.join(root, file))[0:-11]
            print(local_model_path)
if local_model_path == None:
    print("Model download may failed, please check prior step!")

左滑查看更多

拷贝模型和数据到 S3

chmod +x ./s5cmd
./s5cmd sync ${local_model_path} s3://${sagemaker_default_bucket}/llm/models/chatglm/original-6B/
./s5cmd sync ./AdvertiseGen/ s3://${sagemaker_default_bucket}/llm/datasets/chatglm/AdvertiseGen/


rm -rf model
rm -rf AdvertiseGen
rm -rf AdvertiseGen.tar.gz

左滑查看更多

模型微调

模型的微调使用 P-Tuning v2,以实现成本和效果的平衡。

模型微调更改的源代码较多,具体可以参考上述 git 仓库。

模型微调参数

模型微调设置的关键参数如下:

  1. 前缀词长度:128

  2. 学习率:2e-2,确保 loss 在训练过程中下降

  3. batch size:1

  4. gradient accumulation step:16

  5. 训练步长:50,步长仅设置为 50 步,已经可以看出比较明显的微调结果

import time
from sagemaker.huggingface import HuggingFace




PRE_SEQ_LEN=128
LR=2e-2
BATCH_SIZE=1
GRADIENT_ACCUMULATION_STEPS=16
TRAIN_STEPS=50


job_name = f'huggingface-chatglm-finetune-ptuning-{time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())}'


instance_type  = "ml.g4dn.2xlarge"
instance_count = 1


# 基础模型存放地址
model_name_or_path = 's3://{}/llm/models/chatglm/original-6B/'.format(sagemaker_default_bucket)


# 微调模型输出地址
output_dir         = '/opt/ml/model/adgen-chatglm-6b-ft'
model_s3_path      = 's3://{}/llm/models/chatglm/finetune-ptuning-adgen/'.format(sagemaker_default_bucket)


# 模型环境变量设置
environment = {
    'PYTORCH_CUDA_ALLOC_CONF': 'max_split_size_mb:32',
    'TRAIN_DATASET'          : '/opt/ml/input/data/AdvertiseGen/train.json',
    'TEST_DATASET'           : '/opt/ml/input/data/AdvertiseGen/dev.json',
    'PROMPT_COLUMN'          : 'content',
    'RESPONSE_COLUMN'        : 'summary',
    'MODEL_NAME_OR_PATH'     : model_name_or_path,
    'OUTPUT_DIR'             : output_dir,
    'MODEL_OUTPUT_S3_PATH'   : model_s3_path,
    'TRAIN_STEPS'            : '50'
}


inputs = {
   'AdvertiseGen': f"s3://{sagemaker_default_bucket}/llm/datasets/chatglm/AdvertiseGen/"
}

左滑查看更多

开启模型微调

huggingface_estimator = HuggingFace(
    entry_point          = 'sm_ptune_train.py',
    source_dir           = './ChatGLM-6B/ptuning',
    instance_type        = instance_type,
    instance_count       = instance_count,
    base_job_name        = job_name,
    role                 = role,
    script_mode          = True,
    transformers_version = '4.26',
    pytorch_version      = '1.13',
    py_version           = 'py39',
    environment          = environment
)


huggingface_estimator.fit(inputs=inputs)

左滑查看更多

ChatGLM 部署测试

模型部署

1. 准备 Dummy 模型

!touch dummy
!tar czvf model.tar.gz dummy
assets_dir = 's3://{0}/{1}/assets/'.format(sagemaker_default_bucket, 'chatglm')
model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(sagemaker_default_bucket, 'chatglm')
!aws s3 cp model.tar.gz $assets_dir
!rm -f dummy model.tar.gz

左滑查看更多

2. 配置模型参数

from sagemaker.pytorch.model import PyTorchModel


model_name                  = None
entry_point                 = 'chatglm-inference-finetune.py'
framework_version           = '1.13.1'
py_version                  = 'py39'
base_model_name_or_path     = 's3://{}/llm/models/chatglm/original-6B/'.format(sagemaker_default_bucket)
finetune_model_name_or_path = 's3://{}/llm/models/chatglm/finetune-ptuning-adgen/adgen-chatglm-6b-ft/checkpoint-50/pytorch_model.bin'.format(sagemaker_default_bucket)


# 模型环境变量设置
model_environment  = {
    'SAGEMAKER_MODEL_SERVER_TIMEOUT': '600',
    'SAGEMAKER_MODEL_SERVER_WORKERS': '1',
    'MODEL_NAME_OR_PATH'            : base_model_name_or_path,
    'PRE_SEQ_LEN'                   : '128',
    'FINETUNE_MODEL_NAME_OR_PATH'   : finetune_model_name_or_path,
}


model = PyTorchModel(
    name              = model_name,
    model_data        = model_data,
    entry_point       = entry_point,
    source_dir        = './code',
    role              = role,
    framework_version = framework_version, 
    py_version        = py_version,
    env               = model_environment
)

左滑查看更多

3. 部署微调模型

from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer


endpoint_name         = None
instance_type         = 'ml.g4dn.2xlarge'
instance_count        = 1


predictor = model.deploy(
    endpoint_name          = endpoint_name,
    instance_type          = instance_type, 
    initial_instance_count = instance_count,
    serializer             = JSONSerializer(),
    deserializer           = JSONDeserializer()
)

左滑查看更多

4. 其中关键的模型加载代码如下:加载原始的 ChatGLM 模型、同时加载 FineTune 的 PrefixEncoder 参数共同进行推理

import torch
import os


from transformers import AutoConfig, AutoModel, AutoTokenizer


# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)


# 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数):
config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)


model = model.quantize(4)
model.half().cuda()

左滑查看更多

模型微调前后对比

1. 模型测试

inputs = {
    "ask": "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞"


}


response = predictor.predict(inputs)
print(response["answer"])

左滑查看更多

2. 对比原始 ChatGLM 模型,对于相同的输入,输出更偏广告词,而不是单纯的语义提取

8956e4b5782c23fc48b9a3388a8d43eb.jpeg

3. 清除资源

predictor.delete_endpoint()

总结

大语言模型方兴未艾,正在以各种方式改变和影响着整个世界。客户拥抱大语言模型,亚马逊云科技团队同样在深耕客户需求和大语言模型技术,可以在未来更好地协助客户实现需求、提升业务价值。

本篇作者

6f79d75ab486b7ea0f2fab7918a2f90c.jpeg

高郁

亚马逊云科技解决方案架构师,主要负责企业客户上云,帮助客户进行云架构设计和技术咨询,专注于智能湖仓、AI/ML 等技术方向。

6ef09b601d749aca120e4a99f614edb7.gif

c11c749e0e473685a0f573d0e7810f92.gif

听说,点完下面4个按钮

就不会碰到bug了!

58feebdfc8ea1f23d6d413656219328c.gif

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/972089.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python使用pymysql三方库操作 mysql数据库

为什么要使用pymysql 在使用Python工作与学习中难免会使用到mysql数据库,使用pymysql三方库可以让我们轻松的对数据库的记录进行操作,如创建、修改,删除表,如增加、删除、修改、查询数据表中的记录,下边记录一下pymysq…

Hive 表注释乱码解决

文章目录 出现原因MySQL 字符集修改调整元数据库字符集测试 出现原因 一般 Hive 的元数据信息都存储在 MySQL 中,但 MySQL 数据库中的 character_set_server 和 character_set_database 参数,默认都为 latin1 字符集,这两个参数决定了服务器…

如何利用客户旅程打造好的用户体验?

在当今竞争激烈的市场中,提供卓越的用户体验已经成为企业脱颖而出的关键因素之一。客户旅程是实现出色用户体验的有力工具之一,而HubSpot的客户旅程规划功能为企业提供了强大的支持,帮助他们更好地理解、管理和改善客户的互动过程。今天运营坛…

【ubuntu22.04 文件管理器nautilus配置默认终端为alacritty】

前言 ubuntu默认的终端不能通过设置里的默认应用程序配置nautilus是ubuntu自带的文件管理器,包管理器里面只有nautilus-extension-gnome-terminal而没有提供大多终端update-alternatives工具可以修改系统的默认终端(ctrl-alt-t),但对nautilus文件管理器…

Slint学习文档

Slint学习文档 Slint Learn如何学习本文档学习顺序标志说明 Slint With VSCodeSlint With Rust依赖👎定义宏 Slint与Rust分离1.添加编译依赖(slint-build)2.编写slint文件3.编写build.rs4.编写main.rs 普通组件主窗体Windowexample 文本Texte…

MySQL - Left Join和Inner Join的效率对比,以及优化

最近在写代码的时候,遇到了需要多表连接的一个问题,初始sql类似于: select * from a left join b on a.id b.aid left join c on c.bid b.id left join d on d.cid c.id 这样的多个left join组合,总觉得这种写法是有问题…

借助AI分析哥斯拉木马原理与Tomcat回显链路挖掘

前言 本次分析使用了ChatGPT进行辅助分析&#xff0c;大大提升了工作效率&#xff0c;很快就分析出木马的工作流程和构造出利用方式。 分析 首先对该木马进行格式化,以增强代码的可读性。得到如下代码 <jsp:root xmlns:jsp"http://java.sun.com/JSP/Page" vers…

mac下配置JDK环境

一、下载安装 下载地址&#xff1a;Java Downloads | Oracle&#xff0c;选择适用于Mac OS的JDK版本&#xff0c;点击下载即可。 下载完之后&#xff0c;直接安装&#xff1a; 安装过程非常简单&#xff0c;按“继续”按钮一直下一步即可。 二、配置环境变量 上一步骤&#x…

建筑安全运行监测,预防建筑潜在风险

建筑物是人们生活和工作的场所&#xff0c;其安全性直接关系到人们的生命财产安全。建筑安全运行监测旨在及时发现和识别潜在的安全隐患&#xff0c;以确保建筑物的稳定运行&#xff0c;其重要性不可低估。 建筑安全运行监测可以帮助及早发现结构问题。随着时间的推移&#xff…

【模方ModelFun】实景三维建模和修模4.0.7最新版安装包以及图文安装教程

模方ModelFun 具有多种功能&#xff0c;旨在帮助用户进行实景三维建模和修模。以下是一些主要功能的简要介绍&#xff1a; 实景三维建模&#xff1a;【模方ModelFun】提供了自动化的实景三维重建功能&#xff0c;可以从实景图像中提取几何形状和纹理信息&#xff0c;生成高质量…

51单片机DHT11温湿度控制系统仿真设计( proteus仿真+程序+原理图+报告+讲解视频)

51单片机DHT11温湿度控制系统仿真设计 1.主要功能&#xff1a;2.仿真3. 程序代码4. 原理图元器件清单5. 设计报告6. 设计资料内容清单&下载链接 51单片机DHT11温湿度控制系统仿真设计( proteus仿真程序原理图报告讲解视频&#xff09; 仿真图proteus8.9及以上 程序编译器&…

缓存案例-架构真题(二十二)

试题一 某大型电商平台建立一个B2B商店系统&#xff0c;并在全国建设了仓储中心。但是在运营过程中&#xff0c;发现很多跨仓储中心调货&#xff0c;延误运送。为此建立全国仓储系统&#xff0c;通过对订单的分析和挖掘&#xff0c;并通过大数据分析预测各类配置&#xff0c;降…

机器学习---预剪枝、后剪枝(REP、CCP、PEP、)

1. 为什么要进行剪枝 横轴表示在决策树创建过程中树的结点总数&#xff0c;纵轴表示决策树的预测精度。 实线显示的是决策树 在训练集上的精度&#xff0c;虚线显示的则是在⼀个独⽴的测试集上测量出来的精度。 随着树的增⻓&#xff0c;在 训练样集上的精度是单调上升的&…

VSCode 配置 C 语言编程环境

目录 一、下载 mingw64 二、配置环境变量 三、三个配置文件 四、格式化代码 1、安装插件 2、保存时自动格式化 3、左 { 不换行 上了两年大学&#xff0c;都还没花心思去搭建 C 语言编程环境&#xff0c;惭愧&#xff0c;惭愧。 一、下载 mingw64 mingw64 是著名的 C/C…

【AI理论学习】语言模型:掌握BERT和GPT模型

语言模型&#xff1a;掌握BERT和GPT模型 BERT模型BERT的基本原理BERT的整体架构BERT的输入BERT的输出 BERT的预训练掩码语言模型预测下一个句子 BERT的微调BERT的特征提取使用PyTorch实现BERT GPT模型GPT模型的整体架构GPT的模型结构GPT-2的Multi-Head与BERT的Multi-Head之间的…

【高性能计算】opencl语法及相关概念(五):图像的仿射变换缩放

目录 简介宿主机程序设备端函数缩放效果图 简介 要使用仿射变换完成图像等宽高比缩放&#xff0c;可以按照以下步骤进行操作&#xff1a; 定义仿射变换矩阵&#xff1a;首先&#xff0c;定义一个仿射变换矩阵&#xff0c;用于描述缩放操作。该矩阵是一个2x3的矩阵&#xff0c;…

实时操作系统Freertos开坑学习笔记:(三):任务的挂起与恢复、中断管理

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、任务挂起与恢复的API函数1.具体函数描述①vTaskSuspend()任务挂起&#xff08;暂停&#xff09;函数②vTaskResume()任务恢复函数③xTaskResumeFromISR()在…

PAT 1167 Cartesian Tree

个人学习记录&#xff0c;代码难免不尽人意。 A Cartesian tree is a binary tree constructed from a sequence of distinct numbers. The tree is heap-ordered, and an inorder traversal returns the original sequence. For example, given the sequence { 8, 15, 3, 4, 1…

机器学习-波士顿房价预测

目录 一.数据处理 读入数据 数据形状变换 数据集划分 数据归一化处理 将上面封装成load data函数 二. 模型设计 完整封装运行代码&#xff1a; 根据loss值进行梯度计算 控制部分变量的变化图像&#xff1a; 一.数据处理 读入数据 # 导入需要用到的package import numpy as np…