昇思 25 天学习打卡营第 24 天 | MindSpore Pix2Pix 实现图像转换

news2024/9/16 11:18:00

1. 背景:

使用 MindSpore 学习神经网络,打卡第 24 天;主要内容也依据 mindspore 的学习记录。

2. PixPix 介绍:

MindSpore 的 Pix2Pix 图像转换

  • 介绍
    Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,该模型是由Phillip Isola等作者在2017年CVPR上提出的,可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图的转换。Pix2Pix是将cGAN应用于有监督的图像到图像翻译的经典之作,其包括两个模型:生成器和判别器。

  • 论文
    Conditional Generative Adversarial Networks 论文地址
    Image-to-Image Translation with Conditional Adversarial Networks

  • 基本原理:
    cGAN 与 GAN 区别:
    a. 输入不同:
    cGAN 生成器输入图片,作为指导信息,生成假图像。本质是:输入图像转换输出为相应“假”图像的本质是从像素到另一个像素的映射。
    GAN 输入是一个给定的随机噪声生成图像,输出图像通过其他约束条件控制生成。

MindSpore 的 docs 中有详细的说明;
https://gitee.com/mindspore/docs/blob/r2.3.0rc2/tutorials/application/source_zh_cn/generative/pix2pix.ipynb

3. 具体实现:

mindspore 使用 pix2pix 数据;

3.1 数据下载:

# 数据下载
ffrom download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"

download(url, "./dataset", kind="tar", replace=True)

3.2 构造网络

生成器G用到的是U-Net结构,输入的轮廓图 𝑥 编码再解码成真是图片,判别器D用到的是作者自己提出来的条件判别器PatchGAN,判别器D的作用是在轮廓图 𝑥 的条件下,对于生成的图片 𝐺(𝑥) 判断为假,对于真实判断为真。

  • 生成器的 G 结构如下所示:
    使用的是 U-Net 结构
    在这里插入图片描述
  • 定义 UNet Skip Connection Block
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops

class UNetSkipConnectionBlock(nn.Cell):
    def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
                 submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
        super(UNetSkipConnectionBlock, self).__init__()
        down_norm = nn.BatchNorm2d(inner_nc)
        up_norm = nn.BatchNorm2d(outer_nc)
        use_bias = False
        if norm_mode == 'instance':
            down_norm = nn.BatchNorm2d(inner_nc, affine=False)
            up_norm = nn.BatchNorm2d(outer_nc, affine=False)
            use_bias = True
        if in_planes is None:
            in_planes = outer_nc
        down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
                              stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
        down_relu = nn.LeakyReLU(alpha)
        up_relu = nn.ReLU()
        if outermost:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, pad_mode='pad')
            down = [down_conv]
            up = [up_relu, up_conv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv]
            up = [up_relu, up_conv, up_norm]
            model = down + up
        else:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv, down_norm]
            up = [up_relu, up_conv, up_norm]

            model = down + [submodule] + up
            if dropout:
                model.append(nn.Dropout(p=0.5))
        self.model = nn.SequentialCell(model)
        self.skip_connections = not outermost

    def construct(self, x):
        out = self.model(x)
        if self.skip_connections:
            out = ops.concat((out, x), axis=1)
        return out
  • 基于 UNet 的生成器
class UNetGenerator(nn.Cell):
    def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
        super(UNetGenerator, self).__init__()
        unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
                                             norm_mode=norm_mode, innermost=True)
        for _ in range(n_layers - 5):
            unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
                                                 norm_mode=norm_mode, dropout=dropout)
        unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
                                             outermost=True, norm_mode=norm_mode)

    def construct(self, x):
        return self.model(x)

Pix2Pix在训练和测试时都使用了dropout,这样可以生成多样性的结果。

  • 构建判别器,基于 PatchGAN:
    判别器使用的PatchGAN结构,可看做卷积。生成的矩阵中的每个点代表原图的一小块区域(patch)。通过矩阵中的各个值来判断原图中对应每个Patch的真假。
