昇思25天学习打卡营第23天 | Pix2Pix实现图像转换

news2024/9/8 23:31:17

内容介绍:

Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks )实现的一种深度学习图像转换模型,该模型是由Phillip Isola等作者在2017年CVPR上提出的,可以实现语义/标签到真实图片、灰度图到彩色图、航空图到地图、白天到黑夜、线稿图到实物图的转换。Pix2Pix是将cGAN应用于有监督的图像到图像翻译的经典之作,其包括两个模型:生成器判别器

传统上,尽管此类任务的目标都是相同的从像素预测像素,但每项都是用单独的专用机器来处理的。而Pix2Pix使用的网络作为一个通用框架,使用相同的架构和目标,只在不同的数据上进行训练,即可得到令人满意的结果,鉴于此许多人已经使用此网络发布了他们自己的艺术作品。

cGAN的生成器与传统GAN的生成器在原理上有一些区别,cGAN的生成器是将输入图片作为指导信息,由输入图像不断尝试生成用于迷惑判别器的“假”图像,由输入图像转换输出为相应“假”图像的本质是从像素到另一个像素的映射,而传统GAN的生成器是基于一个给定的随机噪声生成图像,输出图像通过其他约束条件控制生成,这是cGAN和GAN的在图像翻译任务中的差异。Pix2Pix中判别器的任务是判断从生成器输出的图像是真实的训练图像还是生成的“假”图像。在生成器与判别器的不断博弈过程中,模型会达到一个平衡点,生成器输出的图像与真实训练数据使得判别器刚好具有50%的概率判断正确。

具体内容:

1. 导包:

from download import download
from mindspore import dataset as ds
import matplotlib.pyplot as plt
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
import mindspore.nn as nn
import mindspore.nn as nn
from mindspore.common import initializer as init
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor
from mindspore import load_checkpoint, load_param_into_net

2. 下载数据集

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"

download(url, "./dataset", kind="tar", replace=True)

3. 数据显示

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):
    plt.subplot(3, 10, i)
    plt.axis("off")
    plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()

4. 网络构建

class UNetSkipConnectionBlock(nn.Cell):
    def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
                 submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
        super(UNetSkipConnectionBlock, self).__init__()
        down_norm = nn.BatchNorm2d(inner_nc)
        up_norm = nn.BatchNorm2d(outer_nc)
        use_bias = False
        if norm_mode == 'instance':
            down_norm = nn.BatchNorm2d(inner_nc, affine=False)
            up_norm = nn.BatchNorm2d(outer_nc, affine=False)
            use_bias = True
        if in_planes is None:
            in_planes = outer_nc
        down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
                              stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
        down_relu = nn.LeakyReLU(alpha)
        up_relu = nn.ReLU()
        if outermost:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, pad_mode='pad')
            down = [down_conv]
            up = [up_relu, up_conv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv]
            up = [up_relu, up_conv, up_norm]
            model = down + up
        else:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv, down_norm]
            up = [up_relu, up_conv, up_norm]

            model = down + [submodule] + up
            if dropout:
                model.append(nn.Dropout(p=0.5))
        self.model = nn.SequentialCell(model)
        self.skip_connections = not outermost

    def construct(self, x):
        out = self.model(x)
        if self.skip_connections:
            out = ops.concat((out, x), axis=1)
        return out

5. UNet生成器

class UNetGenerator(nn.Cell):
    def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
        super(UNetGenerator, self).__init__()
        unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
                                             norm_mode=norm_mode, innermost=True)
        for _ in range(n_layers - 5):
            unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
                                                 norm_mode=norm_mode, dropout=dropout)
        unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
                                             outermost=True, norm_mode=norm_mode)

    def construct(self, x):
        return self.model(x)

6. PatchGAN判别器

import mindspore.nn as nn

class ConvNormRelu(nn.Cell):
    def __init__(self,
                 in_planes,
                 out_planes,
                 kernel_size=4,
                 stride=2,
                 alpha=0.2,
                 norm_mode='batch',
                 pad_mode='CONSTANT',
                 use_relu=True,
                 padding=None):
        super(ConvNormRelu, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if not padding:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
                             has_bias=has_bias, padding=padding)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

class Discriminator(nn.Cell):
    def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [
            nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
            nn.LeakyReLU(alpha)
        ]
        nf_mult = ndf
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * ndf
            layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * ndf
        layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
        self.features = nn.SequentialCell(layers)

    def construct(self, x, y):
        x_y = ops.concat((x, y), axis=1)
        output = self.features(x_y)
        return output

7. Pix2Pix生成器和判别器初始化

g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'


net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,
                              ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))


net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,
                                  alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))

class Pix2Pix(nn.Cell):
    """Pix2Pix模型网络"""
    def __init__(self, discriminator, generator):
        super(Pix2Pix, self).__init__(auto_prefix=True)
        self.net_discriminator = discriminator
        self.net_generator = generator

    def construct(self, reala):
        fakeb = self.net_generator(reala)
        return fakeb

8. 训练

epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100

