Trl: llama2-7b-hf使用QLora 4bit量化后ds zero3加上flash atten v2单机多卡训练(笔记)

news2025/1/8 4:46:23

目录

一、环境

  1.1、环境安装

  1.2、安装flash atten

二、代码

  2.1、bash脚本 

  2.2、utils.py 注释与优化

  2.3、train.py 注释与优化

  2.4、模型/参数相关

    2.4.1、量化后的模型

      a) 量化后模型结构

      b) 量化后模型layers

    2.4.2、参数

      a) training args

      b) peft args

      c) model args

三、Trl

  3.1、SFTTrainer

  3.2、其他的代码

    3.2.1、datasets.map 使用 load_from_cache_file = False 方便调试​​​​​​​​​​​​​​

四、


  • 项目地址

peft/examples/sft at main · huggingface/peft · GitHub🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning. - peft/examples/sft at main · huggingface/pefticon-default.png?t=N7T8https://github.com/huggingface/peft/tree/main/examples/sft

  • 文档

https://huggingface.co/docs/peft/accelerate/deepspeedicon-default.png?t=N7T8https://huggingface.co/docs/peft/accelerate/deepspeed

一、环境

系统:ubuntu 
cuda版本:12.1
torch版本:2.2.0
python版本:3.10

conda 虚拟环境中 cuda版本
cuda:12.1  # 确保与"外界"cuda一致

  1.1、环境安装

pip install -r ...

    第一种

git+https://github.com/huggingface/transformers
git+https://github.com/huggingface/accelerate
git+https://github.com/huggingface/peft
git+https://github.com/huggingface/trl
git+https://github.com/huggingface/datatrove.git
unsloth[conda]@git+https://github.com/unslothai/unsloth.git
deepspeed
PyGithub
# flash-attn 单独安装
huggingface-hub
evaluate
datasets
bitsandbytes
einops
wandb
tensorboard
tiktoken
pandas
numpy
scipy
matplotlib
sentencepiece
nltk
xformers
hf_transfer

     第二种

absl-py==2.1.0
accelerate==0.30.0
aiohttp==3.9.4
aiosignal==1.3.1
annotated-types==0.6.0
appdirs==1.4.4
async-timeout==4.0.3
attrs==23.2.0
bitsandbytes==0.43.1
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
contourpy==1.2.1
cryptography==42.0.5
cycler==0.12.1
datasets==2.18.0
datatrove==0.0.1
deepspeed==0.14.0
Deprecated==1.2.14
dill==0.3.8
docker-pycreds==0.4.0
docstring_parser==0.16
einops==0.7.0
evaluate==0.4.1
filelock==3.13.4
# flash-attn==2.5.7
# flash-attn 需要手动安装, 安装之前需要先保证:
# 第一 确保 linux "外界"的 cuda版本 与 conda 虚拟环境中cuda版本一致
# 第二 安装好 c++ g++ ninja
# 第三 参考官方命令: https://github.com/Dao-AILab/flash-attention
fonttools==4.51.0
frozenlist==1.4.1
fsspec==2024.2.0
gitdb==4.0.11
GitPython==3.1.43
grpcio==1.62.1
hf_transfer==0.1.6
hjson==3.1.0
huggingface-hub==0.22.2
humanize==4.9.0
idna==3.7
Jinja2==3.1.3
joblib==1.4.0
kiwisolver==1.4.5
loguru==0.7.2
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.8.4
mdurl==0.1.2
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
networkx==3.3
ninja==1.11.1.1
nltk==3.8.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.19.3
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
packaging==24.0
pandas==2.2.2
peft==0.10.1
pillow==10.3.0
pip==23.3.1
protobuf==3.20.3
psutil==5.9.8
py-cpuinfo==9.0.0
pyarrow==15.0.2
pyarrow-hotfix==0.6
pycparser==2.22
pydantic==2.7.0
pydantic_core==2.18.1
PyGithub==2.3.0
Pygments==2.17.2
PyJWT==2.8.0
PyNaCl==1.5.0
pynvml==11.5.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
pytz==2024.1
PyYAML==6.0.1
regex==2023.12.25
requests==2.31.0
responses==0.18.0
rich==13.7.1
safetensors==0.4.2
scipy==1.13.0
sentencepiece==0.2.0
sentry-sdk==1.45.0
setproctitle==1.3.3
setuptools==68.2.2
shtab==1.7.1
six==1.16.0
smmap==5.0.1
sympy==1.12
tensorboard==2.16.2
tensorboard-data-server==0.7.2
tiktoken==0.6.0
tokenizers==0.15.2
torch==2.2.2
tqdm==4.66.2
transformers==4.40.0
triton==2.2.0
trl==0.8.3
typing_extensions==4.11.0
tyro==0.8.3
tzdata==2024.1
unsloth==2024.4
urllib3==2.2.1
wandb==0.16.6
Werkzeug==3.0.2
wheel==0.43.0
wrapt==1.16.0
xformers==0.0.25.post1
xxhash==3.4.1
yarl==1.9.4

  1.2、安装flash atten

