CycleGAN图像风格迁移互换

news2024/11/25 22:53:52

tutorials/application/source_zh_cn/generative/cyclegan.ipynb · MindSpore/docs - Gitee.com

本案例运行需要较大内存,建议在Ascend/GPU上运行。

模型介绍

模型简介

CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络,来自论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks 。该模型实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。

该模型一个重要应用领域是域迁移(Domain Adaptation),可以通俗地理解为图像风格迁移。其实在 CycleGAN 之前,就已经有了域迁移模型,比如 Pix2Pix ,但是 Pix2Pix 要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,因此 CycleGAN 诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,是一种新的无监督的图像迁移网络。

模型结构

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,其结构如下图所示(图片来源于原论文):

CycleGAN

为了方便理解,这里以苹果和橘子为例介绍。上图中 𝑋 可以理解为苹果,𝑌 为橘子;𝐺 为将苹果生成橘子风格的生成器,𝐹 为将橘子生成的苹果风格的生成器,𝐷𝑋 和 𝐷𝑌 为其相应判别器,具体生成器和判别器的结构可见下文代码。模型最终能够输出两个模型的权重,分别将两种图像的风格进行彼此迁移,生成新的图像。

该模型一个很重要的部分就是损失函数,在所有损失里面循环一致损失(Cycle Consistency Loss)是最重要的。循环损失的计算过程如下图所示(图片来源于原论文):

Cycle Consistency Loss

图中苹果图片 𝑥 经过生成器 𝐺 得到伪橘子 𝑌̂ ,然后将伪橘子 𝑌̂  结果送进生成器 𝐹 又产生苹果风格的结果 𝑥̂ ,最后将生成的苹果风格结果 𝑥̂  与原苹果图片 𝑥 一起计算出循环一致损失,反之亦然。循环损失捕捉了这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。详细的训练过程见下文代码。

数据集

本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。图像被统一缩放为256×256像素大小,其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。

这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理,为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,以省略大部分数据预处理的代码。

数据集下载

使用 download 接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用 pip install download 安装 download 包。

%%capture captured_output
# 实验环境已经预装了mindspore==2.3.0,如需更换mindspore版本,可更改下面 MINDSPORE_VERSION 变量
!pip uninstall mindspore -y
%env MINDSPORE_VERSION=2.3.0
!pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/${MINDSPORE_VERSION}/MindSpore/unified/aarch64/mindspore-${MINDSPORE_VERSION}-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.mirrors.ustc.edu.cn/simple
# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore
Version: 2.3.0
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"

download(url, ".", kind="zip", replace=True)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip (466.8 MB)

file_sizes: 100%|█████████████████████████████| 489M/489M [00:02<00:00, 189MB/s]
Extracting zip file...
Successfully downloaded / unzipped to .

数据集加载

使用 MindSpore 的 MindDataset 接口读取和解析数据集。

from mindspore.dataset import MindDataset

# 读取MindRecord格式数据
name_mr = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"
data = MindDataset(dataset_files=name_mr)
print("Datasize: ", data.get_dataset_size())

batch_size = 1
dataset = data.batch(batch_size)
datasize = dataset.get_dataset_size()
Datasize:  1019

可视化

通过 create_dict_iterator 函数将数据转换成字典迭代器,然后使用 matplotlib 模块可视化部分训练数据。

import numpy as np
import matplotlib.pyplot as plt

mean = 0.5 * 255
std = 0.5 * 255

plt.figure(figsize=(12, 5), dpi=60)
for i, data in enumerate(dataset.create_dict_iterator()):
    if i < 5:
        show_images_a = data["image_A"].asnumpy()
        show_images_b = data["image_B"].asnumpy()

        plt.subplot(2, 5, i+1)
        show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
        plt.imshow(show_images_a)
        plt.axis("off")

        plt.subplot(2, 5, i+6)
        show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
        plt.imshow(show_images_b)
        plt.axis("off")
    else:
        break
plt.show()

构建生成器

