【AWS系列】使用 Amazon SageMaker 微调和部署 ChatGLM 模型

news2024/10/6 10:40:06

前言

大语言模型是一种基于深度学习技术的人工智能模型,可以追溯到早期的语言模型和机器翻译系统。直到最近,随着深度学习技术的崛起,大型预训练语言模型才开始引起广泛的关注。

大语言模型使用大规模的文本数据集进行预训练,从而学习到丰富的语言知识和语境理解能力。通过预训练和微调的方式,大语言模型可以用于各种自然语言处理任务,例如文本生成、机器翻译、问答系统、对话系统等。它们在许多领域都展示出了令人印象深刻的性能,并成为推动人工智能技术发展的重要驱动力。

本篇文章主要介绍如何使用 Amazon SageMaker 进行 ChatGLM 模型部署和微调的示例。

这个示例主要包括:

  1. ChatGLM 总体介绍
  2. ChatGLM 微调介绍
  3. ChatGLM 环境设置
  4. ChatGLM 微调训练
  5. ChatGLM 部署测试

Amazon SageMaker 更多信息,可以点击下面链接进行了解:Amazon SageMaker

亚马逊云科技更多信息可以查看下方链接亚马逊云科技

一、ChatGLM 总体介绍

ChatGLM 模型是由清华大学开源的、支持中英双语问答的对话语言模型,并针对中文进行了优化。该模型基于 General Language Model(GLM)架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署。

ChatGLM 具备以下特点:

  • 充分的中英双语预训练:ChatGLM 在 1:1 比例的中英语料上训练了 1T 的 token 量,兼具双语能力。
  • 优化的模型架构和大小:修正了二维 RoPE 位置编码实现。6B(62 亿)的参数大小,也使得研究者和个人开发者自己微调和部署 ChatGLM 成为可能。
  • 较低的部署门槛:FP16 半精度下,ChatGLM 需要至少 13 GB 的显存进行推理,结合模型量化技术,这一需求可以进一步降低到 10GB(INT8) 和 6GB(INT4),使得 ChatGLM 可以部署在消费级显卡上。
  • 更长的序列长度:ChatGLM 序列长度达 2048,支持更长对话和应用。

二、ChatGLM 微调介绍

模型微调主要分为 Full Fine-Tune 和 PEFT (Performance-Efficient Fine-Tune),前者模型全部参数都会进行更新,训练时间较长,训练资源较大;而后者会冻结大部分参数、微调训练网络结构,常见的方式是 LoRA 和 P-Tuning v2。对于 ChatGLM 来说,选择 P-Tuning v2 进行模型微调,其网络结构如下:在Transformers 的所有层均增加 Prompt/Prefix。

三、ChatGLM 环境设置

备注:项目中的示例代码均保存于代码仓库,地址如下:代码仓库

1. 升级 Python SDK

pip install --upgrade boto3
pip install --upgrade sagemaker
pip install huggingface_hub

  1. 获取运行时资源

包括区域、角色、账号、S3 桶等

import boto3
import sagemaker
from sagemaker import get_execution_role


sess = sagemaker.Session()
role = get_execution_role()
sagemaker_default_bucket = sess.default_bucket()


account = sess.boto_session.client("sts").get_caller_identity()["Account"]
region = sess.boto_session.region_name

四、ChatGLM 微调训练

4.1准备微调

1.克隆代码

rm -rf ChatGLM-6B
git clone https://github.com/THUDM/ChatGLM-6B.git
cd ChatGLM-6B
git checkout 163f94e160f08751545e3722730f1832d73b92d1

2.下载数据集

此处采用示例的广告数据集。根据输入实现广告语的输出,格式如下:

