扩散模型(Diffusion model)代码详细解读

news2024/11/13 16:19:07

扩散模型代码详细解读

代码地址:denoising-diffusion-pytorch/denoising_diffusion_pytorch.py at main · lucidrains/denoising-diffusion-pytorch (github.com)

前向过程和后向过程的代码都在GaussianDiffusion​这个类中。​

常见问题解决

Why self-conditioning? · Issue #94 · lucidrains/denoising-diffusion-pytorch (github.com)

"pred_x0" preforms better than "pred_noise" · Issue #58 · lucidrains/denoising-diffusion-pytorch (github.com)

What is objective=pred_x0 and how do you use it? · Issue #34 · lucidrains/denoising-diffusion-pytorch (github.com)

Conditional generation · Issue #7 · lucidrains/denoising-diffusion-pytorch (github.com)

Questions About DDPM · Issue #10 · lucidrains/denoising-diffusion-pytorch (github.com)
The difference between pred_x0, pred_v, pred_noise three objectives · Issue #153 · lucidrains/denoising-diffusion-pytorch (github.com)

前向训练过程

p_losses

首先是p_losses函数,这个是训练过程的主体部分。

def p_losses(self, x_start, t, noise = None):
        b, c, h, w = x_start.shape
	# 首先随机生成噪声
        noise = default(noise, lambda: torch.randn_like(x_start))

        # noise sample
	# 噪声采样,注意这个是一次性完成的
        x = self.q_sample(x_start = x_start, t = t, noise = noise)

        # if doing self-conditioning, 50% of the time, predict x_start from current set of times
        # and condition with unet with that
        # this technique will slow down training by 25%, but seems to lower FID significantly

	# 判断是否进行self-condition,就是利用前面步骤预测出的x0来辅助当前的预测
        x_self_cond = None
        if self.self_condition and random() < 0.5:
            with torch.no_grad():
                x_self_cond = self.model_predictions(x, t).pred_x_start
                x_self_cond.detach_()

        # predict and take gradient step

	# 将采样的x和self condition的x一起输入到model当中,这个model是UNet结构
        model_out = self.model(x, t, x_self_cond)
	# 模型预测的目标,分为三种
        if self.objective == 'pred_noise':
            target = noise
        elif self.objective == 'pred_x0':
            target = x_start
        elif self.objective == 'pred_v':
            v = self.predict_v(x_start, t, noise)
            target = v
        else:
            raise ValueError(f'unknown objective {self.objective}')
	# 计算损失
        loss = self.loss_fn(model_out, target, reduction = 'none')
        loss = reduce(loss, 'b ... -> b (...)', 'mean')

        loss = loss * extract(self.p2_loss_weight, t, loss.shape)
        return loss.mean()

对其中的extract函数进行分析,extract函数实现如下:

def extract(a, t, x_shape):

    # Extract some coefficients at specified timesteps,
    # then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    b, *_ = t.shape
    # 使用了gather函数
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

q_sample

然后介绍p_losses函数中使用的其他函数,第一个是q_sample函数,它的作用是加上噪声,对应论文的公式:
在这里插入图片描述

其中self.sqrt_alphas_cumprod​和self.sqrt_one_minus_alphas_cumprod​分别是alpha的累乘值和1-alpha的累乘值,x_start相当于x0,noise相当于z。

def q_sample(self, x_start, t, noise=None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        return (
            extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
        )

model_predictions

然后是model_predictions函数,它的实现如下:

def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False):
	# 输入到UNet结构中获得输出
        model_output = self.model(x, t, x_self_cond)
        maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
	# 暂不明确它的作用
        if self.objective == 'pred_noise':
            pred_noise = model_output
            x_start = self.predict_start_from_noise(x, t, pred_noise)
            x_start = maybe_clip(x_start)

        elif self.objective == 'pred_x0':
            x_start = model_output
            x_start = maybe_clip(x_start)
            pred_noise = self.predict_noise_from_start(x, t, x_start)

        elif self.objective == 'pred_v':
            v = model_output
            x_start = self.predict_start_from_v(x, t, v)
            x_start = maybe_clip(x_start)
            pred_noise = self.predict_noise_from_start(x, t, x_start)
	# 返回得到的噪声和
        return ModelPrediction(pred_noise, x_start)

几种objective

model_predictions函数中有一个难点,就是其中的self.objective,它有三种形式:

  • pred_noise:这个相当于是预测噪声,此时UNet模型的输出是噪声
  • pred_x0:这个相当于是预测最开始的x,此时UNet模型的输出是去噪的图像
  • pred_v:这个相当于是预测速度v,它在这篇文章中提出。然后根据速度求出最开始的x,最后预测出噪声。

