LLM - Model、Data、Training、Generate Agruments 超参解析

news2024/9/20 16:33:39

目录

一.引言

二.常用参数

◆ ModelArguments

◆ DataArguments

◆ TrainingArguments

◆ GeneratingArguments

三.代码实现

◆ Python 代码

◆ Shell 代码

四.总结


一.引言

LLM 相关训练框架都会引入 ModelArguments、DataArguments、TrainingArguments、GeneratingArguments 并通过 Transformer.HfArgumentParser 进行整合,实现了两行代码处理训练全程的参数问题。

ModelArguments - 模型参数

DataArguments - 数据集参数

TrainingArguments - 训练参数

GeneratingArguments - 生成参数

二.常用参数

◆ ModelArguments

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")

ModelArguments 主要存储模型加载与配置的相关参数,一般还有以下参数,大家可以自定义:

参数名称默认类型含义
model_name_or_pathNonestr模型地址或名称
cache_dirNonestr缓存地址
use_fast_tokenizerFalsebool使用快速 tokenizer
padding_sideleftstr模型 pad 选择
quantization_bitNoneint量化 bit 选择
compute_typeNonetorch.dtype模型参数类型
checkpoint_dirNonestr微调参数地址
modeNonestrreward、lora
plot_lossFalsebool打印训练 Loss

◆ DataArguments

@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )

DataArguments 主要负责数据集相关参数,数据集通过 dataset 构成,通常包含下述参数:

参数名称默认类型含义
data_pathNonestr数据集地址
process_numNoneint并行处理
max_source_length512intsource 最大长度
max_target_length512inttarget 最大长度
max_samplesNoneint最大样本数
ignore_pad_tokenNoneintloss 计算是否忽略
prompt_templateNonestr样本生成 prompt 模板

◆ TrainingArguments

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = field(default=False)
    output_dir: str = field(default="")

TrainingArguments 主要存储模型微调,训练相关的参数:

参数名称默认类型含义
finetuning_typelorastr微调类型
lora_targetq_proj,v_projstr微调 Layer
lora_rank8intlora 降维维度
lora_alpha32.0floatlora 微调比例因子
lora_dropout0.1floatdropout 比例
num_hidden_layers32intDecode 数量
num_layer_trainable3intfreeze layer 数量
name_module_trainablemlpstrfreeze 训练层选择
output_dirNonestr模型输出地址

◆ GeneratingArguments

@dataclass
class GeneratingArguments:
    do_sample: Optional[bool] = field(
        default=True,
        metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
    )

GeneratingArguments 主要负责 model generate 生成的配置:

参数名称默认类型含义
do_sampleTruebool采样或贪心
temperature0.95float调整下一个 token 的概率
top_p0.7floattoken 概率 top 区间
top_k50inttoken 词库数量
num_beams1intbeam search 数量
max_lengthNoneint最大生成 token 数
max_new_tokens512int最多新 toekn 生成数
repatition_penalty1.0float重复惩罚
length_penalty1.0float长度惩罚

之前单独整理了生成的参数和代码,可以参考: LLM - model batch generate 生成文本

三.代码实现

◆ Python 代码

from typing import Optional
from dataclasses import dataclass, field
import transformers


...

    添加上述的 Argument Class

...


if __name__ == '__main__':
    parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments, GeneratingArguments))
    model_args, data_args, training_args, generate_args = parser.parse_args_into_dataclasses()

    print(model_args)
    print(data_args)
    print(training_args)
    print(generate_args)

两行搞定多类参数,参数对应属性使用 args.xxx 调用即可。

Shell 代码

#!/bin/bash