安装 flash atten 和 deepspeed 前,需要保证:

  • 第一 确保 linux "外界"的 cuda版本 与 conda 虚拟环境中cuda版本一致
  • 第二 安装好 c++ g++ ninja
  • 第三 参考官方命令: GitHub - Dao-AILab/flash-attention: Fast and memory-efficient exact attentionFast and memory-efficient exact attention. Contribute to Dao-AILab/flash-attention development by creating an account on GitHub.icon-default.png?t=N7T8https://github.com/Dao-AILab/flash-attention
1. 安装 c++ g++
sudo apt-get update
sudo apt-get install build-essential

2. 安装 Ninja
sudo apt-get install ninja-build

3. 安装flash atten
    参考上面官方命令:
    pip install packaging
    pip install flash-attn --no-build-isolation          ----- flash atten 编译过程需要一定的时间,需要等待

二、代码

peft/examples/sft at main · huggingface/peft · GitHub🤗 PEFT: State-of-the-art Parameter-Efficient Fine-Tuning. - peft/examples/sft at main · huggingface/pefticon-default.png?t=N7T8https://github.com/huggingface/peft/tree/main/examples/sft

  2.1、bash脚本 

PYTHONPATH=$PWD
export PYTHONPATH
echo "当前bash执行目录: $PWD, 已经将PYTHONPATH设置为: $PYTHONPATH"


# --resume_from_checkpoint dir   表示trainer从dir恢复ckpt
# 注释掉: 与wandb 不能共存
# 2>&1 | tee -a examples/sft/qlora_ds_zero3_log.out
accelerate launch --config_file "examples/sft/configs/deepspeed_config_z3_qlora.yaml"  examples/sft/train.py \
    --seed 100 \
    --model_name_or_path "/workspace/Llama-2-7b-chat-hf" \
    --dataset_name "smangrul/ultrachat-10k-chatml" \
    --chat_template_format "chatml" \
    --add_special_tokens False \
    --append_concat_token False \
    --splits "train,test" \
    --max_seq_len 2048 \
    --num_train_epochs 2 \
    --logging_steps 5 \
    --log_level "info" \
    --logging_strategy "steps" \
    --evaluation_strategy "epoch" \
    --save_strategy "steps" \
    --save_steps 100 \
    --save_total_limit 10 \
    --bf16 True \
    --packing True \
    --learning_rate 1e-4 \
    --lr_scheduler_type "cosine" \
    --weight_decay 1e-4 \
    --warmup_ratio 0.0 \
    --max_grad_norm 1.0 \
    --output_dir "/workspace/output/llama-sft-qlora-dsz3" \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 2 \
    --gradient_accumulation_steps 4 \
    --use_flash_attn True \
    --gradient_checkpointing True \
    --use_reentrant True \
    --dataset_text_field "content" \
    --use_peft_lora True \
    --lora_r 8 \
    --lora_alpha 16 \
    --lora_dropout 0.1 \
    --lora_target_modules "all-linear" \
    --use_4bit_quantization True \
    --use_nested_quant True \
    --bnb_4bit_compute_dtype "bfloat16" \
    --bnb_4bit_quant_storage_dtype "bfloat16" \
    --resume_from_checkpoint /workspace/output/llama-sft-qlora-dsz3/checkpoint-100 \
    2>&1 | tee -a examples/sft/qlora_ds_zero3_log.out

    # 上传至 hub 的参数
    # --push_to_hub \
    # --hub_private_repo True \
    # --hub_strategy "every_save" \

  2.2、utils.py 注释与优化

import os
from enum import Enum

import torch
from datasets import DatasetDict, load_dataset, load_from_disk
from datasets.builder import DatasetGenerationError
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

from peft import LoraConfig

# DEFAULT_CHATML_CHAT_TEMPLATE是一个用于格式化聊天消息的jinja2模板字符串
# jinja2是一种流行的Python模板引擎,它允许在模板中嵌入Python代码,使模板更加动态和可编程
# 在这个模板中,{% for message in messages %} 是一个jinja2的for循环语句,用于遍历messages列表中的每个消息
# {
   {'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
# 这一部分定义了每条消息的格式化方式,包括:
#   1. <|im_start|>: 一个特殊标记,表示消息角色(如user、system或assistant)的开始
#   2. message['role']: 当前消息的角色,如user、system或assistant
#   3. \n: 换行符,用于在角色和消息内容之间添加新行
#   4. message['content']: 当前消息的实际内容
#   5. <|im_end|>: 一个特殊标记,表示消息内容的结束
#   6. \n: 换行符,用于在每条消息之后添加新行
# {% if loop.last and add_generation_prompt %}{
   {'<|im_start|>assistant\n' }}{% endif %}
# 这一部分是一个jinja2的条件语句,当循环遍历到最后一条消息时,如果add_generation_prompt为True,
# 则会在最后一条消息后添加'<|im_start|>assistant\n'作为提示,表示需要模型生成助手的回复
# 这种模板格式化方式的目的是将原始的聊天记录转换为适合语言模型输入的格式,以便进行对话生成任务
DEFAULT_CHATML_CHAT_TEMPLATE = "{% for message in messages %}\n{
   {'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% if loop.last and add_generation_prompt %}{
   {'<|im_start|>assistant\n' }}{% endif %}{% endfor %}"