如图所示:​
在这里插入图片描述

在上面的三种objective中,还涉及到了几种预测方法的实现,具体如下:

(1)predict_start_from_noise:这个函数的作用是根据噪声noise预测最开始的x,也就是去噪的图像。

其中self.sqrt_recip_alphas_cumprod​和self.sqrt_recipm1_alphas_cumprod​来自在这里插入图片描述
公式,它们分别为:在这里插入图片描述
在这里插入图片描述

公式来源文章:DDPM

def predict_start_from_noise(self, x_t, t, noise):
    return (
        extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
        extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
    )

它对应论文中的公式如下:
在这里插入图片描述

(2)predict_noise_from_start:这个函数的作用是根据图像预测噪声,也就是加噪声。

def predict_noise_from_start(self, x_t, t, x0):
    return (
        (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
        extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
    )

它对应论文中的公式如下:
在这里插入图片描述
需要注意它是反推过来的,过程如下:

(3)predict_v:预测速度v

 def predict_v(self, x_start, t, noise):
     return (
         extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
         extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
     )

它对应论文中的公式:在这里插入图片描述

(4)predict_start_from_v:根据速度v预测最初的x,也就是图像

def predict_start_from_v(self, x_t, t, v):
    return (
        extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
        extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
    )

它对应论文中的公式如下:在这里插入图片描述其中zt相当于xt。

后向采样过程

sample函数

@torch.no_grad()
def sample(self, batch_size = 16, return_all_timesteps = False):
    image_size, channels = self.image_size, self.channels
    # 采样的函数
    sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
    # 调用该函数
    return sample_fn((batch_size, channels, image_size, image_size), return_all_timesteps = return_all_timesteps)

该函数的作用是获取采样的函数然后进行调用,采样函数分成两种:p_sample_loop和ddim_sample。

p_sample_loop函数

 @torch.no_grad()
 def p_sample_loop(self, shape, return_all_timesteps = False):
     batch, device = shape[0], self.betas.device
     # 随机生成噪声图像
     img = torch.randn(shape, device = device)
     imgs = [img]

     x_start = None
     # 遍历所有的t
     for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
         # 判断是否使用self-condition
	 self_cond = x_start if self.self_condition else None
         # 进行采样,得到去噪的图像
         img, x_start = self.p_sample(img, t, self_cond)
         imgs.append(img)
     # 判断是否返回每个步骤的img还是最后一步的img
     ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
     # 归一化
     ret = self.unnormalize(ret)
     return ret

其中涉及到归一化函数self.unnormalize​,含有两种

# normalization functions
def normalize_to_neg_one_to_one(img):
    return img * 2 - 1
def unnormalize_to_zero_to_one(t):
    return (t + 1) * 0.5

p_sample函数

@torch.no_grad()
def p_sample(self, x, t: int, x_self_cond = None):
    b, *_, device = *x.shape, x.device
    batched_times = torch.full((b,), t, device = x.device, dtype = torch.long)
    # 获得平均值,方差和x0
    model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
    # 随机生成一个噪声	  
    noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
    # 得到预测的图像,img = 平均值 + exp(0.5 * 方差) * noise
    pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
    return pred_img, x_start

p_mean_variance函数

其中含有p_mean_variance​函数,代码实现如下:

def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
    # 输入到UNet网络进行预测
    preds = self.model_predictions(x, t, x_self_cond)
    # 得到预测的x0
    x_start = preds.pred_x_start
    # 压缩x0中值的范围至[-1,1]
    if clip_denoised:
        x_start.clamp_(-1., 1.)
    # 得到x0后根据xt和t得到分布的平均值和方差
    model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
    return model_mean, posterior_variance, posterior_log_variance, x_start

q_posterior函数

其中q_posterior​函数的实现如下:

def q_posterior(self, x_start, x_t, t):
    # 计算平均值
    posterior_mean = (
        extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
        extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
    )
    # 计算方差
    posterior_variance = extract(self.posterior_variance, t, x_t.shape)
    # 获得一个压缩范围的方差,且取对数
    posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
    return posterior_mean, posterior_variance, posterior_log_variance_clipped

平均值和方差对应的公式如下:

在这里插入图片描述

其中self.posterior_mean_coef1​对应的是x0前面的系数,self.posterior_mean_coef2​对应的是xt前面的系数。

self.posterior_variance​对应的beta那部分的系数。

ddim_sample函数

@torch.no_grad()
def ddim_sample(self, shape, return_all_timesteps = False):
    batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
    times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1)   # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
    times = list(reversed(times.int().tolist()))
    time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
    img = torch.randn(shape, device = device)
    imgs = [img]
    x_start = None
    for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'):
        time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
        self_cond = x_start if self.self_condition else None
        pred_noise, x_start, *_ = self.model_predictions(img, time_cond, self_cond, clip_x_start = True)
        imgs.append(img)
        if time_next < 0:
            img = x_start
            continue

        alpha = self.alphas_cumprod[time]
        alpha_next = self.alphas_cumprod[time_next]
        sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
        c = (1 - alpha_next - sigma ** 2).sqrt()
        noise = torch.randn_like(img)
        img = x_start * alpha_next.sqrt() + \
              c * pred_noise + \
              sigma * noise
    ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
    ret = self.unnormalize(ret)
    return ret