import mindspore.nn as nn

class ConvNormRelu(nn.Cell):
    def __init__(self,
                 in_planes,
                 out_planes,
                 kernel_size=4,
                 stride=2,
                 alpha=0.2,
                 norm_mode='batch',
                 pad_mode='CONSTANT',
                 use_relu=True,
                 padding=None):
        super(ConvNormRelu, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if not padding:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
                             has_bias=has_bias, padding=padding)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

class Discriminator(nn.Cell):
    def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [
            nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
            nn.LeakyReLU(alpha)
        ]
        nf_mult = ndf
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * ndf
            layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * ndf
        layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
        self.features = nn.SequentialCell(layers)

    def construct(self, x, y):
        x_y = ops.concat((x, y), axis=1)
        output = self.features(x_y)
        return output

3.3 生成器与判别器的初始化

实例化Pix2Pix生成器和判别器

import mindspore.nn as nn
from mindspore.common import initializer as init

g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'


net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,
                              ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))


net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,
                                  alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))

class Pix2Pix(nn.Cell):
    """Pix2Pix模型网络"""
    def __init__(self, discriminator, generator):
        super(Pix2Pix, self).__init__(auto_prefix=True)
        self.net_discriminator = discriminator
        self.net_generator = generator

    def construct(self, reala):
        fakeb = self.net_generator(reala)
        return fakeb

3.4 训练

训练判别器和训练生成器。训练判别器的目的是最大程度地提高判别图像真伪的概率。训练生成器是希望能产生更好的虚假图像。在这两个部分中,分别获取训练过程中的损失,并在每个周期结束时进行统计。

  • 损失函数:
    定义了 Generator 和 Discriminator 后,损失函数使用MindSpore中二进制交叉熵损失函数BCELoss ;这里生成器和判别器都是使用Adam优化器。
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor

epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100

def get_lr():
    lrs = [lr] * dataset_size * n_epochs
    lr_epoch = 0
    for epoch in range(n_epochs_decay):
        lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
        lrs += [lr_epoch] * dataset_size
    lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
    return Tensor(np.array(lrs).astype(np.float32))

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

def forword_dis(reala, realb):
    lambda_dis = 0.5
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    pred1 = net_discriminator(reala, realb)
    loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
    loss_dis = loss_d * lambda_dis
    return loss_dis

def forword_gan(reala, realb):
    lambda_gan = 0.5
    lambda_l1 = 100
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    loss_1 = loss_f(pred0, ops.ones_like(pred0))
    loss_2 = l1_loss(fakeb, realb)
    loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
    return loss_gan

d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)

grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())

def train_step(reala, realb):
    loss_dis, d_grads = grad_d(reala, realb)
    loss_gan, g_grads = grad_g(reala, realb)
    d_opt(d_grads)
    g_opt(g_grads)
    return loss_dis, loss_gan

if not os.path.isdir(ckpt_dir):
    os.makedirs(ckpt_dir)

g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)

for epoch in range(epoch_num):
    for i, data in enumerate(data_loader):
        start_time = datetime.datetime.now()
        input_image = Tensor(data["input_images"])
        target_image = Tensor(data["target_images"])
        dis_loss, gen_loss = train_step(input_image, target_image)
        end_time = datetime.datetime.now()
        delta = (end_time - start_time).microseconds
        if i % 2 == 0:
            print("ms per step:{:.2f}  epoch:{}/{}  step:{}/{}  Dloss:{:.4f}  Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))
        d_losses.append(dis_loss.asnumpy())
        g_losses.append(gen_loss.asnumpy())
    if (epoch + 1) == epoch_num:
        mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")

3.5 推理

获取上述训练过程完成后的ckpt文件,通过load_checkpoint和load_param_into_net将ckpt中的权重参数导入到模型中,获取数据进行推理并对推理的效果图进行演示(由于时间问题,训练过程只进行了3个epoch,可根据需求调整epoch)

from mindspore import load_checkpoint, load_param_into_net

param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):
    plt.subplot(2, 10, i + 1)
    plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 10, i + 11)
    plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

