昇思25天学习打卡营第19天|CycleGAN图像风格迁移互换

news2025/1/13 2:54:54

CycleGAN图像风格迁移互换

模型介绍

模型简介

CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络,来自论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks 。该模型实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。

该模型一个重要应用领域是域迁移(Domain Adaptation),可以通俗地理解为图像风格迁移。其实在 CycleGAN 之前,就已经有了域迁移模型,比如 Pix2Pix ,但是 Pix2Pix 要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,因此 CycleGAN 诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,是一种新的无监督的图像迁移网络。

模型结构

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,其结构如下图所示(图片来源于原论文):

为了方便理解,这里以苹果和橘子为例介绍。上图中 𝑋可以理解为苹果,𝑌为橘子;𝐺 为将苹果生成橘子风格的生成器,𝐹 为将橘子生成的苹果风格的生成器,𝐷𝑋 和 𝐷𝑌 为其相应判别器,具体生成器和判别器的结构可见下文代码。模型最终能够输出两个模型的权重,分别将两种图像的风格进行彼此迁移,生成新的图像。

该模型一个很重要的部分就是损失函数,在所有损失里面循环一致损失(Cycle Consistency Loss)是最重要的。循环损失的计算过程如下图所示(图片来源于原论文):

图中苹果图片 𝑥 经过生成器 𝐺 得到伪橘子 𝑌̂ ,然后将伪橘子 𝑌̂  结果送进生成器 𝐹 又产生苹果风格的结果 𝑥̂ ,最后将生成的苹果风格结果 𝑥̂  与原苹果图片 𝑥 一起计算出循环一致损失,反之亦然。循环损失捕捉了这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。详细的训练过程见下文代码。

数据集

本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。图像被统一缩放为256×256像素大小,其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。

这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理,为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,以省略大部分数据预处理的代码。

数据集下载

使用 download 接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用 pip install download 安装 download 包。

环境配置

Python版本

Python 3.9.19

安装环境

pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

完整的Python环境依赖

