《昇思25天学习打卡营第17天 | CycleGAN图像风格迁移互换》

news2024/11/15 6:51:09

《昇思25天学习打卡营第17天 | CycleGAN图像风格迁移互换》

目录

  • 《昇思25天学习打卡营第17天 | CycleGAN图像风格迁移互换》
    • 模型介绍
      • 模型简介
      • 模型结构
    • 数据集
      • 数据集下载
      • 数据集加载
      • 可视化
    • 构建生成器
    • 构建判别器
    • 优化器和损失函数
    • 前向计算
    • 计算梯度和反向传播
    • 模型训练
    • 模型推理
    • 参考

模型介绍

模型简介

CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络,来自论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks 。该模型实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。

该模型一个重要应用领域是域迁移(Domain Adaptation),可以通俗地理解为图像风格迁移。其实在 CycleGAN 之前,就已经有了域迁移模型,比如 Pix2Pix ,但是 Pix2Pix 要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,因此 CycleGAN 诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,是一种新的无监督的图像迁移网络。

模型结构

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,其结构如下图所示(图片来源于原论文):

CycleGAN

为了方便理解,这里以苹果和橘子为例介绍。上图中 X X X 可以理解为苹果, Y Y Y 为橘子; G G G 为将苹果生成橘子风格的生成器, F F F 为将橘子生成的苹果风格的生成器, D X D_{X} DX D Y D_{Y} DY 为其相应判别器,具体生成器和判别器的结构可见下文代码。模型最终能够输出两个模型的权重,分别将两种图像的风格进行彼此迁移,生成新的图像。

该模型一个很重要的部分就是损失函数,在所有损失里面循环一致损失(Cycle Consistency Loss)是最重要的。循环损失的计算过程如下图所示(图片来源于原论文):

Cycle Consistency Loss

图中苹果图片 x x x 经过生成器 G G G 得到伪橘子 Y ^ \hat{Y} Y^,然后将伪橘子 Y ^ \hat{Y} Y^ 结果送进生成器 F F F 又产生苹果风格的结果 x ^ \hat{x} x^,最后将生成的苹果风格结果 x ^ \hat{x} x^ 与原苹果图片 x x x 一起计算出循环一致损失,反之亦然。循环损失捕捉了这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。详细的训练过程见下文代码。

数据集

本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。图像被统一缩放为256×256像素大小,其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。

这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理,为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,以省略大部分数据预处理的代码。

数据集下载

使用 download 接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用 pip install download 安装 download 包。

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"

download(url, ".", kind="zip", replace=True)
Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip (466.8 MB)

file_sizes: 100%|█████████████████████████████| 489M/489M [00:03<00:00, 154MB/s]
Extracting zip file...
Successfully downloaded / unzipped to .
'.'

数据集加载

使用 MindSpore 的 MindDataset 接口读取和解析数据集。

from mindspore.dataset import MindDataset

# 读取MindRecord格式数据
name_mr = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"
data = MindDataset(dataset_files=name_mr)
print("Datasize: ", data.get_dataset_size())

batch_size = 1
dataset = data.batch(batch_size)
datasize = dataset.get_dataset_size()

Datasize: 1019

可视化

通过 create_dict_iterator 函数将数据转换成字典迭代器,然后使用 matplotlib 模块可视化部分训练数据。

import numpy as np
import matplotlib.pyplot as plt

mean = 0.5 * 255
std = 0.5 * 255

plt.figure(figsize=(12, 5), dpi=60)
for i, data in enumerate(dataset.create_dict_iterator()):
    if i < 5:
        show_images_a = data["image_A"].asnumpy()
        show_images_b = data["image_B"].asnumpy()

        plt.subplot(2, 5, i+1)
        show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
        plt.imshow(show_images_a)
        plt.axis("off")

        plt.subplot(2, 5, i+6)
        show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
        plt.imshow(show_images_b)
        plt.axis("off")
    else:
        break
plt.show()

在这里插入图片描述

构建生成器

本案例生成器的模型结构参考的 ResNet 模型的结构,参考原论文,对于128×128大小的输入图片采用6个残差块相连,图片大小为256×256以上的需要采用9个残差块相连,所以本文网络有9个残差块相连,超参数 n_layers 参数控制残差块数。

