使用 Amazon SageMaker 微调 Llama 2 模型

news2024/9/20 13:57:33

1c3aea7b4c3c494fd29dcbc241c8a7fd.gif

本篇文章主要介绍如何使用 Amazon SageMaker 进行 Llama 2 模型微调的示例。

这个示例主要包括:

  1. Llama 2 总体介绍

  2. Llama 2 微调介绍

  3. Llama 2 环境设置

  4. Llama 2 微调训练

前言

随着生成式 AI 的热度逐渐升高,国内外各种基座大语言竞相出炉,在其基础上衍生出种类繁多的应用场景。训练优异的基座大语言模型在通用性方面表现较好,但模型可能并未涉及到特定领域的专业术语、领域内的特定用语或上下文等。采用微调技术可以通过在领域特定数据上进行训练,使模型更好地适应目标领域的特殊语言模式和结构;结合基座模型的通用性和领域特定性,使得模型更具实际应用价值。

Llama 2 总体介绍

Llama 2 是 META 最新开源的 LLM,包括 7B、13B 和 70B 三个版本,训练数据集超过了 Llama 2 的 40%,达到 2 万亿 token;上下文长度也提升到 4K,可以极大扩展多轮对话的轮数、提示词输入数据;与此同时,Llama 2 Chat 模型使用基于人类反馈的强化学习(Reinforcement Learning from Human Feedback,RLHF),针对对话场景进行了大幅优化,达到了非常出色的有用性和安全性基准。HuggingFace 的 TGI 和 vLLM 等框架均有针对 Llama 2 的推理优化,进一步强化了 Llama 2 的可用性。

Llama 2 被认为是开源界大语言模型的首选,众多的垂类大模型均采用 Llama 2 作为基座大模型,在此基础上添加行业数据进行模型的预训练或者微调,适配更多的行业场景。

Llama 2 微调介绍

模型微调主要分为 Full Fine-Tune 和 PEFT (Performance-Efficient Fine-Tune),前者模型全部参数都会进行更新,训练时间较长,训练资源较大;而后者会冻结大部分参数、微调训练网络结构,常见的方式是 LoRA 和 P-Tuning v2。

PEFT 微调方式由于参数更新较少,可能导致模型无法学习到全部领域知识,对于特定任务或领域来说会出现推理不稳定的情况,因此大多数生产系统均使用全参数方式进行模型的微调。基于上述原因,本文会以全参数微调方式介绍 Llama 2 在 Amazon SageMaker 上的微调。

Llama 2 环境设置

备注:项目中的示例代码均保存于代码仓库,地址如下: 

https://github.com/aws-samples/llm-workshop-on-amazon-sagemaker

1. 升级 Python SDK 

pip install -U sagemaker

2. 获取运行时资源,包括区域、角色、账号、S3 桶等 

import boto3
import sagemaker
from sagemaker import get_execution_role




sess                     = sagemaker.Session()
role                     = get_execution_role()
sagemaker_default_bucket = sess.default_bucket()


account                  = sess.boto_session.client("sts").get_caller_identity()["Account"]
region                   = sess.boto_session.region_name

Llama 2 微调训练

微调准备

克隆代码

  • 采用 lm-sys 团队发布的 FastChat 平台进行 Llama 2 的微调,FastChat 也用于训练了知名的 Vicuna 模型,具有良好的代码规范和性能优化。

git clone https://github.com/lm-sys/FastChat.git
cd FastChat
git reset --hard 974537efbd82093b45e64d07904efe7728193a52

下载 Llama 2 原始模型

from huggingface_hub import snapshot_download
from pathlib import Path




local_cache_path = Path("./model")
local_cache_path.mkdir(exist_ok=True)


model_name = "TheBloke/Llama-2-13B-fp16"


# Only download pytorch checkpoint files
allow_patterns = ["*.json", "*.pt", "*.bin", "*.model", "*.py"]


model_download_path = snapshot_download(
    repo_id=model_name,
    cache_dir=local_cache_path,
    allow_patterns=allow_patterns,
    revision='b2e65e8ad4bb35e5abaee0170ebd5fc2134a50bb'
)


# Get the model files path
import os
from glob import glob


local_model_path = None


paths = os.walk(r'./model')
for root, dirs, files in paths:
    for file in files:
        if file == 'config.json':
            print(os.path.join(root,file))
            local_model_path = str(os.path.join(root,file))[0:-11]
            print(local_model_path)
if local_model_path == None:
    print("Model download may failed, please check prior step!")

拷贝模型和数据到 Amazon S3

