扩散模型实战(八):微调扩散模型

news2024/11/17 19:42:00

推荐阅读列表:

扩散模型实战(一):基本原理介绍

扩散模型实战(二):扩散模型的发展

扩散模型实战(三):扩散模型的应用

扩散模型实战(四):从零构建扩散模型

扩散模型实战(五):采样过程

扩散模型实战(六):Diffusers DDPM初探

扩散模型实战(七):Diffusers蝴蝶图像生成实战

       微调在LLM中并不是新鲜的概念,从头开始训练一个扩散模型需要很长的时间,特别是使用高分辨率图像训练。那么其实我们可以在已经训练好的”去噪“扩散模型基础上使用微调数据集进行二次微调训练。

       本文将介绍基于蝴蝶数据集上微调人脸生成的扩散模型:

一、环境准备

1.1 安装相关库

!pip install -qq diffusers datasets accelerate wandb open-clip-torch

1.2 登录Huggingface Hub

如果需要开源微调好的模型到Huggingface Hub上,那么需要使用如下代码登录,否则可忽略此步骤:

from huggingface_hub import notebook_loginnotebook_login()

1.3 导入相关库

import numpy as npimport torchimport torch.nn.functional as Fimport torchvisionfrom datasets import load_datasetfrom diffusers import DDIMScheduler, DDPMPipelinefrom matplotlib import pyplot as pltfrom PIL import Imagefrom torchvision import transformsfrom tqdm.auto import tqdmdevice = (    "mps"    if torch.backends.mps.is_available()    else "cuda"    if torch.cuda.is_available()    else "cpu")

二、导入预训练的扩散模型

下面我们导入人脸生成的扩散模型,观察一下生成的效果,代码如下:

image_pipe = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")image_pipe.to(device);

查看生成的图像,代码如下:

images = image_pipe().imagesimages[0]

       生成的效果虽然不错,但是速度稍微有点慢,其实有更快的采样器可以加速这一过程,比如下面介绍的DDIM

三、DDIM-更快的采样器

       在生成图像的每一步中,模型都会接收一个带有噪声的输入,并且需要预测这个噪声,以此来估计没有噪声的完整图像是什么。这个过程被称为采样过程,在Diffusers库中,采样通过调度器控制的,之前的文章中介绍过DDPMScheduler调度器,本文介绍的DDIMScheduler可以通过更少的迭代周期来产生很好的采样样本(1000多步采样不是必须的)。

# 创建一个新的调度器并设置推理迭代次数scheduler = DDIMScheduler.from_pretrained("google/ddpm-celebahq-256")scheduler.set_timesteps(num_inference_steps=40)
scheduler.timesteps
# 输出tensor([975, 950, 925, 900, 875, 850, 825, 800, 775, 750, 725,         700, 675, 650, 625, 600, 575, 550, 525, 500, 475, 450, 425,         400, 375, 350, 325, 300, 275, 250, 225, 200, 175, 150,         125, 100,  75,  50,  25, 0])

       下面使用4幅随机噪声图像进行循环采样,并观察每一步的输入与输出的”去噪“图像,代码如下:

# 从随机噪声开始x = torch.randn(4, 3, 256, 256).to(device) # batch size为4,三通道,长、宽均为256像素的一组图像# 循环一整套时间步for i, t in tqdm(enumerate(scheduler.timesteps)):     # 准备模型输入:给“带躁”图像加上时间步信息    model_input = scheduler.scale_model_input(x, t)     # 预测噪声    with torch.no_grad():        noise_pred = image_pipe.unet(model_input, t)["sample"]     # 使用调度器计算更新后的样本应该是什么样子    scheduler_output = scheduler.step(noise_pred, t, x)     # 更新输入图像    x = scheduler_output.prev_sample     # 时不时看一下输入图像和预测的“去噪”图像    if i % 10 == 0 or i == len(scheduler.timesteps) - 1:        fig, axs = plt.subplots(1, 2, figsize=(12, 5))         grid = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0)        axs[0].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)        axs[0].set_title(f"Current x (step {i})")         pred_x0 = (            scheduler_output.pred_original_sample        )         grid = torchvision.utils.make_grid(pred_x0, nrow=4).           permute(1, 2, 0)        axs[1].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)        axs[1].set_title(f"Predicted denoised images (step {i})")        plt.show()

     第二步生成图像的采样器是DDPMScheduler,我们可以使用新的DDIMScheduler来代替DDPMScheduler看看image_pipe生成的效果是否有提升,代码如下:

image_pipe.scheduler = schedulerimages = image_pipe(num_inference_steps=40).imagesimages[0]

       上述介绍了生成人脸的扩散模型以及生成的效果,也介绍了更快的采样器DDIMScheduler,下面我们使用蝴蝶数据集来微调人脸生成扩散模型:

四、微调人脸生成扩散模型

4.1 加载蝴蝶数据集

dataset_name = "huggan/smithsonian_butterflies_subset"dataset = load_dataset(dataset_name, split="train")image_size = 256batch_size = 4preprocess = transforms.Compose(    [        transforms.Resize((image_size, image_size)),        transforms.RandomHorizontalFlip(),        transforms.ToTensor(),        transforms.Normalize([0.5], [0.5]),    ]) def transform(examples):    images = [preprocess(image.convert("RGB")) for image in         examples["image"]]    return {"images": images} dataset.set_transform(transform) train_dataloader = torch.utils.data.DataLoader(    dataset, batch_size=batch_size, shuffle=True)

输出4幅蝴蝶图像,便于观察

print("Previewing batch:")batch = next(iter(train_dataloader))grid = torchvision.utils.make_grid(batch["images"], nrow=4)plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)

4.2 微调人脸生成扩散模型

num_epochs = 2lr = 1e-5grad_accumulation_steps = 2 optimizer = torch.optim.AdamW(image_pipe.unet.parameters(), lr=lr) losses = [] for epoch in range(num_epochs):    for step, batch in tqdm(enumerate(train_dataloader),       total=len(train_dataloader)):        clean_images = batch["images"].to(device)        # 随机生成一个噪声,稍后加到图像上        noise = torch.randn(clean_images.shape).to(clean_images.           device)        bs = clean_images.shape[0]         # 随机选取一个时间步        timesteps = torch.randint(            0,            image_pipe.scheduler.num_train_timesteps,            (bs,),            device=clean_images.device,        ).long()         # 根据选中的时间步和确定的幅值,在干净图像上添加噪声        # 此处为前向扩散过程        noisy_images = image_pipe.scheduler.add_noise(clean_images,            noise, timesteps)         # 使用“带噪”图像进行网络预测        noise_pred = image_pipe.unet(noisy_images, timesteps,            return_dict=False)[0]         # 对真正的噪声和预测的结果进行比较,注意这里是预测噪声        loss = F.mse_loss(            noise_pred, noise        )         # 保存损失值        losses.append(loss.item())         # 根据损失值更新梯度        loss.backward()         # 进行梯度累积,在累积到一定步数后更新模型参数        if (step + 1) % grad_accumulation_steps == 0:            optimizer.step()            optimizer.zero_grad()     print(        f"Epoch {epoch} average loss: {sum(losses[-len(train_           dataloader):])/len(train_dataloader)}"    ) # 画出损失曲线,效果如图所示plt.plot(losses) 

4.3 使用微调好的模型生成图像

x = torch.randn(8, 3, 256, 256).to(device)for i, t in tqdm(enumerate(scheduler.timesteps)):    model_input = scheduler.scale_model_input(x, t)    with torch.no_grad():        noise_pred = image_pipe.unet(model_input, t)["sample"]    x = scheduler.step(noise_pred, t, x).prev_samplegrid = torchvision.utils.make_grid(x, nrow=4)plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)

从图中可以看出生成的图像有蝴蝶数据的风格。

4.4 保持微调好的扩散模型,并且上传到Huggingface Hub中

image_pipe.save_pretrained("my-finetuned-model")
from huggingface_hub import HfApi, ModelCard, create_repo, get_   full_repo_name# 配置Hugging Face Hub,上传文件model_name = "ddpm-celebahq-finetuned-butterflies-2epochs"  # 使用@param 脚本程序对上传到 # Hugging Face Hub的文件进行命名local_folder_name = "my-finetuned-model" # @param脚本程序生成的名字,# 你也可以通过 image_pipe.save_pretrained('savename')自行指定description = "Describe your model here" # @paramhub_model_id = get_full_repo_name(model_name)create_repo(hub_model_id)api = HfApi()api.upload_folder(     folder_path=f"{local_folder_name}/scheduler",path_in_repo="",                     repo_id=hub_model_id )api.upload_folder(     folder_path=f"{local_folder_name}/unet", path_in_repo="",      repo_id=hub_model_id )api.upload_file(     path_or_fileobj=f"{local_folder_name}/model_index.json",     path_in_repo="model_index.json",     repo_id=hub_model_id,) # 添加一个模型卡片,这一步虽然不是必需的,但可以给他人提供一些模型描述信息 content = f"""---license: mittags:- pytorch- diffusers- unconditional-image-generation- diffusion-models-class---# 用法from diffusers import DDPMPipelinepipeline = DDPMPipeline.from_pretrained(' {hub_model_id}')image = pipeline().images[0]image'''"""card = ModelCard(content)card.push_to_hub(hub_model_id)