生成器的结构如下所示:

CycleGAN Generator

具体的模型结构请参照下文代码:

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal

weight_init = Normal(sigma=0.02)

class ConvNormReLU(nn.Cell):
    def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',
                 pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):
        super(ConvNormReLU, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if padding is None:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, padding=padding, weight_init=weight_init)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, weight_init=weight_init)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output


class ResidualBlock(nn.Cell):
    def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)
        self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)
        self.dropout = dropout
        if dropout:
            self.dropout = nn.Dropout(p=0.5)

    def construct(self, x):
        out = self.conv1(x)
        if self.dropout:
            out = self.dropout(out)
        out = self.conv2(out)
        return x + out


class ResNetGenerator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,
                 pad_mode="CONSTANT"):
        super(ResNetGenerator, self).__init__()
        self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)
        self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)
        self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)
        layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers
        self.residuals = nn.SequentialCell(layers)
        self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)
        self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)
        if pad_mode == "CONSTANT":
            self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',
                                      padding=3, weight_init=weight_init)
        else:
            pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)
            conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)
            self.conv_out = nn.SequentialCell([pad, conv])

    def construct(self, x):
        x = self.conv_in(x)
        x = self.down_1(x)
        x = self.down_2(x)
        x = self.residuals(x)
        x = self.up_2(x)
        x = self.up_1(x)
        output = self.conv_out(x)
        return ops.tanh(output)

# 实例化生成器
net_rg_a = ResNetGenerator()
net_rg_a.update_parameters_name('net_rg_a.')

net_rg_b = ResNetGenerator()
net_rg_b.update_parameters_name('net_rg_b.')

构建判别器

判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。通过一系列的 Conv2dBatchNorm2dLeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数得到最终概率。

# 定义判别器
class Discriminator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),
                  nn.LeakyReLU(alpha)]
        nf_mult = output_channel
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * output_channel
            layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * output_channel
        layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

# 判别器初始化
net_d_a = Discriminator()
net_d_a.update_parameters_name('net_d_a.')

net_d_b = Discriminator()
net_d_b.update_parameters_name('net_d_b.')

优化器和损失函数

根据不同模型需要单独的设置优化器,这是训练过程决定的。

对生成器 G G G 及其判别器 D Y D_{Y} DY ,目标损失函数定义为:

L G A N ( G , D Y , X , Y ) = E y − p d a t a ( y ) [ l o g D Y ( y ) ] + E x − p d a t a ( x ) [ l o g ( 1 − D Y ( G ( x ) ) ) ] L_{GAN}(G,D_Y,X,Y)=E_{y-p_{data}(y)}[logD_Y(y)]+E_{x-p_{data}(x)}[log(1-D_Y(G(x)))] LGAN(G,DY,X,Y)=Eypdata(y)[logDY(y)]+Expdata(x)[log(1DY(G(x)))]

其中 G G G 试图生成看起来与 Y Y Y 中的图像相似的图像 G ( x ) G(x) G(x) ,而 D Y D_{Y} DY 的目标是区分翻译样本 G ( x ) G(x) G(x) 和真实样本 y y y ,生成器的目标是最小化这个损失函数以此来对抗判别器。即 $ min_{G} max_{D_{Y}}L_{GAN}(G,D_{Y} ,X,Y )$ 。

单独的对抗损失不能保证所学函数可以将单个输入映射到期望的输出,为了进一步减少可能的映射函数的空间,学习到的映射函数应该是周期一致的,例如对于 X X X 的每个图像 x x x ,图像转换周期应能够将 x x x 带回原始图像,可以称之为正向循环一致性,即 x → G ( x ) → F ( G ( x ) ) ≈ x x→G(x)→F(G(x))\approx x xG(x)F(G(x))x 。对于 Y Y Y ,类似的 x → G ( x ) → F ( G ( x ) ) ≈ x x→G(x)→F(G(x))\approx x xG(x)F(G(x))x 。可以理解采用了一个循环一致性损失来激励这种行为。

循环一致损失函数定义如下:

L c y c ( G , F ) = E x − p d a t a ( x ) [ ∥ F ( G ( x ) ) − x ∥ 1 ] + E y − p d a t a ( y ) [ ∥ G ( F ( y ) ) − y ∥ 1 ] L_{cyc}(G,F)=E_{x-p_{data}(x)}[\Vert F(G(x))-x\Vert_{1}]+E_{y-p_{data}(y)}[\Vert G(F(y))-y\Vert_{1}] Lcyc(G,F)=Expdata(x)[F(G(x))x1]+Eypdata(y)[G(F(y))y1]

循环一致损失能够保证重建图像 F ( G ( x ) ) F(G(x)) F(G(x)) 与输入图像 x x x 紧密匹配。

# 构建生成器,判别器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

# GAN网络损失函数,这里最后一层不使用sigmoid函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss("mean")

def gan_loss(predict, target):
    target = ops.ones_like(predict) * target
    loss = loss_fn(predict, target)
    return loss

前向计算

搭建模型前向计算损失的过程,过程如下代码。

为了减少模型振荡[1],这里遵循 Shrivastava 等人的策略[2],使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。

import mindspore as ms

# 前向计算

def generator(img_a, img_b):
    fake_a = net_rg_b(img_b)
    fake_b = net_rg_a(img_a)
    rec_a = net_rg_b(fake_b)
    rec_b = net_rg_a(fake_a)
    identity_a = net_rg_b(img_a)
    identity_b = net_rg_a(img_b)
    return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b

lambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5

def generator_forward(img_a, img_b):
    true = Tensor(True, dtype.bool_)
    fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)
    loss_g_a = gan_loss(net_d_b(fake_b), true)
    loss_g_b = gan_loss(net_d_a(fake_a), true)
    loss_c_a = l1_loss(rec_a, img_a) * lambda_a
    loss_c_b = l1_loss(rec_b, img_b) * lambda_b
    loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idt
    loss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt
    loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b
    return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b

def generator_forward_grad(img_a, img_b):
    _, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)
    return loss_g

def discriminator_forward(img_a, img_b, fake_a, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    loss_d = (loss_d_a + loss_d_b) * 0.5
    return loss_d

def discriminator_forward_a(img_a, fake_a):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    return loss_d_a

def discriminator_forward_b(img_b, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    return loss_d_b

# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50
def image_pool(images):
    num_imgs = 0
    image1 = []
    if isinstance(images, Tensor):
        images = images.asnumpy()
    return_images = []
    for image in images:
        if num_imgs < pool_size:
            num_imgs = num_imgs + 1
            image1.append(image)
            return_images.append(image)
        else:
            if random.uniform(0, 1) > 0.5:
                random_id = random.randint(0, pool_size - 1)

                tmp = image1[random_id].copy()
                image1[random_id] = image
                return_images.append(tmp)

            else:
                return_images.append(image)
    output = Tensor(return_images, ms.float32)
    if output.ndim != 4:
        raise ValueError("img should be 4d, but get shape {}".format(output.shape))
    return output

计算梯度和反向传播

其中梯度计算也是分开不同的模型来进行的,详情见如下代码:

from mindspore import value_and_grad

# 实例化求梯度的方法
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())

grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())

# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):
    net_d_a.set_grad(False)
    net_d_b.set_grad(False)

    fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)

    _, grads_g_a = grad_g_a(img_a, img_b)
    _, grads_g_b = grad_g_b(img_a, img_b)
    optimizer_rg_a(grads_g_a)
    optimizer_rg_b(grads_g_b)

    return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib

# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):
    net_d_a.set_grad(True)
    net_d_b.set_grad(True)

    loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)
    loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)

    loss_d = (loss_d_a + loss_d_b) * 0.5

    optimizer_d_a(grads_d_a)
    optimizer_d_b(grads_d_b)

    return loss_d

模型训练