上面部分依据的公式为:(文章)
在这里插入图片描述
在这里插入图片描述

训练的模型(UNet)

后续会继续更新!
对您有帮助请点赞收藏哦!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/182396.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

盒相关样式-----diaplay:block、inline

盒的基本类型 CSS中的盒分为block类型与inline类型&#xff0c;例如&#xff0c; div元素与p元素属于block类型&#xff0c; span元素与a元素属于inline类型。 block类型的盒对应的是html中的块级元素&#xff0c;inline类型的盒对应了html中的行内元素。 行内元素与块级元素…

JavaScript 练手小技巧:键盘事件

键盘事件应该是鼠标事件之外&#xff0c;使用频率最高的 JS 事件了吧&#xff1f; 一般用于全局或者表单。 键盘事件由用户击打键盘触发&#xff0c;主要有keydown、keypress、keyup三个事件。 keydown&#xff1a;按下键盘时触发。Ctrl、Shift、Alt 等和其它按键组合时&…

BCNF与3NF

今天学了一下午这个BCNFBCNFBCNF与3NF3NF3NF&#xff0c;有感而发&#xff0c;特来总结。好像好久不打键盘了&#xff0c;这手好像刚长出来的一样。本文浅显的分析一下两种范式的关系与不同以及判断方法和分解算法&#xff0c;以做总结。 BCNFBCNFBCNF范式的定义如下: 设属性集…

linux 常用指令大全

目录一、基本指令指令基本格式1、ls1.1 ls相关选项2、pwd3、cd4、mkdir4.1、mkdir相关选项5、touch6、cp6.1 cp相关选项7、mv8、rm8.1、rm相关选项9、输出重定向10、cat11、df11.1、df 相关选项12、free12.1、free 相关选项13、head13.1、head相关选项14、tail14.1 tail相关选项…

day13 二叉树 | 144、二叉树的前序遍历 145、二叉树的后序遍历 54、二叉树的中序遍历

二叉树基础 二叉搜索树 二叉搜索树是一个有序树。 若它的左子树不空&#xff0c;则左子树上所有结点的值均小于它的根结点的值&#xff1b;若它的右子树不空&#xff0c;则右子树上所有结点的值均大于它的根结点的值&#xff1b;它的左、右子树也分别为二叉排序树 下面这两棵…

零食商城|基于springboot的零食商城

作者主页&#xff1a;编程指南针 作者简介&#xff1a;Java领域优质创作者、CSDN博客专家 、掘金特邀作者、多年架构师设计经验、腾讯课堂常驻讲师 主要内容&#xff1a;Java项目、毕业设计、简历模板、学习资料、面试题库、技术互助 收藏点赞不迷路 关注作者有好处 文末获取源…

GuLi商城-简介-项目介绍、分布式基础概念、微服务架构图

一、项目简介 1 、项目背景 1 &#xff09;、电商模式 市面上有 5 种常见的电商模式 B2B、B2C、C2B、C2C、O2O&#xff1b; 1 、 B2B 模式 B2B (Business to Business)&#xff0c; 是指商家与商家建立的商业关系。 如&#xff1a;阿里巴巴 2 、 B2C 模式 B2C (Business…

Win7安装高版本的NodeJS方法,亲测可用

Win7安装高版本的NodeJS方法 正常情况下&#xff0c;Win7所能支持的Node.js最高版本为:V13.14&#xff0c;在开发过程中&#xff0c;git下来的项目由于node版本比较高的原因&#xff0c;好多package都不能还原或出现诸多警告 网络大神分享的安装高版本的方法&#xff1a; 1、…

Express 通过 CORS 或 JSONP 解决跨域问题

文章目录参考描述同源策略同源同源策略示例CSRF 攻击解决跨域问题CORSCORS 响应头部Access-Control-Allow-Origin简单请求预检请求预检请求包含的两次请求解决CORS 中间件使用 CORS 中间件处理跨域请求JSONP通过原生 JS 向服务器端发起 JSONP 请求通过 jQuery 向客户端发起 JSO…