本案例生成器的模型结构参考的 ResNet 模型的结构,参考原论文,对于128×128大小的输入图片采用6个残差块相连,图片大小为256×256以上的需要采用9个残差块相连,所以本文网络有9个残差块相连,超参数 n_layers 参数控制残差块数。

生成器的结构如下所示:

CycleGAN Generator

具体的模型结构请参照下文代码:

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal

weight_init = Normal(sigma=0.02)

class ConvNormReLU(nn.Cell):
    def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',
                 pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):
        super(ConvNormReLU, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if padding is None:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, padding=padding, weight_init=weight_init)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, weight_init=weight_init)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output


class ResidualBlock(nn.Cell):
    def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)
        self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)
        self.dropout = dropout
        if dropout:
            self.dropout = nn.Dropout(p=0.5)

    def construct(self, x):
        out = self.conv1(x)
        if self.dropout:
            out = self.dropout(out)
        out = self.conv2(out)
        return x + out


class ResNetGenerator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,
                 pad_mode="CONSTANT"):
        super(ResNetGenerator, self).__init__()
        self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)
        self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)
        self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)
        layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers
        self.residuals = nn.SequentialCell(layers)
        self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)
        self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)
        if pad_mode == "CONSTANT":
            self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',
                                      padding=3, weight_init=weight_init)
        else:
            pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)
            conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)
            self.conv_out = nn.SequentialCell([pad, conv])

    def construct(self, x):
        x = self.conv_in(x)
        x = self.down_1(x)
        x = self.down_2(x)
        x = self.residuals(x)
        x = self.up_2(x)
        x = self.up_1(x)
        output = self.conv_out(x)
        return ops.tanh(output)

# 实例化生成器
net_rg_a = ResNetGenerator()
net_rg_a.update_parameters_name('net_rg_a.')

net_rg_b = ResNetGenerator()
net_rg_b.update_parameters_name('net_rg_b.')

构建判别器

判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。通过一系列的 Conv2d 、 BatchNorm2d 和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数得到最终概率。

# 定义判别器
class Discriminator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),
                  nn.LeakyReLU(alpha)]
        nf_mult = output_channel
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * output_channel
            layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * output_channel
        layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

# 判别器初始化
net_d_a = Discriminator()
net_d_a.update_parameters_name('net_d_a.')

net_d_b = Discriminator()
net_d_b.update_parameters_name('net_d_b.')

优化器和损失函数

根据不同模型需要单独的设置优化器,这是训练过程决定的。

对生成器 𝐺 及其判别器 𝐷𝑌 ,目标损失函数定义为:

L_{GAN}(G,D_{Y},X,Y)=E_{y-P_{data}(y)}[logD_{y}(y)]+E_{x-P_{data}(x)}[log(1-D_{y}(G(x)))]

其中 𝐺 试图生成看起来与 𝑌 中的图像相似的图像 𝐺(𝑥) ,而 𝐷𝑌 的目标是区分翻译样本 𝐺(𝑥) 和真实样本 𝑦 ,生成器的目标是最小化这个损失函数以此来对抗判别器。即 min_{G}max_{D_{Y}}L_{GAN}(G,D_{Y},X,Y)

单独的对抗损失不能保证所学函数可以将单个输入映射到期望的输出,为了进一步减少可能的映射函数的空间,学习到的映射函数应该是周期一致的,例如对于 𝑋X 的每个图像 𝑥x ,图像转换周期应能够将 𝑥x 带回原始图像,可以称之为正向循环一致性,即 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥 。对于 𝑌 ,类似的 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥 。可以理解采用了一个循环一致性损失来激励这种行为。

循环一致损失函数定义如下:

L_{cyc}(G,F)=E_{x-p_{data}(x)}[\left | \left | F(G(x))-x \right | \right |_{1}]+E_{y-p_{data}(y)}[\left | \left | F(G(y))-y \right | \right |_{1}]

循环一致损失能够保证重建图像 𝐹(𝐺(𝑥))与输入图像 𝑥 紧密匹配。