训练分为两个主要部分:训练判别器和训练生成器,在前文的判别器损失函数中,论文采用了最小二乘损失代替负对数似然目标。

  • 训练判别器:训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化 E y − p d a t a ( y ) [ ( D ( y ) − 1 ) 2 ] E_{y-p_{data}(y)}[(D(y)-1)^2] Eypdata(y)[(D(y)1)2]

  • 训练生成器:如 CycleGAN 论文所述,我们希望通过最小化 E x − p d a t a ( x ) [ ( D ( G ( x ) − 1 ) 2 ] E_{x-p_{data}(x)}[(D(G(x)-1)^2] Expdata(x)[(D(G(x)1)2] 来训练生成器,以产生更好的虚假图像。

下面定义了生成器和判别器的训练过程:

import os
import time
import random
import numpy as np
from PIL import Image
from mindspore import Tensor, save_checkpoint
from mindspore import dtype

# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1
save_step_num = 80
save_checkpoint_epochs = 1
save_ckpt_dir = './train_ckpt_outputs/'

print('Start training!')

for epoch in range(epochs):
    g_loss = []
    d_loss = []
    start_time_e = time.time()
    for step, data in enumerate(dataset.create_dict_iterator()):
        start_time_s = time.time()
        img_a = data["image_A"]
        img_b = data["image_B"]
        res_g = train_step_g(img_a, img_b)
        fake_a = res_g[0]
        fake_b = res_g[1]

        res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))
        loss_d = float(res_d.asnumpy())
        step_time = time.time() - start_time_s

        res = []
        for item in res_g[2:]:
            res.append(float(item.asnumpy()))
        g_loss.append(res[0])
        d_loss.append(loss_d)

        if step % save_step_num == 0:
            print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
                  f"step:[{int(step):>4d}/{int(datasize):>4d}], "
                  f"time:{step_time:>3f}s,\n"
                  f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "
                  f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "
                  f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "
                  f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")

    epoch_cost = time.time() - start_time_e
    per_step_time = epoch_cost / datasize
    mean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasize

    print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
          f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "
          f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")

    if epoch % save_checkpoint_epochs == 0:
        os.makedirs(save_ckpt_dir, exist_ok=True)
        save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))
        save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))
        save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))
        save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))

print('End of training!')

Start training!
Epoch:[ 1/ 1], step:[ 0/1019], time:197.459362s,
loss_g:17.66, loss_d:0.98, loss_g_a: 0.95, loss_g_b: 1.06, loss_c_a: 3.98, loss_c_b: 6.43, loss_idt_a: 1.99, loss_idt_b: 3.24
Epoch:[ 1/ 1], step:[ 80/1019], time:0.426588s,
loss_g:17.94, loss_d:0.42, loss_g_a: 0.55, loss_g_b: 0.41, loss_c_a: 6.44, loss_c_b: 5.68, loss_idt_a: 2.07, loss_idt_b: 2.77
Epoch:[ 1/ 1], step:[ 160/1019], time:0.447278s,
loss_g:10.87, loss_d:0.53, loss_g_a: 0.95, loss_g_b: 0.30, loss_c_a: 2.77, loss_c_b: 3.83, loss_idt_a: 1.23, loss_idt_b: 1.79
Epoch:[ 1/ 1], step:[ 240/1019], time:0.425675s,
loss_g:12.92, loss_d:0.49, loss_g_a: 0.38, loss_g_b: 0.30, loss_c_a: 3.48, loss_c_b: 5.07, loss_idt_a: 1.56, loss_idt_b: 2.13
Epoch:[ 1/ 1], step:[ 320/1019], time:0.441207s,
loss_g:7.27, loss_d:0.34, loss_g_a: 0.31, loss_g_b: 0.51, loss_c_a: 2.60, loss_c_b: 1.84, loss_idt_a: 1.15, loss_idt_b: 0.87
Epoch:[ 1/ 1], step:[ 400/1019], time:0.445558s,
loss_g:6.25, loss_d:0.31, loss_g_a: 0.33, loss_g_b: 0.54, loss_c_a: 2.30, loss_c_b: 1.45, loss_idt_a: 1.06, loss_idt_b: 0.58
Epoch:[ 1/ 1], step:[ 480/1019], time:0.443621s,
loss_g:4.59, loss_d:0.43, loss_g_a: 0.29, loss_g_b: 0.61, loss_c_a: 1.11, loss_c_b: 1.53, loss_idt_a: 0.44, loss_idt_b: 0.60
Epoch:[ 1/ 1], step:[ 560/1019], time:0.429509s,
loss_g:5.63, loss_d:0.38, loss_g_a: 0.49, loss_g_b: 0.35, loss_c_a: 1.78, loss_c_b: 1.59, loss_idt_a: 0.70, loss_idt_b: 0.71
Epoch:[ 1/ 1], step:[ 640/1019], time:0.434424s,
loss_g:5.38, loss_d:0.32, loss_g_a: 0.63, loss_g_b: 0.32, loss_c_a: 1.05, loss_c_b: 2.13, loss_idt_a: 0.45, loss_idt_b: 0.80
Epoch:[ 1/ 1], step:[ 720/1019], time:0.429569s,
loss_g:3.74, loss_d:0.50, loss_g_a: 0.25, loss_g_b: 0.36, loss_c_a: 1.19, loss_c_b: 0.88, loss_idt_a: 0.55, loss_idt_b: 0.51
Epoch:[ 1/ 1], step:[ 800/1019], time:0.428782s,
loss_g:4.55, loss_d:0.38, loss_g_a: 0.59, loss_g_b: 0.21, loss_c_a: 1.43, loss_c_b: 1.21, loss_idt_a: 0.64, loss_idt_b: 0.47
Epoch:[ 1/ 1], step:[ 880/1019], time:0.438998s,
loss_g:4.58, loss_d:0.32, loss_g_a: 0.37, loss_g_b: 0.55, loss_c_a: 1.36, loss_c_b: 1.15, loss_idt_a: 0.64, loss_idt_b: 0.51
Epoch:[ 1/ 1], step:[ 960/1019], time:0.425754s,
loss_g:4.38, loss_d:0.46, loss_g_a: 0.45, loss_g_b: 0.30, loss_c_a: 1.15, loss_c_b: 1.47, loss_idt_a: 0.48, loss_idt_b: 0.53
Epoch:[ 1/ 1], epoch time:642.74s, per step time:0.63, mean_g_loss:7.50, mean_d_loss:0.44
End of training!

