Diffusion Model: DDPM

news2025/1/11 0:45:57

本文相关内容只记录看论文过程中一些难点问题,内容间逻辑性不强,甚至有点混乱,因此只作为本人“备忘”,不建议其他人阅读。

Denoising Diffusion Probabilistic Models: https://arxiv.org/abs/2006.11239

DDPM

一、基于 x_0 已知的情况下,x_t 分布的推导过程:推导过程中,直接递归迭代即可。同时,过程中使用了 —— 两个高斯分布的和也满足高斯分布,其中均值为两个高斯分布均值的和,方差为两个高斯分布方差的和。

二、逆向过程中,q(x_{t-1}|x_t, x_0) 分布求解

进一步根据 1 中的结果可得:

公式 9 中的 z_{\theta}(x_t,t) 就是 diffusion model 需要估计的噪声均值,而噪声的方式是由 \alpha_t 或者 \beta_t 直接得到的。

三、具体训练过程:训练过程比较直接,利用 一 中的公式即可。

https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py L274

def q_sample(self, x_start, t, noise=None):
    noise = default(noise, lambda: torch.randn_like(x_start))
    return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)

def get_loss(self, pred, target, mean=True):
    if self.loss_type == 'l1':
        loss = (target - pred).abs()
        if mean:
            loss = loss.mean()
    elif self.loss_type == 'l2':
        if mean:
            loss = torch.nn.functional.mse_loss(target, pred)
        else:
            loss = torch.nn.functional.mse_loss(target, pred, reduction='none')
    else:
        raise NotImplementedError("unknown loss type '{loss_type}'")

    return loss

# 输入参数说明:
# x_start:原始图像 x0
# t:当前扩散步数
# noise:噪声,需要注意这里的 noise 与 x_start 维度相同;具体含义是每个位置上元素都服从 0-1 高斯分布
def p_losses(self, x_start, t, noise=None):
    # 生成第 t 步的高斯噪声
    noise = default(noise, lambda: torch.randn_like(x_start))

    # 根据本文 一 中推导的公式得到第 t 步加噪后的图像
    x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
    
    # 模型预测结果,根据具体的设置,好像可以回归加的噪声,也可以直接回归原始图像
    model_out = self.model(x_noisy, t)

    loss_dict = {}
    if self.parameterization == "eps":
        # 模型估计噪声
        target = noise
    elif self.parameterization == "x0":
        # 模型直接估计原始图像
        target = x_start
    else:
        raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")

    # 使用 L1 或者 L2 Loss 计算误差
    loss = self.get_loss(model_out, target, mean=False).mean(dim=[1, 2, 3])

    log_prefix = 'train' if self.training else 'val'

    loss_dict.update({f'{log_prefix}/loss_simple': loss.mean()})
    loss_simple = loss.mean() * self.l_simple_weight

    loss_vlb = (self.lvlb_weights[t] * loss).mean()
    loss_dict.update({f'{log_prefix}/loss_vlb': loss_vlb})

    loss = loss_simple + self.original_elbo_weight * loss_vlb

    loss_dict.update({f'{log_prefix}/loss': loss})

    return loss, loss_dict

四、具体生成(采样)过程:根据 二 中推导的公式,依次计算前一步图像的分布。需要注意:

  1. 具体回归的均值的维度与图像维度完全相同,即图像每个位置(包括不同通道)都建模为高斯分布,均值就是无随机时图像应该有的“样子”
  2. 因此,在 T=0 步得到的均值就是最终生成的图像;不过在 T> 0 步依据均值和方差进行采样,可能的原因是增加生成图像的多样性。

https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py L222

# 根据本文 二 中的公式计算 x_t-1 的均值和方差
def q_posterior(self, x_start, x_t, t):
    posterior_mean = (
            extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start +
            extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
    )
    posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape)
    posterior_log_variance_clipped = extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
    return posterior_mean, posterior_variance, posterior_log_variance_clipped

def p_mean_variance(self, x, t, clip_denoised: bool):
    model_out = self.model(x, t)
    if self.parameterization == "eps":
        x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
    elif self.parameterization == "x0":
        x_recon = model_out
    if clip_denoised:
        x_recon.clamp_(-1., 1.)

    model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
    return model_mean, posterior_variance, posterior_log_variance

