Open-Sora代码详细解读(2):时空3D VAE

news2024/9/19 7:29:55

Diffusion Models视频生成

前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。

目录

3D VAE原理

代码剖析

2D VAE

时间VAE

因果3D卷积


3D VAE原理

之前绝大多数都是2D VAE,特别是SDXL的VAE相当好用,很多人都拿来直接用了。但是在DiT-based的模型中,时间序列上如果再不做压缩的话,就已经很难训得动了。因此非常有必要在时间序列上进行压缩,3D VAE应运而生。

Open-Sora的方案是在2D VAE的基础上,再添加一个时间VAE,相比于EasyAnimate 和 CogVideoX的方案的Full Attention 存在劣势,但是可以充分利用到2D VAE的权重,成本更低。

代码剖析

2D VAE

来自华为pixart sdxl vae:

    vae_2d = dict(
        type="VideoAutoencoderKL",
        from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
        subfolder="vae",
        micro_batch_size=micro_batch_size,
        local_files_only=local_files_only,
    )

时间VAE

    vae_temporal = dict(
        type="VAE_Temporal_SD",
        from_pretrained=None,
    )
@MODELS.register_module()
class VAE_Temporal(nn.Module):
    def __init__(
        self,
        in_out_channels=4,
        latent_embed_dim=4,
        embed_dim=4,
        filters=128,
        num_res_blocks=4,
        channel_multipliers=(1, 2, 2, 4),
        temporal_downsample=(True, True, False),
        num_groups=32,  # for nn.GroupNorm
        activation_fn="swish",
    ):
        super().__init__()

        self.time_downsample_factor = 2 ** sum(temporal_downsample)
        # self.time_padding = self.time_downsample_factor - 1
        self.patch_size = (self.time_downsample_factor, 1, 1)
        self.out_channels = in_out_channels

        # NOTE: following MAGVIT, conv in bias=False in encoder first conv
        self.encoder = Encoder(
            in_out_channels=in_out_channels,
            latent_embed_dim=latent_embed_dim * 2,
            filters=filters,
            num_res_blocks=num_res_blocks,
            channel_multipliers=channel_multipliers,
            temporal_downsample=temporal_downsample,
            num_groups=num_groups,  # for nn.GroupNorm
            activation_fn=activation_fn,
        )
        self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)

        self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)
        self.decoder = Decoder(
            in_out_channels=in_out_channels,
            latent_embed_dim=latent_embed_dim,
            filters=filters,
            num_res_blocks=num_res_blocks,
            channel_multipliers=channel_multipliers,
            temporal_downsample=temporal_downsample,
            num_groups=num_groups,  # for nn.GroupNorm
            activation_fn=activation_fn,
        )

    def get_latent_size(self, input_size):
        latent_size = []
        for i in range(3):
            if input_size[i] is None:
                lsize = None
            elif i == 0:
                time_padding = (
                    0
                    if (input_size[i] % self.time_downsample_factor == 0)
                    else self.time_downsample_factor - input_size[i] % self.time_downsample_factor
                )
                lsize = (input_size[i] + time_padding) // self.patch_size[i]
            else:
                lsize = input_size[i] // self.patch_size[i]
            latent_size.append(lsize)
        return latent_size

    def encode(self, x):
        time_padding = (
            0
            if (x.shape[2] % self.time_downsample_factor == 0)
            else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor
        )
        x = pad_at_dim(x, (time_padding, 0), dim=2)
        encoded_feature = self.encoder(x)
        moments = self.quant_conv(encoded_feature).to(x.dtype)
        posterior = DiagonalGaussianDistribution(moments)
        return posterior

    def decode(self, z, num_frames=None):
        time_padding = (
            0
            if (num_frames % self.time_downsample_factor == 0)
            else self.time_downsample_factor - num_frames % self.time_downsample_factor
        )
        z = self.post_quant_conv(z)
        x = self.decoder(z)
        x = x[:, :, time_padding:]
        return x

    def forward(self, x, sample_posterior=True):
        posterior = self.encode(x)
        if sample_posterior:
            z = posterior.sample()
        else:
            z = posterior.mode()
        recon_video = self.decode(z, num_frames=x.shape[2])
        return recon_video, posterior, z

因果3D卷积

