腾讯HunyuanDit代码解析

news2025/1/18 21:06:40

注意:本文仅供自己记录学习过程使用。

训练

全参训练过程

  1. 输入图像用VAE编码得到输入的x_start(1,4,128,128);文本的两个特征:bert的encoder feature(1,77,1024)和T5 的feature(1,256,2048),和旋转位置编码freqs_cis_img: cos (4096,88),sin (4096,88)。
  2. 生成随机的时间步长t;生成随机的噪声(1,4,128,128),给输入的x_start加上噪声得到输出的x_t;
    def q_sample(self, x_start, t, noise=None):
        """
        Diffuse the data for a given number of diffusion steps.

        In other words, sample from q(x_t | x_0).

        :param x_start: the initial data batch.
        :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
        :param noise: if specified, the split-out normal noise.
        :return: A noisy version of x_start.
        """
        if noise is None:
            noise = th.randn_like(x_start)
        assert_shape(noise, x_start)
        return (
            _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
            + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
            * noise
        )

  1. 对T5 的feature(1,256,2048)用mlp降为到(1,256,1024),然后把它和bert的feature cat起来得到text_states (1,33,1024);
  2. 对时间t编码(1,1408),x_t打成path,x(1,4096,1048);
  3. 对t5 feature进行pooling(multihead self-attention)得到extra_vec(1,1024);
  4. 时间t+mlp(extra_vec)=c(1,1408),得到condition;
  5. 上述步骤已得到以下参数:x ,c,text_states,freqs_cis_img。开始迭代处理。
x = block(x, c, text_states, freqs_cis_img)
  1. mlp(c)+x得到self-attention block的输入,把输入分成q/k/v,然后把q/k用旋转位置编码进行编码,得到新的qk。然后mlp提特征,输出x(1,4096,1408);简单来说,就是输入的x和文本的全局特征做了一次注意力提取特征的操作;
    def forward(self, x, freqs_cis_img=None):
        """
        Parameters
        ----------
        x: torch.Tensor
            (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
        freqs_cis_img: torch.Tensor
            (batch, hidden_dim // 2), RoPE for image
        """
        b, s, d = x.shape

        qkv = self.Wqkv(x)
        qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim)  # [b, s, 3, h, d]
        q, k, v = qkv.unbind(dim=2) # [b, s, h, d]
        q = self.q_norm(q).half()   # [b, s, h, d]
        k = self.k_norm(k).half()

        # Apply RoPE if needed
        if freqs_cis_img is not None:
            qq, kk = apply_rotary_emb(q, k, freqs_cis_img)
            assert qq.shape == q.shape and kk.shape == k.shape, f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}'
            q, k = qq, kk

        qkv = torch.stack([q, k, v], dim=2)     # [b, s, 3, h, d]
        context = self.inner_attn(qkv)
        out = self.out_proj(context.view(b, s, d))
        out = self.proj_drop(out)

        out_tuple = (out,)

        return out_tuple
