【论文精读】DALL·E

news2025/3/15 20:49:30

摘要

       本文利用从互联网上收集的2.5亿个图像/文本对数据,训练了一个120亿参数的自回归transformer,进而得到一个可以通过自然语言/图像控制生成的高保真图像生成模型。在大多数数据集上的表现超越以往的方法。

框架

       本文的目标为通过训练一个自回归transformer,通过将文本和图像tokens自回归建模为单个数据流,进而结合图像解码器进行图像生成,整体分为两个阶段:
image

  • 第一阶段:训练一个离散变分自编码器(dVAE),其编码器会将输入图像从 256 × 256 256 × 256 256×256压缩为 32 × 32 32 × 32 32×32的图像tokens,其中每个token都会映射到 K = 8192 K = 8192 K=8192的codebook向量中。相比于直接使用像素作为图像token,这可以使后续步骤的自回归transfromer的上下文大小减少192倍,同时不会大幅降低视觉质量(如上图)。
  • 第二阶段:将256个由BPE编码的文本tokens 与 32 × 32 = 1024 32 × 32 = 1024 32×32=1024个图像tokens拼接起来,基于此训练一个自回归transformer,实现对文本和图像tokens的联合分布的建模。

       整个过程可以被看作最大化模型在图像 x x x、文本 y y y和tokens z z z上的联合似然的ELBO,通过因子分解可以将该分布建模为 p θ , ψ ( x , y , z ) = p θ ( x ∣ y , z ) p ψ ( y , z ) p_{θ,ψ}(x, y, z) = p_θ(x | y, z)p_ψ(y, z) pθ,ψ(x,y,z)=pθ(xy,z)pψ(y,z),对应的ELBO为:
ln ⁡ p θ , ψ ( x , y ) ≥ E z ∼ q ϕ ( z ∣ x ) ( ln ⁡ p θ ( x ∣ y , z ) − β D K L ( q ϕ ( y , z ∣ x ) , p ψ ( y , z ) ) ) \ln p_{θ,ψ}(x, y) \ge \mathbb{E}_{z \sim q_ϕ(z|x)}(\ln p_θ(x|y,z)-\beta D_{KL}(q_ϕ(y,z|x),p_ψ(y,z))) lnpθ,ψ(x,y)Ezqϕ(zx)(lnpθ(xy,z)βDKL(qϕ(y,zx),pψ(y,z)))

       其中, q ϕ q_ϕ qϕ表示给定图像 x x x经过dVAE编码器生成的 32 × 32 32 × 32 32×32的tokens的分布; p θ p_θ pθ表示由tokens经过dVAE解码器生成的图像的分布; p ψ p_ψ pψ表示由自回归transformer建模的文本和图像tokens的联合分布。该ELBO只适用与 β = 1 \beta = 1 β=1的情况。

Learning the Visual Codebook

       第一阶段的训练目标为最大化 ϕ ϕ ϕ θ θ θ的ELBO,即通过给定图像训练dVAE。先验 p ψ p_ψ pψ初始化为基于codebook( K = 8192 K = 8192 K=8192)向量的均匀分类分布(uniform categorical distribution); q ϕ q_ϕ qϕ初始化为在编码器输出的 32 × 32 × 8192 32 × 32×8192 32×32×8192的logits参数化的分类分布(uniform categorical)。

       由于 p ψ p_ψ pψ是一个离散分布,无法使用梯度进行优化,故此处采用gumbel-softmax松弛,用 q ϕ τ q ^τ_ ϕ qϕτ取代 q ϕ q_ϕ qϕ,当 τ → 0 τ → 0 τ0时,松弛程度会逐渐缩小,逼近原始分布。 p θ p_θ pθ的似然使用log-laplace分布评估,以避免离群值导致的生成模糊问题。

       松弛后的ELBO使用Adam和EMA优化,以下配置对训练稳定性很重要:

  • 松弛temperature和步长的具体退火方法。实验发现 τ τ τ退火到1/16时,松弛ELBO的 q ϕ τ q ^τ_ ϕ qϕτ和真实ELBO的 q ϕ q_ϕ qϕ之间的gap就会消失。
  • 在编码器的末尾和解码器的开头使用1 × 1卷积。 实验发现,通过减少松弛方法周围的卷积层的感受野大小,可以使其泛化到真实ELBO的情况。
  • 将编码器和解码器的输出激活值乘一个小的常数,可以使初始化时的训练更加训练。

       另外,KL权重增加到 β = 6.6 β = 6.6 β=6.6时,可以得到更好codebook,故而使训练结束时的重构误差更小。