mysql:如何在windows环境下配置并随意切换两种mysql版本

系列文章目录 文章目录系列文章目录前言一、去官网下载zip安装包二、配置创建my.ini文件2.环境变量3、使用管理员身份打开dos命令窗口4、安装mysql8的服务和初始化data5、启动6 错误解决&#xff1a;修改mysql8服务的注册表最后前言 之前安装过5.7的版本 后来由于需要 就安装了…

天龙八部TLBB从0到1搭建教程-上

服务器的配置选择与购买 我们需要准备的东西,是环境安装和4核8G的服务器一台。 其实购买服务器的地方很多以下这些服务商都可以,具体看服务器的配置选择,像这种4核8G的 服务器价格在260-400之间一台仅供参考,当然还有带防的服务器价格就偏高了阿里云、腾讯、百度、西部数码…

年后公司新来一00后卷王,我们这帮老油条真干不过.....

都说00后躺平了&#xff0c;但是有一说一&#xff0c;该卷的还是卷。这不&#xff0c;我们公司来了个00后&#xff0c;工作没两年&#xff0c;跳槽到我们公司起薪18K&#xff0c;都快接近我了。后来才知道人家是个卷王&#xff0c;从早干到晚就差搬张床到工位睡觉了。 2023年春…

MySQL深分页 + 多字段排序场景的优化方案【三百万级数据量】

需求背景 目前产品需要针对一个大范围地区内的所有用户做排行榜功能&#xff0c;且这个排行榜有几个比较蛋疼的附加需求&#xff1a; 排行榜需要全量展示所有用户&#xff0c;且做分页展示&#xff08;大坑&#x1f4a5;&#xff09; 排行榜有4种排序条件&#xff0c;且每个排…

fatal error怎么解决,有什么快捷的解决方法

fatal error怎么解决&#xff0c;其实是有多种的解决方法的&#xff0c;主要是看你想用哪种解决方法去进行解决&#xff0c;下面一起来看看。 一.fatal error的解决方法 1、按winR&#xff0c;弹出运行窗口。 2、输入regedit点击确定&#xff0c;弹出注册表编辑器。 3、在注…

2023年数据库优化顶级原理

毫不夸张的说咱们后端工程师&#xff0c;无论在哪家公司&#xff0c;呆在哪个团队&#xff0c;做哪个系统&#xff0c;遇到的第一个让人头疼的问题绝对是数据库性能问题。如果我们有一套成熟的方法论&#xff0c;能让大家快速、准确的去选择出合适的优化方案&#xff0c;我相信…

Acwing-1116. 马走日

本题求有多少路径遍历棋盘上的所有点&#xff0c;属于外部搜索&#xff0c;所以需要回溯。另外&#xff0c;对于递归终止条件&#xff0c;我们添加一个参数用来表示当前遍历到第几个点&#xff0c;如果是n*m表明已经将棋盘遍历一遍了&#xff0c;方案数1&#xff0c;return即可…

Linux常用命令——screen命令

在线Linux命令查询工具(http://www.lzltool.com/LinuxCommand) screen 用于命令行终端切换 补充说明 Screen是一款由GNU计划开发的用于命令行终端切换的自由软件。用户可以通过该软件同时连接多个本地或远程的命令行会话&#xff0c;并在其间自由切换。GNU Screen可以看作是…

Linux学习笔记本(不定期持续更新)

一、概述 2023年&#xff0c;打算系统自学一遍Linux&#xff0c;分享到这里来&#xff0c;和大家一起相互学习&#xff0c;探讨。 二、Linux基础知识 Linux学习环境搭建学习每一门技术&#xff0c;系统环境很重要&#xff0c;好的系统环境能够极大提高学习效率。学习Linux也是一…

33. 实战:实现某网站店铺信息的查询与批量抓取(附源码)

目录 前言 目的 思路 代码实现 1. 请求URL&#xff0c;获取源代码 2. 解析源代码&#xff0c;获取数据 3. 完善保存数据的函数save_data 4. 理清main函数逻辑&#xff0c;循环传递每一页有效信息的参数 完整代码 运行效果 总结 前言 近日&#xff0c;我们每周四都能…

ESP-C3入门5. 使用通用计时器

ESP-C3入门5. 使用通用计时器一、 简介二、使用步骤三、操作函数1. 基本操作&#xff08;1&#xff09;定时器实例 gptimer_handle_t &#xff08;2&#xff09; 定时器配置结构体 gptimer_config_t&#xff08;3&#xff09; 定时器初始化 timer_init()&#xff08;3&#xff…