DALLE2: Hierarchical Text-Conditional Image Generation with CLIP Latents
paper: https://cdn.openai.com/papers/dall-e-2.pdf
github: https://github.com/lucidrains/DALLE2-pytorch
DALLE2概览:
- CLIP模型:
用于生成text embedding zt 和image embedding zi
- prior模型:
1) 模型输入:为 the encoded text,the CLIP text embedding,time_embed,image_embed,learned_queries,(文本整体embedding,文本序列embedding,时间步embedding,当前t步对应的图片embedding,用于输出transformer 结果手动构造用于学习的embedding )
2) 模型: diffusion model使用transformer(不是unet)直接预测x0,然后通过diffusion递推公式生成前一步图片embedding.
3)最终输出:为 image Embedding (不同于上面CLIP生成的image embedding )
- decoder 模型
1)模型输入:为 prior 输出的image Embedding
2)模型:diffusion model使用unet网络,预测噪声z (不同于prior模型直接预测x0)
3)模型输出:经过T步去噪后,最后一步x0即为模型输出
0 Abstract
基于对比学习思想,我们提出了两阶段模型,
①一个先验模型prior:
- 在给定文本条件下生成CLIP的 image embedding
② 一个decoder模型:
- 在给定imge embedding 条件下,生成图片
We use diffusion models for the decoder and experiment with both autoregressive and diffusion models for the prior, finding that the latter are computationally more efficient and produce higher quality samples.
我们使用diffusion 模型作为decoder 模型,实验了自回归autoregressive 和diffusion模型作为prior模型,发现diffusion 模型作为先验模型效过更好
1 Introduction
- 虚线上面的是CLIP模型,通过CLIP模型可以学习到text 和image的embedding,
- 虚线以下是文本到图片的生成过程,
① CLIP的 text embedding 喂给autoregressive或者diffusion模型( prior模型),生成image embedding
② 然后根据上面的image embedding喂给decoder 模型,生成最终的图片image
2 Method
- Our training dataset consists of pairs (x, y) of images x and their corresponding captions y. Given an image x,let zi and zt be its CLIP image and text embeddings, respectively. We design our generative stack to produce images from captions using two components:
- 我们训练数据集由成对的(x,y)组成,x是图片,y是文本,给定x和y,通过CLIP模型,可以分别生成image 和text embedding,zi和 zt。
- A prior P(zi|y) that produces CLIP image embeddings zi conditioned on captions y.
一个prior 模型用在给定文本时,生成image embedding zi. - A decoder P(x|zi, y) that produces images x conditioned on CLIP image embeddings zi (and optionally text captions y).
decoder 模型用于在给定条件zi时,生成最终图片 x。
整个过程如下所示
2.1 Decoder
-
We use diffusion models to produce images conditioned on CLIP image embeddings (and optionally text captions).
-
在prior模型生成的image embedding的基础上, 我们使用 diffusion models生成image。
-
将image embedding作为条件直接加上timestep embedding(也可以选择添加加text embedding,实验发现用处不大),然后通过下面的diffusion 去噪公式 ,选择unet网络预测噪声,生成最终的图片x
μ ˉ t = 1 α t ( x t − 1 − α t 1 − α ˉ t z t ) \bar \mu_t=\frac{1 } {\sqrt \alpha_{t}} (x_t -\frac{1-\alpha_t } {\sqrt{1- \bar \alpha_{t}}} z_t) μˉt=αt1(xt−1−αˉt1−αtzt)
2.2 Prior
• While a decoder can invert CLIP image embeddings zi to produce images x, we need a prior model that produces zi from captions y to enable image generations from text captions.
decoder 模型输入 image embedding zi 生成image x,需要prior模型生成的zi.
• Diffusion prior: The continuous vector zi is directly modelled using a Gaussian diffusion model conditioned on the caption y.
Diffusion prior : 给定文本y(clip 模型生成的文本向量)时,通过Gaussian diffusion model 直接生成 zi。为了改善样本质量,训练时我们随机mask掉10%的文本数据。
- 对于 diffusion prior,我们训练一个 decoder-only的Transformer模型,对输入序列使用causal attention mask。用于预测x0 (重点:不是噪声zt)
- Transformer模型的输入: the encoded text,the CLIP text embedding,time_embed,image_embed,learned_queries,(文本整体embedding,文本序列embedding,时间步embedding,当前t步对应的图片embedding,用于输出transformer 结果手动构造用于学习的embedding )
- diffusion 过程: 随机初始化xt,dffusion通过下面公式反向传播公式生成x(t-1)数据(transformer 模型直接生成x0),直到最后一步x0
μ ˉ t ( x t , x 0 ) = α t ( 1 − α ˉ t − 1 ) 1 − α ˉ t x t + α ˉ t − 1 ( 1 − α t ) 1 − α ˉ t x 0 \bar \mu_t(x_t,x_0)=\frac{\sqrt{\alpha_{t}}(1-\bar \alpha_{t-1} ) } {1- \bar \alpha_{t}} x_t +\frac{\sqrt{\bar \alpha_{t-1}}(1-\alpha_t) } {1- \bar \alpha_{t}} x_0 μˉt(xt,x0)=1−αˉtαt(1−αˉt−1)xt+1−αˉtαˉt−1(1−αt)x0