模型推理

下面我们通过加载生成器网络模型参数文件来对原图进行风格迁移,结果中第一行为原图,第二行为对应生成的结果图。

%%time
import os
from PIL import Image
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import load_checkpoint, load_param_into_net

# 加载权重文件
def load_ckpt(net, ckpt_dir):
    param_GA = load_checkpoint(ckpt_dir)
    load_param_into_net(net, param_GA)

g_a_ckpt = './CycleGAN_apple2orange/ckpt/g_a.ckpt'
g_b_ckpt = './CycleGAN_apple2orange/ckpt/g_b.ckpt'

load_ckpt(net_rg_a, g_a_ckpt)
load_ckpt(net_rg_b, g_b_ckpt)

# 图片推理
fig = plt.figure(figsize=(11, 2.5), dpi=100)
def eval_data(dir_path, net, a):

    def read_img():
        for dir in os.listdir(dir_path):
            path = os.path.join(dir_path, dir)
            img = Image.open(path).convert('RGB')
            yield img, dir

    dataset = ds.GeneratorDataset(read_img, column_names=["image", "image_name"])
    trans = [vision.Resize((256, 256)), vision.Normalize(mean=[0.5 * 255] * 3, std=[0.5 * 255] * 3), vision.HWC2CHW()]
    dataset = dataset.map(operations=trans, input_columns=["image"])
    dataset = dataset.batch(1)
    for i, data in enumerate(dataset.create_dict_iterator()):
        img = data["image"]
        fake = net(img)
        fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))
        img = (img[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))

        fig.add_subplot(2, 8, i+1+a)
        plt.axis("off")
        plt.imshow(img.asnumpy())

        fig.add_subplot(2, 8, i+9+a)
        plt.axis("off")
        plt.imshow(fake.asnumpy())

eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)
eval_data('./CycleGAN_apple2orange/predict/orange', net_rg_b, 4)
plt.show()
import time

print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'wujianxingzhe')

在这里插入图片描述

参考

[1] I. Goodfellow. NIPS 2016 tutorial: Generative ad-versarial networks. arXiv preprint arXiv:1701.00160,2016. 2, 4, 5

