AIGC学习笔记——DALL-E2详解+测试

news2024/11/19 13:22:06

它主要包括三个部分:CLIP,先验模块prior和img decoder。其中CLIP又包含text encoder和img encoder。(在看DALL·E2之前强烈建议先搞懂CLIP模型的训练和运作机制,之前发过CLIP博客)

论文地址:https://cdn.openai.com/papers/dall-e-2.pdf

代码地址:https://github.com/lucidrains/DALLE2-pytorch

1、简介

DALLE2提出了一个两阶段模型,利用类似CLIP的对比模型学习到的图像表示。第一阶段是一个先验模型,根据文本描述生成CLIP图像嵌入;第二阶段是一个解码器,根据图像嵌入生成相应的图像。我们发现,通过明确生成图像表示,可以提高图像多样性,同时最小程度地损失真实感和描述相似性。我们的解码器在图像表示的条件下,能够产生保留其语义和风格的图像变体,同时变化了图像表示中缺少的非关键细节。此外,CLIP的联合嵌入空间使得可以通过语言引导图像操作,实现零-shot学习。我们采用扩散模型进行解码,并尝试了自回归和扩散模型作为先验模型,结果显示后者在计算上更高效且生成的样本质量更高。

2、模型介绍

DALLE2模型的工作原理很简单,它接受文本描述并通过CLIP将其编码成向量表示,然后通过先验模块生成与文本相关的图像表示,最后,图像解码器将该表示解码成一张具体的图像,实现了根据文本生成对应图像的任务。(在看DALLE2之前强烈建议先搞懂前面引言中的CLIP模型的训练和运作机制)下图中,虚线上方:训练CLIP过程;虚线下方:由文本生成图像过程

3、训练过程

  • 训练CLIP,使其能够编码文本和对应图像

这一步是与CLIP模型的训练方式完全一样的,目的是能够得到训练好的text encoder和img encoder。这么一来,文本和图像都可以被编码到相应的特征空间。对应上图中的虚线以上部分。

  • 训练prior,使文本编码可以转换为图像编码

将CLIP中训练好的text encoder拿出来,输入文本y,得到文本编码Zt。同样的,将CLIP中训练好的img encoder拿出来,输入图像x得到图像编码Zx。我们希望prior能从Zt获取相对应的Zt。假设Zt经过prior输出的特征为Zi',那么我们自然希望Zi与Zi'越接近越好,这样来更新我们的prior模块。最终训练好的prior,将与CLIP的text encoder串联起来,它们可以根据我们的输入文本y生成对应的图像编码特征Zi了。关于具体如何训练prior,有兴趣的小伙伴可以精读一下论文,作者使用了主成分分析法PCA来提升训练的稳定性。(下图借鉴了一篇知乎的博客的图片)

作者团队尝试了两种先验模型:自回归式Autoregressive (AR) prior 和扩散模型Diffusion prior 。实验效果上发现两种模型的性能相似,而因为扩散模型效率较高,因此最终选择了扩散模型作为prior模块。

小辉问:详细说说自回归式Autoregressive (AR) prior 和扩散模型Diffusion prior

小G答:自回归式(Autoregressive,AR)先验模型和扩散模型(Diffusion Model)是两种不同的先验模型,用于生成图像。

自回归式(Autoregressive)先验模型:

  • 工作原理:自回归式先验模型将图像的生成过程建模为一个序列生成任务。在每个时间步,模型会根据之前生成的内容和输入的条件,预测下一个像素点的取值。生成过程从图像的某个起始位置开始,逐渐生成整张图像。
  • 特点:自回归式先验模型通常使用递归神经网络(RNN)或变换器(Transformer)等结构来建模生成过程。由于生成过程是逐步进行的,因此生成的图像往往具有较高的清晰度和连续性。然而,这种逐步生成的方法可能会导致较慢的生成速度,并且难以处理全局一致性。
  • 优点与局限:自回归式先验模型适用于需要考虑图像局部信息和序列相关性的任务,但由于生成过程是顺序执行的,因此可能受限于生成速度和全局一致性。