class CausalConv3d(nn.Module):
    def __init__(
        self,
        chan_in,
        chan_out,
        kernel_size: Union[int, Tuple[int, int, int]],
        pad_mode="constant",
        strides=None,  # allow custom stride
        **kwargs,
    ):
        super().__init__()
        kernel_size = cast_tuple(kernel_size, 3)

        time_kernel_size, height_kernel_size, width_kernel_size = kernel_size

        assert is_odd(height_kernel_size) and is_odd(width_kernel_size)

        dilation = kwargs.pop("dilation", 1)
        stride = strides[0] if strides is not None else kwargs.pop("stride", 1)

        self.pad_mode = pad_mode
        time_pad = dilation * (time_kernel_size - 1) + (1 - stride)
        height_pad = height_kernel_size // 2
        width_pad = width_kernel_size // 2

        self.time_pad = time_pad
        self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)

        stride = strides if strides is not None else (stride, 1, 1)
        dilation = (dilation, 1, 1)
        self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)

    def forward(self, x):
        x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
        x = self.conv(x)
        return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2136311.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

齐活儿了:一文读懂ERP和MRP、MES、CRM、WMS、SRM、APS等系统

ERP,即企业资源计划系统,是驱动企业资源整合与高效管理的核心引擎。它覆盖了企业财务、人力资源、研发创新、生产制造、供应链管理、采购活动、销售市场、客户服务以及资产管理这九大核心业务领域,形成了一个全方位、多层次的企业价值链管理体…

a-table 定时平滑轮播组件

效果图&#xff1a; 实现代码如下&#xff1a; <template><div class"scroll-container" mouseenter"stopScroll" mouseleave"startScroll"><a-table:columns"columns":data-source"visibleData":paginatio…

【BFS专项】— 解决最短路问题

1、迷宫中离入口最近的出口 - 力扣&#xff08;LeetCode&#xff09; 思路&#xff1a; 利用BFS层序遍历&#xff0c;新建一个变量统计步数代码&#xff1a; class Solution {int dx[] {0, 0, -1, 1};int dy[] {1, -1, 0, 0};public int nearestExit(char[][] maze, int[] en…

安泰高压放大器在基于EHD电喷的柔性压力传感器制备研究中的应用

实验名称&#xff1a;基于EHD电喷的柔性压力传感器制备技术研究 研究方向&#xff1a;EHD电喷技术是近年来出现的一种微纳尺度的新型3D打印技术&#xff0c;因其打印精度高、设备操作简单、材料来源广泛以及成本低等特点受到广泛关注&#xff0c;在柔性电子、生物医疗和可穿戴设…

C++ 继承【一篇让你学会继承】

1. 继承的概念及定义 1.1 继承的概念 继承机制是面向对象程序设计使代码可以复用的最重要的手段&#xff0c;它允许程序员在保持原有类特征的基础上进行扩展&#xff0c;增加功能&#xff0c;这样产生新的类&#xff0c;称派生类。继承呈现了面向对象程序设计的层次结构&…

基于JavaWeb开发的java springboot+mybatis电影售票网站管理系统前台+后台设计和实现

基于JavaWeb开发的java springbootmybatis电影售票网站管理系统前台后台设计和实现 &#x1f345; 作者主页 网顺技术团队 &#x1f345; 欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; &#x1f345; 文末获取源码联系方式 &#x1f4dd; &#x1f345; 查看下方微信号获…

【Linux实践】实验二:LINUX操作基础

【Linux实践】实验二&#xff1a;LINUX操作基础 实验目的实验内容实验步骤及结果1. 打开终端2. 关闭计算机命令3. 查看帮助文档4. 修改计算机主机名5. 显示月历和时间6. 统计行数、字符数、单词数 这章开始要涉及到命令了&#xff0c;其他关于命令的内容可以看我 2021年写的笔记…

量水堰计的校准与维护:确保测量结果的准确性

量水堰计作为水利工程中用于测量和调节水流量的重要设备&#xff0c;其准确性和可靠性直接关系到水利设施的正常运行及数据收集的精度。因此&#xff0c;定期校准与维护量水堰计是确保测量结果准确性的关键步骤。本文将从量水堰计的校准方法和周期&#xff0c;以及日常维护保养…

wifi贴码推广能赚钱吗?wifi贴码怎么跟商家沟通?

大家好&#xff0c;我是鲸天科技千千&#xff0c;大家都知道我是做开发的&#xff0c;平时会给大家分享一些互联网相关的创业项目和网络技巧&#xff0c;感兴趣的可以给我点个关注。 最近WiFi这个项目很多朋友来问我&#xff0c;我是前两年就接触过这个&#xff0c;所以比较了…

望繁信科技携流程智能解决方案亮相CNDS 2024新能源产业数智峰会