# 构建生成器,判别器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

# GAN网络损失函数,这里最后一层不使用sigmoid函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss("mean")

def gan_loss(predict, target):
    target = ops.ones_like(predict) * target
    loss = loss_fn(predict, target)
    return loss

前向计算

搭建模型前向计算损失的过程,过程如下代码。

为了减少模型振荡[1],这里遵循 Shrivastava 等人的策略[2],使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。

import mindspore as ms

# 前向计算

def generator(img_a, img_b):
    fake_a = net_rg_b(img_b)
    fake_b = net_rg_a(img_a)
    rec_a = net_rg_b(fake_b)
    rec_b = net_rg_a(fake_a)
    identity_a = net_rg_b(img_a)
    identity_b = net_rg_a(img_b)
    return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b

lambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5

def generator_forward(img_a, img_b):
    true = Tensor(True, dtype.bool_)
    fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)
    loss_g_a = gan_loss(net_d_b(fake_b), true)
    loss_g_b = gan_loss(net_d_a(fake_a), true)
    loss_c_a = l1_loss(rec_a, img_a) * lambda_a
    loss_c_b = l1_loss(rec_b, img_b) * lambda_b
    loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idt
    loss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt
    loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b
    return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b

def generator_forward_grad(img_a, img_b):
    _, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)
    return loss_g

def discriminator_forward(img_a, img_b, fake_a, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    loss_d = (loss_d_a + loss_d_b) * 0.5
    return loss_d

def discriminator_forward_a(img_a, fake_a):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    return loss_d_a

def discriminator_forward_b(img_b, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    return loss_d_b

# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50
def image_pool(images):
    num_imgs = 0
    image1 = []
    if isinstance(images, Tensor):
        images = images.asnumpy()
    return_images = []
    for image in images:
        if num_imgs < pool_size:
            num_imgs = num_imgs + 1
            image1.append(image)
            return_images.append(image)
        else:
            if random.uniform(0, 1) > 0.5:
                random_id = random.randint(0, pool_size - 1)

                tmp = image1[random_id].copy()
                image1[random_id] = image
                return_images.append(tmp)

            else:
                return_images.append(image)
    output = Tensor(return_images, ms.float32)
    if output.ndim != 4:
        raise ValueError("img should be 4d, but get shape {}".format(output.shape))
    return output

计算梯度和反向传播

其中梯度计算也是分开不同的模型来进行的,详情见如下代码:

from mindspore import value_and_grad

# 实例化求梯度的方法
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())

grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())

# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):
    net_d_a.set_grad(False)
    net_d_b.set_grad(False)

    fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)

    _, grads_g_a = grad_g_a(img_a, img_b)
    _, grads_g_b = grad_g_b(img_a, img_b)
    optimizer_rg_a(grads_g_a)
    optimizer_rg_b(grads_g_b)

    return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib

# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):
    net_d_a.set_grad(True)
    net_d_b.set_grad(True)

    loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)
    loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)

    loss_d = (loss_d_a + loss_d_b) * 0.5

    optimizer_d_a(grads_d_a)
    optimizer_d_b(grads_d_b)

    return loss_d

模型训练