# DEFAULT_ZEPHYR_CHAT_TEMPLATE与DEFAULT_CHATML_CHAT_TEMPLATE类似,也是一个用于格式化聊天消息的jinja2模板
# 不同之处在于格式化方式和使用的特殊标记
# {% for message in messages %} 同样是一个用于遍历消息列表的for循环
# {% if message['role'] == 'user' %} 是一个条件语句,用于判断当前消息的角色是否为user
# 如果是user,则使用{
   { '<|user|>\n' + message['content'] + eos_token }}将消息格式化为:
#   1. <|user|>: 用户角色的特殊标记
#   2. \n: 换行符
#   3. message['content']: 消息内容
#   4. eos_token: 句尾标记,如</s>
# {% elif message['role'] == 'system' %} 是另一个条件分支,用于判断当前消息的角色是否为system
# 如果是system,则使用{
   { '<|system|>\n' + message['content'] + eos_token }}进行格式化
# {% elif message['role'] == 'assistant' %} 是第三个条件分支,用于判断当前消息的角色是否为assistant
# 如果是assistant,则使用{
   { '<|assistant|>\n'  + message['content'] + eos_token }}进行格式化
# {% if loop.last and add_generation_prompt %}\n{
   { '<|assistant|>' }}\n{% endif %}
# 这一部分与DEFAULT_CHATML_CHAT_TEMPLATE类似,当遍历到最后一条消息时,如果add_generation_prompt为True,
# 则会添加'<|assistant|>\n'作为提示,表示需要模型生成助手的回复
# 总的来说,这种格式化方式将原始聊天记录转换为适合语言模型输入的形式,但使用了不同的特殊标记
DEFAULT_ZEPHYR_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{
   { '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{
   { '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{
   { '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{
   { '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

# ZephyrSpecialTokens是一个继承自str和Enum的枚举类
# 它定义了Zephyr聊天格式中使用的各种特殊标记,如用户标记、助手标记、系统标记等
# 枚举类的好处是可以将一组相关的常量组织在一起,并提供更好的可读性和类型安全性
# 每个特殊标记都被定义为一个类属性,其值为对应的字符串形式
# 例如,user = "<|user|>"表示用户标记的字符串形式为"<|user|>"
class ZephyrSpecialTokens(str, Enum):
    user = "<|user|>"
    assistant = "<|assistant|>"
    system = "<|system|>"
    eos_token = "</s>"      # 句尾标记,表示一个句子或序列的结束
    bos_token = "<s>"       # 句首标记,表示一个句子或序列的开始
    pad_token = "<pad>"     # 填充标记,用于将序列填充至指定长度

    # list方法是一个类方法,它返回一个列表,包含了该枚举类中所有特殊标记的字符串形式
    # 这个方法常用于初始化分词器(tokenizer)时,将这些特殊标记添加到词表中
    @classmethod
    def list(cls):
        return [c.value for c in cls]

# ChatmlSpecialTokens与ZephyrSpecialTokens类似,也是一个定义了Chatml聊天格式中使用的特殊标记的枚举类
# 不同之处在于具体的特殊标记字符串形式
# 例如,user标记在Chatml格式中为"<|im_start|>user",而在Zephyr格式中为"<|user|>"
class ChatmlSpecialTokens(str, Enum):
    user = "<|im_start|>user"
    assistant = "<|im_start|>assistant"
    system = "<|im_start|>system"
    eos_token = "<|im_end|>"
    bos_token = "<s>"
    pad_token = "<pad>"

    @classmethod
    def list(cls):
        return [c.value for c in cls]

# create_datasets函数用于创建训练和测试数据集
# 参数包括:
#   tokenizer: 用于对文本进行分词(tokenization)和编码(encoding)的分词器对象
#   data_args: 包含数据相关配置的参数对象,如数据集名称、切分方式等
#   training_args: 包含训练相关配置的参数对象
#   apply_chat_template (bool): 是否应用聊天模板对数据进行预处理,默认为False
def create_datasets(tokenizer, data_args, training_args, apply_chat_template=False):
    # preprocess是一个内部函数,用于对数据样本进行预处理
    # 它接受一个字典样本作为输入,其中"messages"键对应一个列表,列表中的每个元素都是一个对话(conversation)
    def preprocess(samples):
        batch = []     # 初始化一个空列表,用于存储预处理后的对话
        # 遍历样本中的每个对话
        for conversation in samples["messages"]:
            # 对每个对话应用tokenizer.apply_chat_template方法进行预处理
            # tokenize=False表示不执行分词操作,只进行格式化
            batch.append(tokenizer.apply_chat_template(conversation, tokenize=False))
        # 返回一个字典,其中"content"键对应预处理后的对话列表
        return {"content": batch}

    raw_datasets = DatasetDict()   # 初始化一个空的DatasetDict对象,用于存储数据集
    # 遍历data_args.splits指定的数据集切分(如train、test等)
    for split in data_args.splits.split(","):
        try:
            # Try first if dataset on a Hub repo, 首先尝试从Hugging Face Hub上加载指定的数据集
            dataset = load_dataset(data_args.dataset_name, split=split)
        except DatasetGenerationError:
            # If not, check local dataset, 如果从Hub上加载失败,则尝试从本地磁盘加载数据集
            dataset = load_from_disk(os.path.join(data_args.dataset_name, split))

        # 根据切分类型,将数据集存入raw_datasets的对应键值中
        if "train" in split:
            raw_datasets["train"] = dataset
        elif "test" in split:
            raw_datasets["test"] = dataset
        else:
            raise ValueError(f"Split type {split} not recognized as one of test or train.")

    # 如果apply_chat_template为True,则对数据集应用preprocess函数进行预处理
    if apply_chat_template:
        raw_datasets = raw_datasets.map(
            preprocess,
            batched=True,         # 表示对样本进行批处理,提高效率
            remove_columns=raw_datasets["train"].column_names,
        )

    train_data = raw_datasets["train"]  # 获取训练数据集
    valid_data = raw_datasets["test"]   # 获取测试数据集
    print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")  # 打印数据集大小
    print(f"A sample of train dataset: {train_data[0]}")  # 打印训练数据集的第一个样本

    return train_data, valid_data


# create_and_prepare_model函数用于创建和准备模型
# 参数包括:
#   args: 包含模型相关配置的参数对象,如模型名称、是否使用量化等
#   data_args: 包含数据相关配置的参数对象,如最大序列长度等
#   training_args: 包含训练相关配置的参数对象,如是否使用梯度检查点等
def create_and_prepare_model(args, data_args, training_args):
    if args.use_unsloth:
        # 如果使用Unsloth库(一种用于加速语言模型的库),则导入FastLanguageModel类
        from unsloth import FastLanguageModel
    bnb_config = None    # 初始化BitsAndBytesConfig为None,用于量化配置
    quant_storage_dtype = None   # 初始化量化存储数据类型为None

    # 检查是否为分布式训练且使用Unsloth库,如果是则抛出NotImplementedError
    # 因为当前版本的Unsloth不支持分布式训练
    if (
        torch.distributed.is_available()
        and torch.distributed.is_initialized()
        and torch.distributed.get_world_size() > 1
        and args.use_unsloth
    ):
        raise NotImplementedError("Unsloth is not supported in distributed training")

    # 如果使用4位量化,则设置计算数据类型和量化存储数据类型
    if args.use_4bit_quantization:
        # 获取指定的计算数据类型,如torch.float16或torch.bfloat16
        compute_dtype = getattr(torch, args.bnb_4bit_compute_dtype)
        # 获取指定的量化存储数据类型,如torch.float16或torch.float32
        quant_storage_dtype = getattr(torch, args.bnb_4bit_quant_storage_dtype)

        # 创建BitsAndBytesConfig对象,用于配置量化相关参数
        # BitsAndBytesConfig是一个用于管理量化配置的类,可以指定量化类型、计算数据类型、存储数据类型等
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=args.use_4bit_quantization,          # 是否使用4位量化
            bnb_4bit_quant_type=args.bnb_4bit_quant_type,     # 4位量化的类型, 如 nf4
            bnb_4bit_compute_dtype=compute_dtype,             # 计算数据类型
            bnb_4bit_use_double_quant=args.use_nested_quant,  # 是否使用双量化
            # TODO Qlora + zero3 修改的代码
            bnb_4bit_quant_storage=quant_storage_dtype,       # 量化存储数据类型
        )

        # 如果计算数据类型为float16且使用4位量化,则打印GPU是否支持bfloat16的提示
        if compute_dtype == torch.float16 and args.use_4bit_quantization:
            major, _ = torch.cuda.get_device_capability()
            if major >= 8:
                print("=" * 80)
                print("Your GPU supports bfloat16, you can accelerate training with the argument --bf16")
                print("=" * 80)
        # 如果使用8位量化,则创建相应的BitsAndBytesConfig对象
        elif args.use_8bit_quantization:
            bnb_config = BitsAndBytesConfig(load_in_8bit=args.use_8bit_quantization)

    # 如果使用Unsloth库
    if args.use_unsloth:
        # Load model, 使用FastLanguageModel.from_pretrained方法加载模型, 传入模型名称路径、最大序列长度、是否使用4位量化等参数
        model, _ = FastLanguageModel.from_pretrained(
            model_name=args.model_name_or_path,
            max_seq_length=data_args.max_seq_length,
            dtype=None,
            load_in_4bit=args.use_4bit_quantization,
        )
    else: # 如果不使用Unsloth库,则使用AutoModelForCausalLM.from_pretrained方法加载模型
        # TODO Qlora + zero3 修改的代码
        # 如果指定了quant_storage_dtype且是浮点数类型,则使用quant_storage_dtype, 否则使用默认的torch.float32
        torch_dtype = (
            quant_storage_dtype if quant_storage_dtype and quant_storage_dtype.is_floating_point else torch.float32
        )
        # 使用AutoModelForCausalLM.from_pretrained方法加载语言模型, 传入模型路径、量化配置、是否信任远程代码、注意力实现方式和数据类型等参数
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name_or_path,
            quantization_config=bnb_config,
            trust_remote_code=True,
            # 注意力实现方式,flash_attention_2或eager
            attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
            # TODO Qlora + zero3 修改的代码
                # 注意 torch_dtype 对于 AutoModelForCausalLM 与 bnb_4bit_quant_storage 数据类型相同。就是这样。其他所有事情都由 Trainer 和 TRL 处理。
            torch_dtype=torch_dtype,
        )

    peft_config = None      # 初始化PEFT配置为None
    chat_template = None    # 初始化聊天模板为None
    # 如果使用PEFT LoRA且不使用Unsloth库,则创建LoraConfig对象
    # PEFT (Parameter-Efficient Fine-Tuning)是一种模型微调技术,可以在保持大部分模型参数不变的情况下,只微调一小部分参数
    # LoRA (Low-Rank Adaptation)是PEFT的一种实现,通过添加低秩矩阵来适应新任务
    if args.use_peft_lora and not args.use_unsloth:
        peft_config = LoraConfig(
            lora_alpha=args.lora_alpha,         # LoRA的alpha参数,控制LoRA层的重要性
            lora_dropout=args.lora_dropout,
            r=args.lora_r,
            bias="none",                       # 是否对偏置项应用LoRA
            task_type="CAUSAL_LM",             # 任务类型,这里是因果语言模型
            target_modules=args.lora_target_modules.split(",")
            if args.lora_target_modules != "all-linear"
            else args.lora_target_modules,
        )

    special_tokens = None   # 初始化特殊标记为None
    chat_template = None    # 初始化聊天模板为None
    # 根据args.chat_template_format参数,设置特殊标记和聊天模板
    if args.chat_template_format == "chatml":
        special_tokens = ChatmlSpecialTokens              # 使用Chatml格式的特殊标记
        chat_template = DEFAULT_CHATML_CHAT_TEMPLATE      # 使用Chatml聊天模板
    elif args.chat_template_format == "zephyr":
        special_tokens = ZephyrSpecialTokens            # 使用Zephyr格式的特殊标记
        chat_template = DEFAULT_ZEPHYR_CHAT_TEMPLATE    # 使用Zephyr聊天模板

    # 如果特殊标记不为None
    if special_tokens is not None:
        # 使用AutoTokenizer.from_pretrained方法加载分词器
        # 设置填充标记、句首标记、句尾标记和其他特殊标记
        tokenizer = AutoTokenizer.from_pretrained(
            args.model_name_or_path,
            pad_token=special_tokens.pad_token.value,     # 填充标记
            bos_token=special_tokens.bos_token.value,     # 句首标记
            eos_token=special_tokens.eos_token.value,     # 句尾标记
            additional_special_tokens=special_tokens.list(),  # 其他特殊标记
            trust_remote_code=True,
        )
        tokenizer.chat_template = chat_template           # 设置聊天模板
        # make embedding resizing configurable?
        # 调整tokenizer的嵌入大小,使其能够容纳新增的特殊标记
        # pad_to_multiple_of=8用于对齐,提高GPU计算效率
        model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=8)
    else:
        # 如果特殊标记为None,则直接加载分词器
        tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
        tokenizer.pad_token = tokenizer.eos_token     # 设置填充标记为句尾标记


    # 如果使用Unsloth库
    if args.use_unsloth:
        # Do model patching and add fast LoRA weights
        # 使用FastLanguageModel.get_peft_model方法对模型进行修补,并添加快速LoRA权重
        # 传入LoRA相关参数,如alpha、dropout、rank等,以及是否使用梯度检查点、随机种子和最大序列长度
        model = FastLanguageModel.get_peft_model(
            model,
            lora_alpha=args.lora_alpha,
            lora_dropout=args.lora_dropout,
            r=args.lora_r,
            target_modules=args.lora_target_modules.split(",")
            if args.lora_target_modules != "all-linear"
            else args.lora_target_modules,
            use_gradient_checkpointing=training_args.gradient_checkpointing,
            random_state=training_args.seed,
            max_seq_length=data_args.max_seq_length,
        )

    return model, peft_config, tokenizer       # 返回模型、PEFT配置和分词器

  2.3、train.py 注释与优化

import os
import sys
import torch
from dataclasses import dataclass, field
from typing import Optional

import torch.distributed
from transformers import HfArgumentParser, TrainingArguments, set_seed, Seq2SeqTrainingArguments
from trl import SFTTrainer    # SFTTrainer用于序列到序列(Sequence-to-Sequence)的语言模型微调训练
from utils import create_and_prepare_model, create_datasets  # 自定义的实用函数,用于创建和准备模型、数据集

# TODO 新增代码, wandb 与 bash 重定向 log.out 冲突, 关闭掉
os.environ["WANDB_DISABLED"] = "true" # 关闭 wandb

# Define and parse arguments. 定义ModelArguments数据类,用于指定模型相关参数
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    # 指定预训练语言模型的路径或在Hugging Face模型库中的标识符, 这允许您使用您选择的任何预训练模型,如GPT-2、GPT-3、BERT等
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    # 指定聊天数据的格式,有以下选项:
    # 1) chatml: 使用Anthropic的chatml格式,例如: <human>: 你好 \n<assistant>: 你好,很高兴与你交谈。
    # 2) zephyr: 使用Pretrained.AI的zephyr格式,例如: Human: 你好 \nAssistant: 你好,很高兴与你交谈。 
    # 3) none: 如果数据集已经格式化为聊天模板,则设置为none
    # 这个参数可以帮助您灵活地处理不同格式的聊天数据
    chat_template_format: Optional[str] = field(
        default="none",
        metadata={
            "help": "chatml|zephyr|none. Pass `none` if the dataset is already formatted with the chat template."
        },
    )
    lora_alpha: Optional[int] = field(default=16)    # lora_alpha控制LoRA层的重要性,典型值为16或32
    lora_dropout: Optional[float] = field(default=0.1)  # lora_dropout设置LoRA层的dropout率,用于防止过拟合
    # lora_r指定LoRA低秩矩阵的秩(rank),较低的秩可以进一步减少参数量,但可能会影响性能, 秩越低,模型越压缩,但可能会导致性能下降
    lora_r: Optional[int] = field(default=64)
    # lora_target_modules指定应用LoRA的模块列表
    # 默认值包括注意力层的线性投影(q_proj, k_proj, v_proj, o_proj)和前馈神经网络层(down_proj, up_proj, gate_proj)
    # 也可以设置为"all-linear"以应用LoRA到所有线性层
    # 通过选择性地应用LoRA,可以在性能和参数量之间进行权衡
    lora_target_modules: Optional[str] = field(
        default="q_proj,k_proj,v_proj,o_proj,down_proj,up_proj,gate_proj",
        metadata={"help": "comma separated list of target modules to apply LoRA layers to"},
    )
    # use_nested_quant指定是否启用嵌套量化(nested quantization), 嵌套量化可以将4位量化模型进一步量化为2位或更低,从而进一步减小模型大小和内存占用,但可能会影响精度
    # 即 双量化
    use_nested_quant: Optional[bool] = field(
        default=False,
        metadata={"help": "Activate nested quantization for 4bit base models"},
    )
    # bnb_4bit_compute_dtype指定4位量化模型的计算数据类型,例如float16或bfloat16, 使用较低的计算精度可以提高计算速度,但可能会影响模型精度
    bnb_4bit_compute_dtype: Optional[str] = field(
        default="float16",
        metadata={"help": "Compute dtype for 4bit base models"},
    )
    # bnb_4bit_quant_storage_dtype指定4位量化模型的量化存储数据类型,如uint8或float16或bf16, 使用较低的存储精度可以减小模型大小,但可能会影响模型精度
    # 您需要权衡模型大小和精度的平衡
    bnb_4bit_quant_storage_dtype: Optional[str] = field(
        default="uint8",
        metadata={"help": "Quantization storage dtype for 4bit base models"},
    )
    # bnb_4bit_quant_type指定4位量化类型,包括fp4(浮点4位量化)或nf4(整数4位量化)
    # 不同的量化类型会影响模型精度和计算效率的权衡
    # fp4可能会保留更多精度,但nf4可能会更快
    bnb_4bit_quant_type: Optional[str] = field(
        default="nf4",
        metadata={"help": "Quantization type fp4 or nf4"},
    )
    # use_flash_attn指定是否启用Flash注意力(Flash attention)
    # Flash注意力是一种高效的注意力实现,可以通过内存优化和并行计算提高训练速度,但可能会增加一些开销
    # 这个参数可以帮助您在训练速度和内存占用之间进行权衡
    use_flash_attn: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables Flash attention for training."},
    )
    # use_peft_lora指定是否启用PEFT (Parameter-Efficient Fine-Tuning) LoRA
    use_peft_lora: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables PEFT LoRA for training."},
    )
    # use_8bit_quantization指定是否将模型加载为8位量化版本
    use_8bit_quantization: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables loading model in 8bit."},
    )
    # use_4bit_quantization指定是否将模型加载为4位量化版本, 4位量化可以将模型大小减小到原始大小的1/4,从而进一步节省内存和加快计算,但可能会显著影响精度
    use_4bit_quantization: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables loading model in 4bit."},
    )
    # use_reentrant是梯度检查点(Gradient Checkpointing)的一个参数, 梯度检查点可以通过重新计算激活值来节省内存,但会增加一些计算开销
    # use_reentrant指定是否使用可重入(reentrant)的梯度检查点实现,可能会进一步节省内存, 这个参数可以帮助您在内存占用和计算开销之间进行权衡
    use_reentrant: Optional[bool] = field(
        default=False,
        metadata={"help": "Gradient Checkpointing param. Refer the related docs"},
    )
    # use_unsloth指定是否使用Unsloth库进行训练
    # Unsloth是一个优化库,可以通过内存优化和并行计算加速PEFT LoRA的训练过程
    # 这个参数可以帮助您进一步提高训练效率
    use_unsloth: Optional[bool] = field(
        default=False,
        metadata={"help": "Enables UnSloth for training."},
    )