pip list
Package                        Version
------------------------------ --------------
absl-py                        2.1.0
aiofiles                       22.1.0
aiosqlite                      0.20.0
altair                         5.3.0
annotated-types                0.7.0
anyio                          4.4.0
argon2-cffi                    23.1.0
argon2-cffi-bindings           21.2.0
arrow                          1.3.0
astroid                        3.2.2
asttokens                      2.0.5
astunparse                     1.6.3
attrs                          23.2.0
auto-tune                      0.1.0
autopep8                       1.5.5
Babel                          2.15.0
backcall                       0.2.0
beautifulsoup4                 4.12.3
black                          24.4.2
bleach                         6.1.0
certifi                        2024.6.2
cffi                           1.16.0
charset-normalizer             3.3.2
click                          8.1.7
cloudpickle                    3.0.0
colorama                       0.4.6
comm                           0.2.1
contextlib2                    21.6.0
contourpy                      1.2.1
cycler                         0.12.1
dataflow                       0.0.1
debugpy                        1.6.7
decorator                      5.1.1
defusedxml                     0.7.1
dill                           0.3.8
dnspython                      2.6.1
download                       0.3.5
easydict                       1.13
email_validator                2.2.0
entrypoints                    0.4
exceptiongroup                 1.2.0
executing                      0.8.3
fastapi                        0.111.0
fastapi-cli                    0.0.4
fastjsonschema                 2.20.0
ffmpy                          0.3.2
filelock                       3.15.3
flake8                         3.8.4
fonttools                      4.53.0
fqdn                           1.5.1
fsspec                         2024.6.0
gitdb                          4.0.11
GitPython                      3.1.43
gradio                         4.26.0
gradio_client                  0.15.1
h11                            0.14.0
hccl                           0.1.0
hccl-parser                    0.1
httpcore                       1.0.5
httptools                      0.6.1
httpx                          0.27.0
huggingface-hub                0.23.4
idna                           3.7
importlib-metadata             7.0.1
importlib_resources            6.4.0
iniconfig                      2.0.0
ipykernel                      6.28.0
ipympl                         0.9.4
ipython                        8.15.0
ipython-genutils               0.2.0
ipywidgets                     8.1.3
isoduration                    20.11.0
isort                          5.13.2
jedi                           0.17.2
Jinja2                         3.1.4
joblib                         1.4.2
json5                          0.9.25
jsonpointer                    3.0.0
jsonschema                     4.22.0
jsonschema-specifications      2023.12.1
jupyter_client                 7.4.9
jupyter_core                   5.7.2
jupyter-events                 0.10.0
jupyter-lsp                    2.2.5
jupyter-resource-usage         0.7.2
jupyter_server                 2.14.1
jupyter_server_fileid          0.9.2
jupyter-server-mathjax         0.2.6
jupyter_server_terminals       0.5.3
jupyter_server_ydoc            0.8.0
jupyter-ydoc                   0.2.5
jupyterlab                     3.6.7
jupyterlab_code_formatter      2.2.1
jupyterlab_git                 0.50.1
jupyterlab-language-pack-zh-CN 4.2.post1
jupyterlab-lsp                 4.3.0
jupyterlab_pygments            0.3.0
jupyterlab_server              2.27.2
jupyterlab-system-monitor      0.8.0
jupyterlab-topbar              0.6.1
jupyterlab_widgets             3.0.11
kiwisolver                     1.4.5
markdown-it-py                 3.0.0
MarkupSafe                     2.1.5
matplotlib                     3.9.0
matplotlib-inline              0.1.6
mccabe                         0.6.1
mdurl                          0.1.2
mindspore                      2.2.14
mindvision                     0.1.0
mistune                        3.0.2
ml_collections                 0.1.1
mpmath                         1.3.0
msadvisor                      1.0.0
mypy-extensions                1.0.0
nbclassic                      1.1.0
nbclient                       0.10.0
nbconvert                      7.16.4
nbdime                         4.0.1
nbformat                       5.10.4
nest-asyncio                   1.6.0
notebook                       6.5.7
notebook_shim                  0.2.4
numpy                          1.26.4
op-compile-tool                0.1.0
op-gen                         0.1
op-test-frame                  0.1
opc-tool                       0.1.0
opencv-contrib-python-headless 4.10.0.84
opencv-python                  4.10.0.84
opencv-python-headless         4.10.0.84
orjson                         3.10.5
overrides                      7.7.0
packaging                      23.2
pandas                         2.2.2
pandocfilters                  1.5.1
parso                          0.7.1
pathlib2                       2.3.7.post1
pathspec                       0.12.1
pexpect                        4.8.0
pickleshare                    0.7.5
pillow                         10.3.0
pip                            24.1
platformdirs                   4.2.2
pluggy                         1.5.0
prometheus_client              0.20.0
prompt-toolkit                 3.0.43
protobuf                       5.27.1
psutil                         5.9.0
ptyprocess                     0.7.0
pure-eval                      0.2.2
pycodestyle                    2.6.0
pycparser                      2.22
pydantic                       2.7.4
pydantic_core                  2.18.4
pydocstyle                     6.3.0
pydub                          0.25.1
pyflakes                       2.2.0
Pygments                       2.15.1
pylint                         3.2.3
pyparsing                      3.1.2
pytest                         8.0.0
python-dateutil                2.9.0.post0
python-dotenv                  1.0.1
python-json-logger             2.0.7
python-jsonrpc-server          0.4.0
python-language-server         0.36.2
python-multipart               0.0.9
pytoolconfig                   1.3.1
pytz                           2024.1
PyYAML                         6.0.1
pyzmq                          25.1.2
referencing                    0.35.1
requests                       2.32.3
rfc3339-validator              0.1.4
rfc3986-validator              0.1.1
rich                           13.7.1
rope                           1.13.0
rpds-py                        0.18.1
ruff                           0.4.10
schedule-search                0.0.1
scikit-learn                   1.5.0
scipy                          1.13.1
semantic-version               2.10.0
Send2Trash                     1.8.3
setuptools                     69.5.1
shellingham                    1.5.4
six                            1.16.0
smmap                          5.0.1
sniffio                        1.3.1
snowballstemmer                2.2.0
soupsieve                      2.5
stack-data                     0.2.0
starlette                      0.37.2
sympy                          1.12.1
synr                           0.5.0
te                             0.4.0
terminado                      0.18.1
threadpoolctl                  3.5.0
tinycss2                       1.3.0
toml                           0.10.2
tomli                          2.0.1
tomlkit                        0.12.0
toolz                          0.12.1
tornado                        6.4.1
tqdm                           4.66.4
traitlets                      5.14.3
typer                          0.12.3
types-python-dateutil          2.9.0.20240316
typing_extensions              4.11.0
tzdata                         2024.1
ujson                          5.10.0
uri-template                   1.3.0
urllib3                        2.2.2
uvicorn                        0.30.1
uvloop                         0.19.0
watchfiles                     0.22.0
wcwidth                        0.2.5
webcolors                      24.6.0
webencodings                   0.5.1
websocket-client               1.8.0
websockets                     11.0.3
wheel                          0.43.0
widgetsnbextension             4.0.11
y-py                           0.6.2
yapf                           0.40.2
ypy-websocket                  0.8.4
zipp                           3.17.0

