DDPM(Denoising Diffusion Probabilistic Models)
笔记来源:
1.Denoising Diffusion Probabilistic Models
2.大白话AI | 图像生成模型DDPM | 扩散模型 | 生成模型 | 概率扩散去噪生成模型
3.pytorch-stable-diffusion
扩散模型正向过程(Forward Diffusion Process)
给某张图片加噪的具体操作
由前一个
x
t
−
1
x_{t-1}
xt−1 推导后一个
x
t
x_t
xt
经过一番推导(详见下文),我们直接由第一个
x
0
x_0
x0 推导第
t
t
t 个结果
x
t
x_t
xt
DDPM的主要作用:
(1) Add noise to clear image
x
0
x_0
x0
(2) calculate
μ
t
~
\tilde{\mu_t}
μt~ (mean) and
β
t
~
\tilde{\beta_t}
βt~ (variance) for distribution
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
x
t
−
1
;
μ
t
~
,
β
t
~
I
)
q(x_{t-1}|x_t,x_0) = N(x_{t-1};\tilde{\mu_t},\tilde{\beta_t}I)
q(xt−1∣xt,x0)=N(xt−1;μt~,βt~I)
(3) update
μ
t
~
\tilde{\mu_t}
μt~ (mean)
(1) Add noise to clear image using function def add_noise()
上图加噪公式的推导过程见下图
实现 add_noise(clear image: :
x
0
x_0
x0, timesteps: t)
class DDPMSampler:
def __init__(...):
...
def set_inference_timesteps(...): # Set the number of inference timesteps for the DDPM model.
...
def _get_previous_timestep(...): # Calculate the previous timestep for the given timestep
...
def _get_variance(...): # Calculate the variance for the given timestep
...
def set_strength(...): # Set how much noise to add to the input image.
...
def step(...): # Perform one step of the diffusion (forward) process.
...
def add_noise( # Add noise to the original samples according to the diffusion (forward) process.
self,
original_samples: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
"""
Add noise to the original samples according to the diffusion process.
Args:
- original_samples (torch.FloatTensor): The original samples (images) to which noise will be added.
- timesteps (torch.IntTensor): The timesteps at which the noise will be added.
Returns:
- torch.FloatTensor: The noisy samples.
"""
# Retrieve the cumulative product of alphas on the same device and with the same dtype as the original samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
# Move timesteps to the same device as the original samples
timesteps = timesteps.to(original_samples.device)
# Compute the square root of the cumulative product of alphas for the given timesteps
# sqert{hat_alpha_t}
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
# Flatten sqrt_alpha_prod to ensure it's a 1D tensor
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
# Reshape sqrt_alpha_prod to match the dimensions of original_samples
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
# Compute the square root of (1 - cumulative product of alphas) for the given timesteps
# sqrt{1-hat_alpha_t}
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
# Flatten sqrt_one_minus_alpha_prod to ensure it's a 1D tensor
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
# Reshape sqrt_one_minus_alpha_prod to match the dimensions of original_samples
# checks if the number of dimensions of sqrt_alpha_prod is less than the number of dimensions of original_samples
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
# Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
# Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
# here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
# Sample noise from a normal distribution with the same shape as the original samples
noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
# sqrt_alpha_prod * original_samples (This represents the mean component in the noisy sample calculation.)
# This term scales the original samples by the square root of the cumulative product of alphas for the given timesteps.
# sqrt_one_minus_alpha_prod * noise (This represents the variance component in the noisy sample calculation.)
# This term scales the random noise by the square root of (1 - cumulative product of alphas) for the given timesteps.
# sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
# adds the scaled noise to the scaled original samples. This operation forms the noisy samples,
# where the influence of the original samples and the noise varies according to the timesteps.
# x_t = sqrt{hat_alpha_t} * x_0 + sqrt{1-hat_alpha_t} * epsilon
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples
(2) calculate
μ
t
~
\tilde{\mu_t}
μt~ (mean) and
β
t
~
\tilde{\beta_t}
βt~ (variance) for distribution
q
(
x
t
−
1
∣
x
t
,
x
0
)
=
N
(
x
t
−
1
;
μ
t
~
,
β
t
~
I
)
q(x_{t-1}|x_t,x_0) = N(x_{t-1};\tilde{\mu_t},\tilde{\beta_t}I)
q(xt−1∣xt,x0)=N(xt−1;μt~,βt~I)
Note: N(output; mean, variance)
\text{Note: N(output; mean, variance)}
Note: N(output; mean, variance)
求上述概率分布的均值和方差的推导过程见下图
实现 _get_variance() 计算方差,实现 step() 计算均值并更新均值
class DDPMSampler:
def __init__(...):
...
def set_inference_timesteps(...): # Set the number of inference timesteps for the DDPM model.
...
def _get_previous_timestep(...): # Calculate the previous timestep for the given timestep
...
def _get_variance(...): # Calculate the variance for the given timestep
...
def set_strength(...): # Set how much noise to add to the input image.
...
def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
"""
Perform one step of the diffusion (forward) process.
Args:
- timestep (int): The current timestep during diffusion.
- latents (torch.Tensor): The latent representation of the input.
- model_output (torch.Tensor): The output from the diffusion model.
"""
t = timestep
# Get the previous timestep using the _get_previous_timestep method
prev_t = self._get_previous_timestep(t)
# 1. compute alphas, betas
# hat_alpha_t
alpha_prod_t = self.alphas_cumprod[t]
# hat_alpha_{t-1}
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
# hat_beta_t = 1 - hat_alpha_t
beta_prod_t = 1 - alpha_prod_t
# hat_beta_{t-1} = 1 - hat_alpha_{t-1}
beta_prod_t_prev = 1 - alpha_prod_t_prev
# alpha_prod_t / alpha_prod_t_prev = (alpha_t*alpha_{t-1}*...*alpha_1) / (alpha_{t-1}*...*alpha_1) = alpha_t
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
# beta_t = 1- alpha_t
current_beta_t = 1 - current_alpha_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
# x_t = sqrt{1 - hat_alpha_t}* epsilon + sqrt{hat_alpha_t} * x_0
# x_0 = (x_t - sqrt{1 - hat_alpha_t} * epsilon(x_t)) / sqrt{hat_alpha_t}
# x_0 = (x_t - sqrt{hat_beta_t} * epsilon(x_t)) / sqrt{hat_alpha_t}
pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
# x_{t-1} ~ p_{theta}(x_{t-1} | x_t) a distribution with regard to x_{t-1} during reverse process
# = N (1/sqrt{alpha_t} * x_t - (beta_t)/(sqrt{alpha_t}sqrt{1-hat_alpha_t} * epsilon(x_t,t))
# , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}) )
# x_{t-1} ~ q(x_{t-1} | x_t,x_0) a distribution with regard to x_{t-1} during forward process
# = N (frac{sqrt{hat_alpha_{t-1}}beta_t}{1-hat_alpha_t}x_0+frac{sqrt{alpha_t}(1-hat_alpha_{t-1})}{1-hat_alphat_t}*x_t
# , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}))
# frac{sqrt{hat_alpha_{t-1}}beta_t}{1-hat_alpha_t}
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
# frac{sqrt{alpha_t}(1-hat_alpha_{t-1})}{1-hat_alphat_t}
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
# pred_mu_t = coeff_1 * x_0 + coeff_2 * x_t
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
# 6. Update pred_mu_t according to pred_beta_t
...
def add_noise(...):
...
为何我们要计算概率分布
q
(
x
t
−
1
∣
x
t
,
x
0
)
q(x_{t-1}|x_t,x_0)
q(xt−1∣xt,x0)?
Stable Diffusion 的 Loss Funtion 推导中会出现一个KL散度项,此项衡量两个分布的相似性,以此来不断引导反向过程生成最终的图片,具体解释见后续博客
(3) update
μ
t
~
\tilde{\mu_t}
μt~ (mean)
μ
~
t
=
μ
~
t
+
β
t
~
2
×
ϵ
(
Note:
ϵ
∼
N
(
0
,
1
)
)
μ
~
t
=
μ
~
t
+
β
t
~
×
ϵ
\tilde{\mu}_t = \tilde{\mu}_t + \sqrt{\tilde{\beta_t}^2}×\epsilon\ \left(\text{Note: }\epsilon \sim N(0,1)\right)\\ \tilde{\mu}_t = \tilde{\mu}_t + \tilde{\beta_t}×\epsilon
μ~t=μ~t+βt~2×ϵ (Note: ϵ∼N(0,1))μ~t=μ~t+βt~×ϵ
class DDPMSampler:
def __init__(...):
...
def set_inference_timesteps(...): # Set the number of inference timesteps for the DDPM model.
...
def _get_previous_timestep(...): # Calculate the previous timestep for the given timestep
...
def _get_variance(...): # Calculate the variance for the given timestep
...
def set_strength(...): # Set how much noise to add to the input image.
...
def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
"""
Perform one step of the diffusion (forward) process.
Args:
- timestep (int): The current timestep during diffusion.
- latents (torch.Tensor): The latent representation of the input.
- model_output (torch.Tensor): The output from the diffusion model.
"""
...
...
...
# 6. Update pred_mu_t according to pred_beta_t
variance = 0
if t > 0:
# Get the device of model_output
device = model_output.device
# Generate random noise with the same shape as model_output
noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
# Compute the variance for the current timestep as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
# sqrt{sigma_t}*epsilon
variance = (self._get_variance(t) ** 0.5) * noise
# Add the variance (multiplied by noise) to the predicted previous sample
# sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
# the variable "variance" is already multiplied by the noise N(0, 1)
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
# pred_mu_t = pred_mu_t + sqrt{pred_beta_t^2} * epsilon (Note:epsilon ~N(0,1))
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample
def add_noise(...):
...
All of codes about DDPM (ddpm,.py)
import torch
import numpy as np
'''
# Forward Process
# Add noise to clear image and calculate pred_mu_t and pred_beta_t for distribution and update pred_mu_t
# (1) Add noise to clear image using function def add_noise()
# x_t = sqrt{hat_alpha_t} * x_0 + sqrt{1-hat_alpha_t} * epsilon (Note:epsilon~N(0,1))
# see formula (4) from https://arxiv.org/pdf/2006.11239.pdf
# (2) calculate pred_mu_t and pred_beta_t for distribution
# q(x_{t-1}|x_t,x_0) = N(pred_mu_t,pred_beta_t*I)
# def step()
# predicted_mu_t = coeff_1 * x_0 + coeff_2 * x_t
# def _get_variance()
# predicted_variance beta_t=(1-hat_alpha_{t-1})/(1-hat_alpha_t)*beta_t
# (3) update pred_mu_t
# def step()
# update pred_mu_t = pred_mu_t + sqrt{pred_beta_t^2} * noise (Note:noise ~ N(0,1))
# see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf
'''
class DDPMSampler:
def __init__(self, generator: torch.Generator, num_training_steps=1000, beta_start: float = 0.00085, beta_end: float = 0.0120):
# Params "beta_start" and "beta_end" taken from:
# https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L5C8-L5C8
# For the naming conventions, refer to the DDPM paper (https://arxiv.org/pdf/2006.11239.pdf)
"""
Initialize the DDPM (Denoising Diffusion Probabilistic Model) parameters.
Args:
- generator (torch.Generator): A PyTorch random number generator.
- num_training_steps (int, optional): Number of training steps. Default is 1000.
- beta_start (float, optional): The starting value of beta. Default is 0.00085.
- beta_end (float, optional): The ending value of beta. Default is 0.0120.
"""
self.betas = torch.linspace(beta_start ** 0.5, beta_end ** 0.5, num_training_steps, dtype=torch.float32) ** 2
# alppha = 1 - beta
self.alphas = 1.0 - self.betas
# hat_alpha = alpha_t * alpha_ {t-1} * ... * alpha_2 * alpha_1
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
# Define a tensor representing the value 1.0
self.one = torch.tensor(1.0)
# Store the generator for random number generation
self.generator = generator
# Number of training timesteps
self.num_train_timesteps = num_training_steps
# Create a tensor of timesteps in reverse order
self.timesteps = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())
def set_inference_timesteps(self, num_inference_steps=50):
"""
Set the number of inference timesteps for the DDPM model.
Args:
- num_inference_steps (int, optional): Number of steps to use during inference. Default is 50.
"""
# Store the number of inference steps
self.num_inference_steps = num_inference_steps
# Calculate the ratio between training timesteps and inference timesteps
step_ratio = self.num_train_timesteps // self.num_inference_steps
# Generate an array of timesteps for inference:
# - np.arange(0, num_inference_steps): Create an array from 0 to num_inference_steps-1
# - Multiply by step_ratio to space out the timesteps
# - round() to ensure the timesteps are integers
# - [::-1] to reverse the order, as inference typically proceeds backward through the timesteps
# - copy() to ensure the array is contiguous in memory
# - astype(np.int64) to ensure the timesteps are of type int64, which is compatible with PyTorch
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
# Convert the numpy array of timesteps to a PyTorch tensor
self.timesteps = torch.from_numpy(timesteps)
def _get_previous_timestep(self, timestep: int) -> int:
"""
Calculate the previous timestep for the given timestep during inference.
Args:
- timestep (int): The current timestep during inference.
Returns:
- int: The previous timestep during inference.
"""
# Calculate the previous timestep by subtracting the step ratio from the current timestep.
# The step ratio is the integer division of the total number of training timesteps by the number of inference timesteps.
# timstep t-1 = timestep t - ratio
prev_t = timestep - self.num_train_timesteps // self.num_inference_steps
return prev_t
def _get_variance(self, timestep: int) -> torch.Tensor:
"""
Calculate the variance for the given timestep during inference.
Args:
- timestep (int): The current timestep during inference.
Returns:
- torch.Tensor: The variance for the given timestep.
"""
# Get the previous timestep using the _get_previous_timestep method
prev_t = self._get_previous_timestep(timestep)
# Retrieve the cumulative product of alphas at the current and previous timesteps
# hat_alpha_t
alpha_prod_t = self.alphas_cumprod[timestep]
# hat_alpha_{t-1}
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
# alpha_prod_t / alpha_prod_t_prev = (alpha_t*alpha_{t-1}*...*alpha_1) / (alpha_{t-1}*...*alpha_1) = alpha_t
# beta_t = 1- alpha_t
current_beta_t = 1 - alpha_prod_t / alpha_prod_t_prev
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
# x_{t-1} ~ P(x_{t-1} | x_t,x_0)
# = N (mu, sigma)
# = N (1/sqrt{alpha_t} * x_t - (beta_t)/(sqrt{alpha_t}sqrt{1-hat_alpha_t} * epsilon)
# , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}) )
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = (1 - alpha_prod_t_prev) / (1 - alpha_prod_t) * current_beta_t
# Clamp the variance to ensure it's not zero, as we will take its log later
variance = torch.clamp(variance, min=1e-20)
return variance
def set_strength(self, strength=1):
"""
Set how much noise to add to the input image.
Args:
- strength (float, optional): A value between 0 and 1 indicating the amount of noise to add.
- A strength value close to 1 means the output will be further from the input image (more noise).
- A strength value close to 0 means the output will be closer to the input image (less noise).
"""
# Calculate the number of inference steps to skip based on the strength
# Higher strength means fewer steps skipped (more noise added)
# start_step is the number of noise levels to skip
start_step = self.num_inference_steps - int(self.num_inference_steps * strength)
# Update the timesteps to start from the calculated step
# This effectively sets the starting point for the noise addition process
self.timesteps = self.timesteps[start_step:]
# Store the starting step for reference
self.start_step = start_step
def step(self, timestep: int, latents: torch.Tensor, model_output: torch.Tensor):
"""
Perform one step of the diffusion(forward) process.
Args:
- timestep (int): The current timestep during diffusion.
- latents (torch.Tensor): The latent representation of the input.
- model_output (torch.Tensor): The output from the diffusion model.
"""
t = timestep
# Get the previous timestep using the _get_previous_timestep method
prev_t = self._get_previous_timestep(t)
# 1. compute alphas, betas
# hat_alpha_t
alpha_prod_t = self.alphas_cumprod[t]
# hat_alpha_{t-1}
alpha_prod_t_prev = self.alphas_cumprod[prev_t] if prev_t >= 0 else self.one
# hat_beta_t = 1 - hat_alpha_t
beta_prod_t = 1 - alpha_prod_t
# hat_beta_{t-1} = 1 - hat_alpha_{t-1}
beta_prod_t_prev = 1 - alpha_prod_t_prev
# alpha_prod_t / alpha_prod_t_prev = (alpha_t*alpha_{t-1}*...*alpha_1) / (alpha_{t-1}*...*alpha_1) = alpha_t
current_alpha_t = alpha_prod_t / alpha_prod_t_prev
# beta_t = 1- alpha_t
current_beta_t = 1 - current_alpha_t
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
# x_t = sqrt{1 - hat_alpha_t}* epsilon + sqrt{hat_alpha_t} * x_0
# x_0 = (x_t - sqrt{1 - hat_alpha_t} * epsilon(x_t)) / sqrt{hat_alpha_t}
# x_0 = (x_t - sqrt{hat_beta_t} * epsilon(x_t)) / sqrt{hat_alpha_t}
pred_original_sample = (latents - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
# x_{t-1} ~ p_{theta}(x_{t-1} | x_t) a distribution with regard to x_{t-1} during reverse process
# = N (1/sqrt{alpha_t} * x_t - (beta_t)/(sqrt{alpha_t}sqrt{1-hat_alpha_t} * epsilon(x_t,t))
# , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}) )
# x_{t-1} ~ q(x_{t-1} | x_t,x_0) a distribution with regard to x_{t-1} during forward process
# = N (frac{sqrt{hat_alpha_{t-1}}beta_t}{1-hat_alpha_t}x_0+frac{sqrt{alpha_t}(1-hat_alpha_{t-1})}{1-hat_alphat_t}*x_t
# , (beta_t * 1-hat_alpha_{t-1})/(1-hat_alpha_{t}))
# frac{sqrt{hat_alpha_{t-1}}beta_t}{1-hat_alpha_t}
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * current_beta_t) / beta_prod_t
# frac{sqrt{alpha_t}(1-hat_alpha_{t-1})}{1-hat_alphat_t}
current_sample_coeff = current_alpha_t ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
# pred_mu_t = coeff_1 * x_0 + coeff_2 * x_t
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * latents
# 6. Update pred_mu_t according to pred_beta_t
variance = 0
if t > 0:
# Get the device of model_output
device = model_output.device
# Generate random noise with the same shape as model_output
noise = torch.randn(model_output.shape, generator=self.generator, device=device, dtype=model_output.dtype)
# Compute the variance for the current timestep as per formula (7) from https://arxiv.org/pdf/2006.11239.pdf
# sqrt{sigma_t}*epsilon
variance = (self._get_variance(t) ** 0.5) * noise
# Add the variance (multiplied by noise) to the predicted previous sample
# sample from N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
# the variable "variance" is already multiplied by the noise N(0, 1)
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
# pred_mu_t = pred_mu_t + sqrt{pred_beta_t^2} * epsilon (Note:epsilon ~N(0,1))
pred_prev_sample = pred_prev_sample + variance
return pred_prev_sample
def add_noise(
self,
original_samples: torch.FloatTensor,
timesteps: torch.IntTensor,
) -> torch.FloatTensor:
"""
Add noise to the original samples according to the diffusion process.
Args:
- original_samples (torch.FloatTensor): The original samples (images) to which noise will be added.
- timesteps (torch.IntTensor): The timesteps at which the noise will be added.
Returns:
- torch.FloatTensor: The noisy samples.
"""
# Retrieve the cumulative product of alphas on the same device and with the same dtype as the original samples
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
# Move timesteps to the same device as the original samples
timesteps = timesteps.to(original_samples.device)
# Compute the square root of the cumulative product of alphas for the given timesteps
# sqert{hat_alpha_t}
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
# Flatten sqrt_alpha_prod to ensure it's a 1D tensor
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
# Reshape sqrt_alpha_prod to match the dimensions of original_samples
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
# Compute the square root of (1 - cumulative product of alphas) for the given timesteps
# sqrt{1-hat_alpha_t}
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
# Flatten sqrt_one_minus_alpha_prod to ensure it's a 1D tensor
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
# Reshape sqrt_one_minus_alpha_prod to match the dimensions of original_samples
# checks if the number of dimensions of sqrt_alpha_prod is less than the number of dimensions of original_samples
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
# Sample from q(x_t | x_0) as in equation (4) of https://arxiv.org/pdf/2006.11239.pdf
# Because N(mu, sigma) = X can be obtained by X = mu + sigma * N(0, 1)
# here mu = sqrt_alpha_prod * original_samples and sigma = sqrt_one_minus_alpha_prod
# Sample noise from a normal distribution with the same shape as the original samples
noise = torch.randn(original_samples.shape, generator=self.generator, device=original_samples.device, dtype=original_samples.dtype)
# sqrt_alpha_prod * original_samples (This represents the mean component in the noisy sample calculation.)
# This term scales the original samples by the square root of the cumulative product of alphas for the given timesteps.
# sqrt_one_minus_alpha_prod * noise (This represents the variance component in the noisy sample calculation.)
# This term scales the random noise by the square root of (1 - cumulative product of alphas) for the given timesteps.
# sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
# adds the scaled noise to the scaled original samples. This operation forms the noisy samples,
# where the influence of the original samples and the noise varies according to the timesteps.
# x_t = sqrt{hat_alpha_t} * x_0 + sqrt{1-hat_alpha_t} * epsilon
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
return noisy_samples