# 定义DataTrainingArguments数据类,用于指定数据集和数据处理相关参数
@dataclass
class DataTrainingArguments:
    # 指定要使用的数据集名称或路径,默认为OpenAssistant Guanaco数据集
    # 您可以根据需要替换为其他数据集
    dataset_name: Optional[str] = field(
        default="timdettmers/openassistant-guanaco",
        metadata={"help": "The preference dataset to use."},
    )

    # packing指定是否使用数据集打包(packing)
    # 数据集打包可以将多个样本打包为一个更长的序列,从而提高训练效率, 这个参数可以帮助您在训练速度和内存占用之间进行权衡
    packing: Optional[bool] = field(
        default=False,
        metadata={"help": "Use packing dataset creating."},
    )

    # dataset_text_field指定数据集中作为input文本的字段名, 这个参数可以帮助您灵活地处理不同格式的数据集
    dataset_text_field: str = field(default="text", metadata={"help": "Dataset field to use as input text."})
    # max_seq_length指定输入序列的最大长度,超出部分将被截断, 这个参数可以帮助您在训练速度、内存占用和模型性能之间进行权衡
    max_seq_length: Optional[int] = field(default=512)

    # append_concat_token指定在打包数据集时,是否在每个样本的末尾追加一个连接标记(如<eos>),这个参数可以帮助您控制数据集的格式,从而影响模型的输出
    append_concat_token: Optional[bool] = field(
        default=False,
        metadata={"help": "If True, appends `eos_token_id` at the end of each sample being packed."},
    )

    # add_special_tokens指定在打包数据集时,是否由分词器(tokenizer)添加特殊标记(如<bos>和<eos>), 这个参数可以帮助您控制数据集的格式,从而影响模型的输出
    add_special_tokens: Optional[bool] = field(
        default=False,
        metadata={"help": "If True, tokenizers adds special tokens to each sample being packed."},
    )
    # splits指定要从数据集中使用的数据分割,如train、test或val,多个分割用逗号分隔, 这个参数可以帮助您灵活地使用数据集的不同部分进行训练和评估
    splits: Optional[str] = field(
        default="train,test",
        metadata={"help": "Comma separate list of the splits to use from the dataset."},
    )

