扩散模型实战(七):Diffusers蝴蝶图像生成实战

news2024/11/24 5:03:26

推荐阅读列表:

扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

扩散模型实战(四):从零构建扩散模型

扩散模型实战(五):采样过程

扩散模型实战(六):Diffusers DDPM初探

       在本文中,我们以生成绚丽多彩的蝴蝶图像为例,学习Diffusers库相关知识,并学会训练自己的扩散模型。

一、环境准备工作

1.1 安装Diffusers库

%pip install -qq -U diffusers datasets transformers accelerate ftfy pyarrow==9.0.0

1.2 登录huggingface(optional)

      我们计划把训练好的模型上传到huggingface中,因此我们需要首先登录huggingface,可以通过访问https://huggingface.co/settings/tokens获取huggingface的token。

复制上图的token,然后执行下面的代码,并粘贴到执行如下代码的文本框中:

from huggingface_hub import notebook_loginnotebook_login()

1.3 安装Git LFS以上传模型检查点,代码如下:

%%capture!sudo apt -qq install git-lfs!git config --global credential.helper store

1.4 定义所需的函数

import numpy as npimport torchimport torch.nn.functional as Ffrom matplotlib import pyplot as pltfrom PIL import Image def show_images(x):    """给定一批图像,创建一个网格并将其转换为PIL"""    x = x * 0.5 + 0.5  # 将(-1,1)区间映射回(0,1)区间    grid = torchvision.utils.make_grid(x)    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))    return grid_im def make_grid(images, size=64): """给定一个PIL图像列表,将它们叠加成一行以便查看"""    output_im = Image.new("RGB", (size * len(images), size))    for i, im in enumerate(images):        output_im.paste(im.resize((size, size)), (i * size, 0))    return output_im # 对于Mac,可能需要设置成device = 'mps'(未经测试)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DreamBooth

       DreamBooth可以对Stable Diffusion进行微调,并在整个过程中引入特定的面部、物体或者风格等额外信息。我们可以初步体验一下Corridor Crew使用DreamBooth制作的视频(https://www.bilibili.com/video/BV18o4y1c7R7/?vd_source=c5a5204620e35330e6145843f4df6ea4),目前模型以及集成到Huggingface上,下面是加载的代码:

from diffusers import StableDiffusionPipeline # https://huggingface.co/sd-dreambooth-library ,这里有来自社区的各种模型model_id = "sd-dreambooth-library/mr-potato-head" # 加载管线pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_   dtype=torch.float16). to(device)
prompt = "an abstract oil painting of sks mr potato head by picasso"image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).    images[0]image

生成的图像如下图所示:

num_inference_steps:表示采样步骤的数量;

guidance_scale:决定模型的输出与Prompt之间的匹配程度;

三、Diffusers核心API

  • Pipeline(管道):高级API,便于部署;
  • Model(模型):定义训练扩散模型时需要的网络结构,比如UNet模型
  • Scheduler(调度器):在推理过程中使用多种不同的技巧来从噪声中生成图像,同时也可以生成训练过程中所需的”带噪“图像;

四、使用Diffusers生成蝴蝶图像案例

4.1 下载蝴蝶数据集

import torchvisionfrom datasets import load_datasetfrom torchvision import transforms dataset = load_dataset("huggan/smithsonian_butterflies_subset",     split="train") # 也可以从本地文件夹中加载图像# dataset = load_dataset("imagefolder", data_dir="path/to/folder") # 我们将在32×32像素的正方形图像上进行训练,但你也可以尝试更大尺寸的图像image_size = 32# 如果GPU内存不足,你可以减小batch_sizebatch_size = 64 # 定义数据增强过程preprocess = transforms.Compose(    [        transforms.Resize((image_size, image_size)),  # 调整大小        transforms.RandomHorizontalFlip(),            # 随机翻转        transforms.ToTensor(),              # 将张量映射到(0,1)区间        transforms.Normalize([0.5], [0.5]), # 映射到(-1, 1)区间    ]) def transform(examples):    images = [preprocess(image.convert("RGB")) for image in        examples["image"]]    return {"images": images} dataset.set_transform(transform) # 创建一个数据加载器,用于批量提供经过变换的图像train_dataloader = torch.utils.data.DataLoader(    dataset, batch_size=batch_size, shuffle=True)

可视化其中部分数据集

xb = next(iter(train_dataloader))["images"].to(device)[:8]print("X shape:", xb.shape)show_images(xb).resize((8 * 64, 64), resample=Image.NEAREST)
# 输出X shape: torch.Size([8, 3, 32, 32])

