[DL]深度学习_扩散模型正弦时间编码

news2024/11/28 19:00:21

1  扩散模型时间步嵌入

1.1  时间步正弦编码

        在扩散模型按时间步 t 进行加噪去噪过程时,需要包括反映噪声水平的时间步长 t 作为噪声预测器的额外输入。但是最初与图像配套的时间步 t 是数字,需要将代表时间步 t 的数字编码为向量嵌入。嵌入时间向量的宽度dim是按照输出设定的。

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.

    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = th.exp(
        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
    ).to(device=timesteps.device)
    args = timesteps[:, None].float() * freqs[None]
    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

参数定义

  • timesteps--时间步 t ,包含batch中每个元素的时间步信息,是一个一维张量,形状为[N];
  • dim--编码后时间嵌入的最后一维大小,即输出多宽的向量;
  • max_period--控制嵌入的最小频率,值为10000。 

        假设此时的batch中有4个图像,则此时4个图像有4个对应的扩散时间步,所需编码的时间嵌入宽度是8。timesteps的形状为[4] 。

timesteps = [t1, t2, t3, t4], dim = 8

编码过程

    half = dim // 2

        计算所需输出向量的一半尺寸,以供后续将正弦嵌入和余弦嵌入按照最后维度concat拼接。

    freqs = th.exp(
        -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
    ).to(device=timesteps.device)

        用了固定化编码方式,计算出的频率向量,且该频率向量的维度是由dim的一半half决定的。arange(start=0,end=half)则代表arange(0,4),若以i为代表,则此时 i = 0,1,2,3。公式如下:

freqs=e^{(-\frac{i\times log(maxperiod)}{half})} 

则此时计算出的 freqs = [freqs1,freqs2,freqs3,freqs4],是一个一维张量,形状为[4]。

    args = timesteps[:, None].float() * freqs[None]

        首先,timesteps[:, None]作用是将timesteps的形状从一维张量[4]扩展为2维张量[4x1],之前是 timesteps = [ t1, t2, t3, t4 ] 扩展为 timesteps = [ [ t1 ], [ t2 ], [ t3 ], [ t4 ] ]。

        其次将freqs的形状从一维张量[4]扩展为2维张量[1x4],之前是 freqs = [freqs1,freqs2,freqs3,freqs4] 扩展为 freqs = [ [freqs1,freqs2,freqs3,freqs4] ]。

        这样做的目的是为了将时间步 timesteps 代表的batch内每个图片的时间步与频率 freqs 做乘法,将时间步广播进去。得出args的形状是 [4x4] ,即 [N x half] ,每个时间步对应得到half列,有N个时间步。

    embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)

        将得到的args求cos值和sin值,此时得到的 cos(args) 与 sin(args) 的形状依旧都是 [N x half],按照最后一维拼接 cos(args) 与 sin(args) ,则拼接后得到的embedding的形状是 [4 x 8], [N x 2half] ,即 [N x dim]。对应开始定义的 dim 代表的是编码后的最后维度大小。 

    if dim % 2:
        embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

        最后加一个判断,当 dim % 2 ==1 即dim是奇数时,embedding的形状是 [N x dim-1],需要用零向量补充1个维度,变为 [N x dim] 。

1.2  馈入多层感知机 

        当对时间步长执行正弦编码之后,需要将其馈送到多层感知机中获得隐式时间嵌入。 

        time_embed_dim = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        多层感知机由两个具有倒置瓶颈结构的线性层和一个SiLU激活函数组成。其中dim = channels = 8,time_embed_dim = model_channels * 4 = 32。需要注意是 N 为 batch 大小。

  • linear(model_channels, time_embed_dim)中,输入形状为 [4 x 8] ,[N x dim] 即 [N x model_channels] ,输出形状为 [4 x 32] , [N x time_embed_dim]。
  • 激活函数不会改变形状。
  • linear(time_embed_dim, time_embed_dim)中,输入形状为 [4 x 32] ,[N x time_embed_dim],输出形状为 [4 x 32],[N x time_embed_dim]。

        多层感知机主要作用是将输入的时间步正弦编码向量变得更宽。通过激活函数引入非线性。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2249254.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Linux】网络基本配置命令

一、使用网络配置命令 1.常用网络配置文件 1.1. /etc/hosts “/etc/hosts”文件保存着IP地址和主机名或域名的静态映射关系。当用户使用一个主机名或域名时,系统会在该文件中查找与它对应的IP地址。 [rootlocalhost ~]# cat /etc/hosts 127.0.0.1 localhost l…

如何搭建一个小程序:从零开始的详细指南

在当今数字化时代,小程序以其轻便、无需下载安装即可使用的特点,成为了连接用户与服务的重要桥梁。无论是零售、餐饮、教育还是娱乐行业,小程序都展现了巨大的潜力。如果你正考虑搭建一个小程序,本文将为你提供一个从零开始的详细…

学习threejs,使用设置lightMap光照贴图创建阴影效果

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:threejs gis工程师 文章目录 一、🍀前言1.1 ☘️THREE.MeshLambertMaterial…

【前端】JavaScript中的柯里化(Currying)详解及实现

博客主页: [小ᶻ☡꙳ᵃⁱᵍᶜ꙳] 本文专栏: 前端 文章目录 💯前言💯什么是柯里化?💯柯里化的特点💯柯里化的简单示例💯通用的柯里化实现💯柯里化让代码更易读的原因&#x1f4af…

springboot 整合 rabbitMQ (延迟队列)