Learning the Prior

       第二阶段在固定 ϕ ϕ ϕ θ θ θ的情况下,最大化关于 ψ ψ ψ的ELBO,学习文本和图像token的联合先验分布。其中, p ψ p_ψ pψ是一个120亿参数的稀疏transformer。

       具体,给定一个文本/图像对,首先通过对小写文本进行BPE编码(词汇表大小为16384)得到最多256个文本token ,并对dVAE编码器输出的logits进行argmax采样codebook得到1024个图像token,此处没有添加gumbel噪声。最后,拼接这些文本和图像token作为单个数据流进行自回归建模。

       本文限制文本标题的最大长度为256,每个文本位置都会学习一个特殊的“padding” token,当对应位置没有文本token时使用此token。得到文本和图像token的交叉熵损失后,将文本交叉熵损失乘以1/8,图像交叉熵损失乘以7/8,以对loss归一化,本阶段也使用EMA和Adam进行优化。

Data Collection

       本文从互联网上收集了2.5亿个文本/图像对,创建了一个与JFT-300M相似规模的数据集。该数据集包括一部分Conceptual Captions和YFCC100M的经过滤的子集。

Mixed-Precision Training

       为了节省GPU内存并增加吞吐量,模型的大多数参数、Adam矩阵和模型激活值都以FP16存储,并使用了activation checkpointing技术。

       在训练过程中发现,随着模型变得更深更广,resblocks的激活梯度会单调减少,较深层的resblocks的激活梯度可能小于FP16的最小值,其会被四舍五入为0,这种现象称为下溢(underflow)。实验发现消除下溢可以使训练更加稳定。
image
       故对于模型中每个的resblock,通过执行“gradient scale”可以解决下溢问题,如上图。

Distributed Optimization

image
       DELL-E有120亿参数,以FP16精度存储时会消耗约24GB内存,这超过单张NVIDIA V100 GPU的16GB内存,故使用参数分片。如上图。
image
       本文的实现中,每台机器上的每个GPU都独立地计算其参数分片梯度的低秩因子,而不依赖于其相邻的GPU。一旦计算出低秩因子,每台机器都会将其error buffer设置为其八个GPU上未压缩的参数梯度的平均值(通过reduce-scatter获得)与通过解压缩低秩因子得到的梯度之间的残差(两者偏差)。

       对于一个模型训练集群,其机器之间的带宽远低于同一机器上不同GPU之间的带宽,故机器之间梯度平均操作(all-reduce)成为训练期间的主要速度瓶颈,通过引入PowerSGD压缩梯度,可以大大降低这种成本。 PowerSGD会将未压缩的参数梯度的通信操替换为基于其低秩因子的两个更小的通信操作。给定压缩rank r r r和transformer激活尺寸 d m o d e l d_{model} dmodel,其压缩率为 1 − 5 r / ( 8 d m o d e l ) 1 − 5r/(8d_{model}) 15r/(8dmodel)。如上表显示,无论模型大或小,该方法可以实现约85%的压缩率。

Sample Generation

image
       对于从transformer中生成的一系列图像,本文采用预训练CLIP对生成图像与文本标题的匹配程度来分配分数并排序。如上图显示了给定生成的N张图像,并从中选择的top-k图像。除非另有说明,用于定性和定量结果的所有样本都是在不降低temperature的情况下获得的,并使用N = 512重新排序。

实验

Quantitative Results

image
       上图定性比较了DALL-E和AttnGAN、DM-GAN和DF-GAN的生成。
image
       上图为人类验证实验。给定一个文本标题,相比DF-GAN,DALL-E的生成在93%的情况下与文本标题更好地匹配,而获得更多的人类投票。在90%的情况下,也因为更真实而获得了大多数人类投票。
