stable diffusion代码学习笔记

news2025/1/12 18:11:54

前言:本文没有太多公式推理,只有一些简单的公式,以及公式和代码的对应关系。本文仅做个人学习笔记,如有理解错误的地方,请指出。

资源

  1. 本文学习的代码;
  2. 相关文献:
  • Denoising Diffusion Probabilistic Models : DDPM,这个是必看的,推推公式
  • Denoising Diffusion Implicit Models :DDIM,对 DDPM 的改进
  • Pseudo Numerical Methods for Diffusion Models on Manifolds :PNMD/PLMS,对 DDPM 的改进
  • High-Resolution Image Synthesis with Latent Diffusion Models :Latent-Diffusion,必看
  • Neural Discrete Representation Learning : VQVAE,简单翻了翻,示意图非常形象,很容易了解其做法

前向过程(训练)

  • 输入一张图片+随机噪声,训练unet,网络预测图片加上的噪声

反向过程(推理)

  • 给个随机噪声,不断迭代去噪,输出一张图片

总体流程

  • 输入的prompt经过clip encoder编码成(3+3,77,768)特征,正负prompt各3个,默认negative prompt为空‘’,解码时正的和负的latent图片用公式计算一下才是最终结果;time step通过linear层得到(3+3,1280)特征;把prompt和time ebedding和随机生成的图片放入unet,得到的就是我们要的图片。

采样流程 text2img

  • 该函数在PLMSSampler中,输入x(噪声,(3,4,64,64))-----c(输入的prompt,(3,77,768)----t (输入的time step,第几次去噪(3,)。把这三个东西输入unet,得到预测的噪声e_t。
 def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
                      temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
                      unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
        b, *_, device = *x.shape, x.device

        def get_model_output(x, t):
            if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
                e_t = self.model.apply_model(x, t, c)
            else:
                x_in = torch.cat([x] * 2)
                t_in = torch.cat([t] * 2)
                c_in = torch.cat([unconditional_conditioning, c]) # 积极消极的prompt,解码时按照公式减去消极prompt的图像
                e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
                e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

            if score_corrector is not None:
                assert self.model.parameterization == "eps"
                e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)

            return e_t

        alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
        alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
        sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
        sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas

        def get_x_prev_and_pred_x0(e_t, index):
            # select parameters corresponding to the currently considered timestep
            a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
            a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
            sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
            sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

            # current prediction for x_0
            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
            if quantize_denoised:
                pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
            # direction pointing to x_t
            dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
            noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
            if noise_dropout > 0.:
                noise = torch.nn.functional.dropout(noise, p=noise_dropout)
            x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
            return x_prev, pred_x0

        e_t = get_model_output(x, t) # 模型预测的噪声
        if len(old_eps) == 0:
            # Pseudo Improved Euler (2nd order)
            x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) # 输入噪声减去预测噪声得到新的噪声,当前预测的latent图片
            e_t_next = get_model_output(x_prev, t_next)
            e_t_prime = (e_t + e_t_next) / 2 # 两次噪声的均值?
        elif len(old_eps) == 1:
            # 2nd order Pseudo Linear Multistep (Adams-Bashforth)
            e_t_prime = (3 * e_t - old_eps[-1]) / 2
        elif len(old_eps) == 2:
            # 3nd order Pseudo Linear Multistep (Adams-Bashforth)
            e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
        elif len(old_eps) >= 3:
            # 4nd order Pseudo Linear Multistep (Adams-Bashforth)
            e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24

        x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)

        return x_prev, pred_x0, e_t
  • 接下来看公式:
    在这里插入图片描述
  • 网络得到e_t后,进入到get_x_prev_and_pred_x0函数,可以看到pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()就是上述公式,也就是说网络的预测结果通过公式计算,我们可以得到预测的pred_x0原始图片和前一刻的噪声图像x_prev
        def get_x_prev_and_pred_x0(e_t, index):
            # select parameters corresponding to the currently considered timestep
            a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
            a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
            sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
            sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

            # current prediction for x_0
            pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
            if quantize_denoised:
                pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
            # direction pointing to x_t
            dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
            noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
            if noise_dropout > 0.:
                noise = torch.nn.functional.dropout(noise, p=noise_dropout)
            x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
            return x_prev, pred_x0
  • 前一刻的噪声图像的推理公式如图:
    在这里插入图片描述
    在这里插入图片描述

  • 得到了上一刻的噪声图片x_prev后(也就是函数返回的img),继续迭代,最终生成需要的图片。
    在这里插入图片描述