前言: 延迟队列是一个内部有序的数据结构,其主要功能体现在其延时特性上。这种队列存储的元素都设定了特定的处理时间,意味着它们需要在规定的时间点或者延迟之后才能被取出并进行相应的处理。简而言之,延时队列被设计用于存放那…

Java代码操作Zookeeper(使用 Apache Curator 库)

1. Zookeeper原生客户端库存在的缺点 复杂性高:原生客户端库提供了底层的 API,需要开发者手动处理很多细节,如连接管理、会话管理、异常处理等。这增加了开发的复杂性,容易出错。连接管理繁琐:使用原生客户端库时&…

Django实现智能问答助手-基础配置

设置 Django 项目、创建应用、定义模型和视图、实现问答逻辑,并设计用户界面。下面是一步一步的简要说明: 目录: QnAAssistant/ # 项目目录 │ ├── QnAAssistant/ # 项目文件夹 │ ├── init.py # 空文件 │ ├── settings.py # 项目配…

【ESP32CAM+Android+C#上位机】ESP32-CAM在STA或AP模式下基于UDP与手机APP或C#上位机进行视频流/图像传输

前言: 本项目实现ESP32-CAM在STA或AP模式下基于UDP与手机APP或C#上位机进行视频流/图像传输。本项目包含有ESP32源码(arduino)、Android手机APP源码以及C#上位机源码,本文对其工程项目的配置使用进行讲解。实战开发,亲测无误。 AP模式,就是ESP32发出一个WIFI/热点提供给电…

从〇开始深度学习(0)——背景知识与环境配置

从〇开始深度学习(0)——背景知识与环境配置 文章目录 从〇开始深度学习(0)——背景知识与环境配置写在前面1.背景知识1.1.Pytorch1.2.Anaconda1.3.Pycharm1.4.CPU与GPU1.5.整体关系 2.环境配置2.1.准备工作2.1.1.判断有无英伟达显卡2.1.2.清理电脑里的旧环境 2.1.安装Anaconda…

mac下Gpt Chrome升级成GptBrowser书签和保存的密码恢复

cd /Users/自己的用户名/Library/Application\ Support/ 目录下有 GPT\ Chrome/ Google/ GptBrowser/ GPT\ Chrome 为原来的chrome浏览器的文件存储目录. GptBrowser 为升级后chrome浏览器存储目录 书签所在的文件 Bookmarks 登录账号Login 相关的文件 拷贝到GptBrow…

圆域函数的傅里叶变换和傅里叶逆变换

空域圆域函数的傅里叶变换 空域圆域函数(也称为空间中的圆形区域函数)通常指的是在二维空间中,以原点为中心、半径为 a a a的圆内取值为1,圆外取值为0的函数。这种函数可以表示为: f ( x , y ) { 1 if x 2 y 2 ≤ …

Java基础——类型转化(强制转化)

目录 1.数字间的类型转换 (1) 隐式类型转换 (2)显式类型转换(强制类型转换) 2.类对象间的强制转换 (1) 向上转型 (2) 向下转型 将一个类型强制转换成另…

数据结构C语言描述5(图文结合)--广义表讲解与实现

前言 这个专栏将会用纯C实现常用的数据结构和简单的算法;有C基础即可跟着学习,代码均可运行;准备考研的也可跟着写,个人感觉,如果时间充裕,手写一遍比看书、刷题管用很多,这也是本人采用纯C语言…

23种设计模式-装饰器(Decorator)设计模式

文章目录 一.什么是装饰器设计模式?二.装饰器模式的特点三.装饰器模式的结构四.装饰器模式的优缺点五.装饰器模式的 C 实现六.装饰器模式的 Java 实现七.代码解析八.总结 类图: 装饰器设计模式类图 一.什么是装饰器设计模式? 装饰器模式&…

构建英语知识网站:Spring Boot框架解析

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统,它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等,非常…

数据结构之数组与链表的差异

一、数组 数组(Array)是由相同类型的元素(element)的集合所组成的数据结构,分配一块连续的内存来存储。利用元素的索引(index)可以计算出该元素对应的存储地址。最简单的数据结构类型是一维数组…

RabbitMQ7:消息转换器

欢迎来到“雪碧聊技术”CSDN博客! 在这里,您将踏入一个专注于Java开发技术的知识殿堂。无论您是Java编程的初学者,还是具有一定经验的开发者,相信我的博客都能为您提供宝贵的学习资源和实用技巧。作为您的技术向导,我将…

Ubuntu20.04+ROS 进行机械臂抓取仿真:环境搭建(一)

目录 一、从官网上下载UR机械臂 二、给UR机械臂添加夹爪 三、报错解决 本文详细介绍如何在Ubuntu20.04ROS环境中为Universal Robots的UR机械臂添加夹爪。首先从官方和第三方源下载必要的软件包,包括UR机械臂驱动、夹爪插件和相关依赖。然后,针对gazeb…

(即插即用模块-Attention部分) 二十、(2021) GAA 门控轴向注意力

文章目录 1、Gated Axial-Attention2、代码实现 paper:Medical Transformer: Gated Axial-Attention for Medical Image Segmentation Code:https://github.com/jeya-maria-jose/Medical-Transformer 1、Gated Axial-Attention 论文首先分析了 ViTs 在训…

Git 进程占用报错-解决方案

背景 大仓库,由于开发者分支较多,我们在使用 git pull 或 git push 等命令时(与远端仓库交互的命令),不知之前配置了什么,我的电脑会必现以下报错(有非常长一大串报错-不同分支的git进程占用报…