# TODO 新增代码, 打印模型的是否参与训练的参数名和数据类型
def print_model_allarguments_name_dtype(model):
    for n,v in model.named_parameters():
        if v.requires_grad:
            print(f"trainable model arguments: {n} - {v.dtype} - {v.shape}")
        else:
            print(f"not trainable model arguments: {n} - {v.dtype} - {v.shape}")


def main(model_args, data_args, training_args):
    # Set seed for reproducibility
    set_seed(training_args.seed) # 设置随机种子,以确保实验可重复性

    # model ,调用create_and_prepare_model函数,根据参数创建并准备模型、PEFT配置和分词器
    model, peft_config, tokenizer = create_and_prepare_model(model_args, data_args, training_args)

    # gradient ckpt
    # 配置是否使用模型缓存和梯度检查点, 模型缓存可以加速注意力计算,但会占用更多内存, 梯度检查点可以节省内存,但会增加一些计算开销
    # 如果使用了Unsloth,则不需要梯度检查点
    model.config.use_cache = not training_args.gradient_checkpointing
    training_args.gradient_checkpointing = training_args.gradient_checkpointing and not model_args.use_uns

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1591494.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

安装一个在线VS Code 随时随地在线编辑代码 code server搭建教程

code-server是一款在线的 VS Code&#xff0c;只需将其部署到服务端&#xff0c;就可以在浏览器上使用 VS Code&#xff0c;本文将介绍 code-server 安装和使用方法。 首先我们需要准备一台Linux服务器&#xff0c;这里我推荐伍六七云&#xff1a;https://www.vps567.com/ 香港…