{
 "content": "类型#上衣版型#宽松版型#显瘦图案#线条衣样式#衬衫衣袖型#泡泡袖衣款式#抽绳",
 "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}
# 下载 ADGEN 数据集
wget -O AdvertiseGen.tar.gz https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1


# 解压数据集
tar -xzvf AdvertiseGen.tar.gz

3.下载 ChatGLM 原始模型

from huggingface_hub import snapshot_download
from pathlib import Path




local_cache_path = Path("./model")
local_cache_path.mkdir(exist_ok=True)


model_name = "THUDM/chatglm-6b"


# Only download pytorch checkpoint files
allow_patterns = ["*.json", "*.pt", "*.bin", "*.model", "*.py"]


model_download_path = snapshot_download(
    repo_id=model_name,
    cache_dir=local_cache_path,
    allow_patterns=allow_patterns,
)


# Get the model files path
import os
from glob import glob


local_model_path = None


paths = os.walk(r'./model')
for root, dirs, files in paths:
    for file in files:
        if file == 'config.json':
            # print(os.path.join(root, file))
            local_model_path = str(os.path.join(root, file))[0:-11]
            print(local_model_path)
if local_model_path == None:
    print("Model download may failed, please check prior step!")

4.拷贝模型和数据到 S3

chmod +x ./s5cmd
./s5cmd sync ${local_model_path} s3://${sagemaker_default_bucket}/llm/models/chatglm/original-6B/
./s5cmd sync ./AdvertiseGen/ s3://${sagemaker_default_bucket}/llm/datasets/chatglm/AdvertiseGen/


rm -rf model
rm -rf AdvertiseGen
rm -rf AdvertiseGen.tar.gz

4.2模型微调

模型的微调使用 P-Tuning v2,以实现成本和效果的平衡。模型微调更改的源代码较多,具体可以参考上述 git 仓库。

1.模型微调参数

模型微调设置的关键参数如下:

  1. 前缀词长度:128
  2. 学习率:2e-2,确保 loss 在训练过程中下降
  3. batch size:1
  4. gradient accumulation step:16
  5. 训练步长:50,步长仅设置为 50 步,已经可以看出比较明显的微调结果
import time
from sagemaker.huggingface import HuggingFace




PRE_SEQ_LEN=128
LR=2e-2
BATCH_SIZE=1
GRADIENT_ACCUMULATION_STEPS=16
TRAIN_STEPS=50


job_name = f'huggingface-chatglm-finetune-ptuning-{time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())}'


instance_type  = "ml.g4dn.2xlarge"
instance_count = 1


# 基础模型存放地址
model_name_or_path = 's3://{}/llm/models/chatglm/original-6B/'.format(sagemaker_default_bucket)


# 微调模型输出地址
output_dir         = '/opt/ml/model/adgen-chatglm-6b-ft'
model_s3_path      = 's3://{}/llm/models/chatglm/finetune-ptuning-adgen/'.format(sagemaker_default_bucket)


# 模型环境变量设置
environment = {
    'PYTORCH_CUDA_ALLOC_CONF': 'max_split_size_mb:32',
    'TRAIN_DATASET'          : '/opt/ml/input/data/AdvertiseGen/train.json',
    'TEST_DATASET'           : '/opt/ml/input/data/AdvertiseGen/dev.json',
    'PROMPT_COLUMN'          : 'content',
    'RESPONSE_COLUMN'        : 'summary',
    'MODEL_NAME_OR_PATH'     : model_name_or_path,
    'OUTPUT_DIR'             : output_dir,
    'MODEL_OUTPUT_S3_PATH'   : model_s3_path,
    'TRAIN_STEPS'            : '50'
}


inputs = {
   'AdvertiseGen': f"s3://{sagemaker_default_bucket}/llm/datasets/chatglm/AdvertiseGen/"
}

2.开启模型微调

huggingface_estimator = HuggingFace(
    entry_point          = 'sm_ptune_train.py',
    source_dir           = './ChatGLM-6B/ptuning',
    instance_type        = instance_type,
    instance_count       = instance_count,
    base_job_name        = job_name,
    role                 = role,
    script_mode          = True,
    transformers_version = '4.26',
    pytorch_version      = '1.13',
    py_version           = 'py39',
    environment          = environment
)


huggingface_estimator.fit(inputs=inputs)

五、ChatGLM 部署测试

5.1模型部署

1. 准备 Dummy 模型

!touch dummy
!tar czvf model.tar.gz dummy
assets_dir = 's3://{0}/{1}/assets/'.format(sagemaker_default_bucket, 'chatglm')
model_data = 's3://{0}/{1}/assets/model.tar.gz'.format(sagemaker_default_bucket, 'chatglm')
!aws s3 cp model.tar.gz $assets_dir
!rm -f dummy model.tar.gz

2. 配置模型参数

from sagemaker.pytorch.model import PyTorchModel


model_name                  = None
entry_point                 = 'chatglm-inference-finetune.py'
framework_version           = '1.13.1'
py_version                  = 'py39'
base_model_name_or_path     = 's3://{}/llm/models/chatglm/original-6B/'.format(sagemaker_default_bucket)
finetune_model_name_or_path = 's3://{}/llm/models/chatglm/finetune-ptuning-adgen/adgen-chatglm-6b-ft/checkpoint-50/pytorch_model.bin'.format(sagemaker_default_bucket)


# 模型环境变量设置
model_environment  = {
    'SAGEMAKER_MODEL_SERVER_TIMEOUT': '600',
    'SAGEMAKER_MODEL_SERVER_WORKERS': '1',
    'MODEL_NAME_OR_PATH'            : base_model_name_or_path,
    'PRE_SEQ_LEN'                   : '128',
    'FINETUNE_MODEL_NAME_OR_PATH'   : finetune_model_name_or_path,
}


model = PyTorchModel(
    name              = model_name,
    model_data        = model_data,
    entry_point       = entry_point,
    source_dir        = './code',
    role              = role,
    framework_version = framework_version, 
    py_version        = py_version,
    env               = model_environment
)

3. 部署微调模型

from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer


endpoint_name         = None
instance_type         = 'ml.g4dn.2xlarge'
instance_count        = 1


predictor = model.deploy(
    endpoint_name          = endpoint_name,
    instance_type          = instance_type, 
    initial_instance_count = instance_count,
    serializer             = JSONSerializer(),
    deserializer           = JSONDeserializer()
)

4.其中关键的模型加载

代码如下:加载原始的 ChatGLM 模型、同时加载 FineTune 的 PrefixEncoder 参数共同进行推理

import torch
import os


from transformers import AutoConfig, AutoModel, AutoTokenizer


# 载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)


# 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数):
config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)


