0基础学会在亚马逊云科技AWS上利用SageMaker、PEFT和LoRA高效微调AI大语言模型(含具体教程和代码)

news2025/1/12 15:51:12

项目简介:

小李哥今天将继续介绍亚马逊云科技AWS云计算平台上的前沿前沿AI技术解决方案,帮助大家快速了解国际上最热门的云计算平台亚马逊云科技AWS上的AI软甲开发最佳实践,并应用到自己的日常工作里。本次介绍的是如何在Amazon SageMaker上微调(Fine-tune)大语言模型dolly-v2-3b,满足日常生活中不同的场景需求,并将介分享如何在SageMaker上优化模型性能并节省计算资源实现成本控制,最后将部署后的大语言模型URL集成到自己云上的软件应用中。

本方案包括通过Amazon Cloudfront和S3托管前端页面,并通过Amazon API Gateway和AWS Lambda将应用程序与AI模型集成,调用大模型实现推理。本方案的解决方案架构图如下:

利用微调模型创建的对话机器人前端UI

利用本方案小李哥用微调后的模型搭建了一个Q&A对话机器人助手,可以生成代码、文字总结、回答问题。

在开始分享案例之前,我们来了解一下本方案的技术背景,帮助大家更好的理解方案架构。

什么是Amazon SageMaker?

Amazon SageMaker 是一个完全托管的机器学习服务(大家可以理解为Serverless的Jupyter Notebook),专为应用开发和数据科学家设计,帮助他们快速构建、训练和部署机器学习模型。使用 SageMaker,您无需担心底层基础设施的管理,可以专注于模型的开发和优化。它提供了一整套工具和功能,包括数据准备、模型训练、超参数调优、模型部署和监控,简化了整个机器学习工作流程。

本方案将介绍以下内容:

1. 使用 SageMaker Jupyter Notebook进行dolly-v2-3b模型开发和微调

2. 在SageMaker部署微调后的大语言模型LLM并基于数据进行推理

3. 使用多场景的测试案例验证推理结果表现,并将部署的模型API节点集成进云端应用

项目搭建具体步骤:

下面跟着小李哥手把手微调一个亚马逊云科技AWS上的生成式AI模型(dolly-v2-3b)的软件应用,并将AI大模型部署与应用集成。

1. 在控制台进入Amazon SageMaker, 点击Notebook

2. 打开Jupyter Notebook

3. 创建一个新的Notebook:“lab-notebook.ipynb”并打开

4. 接下来我们在单元格内一步一步运行代码,检查CUDA的内存状态

!nvidia-smi

5.接下来,我们安装必要依赖并导入

%%capture

!pip3 install -r requirements.txt --quiet
!pip install sagemaker --quiet --upgrade --force-reinstall
%%capture

import os
import numpy as np
import pandas as pd
from typing import Any, Dict, List, Tuple, Union
from datasets import Dataset, load_dataset, disable_caching
disable_caching() ## disable huggingface cache

from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import TextDataset

import torch
from torch.utils.data import Dataset, random_split
from transformers import TrainingArguments, Trainer
import accelerate
import bitsandbytes

from IPython.display import Markdown

6. 导入提前准备好的FAQs数据集

sagemaker_faqs_dataset = load_dataset("csv", 
                                      data_files='data/amazon_sagemaker_faqs.csv')['train']
sagemaker_faqs_dataset
sagemaker_faqs_dataset[0]

7. 我们定义用于模型推理的提示词格式

from utils.helpers import INTRO_BLURB, INSTRUCTION_KEY, RESPONSE_KEY, END_KEY, RESPONSE_KEY_NL, DEFAULT_SEED, PROMPT
'''
PROMPT = """{intro}
            {instruction_key}
            {instruction}
            {response_key}
            {response}
            {end_key}"""
'''
Markdown(PROMPT)

8. 下面我们进入重头戏,导入一个提前预训练好的LLM大语言模型“databricks/dolly-v2-3b”。

tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-3b", 
                                          padding_side="left")

tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({"additional_special_tokens": 
                              [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})

model = AutoModelForCausalLM.from_pretrained(
    "databricks/dolly-v2-3b",
    # use_cache=False,
    device_map="auto", #"balanced",
    load_in_8bit=True,
)

9. 对模型训练进行预准备, 处理数据集、优化模型训练(PEFT)效率

model.resize_token_embeddings(len(tokenizer))

from functools import partial
from utils.helpers import mlu_preprocess_batch

MAX_LENGTH = 256
_preprocessing_function = partial(mlu_preprocess_batch, max_length=MAX_LENGTH, tokenizer=tokenizer)

encoded_sagemaker_faqs_dataset = sagemaker_faqs_dataset.map(
        _preprocessing_function,
        batched=True,
        remove_columns=["instruction", "response", "text"],
)

processed_dataset = encoded_sagemaker_faqs_dataset.filter(lambda rec: len(rec["input_ids"]) < MAX_LENGTH)

split_dataset = processed_dataset.train_test_split(test_size=14, seed=0)
split_dataset

10. 同时我们使用LoRA(Low-Rank Adaptation)模型加速我们的模型微调

from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskType

MICRO_BATCH_SIZE = 8  
BATCH_SIZE = 64
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LORA_R = 256 # 512
LORA_ALPHA = 512 # 1024
LORA_DROPOUT = 0.05

# Define LoRA Config
lora_config = LoraConfig(
                 r=LORA_R,
                 lora_alpha=LORA_ALPHA,
                 lora_dropout=LORA_DROPOUT,
                 bias="none",
                 task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

from utils.helpers import MLUDataCollatorForCompletionOnlyLM

data_collator = MLUDataCollatorForCompletionOnlyLM(
        tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
)

11. 接下来我们定义模型训练参数并开始训练。其中Batch=1,Step=20000,epoch为10.

EPOCHS = 10
LEARNING_RATE = 1e-4  
MODEL_SAVE_FOLDER_NAME = "dolly-3b-lora"

training_args = TrainingArguments(
                    output_dir=MODEL_SAVE_FOLDER_NAME,
                    fp16=True,
                    per_device_train_batch_size=1,
                    per_device_eval_batch_size=1,
                    learning_rate=LEARNING_RATE,
                    num_train_epochs=EPOCHS,
                    logging_strategy="steps",
                    logging_steps=100,
                    evaluation_strategy="steps",
                    eval_steps=100, 
                    save_strategy="steps",
                    save_steps=20000,
                    save_total_limit=10,
)

trainer = Trainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=split_dataset['train'],
        eval_dataset=split_dataset["test"],
        data_collator=data_collator,
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

12. 接下来我们将微调后的模型保存在本地

trainer.model.save_pretrained(MODEL_SAVE_FOLDER_NAME)

trainer.save_model()

trainer.model.config.save_pretrained(MODEL_SAVE_FOLDER_NAME)

tokenizer.save_pretrained(MODEL_SAVE_FOLDER_NAME)

13. 接下来,我们将保存到本地的模型进行部署,生成公开访问的API节点Endpoint

对部署所需要的参数进行定义和初始化

import boto3
import json
import sagemaker.djl_inference
from sagemaker.session import Session
from sagemaker import image_uris
from sagemaker import Model

sagemaker_session = Session()
print("sagemaker_session: ", sagemaker_session)

aws_role = sagemaker_session.get_caller_identity_arn()
print("aws_role: ", aws_role)

aws_region = boto3.Session().region_name
print("aws_region: ", aws_region)

image_uri = image_uris.retrieve(framework="djl-deepspeed",
                                version="0.22.1",
                                region=sagemaker_session._region_name)
print("image_uri: ", image_uri)

进行模型部署

model_data="s3://{}/lora_model.tar.gz".format(mybucket)

model = Model(image_uri=image_uri,
              model_data=model_data,
              predictor_cls=sagemaker.djl_inference.DJLPredictor,
              role=aws_role)

14.最后我们写入提示词,对大语言模型进行测试, 得到推理

outputs = predictor.predict({"inputs": "What solutions come pre-built with Amazon SageMaker JumpStart?"})

from IPython.display import Markdown
Markdown(outputs)

15. 我们下面进入SageMaker Endpoint页面,得到刚部署的模型API端点的URL,通过这种方式我们就可以在应用中调用我们的微调后的大语言模型了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1926819.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度学习概览

引言 深度学习的定义与背景 深度学习是机器学习的一个子领域&#xff0c;涉及使用多层神经网络分析和学习复杂的数据模式。深度学习的基础可以追溯到20世纪80年代&#xff0c;但真正的发展和广泛应用是在21世纪初。计算能力的提升和大数据的可用性使得深度学习在许多领域取得…

ARM功耗管理标准接口之SCMI

安全之安全(security)博客目录导读 思考&#xff1a;功耗管理有哪些标准接口&#xff1f;ACPI&PSCI&SCMI&#xff1f; Advanced Configuration and Power Interface Power State Coordination Interface System Control and Management Interface 下图示例说明了实现…

【算法笔记自学】第 10 章 提高篇(4)——图算法专题

10.1图的定义和相关术语 #include <cstdio>const int MAXN 100; int degree[MAXN] {0};int main() {int n, m, u, v;scanf("%d%d", &n, &m);for (int j 0; j < m; j) {scanf("%d%d", &u, &v);degree[u];degree[v];}for (int i…

怎么减少pdf的MB,怎么减少pdf的大小

在数字化时代&#xff0c;pdf文件因其格式稳定、跨平台兼容性强等特点而广受欢迎。然而&#xff0c;随着内容的丰富&#xff0c;pdf文件的大小也日益增大&#xff0c;给文件传输和存储带来了不少困扰。本文将为你介绍多种减小pdf文件大小的方法&#xff0c;帮助你轻松应对这一问…

中间件——Kafka

两个系统各自都有各自要去做的事&#xff0c;所以只能将消息放到一个中间平台&#xff08;中间件&#xff09; Kafka 分布式流媒体平台 程序发消息&#xff0c;程序接收消息 Producer&#xff1a;Producer即生产者&#xff0c;消息的产生者&#xff0c;是消息的入口。 Brok…

Hadoop3:HDFS-存储优化之纠删码

一、集群环境 集群一共5个节点&#xff0c;102/103/104/105/106 二、纠删码原理 1、简介 HDFS默认情况下&#xff0c;一个文件有3个副本&#xff0c;这样提高了数据的可靠性&#xff0c;但也带来了2倍的冗余开销。Hadoop3.x引入了纠删码&#xff0c;采用计算的方式&#x…

Python爬虫速成之路(3):下载图片

hello hello~ &#xff0c;这里是绝命Coding——老白~&#x1f496;&#x1f496; &#xff0c;欢迎大家点赞&#x1f973;&#x1f973;关注&#x1f4a5;&#x1f4a5;收藏&#x1f339;&#x1f339;&#x1f339; &#x1f4a5;个人主页&#xff1a;绝命Coding-CSDN博客 &a…

[GXYCTF2019]BabySQli

原题目描述&#xff1a;刚学完sqli&#xff0c;我才知道万能口令这么危险&#xff0c;还好我进行了防护&#xff0c;还用md5哈希了密码&#xff01; 我看到是个黑盒先想着搞一份源码 我dirsearch明明扫到了.git&#xff0c;算了直接注入试试看 随便输入了两个东西&#xff0c…

pico+unity3d项目配置

重点&#xff1a;unity编辑器版本要和pico的sdk要求一致、比如&#xff1a; 对于 Unity 2022.1.14 及以上版本&#xff0c;若同时在项目中使用 URP、Linear 色彩空间、四倍抗锯齿和OpenGL&#xff0c;会出现崩溃。该问题待 Unity 引擎解决。对于 Unity 2022&#xff0c;若同时…

植物神经紊乱?别怕!这些维生素来帮你!

最近收到好多小伙伴的私信&#xff0c;说自己有植物神经紊乱的困扰&#xff0c;不知道该肿么办&#xff1f;别急别急&#xff0c;我给你们来一波大福利&#x1f381;&#xff01;今天就来聊聊植物神经紊乱患者应该补充哪些维生素&#xff0c;让你的身体重回最佳状态&#xff01…

SC.Pandas 03 | 如何使用Pandas分析时间序列数据?

Introduction 前两期我们对Pandas的数据结构和常用的计算方法进行了介绍&#xff0c;在地球科学领域时间序列分析是很重要的一种数据处理方式。 通过理解数据的时间变化特征&#xff0c;我们可以更深入地理解研究对象的演化规律与物理机制。 因此&#xff0c;本期我们就从Pa…

安装cnpm失败

20240714 安装cnpm失败&#xff0c;报错如下&#xff1a; 失败原因&#xff1a; 请求 https://registry.npm.taobao.org 失败&#xff0c;原因&#xff1a;证书已过期 使用以下命令&#xff1a; npm install -g cnpm --registryhttps://registry.npmmirror.com另外&#xff0…

1.33、激活可视化卷积神经网络(matalb)

1、激活可视化卷积神经网络原理及流程 激活可视化&#xff08;Activation Visualization&#xff09;指的是通过可视化神经网络中激活函数的输出&#xff0c;来理解神经网络是如何学习并提取特征的过程。在卷积神经网络&#xff08;CNN&#xff09;中&#xff0c;我们可以通过…

【ROS2】中级:RViz-构建自定义 RViz 显示

背景 在 RViz 中有许多类型的数据已经有现有的可视化。然而&#xff0c;如果有一种消息类型尚未有插件来显示它&#xff0c;那么有两种选择可以在 RViz 中查看它。 将消息转换为另一种类型&#xff0c;例如 visualization_msgs/Marker 。编写自定义 RViz 显示。 使用第一个选项…

成为CMake砖家(5): VSCode CMake Tools 插件基本使用

大家好&#xff0c;我是白鱼。 之前提到过&#xff0c;白鱼的主力 编辑器/IDE 是 VSCode&#xff0c; 也提到过使用 CMake Language Support 搭配 dotnet 执行 CMakeLists.txt 语法高亮。 对于阅读 CMakeLists.txt 脚本&#xff0c; 这足够了。 而在 C/C 开发过程中&#xff…

数据结构(单链表(2))

单链表的实现 SList.h 由于代码中已有大量注释&#xff0c;所以该文章主要起到补充说明作用。 &#xfeff;#pragma once #include<stdio.h> #include<stdlib.h> #include<assert.h>//定义链表&#xff08;结点&#xff09;的结构typedef int SLTDataType…

MySQL in 太多过慢的 3 种解决方案

文章目录 解决方案一&#xff1a;使用 JOIN 替代 IN示例&#xff1a; 解决方案二&#xff1a;分批处理 IN 子句示例&#xff1a; 解决方案三&#xff1a;使用临时表示例&#xff1a; 总结 &#x1f389;欢迎来到Java学习路线专栏~探索Java中的静态变量与实例变量 ☆* o(≧▽≦)…

力扣刷题之978.最长湍流子数组

题干要求&#xff1a; 给定一个整数数组 arr &#xff0c;返回 arr 的 最大湍流子数组的长度 。 如果比较符号在子数组中的每个相邻元素对之间翻转&#xff0c;则该子数组是 湍流子数组 。 更正式地来说&#xff0c;当 arr 的子数组 A[i], A[i1], ..., A[j] 满足仅满足下列条…

基于用户鼠标移动的规律可以对用户身份进行连续验证的方法

概述 论文地址&#xff1a;https://arxiv.org/abs/2403.03828 本文重点论述了高效可靠的用户身份验证方法在计算机安全领域的重要性。它研究了使用鼠标移动动态作为连续身份验证新方法的可能性。具体来说&#xff0c;本文分析了用户在两个不同游戏场景–团队要塞和聚能桥–中…

关于Kafka Topic分区和Replication分配的策略

文章目录 1. Topic多分区2. 理想的策略3. 实际的策略4. 如何自定义策略 1. Topic多分区 如图&#xff0c;是一个多分区Topic在Kafka集群中可能得分配情况。 P0-RL代表分区0&#xff0c;Leader副本。 这个Topic是3分区2副本的配置。分区尽量均匀分在不同的Broker上&#xff0c…