HunyuanDiT代码笔记

news2025/2/25 0:07:58

HunyuanDiT 是由腾讯发布的文生图模型,适配中英双语。
在模型方面的改进,主要包括:

  • transformer结构
  • text encoder
  • positional encoding

Improving Training Stability To stabilize training, we present three techniques:

  1. We add layer normalization in all the attention modules before computing Q, K, and V. This technique is called
    QK-Norm, which is proposed in [13]. We found it effective for training Hunyuan-DiT as well.
  2. We add layer normalization after the skip module in the decoder blocks to avoid loss explosion during training.
  3. We found certain operations, e.g., layer normalization, tend to overflow with FP16. We specifically switch
    them to FP32 to avoid numerical errors.

HunyuanDiT的模型结构
在这里插入图片描述

使用Diffusers中的HunyuanDiTPipeline。

import torch
from diffusers import HunyuanDiTPipeline

pipe = HunyuanDiTPipeline.from_pretrained("/disk2/modelscope/hub/Xorbits/HunyuanDiT-v1___2-Diffusers", torch_dtype=torch.float16)
pipe.to("cuda")

# You may also use English prompt as HunyuanDiT supports both English and Chinese
# prompt = "An astronaut riding a horse"
prompt = "一个宇航员在骑马"
image = pipe(prompt).images[0]
image.save("astronaut.jpg")

transformer结构

HunyuanDiT 共包括40个HunyuanDiTBlock。其中前20个的结果要skip到后20个模块中。skip的时候,仍然需要norm,然后使用Linear恢复到之前的维度。

        skips = []
        for layer, block in enumerate(self.blocks):
            if layer > self.config.num_layers // 2:
                if controlnet_block_samples is not None:
                    skip = skips.pop() + controlnet_block_samples.pop()
                else:
                    skip = skips.pop()
                hidden_states = block(
                    hidden_states,
                    temb=temb,
                    encoder_hidden_states=encoder_hidden_states,
                    image_rotary_emb=image_rotary_emb,
                    skip=skip,
                )  # (N, L, D)
            else:
                hidden_states = block(
                    hidden_states,   #(2,4096,1408)
                    temb=temb, #(2,1408)
                    encoder_hidden_states=encoder_hidden_states, #(2,333,1024)
                    image_rotary_emb=image_rotary_emb,
                )  # (N, L, D)

            if layer < (self.config.num_layers // 2 - 1):
                skips.append(hidden_states)

HunyuanDiTBlock 的forward函数,和上图一致,看图更直观。
输入需要Norm,每次attention后需要Norm。在计算attention的时候,Q和K还要Norm。

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        temb: Optional[torch.Tensor] = None,
        image_rotary_emb=None,
        skip=None,
    ) -> torch.Tensor:
        # Notice that normalization is always applied before the real computation in the following blocks.
        # 0. Long Skip Connection
        if self.skip_linear is not None:
            cat = torch.cat([hidden_states, skip], dim=-1)
            cat = self.skip_norm(cat)
            hidden_states = self.skip_linear(cat)

        # 1. Self-Attention
        norm_hidden_states = self.norm1(hidden_states, temb)  ### checked: self.norm1 is correct
        attn_output = self.attn1(
            norm_hidden_states,
            image_rotary_emb=image_rotary_emb,
        )
        hidden_states = hidden_states + attn_output

        # 2. Cross-Attention
        hidden_states = hidden_states + self.attn2(
            self.norm2(hidden_states),   ###
            encoder_hidden_states=encoder_hidden_states,
            image_rotary_emb=image_rotary_emb,
        )

        # FFN Layer ### TODO: switch norm2 and norm3 in the state dict
        mlp_inputs = self.norm3(hidden_states)
        hidden_states = hidden_states + self.ff(mlp_inputs)

        return hidden_states

text encoder

HunyuanDiT 使用CLIP和T5两个文本编码器。CLIP提取文本和图像的关系特征,T5则加强对于prompt的理解。
CLIP 生成的embedding 维度为(1,77,1024),T5生成的embedding 维度为 (1,256,2048)。
使用PixArtAlphaTextProjection,将T5的embedding 对齐到CLIP,然后将两个序列拼到一起。

encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1) #(2,333,1024)

positional encoding

在这里插入图片描述
代码预设的长宽是512,公式中的S就是32,

base_size = 512 // 8 // self.transformer.config.patch_size  #32
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)  #((0, 0), (32, 32)),crop latent
image_rotary_emb = get_2d_rotary_pos_embed(
    self.transformer.inner_dim // self.transformer.num_heads,
    grid_crops_coords,
    (grid_height, grid_width),
    device=device,
    output_type="pt",
)

