超详细!DALL · E 文生图模型实践指南

news2024/12/25 9:00:23

最近需要用到 DALL·E的推断功能,在现有开源代码基础上发现还有几个问题需要注意,谨以此篇博客记录之。

我用的源码主要是 https://github.com/borisdayma/dalle-mini 仓库中的Inference pipeline.ipynb 文件。

在这里插入图片描述

运行环境:Ubuntu服务器

⚠️注意:本博客仅涉及 DALL · E 推断,不涉及训练过程。


目录

  • 一、环境配置
  • 二、模型下载
  • 三、程序转换
  • 四、程序运行
  • 五、BUG清除指南


一、环境配置

建议使用anaconda新建一个dalle环境,然后在该环境中进行相关配置,避免与环境中的其他库产生版本冲突。

使用下述命令新建名为dalle的环境:

conda create -n dalle python==3.8.0

在终端分别运行下述命令,安装所需的python库:

# 安装 dalle运行需要的依赖库(注意版本只能是0.3.25)# Required only for colab environments + GPU
pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# 安装 dalle特定的库
pip install dalle-mini
# 安装 VQGAN
pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

PS:如果由于网络连接问题无法通过pip命令下载VQGAN,就采取Plan-B:将仓库 https://github.com/patil-suraj/vqgan-jax 下载到服务器并解压,然后使用cd命令将当前目录到对应的仓库下载路径下,在终端运行python setup.py install安装VQGAN即可。


二、模型下载

由于网络连接问题,我采取「事先把模型下载到本地」的策略对模型进行直接调用,首先要明确的一点是,本项目中使用DALL · E 对图像进行编码,使用VQGAN对图像进行解码,所以我们需要分别下载DALL · E 和 VQGAN 两个模型。

DALL · E 模型下载地址:
mini版本:https://huggingface.co/dalle-mini/dalle-mini/tree/main
mega版本:https://huggingface.co/dalle-mini/dalle-mega/tree/main

VQGAN 模型下载地址:
https://huggingface.co/dalle-mini/vqgan_imagenet_f16_16384/tree/main

下载完毕后,将模型部署到服务器,注意保存路径。


三、程序转换

相较于ipynb文件,我个人更加喜欢操作py文件,所以对于给定的ipynb文件,首先使用命令jupyter nbconvert --to script Inference pipeline.ipynb 将其转为同名py文件,该文件的主要内容如下(不含CLIP排序部分),其中模型路径 DALLE_MODEL和VQGAN_REPO 已改为本地路径(就是第二步中两个模型的保存路径),可以看到文件的注释也比较详细。

# dalle-mini
DALLE_MODEL = "/newdata/SD/dalle-mini/dalle-mini"
DALLE_COMMIT_ID = None
# VQGAN model
VQGAN_REPO = "/newdata/SD/dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

import jax
import jax.numpy as jnp

# check how many devices are available
jax.local_device_count()

# Load models & tokenizer
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
# Load dalle-mini
model, params = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False)
# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False)

# Model parameters are replicated on each device for faster inference.
from flax.jax_utils import replicate
params = replicate(params)
vqgan_params = replicate(vqgan_params)

# Model functions are compiled and parallelized to take advantage of multiple devices.
from functools import partial

# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )

# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

# Keys are passed to the model on each device to generate unique inference per device.
import random

# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)

# ## 🖍 Text Prompt
# Our model requires processing prompts.

from dalle_mini import DalleBartProcessor 
# from transformers import AutoProcessor
processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID)  # force_download=True, , local_only=True
# Let's define some text prompts
prompts = [
    "sunset over a lake in the mountains",
    "the Eiffel tower landing on the moon",
]
# print(prompts)
# Note: we could use the same prompt multiple times for faster inference.
tokenized_prompts = processor(prompts)
# Finally we replicate the prompts onto each device.
tokenized_prompt = replicate(tokenized_prompts)

# ## 🎨 We generate images using dalle-mini model and decode them with the VQGAN.

# number of predictions per prompt
n_predictions = 8

# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0  # 越高,生成的图像越接近 prompt

from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

print(f"Prompts: {prompts}\n")
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
    # get a new key
    key, subkey = jax.random.split(key)  #  jax.device_count()=1,returns the number of available jax devices
    # generate images
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(subkey),
        params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )
    # remove BOS
    encoded_images = encoded_images.sequences[..., 1:]
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    
    for idx, decoded_img in enumerate(decoded_images):
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        images.append(img)
... 

四、程序运行

使用命令 python /newdata/SD/inference_dalle-mini.py 运行程序。理想情况下就能够直接得到dalle生成的图像啦!


五、BUG清除指南

由于外部环境因素和一些不当操作,本人在运行该程序过程中还是遇到一些问题,主要有三个,在此将抱错信息与解决方法一并分享给大家。

  • 因网络问题导致特定文件下载失败,报错信息如下:
...
requests.exceptions.ConnectTimeout: (MaxRetryError("HTTPSConnectionPool(host='huggingface.co', port=443): Max retries exceeded with url: /dalle-mini/dalle-mini/resolve/main/enwiki-words-frequency.txt (Caused by ConnectTimeoutError(<urllib3.connection.HTTPSConnection object at 0x7faae4168460>, 'Connection to huggingface.co timed out. (connect timeout=10)'))"), '(Request ID: 61b7c191-3fb8-4dfa-9025-e9acd4ee4d28)')

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/newdata/SD/inference_dalle-mini.py", line 84, in <module>
    processor = DalleBartProcessor.from_pretrained("/newdata/SD/dalle-mini/dalle-mini", revision=DALLE_COMMIT_ID)  # force_download=True, , local_only=True
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/utils.py", line 25, in from_pretrained
    return super(PretrainedFromWandbMixin, cls).from_pretrained(
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 62, in from_pretrained
    return cls(tokenizer, config.normalize_text, config.max_text_length)
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/processor.py", line 21, in __init__
    self.text_processor = TextNormalizer()
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 215, in __init__
    self._hashtag_processor = HashtagProcessor()
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py", line 25, in __init__
    #     wiki_word_frequency = hf_hub_download(
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/utils/_validators.py", line 118, in _inner_fn
    return fn(*args, **kwargs)
  File "/root/anaconda3/envs/dalle/lib/python3.8/site-packages/huggingface_hub/file_download.py", line 1363, in hf_hub_download
    raise LocalEntryNotFoundError(
huggingface_hub.utils._errors.LocalEntryNotFoundError: An error happened while trying to locate the file on the Hub and we cannot find the requested files in the local cache. Please check your connection and try again or make sure your Internet connection is on.

顺着上面的报错信息,定位到/root/anaconda3/envs/dalle/lib/python3.8/site-packages/dalle_mini/model/text.py文件的如下内容:

...
class HashtagProcessor:
    # Adapted from wordninja library
    # We use our wikipedia word count + a good heuristic to make it work
    def __init__(self):
		wiki_word_frequency = hf_hub_download(
		    "dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt"
		)
		self._word_cost = (
		    l.split()[0]
		    for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
		)
...

于是问题的根源就在于,程序运行到这里时,没有找到本地的enwiki-words-frequency.txt文件(经检查该文件其实是存在本地的,不知为何没有找到,很迷),于是尝试通过联网从huggingface官网下载,但由于网络状况欠佳,联网失败,于是报错。解决办法如下:

...
class HashtagProcessor:
    # Adapted from wordninja library
    # We use our wikipedia word count + a good heuristic to make it work
    def __init__(self):
		wiki_word_frequency = "/newdata/SD/dalle-mini/dalle-mini/enwiki-words-frequency.txt"
		self._word_cost = (
		    l.split()[0]
		    for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines()
		)
...

也就是将enwiki-words-frequency.txt文件的本地路径直接赋值给wiki_word_frequency变量,其余部份保持不变,问题解决。


  • 因安装不当导致的版本冲突问题
FIx for "Couldn't invoke ptxas --version"

这个错误的产生是不同python库安装时带来的版本冲突导致的,DALLE-mini要求jax和jaxlib版本必须为0.3.25,但是通过pip imstall dalle-mini 命令安装后的jaxlib版本为0.4.13,但使用pip install jaxlib的方式并不能找到0.3.25版本的jaxlib,而且会产生与flax、orbax-checkpoint等其他库的版本不兼容问题……在尝试多种方法合理降低jaxlib版本均失败后,发现答案就在ipynb中……也就是:pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

💡启示:要以官方说明文档为主,可以少走很多弯路!!!


  • 彩蛋:一个非常奇怪的错误:
The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/newdata/SD/inference_dalle-mini.py", line 130, in <module>
    decoded_images = p_decode(encoded_images, vqgan_params)
ValueError: pmap got inconsistent sizes for array axes to be mapped:
  * most axes (101 of them) had size 512, e.g. axis 0 of argument params['decoder']['conv_in']['bias'] of type float32[512];
  * some axes (71 of them) had size 3, e.g. axis 0 of argument params['decoder']['conv_in']['kernel'] of type float32[3,3,256,512];
  * some axes (69 of them) had size 256, e.g. axis 0 of argument params['decoder']['up_1']['block_0']['norm1']['bias'] of type float32[256];
  * some axes (67 of them) had size 128, e.g. axis 0 of argument params['decoder']['norm_out']['bias'] of type float32[128];
  * some axes (35 of them) had size 1, e.g. axis 0 of argument indices of type int32[1,2,256];
  * one axis had size 16384: axis 0 of argument params['quantize']['embedding']['embedding'] of type float32[16384,256]

后来发现,是因为之前调试的时候不小心把下面这行代码注释掉了……这个bug排得最辛苦,还挺无语的😂

vqgan_params = replicate(vqgan_params)

PS:程序运行过程中还有一些警告,由下述警告也可以看出jax是属于tensoeflow派别的。

2023-11-07 11:30:35.139851: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.257514: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.258648: W external/org_tensorflow/tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2023-11-07 11:30:35.628768: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: system has unsupported display driver / cuda driver combination
2023-11-07 11:30:35.628915: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 525.53.0 does not match DSO version 530.41.3 -- cannot find working devices in this configuration
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Prompts: ['sunset over a lake in the mountains', 'the Eiffel tower landing on the moon']

  0%|          | 0/8 [00:00<?, ?it/s]
/root/anaconda3/envs/dalle/lib/python3.8/site-packages/jax/_src/ops/scatter.py:87: FutureWarning: scatter inputs have incompatible types: cannot safely cast value from dtype=float16 to dtype=float32. In future JAX releases this will result in an error.
  warnings.warn("scatter inputs have incompatible types: cannot safely cast "

后记:第一次接触到基于jax框架编写的程序,还挺新鲜的,感觉和pytorch有一些不一样的地方。了解到jax是tensorflow的轻量级版本。上述博客内容中如果有个人理解不当之处,还望各位批评指正!

参考链接

  1. python pathlib中Path 的使用(解决不同操作系统的路径问题)_python pathlib.path-CSDN博客
  2. python - vmap gives inconsistent shape error when trying to calculate gradient per sample - Stack Overflow
  3. https://github.com/google/jax/issues/9933

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1181606.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

elementui-plus el-tree组件数据不显示问题解决

当前情况: 显示: 注意看右侧的树是没有文字的,数据已经渲染,个数是对的,但就是没有文字, 解决: 对比以后发现是template中的#default{data}没有写大括号导致的 所以写上大括号后: 正常显示

卷积神经网络中 6 种经典卷积操作

深度学习的模型大致可以分为两类&#xff0c;一类是卷积神经网络&#xff0c;另外一类循环神经网络&#xff0c;在计算机视觉领域应用最多的就是卷积神经网络&#xff08;CNN&#xff09;。CNN在图像分类、对象检测、语义分割等经典的视觉任务中表现出色&#xff0c;因此也早就…

【React-Native开发3D应用】React Native加载GLB格式3D模型并打包至Android手机端

【React-Native开发3D应用】React Native加载GLB格式3D模型并打包至Android手机端 【加载3D模型】**React Native上如何加载glb格式的模型**第零步&#xff0c;选择相关模型第一步&#xff0c;导入相关模型加载库第二步&#xff0c;自定义GLB模型加载钩子第三步&#xff0c;借助…

浅析淘宝为什么会严查套红包行为,如何从技术层面实现红包检测规避

最近不少做淘系电商的商家&#xff0c;遇到了一个普遍的问题就是&#xff1a;订单存在买手套红包导致被平台稽查的情况。这种情况&#xff0c;东哥了解到不是发生在某一两个商家身上&#xff0c;而是一个普遍现象。 下面东哥从为什么会稽查套红包的行为、稽查后会有什么后果、如…

学习使用JS实现Echarts的图表保存为图片功能:saveAsImage和getDataURL

学习使用JS实现Echarts的图表保存为图片功能 接口getDataURL实现思路 需求分析 实际项目开发过程中经常会有图表展示功能&#xff0c;同时为了满足用户需要&#xff0c;会附带着图表导出功能&#xff0c;主要形式就是保存为图片。在Echarts中本身就提供这种配置项&#xff0c;…

期中考试后,如何DIY一个成绩发布系统?

期中考试结束后&#xff0c;对于老师们来说&#xff0c;一项重要的任务就是公布考试成绩。然而&#xff0c;传统的成绩公布方式不仅耗时&#xff0c;而且容易出错。为了提高效率&#xff0c;减少误差&#xff0c;我们可以通过各种代码和Excel来实现一个让学生自助查询成绩的系统…

一物一码需求,标签制作功能轻松解决

许多行业存在为人员、物品、设备等做一物一码标签的需求&#xff0c;可使用草料标签制作功能。直接选择标签样式&#xff0c;填入数据&#xff0c;即可批量生成标签&#xff0c;还可批量排版&#xff0c;更易落地。还可保存标签样式&#xff0c;后续多次复用样式&#xff0c;批…

基于java web的计算机office课程平台设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;Vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…

linux入门---消费者生产者模型模拟实现

目录标题 消费者生产者模型的理解单生产单消费模拟实现blockqueue.cpp准备工作MainCp.cpp的准备工作构造函数和析构函数的模拟实现push函数的实现pop函数的实现poductor_func函数的实现consumer_func函数的实现程序的测试程序改进一程序的改进二程序的改进三 多生产多消费模拟实…

什么是CCS Concepts

在撰写论文时&#xff0c;看到了CCS Concepts&#xff0c;注意这是对自己论文的分类&#xff0c;不能随便填写。 在ACM的网页"http://dl.acm.org/ccs/ccs.cfm"中选择自己论文的分类&#xff1a; 然后点击左侧的“Assign This CCS Concept”&#xff0c;再选择相关性…

【TDK 电容 】介电质 代码 对应温度及变化率

JB 电解质是什么&#xff1f;没找到&#xff0c;只有TDK有&#xff0c;也只有这个温度的区别&#xff0c;并且已经停产在售。 对比发现是mouser网站关于电容的描述错误。下图显示正确的&#xff0c;再然后是错误的。 在TDK官网&#xff0c;这样的描述 温度特性 分类标准代码温…

制作电子画册的有好帮手---FLBOOK

随着互联网的发展&#xff0c;越来越多的人开始使用电子书来阅读书籍。而将PDF文件转换成在线翻页电子书&#xff0c;则是一种非常方便的方式。今天&#xff0c;给大家推荐一个可以将PDF转在线翻页电子书的网站。 这个网站就是FLBOOK在线制作电子杂志平台&#xff0c;只需要三步…

C++——类和对象(初始化列表、匿名对象、static成员、类的隐式类型转换和explicit关键字、内部类)

初始化列表、匿名对象、static成员、类的隐式类型转换和explicit关键字、内部类 本章思维导图&#xff1a; 注&#xff1a;本章思维导图对应的xmind文件和.png文件都已同步导入至资源 文章目录 初始化列表、匿名对象、static成员、类的隐式类型转换和explicit关键字、内部类1.…

案例-注册页面(css)

html页面用css控制样式&#xff0c;画一个注册页面。 页面最终效果如下&#xff1a; 页面代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>注册页面</title> <style>*{…

文献阅读 - JADE:具有可选外部存档的自适应差分进化

文章目录 标题摘要关键字结论研究背景I. INTRODUCTION 常用基础理论知识II. BASIC OPERATIONS OF DEIII. ADAPTIVE DE ALGORITHMSA. DESAPB. FADEC. SaDED. jDE 研究内容、成果IV. JADEA. DE/Current-to-pbestB. Parameter AdaptationC. Explanations of the Parameter Adaptat…

WSGI与ASGI:两种Python Web服务器网关接口的比较

在当今的Web开发领域&#xff0c;选择合适的服务器网关接口&#xff08;Server Gateway Interface&#xff0c;简称SGI&#xff09;对于提高Web应用程序的性能和并发性至关重要。在Python中&#xff0c;有两种常见的SGI&#xff1a;WSGI和ASGI。本文将深入探讨这两种SGI的异同点…

中国人民大学与加拿大女王大学金融硕士——在职读研,让人生的火花迸发

每个人都像是一块未经雕琢的宝石&#xff0c;隐藏着无尽的光芒。然而&#xff0c;生活、工作中的困难、挫折和压力&#xff0c;就像尘土一样&#xff0c;掩盖了我们的闪亮之处。只有当我们冲破这些阻碍&#xff0c;才能让内在的光芒照亮世界。中国人民大学与加拿大女王大学金融…

Q-Vision+CANpro Max总线解决方案

智能联网技术在国内的发展势头迅猛&#xff0c;随着汽车智能化、网联化发展大潮的到来&#xff0c;智能网联汽车逐步成为汽车发展的主要趋势。越来越多整车厂诉求&#xff0c;希望可以提供本土的测量软件&#xff0c;特别是关于ADAS测试。而风丘科技推出的Q-Vision软件不仅可支…

一键批量剪辑:视频随机分割新玩法,高效剪辑不再难

随着视频内容的日益丰富&#xff0c;人们对于视频剪辑的需求也日益增长。而传统的视频剪辑方式往往需要耗费大量的时间和精力&#xff0c;让许多非专业人士望而却步。然而&#xff0c;现在有一款名为“云炫AI智剪”的软件&#xff0c;它为我们提供了一种全新的视频剪辑方式——…

数据结构:AVL树的旋转(平衡搜索二叉树)

1、AVL树简介 AVL树是最先发明的自平衡二叉查找树。在AVL树中任何节点的两个子树的高度最大差别为1&#xff0c;所以它也被称为高度平衡树。增加和删除可能需要通过一次或多次树旋转来重新平衡这个树。AVL树得名于它的发明者G. M. Adelson-Velsky和E. M. Landis&#xff0c;他们…