注意:本文仅供自己记录学习过程使用。
训练
全参训练过程
- 输入图像用VAE编码得到输入的x_start(1,4,128,128);文本的两个特征:bert的encoder feature(1,77,1024)和T5 的feature(1,256,2048),和旋转位置编码
freqs_cis_img
:cos
(4096,88),sin
(4096,88)。 - 生成随机的时间步长t;生成随机的噪声(1,4,128,128),给输入的x_start加上噪声得到输出的x_t;
def q_sample(self, x_start, t, noise=None):
"""
Diffuse the data for a given number of diffusion steps.
In other words, sample from q(x_t | x_0).
:param x_start: the initial data batch.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:param noise: if specified, the split-out normal noise.
:return: A noisy version of x_start.
"""
if noise is None:
noise = th.randn_like(x_start)
assert_shape(noise, x_start)
return (
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
* noise
)
- 对T5 的feature(1,256,2048)用mlp降为到(1,256,1024),然后把它和bert的feature cat起来得到
text_states
(1,33,1024); - 对时间t编码(1,1408),x_t打成path,
x
(1,4096,1048); - 对t5 feature进行pooling(multihead self-attention)得到extra_vec(1,1024);
- 时间t+mlp(extra_vec)=
c
(1,1408),得到condition; - 上述步骤已得到以下参数:
x ,c,text_states,freqs_cis_img
。开始迭代处理。
x = block(x, c, text_states, freqs_cis_img)
- mlp(c)+x得到self-attention block的输入,把输入分成q/k/v,然后把q/k用
旋转位置编码
进行编码,得到新的qk。然后mlp提特征,输出x(1,4096,1408);简单来说,就是输入的x和文本的全局特征做了一次注意力提取特征的操作;
def forward(self, x, freqs_cis_img=None):
"""
Parameters
----------
x: torch.Tensor
(batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
freqs_cis_img: torch.Tensor
(batch, hidden_dim // 2), RoPE for image
"""
b, s, d = x.shape
qkv = self.Wqkv(x)
qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, h, d]
q, k, v = qkv.unbind(dim=2) # [b, s, h, d]
q = self.q_norm(q).half() # [b, s, h, d]
k = self.k_norm(k).half()
# Apply RoPE if needed
if freqs_cis_img is not None:
qq, kk = apply_rotary_emb(q, k, freqs_cis_img)
assert qq.shape == q.shape and kk.shape == k.shape, f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
q, k = qq, kk
qkv = torch.stack([q, k, v], dim=2) # [b, s, 3, h, d]
context = self.inner_attn(qkv)
out = self.out_proj(context.view(b, s, d))
out = self.proj_drop(out)
out_tuple = (out,)
return out_tuple
def apply_rotary_emb(
xq: torch.Tensor,
xk: Optional[torch.Tensor],
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
if xk is not None:
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
else:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
if xk is not None:
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
- 上面得到的x,以及文本特征text_states,和旋转位置编码freqs_cis_img,作为cross attention block的输入;y是文本特征text_states;x作为q, text_states作为kv,q加上位置编码后,和kv作cross attention,得到输出x(1,4096,1408);
def forward(self, x, y, freqs_cis_img=None):
"""
Parameters
----------
x: torch.Tensor
(batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim)
y: torch.Tensor
(batch, seqlen2, hidden_dim2)
freqs_cis_img: torch.Tensor
(batch, hidden_dim // num_heads), RoPE for image
"""
b, s1, _ = x.shape # [b, s1, D]
_, s2, _ = y.shape # [b, s2, 1024]
q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d]
kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d]
k, v = kv.unbind(dim=2) # [b, s2, h, d]
q = self.q_norm(q).half() # [b, s1, h, d]
k = self.k_norm(k).half() # [b, s2, h, d]
# Apply RoPE if needed
if freqs_cis_img is not None:
qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
q = qq # [b, s1, h, d]
kv = torch.stack([k, v], dim=2) # [b, s1, 2, h, d]
context = self.inner_attn(q, kv) # [b, s1, h, d]
context = context.view(b, s1, -1) # [b, s1, D]
out = self.out_proj(context)
out = self.proj_drop(out)
out_tuple = (out,)
return out_tuple
- 最后mlp输出x(1,4096,1408)。共有19个hunyuan block,每个block输出的都是(1,4096,1408);(类似于unet encoder的操作,后续就是解码了,但是它这里“编解码”并没有分辨率的概念)。
- 开始“解码操作”了,其实就是前面最后输出x和前面block的输出cat起来,然后提取特征,后续步骤和前面是一样的。
def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
# Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([x, skip], dim=-1)
cat = self.skip_norm(cat)
x = self.skip_linear(cat)
# Self-Attention
shift_msa = self.default_modulation(c).unsqueeze(dim=1)
attn_inputs = (
self.norm1(x) + shift_msa, freq_cis_img,
)
x = x + self.attn1(*attn_inputs)[0]
# Cross-Attention
cross_inputs = (
self.norm3(x), text_states, freq_cis_img
)
x = x + self.attn2(*cross_inputs)[0]
# FFN Layer
mlp_inputs = self.norm2(x)
x = x + self.mlp(mlp_inputs)
return x
- 最后整个网络输出(1,8,128,128);
- 网络的输出前4个通道(1,4,128,128)和输入的纯净的x_start作mse loss,后四个通道作什么变分概率误差;至此训练完成;
lora训练过程
训练过程和全参一样,低秩矩阵调用库peft训练的,略;
controlnet训练过程
- 架构和hunyuandit一致;
- 1-6步和全参训练一样,第六步后有个VAE编码后的control img(1,4,128,128)作为condition, 把它和x_t相加+,得到网络的输入
x
;其他c,text_states,freqs_cis_img
和之前一样;
condition = self.x_embedder(condition)
# ========================= Forward pass through HunYuanDiT blocks =========================
controls = []
x = x + self.before_proj(condition) # add condition
for layer, block in enumerate(self.blocks):
x = block(x, c, text_states, freqs_cis_img)
controls.append(self.after_proj_list[layer](x)) # zero linear for output
- 输出19个block的control feature;与冻结后的hunyuandit的“解码层”特征相加即可;
for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
if controls is not None:
skip = skips.pop() + controls.pop()
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
if layer < (self.depth // 2 - 1):
skips.append(x)
- 损失和之前一致,训练完毕;