卫星图像10个开源数据集资源汇总

文章目录 1、UC Merced Land-Use 2、Indian Pines 3、KSC 4、Washington DC 5、BigEarthNet 6、水体卫星图像的图像 7、城市航拍图像分割数据集 8、游泳池和汽车卫星图像检测 9、人工月球景观数据集 10、马萨诸塞州道路数据集 1、UC Merced Land-Use 数据集下载地址&am…

android11 如何修改状态栏的背景

修改status_bar.xml &#xff1a; <LinearLayout android:id"id/status_bar_contents"android:background"#1ABC9C"android:layout_width"match_parent"android:layout_height"match_parent"android:paddingStart"dimen/statu…

数字IC/FPGA——亚稳态及跨时钟域

什么是亚稳态亚稳态会造成什么平均故障间隔时间如何解决亚稳态同步时钟和异步时钟单bit电平信号如何跨时钟域单bit脉冲信号如何跨时钟域多bit信号如何跨时钟域 目录 一、亚稳态1.基本概念2.危害3.平均故障时间4.解决亚稳态的方法 二、跨时钟域1.同步电路和异步电路&#xff08;…

c语言例题,求数组中最大值,99乘法口诀表

例题1&#xff1a;求出数组中最大的值 根据题意&#xff0c;我们知道的是需要从一个数组中找到一个最大的元素并且输出。那首先我们先建立一个数组&#xff0c;然后将一些不有序的整型元素放到数组中&#xff0c;然后再建立一个变量来存放数组中的第一个元素&#xff0c;通过一…