实践代码

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"

download(url, ".", kind="zip", replace=True)

数据集加载

使用 MindSpore 的 MindDataset 接口读取和解析数据集。

from mindspore.dataset import MindDataset

# 读取MindRecord格式数据
name_mr = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"
data = MindDataset(dataset_files=name_mr)
print("Datasize: ", data.get_dataset_size())

batch_size = 1
dataset = data.batch(batch_size)
datasize = dataset.get_dataset_size()

可视化

通过 create_dict_iterator 函数将数据转换成字典迭代器,然后使用 matplotlib 模块可视化部分训练数据。

import numpy as np
import matplotlib.pyplot as plt

mean = 0.5 * 255
std = 0.5 * 255

plt.figure(figsize=(12, 5), dpi=60)
for i, data in enumerate(dataset.create_dict_iterator()):
    if i < 5:
        show_images_a = data["image_A"].asnumpy()
        show_images_b = data["image_B"].asnumpy()

        plt.subplot(2, 5, i+1)
        show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
        plt.imshow(show_images_a)
        plt.axis("off")

        plt.subplot(2, 5, i+6)
        show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
        plt.imshow(show_images_b)
        plt.axis("off")
    else:
        break
plt.show()

构建生成器

本案例生成器的模型结构参考的 ResNet 模型的结构,参考原论文,对于128×128大小的输入图片采用6个残差块相连,图片大小为256×256以上的需要采用9个残差块相连,所以本文网络有9个残差块相连,超参数 n_layers 参数控制残差块数。

生成器的结构如下所示:

具体的模型结构请参照下文代码:

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal

weight_init = Normal(sigma=0.02)

class ConvNormReLU(nn.Cell):
    def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',
                 pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):
        super(ConvNormReLU, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if padding is None:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, padding=padding, weight_init=weight_init)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, weight_init=weight_init)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output


class ResidualBlock(nn.Cell):
    def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)
        self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)
        self.dropout = dropout
        if dropout:
            self.dropout = nn.Dropout(p=0.5)

    def construct(self, x):
        out = self.conv1(x)
        if self.dropout:
            out = self.dropout(out)
        out = self.conv2(out)
        return x + out


class ResNetGenerator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,
                 pad_mode="CONSTANT"):
        super(ResNetGenerator, self).__init__()
        self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)
        self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)
        self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)
        layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers
        self.residuals = nn.SequentialCell(layers)
        self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)
        self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)
        if pad_mode == "CONSTANT":
            self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',
                                      padding=3, weight_init=weight_init)
        else:
            pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)
            conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)
            self.conv_out = nn.SequentialCell([pad, conv])

    def construct(self, x):
        x = self.conv_in(x)
        x = self.down_1(x)
        x = self.down_2(x)
        x = self.residuals(x)
        x = self.up_2(x)
        x = self.up_1(x)
        output = self.conv_out(x)
        return ops.tanh(output)

# 实例化生成器
net_rg_a = ResNetGenerator()
net_rg_a.update_parameters_name('net_rg_a.')

net_rg_b = ResNetGenerator()
net_rg_b.update_parameters_name('net_rg_b.')

构建判别器

判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。通过一系列的 Conv2d 、 BatchNorm2d 和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数得到最终概率。

# 定义判别器
class Discriminator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),
                  nn.LeakyReLU(alpha)]
        nf_mult = output_channel
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * output_channel
            layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * output_channel
        layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

# 判别器初始化
net_d_a = Discriminator()
net_d_a.update_parameters_name('net_d_a.')