训练分为两个主要部分:训练判别器和训练生成器,在前文的判别器损失函数中,论文采用了最小二乘损失代替负对数似然目标。

  • 训练判别器:训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化 𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[(𝐷(𝑦)−1)2]Ey−pdata(y)[(D(y)−1)2] ;

  • 训练生成器:如 CycleGAN 论文所述,我们希望通过最小化 𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[(𝐷(𝐺(𝑥)−1)2]Ex−pdata(x)[(D(G(x)−1)2] 来训练生成器,以产生更好的虚假图像。

下面定义了生成器和判别器的训练过程:

%%time
import os
import time
import random
import numpy as np
from PIL import Image
from mindspore import Tensor, save_checkpoint
from mindspore import dtype

# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1
save_step_num = 80
save_checkpoint_epochs = 1
save_ckpt_dir = './train_ckpt_outputs/'

print('Start training!')

for epoch in range(epochs):
    g_loss = []
    d_loss = []
    start_time_e = time.time()
    for step, data in enumerate(dataset.create_dict_iterator()):
        start_time_s = time.time()
        img_a = data["image_A"]
        img_b = data["image_B"]
        res_g = train_step_g(img_a, img_b)
        fake_a = res_g[0]
        fake_b = res_g[1]

        res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))
        loss_d = float(res_d.asnumpy())
        step_time = time.time() - start_time_s

        res = []
        for item in res_g[2:]:
            res.append(float(item.asnumpy()))
        g_loss.append(res[0])
        d_loss.append(loss_d)

        if step % save_step_num == 0:
            print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
                  f"step:[{int(step):>4d}/{int(datasize):>4d}], "
                  f"time:{step_time:>3f}s,\n"
                  f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "
                  f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "
                  f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "
                  f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")

    epoch_cost = time.time() - start_time_e
    per_step_time = epoch_cost / datasize
    mean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasize

    print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
          f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "
          f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")

    if epoch % save_checkpoint_epochs == 0:
        os.makedirs(save_ckpt_dir, exist_ok=True)
        save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))
        save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))
        save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))
        save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))

print('End of training!')
Start training!
Epoch:[  1/  1], step:[   0/1019], time:49.117887s,
loss_g:21.40, loss_d:0.94, loss_g_a: 0.95, loss_g_b: 0.96, loss_c_a: 6.28, loss_c_b: 6.72, loss_idt_a: 3.13, loss_idt_b: 3.36
Epoch:[  1/  1], step:[  80/1019], time:0.679636s,
loss_g:11.84, loss_d:0.37, loss_g_a: 0.26, loss_g_b: 0.36, loss_c_a: 3.62, loss_c_b: 4.23, loss_idt_a: 1.52, loss_idt_b: 1.86
Epoch:[  1/  1], step:[ 160/1019], time:0.644315s,
loss_g:7.08, loss_d:0.45, loss_g_a: 0.44, loss_g_b: 0.49, loss_c_a: 2.18, loss_c_b: 2.06, loss_idt_a: 0.89, loss_idt_b: 1.03
Epoch:[  1/  1], step:[ 240/1019], time:0.675089s,
loss_g:7.64, loss_d:0.58, loss_g_a: 0.48, loss_g_b: 0.33, loss_c_a: 2.64, loss_c_b: 1.98, loss_idt_a: 1.24, loss_idt_b: 0.97
Epoch:[  1/  1], step:[ 320/1019], time:0.647689s,
loss_g:5.71, loss_d:0.30, loss_g_a: 0.29, loss_g_b: 0.39, loss_c_a: 1.00, loss_c_b: 2.57, loss_idt_a: 0.43, loss_idt_b: 1.03
Epoch:[  1/  1], step:[ 400/1019], time:0.667818s,
loss_g:6.55, loss_d:0.48, loss_g_a: 0.40, loss_g_b: 0.13, loss_c_a: 2.34, loss_c_b: 1.98, loss_idt_a: 1.04, loss_idt_b: 0.66
Epoch:[  1/  1], step:[ 480/1019], time:0.657683s,
loss_g:8.32, loss_d:0.34, loss_g_a: 0.29, loss_g_b: 0.43, loss_c_a: 4.49, loss_c_b: 1.49, loss_idt_a: 1.06, loss_idt_b: 0.55
Epoch:[  1/  1], step:[ 560/1019], time:0.681585s,
loss_g:8.00, loss_d:0.24, loss_g_a: 0.40, loss_g_b: 0.35, loss_c_a: 2.45, loss_c_b: 2.76, loss_idt_a: 0.93, loss_idt_b: 1.12
Epoch:[  1/  1], step:[ 640/1019], time:0.644428s,
loss_g:6.03, loss_d:0.44, loss_g_a: 0.22, loss_g_b: 0.69, loss_c_a: 2.50, loss_c_b: 0.89, loss_idt_a: 1.23, loss_idt_b: 0.50
Epoch:[  1/  1], step:[ 720/1019], time:0.702698s,
loss_g:3.67, loss_d:0.53, loss_g_a: 0.50, loss_g_b: 0.36, loss_c_a: 0.93, loss_c_b: 1.19, loss_idt_a: 0.35, loss_idt_b: 0.35
Epoch:[  1/  1], step:[ 800/1019], time:0.670582s,
loss_g:5.37, loss_d:0.35, loss_g_a: 0.46, loss_g_b: 0.30, loss_c_a: 2.08, loss_c_b: 1.19, loss_idt_a: 0.93, loss_idt_b: 0.41
Epoch:[  1/  1], step:[ 880/1019], time:0.670877s,
loss_g:4.33, loss_d:0.85, loss_g_a: 0.58, loss_g_b: 0.07, loss_c_a: 1.61, loss_c_b: 1.04, loss_idt_a: 0.59, loss_idt_b: 0.44
Epoch:[  1/  1], step:[ 960/1019], time:0.665170s,
loss_g:4.84, loss_d:0.50, loss_g_a: 0.14, loss_g_b: 0.56, loss_c_a: 1.89, loss_c_b: 0.91, loss_idt_a: 0.87, loss_idt_b: 0.49
Epoch:[  1/  1], epoch time:730.05s, per step time:0.72, mean_g_loss:6.86, mean_d_loss:0.45
End of training!
CPU times: user 43min 41s, sys: 11min 41s, total: 55min 22s