扩散模型(Diffusion Model)先验模型:

  • 工作原理:扩散模型先验模型的核心思想是逐步“扩散”图像中的随机噪声,从而生成最终的图像。生成过程从一个随机初始化的图像开始,然后通过一系列步骤逐渐减小噪声,生成越来越清晰的图像。
  • 特点:扩散模型先验模型通常利用马尔可夫链蒙特卡洛(MCMC)方法来建模生成过程,每个步骤都会根据当前图像状态和噪声水平生成下一个图像状态。这种逐步“扩散”的方法能够生成具有较高质量和全局一致性的图像,同时也具有较快的生成速度。
  • 优点与局限:扩散模型先验模型在处理全局一致性和生成速度方面表现出色,但可能在捕捉局部细节和序列相关性方面稍显不足。

综上所述,自回归式(Autoregressive)先验模型和扩散模型(Diffusion Model)都是常用的先验模型,各自具有不同的特点和适用场景。选择合适的先验模型取决于具体的任务需求和性能要求。

  • 训练decoder生成最终的图像

也就是说我们要训练decoder模块,从图像特征Zi还原出真实的图像 x,如下图左边所示。这个过程与自编码器类似,从中间特征层还原出输入图像,但又不完全一样。我们需要生成出的图像,只需要保持原始图像的显著特征就可以了,这样以便于多样化生成,例如下图。图像经过img encoder再经decoder得到重建图像。顶部图像为输入。

DALLE2使用的是改进的GLIDE模型。这个模型可以根据CLIP图像编码的Zi,还原出具有相同与x有相同语义,而又不是与x完全一致的图像。

4、推理过程(由文本生成图像过程)

经过以上三个步骤的训练,已经可以完成DALLE2预训练模型的搭建了。我们这时候丢掉CLIP中的img encoder,留下CLIP中的text encoder,以及新训练好的prior和decoder。这么一来流程自然很清晰了。由text encoder将文本进行编码,再由prior将文本编码转换为图像编码,最后由decoder进行解码生成图像。如下图(借鉴的知乎的博客)

5、实验demo理解

5.1、训练CLIP模型

首先初始化一个CLIP模型,然后打印了其结构。接着,在一个循环中,生成了一些虚拟的文本和图像数据,并用于训练CLIP模型。在每一轮训练中,打印了生成的文本和图像数据,以及训练过程中的对比损失,并执行了梯度计算。

import torch
from dalle2_pytorch.x_clip import CLIP


# 初始化CLIP模型
clip = CLIP(
    dim_text = 512,                        # 文本编码维度
    dim_image = 512,                       # 图像编码维度
    dim_latent = 512,                      # 潜在特征维度
    num_text_tokens = 49408,               # 文本token数量
    text_enc_depth = 1,                    # 文本编码器深度
    text_seq_len = 256,                    # 文本序列长度
    text_heads = 8,                        # 文本编码器头数
    visual_enc_depth = 1,                  # 图像编码器深度
    visual_image_size = 256,               # 图像输入尺寸
    visual_patch_size = 32,                # 图像切片尺寸
    visual_heads = 8,                      # 图像编码器头数
    use_all_token_embeds = True,           # 是否使用细粒度对比学习(FILIP)
    decoupled_contrastive_learning = True, # 使用解耦的对比学习(DCL)目标函数,从InfoNCE损失的分母中删除正对比对(CLOOB + DCL)
    extra_latent_projection = True,        # 是否为文本到图像和图像到文本的比较使用单独的投影(CLOOB)
    use_visual_ssl = True,                 # 是否对图像进行自监督学习
    visual_ssl_type = 'simclr',            # 可以是'simclr'或'simsiam',取决于使用DeCLIP还是SLIP
    use_mlm = False,                       # 是否在文本上使用遮蔽语言学习(MLM)(DeCLIP)
    text_ssl_loss_weight = 0.05,           # 文本MLM损失权重
    image_ssl_loss_weight = 0.05           # 图像自监督学习损失权重
).cuda()