net_d_b = Discriminator()
net_d_b.update_parameters_name('net_d_b.')

优化器和损失函数

根据不同模型需要单独的设置优化器,这是训练过程决定的。

对生成器 𝐺及其判别器 𝐷𝑌 ,目标损失函数定义为:

其中 𝐺 试图生成看起来与 𝑌 中的图像相似的图像 𝐺(𝑥),而 𝐷𝑌 的目标是区分翻译样本 𝐺(𝑥)和真实样本 𝑦 ,生成器的目标是最小化这个损失函数以此来对抗判别器。即 𝑚𝑖𝑛𝐺𝑚𝑎𝑥𝐷𝑌𝐿𝐺𝐴𝑁(𝐺,𝐷𝑌,𝑋,𝑌) 。

单独的对抗损失不能保证所学函数可以将单个输入映射到期望的输出,为了进一步减少可能的映射函数的空间,学习到的映射函数应该是周期一致的,例如对于 𝑋的每个图像 𝑥 ,图像转换周期应能够将 𝑥 带回原始图像,可以称之为正向循环一致性,即 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥。对于 𝑌,类似的 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥 。可以理解采用了一个循环一致性损失来激励这种行为。

循环一致损失函数定义如下:

循环一致损失能够保证重建图像 𝐹(𝐺(𝑥)) 与输入图像 𝑥 紧密匹配。

# 构建生成器,判别器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

# GAN网络损失函数,这里最后一层不使用sigmoid函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss("mean")

def gan_loss(predict, target):
    target = ops.ones_like(predict) * target
    loss = loss_fn(predict, target)
    return loss

前向计算

搭建模型前向计算损失的过程,过程如下代码。

为了减少模型振荡,这里遵循 Shrivastava 等人的策略,使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。

import mindspore as ms

# 前向计算

def generator(img_a, img_b):
    fake_a = net_rg_b(img_b)
    fake_b = net_rg_a(img_a)
    rec_a = net_rg_b(fake_b)
    rec_b = net_rg_a(fake_a)
    identity_a = net_rg_b(img_a)
    identity_b = net_rg_a(img_b)
    return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b

lambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5

def generator_forward(img_a, img_b):
    true = Tensor(True, dtype.bool_)
    fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)
    loss_g_a = gan_loss(net_d_b(fake_b), true)
    loss_g_b = gan_loss(net_d_a(fake_a), true)
    loss_c_a = l1_loss(rec_a, img_a) * lambda_a
    loss_c_b = l1_loss(rec_b, img_b) * lambda_b
    loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idt
    loss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt
    loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b
    return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b

def generator_forward_grad(img_a, img_b):
    _, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)
    return loss_g

def discriminator_forward(img_a, img_b, fake_a, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    loss_d = (loss_d_a + loss_d_b) * 0.5
    return loss_d

def discriminator_forward_a(img_a, fake_a):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    return loss_d_a

def discriminator_forward_b(img_b, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    return loss_d_b

# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50
def image_pool(images):
    num_imgs = 0
    image1 = []
    if isinstance(images, Tensor):
        images = images.asnumpy()
    return_images = []
    for image in images:
        if num_imgs < pool_size:
            num_imgs = num_imgs + 1
            image1.append(image)
            return_images.append(image)
        else:
            if random.uniform(0, 1) > 0.5:
                random_id = random.randint(0, pool_size - 1)

                tmp = image1[random_id].copy()
                image1[random_id] = image
                return_images.append(tmp)

            else:
                return_images.append(image)
    output = Tensor(return_images, ms.float32)
    if output.ndim != 4:
        raise ValueError("img should be 4d, but get shape {}".format(output.shape))
    return output

计算梯度和反向传播

其中梯度计算也是分开不同的模型来进行的,详情见如下代码:

from mindspore import value_and_grad

# 实例化求梯度的方法
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())

grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())

# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):
    net_d_a.set_grad(False)
    net_d_b.set_grad(False)

    fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)

    _, grads_g_a = grad_g_a(img_a, img_b)
    _, grads_g_b = grad_g_b(img_a, img_b)
    optimizer_rg_a(grads_g_a)
    optimizer_rg_b(grads_g_b)

    return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib

# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):
    net_d_a.set_grad(True)
    net_d_b.set_grad(True)

    loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)
    loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)

    loss_d = (loss_d_a + loss_d_b) * 0.5

    optimizer_d_a(grads_d_a)
    optimizer_d_b(grads_d_b)

    return loss_d

模型训练

训练分为两个主要部分:训练判别器和训练生成器,在前文的判别器损失函数中,论文采用了最小二乘损失代替负对数似然目标。

  • 训练判别器:训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化 𝐸𝑦−𝑝𝑑𝑎𝑡𝑎(𝑦)[(𝐷(𝑦)−1)2];

  • 训练生成器:如 CycleGAN 论文所述,我们希望通过最小化 𝐸𝑥−𝑝𝑑𝑎𝑡𝑎(𝑥)[(𝐷(𝐺(𝑥)−1)2] 来训练生成器,以产生更好的虚假图像。

下面定义了生成器和判别器的训练过程:

import os
import time
import random
import numpy as np
from PIL import Image
from mindspore import Tensor, save_checkpoint
from mindspore import dtype

# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1
save_step_num = 80
save_checkpoint_epochs = 1
save_ckpt_dir = './train_ckpt_outputs/'

print('Start training!')

for epoch in range(epochs):
    g_loss = []
    d_loss = []
    start_time_e = time.time()
    for step, data in enumerate(dataset.create_dict_iterator()):
        start_time_s = time.time()
        img_a = data["image_A"]
        img_b = data["image_B"]
        res_g = train_step_g(img_a, img_b)
        fake_a = res_g[0]
        fake_b = res_g[1]

        res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))
        loss_d = float(res_d.asnumpy())
        step_time = time.time() - start_time_s

        res = []
        for item in res_g[2:]:
            res.append(float(item.asnumpy()))
        g_loss.append(res[0])
        d_loss.append(loss_d)

        if step % save_step_num == 0:
            print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
                  f"step:[{int(step):>4d}/{int(datasize):>4d}], "
                  f"time:{step_time:>3f}s,\n"
                  f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "
                  f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "
                  f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "
                  f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")

    epoch_cost = time.time() - start_time_e
    per_step_time = epoch_cost / datasize
    mean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasize

    print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
          f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "
          f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")

    if epoch % save_checkpoint_epochs == 0:
        os.makedirs(save_ckpt_dir, exist_ok=True)
        save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))
        save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))
        save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))
        save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))

print('End of training!')

模型推理

下面我们通过加载生成器网络模型参数文件来对原图进行风格迁移,结果中第一行为原图,第二行为对应生成的结果图。

%%time
import os
from PIL import Image
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import load_checkpoint, load_param_into_net

# 加载权重文件
def load_ckpt(net, ckpt_dir):
    param_GA = load_checkpoint(ckpt_dir)
    load_param_into_net(net, param_GA)

g_a_ckpt = './CycleGAN_apple2orange/ckpt/g_a.ckpt'
g_b_ckpt = './CycleGAN_apple2orange/ckpt/g_b.ckpt'

load_ckpt(net_rg_a, g_a_ckpt)
load_ckpt(net_rg_b, g_b_ckpt)

# 图片推理
fig = plt.figure(figsize=(11, 2.5), dpi=100)
def eval_data(dir_path, net, a):

    def read_img():
        for dir in os.listdir(dir_path):
            path = os.path.join(dir_path, dir)
            img = Image.open(path).convert('RGB')
            yield img, dir

    dataset = ds.GeneratorDataset(read_img, column_names=["image", "image_name"])
    trans = [vision.Resize((256, 256)), vision.Normalize(mean=[0.5 * 255] * 3, std=[0.5 * 255] * 3), vision.HWC2CHW()]
    dataset = dataset.map(operations=trans, input_columns=["image"])
    dataset = dataset.batch(1)
    for i, data in enumerate(dataset.create_dict_iterator()):
        img = data["image"]
        fake = net(img)
        fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))
        img = (img[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))

        fig.add_subplot(2, 8, i+1+a)
        plt.axis("off")
        plt.imshow(img.asnumpy())

        fig.add_subplot(2, 8, i+9+a)
        plt.axis("off")
        plt.imshow(fake.asnumpy())

eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)
eval_data('./CycleGAN_apple2orange/predict/orange', net_rg_b, 4)
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1921175.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

jenkins打包java项目报错Error: Unable to access jarfile tlm-admin.jar

