【AI 绘画】模型转换与快速生图(基于diffusers)

news2024/12/27 15:41:11

AI 绘画- 模型转换与快速生图(基于diffusers)

1. 本章介绍

本次主要展示一下不同框架内文生图模型转换,以及快速生成图片的方法。

SDXL文生图

2. sdxl_lightning基本原理

模型基本原理介绍如下

利用蒸馏方法获取小参数模型。首先,论文从128步直接蒸馏到32步,并使用MSE损失。在早期阶段,论文发现MSE已足够。此外,在此阶段,论文仅应用了无分类器指导(CFG),并使用了6的指导尺度,而没有使用任何负面提示。

接着,论文通过对抗性损失按照以下顺序进一步减少步数:32 → 8 → 4 → 2 → 1。在每个阶段,论文首先使用条件目标进行训练,以保持概率流动,然后使用无条件目标进行训练,以放松模式覆盖。

在每个阶段,论文首先使用LoRA并结合这两个目标进行训练,然后合并LoRA,并进一步用无条件目标训练整个UNet。论文发现微调整个UNet可以获得更好的性能,而LoRA模块可以用于其他基础模型。论文的LoRA设置与LCM-LoRA 相同,即在所有卷积和线性权重上使用64的秩,但不包括输入和输出卷积以及共享的时间嵌入线性层。论文没有在判别器上使用LoRA,并且在每个阶段都会重新初始化判别器。

3. 环境安装

diffusers是Hugging Face推出的一个diffusion库,它提供了简单方便的diffusion推理训练pipe,同时拥有一个模型和数据社区,代码可以像torchhub一样直接从指定的仓库去调用别人上传的数据集和pretrain checkpoint。除此之外,安装方便,代码结构清晰,注释齐全,二次开发会十分有效率。

# pip
pip install --upgrade diffusers[torch]
# conda
conda install -c conda-forge diffusers

4. 代码实现

主要测试代码:

4.1 sdxl_lightning文生图


from diffusers import DPMSolverMultistepScheduler,UNet2DConditionModel,StableDiffusionXLPipeline,DiffusionPipeline
import torch
from safetensors.torch import load_file

device = "cuda"

# load both base & refiner
# stabilityai/stable-diffusion-xl-base-1.0
# base = DiffusionPipeline.from_pretrained(
#     "./data/data282269/",device_map=None,torch_dtype=torch.float16, variant="fp16", use_safetensors=True
# )
# !unzip  ./data/data283423/SDXL.zip -D ./data/data283423/
# load base model
unet = UNet2DConditionModel.from_config("./data/data283423/SDXL/unet/config.json").to( device, torch.float16)
unet.load_state_dict(load_file("./data/data283423/sdxl_lightning_4step_unet.safetensors", device= device))


base = StableDiffusionXLPipeline.from_pretrained(
    "./data/data283423/SDXL/", unet=unet, torch_dtype=torch.float16, variant="fp16"
).to( device)

# # scheduler
# base.scheduler = DPMSolverMultistepScheduler.from_config(
#     base.scheduler.config, timestep_spacing="trailing"
# )

base.to("cuda")


# Define how many steps and what % of steps to be run on each experts (80/20) here
n_steps = 4
high_noise_frac = 0.8

prompt = "masterpiece, best quality,Realistic, cinematic quality,A majestic lion jumping from a big stone at night "#"A majestic lion jumping from a big stone at night"
negative_prompt = ('flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,')
# run both experts
image = base(
    prompt=prompt,
    negative_prompt = negative_prompt,
    num_inference_steps=n_steps,
  #  denoising_end=high_noise_frac,
    #output_type="latent",
).images[0]

image.save("./data/section-1/h5.png")

4.2 safetensors模型加载

如果想将safetensors模型加载到diffusers中,需要使用如下代码


pipeline = AutoPipelineForImage2Image.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)

转换为


from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt

base = download_from_original_stable_diffusion_ckpt(from_safetensors = True,
   checkpoint_path_or_dict = "./data/data282269/SDXL_doll.safetensors"
)

