《昇思25天学习打卡营第27天》

news2024/11/27 10:44:19

今天我们继续Diffusion扩散模型的后半部分学习

条件U-Net

网络构建过程如下:

  • 首先,将卷积层应用于噪声图像批上,并计算噪声水平的位置

  • 接下来,应用一系列下采样级。每个下采样阶段由2个ResNet/ConvNeXT块 + groupnorm + attention + 残差连接 + 一个下采样操作组成

  • 在网络的中间,再次应用ResNet或ConvNeXT块,并与attention交织

  • 接下来,应用一系列上采样级。每个上采样级由2个ResNet/ConvNeXT块+ groupnorm + attention + 残差连接 + 一个上采样操作组成

  • 最后,应用ResNet/ConvNeXT块,然后应用卷积层

最终,神经网络将层堆叠起来,就像它们是乐高积木一样(但重要的是了解它们是如何工作的)。

代码如下

class Unet(nn.Cell):
    def __init__(
            self,
            dim,
            init_dim=None,
            out_dim=None,
            dim_mults=(1, 2, 4, 8),
            channels=3,
            with_time_emb=True,
            convnext_mult=2,
    ):
        super().__init__()

        self.channels = channels

        init_dim = default(init_dim, dim // 3 * 2)
        self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ConvNextBlock, mult=convnext_mult)

        if with_time_emb:
            time_dim = dim * 4
            self.time_mlp = nn.SequentialCell(
                SinusoidalPositionEmbeddings(dim),
                nn.Dense(dim, time_dim),
                nn.GELU(),
                nn.Dense(time_dim, time_dim),
            )
        else:
            time_dim = None
            self.time_mlp = None

        self.downs = nn.CellList([])
        self.ups = nn.CellList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.CellList(
                    [
                        block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Downsample(dim_out) if not is_last else nn.Identity(),
                    ]
                )
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
            is_last = ind >= (num_resolutions - 1)

            self.ups.append(
                nn.CellList(
                    [
                        block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Upsample(dim_in) if not is_last else nn.Identity(),
                    ]
                )
            )

        out_dim = default(out_dim, channels)
        self.final_conv = nn.SequentialCell(
            block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
        )

    def construct(self, x, time):
        x = self.init_conv(x)

        t = self.time_mlp(time) if exists(self.time_mlp) else None

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        len_h = len(h) - 1
        for block1, block2, attn, upsample in self.ups:
            x = ops.concat((x, h[len_h]), 1)
            len_h -= 1
            x = block1(x, t)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)
        return self.final_conv(x)

正向扩散

  • 我们将正向过程方差设置为常数,从𝛽1=10−41=10−4线性增加到𝛽𝑇=0.02=0.02。

  • 但是,它在(Nichol et al., 2021)中表明,当使用余弦调度时,可以获得更好的结果。

代码

from mindspore.dataset import ImageFolderDataset

image_size = 128
transforms = [
    Resize(image_size, Inter.BILINEAR),
    CenterCrop(image_size),
    ToTensor(),
    lambda t: (t * 2) - 1
]


path = './image_cat'
dataset = ImageFolderDataset(dataset_dir=path, num_parallel_workers=cpu_count(),
                             extensions=['.jpg', '.jpeg', '.png', '.tiff'],
                             num_shards=1, shard_id=0, shuffle=False, decode=True)
dataset = dataset.project('image')
transforms.insert(1, RandomHorizontalFlip())
dataset_1 = dataset.map(transforms, 'image')
dataset_2 = dataset_1.batch(1, drop_remainder=True)
x_start = next(dataset_2.create_tuple_iterator())[0]
print(x_start.shape)

文末附上打卡时间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1961957.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JAVA(IO流-字符流)day 7.30

ok了家人们今天继续学习IO流, 一.字符集 使用字节流输出中文可能有乱码。 因为每次读取的字节没有完全读取一个字的字节。 二.字符流 2.1 字符输出流【Writer】(抽象类) Writer是所有字符流的超类(父类) 字符输出…

蚓链数字化营销系统:“爆省”!“爆赚”!“爆值”!“爆快”!“爆增”!“爆享”!

随着信息技术的飞速发展和消费者行为的深刻变化,数字化营销已成为企业在市场竞争中取得优势的关键手段。蚓链数字化营销系统凭借其创新的功能和策略,为企业带来了一系列“爆”优势! “按效果付费--信息化建设费用爆省”! “按效果…

Win11没有记事本怎么办?更新至win11无法右键新建txt文件?

博主更新至Win11系统后目前用了不到一个月时间,今天突然发现 鼠标右键无法新建txt文件 了,一开始还以为Win11系统不支持txt类型文件,遂查找各种网上恢复教程。本文综合了多篇教程的方法,力求一文解决所有可能出现的情况&#xff0…

网络安全是什么?怎么入门网络安全?

一、网络安全的定义 网络安全,简单来说,就是保护网络系统中的硬件、软件以及其中的数据不因偶然或恶意的原因而遭到破坏、更改、泄露,保障系统连续可靠正常地运行,网络服务不中断。 随着信息技术的飞速发展,网络安全的…

JAVA基础 - 网络编程