def get_lr():
    lrs = [lr] * dataset_size * n_epochs
    lr_epoch = 0
    for epoch in range(n_epochs_decay):
        lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
        lrs += [lr_epoch] * dataset_size
    lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
    return Tensor(np.array(lrs).astype(np.float32))

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

def forword_dis(reala, realb):
    lambda_dis = 0.5
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    pred1 = net_discriminator(reala, realb)
    loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
    loss_dis = loss_d * lambda_dis
    return loss_dis

def forword_gan(reala, realb):
    lambda_gan = 0.5
    lambda_l1 = 100
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    loss_1 = loss_f(pred0, ops.ones_like(pred0))
    loss_2 = l1_loss(fakeb, realb)
    loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
    return loss_gan

d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)

grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())

def train_step(reala, realb):
    loss_dis, d_grads = grad_d(reala, realb)
    loss_gan, g_grads = grad_g(reala, realb)
    d_opt(d_grads)
    g_opt(g_grads)
    return loss_dis, loss_gan

if not os.path.isdir(ckpt_dir):
    os.makedirs(ckpt_dir)

g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)

for epoch in range(epoch_num):
    for i, data in enumerate(data_loader):
        start_time = datetime.datetime.now()
        input_image = Tensor(data["input_images"])
        target_image = Tensor(data["target_images"])
        dis_loss, gen_loss = train_step(input_image, target_image)
        end_time = datetime.datetime.now()
        delta = (end_time - start_time).microseconds
        if i % 2 == 0:
            print("ms per step:{:.2f}  epoch:{}/{}  step:{}/{}  Dloss:{:.4f}  Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))
        d_losses.append(dis_loss.asnumpy())
        g_losses.append(gen_loss.asnumpy())
    if (epoch + 1) == epoch_num:
        mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")

9. 推理

param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):
    plt.subplot(2, 10, i + 1)
    plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 10, i + 11)
    plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

Pix2Pix的学习过程让我深刻体会到了深度学习在图像处理领域的强大能力。通过训练一对相互竞争的神经网络——生成器与判别器,Pix2Pix能够学习到输入图像与输出图像之间复杂的映射关系。这种端到端的学习方式,无需人工设计复杂的特征提取与转换规则,极大地简化了图像转换的流程,同时也提高了转换结果的质量和多样性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1916910.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

二分法求函数的零点 信友队

题目ID&#xff1a;15713 必做题 100分 时间限制: 1000ms 空间限制: 65536kB 题目描述 有函数&#xff1a;f(x) 已知f(1.5) > 0&#xff0c;f(2.4) < 0 且方程 f(x) 0 在区间 [1.5,2.4] 有且只有一个根&#xff0c;请用二分法求出该根。 输入格式 &#xff08;无…

reduce规约:深入理解java8中的规约reduce

&#x1f370; 个人主页:_小白不加班__ &#x1f35e;文章有不合理的地方请各位大佬指正。 &#x1f349;文章不定期持续更新&#xff0c;如果我的文章对你有帮助➡️ 关注&#x1f64f;&#x1f3fb; 点赞&#x1f44d; 收藏⭐️ 文章目录 常见场景图示reduce中的BiFunction和…

【linux】阿里云centos配置邮件服务

目录 1.安装mailx服务 2./etc/mail.rc 配置增加 3.QQ邮箱开启smtp服务&#xff0c;获取授权码 4.端口设置&#xff1a;Linux 防火墙开放端口-CSDN博客 5.测试 1.安装mailx服务 yum -y install mailx 2./etc/mail.rc 配置增加 #邮件发送人 set from924066173qq.com #阿里…

完美解决AttributeError: ‘list‘ object has no attribute ‘shape‘的正确解决方法,亲测有效!!!

完美解决AttributeError: ‘list‘ object has no attribute ‘shape‘的正确解决方法&#xff0c;亲测有效&#xff01;&#xff01;&#xff01; 亲测有效 完美解决AttributeError: ‘list‘ object has no attribute ‘shape‘的正确解决方法&#xff0c;亲测有效&#xff0…

Java对象引用的访问方式是什么?

哈喽&#xff0c;大家好&#x1f389;&#xff0c;我是世杰。 本文我为大家介绍面试官经常考察的**「Java对象引用相关内容」** 照例在开头留一些面试考察内容~~ 面试连环call Java对象引用都有哪些类型?Java参数传递是值传递还是引用传递? 为什么?Java对象引用访问方式有…

解释 C 语言中的递归函数

&#x1f345;关注博主&#x1f397;️ 带你畅游技术世界&#xff0c;不错过每一次成长机会&#xff01; &#x1f4d9;C 语言百万年薪修炼课程 通俗易懂&#xff0c;深入浅出&#xff0c;匠心打磨&#xff0c;死磕细节&#xff0c;6年迭代&#xff0c;看过的人都说好。 文章目…

各向异性含水层中地下水三维流基本微分方程的推导