4.2 safetensors模型转换

如果想将safetensors模型转化为diffusers常用格式,需要使用如下代码


"""Conversion script for the LDM checkpoints."""

import argparse
import importlib

import torch

from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--checkpoint_path", default="./data/data282269/SDXL_doll.safetensors", type=str, help="Path to the checkpoint to convert." #required=True, 
    )
    # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml
    parser.add_argument(
        "--original_config_file",
        default=None,
        type=str,
        help="The YAML config file corresponding to the original architecture.",
    )
    parser.add_argument(
        "--config_files",
        default=None,
        type=str,
        help="The YAML config file corresponding to the architecture.",
    )
    parser.add_argument(
        "--num_in_channels",
        default=None,
        type=int,
        help="The number of input channels. If `None` number of input channels will be automatically inferred.",
    )
    parser.add_argument(
        "--scheduler_type",
        default="pndm",
        type=str,
        help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']",
    )
    parser.add_argument(
        "--pipeline_type",
        default=None,
        type=str,
        help=(
            "The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'"
            ". If `None` pipeline will be automatically inferred."
        ),
    )
    parser.add_argument(
        "--image_size",
        default=None,
        type=int,
        help=(
            "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
            " Base. Use 768 for Stable Diffusion v2."
        ),
    )
    parser.add_argument(
        "--prediction_type",
        default=None,
        type=str,
        help=(
            "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
            " Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2."
        ),
    )
    parser.add_argument(
        "--extract_ema",
        action="store_true",
        help=(
            "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
            " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
            " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
        ),
    )
    parser.add_argument(
        "--upcast_attention",
        action="store_true",
        help=(
            "Whether the attention computation should always be upcasted. This is necessary when running stable"
            " diffusion 2.1."
        ),
    )
    parser.add_argument(
        "--from_safetensors",
        default= "true",
      #  action="store_true",
        help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.",
    )
    parser.add_argument(
        "--to_safetensors",
        action="store_true",
        help="Whether to store pipeline in safetensors format or not.",
    )
    parser.add_argument("--dump_path", default="./data/data282269/", type=str,  help="Path to the output model.")
    parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
    parser.add_argument(
        "--stable_unclip",
        type=str,
        default=None,
        required=False,
        help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.",
    )
    parser.add_argument(
        "--stable_unclip_prior",
        type=str,
        default=None,
        required=False,
        help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.",
    )
    parser.add_argument(
        "--clip_stats_path",
        type=str,
        help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.",
        required=False,
    )
    parser.add_argument(
        "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint."
    )
    parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
    parser.add_argument(
        "--vae_path",
        type=str,
        default=None,
        required=False,
        help="Set to a path, hub id to an already converted vae to not convert it again.",
    )
    parser.add_argument(
        "--pipeline_class_name",
        type=str,
        default=None,
        required=False,
        help="Specify the pipeline class name",
    )

    args = parser.parse_args()

    if args.pipeline_class_name is not None:
        library = importlib.import_module("diffusers")
        class_obj = getattr(library, args.pipeline_class_name)
        pipeline_class = class_obj
    else:
        pipeline_class = None

    pipe = download_from_original_stable_diffusion_ckpt(
        checkpoint_path_or_dict=args.checkpoint_path,
        original_config_file=args.original_config_file,
        config_files=args.config_files,
        image_size=args.image_size,
        prediction_type=args.prediction_type,
        model_type=args.pipeline_type,
        extract_ema=args.extract_ema,
        scheduler_type=args.scheduler_type,
        num_in_channels=args.num_in_channels,
        upcast_attention=args.upcast_attention,
        from_safetensors=args.from_safetensors,
        device=args.device,
        stable_unclip=args.stable_unclip,
        stable_unclip_prior=args.stable_unclip_prior,
        clip_stats_path=args.clip_stats_path,
        controlnet=args.controlnet,
        vae_path=args.vae_path,
        pipeline_class=pipeline_class,
    )

    if args.half:
        pipe.to(dtype=torch.float16)

    if args.controlnet:
        # only save the controlnet model
        pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
    else:
        pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

