DALLE 2 文生图模型实践指南

news2024/12/23 13:55:46

前言:最近在运行dalle2模型进行推断,本篇博客记录相关资料。

相关博客:超详细!DALL · E 文生图模型实践指南


在这里插入图片描述

目录

  • 1. 环境搭建和预训练模型准备
    • 环境搭建
    • 预训练模型下载
  • 2. 代码
  • 3. BUG&DEBUG
    • URLError
    • RuntimeError
    • CUDA error


1. 环境搭建和预训练模型准备

本文使用的代码仓库为:https://github.com/lucidrains/DALLE2-pytorch

环境搭建

pip install dalle2-pytorch

预训练模型下载

地址:https://huggingface.co/laion/DALLE2-PyTorch

2. 代码

DALLE2 for inference 完整推断流程如下(from cest_andre):

import torch
from torchvision.transforms import ToPILImage
from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdapter, Decoder, DALLE2
from dalle2_pytorch.train_configs import TrainDiffusionPriorConfig, TrainDecoderConfig


prior_config = TrainDiffusionPriorConfig.from_json_path("weights/prior_config.json").prior
prior = prior_config.create().cuda()

prior_model_state = torch.load("weights/prior_latest.pth")
prior.load_state_dict(prior_model_state, strict=True)

decoder_config = TrainDecoderConfig.from_json_path("weights/decoder_config.json").decoder
decoder = decoder_config.create().cuda()

decoder_model_state = torch.load("weights/decoder_latest.pth")["model"]

for k in decoder.clip.state_dict().keys():
    decoder_model_state["clip." + k] = decoder.clip.state_dict()[k]

decoder.load_state_dict(decoder_model_state, strict=True)

dalle2 = DALLE2(prior=prior, decoder=decoder).cuda()

images = dalle2(
    ['your prompt here'],
    cond_scale = 2.
).cpu()

print(images.shape)

for img in images:
    img = ToPILImage()(img)
    img.show()

3. BUG&DEBUG

URLError

报错信息如下:

Traceback (most recent call last):
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1350, in do_open
    h.request(req.get_method(), req.selector, req.data, headers,
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1255, in request
    self._send_request(method, url, body, headers, encode_chunked)
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1301, in _send_request
    self.endheaders(body, encode_chunked=encode_chunked)
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1250, in endheaders
    self._send_output(message_body, encode_chunked=encode_chunked)
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1010, in _send_output
    self.send(msg)
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 950, in send
    self.connect()
  File "/root/anaconda3/envs/ldm/lib/python3.8/http/client.py", line 1424, in connect
    self.sock = self._context.wrap_socket(self.sock,
  File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 500, in wrap_socket
    return self.sslsocket_class._create(
  File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 1040, in _create
    self.do_handshake()
  File "/root/anaconda3/envs/ldm/lib/python3.8/ssl.py", line 1309, in do_handshake
    self._sslobj.do_handshake()
ConnectionResetError: [Errno 104] Connection reset by peer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/newdata/SD/extra/dalle2_cest.py", line 11, in <module>
    prior = prior_config.create().cuda()
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 185, in create
    clip = self.clip.create()
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/train_configs.py", line 122, in create
    return OpenAIClipAdapter(self.model)
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/dalle2_pytorch/dalle2_pytorch.py", line 313, in __init__
    openai_clip, preprocess = clip.load(name)
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/clip/clip.py", line 122, in load
    model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/clip/clip.py", line 59, in _download
    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 222, in urlopen
    return opener.open(url, data, timeout)
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 525, in open
    response = self._open(req, data)
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 542, in _open
    result = self._call_chain(self.handle_open, protocol, protocol +
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 502, in _call_chain
    result = func(*args)
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1393, in https_open
    return self.do_open(http.client.HTTPSConnection, req,
  File "/root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py", line 1353, in do_open
    raise URLError(err)
urllib.error.URLError: <urlopen error [Errno 104] Connection reset by peer>

我使用的是https://github.com/lucidrains/DALLE2-pytorch这个网址。

找到 /root/anaconda3/envs/ldm/lib/python3.8/urllib/request.py 中对应的位置,我这里是第1349行,修改方式也在下面代码中一并给出。

try:
    h.request(req.get_method(), req.selector, req.data, headers,
              encode_chunked=req.has_header('Transfer-encoding'))
    time.sleep(0.5)  # 添加的一行
except OSError as err: # timeout error
    raise URLError(err)

RuntimeError

Traceback (most recent call last):
  File "/newdata/SD/extra/dalle2_cest.py", line 14, in <module>
    prior.load_state_dict(prior_model_state, strict=True)
  File "/root/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1604, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for DiffusionPrior:
        Missing key(s) in state_dict: "net.null_text_encodings", "net.null_text_embeds", "net.null_image_embed". 
        Unexpected key(s) in state_dict: "net.null_text_embed". 

解决办法:load_state_dict()函数中的 strict=True 改为 strict=False,如下:

...
prior.load_state_dict(prior_model_state, strict=False)

decoder.load_state_dict(decoder_model_state, strict=False)
...

CUDA error

RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

解决方法:版本不匹配,更换与系统cuda相匹配的pytorch版本。比如我的cuda版本是12.0,可以使用如下命令安装pytorch:

pip install torch==2.0.0+cu118 torchvision==0.15.1+cu118 torchaudio==2.0.1+cu118 -f https://download.pytorch.org/whl/torch_stable.html

到这里,模型就可以完成推断过程啦~嘻嘻!


参考链接

  1. https://github.com/lucidrains/DALLE2-pytorch/issues/282
  2. python requests请求报错ConnectionError: (‘Connection aborted.‘, error(104, ‘Connection reset by peer‘))_铁朵斯提的博客-CSDN博客
  3. GPU版本pytorch(Cuda12.1)清华源快速安装一步一步教!小白教学~_清华源安装torch-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1205128.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【C语言 | 指针】C指针详解(经典,非常详细)

&#x1f601;博客主页&#x1f601;&#xff1a;&#x1f680;https://blog.csdn.net/wkd_007&#x1f680; &#x1f911;博客内容&#x1f911;&#xff1a;&#x1f36d;嵌入式开发、Linux、C语言、C、数据结构、音视频&#x1f36d; &#x1f923;本文内容&#x1f923;&a…

这样书写Python代码的方式,实在是太优雅了~

文章目录 前言一、在Python中配合pipe灵活使用链式写法二 、pipe中常用的管道操作函数1.使用traverse()展平嵌套数组2.使用dedup()进行顺序去重3.使用filter()进行值过滤4.使用groupby()进行分组运算5.使用select()对上一步结果进行自定义遍历运算6.使用sort()进行排序 总结关于…

thinkphp8 多级控制器调用

在使用这个目录的时候正常访问时 http://tp.com/index.php/user2.login/index, 这个多级目录时不允许使用的&#xff0c;想要使用就的使用路由 在route/app.php 里面配置&#xff1a;Route::get(user2/login,user2.Login/index); 第一个参数时外部访问参数&#xff0c;第二个是…

Android权限动态申请(包括悬浮窗)

目录 效果图 一、环境配置 二、新建工具类 三、开始使用 备注&#xff08;一&#xff09;&#xff1a;用户手动设置权限 手动设置效果图 备注&#xff08;二&#xff09;&#xff1a;在Fragment中如何调用动态权限申请 备注&#xff08;三&#xff09;&#xff1a;悬浮窗…

SDL2 显示文字

1.简介 SDL本身没有显示文字功能&#xff0c;它需要用扩展库SDL_ttf来显示文字。ttf是True Type Font的缩写&#xff0c;ttf是Windows下的缺省字体&#xff0c;它有美观&#xff0c;放大缩小不变形的优点&#xff0c;因此广泛应用很多场合。 使用ttf库的第一件事要从Windows的…

【leetcode】8.字符串转换整数

题目 请你来实现一个 myAtoi(string s) 函数&#xff0c;使其能将字符串转换成一个 32 位有符号整数&#xff08;类似 C/C 中的 atoi 函数&#xff09;。 函数 myAtoi(string s) 的算法如下&#xff1a; 读入字符串并丢弃无用的前导空格 检查下一个字符&#xff08;假设还未…

【文末送书】深入浅出嵌入式虚拟机原理

欢迎关注博主 Mindtechnist 或加入【智能科技社区】一起学习和分享Linux、C、C、Python、Matlab&#xff0c;机器人运动控制、多机器人协作&#xff0c;智能优化算法&#xff0c;滤波估计、多传感器信息融合&#xff0c;机器学习&#xff0c;人工智能等相关领域的知识和技术。关…

基于回溯搜索算法优化概率神经网络PNN的分类预测 - 附代码

基于回溯搜索算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于回溯搜索算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于回溯搜索优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要&#xff1a;针对PNN神…

二维码智慧门牌管理系统升级解决方案:数据可视化助力运营精准决策

文章目录 前言一、升级版二维码智慧门牌管理系统的特点二、数据可视化助力运营精准决策 前言 随着科技的不断进步&#xff0c;传统的门牌管理系统已经无法满足现代社会的需求。为了提高管理效率&#xff0c;减少人力成本&#xff0c;我们引入了升级版的二维码智慧门牌管理系统…

[PyTorch][chapter 62][强化学习-基本概念]

前言&#xff1a; 目录&#xff1a; 强化学习概念 马尔科夫决策 Bellman 方程 格子世界例子 一 强化学习 强化学习 必须在尝试之后&#xff0c;才能发现哪些行为会导致奖励的最大化。 当前的行为可能不仅仅会影响即时奖赏&#xff0c;还有影响下一步奖赏和所有奖赏 强…

如何应对招聘中的职业性格测评?

很多同学听说要做性格测试&#xff0c;第一反应是如何让自己的性格让HR看起来更好....没办法为了顺利入职&#xff0c;咱不能老实作答&#xff0c;因为性格测评搞不好是真刷人的&#xff0c;刷人的&#xff08;无视你的专业能力和笔试成绩&#xff09;..... 可是....很多性格测…

eNSP-打开华为USG6000V1防火墙web管理页面方法

一、本地打开防火墙web管理页面 1.先在ensp中启动USG6000V1防火墙&#xff0c;启动后&#xff0c;需要输入原始username和password&#xff08;username&#xff1a;admin&#xff0c;password&#xff1a;Admin123&#xff09;&#xff0c;并修改原始密码后&#xff0c;才能配…

SQL学习(CTFhub)整数型注入,字符型注入,报错注入 -----手工注入+ sqlmap注入

目录 整数型注入 手工注入 为什么要将1设置为-1呢&#xff1f; sqlmap注入 sqlmap注入步骤&#xff1a; 字符型注入 手工注入 sqlmap注入 报错注入 手工注入 sqlmap注入 整数型注入 手工注入 先输入1 接着尝试2&#xff0c;3&#xff0c;2有回显&#xff0c;而3没有回显…

做一个springboot用户信息模块

目录 用户信息部分 1、获取用户详细信息 前言 代码分析 代码实现 测试 2、更新用户信息 前言 代码实现 测试 3、更新用户头像 前言 代码实现 测试 4、更新用户密码 前言 代码实现 测试 用户信息部分 1、获取用户详细信息 前言 承接上一篇博客登录注册功能…

快速批量去除文件夹名称中多余重复文字!一键轻松优化文件夹命名!

您是否曾经因为文件夹名称中多余重复文字而烦恼&#xff1f;是否因为文件夹重命名而浪费大量时间&#xff1f;现在&#xff0c;我们为您推荐一款全新的文件夹批量改名工具——快速批量去除文件夹名称中多余重复文字&#xff0c;轻松实现文件夹改名优化&#xff0c;让您的整理效…

Leetcode_2:两数相加

题目描述&#xff1a; 给你两个 非空 的链表&#xff0c;表示两个非负的整数。它们每位数字都是按照 逆序 的方式存储的&#xff0c;并且每个节点只能存储 一位 数字。 请你将两个数相加&#xff0c;并以相同形式返回一个表示和的链表。 你可以假设除了数字 0 之外&#xff…

106.am40刷机(linux)折腾记2-前期的准备工作2-软件使用

最终的目标是刷入firefly的3399的镜像&#xff0c;同时更新内核到linux5.10版本&#xff08;4.4的内核应该是相同的方法&#xff0c;我目前没有去折腾&#xff0c;暂时不用了&#xff09;。 1. 平台&#xff1a; rk3399 am40 4g32g 2. 内核&#xff1a;暂无 3. 交叉编译工…

数据结构----顺序栈的操作

1.顺序栈的存储结构 typedef int SElemType; typedef int Status; typedef struct{SElemType *top,*base;//定义栈顶和栈底指针int stacksize;//定义栈的容量 }SqStack; 2.初始化栈 Status InitStack(SqStack &S){//初始化一个空栈S.basenew SElemType[MAXSIZE];//为顺序…

macOS文本编辑器 BBEdit 最新 for mac

BBEdit是一款功能强大的文本编辑器&#xff0c;适用于Mac操作系统。它由Bare Bones Software开发&#xff0c;旨在为开发者和写作人员提供专业级的文本编辑工具。 以下是BBEdit的一些主要特点和功能&#xff1a; 多语言支持&#xff1a;BBEdit支持多种编程语言和标记语言&…

jstack java堆栈跟踪工具

jstack java堆栈跟踪工具 1、jstack介绍 jstack&#xff08;stack trace for java&#xff09;是java虚拟机自带的一种堆栈跟踪工具。 jstack主要用于生成java虚拟机当前时刻的线程快照&#xff0c;线程快照是当前java虚拟机内每一条线程正在执行的方法 堆栈的集合&#xf…