# 打印模型结构
print(clip)

# 模拟数据
for i in range(1):
    text = torch.randint(0, 49408, (4, 256)).cuda()  # 随机生成文本数据,shape为(4, 256)
    images = torch.randn(4, 3, 256, 256).cuda()     # 随机生成图像数据,shape为(4, 3, 256, 256)

    print(f"\n--- 第 {i+1} 轮训练 ---")

    print("随机生成的文本数据:")
    print(text)

    print("\n随机生成的图像数据:")
    print(images)

    # 训练

    loss = clip(
        text,
        images,
        return_loss = True  # 需要设置为True以返回对比损失
    )

    print("\n训练过程中的对比损失:")
    print(loss.item())

    loss.backward()

    print("\n梯度计算完毕。")

# 在循环中尽可能多地使用文本和图像来执行以上操作

5.2、训练解码器

使用一个训练好的CLIP模型来辅助生成。首先,加载了训练好的CLIP模型,并创建了用于解码器的Unet模型。然后,创建了解码器,其中包含Unet模型和CLIP模型。接着,生成了一些虚拟图片数据,并将其输入解码器进行训练。最后,通过反向传播更新解码器的参数,重复这个过程多次,直到模型学会根据CLIP图像嵌入生成图片。

import torch
from dalle2_pytorch import Unet, Decoder, CLIP

# 加载训练好的 CLIP 模型
clip = CLIP(
    dim_text = 512,
    dim_image = 512,
    dim_latent = 512,
    num_text_tokens = 49408,
    text_enc_depth = 1,
    text_seq_len = 256,
    text_heads = 8,
    visual_enc_depth = 1,
    visual_image_size = 256,
    visual_patch_size = 32,
    visual_heads = 8
).cuda()

# 创建用于解码器的 Unet 模型
unet = Unet(
    dim = 128,
    image_embed_dim = 512,
    cond_dim = 128,
    channels = 3,
    dim_mults=(1, 2, 4, 8)
).cuda()

# 创建解码器,包含 Unet 和 CLIP
decoder = Decoder(
    unet = unet,
    clip = clip,
    timesteps = 100,
    image_cond_drop_prob = 0.1,
    text_cond_drop_prob = 0.5
).cuda()

# 打印模型结构
print("CLIP Model Architecture:")
print(clip)
print("\nUnet Model Architecture:")
print(unet)
print("\nDecoder Model Architecture:")
print(decoder)


# 创建虚拟图片数据(获取大量数据)
images = torch.randn(4, 3, 256, 256).cuda()

# 循环训练
for epoch in range(10):  # 假设训练10个epoch
    # 输入数据并进行训练
    loss = decoder(images)
    loss.backward()
    # 输出训练信息
    print(f"Epoch [{epoch+1}/10], Loss: {loss.item()}")

# 训练完成
print("Training completed.")

# 重复以上步骤多次,让模型学会根据 CLIP 图像嵌入生成图片

5.3、训练扩散先验网络

从给定的文本描述生成对应的图像嵌入。代码首先创建了一个包含文本和图像编码功能的CLIP模型,然后建立了一个包含自回归Transformer的先验网络,并将CLIP模型和先验网络结合在一起形成扩散先验网络。接着通过虚拟数据进行训练,在训练循环中反复迭代,使网络逐渐学习从文本到图像嵌入的映射关系。

import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP

# 从第一步获取训练好的 CLIP 模型

clip = CLIP(
    dim_text=512,  # 文本编码维度
    dim_image=512,  # 图像编码维度
    dim_latent=512,  # 潜在空间维度
    num_text_tokens=49408,  # 文本词汇表大小
    text_enc_depth=6,  # 文本编码器的深度
    text_seq_len=256,  # 文本序列长度
    text_heads=8,  # 文本注意力头数
    visual_enc_depth=6,  # 图像编码器的深度
    visual_image_size=256,  # 图像输入大小
    visual_patch_size=32,  # 图像分块大小
    visual_heads=8,  # 图像注意力头数
).cuda()