chmod +x ./s5cmd
./s5cmd sync ${local_model_path} s3://${sagemaker_default_bucket}/llm/models/llama2/TheBloke/Llama-2-13B-fp16/ 
rm -rf model

模型微调

  • 模型的微调使用全参数模型,以实现微调后模型的稳定性。

  • 模型的微调使用开源框架 DeepSpeed 进行加速。

准备基础镜像

使用 Amazon SageMaker 定制的深度学习训练镜像作为基础镜像,再安装 Llama 2 训练所需的依赖包。Dockerfile 如下:

%%writefile Dockerfile
## You should change below region code to the region you used, here sample is use us-west-2
From 763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04 




ENV LANG=C.UTF-8
ENV PYTHONUNBUFFERED=TRUE
ENV PYTHONDONTWRITEBYTECODE=TRUE


RUN pip3 uninstall -y deepspeed \
    && pip3 install deepspeed==0.10.0 \
    && pip3 install transformers==4.30.2


## Make all local GPUs visible
ENV NVIDIA_VISIBLE_DEVICES="all"

模型微调代码

模型微调源代码较多,细节可以参考上述 git 仓库。

微调参数

  • 为了节省显存,采用 DeepSpeed Stage-3

  • 训练过程开启 bf16,实现整数范围和精度的平衡

  • 训练数据集采用官方提供的 dummy_conversation.json,也就是典型的 {"instruction"、"input"、"output"} 的格式,同时可以支持多轮对话

DEEPSPEED_OPTS="""
    FastChat/fastchat/train/train_mem.py 
    --deepspeed ds.json 
    --model_name_or_path "/tmp/llama_pretrain/" 
    --data_path FastChat/data/dummy_conversation.json 
    --output_dir "/tmp/llama_out" 
    --num_train_epochs 1 
    --per_device_train_batch_size 1 
    --per_device_eval_batch_size  1 
    --gradient_accumulation_steps 4 
    --evaluation_strategy "no" 
    --save_strategy "no" 
    --save_steps 2000 
    --save_total_limit 1 
    --learning_rate 2e-5 
    --weight_decay 0. 
    --warmup_ratio 0.03 
    --lr_scheduler_type "cosine" 
    --logging_steps 1 
    --cache_dir '/tmp' 
    --model_max_length 2048 
    --gradient_checkpointing True 
    --lazy_preprocess True 
    --bf16 True 
    --tf32 True 
    --report_to "none"
"""

微调脚本

  • 微调使用 torchrun + DeepSpeed 进行分布式训练

%%writefile ./src/ds-train-dist.sh
#!/bin/bash
CURRENT_HOST="${SM_CURRENT_HOST}"