模型推理

下面我们通过加载生成器网络模型参数文件来对原图进行风格迁移,结果中第一行为原图,第二行为对应生成的结果图。

import os
from PIL import Image
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import load_checkpoint, load_param_into_net

# 加载权重文件
def load_ckpt(net, ckpt_dir):
    param_GA = load_checkpoint(ckpt_dir)
    load_param_into_net(net, param_GA)

g_a_ckpt = './CycleGAN_apple2orange/ckpt/g_a.ckpt'
g_b_ckpt = './CycleGAN_apple2orange/ckpt/g_b.ckpt'

load_ckpt(net_rg_a, g_a_ckpt)
load_ckpt(net_rg_b, g_b_ckpt)

# 图片推理
fig = plt.figure(figsize=(11, 2.5), dpi=100)
def eval_data(dir_path, net, a):

    def read_img():
        for dir in os.listdir(dir_path):
            path = os.path.join(dir_path, dir)
            img = Image.open(path).convert('RGB')
            yield img, dir

    dataset = ds.GeneratorDataset(read_img, column_names=["image", "image_name"])
    trans = [vision.Resize((256, 256)), vision.Normalize(mean=[0.5 * 255] * 3, std=[0.5 * 255] * 3), vision.HWC2CHW()]
    dataset = dataset.map(operations=trans, input_columns=["image"])
    dataset = dataset.batch(1)
    for i, data in enumerate(dataset.create_dict_iterator()):
        img = data["image"]
        fake = net(img)
        fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))
        img = (img[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))

        fig.add_subplot(2, 8, i+1+a)
        plt.axis("off")
        plt.imshow(img.asnumpy())

        fig.add_subplot(2, 8, i+9+a)
        plt.axis("off")
        plt.imshow(fake.asnumpy())

eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)
eval_data('./CycleGAN_apple2orange/predict/orange', net_rg_b, 4)
plt.show()

参考

[1] I. Goodfellow. NIPS 2016 tutorial: Generative ad-versarial networks. arXiv preprint arXiv:1701.00160,2016. 2, 4, 5

[2] A. Shrivastava, T. Pfister, O. Tuzel, J. Susskind, W. Wang, R. Webb. Learning from simulated and unsupervised images through adversarial training. In CVPR, 2017. 3, 5, 6, 7

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2187710.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java】—— 集合框架:Collection子接口:Set不同实现类的对比及使用(HashSet、LinkedHashSet、TreeSet)