# 设置包含自回归 Transformer 的先验网络

prior_network = DiffusionPriorNetwork(
    dim=512,  # 输入维度
    depth=6,  # 网络深度
    dim_head=64,  # 注意力头维度
    heads=8  # 注意力头数
).cuda()

# 创建扩散先验网络,其中包含上述的 CLIP 模型和网络(带有 Transformer)

diffusion_prior = DiffusionPrior(
    net=prior_network,  # 先验网络
    clip=clip,  # CLIP 模型
    timesteps=100,  # 时间步数
    cond_drop_prob=0.2  # 条件丢失的概率
).cuda()

# 创建虚拟数据

text = torch.randint(0, 49408, (4, 256)).cuda()  # 随机生成文本数据
images = torch.randn(4, 3, 256, 256).cuda()  # 随机生成图像数据

# 打印一些数据信息
print("Text shape:", text.shape)
print("Images shape:", images.shape)

# 模拟训练循环
for step in range(10):  # 循环10次
    # 将文本和图像输入扩散先验网络
    loss = diffusion_prior(text, images)
    # 打印损失值
    print(f"Step {step + 1}, Loss: {loss.item()}")
    # 反向传播并更新参数
    loss.backward()
    # 清空梯度
    diffusion_prior.zero_grad()

# 现在扩散先验网络可以从文本嵌入生成图像嵌入

上面demo只是为了理解DALLE2的原理, 最后的效果很糟糕,下面我想用预训练模型推理一下,看看效果

6、测试结果

预训练模型地址:https://huggingface.co/laion/DALLE2-PyTorch

推理脚本

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig

# 从预训练模型配置文件中加载 Diffusion Prior 模型配置
prior_config = TrainDiffusionPriorConfig.from_json_path("./weights/prior_config.json").prior
# 创建并加载 Diffusion Prior 模型
prior = prior_config.create().cuda()

# 加载预训练的 Diffusion Prior 模型参数
prior_model_state = torch.load("./weights/prior_latest.pth")
prior.load_state_dict(prior_model_state, strict=True)

# 从预训练模型配置文件中加载 Decoder 模型配置
decoder_config = TrainDecoderConfig.from_json_path("./weights/decoder_config.json").decoder
# 创建并加载 Decoder 模型
decoder = decoder_config.create().cuda()

# 加载预训练的 Decoder 模型参数
decoder_model_state = torch.load("./weights/decoder_latest.pth")["model"]

# 将预训练模型参数应用到 Decoder 的 CLIP 模型中
for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

# 加载预训练的 Decoder 模型参数
decoder.load_state_dict(decoder_model_state, strict=True)

# 创建 DALL-E2 模型,将加载的 Diffusion Prior 和 Decoder 放在一起
dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()

# 生成图像,你需要替换 ['your prompt here'] 为你的提示文本
images = dalle2(
    ['a red car'],
    cond_scale = 2.
).cpu()

print(images.shape)

# 保存图像
for i, img in enumerate(images):
    img_pil = ToPILImage()(img)  # 将张量转换为 PIL 图像
    img_pil.save(f'image_{i}.png')  # 保存 PIL 图像为文件

在这个过程中会有一些报错,可以参考can not generate normal image with pretrained model · Issue #282 · lucidrains/DALLE2-pytorch · GitHub解决,首先我测试的预训练模型比较小,所以效果可能不是那么好,其次是模型生成的很慢.后续还要再研究研究看看,怎么训练

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1466551.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringMVC 学习(二)之第一个 SpringMVC 案例