# 基于估计的图像每个位置的均值 model_mean 和方差 model_log_variance 生成对应随机图像
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
    b, *_, device = *x.shape, x.device
    model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
    noise = noise_like(x.shape, device, repeat_noise)
    # no noise when t == 0
    nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
    return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise

# 从 T 步 ——> T-1 步 ——> ... ——> 0 步,依次进行反向估计
@torch.no_grad()
def p_sample_loop(self, shape, return_intermediates=False):
    device = self.betas.device
    b = shape[0]
    img = torch.randn(shape, device=device)
    intermediates = [img]
    for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
        img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
                            clip_denoised=self.clip_denoised)
        if i % self.log_every_t == 0 or i == self.num_timesteps - 1:
            intermediates.append(img)
    if return_intermediates:
        return img, intermediates
    return img

# 采样入口函数,batch_size 一次生成的图像数量
@torch.no_grad()
def sample(self, batch_size=16, return_intermediates=False):
    image_size = self.image_size
    channels = self.channels
    return self.p_sample_loop((batch_size, channels, image_size, image_size),
                                return_intermediates=return_intermediates)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1250258.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用Linux JumpServer堡垒机本地部署与远程访问

🌈个人主页:聆风吟 🔥系列专栏:网络奇遇记、Cpolar杂谈 🔖少年有梦不应止于心动,更要付诸行动。 文章目录 📋前言一. 安装Jump server二. 本地访问jump server三. 安装 cpolar内网穿透软件四. 配…

mysql索引分为哪几类,聚簇索引和非聚簇索引的区别,MySQL索引失效的情况有哪几种情况,MySQL索引优化的手段,MySQL回表