4. 相关链接:

  • https://xihe.mindspore.cn/events/mindspore-training-camp
  • [1] Phillip Isola,Jun-Yan Zhu,Tinghui Zhou,Alexei A. Efros. Image-to-Image Translation with Conditional Adversarial Networks.[J]. CoRR,2016,abs/1611.07004. https://arxiv.org/pdf/1611.07004

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1961569.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Oracle如何跨越incarnation进行数据恢复

作者介绍:老苏,10余年DBA工作运维经验,擅长Oracle、MySQL、PG、Mongodb数据库运维(如安装迁移,性能优化、故障应急处理等) 公众号:老苏畅谈运维 欢迎关注本人公众号,更多精彩与您分享…

Skywalking 入门与实战

一 什么是 Skywalking? Skywalking 时一个开源的分布式追踪系统,用于检测、诊断和优化分布式系统的功能。它可以帮助开发者和运维人员深入了解分布式系统中各个组件之间的调用关系、性能瓶颈以及异常情况,从而提供系统级的性能优化和故障排查。 1.1 为…

笑谈“八股文”,人生不成文

一、“八股文”在实际工作中是助力、阻力还是空谈? 作为现在各类大中小企业面试程序员时的必问内容,“八股文”似乎是很重要的存在。但“八股文”是否能在实际工作中发挥它“敲门砖”应有的作用呢?有IT人士不禁发出疑问:程序员面试…

AcWing3302. 表达式求值

代码解释 while(j<str.size()&&isdigit(str[j])){xx*10str[j]-0;}把字符串中里面连续的数字转化为int类型变量&#xff0c;比如输入996/3328,正常的挨个字符扫描只能扫到’9’,‘9’,‘6’,但是按照上面代码的算法是重新开了一个循环&#xff0c;直接把’9’,‘9’,…

【网络请求调试神器,curl -vvv 返回都有什么】

curl -vvv 是一个用于在命令行中执行 HTTP 请求的命令&#xff0c;其中 -vvv 是一个选项&#xff0c;用于启用详细的调试输出。 vvv: 这是一个选项&#xff0c;表示启用详细的调试输出。每个 v 增加调试信息的详细程度&#xff0c;vvv 是最高级别的详细输出。 详细输出包括&a…

【shell脚本快速一键部署项目】

目录 一、环境拓扑图二、主机环境描述三、注意四、需求描述五、shell代码的编写六、总结 一、环境拓扑图 二、主机环境描述 主机名主机地址需要提供的服务content.exam.com172.25.250.101提供基于 httpd/nginx 的 YUM仓库服务ntp.exam.com172.25.250.102提供基于Chronyd 的 NT…

GPU池化:点燃Jupyter Notebook中的AI算力之火

数据科学的火花在Jupyter Notebook中点燃&#xff0c;而GPU的加入&#xff0c;让这火焰更加炽热&#xff01;随着人工智能领域的飞速发展&#xff0c;利用GPU加速已成为数据科学和机器学习领域的新常态。 今天&#xff0c;我们要探索的&#xff0c;是Jupyter Notebook与GPU池化…

PHP学习:PHP基础

以.php作为后缀结尾的文件&#xff0c;由服务器解析和运行的语言。 一、语法 PHP 脚本可以放在文档中的任何位置。 PHP 脚本以 <?php 开始&#xff0c;以 ?> 结束。 <!DOCTYPE html> <html> <body><h1>My first PHP page</h1><?php …

spaCy语言模型下载

spaCy 是一个基于 Python 编写的开源自然语言处理&#xff08;NLP&#xff09;库&#xff0c;它提供了一系列的工具和功能&#xff0c;用于文本预处理、文本解析、命名实体识别、词性标注、句法分析和文本分类等任务。 spaCy支持多种语言模型对文本进行处理&#xff0c;包括中文…

自己在Vmware中搭建mqtt服务器