目录 1 通过 Maven 创建一个 JavaWeb 工程 2 配置 web.xml 文件 3 创建 SpringMVC 配置文件 spring-mvc.xml 4 创建控制器 HelloController 5 创建视图 index.jsp 和 success.jsp 6 运行过程 7 参考文档 1 通过 Maven 创建一个 JavaWeb 工程 可以参考以下博文&#x…

QT Widget自定义菜单

此文以设置QListWidget的自定义菜单为例,其他继承于QWidget的类也都可以按类似的方法去实现。 1、ui文件设置contextMenuPolicy属性为CustomContextMenu 2、添加槽函数 /*** brief onCustomContextMenuRequested 右键弹出菜单* param pos 右键的坐标*/void onCusto…

Stable Diffusion 模型分享:FenrisXL(芬里斯XL)

本文收录于《AI绘画从入门到精通》专栏,专栏总目录:点这里。 文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八案例九案例十

台式电脑电源功率越大越费电吗?装机选购多少W电源

要组装一台电脑,我们首先需要选择硬件。 硬件搭配最关键的一点就是CPU和主板的兼容性。 硬件、电源等之间的平衡都需要仔细考虑。 那么台式电脑电源多大功率合适呢? 下面分享组装电脑电源瓦数选购指南,教您正确选择合适的电源瓦数。 让我们来…

集成TinyMCE富文本编辑器

若依的基础上集成TinyMCE富文本编辑器 前端bootstrap TinyMCE官网链接 TinyMCE所需静态资源下载链接 开源项目-若依链接 将TinyMCE静态资源包放入项目中&#xff1b; 代码引入css&#xff1a; <!-- 引入TinyMCE CSS --><link th:href"{/ajax/libs/tinymce/j…

axios是如何实现的(源码解析)

1 axios的实例与请求流程 在阅读源码之前&#xff0c;先大概了解一下axios实例的属性和请求整体流程&#xff0c;带着这些概念&#xff0c;阅读源码可以轻松不少&#xff01; 下图是axios实例属性的简图。 可以看到axios的实例上&#xff0c;其实主要就这三个东西&#xff1a…

第九节HarmonyOS 常用基础组件24-Navigation

1、描述 Navigation组件一般作为Page页面的根容器&#xff0c;通过属性设置来展示的标题栏、工具栏、导航栏等。 2、子组件 可以包含子组件&#xff0c;推荐与NavRouter组件搭配使用。 3、接口 Navigation() 4、属性 名称 参数类型 描述 title string|NavigationComm…

Python 实现 ATR 指标计算(真实波幅):股票技术分析的利器系列(10)

Python 实现 ATR 指标计算&#xff08;真实波幅&#xff09;&#xff1a;股票技术分析的利器系列&#xff08;10&#xff09; 介绍算法解释 代码rolling函数介绍核心代码 完整代码 介绍 ATR&#xff08;真实波幅&#xff09;是一种技术指标&#xff0c;用于衡量市场波动性的程…

RabbitMQ的死信队列和延迟队列

文章目录 死信队列如何配置死信队列死信队列的应用场景Spring Boot实现RabbitMQ的死信队列 延迟队列方案优劣&#xff1a;延迟队列的实现有两种方式&#xff1a; 死信队列 1&#xff09;“死信”是RabbitMQ中的一种消息机制。 2&#xff09;消息变成死信&#xff0c;可能是由于…

基于Python网络爬虫的IT招聘就业岗位可视化分析推荐系统

文章目录 基于Python网络爬虫的IT招聘就业岗位可视化分析推荐系统项目概述招聘岗位数据爬虫分析系统展示用户注册登录系统首页IT招聘数据开发岗-javaIT招聘数据开发岗-PythonIT招聘数据开发岗-Android算法方面运维方面测试方面招聘岗位薪资多维度精准预测招聘岗位分析推荐 结语…

《TCP/IP详解 卷一》第6章 DHCP