# get_resize_crop_region_for_grid实现公式
def get_resize_crop_region_for_grid(src, tgt_size):
    th = tw = tgt_size
    h, w = src

    r = h / w

    # resize
    if r > 1:
        resize_height = th
        resize_width = int(round(th / h * w))
    else:
        resize_width = tw
        resize_height = int(round(tw / w * h))

    crop_top = int(round((th - resize_height) / 2.0))
    crop_left = int(round((tw - resize_width) / 2.0))

    return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)

既然是图像的RoPE,那么也只能加在图像的Q和K上,在cross_attention中,K和V来自于prompt。

        # Apply RoPE if needed
        if image_rotary_emb is not None:
            query = apply_rotary_emb(query, image_rotary_emb)
            if not attn.is_cross_attention:
                key = apply_rotary_emb(key, image_rotary_emb)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2279779.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DDD - 如何设计支持快速交付的DDD技术中台

文章目录 Pre概述打造快速交付团队烟囱式的开发团队(BAD)大前端技术中台(GOOD) 技术中台的特征简单易用的技术中台建设总结 Pre DDD - 软件退化原因及案例分析 DDD - 如何运用 DDD 进行软件设计 DDD - 如何运用 DDD 进行数据库设计 DDD - 服务、实体与值对象的两种设计思路…

服务器硬盘RAID速度分析

​ 在现代数据中心和企业环境中&#xff0c;服务器的存储性能至关重要&#xff0c;RAID&#xff08;独立磁盘冗余阵列&#xff09;技术通过将多块硬盘组合成一个逻辑单元&#xff0c;提供了数据冗余和性能优化&#xff0c;本文将详细探讨不同RAID级别对服务器硬盘速度的影响&am…

【Docker】搭建一个功能强大的自托管虚拟浏览器 - n.eko

前言 本教程基于群晖的NAS设备DS423的docker功能进行搭建&#xff0c;DSM版本为 DSM 7.2.2-72806 Update 2。 n.eko 支持多种类型浏览器在其虚拟环境中运行&#xff0c;本次教程使用 Chromium​ 浏览器镜像进行演示&#xff0c;支持访问内网设备和公网地址。 简介 n.eko 是…

五、华为 RSTP

RSTP&#xff08;Rapid Spanning Tree Protocol&#xff0c;快速生成树协议&#xff09;是 STP 的优化版本&#xff0c;能实现网络拓扑的快速收敛。 一、RSTP 原理 快速收敛机制&#xff1a;RSTP 通过引入边缘端口、P/A&#xff08;Proposal/Agreement&#xff09;机制等&…

“深入浅出”系列之C++:(9)线程分离

线程分离的基本概念 线程分离是通过调用 std::thread::detach() 方法实现的。当线程被分离时&#xff0c;它会成为一个独立的线程&#xff0c;并且会自动管理自己的资源。当该线程完成执行时&#xff0c;它会自动清理资源&#xff0c;父线程不再需要等待或回收这个线程。 线程…

Day 13 卡玛笔记

这是基于代码随想录的每日打卡 144. 二叉树的前序遍历 给你二叉树的根节点 root &#xff0c;返回它节点值的 前序 遍历。 示例 1&#xff1a; 输入&#xff1a; root [1,null,2,3] 输出&#xff1a;[1,2,3] 解释&#xff1a; 示例 2&#xff1a; 输入&#xff1a; ro…

【STM32项目实战系列】系列开篇导语

【这个系列到底是什么】 简单来讲就是基于STM32的主控芯片的实际应用项目的介绍&#xff08;当然根据不同的项目功能特性需要使用不同的系列的ST主控芯片&#xff09;&#xff0c;这里面会涉及到基础工程的建立、各种驱动外设、中断和时钟的配置、RTOS的移植方法、文件系统的移…

产业园管理系统赋能企业精细管理与效益提升新路径

内容概要 现在的企业运营面临着越来越复杂的管理挑战&#xff0c;尤其是在园区管理领域。为了提升管理效率和经营效益&#xff0c;产业园管理系统的推出无疑为众多企业提供了全新的解决方案。这套系统通过智能化技术&#xff0c;将资产管理、租赁管理与财务监控等多个功能有机…

论文笔记(六十二)Diffusion Reward Learning Rewards via Conditional Video Diffusion

Diffusion Reward Learning Rewards via Conditional Video Diffusion 文章概括摘要1 引言2 相关工作3 前言4 方法4.1 基于扩散模型的专家视频建模4.2 条件熵作为奖励4.3 训练细节 5 实验5.1 实验设置5.2 主要结果5.3 零样本奖励泛化5.4 真实机器人评估5.5 消融研究 6 结论 文章…