image
       上图(a上)为在MS-COCO数据集上验证的定量结果,DALL-E与之前最佳方法只差2个点的FID分数。由于DALL-E的训练数据中包含一个YFCC100M的过滤子集,其中包含MS-COCO验证集中大约21%的图像,故为了隔离这种影响,另外分别计算了验证集有这些图像(实线)和没有这些图像(虚线)的FID信息,结果没有明显变化。

       用dVAE编码器的token训练transformer,可以使模型学习更多的图像低频信息,使图像在视觉上更真实。,但这也不利于模型学习产生高频细节。为了验证模型的高频建模能力,本实验对验证图像和模型生成的样本应用了不同半径的高斯滤波器,并计算对应IS值。结果如上图(a下),随着模糊半径的增加,DALL-E和其他方法之间的差距越拉越大,当模糊半径大于等于2时,DALL-E取得了最佳结果。

       DALL-E在CUB数据集上的表现比较差,如上图(b),和之前的主要方法有近40点FID的差距。经过检测发现,训练数据集中包含12%的CUB数据,但去除这些数据后模型表现仍旧不佳。故推测zero-shot DALL-E不太可能在CUB等专业分布的数据集上获得优势。

       上图(c)显示了当用于CLIP重排序的样本增加时,DALL-E的FID有了明显改进。
image
       上图显示了DALL-E在CUB数据集中不同文本标题下的生成示例。

Qualitative Findings

image
       通过验证发现DALL-E有不以最初预期的方式进行泛化的能力。当给出文本标题“a tapir made of accordion… ”,该模型似乎画了一个以手风琴为身 体的貘(上图a)。这表明,其发展出一种基本的能力,可以在较高的抽象层次上组合概念。

       DALL-E似乎也能进行组合泛化,例如在渲染如“an illustration of a baby hedgehog in a christmas sweater walking a dog”这样的句子(上图b、c)。

       在有限的可靠性程度上,还发现该模型能够由自然语言控制图像到图像的翻译。当模型被赋予标题“the exact same cat on the top as a sketch at the bottom”时,其能够在底部画一个类似的猫的草图(上图d)。 这也适用于其他几种类型的转换,包括图像操作(例如改变图像的颜色、将其转换为灰度或翻转图像)和样式转换(例如在贺卡、邮票或手机壳上画猫)。一些只涉及改变动物颜色的转换,表明DALL-E能够执行基本的对象分割。

Appendix

Details for Discrete VAE

Architecture

       dVAE编码器和解码器都为具有bottleneck-style resblocks的ResNets。编码器的第一层卷积核尺寸为 7 × 7 7 × 7 7×7,编码器的最后一层卷积核尺寸为 1 × 1 1×1 1×1(输出尺寸为 32 × 32 × 8192 32 × 32 × 8192 32×32×8192,用作图像token的分类分布的logits)。解码器的第一层卷积和最后一层卷积核尺寸都为 1 × 1 1×1 1×1。编码器使用最大池化下采样,解码器使用最近邻上采样。

Training

image
       dVAE在与transformer相同的数据集上进行训练,使用上图中给出的数据增强代码。以下量在训练过程中使用余弦退火进行衰减:

  • KL权重 β β β在前5000次迭代中从0增加到6.6
  • 松弛 τ τ τ在前150000次迭代中从1退火到1/16
  • 在1200000迭代中,step size从 1 ⋅ 1 0 − 4 1\cdot10^{−4} 1104退火到 1.25 ⋅ 1 0 − 6 1.25 \cdot 10^{−6} 1.25106

       使用 β 1 = 0.9 , β 2 = 0.999 , ϵ = 1 0 − 8 β_1 = 0.9, β_2 = 0.999, ϵ = 10^{−8} β1=0.9,β2=0.999,ϵ=108的AdamW和 1 0 − 4 10^{−4} 104的weight decay和 0.999 0.999 0.999的EMA优化模型。该模型在64个16 GB NVIDIA V100 gpu上使用混合精度训练,每个gpu的batch size为8,总batch size为512,总共3000000次。

Details for Transformer

Architecture

image
       本文第二阶段模型是一个仅解码器的稀疏transformer,其输入tokens embedding格式如上图。其包括64个注意力层,每个层使用62个注意力头,每个头的维度大小为64。
image
       该模型使用三种稀疏注意力mask,如上图。给定自注意力层的索引 i i i i ∈ [ 1 , 63 ] i ∈ [1, 63] i[1,63]),如果 i − 2   m o d   4 = 0 i − 2\ mod \ 4 = 0 i2 mod 4=0,则使用列注意力mask(c),否则使用行注意力mask,例如,前四个自注意力层分别使用row、column、row、row。卷积注意力mask(d)仅用于最后的自注意力层。

Training

