昇思25天学习打卡营第19天|munger85

news2024/12/24 20:42:46

Diffusion扩散模型

它并没有那么复杂,它们都将噪声从一些简单分布转换为数据样本,Diffusion也是从纯噪声开始通过一个神经网络学习逐步去噪,最终得到一个实际图像

def rearrange(head, inputs):
b, hc, x, y = inputs.shape
c = hc // head
return inputs.reshape((b, head, c, x * y))

def rsqrt(x):
res = ops.sqrt(x)
return ops.inv(res)

def randn_like(x, dtype=None):
if dtype is None:
dtype = x.dtype
res = ops.standard_normal(x.shape).astype(dtype)
return res

def randn(shape, dtype=None):
if dtype is None:
dtype = ms.float32
res = ops.standard_normal(shape).astype(dtype)
return res

def randint(low, high, size, dtype=ms.int32):
res = ops.uniform(size, Tensor(low, dtype), Tensor(high, dtype), dtype=dtype)
return res

def exists(x):
return x is not None

def default(val, d):
if exists(val):
return val
return d() if callable(d) else d

def _check_dtype(d1, d2):
if ms.float32 in (d1, d2):
return ms.float32
if d1 == d2:
return d1
raise ValueError(‘dtype is not supported.’)

class Residual(nn.Cell):
def init(self, fn):
super().init()
self.fn = fn

def construct(self, x, *args, **kwargs):
    return self.fn(x, *args, **kwargs) + x

这些是辅助的方法
def Upsample(dim):
return nn.Conv2dTranspose(dim, dim, 4, 2, pad_mode=“pad”, padding=1)

def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, pad_mode=“pad”, padding=1)
上面是上,下采样
由于噪声是水平的,那么位置就用sin来表示
class SinusoidalPositionEmbeddings(nn.Cell):
def init(self, dim):
super().init()
self.dim = dim
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = np.exp(np.arange(half_dim) * - emb)
self.emb = Tensor(emb, ms.float32)

def construct(self, x):
    emb = x[:, None] * self.emb[None, :]
    emb = ops.concat((ops.sin(emb), ops.cos(emb)), axis=-1)
    return emb

class Block(nn.Cell):
def init(self, dim, dim_out, groups=1):
super().init()
self.proj = nn.Conv2d(dim, dim_out, 3, pad_mode=“pad”, padding=1)
self.proj = c(dim, dim_out, 3, padding=1, pad_mode=‘pad’)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()

def construct(self, x, scale_shift=None):
    x = self.proj(x)
    x = self.norm(x)

    if exists(scale_shift):
        scale, shift = scale_shift
        x = x * (scale + 1) + shift

    x = self.act(x)
    return x

class ConvNextBlock(nn.Cell):
def init(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):
super().init()
self.mlp = (
nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))
if exists(time_emb_dim)
else None
)

    self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")
    self.net = nn.SequentialCell(
        nn.GroupNorm(1, dim) if norm else nn.Identity(),
        nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),
        nn.GELU(),
        nn.GroupNorm(1, dim_out * mult),
        nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),
    )

    self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

def construct(self, x, time_emb=None):
    h = self.ds_conv(x)
    if exists(self.mlp) and exists(time_emb):
        assert exists(time_emb), "time embedding must be passed in"
        condition = self.mlp(time_emb)
        condition = condition.expand_dims(-1).expand_dims(-1)
        h = h + condition

    h = self.net(h)
    return h + self.res_conv(x)

这哦深奥了。但是就是构建unet
unet是图像的编解码器,可以捕捉细节

class Attention(nn.Cell):
def init(self, dim, heads=4, dim_head=32):
super().init()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads

    self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
    self.to_out = nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)
    self.map = ops.Map()
    self.partial = ops.Partial()

def construct(self, x):
    b, _, h, w = x.shape
    qkv = self.to_qkv(x).chunk(3, 1)
    q, k, v = self.map(self.partial(rearrange, self.heads), qkv)

    q = q * self.scale

    # 'b h d i, b h d j -> b h i j'
    sim = ops.bmm(q.swapaxes(2, 3), k)
    attn = ops.softmax(sim, axis=-1)
    # 'b h i j, b h d j -> b h i d'
    out = ops.bmm(attn, v.swapaxes(2, 3))
    out = out.swapaxes(-1, -2).reshape((b, -1, h, w))

    return self.to_out(out)

class LayerNorm(nn.Cell):
def init(self, dim):
super().init()
self.g = Parameter(initializer(‘ones’, (1, dim, 1, 1)), name=‘g’)