目录 6.1 引言 6.2 DHCP 6.2.1 地址池和租用 6.2.2 DHCP和BOOTP消息格式 6.2.3 DHCP和BOOTP选项 6.2.4 DHCP协议操作 6.2.5 DHCPv6 6.2.6 DCHP中继 6.2.7 DHCP认证 6.2.8 重新配置扩展 6.2.9 快速确认 6.2.10 位置信息&#xff08;LCI和LoST&#xff09; 6.2.11 移…

GPT-SoVITS 快速声音克隆使用案例:webui、api接口

参考: https://github.com/RVC-Boss/GPT-SoVITS 环境: Python 3.10 PyTorch 2.1.2, CUDA 12.0 安装包: 1、使用: 1)下载项目 git clone https://github.com/RVC-Boss/GPT-SoVITS.git2)下载预训练模型 https://huggingface.co/lj1995/GPT-SoVITS 下载模型文件放到GPT…

Vue2响应式原理分析(数据代理与数据劫持)

综述&#xff1a; 我们都知道&#xff0c;每个Vue的应用都是通过new一个Vue构造函数从而创造出来一个vm实例对象&#xff0c;el&#xff08;elect&#xff09;配置项为通过id选择器#root选择index页面中的根dom元素进行绑定&#xff0c;data配置项则为vue模板中用到的源数据。 …

python 层次分析(AHP)

文章目录 一、算法原理二、案例分析2.1 构建指标层判断矩阵2.2 求各指标权重2.2.1 算术平均法&#xff08;和积法&#xff09;2.2.2 几何平均法&#xff08;方根法&#xff09; 2.3 一致性检验2.3.1 求解最大特征根值2.3.2 求解CI、RI、CR值2.3.3 一致性判断 2.4 分别求解方案层…

算法沉淀——FloodFill 算法(leetcode真题剖析)

算法沉淀——FloodFill 算法 01.图像渲染02.岛屿数量03.岛屿的最大面积04.被围绕的区域05.太平洋大西洋水流问题06.扫雷游戏07.衣橱整理 Flood Fill&#xff08;泛洪填充&#xff09;算法是一种图像处理的基本算法&#xff0c;用于填充连通区域。该算法通常从一个种子点开始&am…

【DDD】学习笔记-薪资管理系统的测试驱动开发2

测试驱动开发的过程 满足简单设计并编写新的测试 当代码满足重用性和可读性之后&#xff0c;就应遵循简单设计的第四条原则“若无必要&#xff0c;勿增实体”&#xff0c;不要盲目地考虑为其增加新的软件元素。这时&#xff0c;需要暂时停止重构&#xff0c;编写新的测试。 …

2.23数据结构

单向循环链表 创建单向循环链表&#xff0c;创建节点 &#xff0c;头插&#xff0c;按位置插入&#xff0c;输出&#xff0c;尾删&#xff0c;按位置删除功能 //main.c #include "loop_list.h" int main() {loop_p Hcreate_head();insert_head(H,12);insert_head(…

计算机网络-网络层,运输层,应用层

网络层/网际层 网络层的主要任务包括&#xff1a; 提供逻辑上的端到端通信&#xff1a;网络层负责确定数据的传输路径&#xff0c;使数据能够从源主机传输到目标主机&#xff0c;即实现端到端的通信。数据包的路由和转发&#xff1a;网络层根据目标主机的地址信息&#xff0c…

vue项目使用vue2-org-tree

实现方式 安装依赖 npm i vue2-org-tree使用的vue页面引入 <template><div class"container"><div class"oTree" ><vue2-org-tree name"test":data"data":horizontal"horizontal":collapsable"…

【服务器数据恢复】通过reed-solomon算法恢复raid6数据的案例

服务器数据恢复环境&#xff1a; 一台网站服务器中有一组由6块磁盘组建的RAID6磁盘阵列&#xff0c;操作系统层面运行MySQL数据库和存放一些其他类型文件。 服务器故障&#xff1a; 该服务器在工作过程中&#xff0c;raid6磁盘阵列中有两块磁盘先后离线&#xff0c;不知道是管理…