Flag 验证器使用教程
Flag 验证器
是一种常用工具,用来验证命令行参数或配置文件中的标志(flag)是否符合预期规则。这些工具可以帮助开发者确保传入的参数满足一定的条件,避免因参数错误而导致程序运行失败。以下是对各个验证器功能的中文说明以及使用示例。
功能解释
1. register_validator
用于注册一个验证函数,该函数用来验证某个特定 flag 的值是否有效。
- 用法:
register_validator("learning_rate", lambda lr: lr > 0, message="学习率必须为正数。")
- 第一个参数是 flag 的名称,例如
"learning_rate"
。 - 第二个参数是一个验证函数,接收 flag 的值作为输入,返回
True
表示合法,抛出异常或返回False
表示非法。 message
参数是可选的,用于在验证失败时输出提示信息。
- 第一个参数是 flag 的名称,例如
2. validator
这是一个装饰器,用来定义并注册验证器函数。它和 register_validator
类似,但更简洁。
- 用法:
@validator def validate_positive_learning_rate(value): return value > 0 # 学习率必须为正数
3. register_multi_flags_validator
用于验证多个 flags 之间的关系。适用于当多个 flag 需要满足某种依赖关系或约束时。
- 用法:
register_multi_flags_validator( ["learning_rate", "batch_size"], lambda lr, bs: lr < 1 and bs > 0, message="学习率必须小于 1 且批量大小必须大于 0。" )
- 第一个参数是 flag 名称的列表。
- 第二个参数是验证函数,接收多个 flag 的值作为输入。
message
参数用于验证失败时的提示。
4. multi_flags_validator
这是 register_multi_flags_validator
的装饰器版本,用来简化验证器的定义。
- 用法:
@multi_flags_validator(["flag_a", "flag_b"]) def validate_flags(flag_a, flag_b): return flag_a != flag_b # 确保 flag_a 和 flag_b 的值不同
5. mark_flag_as_required
标记某个 flag 为必需。如果运行程序时未提供该 flag,则会报错。
- 用法:
mark_flag_as_required("model_path") # 模型路径是必需的
6. mark_flags_as_required
标记多个 flag 为必需。如果这些 flag 中的任意一个未提供,则会报错。
- 用法:
mark_flags_as_required(["input_path", "output_path"]) # 输入路径和输出路径都是必需的
7. mark_flags_as_mutual_exclusive
确保多个 flag 是互斥的,即只能设置其中一个。如果多个 flag 同时被设置,则会报错。
- 用法:
mark_flags_as_mutual_exclusive(["use_gpu", "use_tpu"]) # GPU 和 TPU 不能同时使用
8. mark_bool_flags_as_mutual_exclusive
这是 mark_flags_as_mutual_exclusive
的专门版本,用于布尔类型的 flag。确保多个布尔 flag 中最多只有一个为 True
。
- 用法:
mark_bool_flags_as_mutual_exclusive(["debug", "production"]) # debug 和 production 模式不能同时开启
这些工具如何协同使用
这些验证器通常用于框架(如 TensorFlow、PyTorch)或自定义的命令行工具中,用来确保传入的参数符合要求。以下是一个示例,展示如何结合使用这些验证器。
示例代码
以下代码展示了如何使用这些验证器来定义和验证命令行 flag。
from _validators import (
register_validator,
register_multi_flags_validator,
mark_flag_as_required,
mark_flags_as_mutual_exclusive,
mark_bool_flags_as_mutual_exclusive,
)
# 定义 flags
flags.DEFINE_float("learning_rate", 0.01, "优化器的学习率。")
flags.DEFINE_integer("batch_size", 32, "训练的批量大小。")
flags.DEFINE_boolean("use_gpu", False, "是否使用 GPU 进行训练。")
flags.DEFINE_boolean("use_tpu", False, "是否使用 TPU 进行训练。")
flags.DEFINE_string("output_dir", None, "保存训练结果的目录。")
# 注册验证器
# 确保学习率为正数
register_validator("learning_rate", lambda lr: lr > 0, message="学习率必须为正数!")
# 确保批量大小大于 0
register_validator("batch_size", lambda bs: bs > 0, message="批量大小必须大于 0!")
# 确保输出目录是必需的
mark_flag_as_required("output_dir")
# 确保 GPU 和 TPU 是互斥的
mark_bool_flags_as_mutual_exclusive(["use_gpu", "use_tpu"])
# 确保学习率和批量大小满足一定的关系
register_multi_flags_validator(
["learning_rate", "batch_size"],
lambda lr, bs: lr * bs < 1,
message="学习率和批量大小的乘积必须小于 1!"
)
运行结果
-
如果未提供
output_dir
:错误:output_dir 是必需的,请指定保存路径。
-
如果同时启用了
use_gpu
和use_tpu
:错误:use_gpu 和 use_tpu 是互斥的,请选择其中之一。
-
如果
learning_rate
为负数:错误:学习率必须为正数!
-
如果
learning_rate * batch_size >= 1
:错误:学习率和批量大小的乘积必须小于 1!
总结
通过以上的工具和方法,可以轻松实现以下功能:
- 验证单个 flag 的合法性,如检查参数范围。
- 验证多个 flag 的依赖关系,如互斥性或相关性。
- 确保必需的 flag 被提供,避免缺少关键参数导致程序失败。
因此在jaxpi的代码里:
import os
# Deterministic
# os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_reductions --xla_gpu_autotune_level=0"
os.environ["TF_CUDNN_DETERMINISTIC"] = "1" # DETERMINISTIC
from absl import app
from absl import flags
from absl import logging
from ml_collections import config_flags
import jax
jax.config.update("jax_default_matmul_precision", "highest")
import train
import eval
FLAGS = flags.FLAGS
flags.DEFINE_string("workdir", ".", "Directory to store model data.")
config_flags.DEFINE_config_file(
"config",
"./configs/default.py",
"File path to the training hyperparameter configuration.",
lock_config=True,
)
def main(argv):
if FLAGS.config.mode == "train":
train.train_and_evaluate(FLAGS.config, FLAGS.workdir)
elif FLAGS.config.mode == "eval":
eval.evaluate(FLAGS.config, FLAGS.workdir)
if __name__ == "__main__":
flags.mark_flags_as_required(["config", "workdir"])
app.run(main)
将 config 和 workdir 标记为必需的命令行参数。
如果运行程序时未提供这两个参数,会报错。
作用:
config:配置文件的路径,程序需要通过它加载配置。
workdir:工作目录,用于保存训练结果、模型检查点等。