4.2 定义扩散模型的调度器

       在训练扩散模型和使用扩散模型进行推理时都可以由调度器(scheduler)来完成,噪声调度器能够确定在不同迭代周期分别添加多少噪声,通常可以使用如下两种方式来添加噪声,代码如下:

# 仅添加了少量噪声# 方法一:# noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_# start=0.001, beta_end=0.004)# 'cosine'调度方式,这种方式可能更适合尺寸较小的图像# 方法二:# noise_scheduler = DDPMScheduler(num_train_timesteps=1000,# beta_schedule='squaredcos_cap_v2')

参数说明:

beta_start:控制推理阶段开始时beta的值;

beta_end:控制推理阶段结束时beta的值;

beta_schedule:可以通过一个函数映射来为模型推理的每一步生成一个beta值

      无论选择哪个调度器,我们都可以通过noise_scheduler.add_noise为图片添加不同程度的噪声,代码如下:

timesteps = torch.linspace(0, 999, 8).long().to(device)noise = torch.randn_like(xb)noisy_xb = noise_scheduler.add_noise(xb, noise, timesteps)print("Noisy X shape", noisy_xb.shape)show_images(noisy_xb).resize((8 * 64, 64), resample=Image.NEAREST)
# 输出Noisy X shape torch.Size([8, 3, 32, 32])

4.3 定义扩散模型

       下面是Diffusers的核心概念-模型介绍,本文采用UNet,模型结构如图所示:

     下面代码中down_block_types对应下采样模型(绿色部分),up_block_types对应上采样模型(粉色部分)

from diffusers import UNet2DModel # 创建模型model = UNet2DModel(    sample_size=image_size,   # 目标图像分辨率    in_channels=3,            # 输入通道数,对于RGB图像来说,通道数为3     out_channels=3,           # 输出通道数    layers_per_block=2,       # 每个UNet块使用的ResNet层数    block_out_channels=(64, 128, 128, 256), # 更多的通道→更多的参数    down_block_types=(        "DownBlock2D",        # 一个常规的ResNet下采样模块        "DownBlock2D",        "AttnDownBlock2D",    # 一个带有空间自注意力的ResNet下采样模块        "AttnDownBlock2D",    ),    up_block_types=(        "AttnUpBlock2D",        "AttnUpBlock2D",      # 一个带有空间自注意力的ResNet上采样模块        "UpBlock2D",        "UpBlock2D",          # 一个常规的ResNet上采样模块    ),)model.to(device);

4.4 训练扩散模型

# 设定噪声调度器noise_scheduler = DDPMScheduler(    num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2") # 训练循环optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4) losses = [] for epoch in range(30):    for step, batch in enumerate(train_dataloader):        clean_images = batch["images"].to(device)        # 为图片添加采样噪声        noise = torch.randn(clean_images.shape).to(clean_images.           device)        bs = clean_images.shape[0]         # 为每张图片随机采样一个时间步        timesteps = torch.randint(            0, noise_scheduler.num_train_timesteps, (bs,),             device=clean_images.device        ).long()         # 根据每个时间步的噪声幅度,向清晰的图片中添加噪声        noisy_images = noise_scheduler.add_noise(clean_images,            noise, timesteps)         # 获得模型的预测结果        noise_pred = model(noisy_images, timesteps, return_           dict=False)[0]         # 计算损失        loss = F.mse_loss(noise_pred, noise)        loss.backward(loss)        losses.append(loss.item())         # 迭代模型参数        optimizer.step()        optimizer.zero_grad()    if (epoch + 1) % 5 == 0:        loss_last_epoch = sum(losses[-len(train_dataloader) :]) /  len(train_dataloader)        print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")

我们绘制一下训练过程中损失变化:

fig, axs = plt.subplots(1, 2, figsize=(12, 4))axs[0].plot(losses)axs[1].plot(np.log(losses))plt.show()

4.5 生成图像

我们使用如下两种方法来生成图像。

方法一:建立一个Pipeline

from diffusers import DDPMPipeline image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler)pipeline_output = image_pipe()pipeline_output.images[0]

保存Pipeline到本地文件夹

image_pipe.save_pretrained("my_pipeline")

我们查看一下my_pipeline文件夹里面保存了什么?

!ls my_pipeline/# 输出model_index.json  scheduler  unet

       scheduler和unet两个子文件夹包含了生成图像所需的全部组件,其中unet子文件夹包含了描述模型结构的配置文件config.json和模型参数文件diffusion_pytorch_model.bin

!ls my_pipeline/unet/# 输出config.json  diffusion_pytorch_model.bin

      我们只需要将 scheduler和unet两个子文件上传到Huggingface Hub中,就可以实现模型共享。