额外说明

这部分代码应该就是PLMS加速采样用的,论文中有公式推理
在这里插入图片描述
另外,还有一些参数是训练时候保存的,betas逐渐增大,用来控制噪声的强度。变量名解析 log_one_minus_alphas_cumprod其实就是log(1-alpha(右下角t)(头上直线)),没有带prev的都是当前时刻t,带prev的是前一时刻t-1。

在这里插入图片描述

参考文献:

https://blog.csdn.net/Eric_1993/article/details/129600524?spm=1001.2014.3001.5502
https://zhuanlan.zhihu.com/p/630354327

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1374644.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

手敲Mybatis(16章)-一级缓存功能实现

1.实现目的 这一节的目的主要是实现SqlSession级别的缓存,也就是一级缓存,首先看下图一,用户可以通过设置来进行是否开启一级缓存,不设置的化默认开启一级缓存,localCacheScopeSESSION为要设置一级缓存,lo…

【算法】基础算法001之双指针

👀樊梓慕:个人主页 🎥个人专栏:《C语言》《数据结构》《蓝桥杯试题》《LeetCode刷题笔记》《实训项目》《C》《Linux》《算法》 🌝每一个不曾起舞的日子,都是对生命的辜负 目录 前言 1.数组分块&#xf…

光伏项目如何建设和施工?

光伏项目现在已经成为能源行业最受欢迎、最有潜力的项目了,不仅仅可以减少温室气体的排放,也能够促进新能源和经济的发展。那么该如何建设和实施一个成功的光伏项目呢? 首先,选址是光伏项目成功的关键。理想的光伏项目应位于阳光充…

MySQL的三种存储引擎 InnoDB、MyISAM、Memory

InnoDB 1). 介绍 InnoDB是一种兼顾高可靠性和高性能的通用存储引擎,在 MySQL 5.5 之后,InnoDB是默认的MySQL 存储引擎。 2). 特点 DML操作遵循ACID模型,支持事务; 行级锁,提高并发访问性能; 支持外键F…

neo4j图数据库的简单操作记录

知识图谱文件导出 首先停止运行sudo neo4j stop然后导出数据库 导出格式为: 具体命令如下sudo neo4j-admin database dump --to-path/home/ neo4j最后重启sudo neo4j start知识图谱外观修改 在网页点击节点,选中一个表情后点击,可修改其颜…

springBoot-简单实践

基本介绍 在理解了自动依赖&#xff0c;以及自动配置后&#xff0c; 我们来做一个简单的使用springBoot的了解。 1、引入场景依赖 例如 <!--添加web场景--> <!-- springBoot会自动管理依赖&#xff0c;并自动配置--><dependencies><depende…

如何在OpenWRT部署uhttpd搭建服务器实现远程访问本地web站点

文章目录 前言1. 检查uhttpd安装2. 部署web站点3. 安装cpolar内网穿透4. 配置远程访问地址5. 配置固定远程地址 前言 uhttpd 是 OpenWrt/LuCI 开发者从零开始编写的 Web 服务器&#xff0c;目的是成为优秀稳定的、适合嵌入式设备的轻量级任务的 HTTP 服务器&#xff0c;并且和…

代码随想录算法训练营第二十五天| 回溯总结

回溯是递归的副产品&#xff0c;只要有递归就会有回溯&#xff0c;所以回溯法也经常和二叉树遍历&#xff0c;深度优先搜索混在一起&#xff0c;因为这两种方式都是用了递归。 回溯算法能解决如下问题&#xff1a; 组合问题&#xff1a;N个数里面按一定规则找出k个数的集合排…

CentOs 环境下使用 Docker 部署 Ruoyi-Vue

CentOs 环境下使用 Docker 部署 Ruoyi-Vue RuoYi-Vue 项目下载地址 RuoYi-Vue: &#x1f389; 基于SpringBoot&#xff0c;Spring Security&#xff0c;JWT&#xff0c;Vue & Element 的前后端分离权限管理系统&#xff0c;同时提供了 Vue3 的版本 (gitee.com) Docker 部…

存储卷(数据卷)—主要是nfs方式挂载