model = model.quantize(4)
model.half().cuda()

5.2模型微调前后对比

1. 模型测试

 inputs = {
    "ask": "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞"


}


response = predictor.predict(inputs)
print(response["answer"])

  1. 对比原始 ChatGLM 模型,

对于相同的输入,输出更偏广告词,而不是单纯的语义提取

2. 清除资源

predictor.delete_endpoint()

六、总结

大语言模型方兴未艾,正在以各种方式改变和影响着整个世界。客户拥抱大语言模型,亚马逊云科技团队同样在深耕客户需求和大语言模型技术,可以在未来更好地协助客户实现需求、提升业务价值。

如果对大模型感兴趣,可以访问下面链接了解更多大模型信息

亚马逊云科技

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1184182.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【VSCode】VSCode自定义代码编辑区背景色

// A code block { "editor.fontSize": 16, "editor.mouseWheelZoom": true, "editor.tabSize": 2, "workbench.colorCustomizations": { // 写在 Atom One Light 里面则只对该主题有效 "[Atom One Light]"…

GreenPlum简介

简介 Greenplum是一家总部位于**美国加利福尼亚州,为全球大型企业用户提供新型企业级数据仓库(EDW)、企业级数据云(EDC)和商务智能(BI)提供解决方案和咨询服务的公司,在全球已有:纳斯达克,纽约证券交易所,Skype. FOX&…

第四章:java关键字super

系列文章目录 文章目录 系列文章目录前言一、super关键字二、super 和 this 的比较总结 前言 super关键字可以用于对象访问父类成员。 一、super关键字 super 代表父类的引用, 用于访问父类的属性、 方法、 构造器。 super.属性名 //访问父类的属性,不…

2003-2022年高铁数据高铁开通时间数据