def construct(self, x):
    eps = 1e-5
    var = x.var(1, keepdims=True)
    mean = x.mean(1, keep_dims=True)
    return (x - mean) * rsqrt((var + eps)) * self.g

class LinearAttention(nn.Cell):
def init(self, dim, heads=4, dim_head=32):
super().init()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode=‘valid’, has_bias=False)

    self.to_out = nn.SequentialCell(
        nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True),
        LayerNorm(dim)
    )

    self.map = ops.Map()
    self.partial = ops.Partial()

def construct(self, x):
    b, _, h, w = x.shape
    qkv = self.to_qkv(x).chunk(3, 1)
    q, k, v = self.map(self.partial(rearrange, self.heads), qkv)

    q = ops.softmax(q, -2)
    k = ops.softmax(k, -1)

    q = q * self.scale
    v = v / (h * w)

    # 'b h d n, b h e n -> b h d e'
    context = ops.bmm(k, v.swapaxes(2, 3))
    # 'b h d e, b h d n -> b h e n'
    out = ops.bmm(context.swapaxes(2, 3), q)

    out = out.reshape((b, -1, h, w))
    return self.to_out(out)
    这是注意力模块,也就是网络的权重吧
    class PreNorm(nn.Cell):
def __init__(self, dim, fn):
    super().__init__()
    self.fn = fn
    self.norm = nn.GroupNorm(1, dim)

def construct(self, x):
    x = self.norm(x)
    return self.fn(x)
    把U-Net的卷积/注意层与群归一化

class Unet(nn.Cell):
def init(
self,
dim,
init_dim=None,
out_dim=None,
dim_mults=(1, 2, 4, 8),
channels=3,
with_time_emb=True,
convnext_mult=2,
):
super().init()

    self.channels = channels

    init_dim = default(init_dim, dim // 3 * 2)
    self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)

    dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
    in_out = list(zip(dims[:-1], dims[1:]))

    block_klass = partial(ConvNextBlock, mult=convnext_mult)

    if with_time_emb:
        time_dim = dim * 4
        self.time_mlp = nn.SequentialCell(
            SinusoidalPositionEmbeddings(dim),
            nn.Dense(dim, time_dim),
            nn.GELU(),
            nn.Dense(time_dim, time_dim),
        )
    else:
        time_dim = None
        self.time_mlp = None

    self.downs = nn.CellList([])
    self.ups = nn.CellList([])
    num_resolutions = len(in_out)

    for ind, (dim_in, dim_out) in enumerate(in_out):
        is_last = ind >= (num_resolutions - 1)

        self.downs.append(
            nn.CellList(
                [
                    block_klass(dim_in, dim_out, time_emb_dim=time_dim),
                    block_klass(dim_out, dim_out, time_emb_dim=time_dim),
                    Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                    Downsample(dim_out) if not is_last else nn.Identity(),
                ]
            )
        )

    mid_dim = dims[-1]
    self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
    self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
    self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

    for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
        is_last = ind >= (num_resolutions - 1)

        self.ups.append(
            nn.CellList(
                [
                    block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
                    block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                    Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                    Upsample(dim_in) if not is_last else nn.Identity(),
                ]
            )
        )

    out_dim = default(out_dim, channels)
    self.final_conv = nn.SequentialCell(
        block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1)
    )

def construct(self, x, time):
    x = self.init_conv(x)

    t = self.time_mlp(time) if exists(self.time_mlp) else None

    h = []

    for block1, block2, attn, downsample in self.downs:
        x = block1(x, t)
        x = block2(x, t)
        x = attn(x)
        h.append(x)

        x = downsample(x)

    x = self.mid_block1(x, t)
    x = self.mid_attn(x)
    x = self.mid_block2(x, t)

    len_h = len(h) - 1
    for block1, block2, attn, upsample in self.ups:
        x = ops.concat((x, h[len_h]), 1)
        len_h -= 1
        x = block1(x, t)
        x = block2(x, t)
        x = attn(x)

        x = upsample(x)
    return self.final_conv(x)
    总是就是为了把各个零件合在一起称为大网络
    正向扩散
    def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)
正向传播就是加噪声。

扩散200步

timesteps = 200

定义 beta schedule

betas = linear_beta_schedule(timesteps=timesteps)

定义 alphas

alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1, 0), constant_values=1)

sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))

计算 q(x_{t-1} | x_t, x_0)

posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

p2_loss_weight = (1 + alphas_cumprod / (1 - alphas_cumprod)) ** -0.
p2_loss_weight = Tensor(p2_loss_weight)

def extract(a, t, x_shape):
b = t.shape[0]
out = Tensor(a).gather(t, -1)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))