方法二:加入时间t进行循环采样

        我们按照不同时间步t进行逐步采样,我们看一下生成的效果,代码如下:

# 随机初始化(8张随机图片)sample = torch.randn(8, 3, 32, 32).to(device) for i, t in enumerate(noise_scheduler.timesteps):     # 获得模型的预测结果    with torch.no_grad():        residual = model(sample, t).sample     # 根据预测结果更新图像    sample = noise_scheduler.step(residual, t, sample).prev_sample show_images(sample)

4.6 上传模型到Huggingface Hub

我们定义上传模型的名称,名称会包含用户名,代码如下:

from huggingface_hub import get_full_repo_name model_name = "sd-class-butterflies-32"hub_model_id = get_full_repo_name(model_name)hub_model_id
# 输出Arron/sd-class-butterflies-32

在Huggingface Hub上创建一个模型仓库并将其上传,代码如下:

from huggingface_hub import HfApi, create_repo create_repo(hub_model_id)api = HfApi()api.upload_folder(    folder_path="my_pipeline/scheduler", path_in_repo="",     repo_id=hub_model_id)api.upload_folder(folder_path="my_pipeline/unet", path_in_repo="",     repo_id=hub_model_id)api.upload_file(    path_or_fileobj="my_pipeline/model_index.json",    path_in_repo="model_index.json",    repo_id=hub_model_id,)
# 输出https://huggingface.co/Arron/sd-class-butterflies-32/blob/main/model_index.json

创建一个模型卡片以便描述模型的细节,代码如下:

from huggingface_hub import ModelCard content = f"""---license: mittags:- pytorch- diffusers- unconditional-image-generation- diffusion-models-class---# 这个模型用于生成蝴蝶图像的无条件图像生成扩散模型 '''pythonfrom diffusers import DDPMPipeline pipeline = DDPMPipeline.from_pretrained('{hub_model_id}')image = pipeline().images[0]image"""card = ModelCard(content) card.push_to_hub(hub_model_id)
# 输出https://huggingface.co/Arron/sd-class-butterflies-32/blob/main/README.md

       至此,我们已经把自己训练好的模型上传到Huggingface Hub了,下面就可以使用DDPMPipeline的from_pretrained方法来加载模型了:

from diffusers import DDPMPipeline image_pipe = DDPMPipeline.from_pretrained(hub_model_id)pipeline_output = image_pipe()pipeline_output.images[0]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/942113.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ASUS华硕天选4笔记本电脑FA507XV原厂Windows11系统22H2

天选四FA507X原装系统自带所有驱动、出厂主题壁纸LOGO、Office办公软件 华硕电脑管家、奥创控制中心等预装程序,恢复出厂状态W11 链接:https://pan.baidu.com/s/1SPoFW7wR5KawGu-yMckNzg?pwdayxd 提取码:ayxd

听会议整理的几个问题整理

听会议整理的几个问题整理 AR与NAR正样本和负样本数据蒸馏 AR与NAR 正样本和负样本 对于ar 和nar ar虽然从一个人来看是串行的,但是可以对用户进行并行 nar对于一个人的需求是并行的。但是对于多用户是无法并行的 所以ar并非一定效率低 数据蒸馏

CUDA小白 - NPP(2) -图像处理-算数和逻辑操作

cuda小白 原文链接 NPP GPU架构近些年也有不少的变化,具体的可以参考别的博主的介绍,都比较详细。还有一些cuda中的专有名词的含义,可以参考《详解CUDA的Context、Stream、Warp、SM、SP、Kernel、Block、Grid》 常见的NppStatus&#xff0c…

手机无人直播软件有哪些,又有哪些优势?

如今,随着智能手机的普及和移动互联网的发展,手机无人直播成为了一个炙手可热的领域。手机无人直播软件为用户提供了便捷、灵活的直播方式,让更多商家人能够实现自己的直播带货的梦想。接下来,我们将探讨手机无人直播软件有哪些&a…

解密算法与数据结构面试:程序员如何应对挑战

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

el-upload调用内部方法删除文件

从Element UI 的官方文档中, Upload 上传组组件提供了on-remove和before-remove的文件删除的钩子属性(回调方法名),但如何调用组件删除方法(让该方法删除本地上传文件列表以及触发这两个钩子)并无相关说明。…

解锁安全高效办公——私有化部署的WorkPlus即时通讯软件

在当今信息时代,高效的沟通与协作对于企业的成功至关重要。然而,随着信息技术的发展,保护敏感信息和数据安全也变得越来越重要。为了满足企业对于安全沟通和高效办公的需求,我们隆重推出私有化部署的WorkPlus即时通讯软件&#xf…

Marin说PCB之如何使用CAM350做Gerber compare ?