目录 5. Collection子接口2&#xff1a;Set 5.1 Set接口概述 5.2 Set主要实现类&#xff1a;HashSet 5.2.1 HashSet概述 5.2.2 HashSet中添加元素的过程&#xff1a; 5.2.3 重写 hashCode() 方法的基本原则 5.2.4 重写equals()方法的基本原则 5.2.5 练习 5.3 Set实现类…

map部分重点

1.map的方括号 给key,返回value的引用&#xff0c;如果没有key&#xff0c;就插入一个key,无参构造的value的pair<> 适用&#xff1a;没有就插入&#xff0c;有就拿找到的值 insert返回pair<iterator,bool>,[]返回值 #include<iostream> #include<map&…

更美观的HTTP性能监测工具:httpstat

reorx/httpstat是一个旨在提供更美观和详细HTTP请求统计信息的cURL命令行工具&#xff0c;它能够帮助开发者和运维人员深入理解HTTP请求的性能和状态。 1. 基本概述 项目地址&#xff1a;https://github.com/reorx/httpstat语言&#xff1a;该工具主要是以Python编写&#xff…

偏标记学习+图像分类(论文复现)

偏标记学习图像分类&#xff08;论文复现&#xff09; 本文所涉及所有资源均在传知代码平台可获取 文章目录 偏标记学习图像分类&#xff08;论文复现&#xff09;概述算法原理核心逻辑效果演示使用方式 概述 本文复现论文提出的偏标记学习方法&#xff0c;随着深度神经网络的发…

异常场景分析

优质博文&#xff1a;IT-BLOG-CN 为了防止黑客从前台异常信息&#xff0c;对系统进行攻击。同时&#xff0c;为了提高用户体验&#xff0c;我们都会都抛出的异常进行拦截处理。 一、异常处理类 Java把异常当做是破坏正常流程的一个事件&#xff0c;当事件发生后&#xff0c;…

CMU 10423 Generative AI:lec16(Mixture of Experts 混合专家模型)

关于MoE推荐博客&#xff1a; https://huggingface.co/blog/zh/moehttps://www.paddlepaddle.org.cn/documentation/docs/zh/guides/06_distributed_training/moe_cn.html 1 概述 这个文档是关于Mixture of Experts (MoE) 的介绍和实现&#xff0c;主要内容如下&#xff1a;…

virtualbox配置为NAT模式后物理机和虚拟机互通

virtualbox配置为 NAT模式后&#xff0c;虚拟机分配到的 IP地址一般是 10.xx网段的&#xff0c;虚拟机可以通过网络地址转换访问物理机所在的网络&#xff0c;但若不做任何配置&#xff0c;则物理机无法直接访问虚拟机。 virtualbox在提供 NAT配置模式时&#xff0c;也提供了端…

深度学习:CycleGAN图像风格迁移转换

基础概念 CycleGAN是一种GAN的变体&#xff0c;它被设计用来在没有成对训练数据的情况下学习两种不同域之间的图像到图像的转换&#xff0c;不需要同一场景或物体在两个不同域中的对应图像。 CycleGAN由Jun-Yan Zhu等人在2017年提出。 CycleGAN的模型架构主要由两组生成器和…

mac配置python出现DataDirError: Valid PROJ data directory not found错误的解决

最近在利用python下载SWOT数据时出现以下的问题&#xff1a; import xarray as xr import s3fs import cartopy.crs as ccrs from matplotlib import pyplot as plt import earthaccess from earthaccess import Auth, DataCollections, DataGranules, Store import os os.env…

CSS3--美开二度

免责声明&#xff1a;本文仅做分享&#xff01; 目录 定位 相对定位 绝对定位 定位居中 固定定位 堆叠层级 z-index 定位-小结 CSS 精灵 京东案例 字体图标 下载字体 使用字体 上传矢量图 CSS 修饰属性 垂直对齐方式 vertical-align 过渡 transition 透明度 opa…

【西门子V20变频器】 变频器运行时报A922报警