微调Trick:

  • 设置合适的batch_size,batch_size要在不超过GPU显存的前提下,尽量大一些,这样可以提高GPU计算效果;如果特别小,可以采用梯度累积的方式来更新模型参数,达到和大batch_size类似的效果,也就是多运行几次loss.backward(),再调用optimizer.step()和optimizer.zero_grad();
  • 训练过程中,要时不时生成一些图像样本来观察模型性能;
  • 训练过程中,可以把损失值、生成的图像样本等信息记录在日志中,可以使用Weights and Biases、TensorBoard等工具;

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/946912.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FPGA VR摄像机-拍摄和拼接立体 360 度视频

本文介绍的是 FPGA VR 相机的第二个版本,第一个版本是下面这样: 第一版地址: ❝ https://hackaday.io/project/26974-vr-camera-fpga-stereoscopic-3d-360-camera ❞ 本文主要介绍第二版本,第二版本的 VR 摄像机,能够以…

Python提取JSON文件中的指定数据并保存在CSV或Excel表格文件内

本文介绍基于Python语言,读取JSON格式的数据,提取其中的指定内容,并将提取到的数据保存到.csv格式或.xlsx格式的表格文件中的方法。 JSON格式的数据在数据信息交换过程中经常使用,但是相对而言并不直观;因此&#xff0…

Java学习笔记31——字符流

字符流 字符流为什么出现字符流编码表字符串中的编码解码问题字符流写数据的5中方式字符流读数据的两种方式字符流复制Java文件 字符流 为什么出现字符流 汉字的存储如果是GBK编码占用2个字节,如果是UTF-8占用三个字节 用字节流复制文本文件时,文本文…

华为---OSPF协议优先级、开销(cost)、定时器简介及示例配置

OSPF协议优先级、开销、定时器简介及示例配置 路由协议优先级:由于路由器上可能同时运行多种动态路由协议,就存在各个路由协议之间路由信息共享和选择的问题。系统为每一种路由协议设置了不同的默认优先级,当在不同协议中发现同一条路由时&am…

【USRP】调制解调系列3:2FSK、4FSK、8FSK,基于labview的实现

2FSK、4FSK、8FSK FSK(Frequency-shift keying)是信息传输中使用得较早的一种调制方式,它的主要优点是: 实现起来较容易,抗噪声与抗衰减的性能较好。在中低速数据传输中得到了广泛的应用。最常见的是用两个频率承载二进制1和0的双频FSK系统。 FSK 信号…

uni-app开发小程序,radio单选按钮,点击可以选中,再次点击可以取消

一、实现效果&#xff1a; 二、代码实现&#xff1a; 不适用官方的change方法&#xff0c;自己定义点击方法。 动态判断定义的值是否等于遍历的值进行回显&#xff0c;如果和上一次点击的值一样&#xff0c;就把定义的值改为null <template><view><radio-group&…

国产自主可控C++工业软件可视化图形架构源码

关于国产自主代替的问题是当前热点&#xff0c;尤其是工业软件领域。 “一个功能强大的全自主C跨平台图形可视化架构对开发自主可控工业基础软件至关重要&#xff01;” 作为全球领先的C工业基础图形可视化软件提供商&#xff0c;UCanCode软件有自己的思考&#xff0c;我们认…

Linux中的工具:yum,vim,gcc/g++,make/makefile,gdb

目录 1、yum 1.1 查看软件包&#xff1a; 1.2 安装软件包 1.3 卸载软件 2、vim 2.1 vim的三种模式 2.2 vim的基本操作 2.3. vim正常模式命令集 2.3.1 插入模式 2.3.2 移动光标 2.3.3 删除文字 2.3.4 复制 2.3.5 替换 2.3.6撤销上一次操作 2.3.7 更改 2.3.8 跳至…

WebGL模型矩阵

前言&#xff1a;依赖矩阵库 WebGL矩阵变换库_山楂树の的博客-CSDN博客 先平移&#xff0c;后旋转的模型变换&#xff1a; 1.将三角形沿着X轴平移一段距离。 2.在此基础上&#xff0c;旋转三角形。 先写下第1条&#xff08;平移操作&#xff09;中的坐标方程式。 等式1&am…

