深入浅出 diffusion(2):pytorch 实现 diffusion 加噪过程

news2024/11/15 23:38:11

         我在上篇博客深入浅出 diffusion(1):白话 diffusion 原理(无公式)中介绍了 diffusion 的一些基本原理,其中谈到了 diffusion 的加噪过程,本文用pytorch 实现下到底是怎么加噪的。

import torch
import math
import numpy as np
from PIL import Image
import requests
import matplotlib.pyplot as plot
import cv2


def linear_beta_schedule(timesteps):
    """
    linear schedule, proposed in original ddpm paper
    """
    scale = 1000 / timesteps
    beta_start = scale * 0.0001
    beta_end = scale * 0.02
    return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)

def cosine_beta_schedule(timesteps, s = 0.008):
    """
    cosine schedule
    as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
    """
    steps = timesteps + 1
    t = torch.linspace(0, timesteps, steps, dtype = torch.float64) / timesteps
    alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0, 0.999)
    
   
# 时间步(timestep)定义为1000
timesteps = 1000

# 定义Beta Schedule, 选择线性版本,同DDPM原文一致,当然也可以换成cosine_beta_schedule
betas = linear_beta_schedule(timesteps=timesteps)

# 根据beta定义alpha 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# 计算前向过程 diffusion q(x_t | x_{t-1}) 中所需的
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)


def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# 前向加噪过程: forward diffusion process
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
        cv2.imwrite('noise.png', noise.numpy()*255)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )
    
    print('sqrt_alphas_cumprod_t :', sqrt_alphas_cumprod_t)
    print('sqrt_one_minus_alphas_cumprod_t :', sqrt_one_minus_alphas_cumprod_t)
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

# 图像后处理
def get_noisy_image(x_start, t):
  # add noise
  x_noisy = q_sample(x_start, t=t)

  # turn back into PIL image
  noisy_image = x_noisy.squeeze().numpy()

  return noisy_image

...

# 展示图像, t=0, 50, 100, 500的效果
x_start = cv2.imread('img.png') / 255.0
x_start = torch.tensor(x_start, dtype=torch.float)
cv2.imwrite('img_0.png', get_noisy_image(x_start, torch.tensor([0])) * 255.0)
cv2.imwrite('img_50.png', get_noisy_image(x_start, torch.tensor([50])) * 255.0)
cv2.imwrite('img_100.png', get_noisy_image(x_start, torch.tensor([100])) * 255.0)
cv2.imwrite('img_500.png', get_noisy_image(x_start, torch.tensor([500])) * 255.0)
cv2.imwrite('img_999.png', get_noisy_image(x_start, torch.tensor([999])) * 255.0)


sqrt_alphas_cumprod_t : tensor([[[0.9999]]], dtype=torch.float64)
sqrt_one_minus_alphas_cumprod_t : tensor([[[0.0100]]], dtype=torch.float64)
sqrt_alphas_cumprod_t : tensor([[[0.9849]]], dtype=torch.float64)
sqrt_one_minus_alphas_cumprod_t : tensor([[[0.1733]]], dtype=torch.float64)
sqrt_alphas_cumprod_t : tensor([[[0.9461]]], dtype=torch.float64)
sqrt_one_minus_alphas_cumprod_t : tensor([[[0.3238]]], dtype=torch.float64)
sqrt_alphas_cumprod_t : tensor([[[0.2789]]], dtype=torch.float64)
sqrt_one_minus_alphas_cumprod_t : tensor([[[0.9603]]], dtype=torch.float64)
sqrt_alphas_cumprod_t : tensor([[[0.0064]]], dtype=torch.float64)
sqrt_one_minus_alphas_cumprod_t : tensor([[[1.0000]]], dtype=torch.float64)

        以下分别为原图,t = 0, 50, 100, 500, 999 的结果。

        可见,随着 t 的加大,原图对应的比例系数减小,噪声的强度系数加大,t = 500的时候,隐约可见人脸轮廓,t = 999 的时候,人脸彻底淹没在噪声里面了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1411529.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

最小覆盖子串(Leetcode76)

例题: 分析: 比如现在有字符串(s),s "ADOBECODEBANC", 给出目标字符串 t "ABC", 题目就是要从原始字符串(s)中找到一个子串(res)可以覆盖目标字符串 t &…

UE使用C++添加FGameplayTag(游戏标签)

首先Ue会有一个UGameplayTagsManager类型的对象 游戏标签管理器(全局中就有一个) 我们直接通过 UGameplayTagsManager::Get()静态函数拿到 全局唯一的游戏标签管理器的实例 返回的是个左值引用 之后通过调用 AddNativeGameplayTag()函数就可添加游戏标签了 就这么简单 第…

Java+Spring Cloud +Vue+UniApp微服务智慧工地云平台源码

目录 智慧工地云平台功能 【劳务工种】所属工种有哪些? 1.管理人员 2.信息采集 3.证件管理 4.考勤管理 5.考勤明细 6.工资管理 7.现场统计 8.WIFI教育 9.课程库管理 10.工种管理 11.分包商管理 12.班组管理 13.项目管理 智慧工地管理平台是以物联网、…

算法题 — 删除排序数组中的重复项

问题:一个有序数组 nums,原地删除重复出现的元素,使每个元素只出现一次,返回删除后数组的新长度。 注:不能使用额外的数组空间,必须在原地修改输入数组并在使用 O(1) 额外空间的条件下完成。 例&#xff…

【考研结束了,不管上不上岸,我建议你先....】