image
       对于训练transformer的训练,在使用dVAE编码器编码图像之前,首先对图像进行如上图代码所示的数据增强。在用BPE编码文本标题时,还应用了10%的BPE dropout。该模型使用逐resblock缩放和梯度压缩进行训练,总压缩rank为896(每个GPU的参数分片使用112的压缩rank)。

       使用 β 1 = 0.9 , β 2 = 0.96 , ϵ = 1 0 − 8 β_1 = 0.9, β_2 = 0.96, ϵ = 10^{−8} β1=0.9,β2=0.96,ϵ=108的AdamW与和 4.5 ⋅ 1 0 − 2 4.5 \cdot 10^{−2} 4.5102的weight decay和0.99的EMA优化参数。在应用Adam更新前,会使用阈值为4的norm对解压后的梯度进行裁剪,梯度裁剪仅在训练开始的预热阶段运行。为了节省内存,大部分Adam矩阵以FP16格式存储,其中运行平均值为1-6-9格式(即1位用于符号,6位用于指数,9位用于尾数),运行方差为0-6-10格式,在更新参数或动量之前,会将运行方差裁剪为5。其次,还会异步地将模型参数从GPU复制到CPU(每25次更新复制一次),以获得更稳定的更新。

       该模型在1024个16 GB NVIDIA V100 gpu和总batch size为1024的设置下训练模型,总共进行了430000更新。step size在前5000次迭代中,线性退火到 4.5 ⋅ 1 0 − 4 4.5 · 10^{−4} 4.5104,并在每次训练损失趋于稳定时将step size减半,训练周期内,总共减半了5次,比初始步长小32倍的最终步长结束训练。

Details for Human Evaluation Experiments

image
       对于人类验证实验,本文使每个模型对每个文本标题生成一个示例图像,并给定文本和示例图像让人类给出比较结果,实验提交给了亚马逊的Mechanical Turk,每组生成都由五名不同的人类回答。工作人员被要求比较两张图像并选择答案:(1)哪张图像最真实,(2)哪张图像最匹配文本标题。提供给人类的实验设置如上图。

Zero-Shot Image-to-Image Translation

image
       上图显示了DALL-E的zero-shot图像到图像转换的示例。

reference

Ramesh, A. , Pavlov, M. , Goh, G. , Gray, S. , Voss, C. , & Radford, A. , et al. (2021). Zero-shot text-to-image generation.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1453209.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mysql调优实战

EXPLAIN执行分析 id:值越大越先执行相同时,由上向下执行。 possible_key: 可能走索引的键。 key:真正走索引的键rows:根据表统计信息及索引选用情况,大致估算出找到所需的记录所需要读取的行数,也就是说,用的越少越好 …

004 - Hugo, 分类

004 - Hugo, 分类content文件夹 004 - Hugo, 分类 content文件夹 ├─.obsidian ├─categories │ ├─Python │ └─Test ├─page │ ├─about │ ├─archives │ ├─links │ └─search └─post├─chinese-test├─emoji-support├─Git教程├─Hugo分类├─…

如何在CSS中实现背景图片的渐变?

--引言 在CSS中,实现背景图片的渐变通常需要使用linear-gradient或者radial-gradient函数,这些函数可以与背景图像一起使用来创建渐变效果。然而,CSS的渐变并不直接支持使用图像作为渐变的颜色停止点。但你可以通过一些技巧来实现类似的效果…

2024年【高处安装、维护、拆除】模拟考试题库及高处安装、维护、拆除实操考试视频

题库来源:安全生产模拟考试一点通公众号小程序 高处安装、维护、拆除模拟考试题库是安全生产模拟考试一点通生成的,高处安装、维护、拆除证模拟考试题库是根据高处安装、维护、拆除最新版教材汇编出高处安装、维护、拆除仿真模拟考试。2024年【高处安装…

得物面试:Redis用哈希槽,而不是一致性哈希,为什么?

尼恩说在前面 在40岁老架构师 尼恩的读者交流群(50)中,最近有小伙伴拿到了一线互联网企业如得物、阿里、滴滴、极兔、有赞、希音、百度、网易、美团的面试资格,遇到很多很重要的面试题: Redis为何用哈希槽而不用一致性哈希? 最近…

Prompt Tuning:深度解读一种新的微调范式

阅读该博客,您将系统地掌握如下知识点: 什么是预训练语言模型? 什么是prompt?为什么要引入prompt?相比传统fine-tuning有什么优势? 自20年底开始,prompt的发展历程,哪些经典的代表…

Sora时代,我们的AI应该何去何从?——关于Sora大模型的思考

