Muon优化器在小规模语言模型训练中表现出色,但在大规模模型训练中的可扩展性尚未得到证实。月之暗面通过系统分析和改进,成功将 Muon 应用于 3B/16B 参数的 MoE 模型训练,累计训练 5.7 万亿 token。结果表明,Muon 可以替代 AdamW 成为大规模 LLM 训练的标准优化器,在训练效率和模型性能方面具有显著优势。通过开源实现、Moonlight 模型和训练中间检查点,本文旨在推动可扩展优化技术研究,加速 LLM 训练方法发展。
🔗 资源链接
- 代码 & 实现
https://github.com/MoonshotAI/Moonlight - 全系列模型(预训练/指令微调/中间检查点)
https://huggingface.co/moonshotai - 技术报告
https://github.com/MoonshotAI/Moonlight/blob/master/Moonlight.pdf
🛠 核心改进
1. 权重衰减机制
通过在 Muon 中引入标准 AdamW 式权重衰减,有效解决了模型参数和层输出 RMS 过大的问题。
参数更新公式
使用如下数学公式描述 Muon 优化器的权重更新规则:
W t = W t − 1 − η t ( O t + λ W t − 1 ) W_{t} = W_{t-1} - \eta_{t} \left( O_{t} + \lambda W_{t-1} \right) Wt=Wt−1−ηt(Ot+λWt−1)
公式解析:
- W t W_t Wt:第 t 步的权重矩阵
- η t \eta_t ηt:动态学习率
- O t O_t Ot:当前梯度估计量
- λ \lambda λ:权重衰减系数
- ( ⋅ ) \left( \cdot \right) (⋅):自适应缩放运算符
该公式实现了:
- 权重衰减项 λ W t − 1 \lambda W_{t-1} λWt−1 的显式分离
- 梯度方向与正则化项的联合优化
- 通过动态学习率
η
t
\eta_t
ηt 实现训练稳定性控制
2. 参数更新尺度调整:一致的更新 RMS
改进参数更新规则,确保不同形状矩阵间的更新 RMS 一致性,显著提升训练稳定性。,提出基于矩阵维度特性的调整策略:
缩放规则
对每个矩阵按其最大维度进行归一化处理:
max
(
A
,
B
)
\sqrt{\max(A,B)}
max(A,B)
更新公式
调整后的参数更新规则实现为:
W
t
=
W
t
−
1
−
η
t
(
0.2
⋅
O
t
⋅
max
(
A
,
B
)
+
λ
W
t
−
1
)
W_{t} = W_{t-1} - \eta_{t} \left( 0.2 \cdot O_{t} \cdot \sqrt{\max(A,B)} + \lambda W_{t-1} \right)
Wt=Wt−1−ηt(0.2⋅Ot⋅max(A,B)+λWt−1)
关键改进:
- 维度感知缩放:通过 max ( A , B ) \sqrt{\max(A,B)} max(A,B) 补偿不同形状矩阵的梯度尺度差异
- 经验系数:Muon 的更新均方根误差控制在 0.2
3. 分布式实现优化
开发基于 ZeRO-1 风格的分布式版本,实现:
- 内存占用优化
- 通信效率提升(梯度同步频率降低 50%)
🧪 实验设计
模型架构
采用类 Deepseek-V3-Small 架构,并针对 Moonlight 模型需求进行微调。
数据集
使用 Kimi 团队提供的 5.7 万亿 token 数据集进行预训练。
训练流程
分阶段优化策略:
- 渐进式提升学习率
- 动态调整批量大小
- 多阶段数据质量优化
📊 实验结果
一致性更新 RMS
调整学习率方法(Adjusted LR)表现最优,显著优于:
- 基线方法(Baseline)
- 仅保持与 AdamW 一致 RMS 的方法(Update Norm)
扩展性验证
Muon 在计算最优设置下仅需 52% 训练 FLOPs 即可达到 AdamW 同等性能。
预训练性能
Moonlight 模型在 1.2T token 时的表现显著优于 AdamW 训练的 Moonlight-A 模型。
微调表现
- 优势场景:Muon 预训练+微调模型全面优于 AdamW 预训练+微调模型
- 局限场景:当微调阶段切换优化器时,Muon 优势减弱
关键问题及回答
问题 1:Muon 优化器在大规模模型训练中引入权重衰减的具体作用是什么?
权重衰减在 Muon 优化器中的作用主要是防止模型权重过大,从而避免模型在训练过程中出现梯度爆炸或梯度消失的问题。具体来说,权重衰减通过在更新规则中加入一个正则化项来限制权重的增长。论文中的权重衰减公式如下:
[W_{t}=W_{t - 1}-\eta_{t}\left(O_{t}+\lambda W_{t - 1}\right)]
其中,(\lambda)是权重衰减比率。通过引入权重衰减,Muon 能够在训练过程中更好地控制权重的增长速度,确保模型在训练后期不会出现过大的权重值,从而提高模型的稳定性和泛化能力。实验结果表明,加入权重衰减后,Muon 在大规模模型训练中的性能显著提升。
问题 2:分布式 Muon 实现是如何提高内存效率和减少通信开销的?
分布式 Muon 实现通过以下方式提高内存效率和减少通信开销:
- Reduce-Scatter 操作:在数据并行组上进行梯度聚合,减少了全局通信的需求。
- 局部分区动量:使用局部分区动量进行动量应用,避免了全局动量的传输。
- Newton-Schulz 迭代:在本地计算 Newton-Schulz 迭代,只对需要的部分进行通信。
- DP Gather 操作:将局部分区更新矩阵聚合成全矩阵,这一步骤只在必要时进行一次。
- 减少冗余计算:在计算全更新矩阵后,丢弃不需要的部分,只保留局部更新矩阵进行下一步计算。
通过这些优化措施,分布式 Muon 在保持算法数学性质的同时,显著减少了内存使用和通信开销,提高了大规模模型训练的效率。
问题 3:Moonlight 模型在监督微调阶段的表现如何,与仅使用 AdamW 预训练和微调的模型相比有何优势?
Moonlight 模型在监督微调阶段表现出色,具体优势如下:
- 更高的性能:Muon 预训练和微调的模型在多个基准测试中均表现出比仅使用 AdamW 预训练和微调的模型更高的性能。例如,在 MMLU 和 GSM8k 基准上,Moonlight 模型分别取得了 70.0 和 77.4 的分数,而 AdamW 微调的模型分别为 66.7 和 70.7。
- 一致性优化:Muon 在整个训练过程中保持了更好的优化稳定性,避免了在微调阶段出现的梯度爆炸或梯度消失现象。
- 泛化能力:Muon 预训练和微调的模型在未见数据上的表现也更好,显示出更强的泛化能力。
注意
当微调阶段使用与预训练阶段不同的优化器时,Muon 并未表现出显著优势。这表明,为了充分发挥 Muon 的优势,建议在预训练和微调阶段都使用相同的优化器。
代码
class Muon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
Muon内部运行标准的SGD动量优化,然后执行一个正交化后处理步骤,
在该步骤中,每个二维参数的更新将被替换为最近的正交矩阵。
为了高效地对每个更新进行正交化,我们使用牛顿 - 舒尔茨迭代,
其优点是可以在GPU上以bfloat16格式稳定运行。
一些警告:
- 我们认为这个优化器不太可能在小批量训练中表现良好。
- 我们认为它可能不太适合微调预训练模型,但我们还没有进行测试。
参数:
muon_params: 要由Muon优化的参数。
lr: 学习率。更新的谱范数将为`lr`。(0.02是一个不错的默认值)
momentum: 内部SGD使用的动量。(0.95是一个不错的默认值)
nesterov: 是否在内部SGD中使用Nesterov风格的动量。(推荐)
ns_steps: 要运行的牛顿 - 舒尔茨迭代的步数。(6步可能总是足够的)
adamw_params: 要由AdamW优化的参数。`muon_params`中任何为{0, 1}维的参数
或者被检测为嵌入层或lm_head的参数也将由AdamW进行优化。
adamw_lr: 内部AdamW的学习率。
adamw_betas: 内部AdamW的betas参数。
adamw_eps: 内部AdamW的epsilon参数。
adamw_wd: 内部AdamW的权重衰减参数。
"""
def __init__(
self,
lr=1e-3,
wd=0.1,
muon_params=None,
momentum=0.95,
nesterov=True,
ns_steps=5,
adamw_params=None,
adamw_betas=(0.95, 0.95),
adamw_eps=1e-8,
):
# 定义默认参数字典
defaults = dict(
lr=lr,
wd=wd,
momentum=momentum,
nesterov=nesterov,
ns_steps=ns_steps,
adamw_betas=adamw_betas,
adamw_eps=adamw_eps,
)
# 将muon_params转换为列表
params = list(muon_params)
# 如果adamw_params不为None,将其转换为列表,否则设为空列表
adamw_params = list(adamw_params) if adamw_params is not None else []
# 将adamw_params的参数添加到params列表中
params.extend(adamw_params)
# 调用父类的初始化方法
super().__init__(params, defaults)
# 将参数分为使用Muon优化的和不使用Muon优化的两类
for p in muon_params:
# 对于muon_params中的每个参数,确保其维度为2
assert p.ndim == 2, p.ndim
# 标记该参数使用Muon进行优化
self.state[p]["use_muon"] = True
for p in adamw_params:
# 对于adamw_params中的每个参数,标记其不使用Muon进行优化
self.state[p]["use_muon"] = False
def adjust_lr_for_muon(self, lr, param_shape):
# 获取参数矩阵的前两个维度
A, B = param_shape[:2]
# 我们根据参数矩阵的大小调整学习率和权重衰减,如论文中所述
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
# 计算调整后的学习率
adjusted_lr = lr * adjusted_ratio
return adjusted_lr
def step(self, closure=None):
"""执行单个优化步骤。
参数:
closure (Callable, optional): 一个闭包,用于重新评估模型并返回损失。
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
############################
# Muon #
############################
# 筛选出使用Muon优化的参数
params = [p for p in group["params"] if self.state[p]["use_muon"]]
# 获取当前组的学习率
lr = group["lr"]
# 获取当前组的权重衰减
wd = group["wd"]
# 获取当前组的动量
momentum = group["momentum"]
# 以分布式方式生成权重更新
for p in params:
# 进行合理性检查
g = p.grad
if g is None:
continue
if g.ndim > 2:
g = g.view(g.size(0), -1)
assert g is not None
# 计算更新
state = self.state[p]
if "momentum_buffer" not in state:
# 如果动量缓冲区不存在,初始化为与梯度相同形状的零张量
state["momentum_buffer"] = torch.zeros_like(g)
buf = state["momentum_buffer"]
# 更新动量缓冲区
buf.mul_(momentum).add_(g)
if group["nesterov"]:
# 如果使用Nesterov动量,更新梯度
g = g.add(buf, alpha=momentum)
else:
g = buf
# 使用牛顿 - 舒尔茨迭代计算正交化更新
u = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
# 调整学习率
adjusted_lr = self.adjust_lr_for_muon(lr, p.shape)
# 应用权重衰减
p.data.mul_(1 - lr * wd)
# 应用更新
p.data.add_(u, alpha=-adjusted_lr)
############################
# AdamW backup #
############################
# 筛选出不使用Muon优化的参数
params = [p for p in group["params"] if not self.state[p]["use_muon"]]
lr = group['lr']
# 获取AdamW的betas参数
beta1, beta2 = group["adamw_betas"]
# 获取AdamW的epsilon参数
eps = group["adamw_eps"]
# 获取AdamW的权重衰减参数
weight_decay = group["wd"]
for p in params:
g = p.grad
if g is None:
continue
state = self.state[p]
if "step" not in state:
# 如果步数信息不存在,初始化步数和一阶、二阶动量
state["step"] = 0
state["moment1"] = torch.zeros_like(g)
state["moment2"] = torch.zeros_like(g)
# 更新步数
state["step"] += 1
step = state["step"]
buf1 = state["moment1"]
buf2 = state["moment2"]
# 更新一阶动量
buf1.lerp_(g, 1 - beta1)
# 更新二阶动量
buf2.lerp_(g.square(), 1 - beta2)
# 计算更新
g = buf1 / (eps + buf2.sqrt())
# 计算偏差修正
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
scale = bias_correction1 / bias_correction2**0.5
# 应用权重衰减
p.data.mul_(1 - lr * weight_decay)
# 应用更新
p.data.add_(g, alpha=-lr / scale)
return loss