代码原理分析
1. 核心思想
该代码实现了一个基于扩散模型(Diffusion Model)的强化学习策略网络。扩散模型通过逐步去噪过程生成动作,核心思想是:
• 前向过程:通过T步逐渐将专家动作添加高斯噪声,最终变成纯噪声
• 逆向过程:训练神经网络预测噪声,通过T步逐步去噪生成动作
• 数学基础:基于DDPM(Denoising Diffusion Probabilistic Models)框架
算法步骤:
1.1 前向加噪:在动作空间逐步添加高斯噪声,将真实动作分布转化为高斯分布
q
(
a
t
∣
a
t
−
1
)
=
N
(
a
t
;
1
−
β
t
a
t
−
1
,
β
t
I
)
q(\mathbf{a}_t|\mathbf{a}_{t-1}) = \mathcal{N}(\mathbf{a}_t; \sqrt{1-\beta_t}\mathbf{a}_{t-1}, \beta_t\mathbf{I})
q(at∣at−1)=N(at;1−βtat−1,βtI)
其中
β
t
\beta_t
βt 为噪声调度参数(网页4][网页5][网页8])。
1.2 逆向去噪:基于观测
o
t
\mathbf{o}_t
ot 条件去噪生成动作
p
θ
(
a
t
−
1
∣
a
t
,
o
t
)
=
N
(
a
t
−
1
;
μ
θ
(
a
t
,
o
t
,
t
)
,
Σ
t
)
p_\theta(\mathbf{a}_{t-1}|\mathbf{a}_t, \mathbf{o}_t) = \mathcal{N}(\mathbf{a}_{t-1}; \mu_\theta(\mathbf{a}_t, \mathbf{o}_t, t), \Sigma_t)
pθ(at−1∣at,ot)=N(at−1;μθ(at,ot,t),Σt)
去噪网络
μ
θ
\mu_\theta
μθ 预测噪声残差(网页5][网页6][网页8])。
1.3 训练目标:最小化噪声预测误差
L
=
E
t
,
a
0
,
ϵ
[
∥
ϵ
−
ϵ
θ
(
α
t
a
0
+
1
−
α
t
ϵ
,
o
t
,
t
)
∥
2
]
\mathcal{L} = \mathbb{E}_{t,\mathbf{a}_0,\epsilon}\left[ \|\epsilon - \epsilon_\theta(\sqrt{\alpha_t}\mathbf{a}_0 + \sqrt{1-\alpha_t}\epsilon, \mathbf{o}_t, t)\|^2 \right]
L=Et,a0,ϵ[∥ϵ−ϵθ(αta0+1−αtϵ,ot,t)∥2]
其中
α
t
=
∏
s
=
1
t
(
1
−
β
s
)
\alpha_t = \prod_{s=1}^t (1-\beta_s)
αt=∏s=1t(1−βs)(网页4][网页8][网页11])。
2. 关键数学公式
• 前向过程(扩散过程):
q(a_t|a_{t-1}) = N(a_t; √(α_t)a_{t-1}, (1-α_t)I)
α_t = 1 - β_t,ᾱ_t = ∏_{i=1}^t α_i
a_t = √ᾱ_t a_0 + √(1-ᾱ_t)ε,其中ε ~ N(0,I)
• 训练目标(噪声预测):
L = ||ε - ε_θ(a_t, s, t)||^2
• 逆向过程(采样过程):
p_θ(a_{t-1}|a_t) = N(a_{t-1}; μ_θ(a_t, s, t), Σ_t)
μ_θ = 1/√α_t (a_t - β_t/√(1-ᾱ_t) ε_θ)
逐行代码注释
import torch
import gymnasium as gym
import numpy as np
class DiffusionPolicy(torch.nn.Module):
def __init__(self, state_dim=4, action_dim=2, T=20):
super().__init__()
self.T = T # 扩散过程总步数
self.betas = torch.linspace(1e-4, 0.02, T) # 噪声方差调度
self.alphas = 1 - self.betas # 前向过程参数
self.alpha_bars = torch.cumprod(self.alphas, dim=0) # 累积乘积ᾱ
# 去噪网络(输入维度:state(4) + action(2) + timestep(1) = 7)
self.denoiser = torch.nn.Sequential(
torch.nn.Linear(7, 64), # 输入层
torch.nn.ReLU(), # 激活函数
torch.nn.Linear(64, 2) # 输出预测的噪声
)
self.optimizer = torch.optim.Adam(self.denoiser.parameters(), lr=1e-3)
def train_step(self, states, expert_actions):
batch_size = states.size(0)
t = torch.randint(0, self.T, (batch_size,)) # 随机采样时间步
alpha_bar_t = self.alpha_bars[t].unsqueeze(1) # 获取对应ᾱ_t
# 前向加噪(公式实现)
noise = torch.randn_like(expert_actions) # 生成高斯噪声
noisy_actions = torch.sqrt(alpha_bar_t) * expert_actions + \
torch.sqrt(1 - alpha_bar_t) * noise # 公式(2)
# 输入拼接(状态、加噪动作、归一化时间步)
inputs = torch.cat([
states,
noisy_actions,
(t.float() / self.T).unsqueeze(1) # 时间步归一化到[0,1]
], dim=1) # 最终维度:batch_size x 7
pred_noise = self.denoiser(inputs) # 预测噪声
loss = torch.mean((noise - pred_noise)**2) # MSE损失
return loss
def sample_action(self, state):
state_tensor = torch.FloatTensor(state).unsqueeze(0)
a_t = torch.randn(1, 2) # 初始化为随机噪声(动作维度2)
# 逆向去噪过程(需要补全)
for t in reversed(range(self.T)):
# 应实现的步骤:
# 1. 获取当前时间步参数
# 2. 拼接输入(状态,当前动作,时间步)
# 3. 预测噪声ε_θ
# 4. 根据公式计算均值μ
# 5. 采样新动作(最后一步不添加噪声)
pass
return a_t.detach().numpy()[0] # 返回最终动作
执行过程详解
训练流程
- 随机采样时间步:为每个样本随机选择扩散步t ∈ [0, T-1]
- 前向加噪:根据公式将专家动作添加对应程度的噪声
- 输入构造:拼接状态、加噪动作和归一化时间步
- 噪声预测:神经网络预测添加的噪声
- 损失计算:最小化预测噪声与真实噪声的MSE
采样流程(需补全)
- 初始化:从高斯噪声开始
- 迭代去噪:从t=T到t=1逐步去噪
• 根据当前动作和状态预测噪声
• 计算前一步的均值
• 添加随机噪声(最后一步除外) - 输出:得到最终去噪后的动作
关键改进建议
- 实现逆向过程:需要补充时间步循环和去噪公式
- 添加方差调度:在采样时使用更复杂的方差计算
- 时间步嵌入:可以使用正弦位置编码代替简单归一化
- 网络结构优化:考虑使用Transformer或条件批归一化
该实现展示了扩散策略的核心思想,但完整的扩散策略还需要实现完整的逆向采样过程,并可能需要调整噪声调度参数以获得更好的性能。
最终可执行代码:
import torch
import gymnasium as gym
import numpy as np
class DiffusionPolicy(torch.nn.Module):
def __init__(self, state_dim=4, action_dim=2, T=20):
super().__init__()
self.T = T
self.betas = torch.linspace(1e-4, 0.02, T)
self.alphas = 1 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
# 去噪网络(输入维度:4+2+1=7)
self.denoiser = torch.nn.Sequential(
torch.nn.Linear(7, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, 2)
)
self.optimizer = torch.optim.Adam(self.denoiser.parameters(), lr=1e-3)
def train_step(self, states, expert_actions):
batch_size = states.size(0)
t = torch.randint(0, self.T, (batch_size,))
alpha_bar_t = self.alpha_bars[t].unsqueeze(1)
# 前向加噪公式[2](@ref)
noise = torch.randn_like(expert_actions)
noisy_actions = torch.sqrt(alpha_bar_t) * expert_actions + torch.sqrt(1 - alpha_bar_t) * noise
# 输入拼接(维度对齐)[1](@ref)
inputs = torch.cat([
states,
noisy_actions,
(t.float() / self.T).unsqueeze(1)
], dim=1) # 最终维度:batch_size x 7
pred_noise = self.denoiser(inputs)
loss = torch.mean((noise - pred_noise)**2)
return loss
def sample_action(self, state):
state_tensor = torch.FloatTensor(state).unsqueeze(0)
a_t = torch.randn(1, 2) # 二维动作空间[2](@ref)
# 逆向去噪过程[2](@ref)
for t in reversed(range(self.T)):
alpha_t = self.alphas[t]
alpha_bar_t = self.alpha_bars[t]
inputs = torch.cat([
state_tensor,
a_t,
torch.tensor([[t / self.T]], dtype=torch.float32)
], dim=1)
pred_noise = self.denoiser(inputs)
a_t = (a_t - (1 - alpha_t)/torch.sqrt(1 - alpha_bar_t) * pred_noise) / torch.sqrt(alpha_t)
if t > 0:
a_t += torch.sqrt(self.betas[t]) * torch.randn_like(a_t)
return torch.argmax(a_t).item() # 离散动作选择[1](@ref)
if __name__ == "__main__":
env = gym.make('CartPole-v1')
policy = DiffusionPolicy()
# 关键修复:确保状态数据维度统一[1,2](@ref)
states, actions = [], []
state, _ = env.reset()
for _ in range(1000):
action = env.action_space.sample()
next_state, _, terminated, truncated, _ = env.step(action)
done = terminated or truncated
# 强制转换状态为numpy数组并检查维度[2](@ref)
state = np.array(state, dtype=np.float32).flatten()
if len(state) != 4:
raise ValueError(f"Invalid state shape: {state.shape}")
states.append(state) # 确保每个状态是(4,)的数组
actions.append(action)
if done:
state, _ = env.reset()
else:
state = next_state
# 维度验证与转换[1](@ref)
states_array = np.stack(states) # 强制转换为(1000,4)
if states_array.shape != (1000,4):
raise ValueError(f"States shape error: {states_array.shape}")
actions_onehot = np.eye(2)[np.array(actions)] # 转换为one-hot编码[2](@ref)
states_tensor = torch.FloatTensor(states_array)
actions_tensor = torch.FloatTensor(actions_onehot)
# 训练循环
for epoch in range(100):
loss = policy.train_step(states_tensor, actions_tensor)
policy.optimizer.zero_grad()
loss.backward()
policy.optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
# 测试
state, _ = env.reset()
for _ in range(200):
action = policy.sample_action(state)
state, _, done, _, _ = env.step(action)
if done: break