Mindspore框架利用扩散模型DDPM生成高分辨率图像|(三)模型训练与推理实践

news2024/9/20 16:55:43

利用扩散模型DDPM生成高分辨率图像(生成高保真图像项目实践)

Mindspore框架利用扩散模型DDPM生成高分辨率图像|(一)关于denoising diffusion probabilistic model (DDPM)模型
Mindspore框架利用扩散模型DDPM生成高分辨率图像|(二)数据集准备与处理
Mindspore框架利用扩散模型DDPM生成高分辨率图像|(三)模型训练与推理实践


一、扩散模型DDPM模型训练

扩散模型生成新图像是通过反转扩散过程来实现。理想情况下,我们最终会得到一个看起来像是来自真实数据分布的图像。

def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)

    if t_index == 0:
        return model_mean
    posterior_variance_t = extract(posterior_variance, t, x.shape)
    noise = randn_like(x)
    return model_mean + ops.sqrt(posterior_variance_t) * noise

def p_sample_loop(model, shape):
    b = shape[0]
    # 从纯噪声开始
    img = randn(shape, dtype=None)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, ms.numpy.full((b,), i, dtype=mstype.int32), i)
        imgs.append(img.asnumpy())
    return imgs

def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

1.1 模型初始化

# 定义动态学习率
lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)

# 定义 Unet模型
unet_model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)

name_list = []
for (name, par) in list(unet_model.parameters_and_names()):
    name_list.append(name)
i = 0
for item in list(unet_model.trainable_params()):
    item.name = name_list[i]
    i += 1

# 定义优化器
optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
loss_scaler = DynamicLossScaler(65536, 2, 1000)

# 定义前向过程
def forward_fn(data, t, noise=None):
    loss = p_losses(unet_model, data, t, noise)
    return loss

# 计算梯度
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)

# 梯度更新
def train_step(data, t, noise):
    loss, grads = grad_fn(data, t, noise)
    optimizer(grads)
    return loss

1.2 模型训练

import time

epochs = 10

iterator = dataset.create_tuple_iterator(num_epochs=epochs)
for epoch in range(epochs):
    begin_time = time.time()
    for step, batch in enumerate(iterator):
        unet_model.set_train()
        batch_size = batch[0].shape[0]
        t = randint(0, timesteps, (batch_size,), dtype=ms.int32)
        noise = randn_like(batch[0])
        loss = train_step(batch[0], t, noise)

        if step % 500 == 0:
            print(" epoch: ", epoch, " step: ", step, " Loss: ", loss)
    end_time = time.time()
    times = end_time - begin_time
    print("training time:", times, "s")
    # 展示随机采样效果
    unet_model.set_train(False)
    samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
    plt.imshow(samples[-1][5].reshape(image_size, image_size, channels), cmap="gray")
print("Training Success!")

在这里插入图片描述

二、模型推理预测

2.1 推理过程

从模型中采样

# 采样64个图片
unet_model.set_train(False)
samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)
# 展示一个随机效果
# random_index = 5
# plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

创建去噪过程

import matplotlib.animation as animation

random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=100)
animate.save('diffusion.gif')
plt.show()

在这里插入图片描述

三、参考文献

  1. The Annotated Diffusion Model
  2. 由浅入深了解Diffusion Model
  3. Diffusion扩散模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1995631.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

告别杂音,从 AI 音频降噪开始

生活中,音频无处不在。无论是聆听动人的音乐,还是参与重要的电话会议,又或是沉浸于精彩的网课学习,清晰、纯净的音频质量都至关重要。然而,音频中的噪声却像不速之客,扰乱着这份美好。 音频中的噪声形式多样…

封装一个给 .NET Framework 用的内存缓存帮助类

前言 .NET Core 中已经内置了内存缓存相关的类和操作方法,直接就能使用,非常方便。但在 .NET Framework 中,如果想要使用内存缓存,需要自己进行封装。本文分享一个我自己项目中封装的内存缓存帮助类,有需要的童鞋可以…

前端已经学会vue,做粒子效果

目录 1. Canvas API 2. WebGL 3. 粒子系统 4. 动画与性能优化 5. 现有库和框架 6. Vue 组件和状态管理 实践项目建议 案例1 案例2雪花 已经熟悉了 Vue、TypeScript 和 JavaScript,下面是一些你可以学习的内容,以帮助你实现粒子效果的界面&#…

深度学习基础 - 梯度垂直于等高线的切线