def apply_rotary_emb(
        xq: torch.Tensor,
        xk: Optional[torch.Tensor],
        freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
        head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply rotary embeddings to input tensors using the given frequency tensor.

    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
    returned as real tensors.

    Args:
        xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
        xk (torch.Tensor): Key tensor to apply rotary embeddings.   [B, S, H, D]
        freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials.
        head_first (bool): head dimension first (except batch dim) or not.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.

    """
    xk_out = None
    if isinstance(freqs_cis, tuple):
        cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first)    # [S, D]
        cos, sin = cos.to(xq.device), sin.to(xq.device)
        xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
        if xk is not None:
            xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
    else:
        xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))  # [B, S, H, D//2]
        freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device)   # [S, D//2] --> [1, S, 1, D//2]
        xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
        if xk is not None:
            xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))  # [B, S, H, D//2]
            xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)

    return xq_out, xk_out
  1. 上面得到的x,以及文本特征text_states,和旋转位置编码freqs_cis_img,作为cross attention block的输入;y是文本特征text_states;x作为q, text_states作为kv,q加上位置编码后,和kv作cross attention,得到输出x(1,4096,1408);
    def forward(self, x, y, freqs_cis_img=None):
        """
        Parameters
        ----------
        x: torch.Tensor
            (batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim)
        y: torch.Tensor
            (batch, seqlen2, hidden_dim2)
        freqs_cis_img: torch.Tensor
            (batch, hidden_dim // num_heads), RoPE for image
        """
        b, s1, _ = x.shape     # [b, s1, D]
        _, s2, _ = y.shape     # [b, s2, 1024]

        q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim)       # [b, s1, h, d]
        kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim)  # [b, s2, 2, h, d]
        k, v = kv.unbind(dim=2)                 # [b, s2, h, d]
        q = self.q_norm(q).half()               # [b, s1, h, d]
        k = self.k_norm(k).half()               # [b, s2, h, d]

        # Apply RoPE if needed
        if freqs_cis_img is not None:
            qq, _ = apply_rotary_emb(q, None, freqs_cis_img)
            assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}'
            q = qq                              # [b, s1, h, d]
        kv = torch.stack([k, v], dim=2)         # [b, s1, 2, h, d]
        context = self.inner_attn(q, kv)        # [b, s1, h, d]
        context = context.view(b, s1, -1)       # [b, s1, D]

        out = self.out_proj(context)
        out = self.proj_drop(out)

        out_tuple = (out,)

        return out_tuple
  1. 最后mlp输出x(1,4096,1408)。共有19个hunyuan block,每个block输出的都是(1,4096,1408);(类似于unet encoder的操作,后续就是解码了,但是它这里“编解码”并没有分辨率的概念)。
  2. 开始“解码操作”了,其实就是前面最后输出x和前面block的输出cat起来,然后提取特征,后续步骤和前面是一样的。
    def _forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None):
        # Long Skip Connection
        if self.skip_linear is not None:
            cat = torch.cat([x, skip], dim=-1)
            cat = self.skip_norm(cat)
            x = self.skip_linear(cat)

        # Self-Attention
        shift_msa = self.default_modulation(c).unsqueeze(dim=1)
        attn_inputs = (
            self.norm1(x) + shift_msa, freq_cis_img,
        )
        x = x + self.attn1(*attn_inputs)[0]

        # Cross-Attention
        cross_inputs = (
            self.norm3(x), text_states, freq_cis_img
        )
        x = x + self.attn2(*cross_inputs)[0]

        # FFN Layer
        mlp_inputs = self.norm2(x)
        x = x + self.mlp(mlp_inputs)

        return x
  1. 最后整个网络输出(1,8,128,128);
  2. 网络的输出前4个通道(1,4,128,128)和输入的纯净的x_start作mse loss,后四个通道作什么变分概率误差;至此训练完成;

lora训练过程

训练过程和全参一样,低秩矩阵调用库peft训练的,略;在这里插入图片描述

controlnet训练过程

  1. 架构和hunyuandit一致;
  2. 1-6步和全参训练一样,第六步后有个VAE编码后的control img(1,4,128,128)作为condition, 把它和x_t相加+,得到网络的输入x;其他c,text_states,freqs_cis_img和之前一样;
        condition = self.x_embedder(condition)

        # ========================= Forward pass through HunYuanDiT blocks =========================
        controls = []
        x = x + self.before_proj(condition) # add condition
        for layer, block in enumerate(self.blocks):
            x = block(x, c, text_states, freqs_cis_img)
            controls.append(self.after_proj_list[layer](x)) # zero linear for output
  1. 输出19个block的control feature;与冻结后的hunyuandit的“解码层”特征相加即可;在这里插入图片描述
        for layer, block in enumerate(self.blocks):
            if layer > self.depth // 2:
                if controls is not None:
                    skip = skips.pop() + controls.pop()
                else:
                    skip = skips.pop()
                x = block(x, c, text_states, freqs_cis_img, skip)   # (N, L, D)
            else:
                x = block(x, c, text_states, freqs_cis_img)         # (N, L, D)

            if layer < (self.depth // 2 - 1):
                skips.append(x)
  1. 损失和之前一致,训练完毕;

推理(待续~)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1979924.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

4.8.双向循环神经网络

双向循环神经网络 ​ 在序列模型中&#xff0c;我们总是关注之前的信息&#xff0c;并以此来对下一个输出进行预测&#xff0c;但可能未来的信息也很重要&#xff0c;比如文本序列填空&#xff1a; 我___。我___饿了。我___饿了&#xff0c;我可以吃半头猪。 ​ 我们可以分别…

数据安全复合治理与实践

数据安全复合治理与实践 关键要点理论与实践 本文探讨了数据安全复合治理模式的理论与实践&#xff0c;着重强调了在数字经济迅猛发展的背景下&#xff0c;数据安全的重要性以及面对数据安全挑战时所需采取的综合治理策略。首先&#xff0c;文章概述了数据安全治理的必要性&…

使用GPT-4o mini融合GraphRAG技术进行实战应用

什么是gpt-4o mini OpenAI 推出 GPT-4o mini&#xff0c;这是他们最具成本效益的小型模型。它的定价为每百万输入代币 15 美分&#xff0c;每百万输出代币 60 美分&#xff0c;比之前的 Frontier 型号便宜一个数量级&#xff0c;比 GPT-3.5 Turbo 便宜 60% 以上。目前&#xf…

Java 并发编程:一文了解 synchronized 的使用

大家好&#xff0c;我是栗筝i&#xff0c;这篇文章是我的 “栗筝i 的 Java 技术栈” 专栏的第 027 篇文章&#xff0c;在 “栗筝i 的 Java 技术栈” 这个专栏中我会持续为大家更新 Java 技术相关全套技术栈内容。专栏的主要目标是已经有一定 Java 开发经验&#xff0c;并希望进…

练题模块环境搭建

文章目录 1.数据库表设计1.practice_set 套卷2.practice_set_detail 套卷细节3.practice_info 练习信息4.practice_detail 练习详情5.E-R图 2.架构设计&#xff08;三层架构&#xff09;3.练题微服务架构搭建1.创建一个练题微服务模块1.创建一个maven项目2.把src删除&#xff0…

类中特殊变量的初始化

在C的类中有一些变量的初始化需要进行特殊化的处理&#xff0c;这里我将列举出常见的两种特殊类型的变量初始化。 目录 const 类型数据的初始化 代码实例&#xff1a; static类型数据的初始化 代码实例&#xff1a; const 类型数据的初始化 对于const修饰的数据我们需要在…

Robot Operating System——单线程中启动多个Node

在《Robot Operating System——Service的同步/异步通信》一文中&#xff0c;我们介绍了如何实现Service的客户端和服务端。在例子中&#xff0c;它们分别被编译成libadd_two_ints_client_async_library.so和libadd_two_ints_server_library.so&#xff0c;然后分别被可执行程序…

C:将代码拆分放在多个文件的操作

目录 前言&#xff1a; 1、多个文件 2、将一个程序分为多个文件的好处 3、一定程度上对代码进行隐藏 结语&#xff1a; 前言&#xff1a; 在我们刚开始学习C语言时&#xff0c;编写的代码通常比较简短&#xff0c;因此将其放在一个文件中并不会带来不便。然而&#xff0c;…

17965 幸运之星(优先做)

这个问题可以通过使用递归或者迭代的方法来解决。我们可以使用一个一维数组dp来存储中间结果&#xff0c;dp[i]表示i个人时的“幸运之星”的初始编号。 以下是使用C的代码实现&#xff1a; #include <iostream> using namespace std;const int MAXN 1000000; int dp[M…

力扣:100379. 新增道路查询后的最短距离 I(Java,BFS)

目录 题目描述&#xff1a;示例 &#xff1a;代码实现&#xff1a; 题目描述&#xff1a; 给你一个整数 n 和一个二维整数数组 queries。 有 n 个城市&#xff0c;编号从 0 到 n - 1。初始时&#xff0c;每个城市 i 都有一条单向道路通往城市 i 1&#xff08; 0 < i < …

web高可用群集架构部署----超详细

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:Linux运维老纪的首页…

【笔记1-6】Qt bug记录:Qt6 无法使用qsort函数排序

在进行Qt5向Qt6升级的过程中&#xff0c;发现Qt6会编译时会出现以下错误&#xff0c;找不到qsort的定义 一开始以为应该是需要头文件或者.pro文件追加一些配置的问题&#xff0c;但是按照下面的提示追加了两个头文件后也没有效果 再进一步调查&#xff0c;找到了下面的文章&a…

python3 pyside6图形库学习笔记及实践(三)

目录 前言菜单栏相关控件使用QtDesigner快速构建菜单栏结构语法 上下文菜单概念为窗体添加上下文菜单为控件添加上下文菜单 折叠菜单资源的加载内置图标Rcc的使用创建资源文件加载资源文件 前言 本系列文章为b站PySide6教程以及官方文档的学习笔记 原视频传送门&#xff1a;【…

【Linux 17】进程信号

文章目录 &#x1f308; 一、信号的概念⭐ 1. 什么是信号⭐ 2. 常见的信号⭐ 3. 信号的管理 &#x1f308; 二、进程的运行⭐ 1. 进程运行模式⭐ 2. 查看后台进程⭐ 3. 运行后台进程⭐ 4. 终止后台进程 &#x1f308; 三、信号的产生⭐ 1. 通过键盘产生信号⭐ 2. 调用系统函数向…

YiYi-Web项目技术栈介绍

项目地址&#xff1a;https://gitee.com/jack0240/yiyi-web YiYi后台管理系统&#xff08;不分离版&#xff09;&#xff0c;SpringBoot Thymeleaf Layui 后台管理系统框架。 前端技术栈 HTML JavaScript JQuery Layui、Bootstrap Echarts图表、大屏展示、富文本 进阶&#…

书生大模型实战营第三期——入门岛

第一关&#xff1a;Linux基础知识 任务如下&#xff1a; 任务描述闯关任务完成SSH连接与端口映射并运行hello_world.py可选任务 1将Linux基础命令在开发机上完成一遍可选任务 2使用 VSCODE 远程连接开发机并创建一个conda环境可选任务 3创建并运行test.sh文件 1. 使用密码进行…

数据结构——单向链表

目录 前言 一、单向链表 二、单向链表基本操作 1、链表单创建 2.节点插入 &#xff08;1&#xff09;尾部插入 &#xff08;2&#xff09;任意位置插入 3、单向链表节点删除 4、链表打印 5、释放链表 6、链表逆序 ...... 三、链表测试 总结 前言 链表&#xff08;Linked List&a…

单细胞Seurat的umi矩阵-与feature、counts(用于质控)

目录 关于umi矩阵学习 用umi计算feature、counts值 ①meta数据查看 ②Count和Feature计算(生成Seurat时自动计算) 1)提取UMI矩阵 2)计算 其他指标 评估质量指标(重点) 1)UMI计数 2)基因计数 3)UMIs vs. genes detected 4)线粒体计数比率 5)综合过滤 过…

【C语言篇】文件操作(下篇)

文章目录 前言文件的顺序读写fscanf和fprintffread和fwrite 文件的随机读写fseekftellrewind 文件读取结束的判定容易被错误使用的feof 文件缓冲区 前言 本篇接上一篇文件操作&#xff08;上篇&#xff09;的内容 文件的顺序读写 在上一篇已经介绍了前面四个了&#xff0c;接…

【人工智能基础四】循环神经网络(RNN)与长短时记忆网络(LSTM)

文章目录 一. RNN1. 循环神经网络结构2. 循环神经网络计算2.1. 机器翻译2.2. 循环体 二. 长短时记忆网路&#xff08;LSTM&#xff09;1. 产生背景2. LSTM的设计思想与LSTM的链式结构2.1. LSTM的设计思想2.2. LSTM链式结构图与遗忘门 3. 长短时记忆网络结构 一. RNN RNN出现的…