Hunuan-DiT代码阅读

news2024/11/24 17:56:21

一 整体架构

该模型是以SD为基础的文生图模型,具体扩散模型原理参考https://zhouyifan.net/2023/07/07/20230330-diffusion-model/,代码地址https://github.com/Tencent/HunyuanDiT,这里介绍 Full-parameter Training

二 输入数据处理

这里主要包括图像和文本数据输入处理

2.1 图像处理

这里代码参考 hydit/data_loader/arrow_load_stream.py,生成1024*1024的图片,对于输入图片进行random_crop,之后包括随机水平翻转,转tensor,以及Normalize(减均值0.5, 除以标准差0.5,为什么是这个,是因为通过PIL Image读图之后转到tensor范围是0-1之间,不是opencv读出来像素值在0-255之间),得到最终image( B ∗ 3 ∗ 1024 ∗ 1024 B*3*1024*1024 B310241024

2.2 文本处理

输入的文本,通过BertTokenizer,进行映射,同时补齐长度到77,不够的补0,同时生成相应的attention_mask;同时还有T5TokenizerFast,对于T5的输入,会随机小于uncond_p_t5(目前给出的设置uncond_p_t5=5),输入为空,否则为文本输入,补齐长度256,同时生成相应的attention_mask

2.3 图像编码

对于输入图像,采用VAE encoder 进行编码,生成隐空间特征latents( B ∗ 4 ∗ 128 ∗ 128 B*4*128*128 B4128128,就是输入8倍下采样,计算过程latents = vae.encode(image).latent_dist.sample().mul_(vae_scaling_factor),具体VAE相关后续补充)

2.4 文本编码

包括两个部分,一个是CLIP的text编码,采用bert layer,生成encoder_hidden_states( B ∗ 77 ∗ 1024 B*77*1024 B771024);第二部分是mT5的text编码,生成encoder_hidden_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048

2.5 位置编码

这里是采用根据预设的分辨率,提前生成好的位置编码,这里采用ROPE,生成cos_cis_img, sin_cis_img (分别都是 4096 ∗ 88 4096*88 409688)

最终生成图像编码latents,文本编码(encoder_hidden_states以及对应的attention_mask,encoder_hidden_states_t5以及对应的attention_mask),以及位置编码cos_cis_img, sin_cis_img

三 DIT模型

3.1 add noise过程

  • 根据上一步的输出latents,作为x_start,随机选取一个time step,根据q_sample,得到增加噪声之后的输出x_t(具体公式参考如下,x0对应x_start,xt对应x_t)
    在这里插入图片描述

3.2 HunYuanDiT模型训练过程

  • 对于输入的文本编码,包括text_states( B ∗ 77 ∗ 1024 B*77*1024 B771024),text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048)以及相应的attention_mask,对于text_states_t5通过Linear+Silu+Linear,转成 B ∗ 256 ∗ 1024 B*256*1024 B2561024,然后对着两个进行concat,得到text_states( B ∗ 333 ∗ 1024 B*333*1024 B3331024),对于attention_mask也concat得到clip_t5_mask( B ∗ 333 B*333 B333);这里会生成一个可学习的text_embedding_padding特征( B ∗ 333 ∗ 1024 B*333*1024 B3331024),对于clip_t5_mask中通过补0得到的特征全部替换成text_embedding_padding特征
  • 对于输入time step 先走timestep_embedding(就是sinusoidal编码),然后通过Linear+Silu+Linear得到最终t ( B ∗ 1408 B*1408 B1408)
  • 对于输入x(就是上一步的x_t),通过PatchEmbed(就是VIT前面对图像进行patch),得到x( B ∗ 4096 ∗ 1408 , 4096 是 64 ∗ 64 B*4096*1408,4096是64*64 B4096140840966464
  • 对于text_states_t5( B ∗ 256 ∗ 2048 B*256*2048 B2562048),添加一个AttentionPool模块,就是对于输入在256维度上,进行mean,当成query,然后将输入和query concat一起得到257维,作为key和value,(其中query,key,value都添加位置编码)做multi_head_attention,得到最终输出extra_vec( B ∗ 1024 B*1024 B1024
  • 对于extra_vec 通过Linear+Silu+Linear得到( B ∗ 1408 B*1408 B1408),然后与通过time step得到的t相加,得到c( B ∗ 1408 B*1408 B1408,作为所有extra_vectors)

3.2.1 进入Dit Block

一共40个block,前面0到18个block的生成输入,中间19,20作为middle block,剩余的block会增加一个前面19个block输出的结果作为skip

3.2.1.1 前面0到18共19个block
  • 前面一共19个block的过程,输入x( B ∗ 4096 ∗ 1408 B*4096*1408 B40961408),c( B ∗ 1408 B*1408 B1408),text_states( B ∗ 333 ∗ 1024 B*333*1024 B3331024),位置编码freqs_cis_img (cos_cis_img, sin_cis_img,分别都是 B ∗ 4096 ∗ 88 B*4096*88 B409688
HunYuanDiTBlock(
  (norm1): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
  (attn1): FlashSelfMHAModified(
    (Wqkv): Linear(in_features=1408, out_features=4224, bias=True)
    (q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)
    (k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)
    (inner_attn): FlashSelfAttention(
      (drop): Dropout(p=0.0, inplace=False)
    )
    (out_proj): Linear(in_features=1408, out_features=1408, bias=True)
    (proj_drop): Dropout(p=0.0, inplace=False)
  )
  (norm2): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
  (mlp): Mlp(
    (fc1): Linear(in_features=1408, out_features=6144, bias=True)
    (act): GELU(approximate='tanh')
    (drop1): Dropout(p=0, inplace=False)
    (norm): Identity()
    (fc2): Linear(in_features=6144, out_features=1408, bias=True)
    (drop2): Dropout(p=0, inplace=False)
  )
  (default_modulation): Sequential(
    (0): FP32_SiLU()
    (1): Linear(in_features=1408, out_features=1408, bias=True)
  )
  (attn2): FlashCrossMHAModified(
    (q_proj): Linear(in_features=1408, out_features=1408, bias=True)
    (kv_proj): Linear(in_features=1024, out_features=2816, bias=True)
    (q_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)
    (k_norm): LayerNorm((88,), eps=1e-06, elementwise_affine=True)
    (inner_attn): FlashCrossAttention(
      (drop): Dropout(p=0.0, inplace=False)
    )
    (out_proj): Linear(in_features=1408, out_features=1408, bias=True)
    (proj_drop): Dropout(p=0.0, inplace=False)
  )
  (norm3): FP32_Layernorm((1408,), eps=1e-06, elementwise_affine=True)
)
  • 对于c 通过default_modulation,得到shift_msa( B ∗ 4096 ∗ 1408 B*4096*1408 B40961408),与经过norm1之后的x进行相加作为attn1的输入(就是Flash Self Attention)
  • 将attn1的输出与原始的x进行残差相加,在经过norm3,与text_states一起作为attn2的输入(就是Flash Cross Attention)
  • 在将经过残差相加之后的x与attn2的输出在进行残差相加,作为输入,走FFN,即先经过norm2,在经过mlp,之后与输入残差相加
3.2.1.2 第19和20 middle block
  • 中间第19 和 20 两个block作为middle block,方式和上面一样
3.2.1.3 后面21到39共19个block
  • 从第21个block开始,增加一个输入,例如第21个block,会将第18个block的输出作为输入
  (skip_norm): FP32_Layernorm((2816,), eps=1e-06, elementwise_affine=True)
  (skip_linear): Linear(in_features=2816, out_features=1408, bias=True)
  • 就是对于新的输入skip,将skip与x进行concat之后,经过skip norm,然后在经过skip linear,得到输出x,剩余步骤与前面一样

3.2.2 最后FInal layer处理

  • 输入x和c,x是上面所有dit block的输出,c是上面的extra_vectors;对于c先进行SILU+Linear,得到( B ∗ 2816 B*2816 B2816),并彩分成shift 和 scale(分别为 B ∗ 1408 B*1408 B1408),最终通过x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1),然后通过Linear,得到最终输出x( B ∗ 4096 ∗ 32 B*4096*32 B409632),然后通过转换得到输出imgs ( B ∗ 8 ∗ 128 ∗ 128 B*8*128*128 B8128128

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2203164.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一、Spring入门

文章目录 1. 课程内容介绍2. Spring5 框架概述3. Spring5 入门案例 1. 课程内容介绍 2. Spring5 框架概述 3. Spring5 入门案例

Java反射专题

目录 一.反射机制 1.Java Reflection 2.反射相关的主要类 3.反射的优缺点 4.反射调用优化—关闭访问检查 二.Class类 1.基本介绍 2.常用方法 3.获取Class对象的方式 4.那些类型有Class对象 三.类加载 1.介绍 2.类加载时机 3.类加载各阶段 四.获取类结构的信息 1…

25.第二阶段x86游戏实战2-背包属性补充

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 本次游戏没法给 内容参考于:微尘网络安全 本人写的内容纯属胡编乱造,全都是合成造假,仅仅只是为了娱乐,请不要…

机器学习项目——运用机器学习洞察青年消费趋势

1. 项目背景 在21世纪的第三个十年,全球经济和技术的飞速发展正深刻影响着各个领域,尤其是青年消费市场。随着数字化进程的加速,尤其是移动互联网的广泛普及,青年的消费行为和生活方式发生了前所未有的转变。 然而,面对…

开源版GPT-4o来了!互腾讯引领风潮,开源VITA多模态大模型,开启自然人机交互新纪元[文末领福利]

目录 总览 VITA 模型的详细介绍 2.1 LLM 指令微调 2.2 多模态对齐 2.2.1 视觉模态 2.3 音频模态 多模态指令微调 3.1 训练数据 3.1.1 训练过程 模型部署:双工策略 4.1 非唤醒交互 4.2 音频中断交互 评估 5.1 语言表现 5.2 音频表现 5.3 多模态表现 …

离散微分几何中的网格(Meshes)

https://zhuanlan.zhihu.com/p/893338073 一、引言 ![](https://img-blog.csdnimg.cn/img_convert/c5e06e652822e0003cf6be91d26436b7.png) 在离散微分几何的广袤领域中,网格(Meshes)作为一个核心概念,犹如一座桥梁,…

小灰:从0到年入100万+,从程序员到自由职业者他经历了什么?

这是开发者说的第20期,这次给大家带来的畅销书《漫画算法》作者、自媒体创作者程序员小灰。 小灰做自媒体的时间已经有8年了,目前在全网有60w粉丝,同时《漫画算法》系列和《漫画ChatGPT》的书籍,在全网卖了12万册,靠写…

rocky9 samba共享

1. 安装samba服务,设置开机自启。 2. 创建四个用户user1,user2,sale1,manager,user1,user2属于finance组,sale1属于sales组,manager属于manager组。 3. 建立共享目录/opt/finance_…

模版进阶 非类型模版参数

一.模板参数分类类型形参与非类型形参。 类型形参即:出现在模板参数列表中,跟在class或者typename之类的参数类型名称。 非类型形参,就是用一个常量作为类(函数)模板的一个参数,在类(函数)模板中可将该参数当成常量来使用。 #i…

打印1000年到2000年之间的闰年

我们要打印1000年到2000年之间的闰年,首先我们先输出1000年到2000年之间的所有的年份,同时我们将闰年的判断方法输入到其中 闰年需要满足下列两个条件的其中之一: 1.能被4整除但不能被100整除 2.能被400整除 打印1000年到2000年之间的闰年…

PCL 贪婪三角投影三角化

目录 一、概述二、代码三、结果 一、概述 PCL中贪婪三角投影三角化的简单使用案例 二、代码 greedy_projection.cpp #include <pcl/point_types.h> #include <pcl/io/pcd_io.h> #include <pcl/search/kdtree.h> // for KdTree #include <pcl/features/…

【软件系统架构设计师-案例-1】架构风格

1. 请用200字以内说明系统可靠性的定义及包含的4个子特性&#xff0c;并简要指出提高系统可靠性一般采用哪些技术&#xff1f; &#xff08;1&#xff09;可靠性定义&#xff1a;系统在规定的时间或环境条件下&#xff0c;完成规定功能的能力&#xff0c;就是系统无故障运行的…

【AD速成】半小时入门AltiumDesigner(速通基础)

一.创建工程 1.工程 文件->新的->项目 PCB选择<Default>Project Name填入自己的工程名称Folder选择工程保存的路径 创建后如图&#xff1a; 这里的.prjPcb的文件即为AD的工程文件。 如果没有Project栏可以在视图->面板->Projects中勾选Projects CtrlS保存工…

Java学习-JVM调优

目录 1. JDK工具包 1.1 jps 1.2 jstat 1.3 jinfo 1.4 jmap 1.5 jhat 1.6 jstack 1.7 VisualVM 2. 第三方工具 2.1 GCEasy 2.2 MAT 2.3 Arthas 3. JVM参数 3.1 标准参数 3.2 非标准参数 3.3 不稳定参数 4. 调优 4.1 什么时候调优 4.2 调优调什么 4.3 调优原…

LINUX 系统管理操作

基础编辑 Tab 单击一次补全 双击列举候选 CTRL U 删除光标前 K 删除光标后 L 清屏&#xff08;只剩新命令行&#xff09; C 取消当前操作 反斜杠“\” 在需要转行的时候输入反斜杠 “\”回车 在>后继续输入 帮助命令 help 命令 大部分内建命令 格式&#xff1a;h…

直播预告 | 药品安全与合规保障难?智能温度监测助您领先制药工业4.0!

您是否在为温度敏感药品的运输和存储合规而苦恼&#xff1f; 是否担心冷链物流中的温度监控漏洞导致药品质量下降&#xff1f; 制药环境中的温湿度监控是否让您无从下手&#xff1f; 这些问题不仅影响药品的安全性&#xff0c;也直接影响企业的合规性和市场竞争力。如何确保环…

Android 保存图片到相册却不在“照片”中显示,只在相簿中显示

背景 需要从网络上下载图片到本地&#xff0c; 并显示在相册中 问题 将图片保存到内存中&#xff0c; 通过媒体API插入到媒体库后&#xff0c;图片只在“相簿”中的“所有项目”中显示&#xff0c;第一个页面的“照片”却不显示 解决办法 图片被保存到 Pictures/AppName 目录…

Linux系统通过编辑crontab来设置定时任务---定时关机

在Linux系统中&#xff0c;crontab 是用来设置周期性被执行的指令的守护进程。通过编辑 crontab&#xff0c;您可以安排定时任务&#xff0c;比如定时关机、定时备份文件、定时运行脚本等。以下是如何编辑 crontab 来设置定时任务的步骤&#xff1a; 打开终端&#xff1a;您可以…

基于springboot+vue的在线宠物用品交易网站

一、系统架构 前端&#xff1a;vue | element-ui | html 后端&#xff1a;springboot | mybatis-plus 环境&#xff1a;jdk1.8 | mysql | maven | nodejs 二、代码及数据库 三、功能介绍 01. web端-首页1 02. web端-首页2 03. web端-注册 04. web端-登录 05. w…

“万万没想到”,“人工智能”获得2024年诺贝尔物理学奖

近日&#xff0c;2024年诺贝尔物理学奖颁发给了机器学习与神经网络领域的研究者&#xff0c;这是历史上首次出现这样的情况。这项奖项原本只授予对自然现象和物质的物理学研究作出重大贡献的科学家&#xff0c;如今却将全球范围内对机器学习和神经网络的研究和开发作为了一种能…