Stable Diffusion (version x.x) 文生图模型实践指南

news2024/9/21 11:14:20

前言:本篇博客记录使用Stable Diffusion模型进行推断时借鉴的相关资料和操作流程。

相关博客:
超详细!DALL · E 文生图模型实践指南
DALL·E 2 文生图模型实践指南

目录

  • 1. 环境搭建和预训练模型准备
    • 环境搭建
    • 预训练模型下载
  • 2. 代码


1. 环境搭建和预训练模型准备

环境搭建

pip install diffusers transformers accelerate scipy safetensors

预训练模型下载

关于 huggingface 网站总是崩溃的情况,找到一个解决办法,就是可以通过脚本来下载

第一步:安装 huggingface_hub,使用命令 pip install huggingface_hub
第二步:下载具体模型,使用命令 python model_download.py --repo_id model_id,其中,model_id 为要下载的模型,比如SD v2.1 版本的model_id可以是 stabilityai/stable-diffusion-2-1;SD v1.5 版本的model_id可以是 runwayml/stable-diffusion-v1-5. model_id 的查找方式是在huggingface 网站直接搜索需要的模型(如下图),得到的「模型来源/版本」的组合即为所需。

在这里插入图片描述

model_download.py文件来自这个链接。

# usage     : python model_download.py --repo_id repo_id
# example   : python model_download.py --repo_id facebook/opt-350m
import argparse
import time
import requests
import json
import os
from huggingface_hub import snapshot_download
import platform
from tqdm import tqdm
from urllib.request import urlretrieve


def _log(_repo_id, _type, _msg):
    date1 = time.strftime('%Y-%m-%d %H:%M:%S')
    print(date1 + " " + _repo_id + " " + _type + " :" + _msg)


def _download_model(_repo_id, _repo_type):
    if _repo_type == "model":
        _local_dir = 'dataroot/models/' + _repo_id
    else:
        _local_dir = 'dataroot/datasets/' + _repo_id
    try:
        if _check_Completed(_repo_id, _local_dir):
            return True, "check_Completed ok"
    except Exception as e:
        return False, "check_Complete exception," + str(e)
    _cache_dir = 'caches/' + _repo_id

    _local_dir_use_symlinks = True
    if platform.system().lower() == 'windows':
        _local_dir_use_symlinks = False
    try:
        if _repo_type == "model":
            snapshot_download(repo_id=_repo_id, cache_dir=_cache_dir, local_dir=_local_dir, local_dir_use_symlinks=_local_dir_use_symlinks,
                              resume_download=True, max_workers=4)
        else:
            snapshot_download(repo_id=_repo_id, cache_dir=_cache_dir, local_dir=_local_dir, local_dir_use_symlinks=_local_dir_use_symlinks,
                              resume_download=True, max_workers=4, repo_type="dataset")
    except Exception as e:
        error_msg = str(e)
        if ("401 Client Error" in error_msg):
            return True, error_msg
        else:
            return False, error_msg
    _removeHintFile(_local_dir)
    return True, ""


def _writeHintFile(_local_dir):
    file_path = _local_dir + '/~incomplete.txt'
    if not os.path.exists(file_path):
        if not os.path.exists(_local_dir):
            os.makedirs(_local_dir)
        open(file_path, 'w').close()


def _removeHintFile(_local_dir):
    file_path = _local_dir + '/~incomplete.txt'
    if os.path.exists(file_path):
        os.remove(file_path)


def _check_Completed(_repo_id, _local_dir):
    _writeHintFile(_local_dir)
    url = 'https://huggingface.co/api/models/' + _repo_id
    response = requests.get(url)
    if response.status_code == 200:
        data = json.loads(response.text)
    else:
        return False
    for sibling in data["siblings"]:
        if not os.path.exists(_local_dir + "/" + sibling["rfilename"]):
            return False
    _removeHintFile(_local_dir)
    return True


def download_model_retry(_repo_id, _repo_type):
    i = 0
    flag = False
    msg = ""
    while True:
        flag, msg = _download_model(_repo_id, _repo_type)
        if flag:
            _log(_repo_id, "success", msg)
            break
        else:
            _log(_repo_id, "fail", msg)
            if i > 1440:
                msg = "retry over one day"
                _log(_repo_id, "fail", msg)
                break
            timeout = 60
            time.sleep(timeout)
            i = i + 1
            _log(_repo_id, "retry", str(i))
    return flag, msg


