目录
一.引言
二.常用参数
◆ ModelArguments
◆ DataArguments
◆ TrainingArguments
◆ GeneratingArguments
三.代码实现
◆ Python 代码
◆ Shell 代码
四.总结
一.引言
LLM 相关训练框架都会引入 ModelArguments、DataArguments、TrainingArguments、GeneratingArguments 并通过 Transformer.HfArgumentParser 进行整合,实现了两行代码处理训练全程的参数问题。
DataArguments - 数据集参数
TrainingArguments - 训练参数
GeneratingArguments - 生成参数
二.常用参数
◆ ModelArguments
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
ModelArguments 主要存储模型加载与配置的相关参数,一般还有以下参数,大家可以自定义:
参数名称 | 默认 | 类型 | 含义 |
model_name_or_path | None | str | 模型地址或名称 |
cache_dir | None | str | 缓存地址 |
use_fast_tokenizer | False | bool | 使用快速 tokenizer |
padding_side | left | str | 模型 pad 选择 |
quantization_bit | None | int | 量化 bit 选择 |
compute_type | None | torch.dtype | 模型参数类型 |
checkpoint_dir | None | str | 微调参数地址 |
mode | None | str | reward、lora |
plot_loss | False | bool | 打印训练 Loss |
◆ DataArguments
@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
DataArguments 主要负责数据集相关参数,数据集通过 dataset 构成,通常包含下述参数:
参数名称 | 默认 | 类型 | 含义 |
data_path | None | str | 数据集地址 |
process_num | None | int | 并行处理 |
max_source_length | 512 | int | source 最大长度 |
max_target_length | 512 | int | target 最大长度 |
max_samples | None | int | 最大样本数 |
ignore_pad_token | None | int | loss 计算是否忽略 |
prompt_template | None | str | 样本生成 prompt 模板 |
◆ TrainingArguments
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
use_lora: bool = field(default=False)
output_dir: str = field(default="")
TrainingArguments 主要存储模型微调,训练相关的参数:
参数名称 | 默认 | 类型 | 含义 |
finetuning_type | lora | str | 微调类型 |
lora_target | q_proj,v_proj | str | 微调 Layer |
lora_rank | 8 | int | lora 降维维度 |
lora_alpha | 32.0 | float | lora 微调比例因子 |
lora_dropout | 0.1 | float | dropout 比例 |
num_hidden_layers | 32 | int | Decode 数量 |
num_layer_trainable | 3 | int | freeze layer 数量 |
name_module_trainable | mlp | str | freeze 训练层选择 |
output_dir | None | str | 模型输出地址 |
◆ GeneratingArguments
@dataclass
class GeneratingArguments:
do_sample: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
)
GeneratingArguments 主要负责 model generate 生成的配置:
参数名称 | 默认 | 类型 | 含义 |
do_sample | True | bool | 采样或贪心 |
temperature | 0.95 | float | 调整下一个 token 的概率 |
top_p | 0.7 | float | token 概率 top 区间 |
top_k | 50 | int | token 词库数量 |
num_beams | 1 | int | beam search 数量 |
max_length | None | int | 最大生成 token 数 |
max_new_tokens | 512 | int | 最多新 toekn 生成数 |
repatition_penalty | 1.0 | float | 重复惩罚 |
length_penalty | 1.0 | float | 长度惩罚 |
之前单独整理了生成的参数和代码,可以参考: LLM - model batch generate 生成文本
三.代码实现
◆ Python 代码
from typing import Optional
from dataclasses import dataclass, field
import transformers
...
添加上述的 Argument Class
...
if __name__ == '__main__':
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, GeneratingArguments))
model_args, data_args, training_args, generate_args = parser.parse_args_into_dataclasses()
print(model_args)
print(data_args)
print(training_args)
print(generate_args)
两行搞定多类参数,参数对应属性使用 args.xxx 调用即可。
◆ Shell 代码
#!/bin/bash
python GetConfigByArgs.py \
--report_to "none" \
--data_path "data/belle_chat_ramdon_10k.json" \
--model_name_or_path "baichuan-inc/Baichuan2-7B-Base" \
--output_dir "output" \
--model_max_length 512 \
--num_train_epochs 4 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 1 \
--save_strategy epoch \
--learning_rate 2e-5 \
--lr_scheduler_type constant \
--adam_beta1 0.9 \
--adam_beta2 0.98 \
--adam_epsilon 1e-8 \
--max_grad_norm 1.0 \
--weight_decay 1e-4 \
--warmup_ratio 0.0 \
--logging_steps 1 \
--gradient_checkpointing True \
--deepspeed ds_config.json \
--bf16 False \
--tf32 False
通过 -- 传递我们需要的参数即可。
四.总结
这个没啥总结的了,就是觉得写法比较优雅,后面自己的脚本也可以借用。