5. 资源链接

https://www.liblib.art/modelinfo/8345679083144158adb64b80c58e3afd

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2054077.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

三、前后端分离通用权限系统(3)

🌻🌻目录 一、角色管理1.1、测试 controller 层1.2、整合 Swagger21.2.1、Swagger 介绍1.2.2、集成 knife4j1.2.2.1 添加依赖1.2.2.2 添加 knife4j 配置类1.2.2.3 Controller 层添加注解1.2.2.4、测试 1.3、定义统一返回结果对象1.3.1、定义统一返回结果…

备战秋招60天算法挑战,Day21

题目链接: https://leetcode.cn/problems/number-of-1-bits/ 视频题解: https://www.bilibili.com/video/BV1ir421M7XU/ LeetCode 191.位1的个数 题目描述 编写一个函数,输入是一个无符号整数 (以二进制串的形式)&am…

C语言 ——— 学习并使用calloc和realloc函数

目录 calloc函数的功能 学习并使用calloc函数​编辑 realloc函数的功能 学习并使用realloc函数​编辑 calloc函数的功能 calloc函数的功能和malloc函数的功能类似,于malloc函数的区别只在于calloc函数会再返回地址之前把申请的空间的每个字节初始化为全0 C语言…

tweens运动详解

linear 线性匀速运动效果Sine.easeIn 正弦曲线的缓动(sin(t))/ 从0开始加速的缓动,也就是先慢后快Sine.easeOut 正弦曲线的缓动(sin(t))/ 减速到0的缓动,也就是先快后慢Sine.easeInOut 正弦曲线的缓动(sin(t))/ 前半段从0开始加速,后半段减速到0的缓动Quad.easeIn 二次…

c语言基础-------指针变量和变量指针

在 C 语言中,“变量指针”和“指针变量”这两个术语虽然经常交替使用,但它们的侧重点有所不同。 指针变量 “指针变量”是指其值为内存地址的变量。指针变量的类型定义了它所指向的数据类型,例如 int * 是一个指向整型数据的指针变量。 以下是一个指针变量的例子: int v…

数据埋点系列 16| 数据可视化高级技巧:从洞察到视觉故事