Sora时代,我们的AI应该何去何从?——关于Sora大模型的思考 一、Sora大模型:横空出世,让AI生成所有领域瑟瑟发抖二、Sora的出现代表了相关行业的灭亡?三、我们将何去何从? 一、Sora大模型:横空出世&#xf…

计算机毕业设计SSM基于的高校学习资源共享系统

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: vue mybatis Maven mysql5.7或8.0等等组成,B…

C++ 多起点的bfs(五十九)【第六篇】

今天我们来学习多起点的bfs 1.多起点的bfs 在普通的广度优先搜索问题中,为了得到从初始状态到达目标状态的最小操作数,则将初始状态放入队列中。离初始状态由近及远地不断扩展出新的状态,直到搜索到目的状态,或队列为空&#xff…

使用Docker部署JDK镜像

构建镜像 我们将已经准备好的docker-demo.jar包以及Dockerfile拷贝到虚拟机的/root/demo目录: 然后,执行命令,构建镜像: # 直接指定Dockerfile目录 docker build -t docker-demo:1.0 /root/demo 查看镜像列表: # 查看…

神经网络算法原理

目录 得分函数 数学表示 计算方法 损失函数 ​编辑 前向传播 反向传播 ​编辑 整体架构 正则化的作用 数据预处理 ​过拟合解决方法 得分函数 得分函数是在机器学习和自然语言处理中常用的一种函数,用于评估模型对输入数据的预测结果的准确性或匹配程度。…

函数、极限、连续——刷题(5

目录 1.题目:2.解题思路和步骤:3.总结:小结: 1.题目: 2.解题思路和步骤: 首先可能想到的是答案为0,但是不可以把 直接化简为n 这里要用到分子分母的平方差,sin^2的周期为π&#x…

WebServer 之 http连接处理(下)

目录 ✊请求报文--解析 流程图 && 状态机 状态机 -- 状态转移图 主状态机 从状态机 http 报文解析 HTTP_CODE 含义 从状态机 逻辑 主状态机 逻辑 🐞请求报文--响应 基础API stat mmap iovec writev 流程图 HTTP_CODE 含义(2) 代码分析 …

及其详细的Markdown基础-学习笔记(附有使用案例)

Markdown 基础语法 查看更多学习笔记:GitHub:LoveEmiliaForever 标题创建 标题语法格式 在文字前添加一至六个#即可创建标题 标题是有等级的,具体等级根据#个数决定 由于标题等级参与构建整篇文章的架构,编写时应该遵循如下规…

【C->Cpp】由C迈向Cpp(3)

正文开始: 目录 (一)函数重载 (1)函数重载 (2)函数重载实现原理 (二) 引用 (1)引用 (2)语法 i ,别名&am…

输入捕获模式测频率PWM输入模式(PWMI)测占空比

一、概念介绍 输出比较: 比较电路输入的CNT、CCR大小关系 ,在通道引脚输出高低电平 二、*频率知识、测量方法补充 * N/fc得到标准频率的时长,也就是待测频率的周期 测频法代码实现:修改对射式红外传感器计次(上升沿…

51_蓝桥杯_蜂鸣器与继电器

一 电路 二 蜂鸣器与继电器工作原理 2.1蜂鸣器与继电器 2.2 十六进制与二进制 二进制 0000 0001 0010 0011 0100 0101 0110 0111 1000 1001 1010 1011 1100 1101 1110 1111 十六进制 0 1 2 3 4 5 6 7 8 9 A B C D E F 2.3非门 二 代码 …

数据集合

目录 并集 union union all 区别 交集 intersect 差集 minus 错误操作 Oracle从入门到总裁:https://blog.csdn.net/weixin_67859959/article/details/135209645 常用的数学集合有:交集、并集、差集、补集 每一次查询实际上都会返回数据集合,…

【Anaconda】conda创建、删除、查看虚拟环境,安装pytorch

1.删除环境 首先退出现有的环境 conda deactivate然后查看要删除的环境名称与路径 conda env list接下来就可以删除环境了 有两种方法 方法1: conda env remove -p 要删除的虚拟环境路径对我来说就是: conda env remove -p D:\Anaconda3\envs\MVDet…

Screw自动生成数据库文档

Screw简介 官方地址 Screw可以根据数据库中的表自动生成HTML、Word、Markdown格式的文档。 Springboot 3.1集成 生成Springboot项目 Spring Initializr Maven依赖 <dependency><groupId>cn.smallbun.screw</groupId><artifactId>screw-core</…