def _fetchFileList(files):
    _files = []
    for file in files:
        if file['type'] == 'dir':
            filesUrl = 'https://e.aliendao.cn/' + file['path'] + '?json=true'
            response = requests.get(filesUrl)
            if response.status_code == 200:
                data = json.loads(response.text)
                for file1 in data['data']['files']:
                    if file1['type'] == 'dir':
                        filesUrl = 'https://e.aliendao.cn/' + \
                            file1['path'] + '?json=true'
                        response = requests.get(filesUrl)
                        if response.status_code == 200:
                            data = json.loads(response.text)
                            for file2 in data['data']['files']:
                                _files.append(file2)
                    else:
                        _files.append(file1)
        else:
            if file['name'] != '.gitattributes':
                _files.append(file)
    return _files


def _download_file_resumable(url, save_path, i, j, chunk_size=1024*1024):
    headers = {}
    r = requests.get(url, headers=headers, stream=True, timeout=(20, 60))
    if r.status_code == 403:
        _log(url, "download", '下载资源发生了错误,请使用正确的token')
        return False
    bar_format = '{desc}{percentage:3.0f}%|{bar}|{n_fmt}M/{total_fmt}M [{elapsed}<{remaining}, {rate_fmt}]'
    _desc = str(i) + ' of ' + str(j) + '(' + save_path.split('/')[-1] + ')'
    total_length = int(r.headers.get('content-length'))
    if os.path.exists(save_path):
        temp_size = os.path.getsize(save_path)
    else:
        temp_size = 0
    retries = 0
    if temp_size >= total_length:
        return True
    # 小文件显示
    if total_length < chunk_size:
        with open(save_path, 'wb') as f:
            for chunk in r.iter_content(chunk_size=chunk_size):
                if chunk:
                    f.write(chunk)
        with tqdm(total=1, desc=_desc, unit='MB', bar_format=bar_format) as pbar:
            pbar.update(1)
    else:
        headers['Range'] = f'bytes={temp_size}-{total_length}'
        r = requests.get(url, headers=headers, stream=True,
                         verify=False, timeout=(20, 60))
        data_size = round(total_length / 1024 / 1024)
        with open(save_path, 'ab') as fd:
            fd.seek(temp_size)
            initial = temp_size//chunk_size
            for chunk in tqdm(iterable=r.iter_content(chunk_size=chunk_size), initial=initial, total=data_size, desc=_desc, unit='MB', bar_format=bar_format):
                if chunk:
                    temp_size += len(chunk)
                    fd.write(chunk)
                    fd.flush()
    return True


def _download_model_from_mirror(_repo_id, _repo_type, _token, _e):
    if _repo_type == "model":
        filesUrl = 'https://e.aliendao.cn/models/' + _repo_id + '?json=true'
    else:
        filesUrl = 'https://e.aliendao.cn/datasets/' + _repo_id + '?json=true'
    response = requests.get(filesUrl)
    if response.status_code != 200:
        _log(_repo_id, "mirror", str(response.status_code))
        return False
    data = json.loads(response.text)
    files = data['data']['files']
    for file in files:
        if file['name'] == '~incomplete.txt':
            _log(_repo_id, "mirror", 'downloading')
            return False
    files = _fetchFileList(files)
    i = 1
    for file in files:
        url = 'http://61.133.217.142:20800/download' + file['path']
        if _e:
            url = 'http://61.133.217.139:20800/download' + \
                file['path'] + "?token=" + _token
        file_name = 'dataroot/' + file['path']
        if not os.path.exists(os.path.dirname(file_name)):
            os.makedirs(os.path.dirname(file_name))
        i = i + 1
        if not _download_file_resumable(url, file_name, i, len(files)):
            return False
    return True


def download_model_from_mirror(_repo_id, _repo_type, _token, _e):
    if _download_model_from_mirror(_repo_id, _repo_type, _token, _e):
        return
    else:
        #return download_model_retry(_repo_id, _repo_type)
        _log(_repo_id, "download", '下载资源发生了错误,请使用正确的token')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--repo_id', default=None, type=str, required=True)
    parser.add_argument('--repo_type', default="model",
                        type=str, required=False)  # models,dataset
    # --mirror为从aliendao.cn镜像下载,如果aliendao.cn没有镜像,则会转到hf
    # 默认为True
    parser.add_argument('--mirror', action='store_true',
                        default=True, required=False)
    parser.add_argument('--token', default="", type=str, required=False)
    # --e为企业付费版
    parser.add_argument('--e', action='store_true',
                        default=False, required=False)
    args = parser.parse_args()
    if args.mirror:
        download_model_from_mirror(
            args.repo_id, args.repo_type, args.token, args.e)
    else:
        download_model_retry(args.repo_id, args.repo_type)