[2] A. Shrivastava, T. Pfister, O. Tuzel, J. Susskind, W. Wang, R. Webb. Learning from simulated and unsupervised images through adversarial training. In CVPR, 2017. 3, 5, 6, 7

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1923916.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大模型-基于大模型的数据标注

方法来自于这篇论文&#xff1a; Can Generalist Foundation Models Outcompete Special-Purpose Tuning? Case Study in Medicine。 一.背景 假设&#xff0c;存在一批标注好的数据D_labeled&#xff0c;其包含m个标注样本(x, y)。 目标是&#xff0c;基于D_labeled&#…

Redis集群和高可用

文章目录 一、Redis主从复制redis主从复制架构主从复制实现主从复制故障恢复主从复制优化主从复制过程 主从同步优化配置 二、哨兵模式 (Sentinel)redis集群介绍哨兵 (Sentinel)工作原理实现哨兵主从复制哨兵配置文件 三、Redis cluster架构工作原理Redis cluster架构实现集群 …

电表及销售统计Python应用及win程序

暑假每天都要填表算账很烦躁&#xff0c;就整了个小程序来减轻压力 程序可以做到记录输入的每一条数据&#xff0c;并用新数据减去旧数据算新增的量&#xff0c;同时记录填写时间 Python代码 import json import os # 导入os模块 from datetime import datetime from tkint…

防火墙NAT智能选举综合实验