jenkins打包boot项目 自动重启脚本失败 查看了一下项目日志报错&#xff1a; Error: Unable to access jarfile tlm-admin.jar我检查了一下这个配置&#xff0c;感觉没有问题&#xff0c;包可以正常打&#xff0c; cd 到项目目录下面&#xff0c;手动执行这个sh脚本也是能正常…

本地Kali系统开启SSH服务并使用内网穿透生成公网地址实现ssh远程连接

文章目录 前言1. 启动kali ssh 服务2. kali 安装cpolar 内网穿透3. 配置kali ssh公网地址4. 远程连接5. 固定连接SSH公网地址6. SSH固定地址连接测试 前言 本文主要介绍如何在本地Kali Linux系统启动ssh服务&#xff0c;并结合cpolar内网穿透软件生成公网地址&#xff0c;轻松…

提示词工程(Prompt Engineering)是什么?

一、定义 Prompt Engineering 提示词工程&#xff08;Prompt Engineering&#xff09;是一项通过优化提示词&#xff08;Prompt&#xff09;和生成策略&#xff0c;从而获得更好的模型返回结果的工程技术。 二、System message 系统指令 System message可以被广泛应用在&am…

平凯星辰黄东旭出席 2024 全球数字经济大会 · 开放原子开源数据库生态论坛

7 月 5 日&#xff0c;以“开源生态筑基础&#xff0c;数字经济铸未来”为主题的 2024 全球数字经济大会——开放原子开源数据库生态论坛在北京成功举办。平凯星辰&#xff08;北京&#xff09;科技有限公司联合创始人黄东旭发表了题为《TiDB 助力金融行业关键业务系统实践》的…

【TS】typescript 获取函数入参类型、返回值类型、promise返回值类型