使用avx2 指令集加速向量算法运算

使用cpu-z 查看cpu指令集 2 向量加&#xff0c;乘法&#xff0c;除法 我们使用向量加&#xff0c;为什么函数是0 到 8 的计算&#xff0c;因为avx2 寄存器为256位&#xff0c;同时设置启动增强指令集 #include <immintrin.h> // 引入包含AVX2指令集的头文件void vecto…

AI识别技术详解 --在windows环境中部署基于YOLO v8模型的目标检测

首先 YOLO是一个端到端的目标检测算法&#xff0c;一次前向传播计算&#xff0c;实现图像的多目标检测任务&#xff0c;我么可以在ultralytics官网上查看YOLO的各个版本&#xff08;v1-v8&#xff09;以及源码 使用YOLO v8提供的python接口&#xff0c;训练一个佩戴安全帽的目标…

vue 百度地图 使用 vue-baidu-map 进行当前位置定位和范围展示

vue 百度地图 使用 vue-baidu-map 进行当前位置定位和范围展示&#xff08;考勤打卡&#xff09; 一、创建百度地图账号&#xff0c;获取秘钥二、 引入插件1、安装vue-baidu-map2、在main.js中引入 三、 简单使用 最近写项目的时候&#xff0c;做到了考勤打卡的模块内容&#x…