目录 实验拓扑 实验要求 实验思路 实验配置 需求7 需求8 需求9 需求10 需求11 实验拓扑 实验要求 7.办公区设备可以通过电信链路和移动链路上网(多对多的NAT&#xff0c;并且需要保留一个公网IP不能用来转换) 8.分公司设备可以通过总公司的移动链路和电信链路访问到d…

Python数据分析案例52——基于SSA-LSTM的风速预测(麻雀优化)

案例背景 又要开始更新时间序列水论文的系列的方法了&#xff0c;前面基于各种不同神经网络层&#xff0c;还有注意力机制做了一些缝合模型。 其实论文里面用的多的可能是优化算法和模态分解&#xff0c;这两个我还没出专门的例子&#xff0c;这几天正好出一个优化算法的例子来…

RocketMQ~架构了解

简介 RocketMQ 具有高性能、高可靠、高实时、分布式 的特点。它是一个采用 Java 语言开发的分布式的消息系统&#xff0c;由阿里巴巴团队开发&#xff0c;在 2016 年底贡献给 Apache&#xff0c;成为了 Apache 的一个顶级项目。 在阿里内部&#xff0c;RocketMQ 很好地服务了集…

优化Cocos Creator 包体体积

优化Cocos Creator 包体体积 引言一、优化图片文件体积&#xff1a;二、优化声音文件体积&#xff1a;三、优化引擎代码体积&#xff1a;四、 优化字体字库文件大小&#xff1a; 引言 优化Cocos Creator项目的包体体积是一个常见且重要的任务&#xff0c;尤其是在移动设备和网…

【高中数学/幂函数】比较a=2^0.3,b=3^0.2,c=7^0.1的大小

【问题】 比较a2^0.3,b3^0.2,c7^0.1的大小 【解答】 a2^0.32^3/10(2^3)^1/108^1/10 b3^0.23^2/10(3^2)^1/109^1/10 c7^0.17^1/10 由于yx^1/10在x正半轴是增函数&#xff0c;底数大的得数就大。 因为9>8>7,所以b>a>c 【图像】 在图像上绘出曲线yx^1/10&…

红日靶场----(三)1.漏洞利用

上期已经信息收集阶段已经完成&#xff0c;接下来是漏洞利用。 靶场思路 通过信息收集得到两个吧靶场的思路 1、http://192.168.195.33/phpmyadmin/&#xff08;数据库的管理界面&#xff09; root/root 2、http://192.168.195.33/yxcms/index.php?radmin/index/login&am…

杆塔倾斜在线监测装置

概述 我国约960万平方公里已经基本实现电网和基站通讯全覆盖&#xff0c;但我国地貌复杂多样&#xff0c;大部分杆塔需要安装在野外&#xff0c;在安装时并不能保证地基的结实可靠&#xff0c;一不小心就可能导致杆塔的倾斜倒塌。 在通信铁塔倾斜现象发生发展的初期&#xff0…

HarmonyOS(43) @BuilderParam标签使用指南

BuilderParam BuilderParam使用举例定义模板定义具体实现BuilderParam初始化 demo源码参考资料 BuilderParam 该标签有的作用有点类似于设计模式中的模板模式&#xff0c;类似于指定一个UI占位符&#xff0c;具体的实现交给具体的Builder&#xff0c;顾名思义&#xff0c;可以…

面试内容集合

用例设计方法 &#xff08;一&#xff09;等价类划分  常见的软件测试面试题划分等价类: 等价类是指某个输入域的子集合.在该子集合中,各个输入数据对于揭露程序中的错误都是等效的.并合理地假定:测试某等价类的代表值就等于对这一类其它值的测试.因此,可以把全部输入数据合理…

腾讯云如何设置二级域名?

什么是二级域名&#xff1f; 例如我已申请的域名为&#xff1a; test.com //顶级域名 现在我开发的应用要部署到二级域名&#xff1a; blog.test.com 1、打开腾讯云控制台的我的域名&#xff0c;然后点击解析 2、在我的解析页面点击添加记录&#xff0c;然后需注意红色方框处…

js 请求blob:https:// 图片

方式1 def get_file_content_chrome(driver, uri):result driver.execute_async_script("""var uri arguments[0];var callback arguments[1];var toBase64 function(buffer){for(var r,nnew Uint8Array(buffer),tn.length,anew Uint8Array(4*Math.ceil(t/…

[WUSTCTF2020]funnyre

【【反调试】花指令patch与原理分析】https://www.bilibili.com/video/BV1mK411A75G?vd_source7ad69e0c2be65c96d9584e19b0202113 B站这个视频和这道题的花指令一样的 这个call百分之一万是辣鸡 重编译u他 经典辣鸡花指令 nop掉 下面一共有四处,一样的操作 然后回到main函…

奥利奥广告策略解析「扭一扭、舔一舔、泡一泡」广告为何深入人心?

作为一个多年的广告人&#xff0c;我认为奥利奥的「扭一扭、舔一舔、泡一泡」广告策略非常巧妙。今天可以从专业的角度来分析分析一下&#xff0c;大概应该有三大原因吧。 品牌识别度与记忆点&#xff1a; “扭一扭、舔一舔、泡一泡”这句广告语简洁易记&#xff0c;富有节奏…

如何30分钟下载完368G的Android系统源码?

如何30分钟下载完368G的Android系统源码&#xff1f; Android系统开发的一个痛点问题就是Android系统源码庞大&#xff0c;小则100G,大则&#xff0c;三四百G。如标题所言&#xff0c;本文介绍通过局域网高速网速下载源码的方法。 制作源码mirror 从源码git服务器A&#xff0c…

AGI 之 【Hugging Face】 的【问答系统】的 [评估并改进问答Pipeline] / [ 生成式问答 ] 的简单整理

AGI 之 【Hugging Face】 的【问答系统】的 [评估并改进问答Pipeline] / [ 生成式问答 ] 的简单整理 目录 AGI 之 【Hugging Face】 的【问答系统】的 [评估并改进问答Pipeline] / [ 生成式问答 ] 的简单整理 一、简单介绍 二、构建问答系统 三、评估并改进问答pipeline 1…

总结单例模式的写法

一、单例模式的概念 1.1 单例模式的概念 单例模式&#xff08;Singleton Pattern&#xff09;是 Java 中最简单的设计模式之一。这种类型的设计模式属于创建型模式&#xff0c;它提供了一种创建对象的最佳方式。就是当前进程确保一个类全局只有一个实例。 1.2 单例模式的优…

【postgresql】时间函数和操作符

日期/时间操作符 加减操作符&#xff1a; 和 - 可以用于日期、时间、时间戳和时间间隔的加减操作。 SELECT 2024-01-01::date INTERVAL 1 day as "date"; ; -- 结果&#xff1a;2024-01-02SELECT 2024-01-01 12:00:00::timestamp - INTERVAL 2 hours as "…