文章目录 1. 准备工作2. 获取函数入参的类型3. 获取函数返回值类型4. 获取promise返回值类型 1. 准备工作 创建 utils.ts interface User {id: number;name: string;age: number; } interface Params {method: string;url: string; }function getUserList(params: Params,other…

RocketMQ 消费者之顺序消费和流程详解附源码解析

1. 背景 本文是 RocketMQ 消费者系列的第六篇&#xff0c;上一篇主要介绍并发消费&#xff0c;而本片主要介绍 RocketMQ 顺序消费的设计和流程。 我把 RocketMQ 消费分成如下几个步骤 重平衡 消费者拉取消息 Broker 接收拉取请求后从存储中查询消息并返回 消费者消费消息 顺序…

算法学习day10(贪心算法)

贪心算法&#xff1a;由局部最优->全局最优 贪心算法一般分为如下四步&#xff1a; 将问题分解为若干个子问题找出适合的贪心策略求解每一个子问题的最优解将局部最优解堆叠成全局最优解 一、摆动序列&#xff08;理解难&#xff09; 连续数字之间的差有正负的交替&…

GO channel 学习

引言 单纯地将函数并发执行是没有意义的。函数与函数间需要交换数据才能体现并发执行函数的意义。 虽然可以使用共享内存进行数据交换&#xff0c;但是共享内存在不同的goroutine中容易发生竞态问题。为了保证数据交换的正确性&#xff0c;必须使用互斥量对内存进行加锁&#…

视频语音转文字工具用哪个好?推荐6款优质的视频转文字工具

在沉浸于电影情节时&#xff0c;周遭的喧嚣往往成了享受视听的障碍&#xff0c;这时&#xff0c;字幕的重要性便不言而喻。 字幕的作用远不止于此&#xff0c;它是听力受限观众的桥梁&#xff0c;也是语言学习者的得力助手。幸运的是&#xff0c;将视频语音转文字字幕现已变得…

红酒与未来科技:传统与创新的碰撞

在岁月的长河中&#xff0c;红酒以其深邃的色泽、丰富的口感和不同的文化魅力&#xff0c;成为人类文明中的一颗璀璨明珠。而未来科技&#xff0c;则以其迅猛的发展速度和无限的可能性&#xff0c;领着人类走向一个崭新的时代。当红酒与未来科技相遇&#xff0c;一场传统与创新…

电脑自动重启是什么原因呢?99%人都不知道的解决办法,直接打破循环

当你的电脑突然毫无预警地自动重启&#xff0c;不仅打断了工作流程&#xff0c;还可能导致未保存的数据丢失&#xff0c;这无疑是一件令人沮丧的事情。那么&#xff0c;电脑自动重启是什么原因呢&#xff1f;有什么方法可以解决呢&#xff1f;别担心&#xff0c;在大多数情况下…

对象与键值对数组的相互转换Object.entries与Object.fromEntries

Object.entries是JavaScript中的一个内置方法&#xff0c;它可以将一个对象的属性和值转换为一个包含键值对的数组。 let obj {name: mike,age: 18,sex: man } Object.entries(obj)Object.entries的使用场景: 1、动态更新对象属性 let obj {name: mike, age: 18, sex: man…

《昇思25天学习打卡营第02天|qingyun201003》

日期 心得 通过这次的学习&#xff0c;主要是了解过张量的基础概念&#xff0c;同时也知道有关构造张量的方法。通过索引查询张量&#xff0c;张量的运算。通过concat\stack 将张量进行维度链接。Tensor与NumPy的互相转换&#xff0c;但是我似乎并不了解什么它们的概念。也认知…

电脑视频去水印软件哪个好用,电脑上视频去水印的软件

在数字化时代&#xff0c;视频创作已成为许多人展示才华和创意的重要途径。然而&#xff0c;视频中的水印常常让人感到头疼&#xff0c;尤其是当水印影响了视频的整体美观时。本文将为你揭秘如何在电脑上使用各种软件去除视频水印&#xff0c;让你的作品更加专业&#xff01; 方…

时光穿梭机:AI如何让老照片焕发新生,跃然“动”起来

在岁月的长河中&#xff0c;每一张老照片都是时间的低语&#xff0c;承载着过往的记忆与温情。它们静静地躺在相册的角落&#xff0c;或是泛黄的相纸上&#xff0c;定格了某个瞬间的欢笑与泪水&#xff0c;却也因此失去了那份生动的活力。然而&#xff0c;随着人工智能&#xf…

如何做一个迟钝不受伤的打工人?

一、背景 在当前激烈的职场环境中&#xff0c;想要成为一个相对“迟钝”且不易受伤的打工人&#xff0c;以下是一些建议&#xff0c;但请注意&#xff0c;这里的“迟钝”并非指智力上的迟钝&#xff0c;而是指在应对复杂人际关系和压力时展现出的豁达与钝感力&#xff1a; 尊重…

如何在Linux系统安装openGauss数据库并使用固定公网地址远程连接

文章目录 前言1. Linux 安装 openGauss2. Linux 安装cpolar3. 创建openGauss主节点端口号公网地址4. 远程连接openGauss5. 固定连接TCP公网地址6. 固定地址连接测试 前言 本文主要介绍如何在Linux系统如何安装openGauss数据库管理系统&#xff0c;并结合cpolar内网穿透工具生成…

docker基础知识以及windows上的docker desktop 安装

记录以供备忘 基础概念&#xff1a; 什么是docker 将程序和环境一起打包&#xff0c;以在不同操作系统上运行的工具软件 什么是基础镜像 选一个基础操作系统和语言后&#xff0c;将对应的文件系统、依赖库、配置等打包为一个类似压缩包的文件&#xff0c;就是基础镜像 什么是…

【银河麒麟高级服务器操作系统】数据中心系统异常卡死分析处理建议

了解银河麒麟操作系统更多全新产品&#xff0c;请点击访问&#xff1a;https://product.kylinos.cn 1.服务器环境以及配置 【机型】浪潮NF5280M5 处理器&#xff1a; Intel 内存&#xff1a; 1T 【内核版本】 4.19.90-24.4.v2101.ky10.x86_64 【OS镜像版本】 银河麒麟…

python 屏幕显示一个文本窗口,文本窗口显示在当前鼠标在的位置,该文字窗口跟随鼠标移动,并且始终保持最前面显示 ,可以根据文字的多少来自动调节窗口大小

python 屏幕显示一个文本窗口,我有一段文字需要显示,鼠标在那里,文本窗口就在哪里显示,该文字窗口需要跟随鼠标移动,并且始终保持最前面显示,可以根据文字的多少来自动调节窗口大小 仅仅使用 tkinter # -*- coding:utf-8 -*-import tkinter as tkdef update_position(e…