目录 一. 网络基础 BS(Browser/Server,浏览器/服务器架构) CS(Client/Server,客户端/服务器架构) 二. TCP Socket通信 三. Socket类 四. 聊天实例 五. UDP Socket 六. 数据交换格式 一. 网络基础 网…

力反馈设备在远程机器人遥操作中的应用实例

随着科技的飞速发展,力反馈设备在远程机器人遥操作中的应用日益广泛,极大地提升了操作的精确性和安全性。其中,Haption Virtuose 6D力反馈设备以其卓越的性能成为该领域的佼佼者。 Haption Virtuose 6D力反馈设备医疗遥操作应用 在医疗领域&a…

公布一批脸书爬虫(facebook)IP地址,真实采集数据

一、数据来源: 1、这批脸书爬虫(facebook)IP来源于尚贤达猎头公司网站采集数据; ​ 2、数据采集时间段:2023年10月-2024年7月; 3、判断标准:主要根据用户代理是否包含“facebook”和IP核实。…

谷粒商城实战笔记-92~96-商品发布和查询

文章目录 Spu列表检索接口。Sku列表检索接口。仓库列表接口。问题记录 这一篇包含如下内容: 92-商品服务-API-新增商品-商品保存其他问题处理93-商品服务-API-商品管理-SPU检索94-商品服务-API-商品管理-SKU检索95-仓储服务-API-仓库管理-整合ware服务&获取仓库…

【云原生】Kubernetes中的定时任务CronJob的详细用法与企业级应用案例分享

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

皮尔逊(Person)相关系数

目录 一、总体和样本 二、总体皮尔逊相关系数 三、样本皮尔逊相关系数 四、皮尔逊相关系数的理解误区 五、总结 六、相关系数大小的解释(并不严格) 七、皮尔逊相关系数的计算 一、总体和样本 二、总体皮尔逊相关系数 三、样本皮尔逊相关系数 四、…

Java并发之ThreadLocal

1. 简介 ThreadLocal是 Java 中提供的一种用于实现线程局部变量的工具类。它允许每个线程都拥有自己的独立副本,从而实现线程隔离,用于解决多线程中共享对象的线程安全问题。 通常,我们会使用 synchronzed 关键字 或者 lock 来控制线程对临…

iPhone 手机如何查看自己的电话号码?这两种方法都可以

设置应用程序查看 第一种查看自己的电话号码的方法是在设置应用程序中的电话选项中查看当前手机的电话号码,下面是具体的操作步骤: 首先我们先打开设置应用程序,然后往下滑动找到电话选项,点击进入。 然后就可以看见界面中有“本…

【100个深度学习实战项目】目标检测、语义分割、目标追踪、图像分类等应有尽有,持续更新~~~

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】共同学习交流~ 👍感谢小伙伴们点赞、关注! 引言 本文主要介绍深度学习在各…

智能城市管理系统设计思路详解:集成InfluxDB、Grafana和MQTTx协议(代码示例)

引言 随着城市化进程的加快,城市管理面临越来越多的挑战。智能城市管理系统的出现,为城市的基础设施管理、资源优化和数据分析提供了现代化的解决方案。本文将详细介绍一个基于开源技术的智能城市管理系统,涵盖系统功能、技术实现、环境搭建…

【扒模块】DFF

图 医学图像分割任务 代码 import torch import torch.nn as nnfrom timm.models.layers import DropPath # 论文:D-Net:具有动态特征融合的动态大核,用于体积医学图像分割(3D图像任务) # https://arxiv.org/abs/2403…

[C#]基于wpf实现的一百多种音色的Midi键盘软件

键盘 音色库 源码地址:https://download.csdn.net/download/FL1623863129/89599322

前端必知必会-html表单的基本使用

文章目录 HTML 表单<form> 元素<input> 元素文本字段<label> 元素单选按钮复选框提交按钮<input> 的 Name 属性总结 HTML 表单 HTML 表单用于收集用户输入。用户输入通常发送到服务器进行处理。 <form> 元素 HTML <form> 元素用于创建 H…

无人机环保行业解决方案-河道自动巡检

搭配大疆机场&#xff0c;智能化巡检 轻量一体化设计 相较于其他市面产品&#xff0c;大疆机场更加小巧&#xff0c;占地面积不足1平方米。 展开状态&#xff1a;长1675 mm&#xff0c;宽895 mm&#xff0c;高530mm&#xff08;不含气象站&#xff09; 闭合状态&#xff1a;…

C++之引用(详解,引用与指针的区别)

目录 1. 引⽤的概念和定义 2. 引⽤的特性 3. 引⽤的使⽤ 4. const引⽤ 5. 指针和引⽤的关系 1. 引⽤的概念和定义 引⽤不是新定义⼀个变量&#xff0c;⽽是给已存在变量取了⼀个别名(相当于是给变量起了个外号)&#xff0c;编译器不会为引⽤变量开辟内存空间&#xff0c;它…

冷月大佬的EVA-GAN: 可扩展的音频生成GAN结构-论文阅读笔记

前言&#xff1a; 最近在调研asr和tts相关的内容&#xff0c;准备开始学习相关的经典工作&#xff0c;为了倒逼自己学进脑子&#xff0c;我尝试将这些看到的一些好的学习笔记分享出来。 萌新学徒&#xff0c;欢迎相关领域的朋友推荐学习路径和经典工作&#xff01;拜谢&#x…