AIGC笔记--基于DDPM实现图片生成

news2024/10/4 10:31:34

目录

1--扩散模型

2--训练过程

3--损失函数

4--生成过程

5--参考


1--扩散模型

完整代码:ljf69/DDPM

扩散模型包含两个过程,前向扩散过程和反向生成过程。

前向扩散过程对一张图像逐渐添加高斯噪声,直至图像变为随机噪声。

反向生成过程从一个随机噪声开始,逐渐去噪声直至生成一张图像。

2--训练过程

通过以下公式对图像进行加噪:

def forward(self, x0, t, eta = None):
    n, c, h, w = x0.shape # 输入图片的shape
    a_bar = self.alpha_bars[t]
    if eta is None:
        eta = torch.randn(n, c, h, w).to(self.device)
    noisy = a_bar.sqrt().reshape(n, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(n, 1, 1, 1) * eta # 加噪
    return noisy # 返回加噪结果

3--损失函数

通过一个UNet网络来预测损失,计算预测损失和真实损失MSE损失:

...
eta = torch.randn_like(x0).to(device) # 产生真实随机噪声
t = torch.randint(0, n_steps, (n,)).to(device)

# 前向扩散过程
noisy_imgs = ddpm(x0, t, eta)

# 通过UNet预测噪声
eta_theta = ddpm.backward(noisy_imgs, t.reshape(n, -1))

# 计算预测噪声和真实随机噪声的MSE损失
loss = mse(eta_theta, eta)
...

4--生成过程

通过以下公式实现图片生成:

x = torch.randn(n_samples, c, h, w).to(device) # 随机初始化噪声
for idx, t in enumerate(list(range(ddpm.n_steps))[::-1]):
    time_tensor = (torch.ones(n_samples, 1) * t).to(device).long()
    eta_theta = ddpm.backward(x, time_tensor)
    alpha_t = ddpm.alphas[t]
    alpha_t_bar = ddpm.alpha_bars[t]

    x = (1 / alpha_t.sqrt()) * (x - (1 - alpha_t) / (1 - alpha_t_bar).sqrt() * eta_theta) # 去噪
    if t > 0:
        z = torch.randn(n_samples, c, h, w).to(device)
        beta_t = ddpm.betas[t]
        sigma_t = beta_t.sqrt()
        x = x + sigma_t * z

5--参考

怎么理解今年 CV 比较火的扩散模型(DDPM)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1124965.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

--initialize specified but the data directory has files in it. Aborting. 问题解决

当电脑输入这条命令以试图初始化数据库的时候,出现这样的错误。 2023-10-23T09:04:21.258180Z 0 [Warning] TIMESTAMP with implicit DEFAULT value is deprecated. Please use --explicit_defaults_for_timestamp server option (see documentation for more deta…

Spark SQL概述与基本操作

目录 一、Spark SQL概述 (1)概念 (2)特点 (3)Spark SQL与Hive异同 (4)Spark的数据抽象 二、Spark Session对象执行环境构建 (1)Spark Session对象 (2)代码演…

Python-字符串(切片操作与内建函数)

目录 一、字符串介绍 1、什么是字符串 2、转义字符 二、字符串的输入和输出 1、字符串输出 2、字符串输入 三、访问字符串中的值 1、字符串的存储方式 2、使用切片截取字符串 四、字符串内建函数 1、find 2、index 3、count 4、replace 5、split 6、capitalize …

Centos 7 Zabbix配置安装

前言 Zabbix是一款开源的网络监控和管理软件,具有高度的可扩展性和灵活性。它可以监控各种网络设备、服务器、虚拟机以及应用程序等,收集并分析性能指标,并发送警报和报告。Zabbix具有以下特点: 1. 支持多种监控方式:可…

Docker容器引擎的介绍

目录 Docker概述 容器受欢迎的原因 Docker与虚拟机的区别 Docker三个核心概念 Docker的安装 1、环境准备 2、安装依赖包 3、设置阿里云镜像源 4、安装 Docker-CE并设置为开机自动启动 Docker命令 1、查看 docker 版本信息 2、docker 信息查看 3、Docker 镜像操作命…

GoLong的学习之路(五)语法之数组

书接上回,上回书说到,循环语句,在go中循环语句的少了whlie这个关键词,但是与之for可以改这个改这个特点。并且在终止关键词中,又有标签可以方便,停止。这次说数组 文章目录 Array(数组)数组的初始化方法一方…

数据结构堆详解

[TOC]堆详解 一,堆 1.1堆的概念 堆的性质: 堆中某个节点的值总是不大于或不小于其父节点的值; 堆总是一棵完全二叉树。 1.2堆的存储模式 我们前面的文章提到过,二叉树的两种存储模式,一个是顺序存储,一…

网络第一颗

✍ 如何理解局域网和广域网? ✍ 路由器和交换机是怎样工作的? ✍ 三层交换机能不能代替路由器? -- 1.局域网 2. 广域网 -- -- 企业网络 运营商架构 数据中心架构 -- 局域网 - 内网 - 私网 -- 通过交换机连接的 转发相同IP地址段的…

NVIDIA显卡算力表--nvidia显卡算力表

参考链接:https://blog.csdn.net/qq_41070955/article/details/108269915 官方链接:https://developer.nvidia.com/cuda-gpus

电压放大器在工业领域有哪些用途

电压放大器在工业领域中有广泛的应用,其主要功能是将传感器或其他信号源的微小电压信号放大为更大幅度的电压信号,以便进行后续的信号处理、控制和监测。以下是电压放大器在工业领域中的一些常见用途: 传感器信号放大:工业生产中经…

Java 通过反射修改字符串 String 类型变量的取值而不改变字符串变量的指向

注意点 由于 JDK 8 中有关反射相关的功能自从 JDK 9 开始就已经被限制了,如:通过反射修改 String 类型变量的 value 字段(final byte[]),所以要能够使用运行此方法,需要在运行项目时,添加虚拟机(VM)选项:-…

map set 使用快速上手【C++】

目录 一,关联式容器 二,键值对 三,set 1)使用参考此文档 2)count 函数 3)multiset类 四,map 1. 模板参数介绍 2.operator[]介绍 3. multimap 英语比较好的同学可以自行查找文档 学…