python GetConfigByArgs.py \
    --report_to "none" \
    --data_path "data/belle_chat_ramdon_10k.json" \
    --model_name_or_path "baichuan-inc/Baichuan2-7B-Base" \
    --output_dir "output" \
    --model_max_length 512 \
    --num_train_epochs 4 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --save_strategy epoch \
    --learning_rate 2e-5 \
    --lr_scheduler_type constant \
    --adam_beta1 0.9 \
    --adam_beta2 0.98 \
    --adam_epsilon 1e-8 \
    --max_grad_norm 1.0 \
    --weight_decay 1e-4 \
    --warmup_ratio 0.0 \
    --logging_steps 1 \
    --gradient_checkpointing True \
    --deepspeed ds_config.json \
    --bf16 False \
    --tf32 False

通过 -- 传递我们需要的参数即可。

四.总结

这个没啥总结的了,就是觉得写法比较优雅,后面自己的脚本也可以借用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/987701.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL 学习笔记

😀😀😀创作不易,各位看官点赞收藏. 文章目录 MySQL 学习笔记1、DQL 查询语句1.1、基本查询1.2、函数查询1.2.1、单行函数1.2.2、聚合函数 1.3、复杂查询1.3.1、连接查询1.3.2、子查询 1.4、SQL 语句 执行顺序 2、DDL 定义语句2.1、…

F. Selling a Menagerie Codeforces Round 895 (Div. 3)

Problem - F - Codeforces 题目大意:有n个动物,每个动物i有一个害怕的动物a[i],现要卖掉所有动物,每个动物都有价值c[i],如果i在a[i]之前卖掉,就会获得2*c[i]的价值,如果在a[i]之后被卖掉就会获…

垃圾回收 - 分代垃圾回收

分代垃圾回收在对象中导入了“年龄”的概念,通过优先回收容易成为垃圾的对象,提高垃圾回收的效率。 1、新生代对象和老年代对象 分代垃圾回收中把对象分类成几代,针对不同的代使用不同的 GC 算法,我们把刚生成的对象称为新生代对…

UI自动化测试详解

前言 随着智能化信息基础设施的推进,软件开发的进程也不断加快。软件测试工作也逐渐由传统的手工测试向软件自动化测试跨越。 对于很多企业来说,做好软件自动化测试工作已经不仅仅是通过测试工具进行“点点点”,要想找出软件测试过程中的缺…

python 小案例72

import requestsdef fetch_data_from_api(url):response requests.get(url)if response.status_code 200:data response.json()return dataelse:print("Failed to fetch data from API")return None# 使用NASA的API获取每日天文图片 url "https://api.nasa.…

【Springcloud】Actuator服务监控

【Springcloud】Actuator服务监控 【一】基本介绍【二】如何使用【三】端点分类【四】整合Admin-Ui【五】客户端配置【六】集成Nacos【七】登录认证【八】实时日志【九】动态日志【十】自定义通知 【一】基本介绍 (1)什么是服务监控 监视当前系统应用状…

情侣头像微信小程序源码 朋友圈背景小程序源码 动态壁纸微信小程序源码

壁纸和情侣头像,朋友圈素材都可以做,带视频教程。 搭建也不难,纯前端无后台。直接开发者工具调试前端,绑定合法域名,流量主功能也是在前端替换。 无需服务器域名直接上手!!!

飞行动力学 - 第17节-part3-垂尾和推进系统对航向的影响 之 基础点摘要

飞行动力学 - 第17节-part3-垂尾和推进系统对航向的影响 之 基础点摘要 1. 尾翼的贡献2. 垂尾是航向静稳定性的最大来源3. 推进系统对航向的贡献3.1 螺旋桨3.2 喷气式 4. 参考资料 1. 尾翼的贡献 平尾对航向静稳定性的影响机理与机翼相同,由于尺寸小,通…

AI教程 | 用Midjourney制作AI模特和换装的保姆级教程

Hi! 大家好,我是专注于AI项目实战的赤辰。 昨天电商朋友过来交流,聊到他最近新开了一家淘宝店,在没有请任何员工的情况下,他一个人用AI工具完成了店铺取名,商品文案,店铺logo,主图设计&#xf…

ASO优化之阅读并回复应用的评论