IFS=',' read -ra hosts_array <<< "${SM_HOSTS}"
NNODES=${#hosts_array[@]}
NODE_RANK=0


for i in "${!hosts_array[@]}"; do
    if [[ "${hosts_array[$i]}" == *${CURRENT_HOST}* ]]; then
        echo "host index:$i"
        NODE_RANK="$i" 
    fi
done
   
    
MASTER_PORT="13579"
export NCCL_SOCKET_IFNAME="eth0"


#Configure the distributed arguments for torch.distributed.launch.
GPUS_PER_NODE="$SM_NUM_GPUS"
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"


chmod +x ./s5cmd
./s5cmd sync s3://$MODEL_S3_BUCKET/llm/models/llama2/TheBloke/Llama-2-13B-fp16/* /tmp/llama_pretrain/


CMD="torchrun ${DISTRIBUTED_ARGS} ${DEEPSPEED_OPTS}"
echo ${CMD}
${CMD} 2>&1 


if [[ "${CURRENT_HOST}" == "${MASTER_ADDR}" ]]; then  
     ./s5cmd sync /tmp/llama_out s3://$MODEL_S3_BUCKET/llm/models/llama2/output/TheBloke/Llama-2-13B-fp16/$(date +%Y-%m-%d-%H-%M-%S)/
fi

启动微调

  • 全参数微调,需要使用至少一台 p4de.12xlarge(8 卡 A100 40GB)作为训练机器。

  • 当微调完成后,训练好的模型自动存储于指定的 S3 桶内,可用于后续的模型部署推理。

import time
from sagemaker.estimator import Estimator




environment = {
    'MODEL_S3_BUCKET': sagemaker_default_bucket # The bucket to store pretrained model and fine-tune model
}


base_job_name = 'llama2-13b-finetune'


instance_type = 'ml.p4d.24xlarge'


estimator = Estimator(role=role,
                      entry_point='ds-train-dist.sh',
                      source_dir='./src',
                      base_job_name=base_job_name,
                      instance_count=1,
                      instance_type=instance_type,
                      image_uri=image_uri,
                      environment=environment,
                      disable_profiler=True,
                      debugger_hook_config=False)




estimator.fit()

总结

大语言模型方兴未艾,正在以各种方式改变和影响着整个世界。客户拥抱大语言模型,亚马逊云科技团队同样在深耕客户需求和大语言模型技术,可以在未来更好地协助客户实现需求,提升业务价值。

本篇作者

6ce443c21a564a6109595741d2da8d7c.jpeg

高郁

亚马逊云科技解决方案架构师,主要负责企业客户上云,帮助客户进行云架构设计和技术咨询,专注于智能湖仓、AI/ML 等技术方向。

3c7b3572aea34ef2ee8fc45077824270.gif

星标不迷路,开发更极速!

关注后记得星标「亚马逊云开发者」

952455b80801542c984978e00d9d8e6e.gif

听说,点完下面4个按钮

就不会碰到bug了!

c330e0b3208f52226bf51d8e36d2c4f1.gif

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1537353.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

I2C芯片24C02/4/8/16(EEPROM)解读

一.原理图 24C01的硬件连接图如下&#xff1a; 二.24C0x系列芯片规格 三.24C0x芯片结构 下面简述EEPROM内部存储结构。 3.1 内部存储结构 根据24C02芯片的Datasheet描述&#xff0c;其内部存储结构应该如下图所示。 其它容量的EEPROM内部结构依此类推。 3.2 地址 3.2.1 器件…

BitMap介绍与应用

文章目录 BitMapBitMap介绍BitMap 结构RoaringBitmap 常见BitMapJava中的BitSetRedis中的BitMapClickHouse中的BitMap BitMap应用案例人群圈选 BitMap 场景一&#xff1a;(大部分开发面试都会遇到的一个问题&#xff09; 有10亿个用户id (int类型)&#xff0c;判断用户是否登…

Vue el-table 合并单元格

一般常见的就是下图这种的单列&#xff0c;上下重复进行合并。 有时候可能也会需要多行多列的合并。 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content&qu…

【LeetCode】--- 动态规划 集训(一)

目录 一、1137. 第 N 个泰波那契数1.1 题目解析1.2 状态转移方程1.3 解题代码 二、面试题 08.01. 三步问题2.1 题目解析2.2 状态转移方程2.3 解题代码 三、746. 使用最小花费爬楼梯3.1 题目解析3.2 状态转移方程3.3 解题代码 一、1137. 第 N 个泰波那契数 题目地址&#xff1a…

FloodFill算法——岛屿数量

文章目录 题目解析算法解析代码解析 题目解析 岛屿数量 题目依旧是熟悉的配方&#xff0c;熟悉的味道&#xff0c;还是那个0还是那个1还是那个二维矩阵&#xff0c;这时候BFS和DFS闻着味就来了&#xff0c;我们来看一下这个题目&#xff0c;这个题目也很容易理解如下图有一个…

阿里云2核4G服务器租用价格和性能测评

阿里云2核4G服务器租用优惠价格&#xff0c;轻量2核4G服务器165元一年、u1服务器2核4G5M带宽199元一年、云服务器e实例30元3个月&#xff0c;活动链接 aliyunfuwuqi.com/go/aliyun 活动链接如下图&#xff1a; 阿里云2核4G服务器优惠价格 轻量应用服务器2核2G4M带宽、60GB高效…

市场复盘总结 20240322

仅用于记录当天的市场情况&#xff0c;用于统计交易策略的适用情况&#xff0c;以便程序回测 短线核心&#xff1a;不参与任何级别的调整&#xff0c;采用龙空龙模式 一支股票 10%的时候可以操作&#xff0c; 90%的时间适合空仓等待 二进三&#xff1a; 进级率中 36% 最常用…

力扣题库27题移除元素(c语言)

解法&#xff1a; int removeElement(int* nums, int numsSize, int val) {int src0,dst0;while(src<numsSize){if(nums[src]val){src;}else{nums[dst]nums[src];src;dst;}}return dst; }

SCI一区 | Matlab实现PSO-TCN-BiGRU-Attention粒子群算法优化时间卷积双向门控循环单元融合注意力机制多变量时间序列预测

SCI一区 | Matlab实现PSO-TCN-BiGRU-Attention粒子群算法优化时间卷积双向门控循环单元融合注意力机制多变量时间序列预测 目录 SCI一区 | Matlab实现PSO-TCN-BiGRU-Attention粒子群算法优化时间卷积双向门控循环单元融合注意力机制多变量时间序列预测预测效果基本介绍模型描述…

visual studio卸载几种方法

1、控制面板卸载&#xff1b; 2、有时候会发现控制面板卸载会失败&#xff0c;无法卸载&#xff0c;这时候要先把下面目录的关于visual studio的都删除&#xff0c;然后重启电脑后&#xff0c;重新安装vs即可。

C语言预编译#pragma宏的作用

在嵌入式编程中&#xff0c;#pragma 指令具有非常重要的作用&#xff0c;因为它允许开发者在不同的编译器之间传达特定的编译指令。由于嵌入式编程通常与硬件紧密相关&#xff0c;且资源有限&#xff0c;这些指令可以帮助开发者更有效地利用可用资源&#xff0c;优化程序&#…

基于python+vue的stone音乐播放器的设计与实现flask-django-php-nodejs

随着我国经济的高速发展与人们生活水平的日益提高&#xff0c;人们对生活质量的追求也多种多样。尤其在人们生活节奏不断加快的当下&#xff0c;人们更趋向于足不出户解决生活上的问题&#xff0c;stone音乐播放器展现了其蓬勃生命力和广阔的前景。与此同时&#xff0c;为解决用…

docker快速安装达梦数据库

docker快速安装达梦数据库 文章目录 docker快速安装达梦数据库前言环境准备下载镜像运行、配置容器 前言 因为公司需要将自己的底代码平台与客户的需求做适配&#xff0c;客户要求必须满足信创要求&#xff0c;使用达梦数据库。所以需要将原有的MySQL数据库与达梦数据库适配&a…

每日五道java面试题之springboot篇(一)

目录&#xff1a; 第一题. 什么是 Spring Boot&#xff1f;第二题. Spring Boot 有哪些优点&#xff1f;第三题. Spring Boot 的核心注解是哪个&#xff1f;它主要由哪几个注解组成的&#xff1f;第四题. 什么是 JavaConfig&#xff1f;第五题. Spring Boot 自动配置原理是什么…

来了,工业5.0

什么是工业5.0 “工业5.0”一词是由欧盟委员会引入和推广的&#xff0c;用于描述其对欧洲工业的愿景。 工业5.0的强调的不仅是技术&#xff0c;更注重是人性。提倡“以人为本”的思想。工业 5.0 不是专注于创造经济价值&#xff0c;而是激励企业探索如何通过提供更健康的工作…

排序算法记录(冒泡+快排+归并)

文章目录 前言冒泡排序快速排序归并排序 前言 冒泡 快排 归并&#xff0c;这三种排序算法太过经典&#xff0c;但又很容易忘了。虽然一开始接触雀氏这些算法雀氏有些头大&#xff0c;但时间长了也还好。主要是回忆这些算法干了啥很耗时间。 如果在笔试时要写一个o(nlogn)的…

java学习——集合

目录 一、集合框架介绍 1、集合与集合框架说明 2、使用集合框架原因 3、集合框架接口体系 二、Collection接口 1、Collection常用方法 2、AbstractCollection 三、迭代器 1、迭代器说明 2、自定义Collection集合 四、泛型 1、泛型说明 2、使用泛型方法 3、泛型通配…

哲♂学家带你深♂入了♂解结构体及结构体内存大小问题

目录 概要 一、结构体的声明 二、结构体变量的创建和初始化 三、结构体的特殊声明 四、结构体内存对齐 1、对齐原则 2、例一 对齐数 计算方法 3、例二 总结 概要 结构体是我们日常编程中经常要用到的一种自定义类型&#xff0c;使用起来也是十分的方便。接下来就由…

ts js vue 验证文件 MD5 值 spark-md5

ts js vue 验证文件 MD5 值 spark-md5 如何在前端中验证要上传的文件的 md5 值 一、安装 spark-md5 插件 需要用到 spark-md5 这个插件 官方 github&#xff1a;https://github.com/satazor/js-spark-md5/tree/master yarn add spark-md5 // 或 npm i spark-md5使用的时候引…

TCP | TCP协议格式 | 三次握手

1.TCP协议 为什么需要 TCP 协议 &#xff1f;TCP 工作在哪一层&#xff1f; IP网络层是不可靠的&#xff0c;TCP工作在传输层&#xff0c;保证数据传输的可靠性。 TCP全称为 “传输控制协议&#xff08;Transmission Control Protocol”&#xff09;。 TCP 是面向连接的、可靠…