2003-2022年高铁数据高铁开通时间数据 1、时间:2003-2022年 2、指标:高铁站名称、开通时间、所在省份、所在城市、所属线路名称、以及相关备注 3、指标说明: Hsrwsnm[高铁站名称]-高铁站名称 Optm[开通时间]-高铁站开通的时间 Prvn[所在…

java传base64返回给数据报404踩坑

一、问题复现 1.可能因为base64字符太长,导致后端处理时出错,表现为前端请求报400错误; 这一步debug进去发现base64数据是正常传值的 所以排除掉不是后端问题,但是看了下前端请求,猜测可能是转换base64时间太长数据过大导致的404 2.前端传…

【C++】——基础编程

🎃个人专栏: 🐬 算法设计与分析:算法设计与分析_IT闫的博客-CSDN博客 🐳Java基础:Java基础_IT闫的博客-CSDN博客 🐋c语言:c语言_IT闫的博客-CSDN博客 🐟MySQL&#xff1a…

牵手世界顶尖科学家论坛,五粮液扩大国际高端平台布局

执笔 | 尼 奥 编辑 | 扬 灵 11月6日,以“科学引领变革 重塑世界韧性”为主题的第六届世界顶尖科学家论坛(以下简称“顶科论坛”)在上海召开。来自25个国家和地区,包括27位诺奖得主在内的100余位海外顶尖科学家、40余位中国两…

OpenCV(opencv_apps)在ROS中的视频图像的应用(重点讲解哈里斯角点的检测)

1、引言 通过opencv_apps,你可以在ROS中以最简单的方式运行OpenCV提供的许多功能,也就是说,运行一个与功能相对应的launch启动文件,就可以跳过为OpenCV的许多功能编写OpenCV应用程序代码,非常的方便。 对于想熟悉每个…

融云出海:从全球最多 MAU 的 10 款社交 App,看设计细节的重要性

近期,微信又悄悄进行了一次消息弹窗的更新,再次引发网友热议。在最新版本中,用户在聊天时,也能看到新消息的内容,让不少用户大呼方便。实际上,在过去几年,微信的每一次细小更新都会引发“用户到…

如何提高企业竞争力?CRM管理系统告诉你

随着竞争形势和商业环境的加剧,企业需要迅速适应不断变化的消费需求。不少企业使用CRM客户管理系统来优化业务流程,管理客户信息,实现更多的业绩增长。那么我们来说说,CRM系统如何提高企业竞争力? 强大的数据管理&…

一次性搞懂长轮询、短轮询、SSE、websocket区别

[[toc]] http的4种推送技术 客户端轮询:传统意义上的短轮询(Short Polling)服务器端轮询:长轮询(Long Polling)单向服务器推送:Server-Sent Events(SSE)全双工通信:WebSocket图中 每个箭头代表的是 http 连接 tcp的长连接和短连接 http keep-alive 是什么? 本质:…

打包 广告

小米广告 Type android.support.v4.app.INotificationSideChannel is defined multiple times d8clsPath: Error in D:\ChannelFolder\JJChannelPackageForTest\ToolConfigPath\channels-ad\ATemp-100057\xiaomi\lib\xiaomi_ad_merge_20231104.jar:android/support/v4/app/IN…

【中国知名企业高管团队】系列61:海尔Haier

今明两天,华研荟为您介绍中国的另外两个家电巨头,这两个巨头的发展历程都高度相似,都有赖于第一代创业者敏锐和坚持,而且同处一地。他们是海尔和海信,今天先介绍海尔。 一、认识海尔集团 根据海尔集团官网介绍&#…

innovus:解决报告复制时一行拆成两行的问题

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧? innovus复制报告时一行的东西出现在两行上,解决只需要一条命令: set_table_style -no_frame_width

代码随想录算法训练营第四十七天 | LeetCode 198. 打家劫舍、213. 打家劫舍 II、337. 打家劫舍 III

代码随想录算法训练营第四十七天 | LeetCode 198. 打家劫舍、213. 打家劫舍 II、337. 打家劫舍 III 文章链接:打家劫舍 打家劫舍 II 打家劫舍 III 视频链接:打家劫舍 打家劫舍 II 打家劫舍 III 1. LeetCode 198. 打家劫舍 1.1 思路 我们要去偷钱&#…

python使用memory_profiler分析代码运行内存占用

memory_profile memory_profiler源码仓库 安装 pip install memory_profiler 使用 请参考以下文章,写的很详细 【精选】Python代码优化工具——memory_profiler_被Python玩的Kenny的博客-CSDN博客 本文要增加介绍的是API使用 目录结构 |--my.py |--tests | |-- test_m…

设计模式之保护性暂停

文章目录 1. 定义2. 实现保护性暂停模式3. Join原理4. 保护性暂停模式的扩展 1. 定义 即Guarded Suspension,用在一个线程等待另一个线程的执行结果。 有一个结果需要从一个线程传递给另一个线程,让他们关联到同一个GuarderObject(这就是保…

快速教程|如何在 AWS EC2上使用 Walrus 部署 GitLab

Walrus 是一款基于平台工程理念的开源应用管理平台,致力于解决应用交付领域的深切痛点。借助 Walrus 将云原生的能力和最佳实践扩展到非容器化环境,并支持任意应用形态统一编排部署,降低使用基础设施的复杂度,为研发和运维团队提供…

汽车生产RFID智能制造设计解决方案与思路

汽车行业需求 汽车行业正面临着快速变革,传统的汽车制造方式正在向柔性化、数字化、自动化和数据化的智能制造体系转变,在这个变革的背景下,汽车制造企业面临着物流、生产、配送和资产管理等方面的挑战,为了应对这些挑战&#xf…

为什么亚马逊的轻量应用服务器这么受欢迎 | 个人体验 | 优势所在

文章目录 🌺前言⭐什么是轻量应用服务器🛸特点 🎄亚马逊轻量应用服务器体验如何🌹亚马逊轻量应用服务器的优势 🌺前言 作为一为开发者,我们要开发部署一个自己的网站,要选择一个性能好的服务器…