数据可视化是将复杂数据转化为直观、易懂的视觉表现的艺术和科学。本文将探讨一些高级的数据可视化技巧,帮助您创建更具吸引力和洞察力的数据展示。 目录 1. 高级图表类型1.1 桑基图(Sankey Diagram)1.2 树状图(Treemap&#xf…

3、目标定位(视觉测距)

目标定位的目的:获取物品相对于视觉模块的三维坐标,并将其转换为物品相对于机械臂坐标原点的三维坐标。 要获取物品三维坐标,则首先要测量物品距离摄像头的距离,又因为摄像头安装在机械臂末端上方,所以获取物品相对于摄…

基于springboot的高校学生服务平台的设计与实现--附源码91686

目录 1 绪论 1.1 选题背景与意义 1.2国内外研究现状 1.3论文结构与章节安排 2系统分析 2.1 可行性分析 2.2 系统流程分析 2.2.1系统开发流程 2.2.2 用户登录流程 2.2.3 系统操作流程 2.2.4 添加信息流程 2.2.5 修改信息流程 2.2.6 删除信息流程 2.3 系统功能分析 …

代码随想录算法训练营第二十天| 235. 二叉搜索树的最近公共祖先 701.二叉搜索树中的插入操作 450.删除二叉搜索树中的节点

目录 一、LeetCode 235. 二叉搜索树的最近公共祖先思路:C代码 二、LeetCode 701.二叉搜索树中的插入操作思路C代码 三、LeetCode 450.删除二叉搜索树中的节点思路C代码 总结 一、LeetCode 235. 二叉搜索树的最近公共祖先 题目链接:LeetCode 235. 二叉搜…

C语言:for、while、do-while循环语句

目录 前言 一、while循环 1.1 while语句的执行流程 1.2 while循环的实践 1.3 while循环中的break和continue 1.3.1 break 1.3.2 continue 二、for循环 2.1 语法形式 2.2 for循环的执行流程 2.3 for循环的实践 2.4 for循环中的break和continue 2.4.1 break 2.4.2 …

Java数组03:数组边界、数组的使用

本节内容视频链接:https://www.bilibili.com/video/BV12J41137hu?p55&vd_sourceb5775c3a4ea16a5306db9c7c1c1486b5https://www.bilibili.com/video/BV12J41137hu?p55&vd_sourceb5775c3a4ea16a5306db9c7c1c1486b5 1.数组边界 数组下标的合法区间[ 0, Len…

综合监管云平台存在信息泄露漏洞_中科智远综合监管云平台

0x01阅读须知 本文章仅供参考,此文所提供的信息只为网络安全人员对自己所负责的网站、服务器等(包括但不限于)进行检测或维护参考。本文章仅用于信息安全防御技术分享,因用于其他用途而产生不良后果,作者不承担任何法律责任&#…

昇腾 - 快速理解AscendCL(Ascend Computing Language)基础概念的HelloWord

昇腾 - 快速理解AscendCL(Ascend Computing Language)基础概念的HelloWord flyfish AscendCL(Ascend Computing Language)是一套用于在昇腾平台上开发深度神经网络应用的C语言API库,提供运行资源管理、内存管理、模型…

鸿蒙(API 12 Beta3版)【录像流二次处理(C/C++)】媒体相机开发指导

通过ImageReceiver创建录像输出,获取录像流实时数据,以供后续进行图像二次处理,比如应用可以对其添加滤镜算法等。 开发步骤 导入NDK接口,接口中提供了相机相关的属性和方法,导入方法如下。 // 导入NDK接口头文件#in…

ArcGIS Pro基础:软件的常用设置:中文语言、自动保存、默认底图

上图所示,在【选项】(Options)里找到【语言】设置,将语言切换为中文选项,记得在安装软件时,需要提前安装好ArcGIS语言包。 上图所示,在【选项】里找到【编辑】设置,可以更改软件默认…

Java面试八股之如何保证消息队列中消息不重复消费

如何保证消息队列中消息不重复消费 要保证消息队列中的消息不被重复消费,通常需要从以下几个方面来着手: 消息确认机制: 对于像RabbitMQ这样的消息队列系统,可以使用手动确认(manual acknowledge)机制来…

Eureka原理与实践:构建高效的微服务架构

Eureka原理与实践:构建高效的微服务架构 Eureka的核心原理Eureka Server:服务注册中心Eureka Client:服务提供者与服务消费者 Eureka的实践应用集成Eureka到Spring Cloud项目中创建Eureka Server创建Eureka Client(服务提供者&…

ISA95 企业控制集成标准

ANSI/ISA-95 企业控制系统集成介绍及其全系列最新标准下载(转)https://www.cnblogs.com/TonyJia/p/17616347.html ANSI 1. 综述 ISA-95 简称S95,也有称作SP95。ISA-95 是企业系统与控制系统集成国际标准,由国际自动化学会(ISA…

react最好用的swiper插件和拖动插件 react-tiny-slider react-draggable

react移动端项目,其实有挺多的ui框架的,但是我们公司的项目,都是自己封装的ui库,又不可能为了一个轮播图就去再安装一个ui库 所以找了很多的轮播插件,都是不能满足需求 最后找到了它,react-tiny-slider&…

1 什么是linux驱动

1 目录 1 一、什么是linux驱动? 1、驱动的作用 2、 3、驱动的分类 4、linux源码 5、最简单的linux驱动 二、如何编译驱动程序 -- 有两种编译方法: -- 什么是Linux内核模块? -- Linux内核模块的编译 一、什么是linux驱动&#xff…