鸿蒙中选择地区

1.首页ui import { CustomDialogExampleSelectRegion } from ./selectRegion/SelectRegionDialog;Entry Component struct Index {State selectedRegion: string 选择地区// 地区dialogControllerSelectRegion: CustomDialogController | null new CustomDialogController({b…

【HarmonyOS NAPI 深度探索12】创建你的第一个 HarmonyOS NAPI 模块

【HarmonyOS NAPI 深度探索12】创建你的第一个 HarmonyOS NAPI 模块 在本篇文章中&#xff0c;我们将一步步走过如何创建一个简单的 HarmonyOS NAPI 模块。通过这个模块&#xff0c;你将能够更好地理解 NAPI 的工作原理&#xff0c;并在你的应用中开始使用 C 与 JavaScript 的…

excel实用工具

持续更新… 文章目录 1. 快捷键1.1 求和 2. 命令2.1 查找 vloopup 1. 快捷键 1.1 求和 windows: alt mac : command shift T 2. 命令 2.1 查找 vloopup vlookup 四个入参数 要查找的内容 &#xff08;A2 6xx1&#xff09;查找的备选集 &#xff08;C2:C19&#xff09;…

Linux中的基本指令(一)

一、Linux中指令的存在意义 Linux中&#xff0c;通过输入指令来让操作系统执行&#xff0c;以此达到控制操作系统的目的&#xff0c;类似于Windows中的双击&#xff0c;右键新建文件&#xff0c;新建文件夹等 1.补&#xff1a;关于屏幕的几个操作指令 ①清屏指令 clear 回…

深入解析 C++17 中的 u8 字符字面量:提升 Unicode 处理能力

在现代软件开发中&#xff0c;处理多语言文本是一个常见需求&#xff0c;特别是在全球化的应用场景下。C17 标准引入的 u8 字符字面量为开发者提供了一个强大的工具&#xff0c;以更有效地处理和表示 UTF-8 编码的字符串。本文将详细探讨 u8 字符字面量的技术细节、实际应用&am…

2025年国产化推进.NET跨平台应用框架推荐

2025年国产化推进.NET跨平台应用框架推荐 1. .NET MAUI NET MAUI是一个开源、免费&#xff08;MIT License&#xff09;的跨平台框架&#xff08;支持Android、iOS、macOS 和 Windows多平台运行&#xff09;&#xff0c;是 Xamarin.Forms 的进化版&#xff0c;从移动场景扩展到…

C++和OpenGL实现3D游戏编程【连载21】——父物体和子物体模式实现

欢迎来到zhooyu的专栏。 &#x1f525;C和OpenGL实现3D游戏编程【专题总览】 1、本节要实现的内容 上节课我们已经创建了一个基础Object类&#xff0c;以后所有的游戏元素都可以从这个基类中派生出来。同时为了操作方便&#xff0c;我们可以为任意两个Object类&#xff08;及其…

unity插件Excel转换Proto插件-ExcelToProtobufferTool

unity插件Excel转换Proto插件-ExcelToProtobufferTool **ExcelToProtobufTool 插件文档****1. 插件概述****2. 默认配置类&#xff1a;DefaultIProtoPathConfig****属性说明** **3. 自定义配置类****定义规则****示例代码** **4. 使用方式****4.1 默认路径****4.2 自定义路径**…

【数据结构篇】顺序表 超详细!

目录 一.顺序表的定义 1.顺序表的概念及结构 1.1线性表 2.顺序表的分类 2.1静态顺序表 2.2动态顺序表 二.动态顺序表的实现 1.准备工作和注意事项 2.顺序表的基本接口&#xff1a; 2.0 创建一个顺序表 2.1 顺序表的初始化 2.2 顺序表的销毁 2.3 顺序表的打印 3.顺序…

vulnhub靶场【IA系列】之Tornado

前言 靶机&#xff1a;IA-Tornado&#xff0c;IP地址为192.168.10.11 攻击&#xff1a;kali&#xff0c;IP地址为192.168.10.2 都采用虚拟机&#xff0c;网卡为桥接模式 本文所用靶场、kali镜像以及相关工具&#xff0c;我放置在网盘中&#xff0c;可以复制后面链接查看 htt…

云上贵州多彩宝荣获仓颉社区先锋应用奖 | 助力数字政务新突破

在信息技术应用创新的浪潮中&#xff0c;仓颉社区吸引了众多企业和开发者的积极参与&#xff0c;已有多个应用成功落地&#xff0c;展现出蓬勃的创新活力。仓颉编程语言精心遴选了在社区建设、应用创新、开源共建、技术布道等方面做出突出贡献的优秀项目应用&#xff0c;并颁发…