AIGC笔记--U-ViT的简单代码实现

news2025/4/4 11:35:51

1--前言

原论文:All are Worth Words: A ViT Backbone for Diffusion Models

完整可debug的代码:Simple_U-ViT

2--结构

3--简单代码

        以视频作为输入,实现上图红色框的计算:

import torch
import torch.nn as nn
from einops import rearrange

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels = 3, patch_size = 16, emb_dim = 768):
        super().__init__()
        self.patch_size = patch_size
        # 由于是视频作为输入,因此使用Conv3d
        self.proj = nn.Conv3d(in_channels = in_channels, out_channels = emb_dim, kernel_size = (1, patch_size, patch_size), stride = (1, patch_size, patch_size))
    
    def forward(self, x):
        # x.shape: [batch_size, channels, frames, height, width]
        x = self.proj(x) # (batch_size, emd_dim, frames, height // patch_size, width // patch_size)
        x = rearrange(x, 'b e t h w -> b (t h w) e') # flatten into patches
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, emb_dim = 768, num_heads = 8, ff_dim = 1024, dropout = 0.1, skip = False):
        super().__init__()
        self.layernorm1 = nn.LayerNorm(emb_dim)
        self.self_attn = nn.MultiheadAttention(embed_dim = emb_dim, num_heads = num_heads, dropout = dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.layernorm2 = nn.LayerNorm(emb_dim)
        self.ffn = nn.Sequential(
            nn.Linear(emb_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, emb_dim)
        )
        self.dropout2 = nn.Dropout(dropout)
        if skip:
            self.skip_linear = nn.Linear(2 * emb_dim, emb_dim)

    def forward(self, x, skip = None):
        # x.shape: [b, thw, c]
        if(skip != None): # long skip connection
            x = self.skip_linear(torch.cat([x, skip], dim=-1)) # [b, thw, 2c] -> [b, thw, c]
        x2 = self.layernorm1(x) # norm
        x2, _ = self.self_attn(x2, x2, x2) # attention
        x = x + self.dropout1(x2) # skip
        x2 = self.layernorm2(x) # norm
        x2 = self.ffn(x2) # mlp
        x = x + self.dropout2(x2) # skip
        return x
    
class UViT(nn.Module):
    def __init__(self, in_channels = 3, patch_size = 16, emb_dim = 768, depth = 5, num_heads = 8, ff_dim = 1024, dropout = 0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(in_channels, patch_size, emb_dim)
        self.in_blocks = nn.ModuleList([
            TransformerEncoder(emb_dim = emb_dim, num_heads = num_heads, ff_dim = ff_dim, dropout = dropout) for _ in range(depth // 2)
        ])
        self.mid_block = TransformerEncoder(emb_dim = emb_dim, num_heads = num_heads, ff_dim = ff_dim, dropout = dropout)
        self.out_blocks = nn.ModuleList([
            TransformerEncoder(emb_dim = emb_dim, num_heads = num_heads, ff_dim = ff_dim, dropout = dropout, skip = True) for _ in range(depth // 2)
        ])
        self.layernorm = nn.LayerNorm(emb_dim)
    
    def forward(self, x): # [b, c, t, h, w]
        x = self.patch_embed(x) # [b, t * (h//patch) * (w//patch), c']

        skips = [] # 存储in_blocks的输出, 作为long skip, 本质上是一个栈(后进先出)
        for blk in self.in_blocks:
            x = blk(x)
            skips.append(x)
        
        x = self.mid_block(x)

        for blk in self.out_blocks:
            x = blk(x, skips.pop()) # 栈的pop,作为long skip
    
        x = self.layernorm(x)
        return x
    
if __name__ == "__main__":
    batch_size = 2
    in_channels = 3
    frames = 8
    height = 128
    width = 128
    video_tensor = torch.randn(batch_size, in_channels, frames, height, width).cuda() # 这里以视频作为输入,原论文是以图片作为输入

    # depth 必须是奇数,因为Long skip connection的存在
    model = UViT(in_channels = in_channels, patch_size = 16, emb_dim = 768, depth = 5, num_heads = 8, ff_dim = 1024, dropout = 0.1).cuda()
    output = model(video_tensor)
    print("output.shape: ", output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1866635.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

6.2 通过构建情感分类器训练词向量

在上一节中,我们简要地了解了词向量,但并没有去实现它。在本节中,我们将下载一个名为IMDB的数据集(其中包含了评论),然后构建一个用于计算评论的情感是正面、负面还是未知的情感分类器。在构建过程中,还将为 IMDB 数据…

分享一个 MySQL 简单快速进行自动备份和还原的脚本和方法

前言 数据备份和还原在信息技术领域中具有非常重要的作用,不论是人为误操作、硬件故障、病毒感染、自然灾害还是其他原因,数据丢失的风险都是存在的。如果没有备份,一旦数据丢失,可能对个人、企业甚至整个组织造成巨大的损失。 …

2-17 基于matlab的改进的遗传算法(IGA)对城市交通信号优化分析

基于matlab的改进的遗传算法(IGA)对城市交通信号优化分析。根据交通流量以及饱和流量,对城市道路交叉口交通信号灯实施合理优化控制,考虑到交通状况的动态变化,及每个交叉口的唯一性。通过实时监测交通流量&#xff0c…

部署企业级AI知识库最重要的是什么?✍

随着人工智能技术的迅猛发展,企业级AI知识库成为提升企业管理效率和信息获取能力的重要工具。那么,在部署企业级AI知识库时,最重要的是什么呢?本文将从数据质量、系统可扩展性、用户体验以及智能化这四个关键方面进行详细分析。 …

单片机是否有损坏,怎沫判断

目录 1、操作步骤: 2、单片机损坏常见原因: 3、 单片机不工作的原因: 参考:细讲寄存器读写与Bit位操作原理--单片机C语言编程Bit位的与或非屏蔽运算--洋桃电子大百科P019_哔哩哔哩_bilibili 1、操作步骤: 首先需要…

Objects and Classes (对象和类)

Objects and Classes [对象和类] 1. Procedural and Object-Oriented Programming (过程性编程和面向对象编程)2. Abstraction and Classes (抽象和类)2.1. Classes in C (C 中的类)2.2. Implementing Class Member Functions (实现类成员函数)2.3. Using Classes References O…

第 28 篇 : SSH秘钥登录

1 生成秘钥 ssh-keygen -t rsa ls -a ./.ssh/一直回车就行了 2. 修改配置 vi /etc/ssh/sshd_config放开注释 公钥的位置修改 关闭密码登录 PubkeyAuthentication yes AuthorizedKeysFile .ssh/id_rsa.pub PasswordAuthentication no3. 下载id_rsa私钥, 自行解决 注意…

Vue中数组的【响应式】操作

在 Vue.js 中,当你修改数组时,Vue 不能检测到以下变动的数组: 当你利用索引直接设置一个项时,例如:vm.items[indexOfItem] newValue当你修改数组的长度时,例如:vm.items.length newLength 为…

常见图像分割模型介绍:FCN、U-Net、SegNet、Mask R-CNN

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…

AI助力校园安全:EasyCVR视频智能技术在校园欺凌中的应用

一、背景分析 近年来,各地深入开展中小学生欺凌行为治理工作,但有的地方学生欺凌事件仍时有发生,严重损害学生身心健康,引发社会广泛关注。为此,教育部制定了《防范中小学生欺凌专项治理行动工作方案》进一步防范和遏…

软件构造 | 期末查缺补漏

软件构造 | 期末查缺补漏 总体观 软件构造的三维度八度图是由软件工程师Steve McConnell提出的概念,用于描述软件构建过程中的三个关键维度和八个要素。这些维度和要素可以帮助软件开发团队全面考虑软件构建的方方面面,从而提高软件质量和开发效率。 下…

文件批量重命名001到100 最简单的数字序号递增的改名技巧

文件批量重命名001到100 最简单的数字序号递增的改名方法。最近看到很多人都在找怎么批量修改文件名称,还要按固定的ID需要递增,这个办法用F2或者右键改名是不能做到的。 这时候我们可以通过一个专业的文件批量重命名软件来批量处理这些文档。 芝麻文件…

A-8 项目开源 qt1.0

A-8 2024/6/26 项目开源 由于大家有相关的需求,就创建一个项目来放置相关的代码和项目 欢迎交流,QQ:963385291 介绍 利用opencascade和vulkanscene实现stp模型的查看器打算公布好几个版本的代码放在不同的分支下,用qt实现&am…

如何获取特定 HIVE 库的元数据信息如其所有分区表和所有分区

如何获取特定 HIVE 库的元数据信息如其所有分区表和所有分区 1. 问题背景 有时我们需要获取特定 HIVE 库下所有分区表,或者所有分区表的所有分区,以便执行进一步的操作,比如通过 使用 HIVE 命令 MSCK REPAIR TABLE table_name sync partiti…

Redis实战—基于setnx的分布式锁与Redisson

本博客为个人学习笔记,学习网站与详细见:黑马程序员Redis入门到实战 P56 - P63 目录 分布式锁介绍 基于SETNX的分布式锁 SETNX锁代码实现 修改业务代码 SETNX锁误删问题 SETNX锁原子性问题 Lua脚本 编写脚本 代码优化 总结 Redisson 前言…

基于盲信号处理的人声分离

1.问题描述 在实际生活中,存在一种基本现象称为“鸡尾酒效应”,该效应指即使在非常嘈杂的环境中,人依然可以从噪声中提取出自己所感兴趣的声音。 在实际应用中,我们可能需要对混合的声音进行分离,此时已知的只有混合…

springcloud第4季 springcloud-alibaba之openfegin+sentinel整合案例

一 介绍说明 1.1 说明 1.1.1 消费者8081 1.1.2 openfegin接口 1.1.3 提供者9091 9091微服务满足: 1 openfegin 配置fallback逻辑,作为统一fallback服务降级处理。 2.sentinel访问触发了自定义的限流配置,在注解sentinelResource里面配置…

DVWA 靶场 SQL Injection 通关解析

前言 DVWA代表Damn Vulnerable Web Application,是一个用于学习和练习Web应用程序漏洞的开源漏洞应用程序。它被设计成一个易于安装和配置的漏洞应用程序,旨在帮助安全专业人员和爱好者了解和熟悉不同类型的Web应用程序漏洞。 DVWA提供了一系列的漏洞场…

数学建模--Matlab求解线性规划问题两种类型实际应用

1.约束条件的符号一致 (1)约束条件的符号一致的意思就是指的是这个约束条件里面的,像这个下面的实例里面的三个约束条件,都是小于号,这个我称之为约束条件符号一致; (2)下面的就是上…