Mybatis1.4 多条件查询

1.4 多条件查询 1.4.1 编写接口方法1.4.2 编写SQL语句1.4.3 编写测试方法1.4.4 动态SQL 我们经常会遇到如上图所示的多条件查询&#xff0c;将多条件查询的结果展示在下方的数据列表中。而我们做这个功能需要分析最终的SQL语句应该是什么样&#xff0c;思考两个问题 条件表达式…

即时通讯开发中的性能优化技巧

即时通讯开发在如今的数字化社会中扮演着重要角色&#xff0c;然而&#xff0c;随着用户对即时通讯应用的需求不断增长&#xff0c;开发者们面临着使其应用保持高性能和可靠性的挑战。本文将探讨即时通讯开发中关键的性能优化技巧&#xff0c;帮助开发者们提升应用的用户体验和…

带你深入了解分布式系统

一.前言 当我们进行购物的时候,不知道大家有没有想过,每个人有那么多订单,要浏览海量商品,要加载许多网页,屏幕背后的网站是怎么完成这一系列的网页响应,数据存储的?本文将带大家深入了解这背后的机制和原理. 在进⾏技术学习过程中&#xff0c;由于⼤部分人没有经历过⼀些中⼤…

C语言(第三十五天)

3. 移位操作符 << 左移操作符 >> 右移操作符 注&#xff1a;移位操作符的操作数只能是整数。 3.1 左移操作符 移位规则&#xff1a;左边抛弃、右边补0 3.2 右移操作符 移位规则&#xff1a;首先右移运算分两种&#xff1a; 1. 逻辑右移&#xff1a;左边用0填充&a…

spring创建bean的三种方法

Spring支持如下三种方式创建Bean 1&#xff1a;调用构造器创建Bean 2&#xff1a;调用静态工厂方法创建Bean 3&#xff1a;调用实例工厂方法创建Bean 注解Bean方式 Xml方式 BeanDefinitionBuilder方式 1.注解Bean方式 通过Bean方式创建自定义Bean是最明显的方式&#xff0c;…

代码随想录笔记--链表篇

目录 1--虚拟头节点的使用 2--设计链表 3--反转链表 4--两两交换链表中的节点 5--快慢指针 5-1--删除链表倒数第N个节点 5-2--环形链表 5-3--环形链表II 1--虚拟头节点的使用 在链表相关题目中&#xff0c;常新定义一个虚拟头结点 dummynode 来指向原链表的头结点&…

常见矿石材质鉴定VR实训模拟操作平台提高学员的学习效果和实践能力

随着“元宇宙”概念的不断发展&#xff0c;在矿山领域中&#xff0c;长期存在传统培训内容不够丰富、教学方式单一、资源消耗大等缺点&#xff0c;无法适应当前矿山企业发展需求的长期难题。元宇宙企业借助VR虚拟现实、web3d开发和计算机技术构建的一个虚拟世界&#xff0c;为用…

华为云新生代开发者招募

开发者您好&#xff0c;我们是华为2012UCD的研究团队 为了解年轻开发者的开发现状和趋势 正在邀请各位先锋开发者&#xff0c;与我们进行2小时的线上交流&#xff08;江浙沪附近可线下交流&#xff09; 聊聊您日常开发工作中的产品使用需求 成功参与访谈者将获得至少300元京…

Angular安全专辑之三:授权绕过,利用漏洞控制管理员账户

这篇文章是针对实际项目中所出现的问题所做的一个总结。简单来说&#xff0c;就是授权绕过问题&#xff0c;管理员帐户被错误的接管。 详细情况是这样的&#xff0c;我们的项目中通常都会有用户身份验证功能&#xff0c;不同的用户拥有不同的权限。相对来说管理员账户所对应的…

嵌入式学习之进程

今天主要学习了进程&#xff0c;对fork的相关知识有了更加清楚的理解。 进程退出 正常调用&#xff1a; Main函数调用return ; 进程调用exit(),属于标准的C库&#xff1b;进程调用_exit&#xff08;&#xff09;或者_Exit(),属于系统调用 补充&#xff1a; 1.进程最后一个…

flutter高德地图大头针

1、效果图 2、pub get #地图定位 amap_flutter_map: ^3.0.0 amap_flutter_location: ^3.0.0 3、上代码 import dart:async; import dart:io;import package:amap_flutter_location/amap_flutter_location.dart; import package:amap_flutter_location/amap_location_option…