Qt控件---容器类

文章目录 QGroupBox&#xff08;有标题的分组框&#xff09;QTabWidget&#xff08;带有标签页控件&#xff09; QGroupBox&#xff08;有标题的分组框&#xff09; 属性说明title分组框的标题alignment分组框内部内容的对齐方式flat是否为 扁平 模式checkable是否可选&#x…

僵尸进程和孤儿进程

目录 引言僵尸进程僵尸进程的状态僵尸进程周边知识 孤儿进程孤儿进程的状态 进程中的其他状态①.R---表示进程运行状态。②.S---表示进程的休眠状态。(进程什么都没做)③T 和 t 进程的运行、阻塞和挂起运行阻塞挂起状态&#xff1a; 引言 今天我们来将僵尸进程和孤儿进程以及其…

两数之和-第12届蓝桥杯选拔赛Python真题精选

[导读]&#xff1a;超平老师的Scratch蓝桥杯真题解读系列在推出之后&#xff0c;受到了广大老师和家长的好评&#xff0c;非常感谢各位的认可和厚爱。作为回馈&#xff0c;超平老师计划推出《Python蓝桥杯真题解析100讲》&#xff0c;这是解读系列的第51讲。 两数之和&#xf…

深入解析API技术:原理、实现与应用

在现代软件开发中&#xff0c;API&#xff08;应用程序接口&#xff09;扮演着至关重要的角色。API 允许不同的软件应用程序和系统之间进行通信和数据交换&#xff0c;从而构建出更加高效、灵活和可扩展的软件解决方案。本文将深入解析API技术的原理、实现方法&#xff0c;并附…

FANUC机器人通过ROBOGUIDE实现与实际的机器人进行程序导入导出的具体方法示例

FANUC机器人通过ROBOGUIDE实现与实际的机器人进行程序导入导出的具体方法示例 如下图所示,在电脑的开始菜单中找到”Robot Neiborhood”,点击进入, 如下图所示,设置要连接的机器人名称和主机IP地址(要确保自己的电脑和机器人IP地址在同一网段内),点击Add添加, 添加在线…

TCP 三次握手与四次挥手面试题(计算机网络)

TCP 基本认识 TCP 头格式有哪些&#xff1f; 序列号&#xff1a;在建立连接时由计算机生成的随机数作为其初始值&#xff0c;通过 SYN 包传给接收端主机&#xff0c;每发送一次数据&#xff0c;就「累加」一次该「数据字节数」的大小。用来解决网络包乱序问题。 确认应答号&a…

成都百洲文化传媒有限公司电商领域的新锐力量

在电商服务领域&#xff0c;成都百洲文化传媒有限公司凭借其专业的服务理念和创新的策略&#xff0c;正逐渐成为行业内的翘楚。这家公司不仅拥有资深的电商团队&#xff0c;还以其精准的市场定位和高效的服务模式&#xff0c;赢得了众多客户的信赖和好评。 一、专业团队&#…

基于 FPGA 的 DE1-SoC 功率估算器

Introduction 功耗是当今许多技术都要考虑的重要因素。例如&#xff0c;手机生产商总是谈论他们在电源管理方面的改进&#xff0c;以及如何延长电池的使用寿命。功能与功耗之间的平衡是许多人都在研究的有趣课题。然而&#xff0c;当我们做实验时&#xff0c;我们很少会考虑我…

centos7上docker搭建vulhub靶场

1 vulhub靶场概述 VulHub是一个在线靶场平台&#xff0c;提供了丰富的漏洞环境供安全爱好者学习和实践。 该平台主要面向网络安全初学者和进阶者&#xff0c;通过模拟真实的漏洞环境&#xff0c;帮助用户深入了解漏洞的成因、利用方式以及防范措施。 此外&#xff0c;VulHub还…

信号完整性的常见术语概念(面试常用)

目录 术语 概念一览 1&#xff0e;信号完整性&#xff08;Signal Integrity&#xff09; 2&#xff0e;传输线&#xff08;Transmission Line&#xff09; 3&#xff0e;特性阻抗&#xff08;Characteristic Impedance&#xff09; 4&#xff0e;反射&#xff08;Reflecti…

Stable Video Diffusion(SV3D)安装和测试(windows10)

SVD 安装教程 1.安装miniconda https://docs.anaconda.com/free/miniconda/index.html 2.创建python环境 conda create --name sv3d python3.10conda activate sv3d3.安装triton2.0.0 下载地址&#xff1a;https://huggingface.co/r4ziel/xformers_pre_built/resolve/mai…

AndroidAutomotive模块介绍(一)整体介绍

前言 Android Automotive 是一个基本 Android 平台&#xff0c;可运行 IVI 系统中预安装的 Android 应用以及可选的第二方和第三方 Android 应用。 本系列文档将会系统的介绍 Android Automotive 的功能、架构、逻辑等。模块逻辑将从 应用api接口、系统服务、底层服务&#x…