各向异性含水层中地下水三维流基本微分方程的推导 参考文献&#xff1a; [1] 刘欣怡,付小莉.论连续性方程的推导及几种形式转换的方法[J].力学与实践,2023,45(02):469-474. 文章链接 水均衡的基本思想&#xff1a; ∑ 流 入 − ∑ 流 出 Δ V \sum 流入-\sum 流出\Delta V ∑…

【系统架构设计师】九、软件工程(软件测试)

目录 八、软件测试 8.1 测试分类 8.2 静态方法 8.2.1 静态测试 8.2.2 动态测试 8.2.3 自动化测试 8.3 测试阶段 8.3.1 单元测试 8.3.2 集成测试 8.3.3 确认测试 8.3.4 系统测试 8.3.5 性能测试 8.3.6 验收测试 8.3.7 其他测试 8.4 测试用例设计 8.4.1 黑…

使用 Python 绘制美国选举分级统计图

「AI秘籍」系列课程&#xff1a; 人工智能应用数学基础 人工智能Python基础 人工智能基础核心知识 人工智能BI核心知识 人工智能CV核心知识 如何创建美国选举结果的时间序列分级统计图 数据地址为源地址&#xff0c;如果失效请与我联系。 2024 年美国大选将至&#xff0c;…

算法通关:004_1选择排序

代码一定要自己手敲理解 public class _004 {//选择排序&#xff0c;冒泡排序&#xff0c;插入排序//交换public static void swap(int[] arr,int i ,int j){int temp arr[i];arr[i] arr[j];arr[j] temp;}//选择排序public static void selectSort(int[] arr){if(arr null…

C++ | Leetcode C++题解之第225题用队列实现栈

题目&#xff1a; 题解&#xff1a; class MyStack { public:queue<int> q;/** Initialize your data structure here. */MyStack() {}/** Push element x onto stack. */void push(int x) {int n q.size();q.push(x);for (int i 0; i < n; i) {q.push(q.front());…

LabVIEW实现LED显示屏视觉检测

为了满足LED显示屏在生产过程中的严格质量检测需求&#xff0c;引入自动化检测系统是十分必要的。传统人工检测方式存在检测强度高、效率低、准确性差等问题&#xff0c;自动化检测系统则能显著提高检测效率和准确性。视觉检测系统的构建主要包含硬件和软件两个部分。 视觉系统…

JDK中不能继承的类:final类的作用与意义

JDK中不能继承的类&#xff1a;final类的作用与意义 1、 为什么要用final类&#xff1f;2、JDK中有哪些final类&#xff1f;3、总结 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 在Java编程中&#xff0c;有些类被标记为final&#xff0c…

前端面试题47(在动态控制路由时,如何防止未授权用户访问受保护的页面?)

在Vue中&#xff0c;防止未授权用户访问受保护页面通常涉及到使用路由守卫&#xff08;Route Guards&#xff09;。路由守卫允许你在路由发生改变前或后执行一些逻辑&#xff0c;比如检查用户是否已登录或者有访问某个页面的权限。下面是一些常见的路由守卫类型及其使用方式&am…

MapReduce底层原理详解:大案例解析(第32天)

系列文章目录 一、MapReduce概述 二、MapReduce工作机制 三、Map&#xff0c;Shuffle&#xff0c;reduce阶段详解 四、大案例解析 文章目录 系列文章目录前言一、MapReduce概述二、MapReduce工作机制1. 角色与组件2. 作业提交与执行流程1. 作业提交&#xff1a;2. Map阶段&…

IntelliJ IDEA社区版在Windows电脑中的下载、安装方法

本文介绍IntelliJ IDEA软件Community&#xff08;社区版&#xff09;在Windows操作系统中的下载、安装、运行与使用方法。 IntelliJ IDEA软件是一款由JetBrains公司开发的集成开发环境&#xff08;IDE&#xff09;&#xff0c;主要用于Java语言的开发&#xff0c;但同时也支持其…

面试经验之谈

优质博文&#xff1a;IT-BLOG-CN ​通常面试官会把每一轮面试分为三个环节&#xff1a;① 行为面试 ② 技术面试 ③ 应聘者提问 行为面试环节 面试开始的5~10分钟通常是行为面试的时间&#xff0c;面试官会参照简历和你的自我介绍了解应聘者的过往经验和项目经历。由于面试官…

读书笔记-Java并发编程的艺术-第4章(Java并发编程基础)-第4节(线程应用实例)

文章目录 4.4 线程应用实例4.4.1 等待超时模式4.4.2 一个简单的数据库连接池示例4.4.3 线程池技术及其示例4.4.4 一个基于线程池技术的简单 Web 服务器 4.4 线程应用实例 4.4.1 等待超时模式 开发人员经常会遇到这样的方法调用场景&#xff1a;调用一个方法时等待一段时间(一…

postgres 的dblink使用,远程连接数据库

一.安装下载 dblink create extension if not exists dblink 查看是否已经安装 select * from pg_extension;二.运行&#xff0c;查询数据 其中&#xff0c;第一个参数是dblink名字&#xff0c;也可以是连接字符串。 第二个参数是要执行的SQL查询语句。AS子句用于指定返回结…