1、定义 容器内的目录和宿主机的目录进行挂载 容器在系统上的生命周期是短暂的&#xff0c;一旦容器被删除&#xff0c;数据会丢失。k8s基于控制器创建的pod&#xff0c;delete相当于重启&#xff0c;容器的状态会恢复到原始状态。一旦回到原始状态&#xff0c;后天编辑的文件…

RocketMQ Dashboard可视化工具

RocketMQ Dashboard 将 RocketMQ的相关指标展示在web页面 &#xff0c;支持以可视化工具代替 Topic 配置、Broker 管理等命令行操作。 官方文档地址&#xff1a;RocketMQ Dashboard | RocketMQ 目录 1.下载安装 1.1 系统要求&#xff1a; 1.2 源码安装 1.3 访问页面 2.功…

记录仪可作为XCP从站进行数据转发

车辆数据采集系统通常包含多种数据采集设备、多路总线或传感器信号&#xff0c;为了集中监控和管理&#xff0c;需要将这些设备的实时数据传输到上位机。对此&#xff0c;我们将使用基于XCP&#xff08;Universal Measurement and Calibration Protocol&#xff09;协议的数据记…

干货抢先看:SOLIDWORKS阵列操作的技巧与要点

SOLIDWORKS软件中的阵列功能十分常用且强大。本文将介绍一些关于SOLIDWORKS阵列的技巧&#xff0c;以帮助您更加高效地应用该功能。 1.线性阵列方向识别度增强 想使用线性阵列打孔的时候&#xff0c;模型上没有可以选中的参考线作为阵列方向怎么办&#xff1f;使用圆柱面也可…

基于ssm智慧社区停车管理系统设计与实现【附源码】

基于ssm智慧社区停车管理系统设计与实现 &#x1f345; 作者主页 央顺技术团队 &#x1f345; 欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; &#x1f345; 文末获取源码联系方式 &#x1f4dd; 项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&am…

深入理解.NET框架中的CLR(公共语言运行时)

深入理解.NET框架中的CLR&#xff08;公共语言运行时&#xff09; 引言 .NET框架中的CLR&#xff08;公共语言运行时&#xff09;是.NET应用程序运行的核心。本文将继续探索CLR的核心功能&#xff0c;并详细介绍.NET程序启动时是如何自动加载关键的库和服务来提供这些功能的。…

2024在视频号开店怎么样?平台现状如下,有电商经验者优先!

我是王路飞。 现在开网店、做电商的平台有很多&#xff0c;但是有着绝对流量优势的&#xff0c;除了抖音之外就是视频号了。 但是抖音跟视频号相比&#xff0c;已经属于一个很成熟的平台了&#xff0c;商家们也开始进入到内卷阶段了。 所以&#xff0c;如果你们2024年想做电…

工业智能网关:HiWoo Box远程采集设备数据

工业智能网关&#xff1a;HiWoo Box远程采集设备数据 在工业4.0和智能制造的浪潮下&#xff0c;工业互联网已成为推动产业升级、提升生产效率的关键。而在这其中&#xff0c;工业智能网关扮演着至关重要的角色。今天&#xff0c;我们就来深入探讨一下工业智能网关。 一、什么…

学习笔记17——通俗易懂的三次握手四次挥手

提供一种博主本人觉得很好理解的三次握手和四次挥手场景&#xff0c;帮助记忆 三次握手过程 初始状态&#xff1a;客户端处于closed状态&#xff0c;服务器处于listen监听转台客户端向服务器发送一个SYN连接请求&#xff0c;并告诉对方自己此时初始化序列号为x&#xff0c;发送…

k8s的策略

集群调度&#xff1a; Scheduler的调度算法&#xff1a; 预算策略 过滤出合适的节点 优先策略 选择部署的节点 NodeName&#xff1a;硬策略&#xff0c;不走调度策略&#xff0c;node1 nodeSelector&#xff1a;根据节点的标签选择&#xff0c;会走一个调度算法 只要是…

MySQL:索引失效场景总结

1 执行计划查索引 通过执行计划命令可以查看查询语句使用了什么索引。 EXPLAIN SELECT * FROM ods_finebi_area WHERE areaName = 福建 执行查询计划后,key列的值就是被使用的索引的名称,若key列没有值表示查询未使用索引。 2 在什么列上创建索引 (1)列经常被用于where…