*** 考研结束,一定要做这几件事! 又一年考研季的落幕,经历了漫长考研岁月的学子们,终于迎来了期盼已久的解脱。参加考研的同学们必须都顺利上岸。 然而对于技术类专业的考生而言,新的征程与机遇才刚刚启航。 此时此刻…

专业144总分410+华南理工大学811信号与系统考研经验华工电子信息与通信

今年专业811信号与系统144(二战,感谢信息通信Jenny老师专业课对我的巨大提高,第一年自己复习只考了90,主要栽专业课和数学)总分410含泪(二战的同学都知道苦,成功来之不易)考上华南理…

InterSystem IRIS BS BP BO配置

应用:根据请求的BS,通过BP,到BO的处理,集成平台BO获取数据并推送给指定第三方 操作步骤: 一、事前准备: 创建交互服务前提前将SQL网关创建和连接好。需记录网关连接名称,配置在BS设置的DSN处…

HMI-Board以太网数据监视器(二)MQTT和LVGL

E ∫ d E ∫ k d q r 2 k L ∫ d q r 2 E \int dE \int \frac{kdq}{r^2} \frac{k}{L} \int \frac{dq}{r^2} E∫dE∫r2kdq​Lk​∫r2dq​ E Q 2 π ϵ L 2 E \frac{Q}{2\pi\epsilon L^2} E2πϵL2Q​ Γ ( n ) ( n − 1 ) ! ∀ n ∈ N \Gamma(n) (n-1)!\quad\forall n…

实习日志5

活字格图片上传功能(批量) 这个报错真的恶心,又看不了他服务器源码,接口文档又是错的 活字格V9获取图片失败bug,报错404-CSDN博客 代码BUG记录: 问题:上传多个文件的base64编码被最后一个文…

eclipse启动Java服务及注意事项

1、导入项目 选择file——》import…——》Generate——》Exiting Projects into Workspace——》选择要导入的项目 2、添加tomcat 1)点击Serves——》No servers are available. Click this link to create a new server… 2)点击“Add…” 3&…

【Servlet】如何编写第一个Servlet程序

个人主页:兜里有颗棉花糖 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 兜里有颗棉花糖 原创 收录于专栏【Servlet】 本专栏旨在分享学习Servlet的一点学习心得,欢迎大家在评论区交流讨论💌 Servlet是Java编写的服务器端…

STM32第一节——初识STM32

1 硬件介绍 1.1 硬件平台 配套硬件:以野火的STM32 F1霸道开发板为平台,若用的是别的开发板,可自己进行移植。 1.2 什么是STM32 STM32是由意法半导体(STMicroelectronics)公司推出的一系列32位的ARM Cortex-M微控制…

CMIP6数据驱动WRF和WRF-Chem模式、WRF-Chem的未来大气污染变化模拟

目录 一、CMIP6数据及运行平台建设 二、CMIP6数据驱动WRF和WRF-Chem模式 三、WRF-Chem的未来情景模拟 更多应用 对模式比较计划的全球气候预估数据进行动力降尺度,结合预估的未来气候变化,运用区域气候模式和气候-化学耦合模式,实现对未来…

【数据结构】数据结构初识

前言: 数据结构是计算存储,组织数据的方式。数据结构是指相互间存在一种或多种特定关系的数据元素的集合。通常情况下,精心选择的数据结构可以带来更高的运行或者存储效率。数据结构往往同高效的检索算法和索引技术有关。 Data Structure Vi…

day2 C++

封装一个矩形类(Rect)&#xff0c;拥有私有属性:宽度(width)、高度(height)&#xff0c; 定义公有成员函数: 初始化函数:void init(int w, int h) 更改宽度的函数:set_w(int w) 更改高度的函数:set_h(int h) 输出该矩形的周长和面积函数:void show() #include <iostre…

笔记--写代码好习惯

原文&#xff1a;写代码有这16个好习惯&#xff0c;可以减少80%非业务的bug

uniapp vuecli项目融合[小记]:将多个项目融合,打包成一个小程序/App,拆分多个H5应用

前言&#xff1a; 目前两个uniapp vuecli开发的项目【A、B】&#xff0c;新规划的项目C&#xff1a;需要融合项目B 80%的功能模块&#xff0c;同时也需要涵盖项目A的所有功能模块。 应用需求&#xff1a; 1、新项目C【小程序】可支持切换到应用A/C界面【内部通过初始化、路由跳…

对于小微企业而言,数字化转型的重要性是什么?

数字化转型对于小微企业至关重要&#xff0c;原因如下&#xff1a; 1.效率和生产力&#xff1a;采用数字工具和技术可以简化业务流程、自动化重复性任务并提高整体效率。这使得小型企业能够用更少的资源完成更多的工作&#xff0c;最终提高生产力。 2.节省成本&#xff1a;数…

Gradle学习笔记:Gradle的简介、下载与安装

文章目录 一、什么是Gradle二、为什么选择Gradle三、下载并安装Gradle四、Gradle的bin目录添加到环境变量五、测试Gradle是否安装正常 一、什么是Gradle Gradle是一个开源构建自动化工具&#xff0c;专为大型项目设计。它基于DSL&#xff08;领域特定语言&#xff09;编写&…

江科大STM32 中

目录 6、TIM&#xff08;Timer&#xff09;定时器基本定时器通用定时器高级定时器示例程序&#xff08;定时器定时中断&定时器外部时钟&#xff09;TIM输出比较示例程序&#xff08;PWM驱动LED呼吸灯&PWM驱动舵机&PWM驱动直流电机&#xff09;TIM输入捕获示例程序&…