前言 在学习某个HMI的使用的时候&#xff0c;这个HMI带有MQTT功能&#xff0c;就想着自己是不是能够搭建一个自己的MQTT的服务器呢&#xff1f; 一、mqtt 自己搭建之一&#xff1a;Mosquitto 自己搭建MQTT服务器需要安装和运行MQTT服务软件&#xff0c;比如常用的是Mosquitto…

Tkinter简介与实战(1)

Tkinter简介与实战---实现一个计算器 Tkinter简介安装环境和安装命令WindowsmacOSLinux 注意事项使用正确的包管理器&#xff1a;检查安装完整性&#xff1a;更新 Python&#xff1a;使用虚拟环境&#xff1a; 一个实战例子-----计算器1.创建窗口&#xff1a;2.创建 GUI 组件&a…

学习大数据DAY27 Linux最终阶段测试

满分&#xff1a;100 得分&#xff1a;72 目录 一选择题&#xff08;每题 3 分&#xff0c;共计 30 分&#xff09; 二、编程题&#xff08;共 70…

ANSYS仿真DDR4的眼图

1 眼图的基本知识 对于数字信号&#xff0c;高低电平转换可以组合在多个序列中。以3位为例&#xff0c;总共有000-111和8种组合。在时域中&#xff0c;根据某个参考点对足够多的序列进行对齐&#xff0c;然后将波形叠加形成眼图&#xff0c;如下图所示。 图&#xff1a;眼图中…

JavaScript object find 示例

https://andi.cn/page/621631.html

从信息论的角度看微博推荐算法

引言 在数字时代&#xff0c;推荐系统已成为社交媒体和其他在线服务平台的核心组成部分。它们通过分析用户行为和偏好&#xff0c;为用户提供个性化的内容&#xff0c;从而提高用户满意度和平台的参与度。推荐系统不仅能够增强用户体验&#xff0c;还能显著提升广告投放的效率…

angular入门基础教程(一)环境配置与新建项目

ng已经更新到v18了&#xff0c;我对他的印象还停留在v1,v2的版本&#xff0c;最近研究了下&#xff0c;与react和vue是越来越像了&#xff0c;所以准备正式上手了。 新官网地址:https://angular.cn/ 准备条件 nodejs > 18.0vscodeng版本18.x(最新的版本) {"name&qu…

C# Unity 面向对象补全计划 之 继承(字段与属性)

本文仅作学习笔记与交流&#xff0c;不作任何商业用途&#xff0c;作者能力有限&#xff0c;如有不足还请斧正 本系列旨在通过补全学习之后&#xff0c;给出任意类图都能实现并做到逻辑上严丝合缝 Q&#xff1a;为什么要单讲继承字段与属性&#xff0c;不讲继承方法了吗&#x…

【SuperMap GIS 信创部署系列】-- 金蝶V10中间件

⼀、安装包获取 本⽂以10.2.1版本安装为例&#xff0c;官⽹下载iserver war包即可。 下载地址&#xff1a;http://support.supermap.com.cn/DownloadCenter/DownloadPage.aspx?id1852 ⼆、部署 iServer.war 1.解压安装包 将下载的supermap-iserver-10.2.1-war.zip包进⾏解…

基于单片机的步进电机系统设计方法探究

摘 要&#xff1a; 单片机是步进电机系统的重要组成部分&#xff0c;对于步进电机系统的驱动控制具有重要的影响。通过分析步进电机系统的工作原理&#xff0c;对步进电机系统进行规划设置。达到降低步进电机的使用成本&#xff0c;提高步进电机的效率与性能的效果。文章主要探…

WIFI 接收机和发射机同步问题+CFO/SFO频率偏移问题

Synchronization Between Sender and Receiver & CFO Correction 解决同步问题和频率偏移问题是下面论文的关键&#xff0c;接下来结合论文进行详细解读 解读论文&#xff1a;Verification and Redesign of OFDM Backscatter 论文pdf&#xff1a;https://www.usenix.org/s…