springboot+avue框架开发的医院绩效考核系统全套源码

医院综合绩效核算系统全套源码 (应用案例自主版权演示) 医院绩效考核系统以医院的发展战略为导向,把科室、员工的绩效考核跟战略发展目标紧密结合,引导医院各个科室、各员工的工作目标跟医院的发展目标结合在一起,实现…

代码随想录Day26 贪心01 LeetCode T53 最大子数组和

LeetCode T53 最大子数组和 题目链接:53. 最大子数组和 - 力扣(LeetCode) 题目思路: 贪心贪的是哪里呢? 如果 -2 1 在一起,计算起点的时候,一定是从 1 开始计算,因为负数只会拉低总和,这就是贪…

VPN访问外网的原理

一.前言 许多人都用VPN翻墙,那么VPN为什么可以做到访问外网? VPN的全称叫“Virtual Private Network”意思就是虚拟私人专用网络,是专用网络的延伸,通过VPN,可以模拟点对点专用连接的方式,通过共享和公共网…

对知识蒸馏的一些理解

知识蒸馏是一种模型压缩技术,它通过从一个大模型(教师模型)中传输知识到一个小模型(学生模型)中来提高学生模型的性能,知识蒸馏也要用到真实的数据集标签。 软损失soft loss就是拿教师模型在蒸馏温度为T的…

Ai写作创作系统ChatGPT网站源码+图文搭建教程+支持GPT4.0+支持ai绘画(Midjourney)/支持OpenAI GPT全模型+国内AI全模型

一、AI创作系统 SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统AI绘画系统,支持OpenAI GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署…

map 和 set 的一起使用

map 和 set 一起使用的场景其实也蛮多的,最近业务上就遇到了。需求是这样的,一条路径(mpls中的lsp)会申请多个 id,这个 id 是独一无二的。这里很显然就就一个”一对多“的情况,合适用这个容器不保存这些信息…

【Java集合类面试九】、介绍一下HashMap的扩容机制

文章底部有个人公众号:热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享? 踩过的坑没必要让别人在再踩,自己复盘也能加深记忆。利己利人、所谓双赢。 面试官:介绍一下HashMap的扩容机…

【Java集合类面试七】、 JDK7和JDK8中的HashMap有什么区别?

文章底部有个人公众号:热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享? 踩过的坑没必要让别人在再踩,自己复盘也能加深记忆。利己利人、所谓双赢。 面试官:JDK7和JDK8中的HashMap有…