最近小编在追一部东北武侠喜剧(鹊刀门传奇),大部分人员都是乡村爱情的人员演的,这部剧真的是超级搞笑,小编我以人格担保要是不搞笑的话,你来找我。 正当小编我周日在家里追剧的时候,手机上弹出了…

Git仓库简介

1、工作区、暂存区、仓库 工作区:电脑里能看到的目录。 暂存区:工作区有一个隐藏目录.git,是Git的版本库,Git的版本库里存了很多东西,其中最重要的就是称为stage(或者叫index)的暂存区&#xf…

【技术】SpringBoot Word 模板替换

SpringBoot Word 模板替换 什么是 Word 模板替换如何实现 Word 模板替换 什么是 Word 模板替换 模板一般是具有固定格式的内容,其中一部分需要替换。Word 模板通俗的讲是以 Word 的形式制作模板,固定格式和内容,然后将其中的一部分数据替换掉…

第 1 章 绪论 (三元组)

1. 示例代码: 1)status.h /* DataStructure 预定义常量和类型头文件 */#ifndef STATUS_H #define STATUS_H/* 函数结果状态码 */ #define TRUE 1 /* 返回值为真 */ #define FALSE 0 /* 返回值为假 */ #define RET_OK 0 /* 返回值…

电脑不安装软件,怎么将手机文件传输到电脑?

很多人都知道,AirDroid有网页版(web.airdroid.com)。 想要文件传输,却不想在电脑安装软件时,AirDroid的网页版其实也可以传输文件。 然而,要将文件从手机传输文件到网页端所在的电脑时,如果按…

Vue05_关于插槽和指令封装的运用

Vue_05 文章目录 Vue_05Vue 插槽01-插槽-默认插槽默认插槽-基本语法 02-插槽-后备内容(默认值)默认值设置方法 03-插槽-具名插槽具名插槽-语法 04-插槽-作用域插槽默认插槽-语法代码示例 Vue自定义指令- v-loading封装01-自定义指令自定义指令的两种注册…

【算法专题突破】双指针 - 盛最多水的容器(4)

目录 1. 题目解析 2. 算法原理 3. 代码编写 写在最后: 1. 题目解析 题目链接:11. 盛最多水的容器 - 力扣(Leetcode) 这道题目也不难理解, 两边的柱子的盛水量是根据短的那边的柱子决定的, 而盛水量…

如何清空小程序会员卡的电子票

​电子票不仅方便了用户的购票和消费,还提升了用户的购物体验和忠诚度。然而,在一些特殊情况下,可能需要手动清空会员的电子票。那么,下面我们就来探讨一下在小程序中如何手动清空会员的电子票。 1. 找到指定的会员卡。在管理员后…

15.CSS发光按钮的悬停特效

效果 源码 <!DOCTYPE html> <html> <head><title>CSS Modern Button</title><link rel="stylesheet" type="text/css" href="style.css"> </head> <body><a href="#" style=&quo…

Flink CDC学习笔记

第一章 CDC简介 1.1 什么是CDC ​ CDC (Change Data Capture 变更数据获取&#xff09;的简称。核心思想就是&#xff0c;检测并获取数据库的变动&#xff08;增删查改&#xff09;&#xff0c;将这些变更按发生的顺序记录下来&#xff0c;写入到消息中间件以供其它服务进行订…

公网远程访问局域网SQL Server数据库

文章目录 1.前言2.本地安装和设置SQL Server2.1 SQL Server下载2.2 SQL Server本地连接测试2.3 Cpolar内网穿透的下载和安装2.3 Cpolar内网穿透的注册 3.本地网页发布3.1 Cpolar云端设置3.2 Cpolar本地设置 4.公网访问测试5.结语 1.前言 数据库的重要性相信大家都有所了解&…

Deep Learning With Pytorch - 数据预处理,以导入LUNA16数据集为例

文章目录 数据集简介什么是CT扫描&#xff1f;导入大型数据集并不是一份轻松的工作 在Jupyter Notebook中导入LUNA16数据集导入可能用到的第三方库&#xff1a;LUNA16存放路径&#xff1a;用 pandas 读取 candidates.csv&#xff1b;读取 annotations.csv导入subset0和subset1的…

[FPGA IP系列] BRAM IP参数配置与使用示例

FPGA开发中使用频率非常高的两个IP就是FIFO和BRAM&#xff0c;上一篇文章中已经详细介绍了Vivado FIFO IP&#xff0c;今天我们来聊一聊BRAM IP。 本文将详细介绍Vivado中BRAM IP的配置方式和使用技巧。 一、BRAM IP核的配置 1、打开BRAM IP核 在Vivado的IP Catalog中找到B…