YAML 文件提供了一种清晰、简洁且易于理解的方式来描述配置信息,特别适用于机器学习模型的超参数调优和实验管理。
以 Latent Diffusion 官方代码仓库中的 https://github.com/CompVis/latent-diffusion/blob/main/configs/autoencoder/autoencoder_kl_32x32x4.yaml 为例(如下),该 YAML 配置文件,用于定义训练一个自编码器模型的设置,其中包含 3 个部分:
- model (AutoencoderKL的模型结构)
- data(DataModuleFromConfig中如何读入数据)
- lightning(设置回调函数和训练器)
model:
base_learning_rate: 4.5e-6
target: ldm.models.autoencoder.AutoencoderKL
params:
monitor: "val/rec_loss"
embed_dim: 4
lossconfig:
target: ldm.modules.losses.LPIPSWithDiscriminator
params:
disc_start: 50001
kl_weight: 0.000001
disc_weight: 0.5
ddconfig:
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
data:
target: main.DataModuleFromConfig
params:
batch_size: 12
wrap: True
train:
target: ldm.data.imagenet.ImageNetSRTrain
params:
size: 256
degradation: pil_nearest
validation:
target: ldm.data.imagenet.ImageNetSRValidation
params:
size: 256
degradation: pil_nearest
lightning:
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 1000
max_images: 8
increase_log_steps: True
trainer:
benchmark: True
accumulate_grad_batches: 2
Model
base_learning_rate: 4.5e-6: 这是基础学习率,用于优化器的初始化。学习率表示在每次参数更新时,参数被调整的程度。target: ldm.models.autoencoder.AutoencoderKL: 这是要训练的模型的类路径,即模型定义代码所在的位置。params: 这里是模型的参数设置。monitor: "val/rec_loss": 监控的指标,通常是验证集上的重构损失。embed_dim: 4: 嵌入维度,可能是自编码器中隐藏层的维度。lossconfig: 损失函数的配置。-
target: ldm.modules.losses.LPIPSWithDiscriminator: LPIPS损失所在位置。
-
params: 参数设置。disc_start: 50001: 鉴别器开始的步数。kl_weight: 0.000001: KL散度的权重。disc_weight: 0.5: 鉴别器权重。
-
ddconfig: 双向变换的配置。double_z: True: 是否使用双向Z变换。- 其他参数是有关双向变换网络结构的设置,包括通道数量、分辨率、残差块数量等。
Data
target: main.DataModuleFromConfig: 数据模块的类路径。params: 数据加载器的参数设置。batch_size: 12: 批量大小,即每次迭代训练时传递给模型的样本数量。wrap: True: 是否循环迭代数据。train: 训练数据的设置。target: ldm.data.imagenet.ImageNetSRTrain: 训练集加载器的类路径。params: 参数设置。size: 256: 数据的大小。degradation: pil_nearest: 图像降质方法。
validation: 验证集的设置。target: ldm.data.imagenet.ImageNetSRValidation: 验证数据加载器的类路径。params: 参数设置,与训练数据类似。
Lightning
callbacks: 回调函数的设置。image_logger: 图像记录器的设置。target: main.ImageLogger: 图像记录器的类路径。params: 参数设置。batch_frequency: 1000: 记录图像的频率。max_images: 8: 最大图像数量。increase_log_steps: True: 是否逐步增加日志步骤。
trainer: 训练器设置。benchmark: True: 是否启用性能测试。accumulate_grad_batches: 2: 梯度累积的步骤数量,用于处理较大的批次大小。



