深度学习基础 - 梯度垂直于等高线的切线 flyfish 梯度 给定一个标量函数 f ( x , y ) f(x, y) f(x,y),它的梯度(gradient)是一个向量,表示为 ∇ f ( x , y ) \nabla f(x, y) ∇f(x,y),定义为: ∇ f ( x…

单片机GPIO模式和应用

Push pull 推挽输出 定义:推挽输出是一种输出模式,其中引脚可以输出高电平或低电平,且两种电平状态下都具有较强的驱动能力。 特点: 无论输出高电平还是低电平,都有较强的电流驱动能力。 适用于驱动外部数字电路…

宝塔面板启用 QUIC 与 Brotli 的完整教程

环境 系统:Ubuntu 22.04.4 LTS x86_64 宝塔版本:7.7.0 (可使用本博客提供的一键安装优化脚本) nginx版本:1.26.1 开放UDP端口 注意:在你的服务器商家那里也要开放443 udp端口 sudo ufw allow 443/udp然后重新加载 UFW 以使新…

【漏洞复现】maxView Storage Manager 远程代码执行漏洞

maxView Storage Manager使查看、监控和配置系统中基于Microsemi RAID适配器构建的所有存储变得简单。⽅便的图形⽤户界⾯(GUI)在Microsemi产品线和⽀持的操作系统(包括 Windows、Linux、VMWare和Solaris)中的外观和操作都相同。使…

多线程编译

多线程与多进程一样,为了能同时执行多个任务 区别 多进程 创建子进程,子进程会拷贝父进程的数据段的所有内存 进程是资源的获取单位 每个进程完全独立运行 更加关注两个进程之间的通信问题 多线程 线程是进程的最小组成单位,每个进程…

代码随想录算法训练营Day32 | 56. 合并区间 | 738.单调递增的数字 | 968.监控二叉树

今日任务 56. 合并区间 题目链接&#xff1a; https://leetcode.cn/problems/merge-intervals/题目描述&#xff1a; Code class Solution { public:vector<vector<int>> merge(vector<vector<int>>& intervals) {ranges::sort(intervals, [&…

Spring:springboot集成jetcache循环依赖问题

springboot版本&#xff1a;2.6.14 jetcache版本&#xff1a;2.6.2 启动项目报错如下&#xff1a; 解决方案&#xff1a; jetcache版本升级到2.6.4 https://github.com/alibaba/jetcache/issues/624

IT运维岗适用的6本证书

作为IT从业人员&#xff0c;不断提升自身的专业技能和知识是提升职场竞争力、助力升职加薪的重要途径。特别是在运维领域&#xff0c;虽然工作看似简单&#xff0c;但实际上需要掌握的技术知识却相当全面。为了全面提升自己的技术能力&#xff0c;并证明自己的专业能力&#xf…

每周心赏|七夕这样玩也太超前了吧,速来AI一下!

明天就是七夕节了&#xff0c;是时候给七夕节来点大震撼了&#xff0c;AI带你玩点不一样的&#xff01; 给大家挖掘了几个有梗又有爱的智能体。信我&#xff0c;快来试玩&#xff01; 不知道大家是什么人&#xff1f;反正&#xff0c;我是一个很爱测评的人&#x1f92d;&#…

【GaussDB(DWS)】数仓部署架构与物理结构分析

数仓架构与物理结构分析 一、部署架构二、物理结构三、测试验证 一、部署架构 华为数据仓库服务DWS&#xff0c;集群版本8.1.3.x 集群拓扑结构&#xff1a; 上述拓扑结构为DWS单AZ高可靠部署架构&#xff0c;为减少硬件故障对系统可用性的影响&#xff0c;建议集群部署方案遵…

制造企业技术图纸不受控的影响与规避方法

在制造企业中&#xff0c;技术图纸是产品设计、制造与检验的核心依据。若技术图纸不受控&#xff0c;将对企业造成诸多不利影响。 首先&#xff0c;产品质量无法得到保障。不受控的图纸可能存在设计缺陷、尺寸误差或工艺不合理等问题&#xff0c;导致生产出的产品不合格&#…

独辟蹊径:用Python打造你的副业帝国,迈向财富自由

在当今这个数字化时代&#xff0c;掌握一门编程语言如同拥有了一把开启无限可能的钥匙。Python&#xff0c;以其简洁的语法、强大的库支持和广泛的应用领域&#xff0c;成为了许多人实现副业收入乃至财富自由的首选工具。本文将探讨如何利用Python技能开启副业&#xff0c;并逐…

mysql中的表查询操作

performance_schema 系统数据库用于收集Mysql服务器的性能参数&#xff0c;以便数据库管理员了解产生性能瓶颈的原因。information_schema 系统数据库定义了所有数据库对象的元数据信息。 表的常规操作&#xff08;增删改查&#xff09; 我们经常对表进行以下操作 插入&#x…

OceanMind海睿思受邀参加第41届CCF中国数据库学术会议

CCF 中国数据库学术会议始于1977年&#xff0c;是由数据库专业委员会举办的中国数据库领域的最高学术会议&#xff0c;第41届中国数据库学术会议&#xff08;NDBC 2024&#xff09;将于2024年8月7日-8月10日在新疆乌鲁木齐举行。中新赛克副总兼大数据产品线总经理卢云川先生受邀…

匹配格值的前半部分

Excel有多列含空格的源数据&#xff0c;如C3:D19&#xff1b;还有若干用于比较的数据项&#xff0c;由"-"隔为前后两部分&#xff0c;如F3:F7。 要求用源数据的每列与数据项的前半部分进行比较&#xff0c;将匹配上的数据项填在该列下面。 使用 SPL XLL spl("d…

DALL•E 3 重新定义图像生成的人工智能

在人工智能的不断发展中&#xff0c;图像生成技术一直是一个备受关注的领域。OpenAI 的 DALL-E 系列自发布以来&#xff0c;便因其卓越的图像生成能力而备受瞩目。作为这一系列的最新成员&#xff0c;DALL-E 3 再次突破了技术的界限&#xff0c;为图像生成带来了全新的可能性。…

嵌入式day23

实现minishell minishell功能&#xff1a; 1,cp 复制文件 cp 1 2 把文件1复制成文件2 2,cat 查看文件 cat 1 查看文件到内容 3,cd 切换路径 cd 1 切换到目录1中 4,ls 查看当前目录下到文件 ls 或 ls /home 5,ll 查看当前目录下到文件 ll 或 ll /home 6,ln -s 创建软链接…