文章目录 索引分为哪几类?聚簇索引和非聚簇索引的区别什么是[聚簇索引](https://so.csdn.net/so/search?q聚簇索引&spm1001.2101.3001.7020)?(重点)非聚簇索引 聚簇索引和非聚簇索引的区别主要有以下几个:什么叫回…

vcsa6.7 5480无法登录

停电维护硬件后,发现vcsa异常,https://ip:5480无法登录,https://ip/ui正常,ssh登录页正常 kb资料 通过端口 5480 登录到 VMware vCenter Server Appliance Web 控制台失败 (2120477) 操作过程 Connecting to 192.16.20.31:22..…

LLMLingua:集成LlamaIndex,对提示进行压缩,提供大语言模型的高效推理

大型语言模型(llm)的出现刺激了多个领域的创新。但是在思维链(CoT)提示和情境学习(ICL)等策略的驱动下,提示的复杂性不断增加,这给计算带来了挑战。这些冗长的提示需要大量的资源来进行推理,因此需要高效的解决方案,本文将介绍LLM…

2023大模型安全解决方案白皮书

今天分享的是大模型系列深度研究报告:《2023大模型安全解决方案白皮书》。 (报告出品方:百度安全) 报告共计:60页 前言 在当今迅速发展的数字化时代,人工智能技术正引领着科技创新的浪潮而其中的大模型…

一键填充字幕——Arctime pro

之前的博客中,我们聊到了PR这款专业的视频制作软件,但是pr有许多的功能需要搭配使用,相信不少小伙伴在剪辑视频时会发现一个致命的问题,就是字幕编写。伴随着人们对字幕需求的逐渐增加,这款软件便应运而生~ 相信应该有…

汽车业务增长乏力!又被法雷奥告上法庭,英伟达有点「难」

随着智能汽车进入「降本增效」的关键周期,对于上游产业链,尤其是芯片的影响也在持续发酵。 本周,英伟达发布截至2023年10月29日的第三季度财报数据,整体业务收入为181.2亿美元,比去年同期增长206%,比上一季…

【vue_1】console.log没有反应

1、打印不出来?2、警告也会出现问题3、插播:如何使用if-else 语句来处理逻辑 1、打印不出来? 要做一个权限不够的弹出消息框 const authority_message () > {ElMessage({type: warrnings,message: 当前用户的权限不够});console.log(he…

GPS 定位信息分析:航向角分析及经纬度坐标转局部XY坐标

GPS 定位信息分析(1) 从下面的数据可知,raw data 的提取和经纬度的计算应该是没问题的 在相同的经纬度下, x 和 y 还会发生变化,显然是不正确的 raw data:3150.93331124 11717.59467080 5.3 latitude: 31.8489 long…

Int8量化算子在移动端CPU的性能优化

本文介绍了Depthwise Convolution 的Int8算子在移动端CPU上的性能优化方案。ARM架构的升级和相应指令集的更新不断提高移动端各算子的性能上限,结合数据重排和Sdot指令能给DepthwiseConv量化算子的性能带来较大提升。 背景 MNN对ConvolutionDepthwise Int8量化算子在…

计算机组成原理-固态硬盘SSD

文章目录 总览机械硬盘vs固态硬盘固态硬盘的结构固态硬盘与机械硬盘相比的特点磨损均衡技术例题 总览 机械硬盘vs固态硬盘 固态硬盘采用闪存技术,是电可擦除ROM 下图右边黑色的块块就是一块一块的闪存芯片 固态硬盘的结构 块大小16KB~512KB 页大小512B~4KB 对固…

ES6之class类

ES6提供了更接近传统语言的写法,引入了Class类这个概念,作为对象的模板。通过Class关键字,可以定义类,基本上,ES6的class可以看作只是一个语法糖,它的绝大部分功能,ES5都可以做到,新…

数据库的事务的基本特性,事务的隔离级别,事务隔离级别如何在java代码中使用,使用MySQL数据库演示不同隔离级别下的并发问题

文章目录 数据库的事务的基本特性事务的四大特性(ACID)4.1、原子性(Atomicity)4.2、一致性(Consistency)4.3、隔离性(Isolation)4.4、持久性(Durability) 事务的隔离级别5.1、事务不…

6.11左叶子之和(LC404-E)

用java定义树: public class TreeNode {int val;TreeNode left;TreeNode right; //一个空构造方法TreeNode(),用于初始化节点的默认值。TreeNode() {} //一个构造方法TreeNode(int val),用于初始化节点的值,并设置默认的左右子节…

算法笔记:OPTICS 聚类

1 基本介绍 OPTICS(Ordering points to identify the clustering structure)是一基于密度的聚类算法 OPTICS算法是DBSCAN的改进版本 在DBCSAN算法中需要输入两个参数: ϵ 和 MinPts ,选择不同的参数会导致最终聚类的结果千差万别,因此DBCSAN…

分布式篇---第六篇

系列文章目录 文章目录 系列文章目录前言一、说说什么是漏桶算法二、说说什么是令牌桶算法三、数据库如何处理海量数据?前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码…

可观测性建设实践之 - 日志分析的权衡取舍

指标、日志、链路是服务可观测性的三大支柱,在服务稳定性保障中,通常指标侧重于发现故障和问题,日志和链路分析侧重于定位和分析问题,其中日志实际上是串联这三大维度的一个良好桥梁。 但日志分析往往面临成本和效果之间的权衡问…

Spring Boot Actuator 2.2.5 基本使用

1. pom文件 &#xff0c;添加 Actuator 依赖 <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-actuator</artifactId> </dependency> 2.application.properties 文件中添加以下配置 …

优秀软件设计特征与原则

1.摘要 一款软件产品好不好用, 除了拥有丰富的功能和人性化的界面设计之外, 还有其深厚的底层基础, 而设计模式和算法是构建这个底层基础的基石。好的设计模式能够让产品开发快速迭代且稳定可靠, 迅速抢占市场先机&#xff1b;而好的算法能够让产品具有核心价值, 例如字节跳动…

2、用命令行编译Qt程序生成可执行文件exe

一、创建源文件 1、新建一个文件夹&#xff0c;并创建一个txt文件 2、重命名为main.cpp 3、在main.cpp中添加如下代码 #include <QApplication> #include <QDialog> #include <QLabel> int main(int argc, char *argv[]) { QApplication a(argc, argv); QDi…