回复评论对于与用户保持牢固的关系非常重要。如果时间有限,优先回复负面评论,可以向其他用户保证,我们正在积极解决应用的问题,从而提高转化率。 1、逻辑与沟通要清晰。 首先,无论他们的反馈是正面还是负面&#xff0…

【c++】如何有效地利用命名空间?

​ 🌱博客主页:青竹雾色间 😘博客制作不易欢迎各位👍点赞⭐收藏➕关注 ​✨人生如寄,多忧何为 ✨ 目录 前言什么是命名空间?命名空间的语法命名空间的使用避免命名冲突命名空间的嵌套总结 前言 当谈到C编…

51单片机-直流电机学习

简介 51单片机采用的是5V的直流电机 轴长:8mm 轴径:2mm 电压:1-6V 参考电流:0.35-0.4A 3V 转速:17000-18000 转每分钟 他的组成: 直流电机的结构应由 定子 和 转子 两大部分组成。 直流电机运行时静止…

【Spring】aop的底层原理

🎄欢迎来到边境矢梦的csdn博文🎄 🎄本文主要梳理 Spring 中的切面编程aop的底层原理和重点注意的地方 🎄 🌈我是边境矢梦,一个正在为秋招和算法竞赛做准备的学生🌈 🎆喜欢的朋友可以…

工作和生活中,如何用项目管理思维解决复杂的事情?

在工作和生活中,许多事情都可以采用项目思维方式来解决。当我们逐渐将工作和生活中的各种事务以项目的方式来处理和推进时,我们可能并没有意识到,实际上我们正在运用项目管理思维。 项目管理思维能帮助我们在面对繁杂事务时,理清…

DevOps到底是什么意思?

前言: 当我们谈到 DevOps 时,可能讨论的是:流程和管理,运维和自动化,架构和服务,以及文化和组织等等概念。那么,到底什么是"DevOps"呢? 那么,DevOps是什么呢? 有人说它是一种方法,也有人说它是一种工具,还有人说它是一种思想。更有甚者,说它是一种哲学…

【echarts】如何修改折线图X轴每个刻度的间隔宽度,让拥挤的空间变大,所有坐标点的文案可以显示得下,Echarts x轴文本内容太长的几种解决方案

Echarts 如何修改折线图X轴每个刻度的间隔宽度,让拥挤的空间变大,所有坐标点的文案可以显示得下,Echarts x轴文本内容太长的几种解决方案 有以下几种方案,堪称最全方案: 1、dataZoom进行坐标的比例缩放 通过调整dataZ…

生态第五篇-调度的多维空间技术

生态第五篇-调度的多维空间技术 文章目录 生态第五篇-调度的多维空间技术前言一、什么是多维空间?二、实现原理1.先看效果2.如何实现 预告 前言 调度已经结束更新了本不想再更新调度技术,因为生态的更新计划里面有这一条所以就写一篇把 一、什么是多维…

Java“牵手”ebay商品详情数据,ebay商品详情API接口,ebayAPI接口申请指南

天猫平台商品详情接口是开放平台提供的一种API接口,通过调用API接口,开发者可以获取天猫商品的标题、价格、库存、月销量、总销量、库存、详情描述、图片等详细信息 。 获取商品详情接口API是一种用于获取电商平台上商品详情数据的接口,通过…

了解Armv8.x和Armv9.x扩展

概述 Arm架构新增的功能以扩展的形式提供,这样Arm能够定期发布新功能,以响应合作伙伴的需求,而无需对主架构进行重大更改。 Arm 每年都会发布新的扩展。Cortex CPU 是该架构的 Arm 实现,其会根据发布时间使用相应的扩展。 本指…

扫描mapper包

文章目录 第一种-配置在resource目录下第二种- 直接配置java代码目录&#xff0c;在Maven中配置相关路径 第一种-配置在resource目录下 第二种- 直接配置java代码目录&#xff0c;在Maven中配置相关路径 不配置不会把mapper的xml文件编译到target文件中 <build><res…