下载猫猫图像

url = ‘https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/image_cat.zip’
path = download(url, ‘./’, kind=“zip”, replace=True)

在这里插入图片描述
在这里插入图片描述
但是这里就只有1张照片啊。
为什么要做这么复杂的操作
在这里插入图片描述
通过上面的代码正向了,加了噪声了。
每一步加的噪声后是不一样的
在这里插入图片描述
在这里插入图片描述

去噪声的就是unet,他学习了如何区分噪声和真实的有意义的图。
训练开始

下载MNIST数据集

url = ‘https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset.zip’
path = download(url, ‘./’, kind=“zip”, replace=True)
这些是一些小的衣服图
在这里插入图片描述

transforms = [
RandomHorizontalFlip(),
ToTensor(),
lambda t: (t * 2) - 1
]

dataset = dataset.project(‘image’)
dataset = dataset.shuffle(64)
dataset = dataset.map(transforms, ‘image’)
dataset = dataset.batch(16, drop_remainder=True)

训练这个unet

定义动态学习率

lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)

定义 Unet模型

unet_model = Unet(
dim=image_size,
channels=channels,
dim_mults=(1, 2, 4,)
)

name_list = []
for (name, par) in list(unet_model.parameters_and_names()):
name_list.append(name)
i = 0
for item in list(unet_model.trainable_params()):
item.name = name_list[i]
i += 1

定义优化器

optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
loss_scaler = DynamicLossScaler(65536, 2, 1000)
这个是调整loss,不让他太大或者太小。

定义前向过程

def forward_fn(data, t, noise=None):
loss = p_losses(unet_model, data, t, noise)
return loss

计算梯度

grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)

梯度更新

def train_step(data, t, noise):
loss, grads = grad_fn(data, t, noise)
optimizer(grads)
return loss
在这里插入图片描述
这和以往的训练类似
那么正向传播的就是模糊的图
在这里插入图片描述
由于像素太低,你得用代码变小,才看得出来这是个衣服
在这里插入图片描述
请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1942374.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大数据平台之HBase

HBase是一个高可靠性、高性能、面向列、可伸缩的分布式存储系统,是Apache Hadoop生态系统的重要组成部分。它特别适合大规模结构化和半结构化数据的存储和检索,能够处理实时读写和批处理工作负载。以下是对HBase的详细介绍。 1. 核心概念 1.1 表&#x…

TIA博途V19无法勾选来自远程对象的PUT/GET访问的解决办法

TIA博途V19无法勾选来自远程对象的PUT/GET访问的解决办法 TIA博途升级到V19之后,1500CPU也升级到了V3.1的固件,1200CPU升级到了V4.6.1的固件, 固件升级之后,又出现了很多问题,如下图所示,在组态的时候会多出一些东西, 添加CPU之后,在属性界面可以看到“允许来自远程对象…

第二讲:NJ网络配置

Ethernet/IP网络拓扑结构 一. NJ EtherNet/IP 1、网络端口位置 NJ的CPU上面有两个RJ45的网络接口,其中一个是EtherNet/IP网络端口(另一个是EtherCAT的网络端口) 2、网络作用 如图所示,EtherNet/IP网络既可以做控制器与控制器之间的通信,也可以实现与上位机系统的对接通…

python爬虫基础——Webbot库介绍

本文档面向对自动化网页交互、数据抓取和网络自动化任务感兴趣的Python开发者。无论你是初学者还是有经验的开发者,Webbot库都能为你的自动化项目提供强大的支持。 Webbot库概述 Webbot是一个专为Python设计的库,用于简化网页自动化任务。它基于Seleniu…

高速ADC模拟输入接口设计

目录 基本输入接口考虑 输入阻抗 输入驱动 带宽和通带平坦度 噪声 失真 变压器耦合前端 有源耦合前端网络 基本输入接口考虑 采用高输入频率、高速模数转换器(ADC)的系统设计是一 项具挑战性的任务。ADC输入接口设计有6个主要条件: 输入阻抗、输入驱动、带宽…

【RaspberryPi】树莓派系统UI优化

接上文,如何去定制一个树莓派的桌面系统,还是以CM4为例。 解除CM4上电USB无法使用问题 将烧录好的tf卡通过读卡器插入到电脑上,进入boot磁盘,里面有一个Config文件,双击用记事本打开,在【pi4】一栏里加入一…

农业农村大数据底座:实现智慧农业的关键功能

随着信息技术的快速发展,农业领域也在逐步实现数字化转型。农业农村大数据底座作为支持智慧农业发展的重要基础设施,承载了多种关键功能,为农业生产、管理和决策提供了前所未有的支持和可能性。 ### 1. 数据采集与监测 农业农村大数据底座首…