2. 代码

Stable Diffusion 完整推断流程如下(from https://huggingface.co/stabilityai/stable-diffusion-2-1):

import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

model_id = "/dataroot/models/stabilityai/stable-diffusion-2-1"  # 预训练模型的下载路径

# Use the DPMSolverMultistepScheduler (DPM-Solver++) scheduler here instead
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
    
image.save("astronaut_rides_horse.png")

参考文献

  1. https://aliendao.cn/model_download.py
  2. https://github.com/Stability-AI/stablediffusion

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1212368.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python基础练习题库实验八

文章目录 前言题目1代码 题目2代码 题目3代码 总结 前言 &#x1f388;关于python小题库的这模块我已经两年半左右没有更新了&#xff0c;主要是在实习跟考研&#xff0c;目前已经上岸武汉某211计算机&#xff0c;目前重新学习这门课程&#xff0c;也做了一些新的题目 &#x…

部署LCM(Latent Consistency Models)实现快速出图

LCM&#xff08;Latent Consistency Models&#xff09;可以通过很少的迭代次数就可以生成高清晰度的图片&#xff0c;目前只可以使用一个模型Dreamshaper_v7&#xff0c;基于SD版本Dreamshaper微调而来的。 LCM模型下载&#xff1a; https://huggingface.co/SimianLuo/LCM_D…

使用flutter的Scaffold脚手架开发一个最简单的带tabbar的app模板

flutter自带的scaffold脚手架可以说还是挺好用的&#xff0c;集成了appBar&#xff0c;还有左侧抽屉&#xff0c;还有底部tabbar&#xff0c;可以说拿来就可以用了啊&#xff0c;所以我今天也体验了一下&#xff0c;做了一个最简单的demo&#xff0c;就当是学习记录了。 效果展…

后端接口错误总结

今天后端错误总结&#xff1a; 1.ConditionalOnExpression(“${spring.kafka.exclusive-group.enable:false}”) 这个标签负责加载Bean&#xff0c;因此这个位置必须打开&#xff0c;如果这个标签不打开就会报错 问题解决&#xff1a;这里的配置在application.yml文件中 kaf…

Spring Framework 简介与起源

Spring是用于企业Java应用程序开发的最流行的应用程序开发框架。全球数百万开发人员使用Spring Framework创建高性能、易于测试和可重用的代码。 Spring Framework是一个开源的Java平台。它最初由Rod Johnson编写&#xff0c;并于2003年6月在Apache 2.0许可下首次发布。 Spri…

laravel日期字段carbon 输出格式转换

/*** The attributes that should be cast.** var array*/ protected $casts [created_at > datetime:Y-m-d, ]; 滑动验证页面https://segmentfault.com/q/1010000043327049

PyQt中QFrame窗口中的组件不显示的原因

文章目录 问题代码&#xff08;例&#xff09;原因和解决方法 问题代码&#xff08;例&#xff09; from PyQt5.QtWidgets import * from PyQt5.QtGui import QFont, QIcon, QCursor, QPixmap import sysclass FrameToplevel(QFrame):def __init__(self, parentNone):super().…

Newman

近期在复习Postman的基础知识&#xff0c;在小破站上跟着百里老师系统复习了一遍&#xff0c;也做了一些笔记&#xff0c;希望可以给大家一点点启发。 一&#xff09;如何安装Newman 1、下载并安装NodeJs 在官网下载NodeJs&#xff1a; Download | Node.js&#xff08;官网的…

首周聚焦百度智能云千帆大模型平台使用,《大模型应用实践》实训营11月16日开讲!

百度智能云千帆大模型平台官方出品的《大模型应用实践》实训营本周正式上线&#xff01;这是百度智能云推出的首个系列课程&#xff0c;课程内容满满干货&#xff01; 11月16日本周四即将开课&#xff0c;首周由百度智能云千帆大模型平台产品经理以及百度智能云千帆资深用户知…

DM8数据守护集群安装部署_手动切换

一.安装前准备 1.1 硬件环境建议 数据守护集群安装部署前需要额外注意网络环境和磁盘 IO 配置情况&#xff0c;其他环境配置项建议请参考安装前准备工作。 1.1.1 网络环境 心跳网络对 mal 通讯系统的影响非常大&#xff0c;如果网络丢包或者延迟较大&#xff0c;则会严重影…

Linux C 进程间通信

进程间通信 概述进程间通信方式管道概述管道函数无名管道 pipe有名管道 makefifo删除有名管道 rmove 有名管道实现 双人无序聊天 例子 信号信号概述信号处理过程信号函数传送信号给指定的进程 kill注册信号 signal查询或设置信号处理方式 sigaction设置信号传送闹钟 alarm 有名…

天软特色因子看板 (2023.11 第10期)

该因子看板跟踪天软特色因子A05006(近一月单笔流入流出金额之比(%)该因子为近一个月单笔流入流出金额之比(%)均值因子&#xff0c;用以刻画 市场日内分时成交中流入、流出成交金额的差异性特点&#xff0c;发掘市场主力资金的作用机制。 今日为该因子跟踪第10期&#xff0c;跟踪…

《红蓝攻防对抗实战》十三.内网穿透之利用HTTP协议进行隧道穿透

内网穿透之利用HTTP协议进行隧道穿透 一.前言二.前文推荐三.利用HTTP协议进行隧道穿透1. Reduh进行端口转发2. ReGeorg进行隧道穿透3. Neo-reGeorg加密隧道穿透4. Tunna进行隧道穿透5 .Abptts加密隧道穿透6. Pivotnacci加密隧道穿透 四.本篇总结 一.前言 本文介绍了利用HTTP协…

python 实验7

姓名&#xff1a;轨迹 学号&#xff1a;6666 专业年级&#xff1a;2021级软件工程 班级&#xff1a; 66 实验的准备阶段 (指导教师填写) 课程名称 Python开发与应用 实验名称 文件异常应用 实验目的 &#xff08;1&#xff09;掌握基本文件读写的方式&#xff1b; …

基于消息队列+多进程编写的银行模拟系统

银行模拟系统 概述客户端 client.c服务端 serve.c开户 enroll.c存款 save.c转账 transfer.c取款 take.cmakefile文件 概述 该案例大体过程为&#xff0c;服务器先启动&#xff0c;初始化消息队列和信号&#xff0c;用多线程技术启动开户、存钱、转账、取钱模块&#xff0c;并且…

Python基础-解释器安装

一、下载 网址Welcome to Python.orgPython更新到13了&#xff0c;我们安装上一个12版本。 这里我保存到网盘里了&#xff0c;不想从官网下的&#xff0c;可以直接从网盘里下载。 链接&#xff1a;百度网盘 请输入提取码百度网盘为您提供文件的网络备份、同步和分享服务。空间…

IDEA从Gitee拉取代码,推送代码教程

打开IDEA&#xff0c;选择Get from Version Control 输入Gitee 仓库项目的URL地址 URL地址输入后点击Clone&#xff0c;即拉取成功 向Gitee提交推送代码 右键选中项目&#xff0c;选中Git 第一步先点击 Add 第二步 点击Commit填写提交信息&#xff0c;点击Commit就会出现下面…

Jenkins的一些其他操作

Jenkins的一些其他操作 1、代码仓库Gogs的搭建与配置 Gogs 是一款极易搭建的自助 Git 服务&#xff0c;它的目标在于打造一个最简单、快速和轻松的方式搭建 Git 服务。使用 Go 语言开发的它能够通过独立的二进制进行分发&#xff0c;支持了 Go 语言支持的所有平台&#xff0…

Vue修饰符(Vue事件修饰符、Vue按键修饰符)

目录 前言 Vue事件修饰符 列举较常用的事件修饰符 .stop .prevent .capture .once Vue按键修饰符 四个特殊键 获取某个键的按键修饰符 前言 本文介绍Vue修饰符&#xff0c;包括Vue事件修饰符以及按键修饰符 Vue事件修饰符 列举较常用的事件修饰符 .stop: …

【Linux基础IO篇】深入理解文件系统、动静态库

【Linux基础IO篇】深入理解文件系统、动静态库 目录 【Linux基础IO篇】深入理解文件系统、动静态库再次理解文件系统操作系统内存管理模块&#xff08;基础&#xff09;操作系统如何管理内存 Linux中task_struct源码结构 动态库和静态库动静态库介绍&#xff1a;生成静态库库搜…