报警说明 原因&#xff1a; 1.变频器未接负载 2.变频器设定的电机参数与实际电机不匹配 3.查看P2179查看 无负载监控 设定的电流极限值&#xff0c;出厂默认为“3.0”

mysql事务 -- 事务的隔离性(测试实验+介绍,脏读,不可重复读,可重复度读,幻读)

目录 事务的隔离性 引入 测试 读未提交 脏读 读提交 不可重复读 属于问题吗? 例子 可重复读 幻读 串行化 原理 总结 事务的隔离性 引入 当我们让两个客户端共同执行begin语句时,就开始了两个事务并发访问 在这个过程中,可能会出现sql交叉的问题 但我们不希望因为…

项目定位与服务器(SERVER)模块划分

目录 定位 HTTP协议以及HTTP服务器 高并发服务器 单Reactor单线程 单Reactor多线程 多Reactor多线程 模块划分 SERVER模块划分 Buffer 模块 Socket模块 Channel 模块 Connection模块 Acceptor模块 TimerQueue模块 Poller模块 EventLoop模块 TcpServer模块 SE…

【ADC】噪声(1)噪声分类

概述 本文学习于TI 高精度实验室课程&#xff0c;总结 ADC 的噪声分类&#xff0c;并简要介绍量化噪声和热噪声。 文章目录 概述一、ADC 中的噪声类型二、量化噪声三、热噪声四、量化噪声与热噪声对比 一、ADC 中的噪声类型 ADC 固有噪声由两部分组成&#xff1a;第一部分是量…

【树莓派系列】树莓派wiringPi库详解,官方外设开发

树莓派wiringPi库详解&#xff0c;官方外设开发 文章目录 树莓派wiringPi库详解&#xff0c;官方外设开发一、安装wiringPi库二、wiringPi库API大全1.硬件初始化函数2.通用GPIO控制函数3.时间控制函数4.串口通信串口API串口通信配置多串口通信配置串口自发自收测试串口间通信测…

Django 后端数据传给前端

Step 1 创建一个数据库 Step 2 在Django中点击数据库连接 Step 3 连接成功 Step 4 settings中找DATABASES Step 5 将数据库挂上面 将数据库引擎和数据库名改成自己的 Step 6 在_init_.py中加上数据库的支持语句 import pymysql pymysql.install_as_MySQLdb() Step7 简单创建两…

以企业的视角进行大学生招聘

课程来源&#xff1a;中国计算机学会---朱颖韶&#xff08;资深人力资源领域--HR&#xff09; 一、招聘流程 1.简历->门槛 注重&#xff1a;专业学历、行业经验 2.笔试面试->专业知识与技能 3.简历面试-> 过往的成果 4.面试 沟通能力、学习力-----了解动机、价值观…

Pikachu-Sql Inject-insert/update/delete注入

insert 注入 插入语句 insert into tables values(value1,value2,value3); 如&#xff1a;插入用户表 insert into users (id,name,password) values (id,username,password); 当点击注册 先判断是否有SQL注入漏洞&#xff0c;经过判断之后发现存在SQL漏洞。构造insert的pa…

8644 堆排序

### 思路 堆排序是一种基于堆数据结构的排序算法。堆是一种完全二叉树&#xff0c;分为最大堆和最小堆。堆排序的基本思想是将待排序数组构造成一个最大堆&#xff0c;然后依次将堆顶元素与末尾元素交换&#xff0c;并调整堆结构&#xff0c;直到排序完成。 ### 伪代码 1. 读取…

自闭症干预寄宿学校:专业治疗帮助孩子发展

自闭症干预寄宿学校&#xff1a;星贝育园的专业治疗助力孩子全面发展 在自闭症儿童的教育与康复领域&#xff0c;寄宿学校以其独特的教育模式和全面的关怀体系&#xff0c;为众多家庭提供了重要的选择。广州星贝育园自闭症儿童寄宿制学校&#xff0c;作为这一领域的佼佼者&…