【k8s故障处理篇】calico-kube-controllers状态为“ImagePullBackOff”解决办法

【k8s故障处理篇】calico-kube-controllers状态为“ImagePullBackOff”解决办法 一、环境介绍1.1 本次环境规划1.2 kubernetes简介1.3 kubernetes特点二、本次实践介绍2.1 本次实践介绍2.2 报错场景三、查看报错日志3.1 查看pod描述信息3.2 查看pod日志四、报错分析五、故障处理…

【Docker】Docker Desktop - WSL update failed

问题描述 Windows上安装完成docker desktop之后,第一次启动失败,提示:WSL update failed 解决方案 打开Windows PowerShell 手动执行: wsl --set-default-version 2 wsl --update

使用C#手搓Word插件

WordTools主要功能介绍 编码语言:C#【VSTO】 1、选择 1.1、表格 作用:全选文档中的表格; 1.2、表头 作用:全选文档所有表格的表头【第一行】; 1.3、表正文 全选文档中所有表格的除表头部分【除第一行部分】 1.…

便携式自动气象站:科技赋能气象观测

便携式自动气象站,顾名思义,就是一款集成了多种气象传感器,能够自动进行气象观测和数据记录的设备。它体积小巧、重量轻,便于携带和快速部署,可以在各种环境下进行气象数据的实时监测。同时,通过内置的无线…

Flex布局中元素主轴上平均分布 多余的向左对齐

content:父元素 content-item: 子元素 主轴上子元素平均分布 .content {display: flex;flex-wrap: wrap;justify-content: space-between;.service-item {display: flex;flex-direction: column;justify-content: center;align-items: center;width: 80px;height:…

万字长文之分库分表里无分库分表键如何查询【后端面试题 | 中间件 | 数据库 | MySQL | 分库分表 | 其他查询】

在很多业务里,分库分表键都是根据主要查询筛选出来的,那么不怎么重要的查询怎么解决呢? 比如电商场景下,订单都是按照买家ID来分库分表的,那么商家该怎么查找订单呢?或是买家找客服,客服要找到对…

ubuntu一些好用的开发工具及其配置

1 终端模糊搜索fzf https://github.com/junegunn/fzf 输入某命令,比如 conda ,按下ctrlR,会显示和该命令匹配的历史命令的列表 有了这个工具再也不用记忆太复杂的命令,只需要知道大概几个单词,输入即可搜索。 其搜索…

SSD基本架构与工作原理

SSD的核心由一个或多核心的CPU控制器、DRAM缓存以及多个NAND闪存芯片组成。CPU控制器负责管理所有读写操作,并通过DRAM缓存存储映射表等元数据,以加速寻址过程。 NAND闪存则是数据存储的实际介质,其组织结构从大到小依次为通道(包…

C++实现LRU缓存(新手入门详解)

LRU的概念 LRU(Least Recently Used,最近最少使用)是一种常用的缓存淘汰策略,主要目的是在缓存空间有限的情况下,优先淘汰那些最长时间没有被访问的数据项。LRU 策略的核心思想是: 缓存空间有限&#xff1…

航片转GIS数据自动化管线

近年来,计算机视觉领域的进步已显著改善了物体检测和分割任务。一种流行的方法是 YOLO(You Only Look Once)系列模型。YOLOv8 是 YOLO 架构的演进,兼具准确性和效率,是各种应用的绝佳选择,包括分割卫星航拍…

借助Python将txt文本内容导入到数据库

安装数据库并创建admin账号 #Create mariadb user CREATE USER admin% IDENTIFIED BY password; GRANT SELECT, INSERT, UPDATE, DELETE ON hosts_info.* TO admin%; FLUSH PRIVILEGES;创建库并创建数据表 #创建库 CREATE DATABASE hosts_info; #创建表 CREATE TABLE host_tm…

shell条件语句

一,条件测试 1 . test命令 测试表达式是否成立,若成立返回0,否则返回其他数值 1.1 格式 test 条件表达式 [ 条件表达式 ] 2 . 文件测试 2.1 格式 [ 操作符 文件或目录 ] 例 test -d /home/user 2.2 常用的测试操作符 -d:测试是否为目录(Directory)-e:测试目…

安装Ubuntu24.04服务器版本

Ubuntu系统安装 一.启动安装程序二.执行 Ubuntu Server 安装向导1.选择安装程序语言,通常选择「English」2.设置键盘布局,默认「English US」即可3.选择安装方式 三.配置网络1.按Tab键选择网络接口(例如 ens160),然后按…