9月13日&#xff0c;CNDS 2024中国新能源产业数智峰会在北京圆满落幕。本次峰会以“走向数字新能源”为主题&#xff0c;汇聚了来自新能源领域的顶尖领袖、专家学者及知名企业代表&#xff0c;共同探讨数字化技术在新能源行业中的创新应用和发展趋势。上海望繁信科技有限公司&a…

中秋出游热度十足!喆啡酒店如何巧妙捕捉多元旅游需求?

中秋假期临近&#xff0c;多家旅游OTA平台陆续发布旅游热度预测&#xff0c;皆认为中秋小长假有望延续暑期旅游热度。马蜂窝大数据显示&#xff0c;“中秋去哪”关键词近一周热度环比上涨110%&#xff0c;且“中秋3日游”关键词的热度更是大涨175%。消费趋势方面&#xff0c;受…

CAT1 DTU软硬件设计开源资料分析(TCP协议+GNSS定位版本 )

一、CAT1 DTU方案简介&#xff1a; 远程终端单元DTU&#xff0c;一种针对通信距离较长和工业现场环境恶劣而设计的具有模块化结构的、特殊的计算机测控单元&#xff0c;它将末端检测仪表和执行机构与远程控制中心相连接。 奇迹TCP DTUGNSS版本DTU&#xff0c;用于将远程现场的…

【面试干货】软件测试面试题汇总

我把软件测试面试的整个题库都搬来啦&#xff0c;面试能拿下80%&#xff0c;剩下就看你满不满意公司的开价咯。以下答案都是我自己写的&#xff0c;大家根据自己的经历稍作改动&#xff0c;答案仅供参考哦&#xff01;题库持续更新&#xff0c;需要PDF版可以点击文末小卡片领取…

Unity3d 以鼠标位置点为中心缩放视角(正交模式下)

思路整理&#xff1a; 缩放前&#xff1a; 缩放后&#xff1a; 记录缩放前鼠标的屏幕坐标 A&#xff0c;计算鼠标位置对应的世界坐标 A_world 缩放完成后&#xff0c;根据当前屏幕下A所对应的世界坐标A1_world 计算A1_world 和 A_world 的偏移量 移动摄像机 代码&#xff…

将5s1的搜索难度曲线二次归一化

在行列可自由变换的平面上&#xff0c;5点结构只有34个 (A,B)---6*30*2---(0,1)(1,0) 分类A和B&#xff0c;让A是34个5点结构&#xff0c;让B全是0。收敛误差为7e-4&#xff0c;收敛199次取迭代次数平均值&#xff0c; 让训练集A-B矩阵的高分别是5&#xff0c;6.当高为5的时候…

陪护小程序|陪护小程序成品|陪护小程序源码

陪护系统是为提供病人及其家属更好的服务而开发的一种软件系统。在开发陪护系统时&#xff0c;有一些注意事项是需要考虑的。 首先&#xff0c;需要明确陪护系统的主要功能和目标群体。陪护系统可以包括病人信息管理、医护人员协作、医药管理、预约挂号等功能。我们需要确定开发…

项目管理 | 一文读懂什么是敏捷开发管理

在快速变化的商业环境中&#xff0c;项目管理方式也在不断演进&#xff0c;其中敏捷开发管理因其高效、灵活和适应性强的特点&#xff0c;逐渐成为众多企业和团队的首选。本文将详细解析敏捷开发管理的定义、具体内容及其核心角色&#xff0c;帮助读者全面理解这一先进的项目管…

Python基础语法(3)上

函数 函数是什么 编程中的函数和数学中的函数有一定的相似之处. 数学上的函数&#xff0c;比如 y sin x&#xff0c;x 取不同的值&#xff0c;y 就会得到不同的结果 编程中的函数是一段可以被重复使用的代码片段 代码示例&#xff1a;求数列的和&#xff0c;不使用函数 …

教育培训小程序开发,简单实用的入门指南

教育培训小程序可以帮助教育机构和个人老师提供更灵活的在线教学服务&#xff0c;满足学生的学习需求。对于初学者来说&#xff0c;开发一个功能齐全的教育培训小程序并不复杂&#xff0c;只需掌握一些基础的开发知识和工具即可。本文将带你了解如何使用微信小程序开发工具&…

如何准备技术面试?

大家好&#xff0c;我是老三&#xff0c;好久没更新了&#xff0c;翻出之前的一篇旧稿&#xff0c;是一篇总纲性质的文章——如何准备一场技术面试。这篇文章原本的开头是写给金三银四的&#xff0c;转眼就“金九银十”了&#xff0c;每一年都是最差的一年&#xff0c;又是未来…