深度学习:CycleGAN图像风格迁移转换

news2024/10/4 8:32:24

基础概念

CycleGAN是一种GAN的变体,它被设计用来在没有成对训练数据的情况下学习两种不同域之间的图像到图像的转换,不需要同一场景或物体在两个不同域中的对应图像。

CycleGAN由Jun-Yan Zhu等人在2017年提出。

CycleGAN的模型架构主要由两组生成器和判别器组成,每组负责一个方向上的图像转换。

具体来说,假设我们有两个不同的图像领域X(比如马的照片)和Y(比如斑马的照片),那么CycleGAN将包含以下组件:

  1. 生成器G:负责将图像从领域X转换到领域Y。
  2. 生成器F:负责将图像从领域Y转换回领域X。
  3. 判别器DY:用于区分领域Y中的真实图像与通过生成器G从领域X转换来的假图像。
  4. 判别器DX:用于区分领域X中的真实图像与通过生成器F从领域Y转换来的假图像。

模型工作流程

  • 当一张来自领域X的图片x被输入到生成器G时,它会产生一张看起来像是属于领域Y的图片G(x)。
  • 判别器DY会尝试判断G(x)是否是真实的领域Y图片。
  • 同样地,当一张来自领域Y的图片y被输入到生成器F时,它会产生一张看起来像是属于领域X的图片F(y)。
  • 判别器DX会尝试判断F(y)是否是真实的领域X图片。

循环一致性

为了确保生成器G和F不仅能够成功地进行单向转换,而且还能保持原始图像的信息不丢失,CycleGAN引入了循环一致性的概念。

前向循环一致性

对于源域中的图像x,首先通过生成器G生成转换图像G(x),随后通过生成器F将G(x)转换回源域F(G(x))。循环一致性损失计算F(G(x))与原始图像x之间的差异。

反向循环一致性

对于目标域中的图像y,首先通过生成器F生成一个转换后图像F(y),然后通过生成器G将F(y)转换回目标域G(F(y))。计算G(F(y))与原始图像y之间的差异。

对抗性损失

生成器G和F需要生成足够真实的图片七篇对应的判别器DY和DX。

几个基本概念

假图像(Fake Image)

假图像是通过生成器网络将一个域的图像转换成另一个域的图像。例如,在人脸年龄变化的任务中,如果有一个年轻人的脸部图片(属于年轻域),生成器可以生成一张看起来更老的脸部图片(属于年老域)。这个新生成的老年脸部图片就是假图像。

在接下来的代码中,fake_a 是从域 B 的真实图像 img_b 通过生成器 net_rg_b 生成的假图像,而 fake_b 是从域 A 的真实图像 img_a 通过生成器 net_rg_a 生成的假图像。

重建图像(Reconstructed Image)

重建图像是指将假图像再次通过相应的生成器网络转换回原始域的过程。这样做是为了确保图像在跨域转换后仍然能够恢复其原始特征。

例如,如果 fake_b 是从 img_a 生成的,那么再用 net_rg_b 将 fake_b 转换回域 A 得到的图像 rec_a 应该尽可能地接近 img_a

这种循环一致性损失有助于保持图像内容的一致性,即使在跨域转换过程中也不会丢失重要信息。

在接下来的代码中,rec_a 是由 fake_b 通过 net_rg_b 重新转换得到的图像,而 rec_b 是由 fake_a 通过 net_rg_a 重新转换得到的图像。

身份映射图像(Identity Mapping Image)

身份映射图像是指将一个域的真实图像直接输入到对应域的生成器网络中,期望输出与输入相同或非常相似的图像。这用于训练生成器学习如何在不改变图像的情况下保持图像不变。

这种损失被称为身份损失,它鼓励生成器在不需要进行跨域转换时保持图像不变。

在接下来的代码中,identity_a 是将域 A 的真实图像 img_a 直接通过 net_rg_b 得到的输出,而 identity_b 是将域 B 的真实图像 img_b 直接通过 net_rg_a 得到的输出。

基于MindSpore的CycleGAN

数据集

# 数据集
'''
本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。
图像被统一缩放为256×256像素大小,
其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。

对数据进行了随机裁剪、水平随机翻转和归一化的预处理,
为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,
以省略大部分数据预处理的代码。
'''
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"

download(url, ".", kind="zip", replace=True)

# 数据集
'''
本案例使用的数据集里面的图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。
图像被统一缩放为256×256像素大小,
其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。

对数据进行了随机裁剪、水平随机翻转和归一化的预处理,
为了将重点聚焦到模型,此处将数据预处理后的结果转换为 MindRecord 格式的数据,
以省略大部分数据预处理的代码。
'''
from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"

download(url, ".", kind="zip", replace=True)

# 数据集可视化
import numpy as np
import matplotlib.pyplot as plt

mean = 0.5 * 255
std = 0.5 * 255

plt.figure(figsize=(12, 5), dpi=60)
for i, data in enumerate(dataset.create_dict_iterator()):
    if i < 5:
        show_images_a = data["image_A"].asnumpy()
        show_images_b = data["image_B"].asnumpy()

        plt.subplot(2, 5, i+1)
        show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
        plt.imshow(show_images_a)
        plt.axis("off")

        plt.subplot(2, 5, i+6)
        show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
        plt.imshow(show_images_b)
        plt.axis("off")
    else:
        break
plt.show()

生成器的基本架构

构建生成器基本块

# 构建生成器
# 生成器采用ResNet模型结构
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal
import mindspore as ms
# 初始化权重的方法
weight_init = Normal(sigma=0.01)

# 定义ConvNormReLU块
class ConvNormReLU(nn.Cell):
        def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',
                 pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):
            super(ConvNormReLU, self).__init__()
            norm = nn.BatchNorm2d(out_planes)
            if norm_mode == 'instance':
                # 参数affine用于控制是否对归一化后的数据应用可学习的仿射变换(即缩放和平移)。
                # 当设置affine=False时,不会对归一化后的数据进行任何线性变换。
                norm = nn.BatchNorm2d(out_planes, affine=False)
            
            has_bias = (norm_mode == 'instance')
            
            if padding is None:
                padding = (kernel_size - 1) // 2
            if pad_mode == 'CONSTANT':
                # 如果需要转置卷积(上采样)构建转置卷积层
                if transpose:
                    conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',
                                          has_bias=has_bias, weight_init=weight_init)
                else:
                    # 无需转置卷积(下采样)
                    conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, padding=padding, weight_init=weight_init)
                # 组合卷积层和正则化层
                layers = [conv, norm]
            else:
                # 创建了一个四元组列表,每个元组表示一个维度上的前后填充量。
                # (0, 0) 对应于批量大小和通道数维度,意味着在这两个维度上不做任何填充。
                # (padding, padding) 分别对应高度和宽度维度,在这两个维度上都会添加相同数量的填充。
                # 高度和宽度的两侧都会各增加1个像素的填充。
                paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
                # nn.Pad类创建了一个填充层实例。
                # paddings 参数指定了具体的填充方式,按照上面定义的paddings变量。
                pad = nn.Pad(paddings=paddings, mode=pad_mode)
                if transpose:
                    conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                          has_bias=has_bias, weight_init=weight_init)
                else:
                    conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, weight_init=weight_init)
                layers = [pad, conv, norm]
            # 如果需要激活函数,并判断是哪种激活函数
            if use_relu:
                relu = nn.ReLU()
                if alpha > 0:
                    relu = nn.LeakyReLU(alpha)
                layers.append(relu)
            # 组装模型
            self.features = nn.SequentialCell(layers)
            
        def construct(self, x):
            output = self.features(x)
            return output

 定义ResNet的残差块

# 定义ResNet的残差块
class ResidualBlock(nn.Cell):
    def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode='CONSTANT'):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)
        self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)
        self.dropout = dropout
        if dropout:
            self.dropout = nn.Dropout(p=0.5)
            
    def construct(self, x):
        out = self.conv1(x)
        if self.dropout:
            out = self.dropout(out)
        out = self.conv2(out)
        # 返回 x + out 的做法是实现残差学习的关键。这个设计是为了让网络能够更容易地学习到恒等映射(identity mapping)
        # 从而帮助解决深层网络训练中的梯度消失问题,并允许网络构建得更深而不会导致性能下降。
        return x + out

定义基于ResNet的生成器

# 定义基于ResNet的生成器
class ResNetGenerator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,
                 pad_mode="CONSTANT"):
        super(ResNetGenerator, self).__init__()
        # 数据集图像输入后经过的第一个网络
        self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)
        # 随后对数据进行两次下采样
        self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)
        self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)
        # 残差网络有9个残差块
        layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers
        # 组装残差网络
        self.residuals = nn.SequentialCell(layers)
        # 再将图片进行上采样(转置卷积)
        self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)
        self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)
        # 定义输出层
        if pad_mode == 'CONSTANT':
            self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',
                                      padding=3, weight_init=weight_init)
        else:
            pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)
            conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)
            self.conv_out = nn.SequentialCell([pad, conv])
            
    def construct(self, x):
        x = self.conv_in(x)
        x = self.down_1(x)
        x = self.down_2(x)
        x = self.residuals(x)
        x = self.up_2(x)
        x = self.up_1(x)
        output = self.conv_out(x)
        # 将输出压制(-1, 1)
        return ops.tanh(output)

# 实例化生成器
# 创建生成器G和F
net_rg_a = ResNetGenerator()
net_rg_a.update_parameters_name('net_rg_a.')

net_rg_b = ResNetGenerator()
net_rg_b.update_parameters_name('net_rg_b.')

定义判别器

# 创建判别器
# 判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。
# 网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。
class Discriminator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):
        super(Discriminator, self).__init__()
        # 定义卷积核大小
        kernel_size = 4
        layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),
                  nn.LeakyReLU(alpha)]
        # 初始化倍增因子
        nf_mult = output_channel
        # 使用倍增因子逐步增大通道数
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * output_channel
            layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * output_channel
        
        layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        # 输出层
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))
        # 组装模型
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

# 判别器初始化
# 初始化两个判别器
net_d_a = Discriminator()
net_d_a.update_parameters_name('net_d_a.')

net_d_b = Discriminator()
net_d_b.update_parameters_name('net_d_b.')

定义优化器和损失函数

# 构建生成器,判别器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)
# 两个损失函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss('mean')

def gan_loss(predict, target):
    # 全一表示真实数据
    target = ops.ones_like(predict) * target
    loss = loss_fn(predict, target)
    return loss

前向计算

# 前向计算

def generator(img_a, img_b):
    # img_a 是来自域 A 的真实图像
    # img_b 是来自域 B 的真实图像
    
    # 使用网络 net_rg_b 将域 B 的图像 img_b 转换为域 A 的假图像 fake_a
    fake_a = net_rg_b(img_b)
    
    # 使用网络 net_rg_a 将域 A 的图像 img_a 转换为域 B 的假图像 fake_b
    fake_b = net_rg_a(img_a)
    
    # 再次使用网络 net_rg_b 将生成的假图像 fake_b 重新转换回域 A 的重建图像 rec_a
    rec_a = net_rg_b(fake_b)
    
    # 再次使用网络 net_rg_a 将生成的假图像 fake_a 重新转换回域 B 的重建图像 rec_b
    rec_b = net_rg_a(fake_a)
    
    # 使用网络 net_rg_b 直接处理域 A 的图像 img_a,期望输出与输入相同或相似,这是为了保持同一性
    identity_a = net_rg_b(img_a)
    
    # 使用网络 net_rg_a 直接处理域 B 的图像 img_b,期望输出与输入相同或相似,这也是为了保持同一性
    identity_b = net_rg_a(img_b)
    
    # 返回生成的假图像、重建图像和身份映射图像
    # 用于计算循环一致性
    return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b

# 定义不同类型的损失权重
lambda_a = 10.0  # 循环一致性损失 A 到 B 的权重
lambda_b = 10.0  # 循环一致性损失 B 到 A 的权重
lambda_idt = 0.5  # 身份映射损失的权重

def generator_forward(img_a, img_b):
     # 创建一个表示真实的标签 Tensor
    true = Tensor(True, dtype.bool_)
    # 调用先前定义的 generator 函数来获取生成的图像及其重建版本
    fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)
    # 判别器损失
    loss_g_a = gan_loss(net_d_b(fake_b), true)
    loss_g_b = gan_loss(net_d_a(fake_a), true)
    # 循环一致性损失
    loss_c_a = l1_loss(rec_a, img_a) * lambda_a
    loss_c_b = l1_loss(rec_b, img_b) * lambda_b
    # 身份映射损失
    loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idt
    loss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt
    # 整合损失
    loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b
    # 通过这种方式,生成器不仅学习如何欺骗判别器,还要保证图像经过跨域转换后能够准确地恢复原样(循环一致性),以及在不改变域的情况下尽可能保留原始图像(身份映射)。
    return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b
# 获取生成器的总损失
def generator_forward_grad(img_a, img_b):
    _, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)
    return loss_g

# 这个函数同时处理来自域 A 和域 B 的图像,并计算两个判别器的总损失。
def discriminator_forward(img_a, img_b, fake_a, fake_b):
    # 假图像标签
    false = Tensor(False, dtype.bool_)
    # 真图像标签
    true = Tensor(True, dtype.bool_)
    # 判别器a
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    # 判别器b
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    # 计算判别器a的损失
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    # 计算判别器b的损失
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    # 加权计算总损失
    loss_d = (loss_d_a + loss_d_b) * 0.5
    return loss_d
# 只处理域 A 的图像,计算 net_d_a 判别器的损失。
def discriminator_forward_a(img_a, fake_a):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    return loss_d_a
# 只处理域 B 的图像,计算 net_d_b 判别器的损失。
def discriminator_forward_b(img_b, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    return loss_d_b

# 保留了一个图像缓冲区,用来存储之前创建的50个图像
'''
为了减少模型振荡,遵循 Shrivastava 等人的策略[,
使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。
'''
pool_size = 50
def image_pool(images):
    num_imgs = 0
    image1 = []
    if isinstance(images, Tensor):
        images = images.asnumpy()
    return_images = []
    for image in images:
        if num_imgs < pool_size:
            num_imgs = num_imgs + 1
            image1.append(image)
            return_images.append(image)
        else:
            if random.uniform(0, 1) > 0.5:
                random_id = random.randint(0, pool_size - 1)

                tmp = image1[random_id].copy()
                image1[random_id] = image
                return_images.append(tmp)

            else:
                return_images.append(image)
    output = Tensor(return_images, ms.float32)
    if output.ndim != 4:
        raise ValueError("img should be 4d, but get shape {}".format(output.shape))
    return output

梯度计算和反向传播

from mindspore import value_and_grad
# 梯度计算和反向传播
# 实例化求梯度的方法
# 生成器a梯度
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
# 生成器b梯度
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())
# 判别器a梯度
grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
# 判别器d梯度
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())

# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):
    # 对于 net_d 网络中的所有参数,停止计算它们的梯度。
    net_d_a.set_grad(False)
    net_d_b.set_grad(False)

    fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)

    _, grads_g_a = grad_g_a(img_a, img_b)
    _, grads_g_b = grad_g_b(img_a, img_b)
    optimizer_rg_a(grads_g_a)
    optimizer_rg_b(grads_g_b)

    return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib

# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):
    net_d_a.set_grad(True)
    net_d_b.set_grad(True)

    loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)
    loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)

    loss_d = (loss_d_a + loss_d_b) * 0.5

    optimizer_d_a(grads_d_a)
    optimizer_d_b(grads_d_b)

    return loss_d

模型训练

import os  # 操作系统接口模块
import time  # 时间处理模块
import random  # 用于生成随机数
import numpy as np  # 数值计算库
from PIL import Image  # 图像处理库
from mindspore import Tensor, save_checkpoint  # MindSpore 库中的张量和保存检查点功能
from mindspore import dtype  # MindSpore 库中的数据类型定义

# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1  # 训练轮次
save_step_num = 80  # 每隔多少步打印一次信息
save_checkpoint_epochs = 1  # 每隔多少个epoch保存一次模型
save_ckpt_dir = './train_ckpt_outputs/'  # 保存模型检查点的目录

print('Start training!')  # 打印开始训练的信息

for epoch in range(epochs):  # 对每个epoch进行迭代
    g_loss = []  # 初始化生成器损失列表
    d_loss = []  # 初始化判别器损失列表
    start_time_e = time.time()  # 记录当前epoch开始的时间
    for step, data in enumerate(dataset.create_dict_iterator()):  # 对数据集中的每一步进行迭代
        start_time_s = time.time()  # 记录当前步开始的时间
        img_a = data["image_A"]  # 从数据中获取域A的图像
        img_b = data["image_B"]  # 从数据中获取域B的图像
        res_g = train_step_g(img_a, img_b)  # 调用生成器的训练步骤并获取结果
        fake_a = res_g[0]  # 获取生成的假图像A
        fake_b = res_g[1]  # 获取生成的假图像B

        # 调用判别器的训练步骤,使用图像池来存储假图像,并传递给判别器
        res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))
        loss_d = float(res_d.asnumpy())  # 将判别器的损失转换为浮点数
        step_time = time.time() - start_time_s  # 计算当前步的耗时

        # 将生成器的其他损失项转换为浮点数
        res = []
        for item in res_g[2:]:
            res.append(float(item.asnumpy()))
        g_loss.append(res[0])  # 添加总的生成器损失到列表
        d_loss.append(loss_d)  # 添加判别器损失到列表

        if step % save_step_num == 0:  # 如果是需要打印信息的步数
            print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "  # 打印当前epoch/总epoch
                  f"step:[{int(step):>4d}/{int(datasize):>4d}], "  # 打印当前步/总步数
                  f"time:{step_time:>3f}s,\n"  # 打印当前步耗时
                  f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "  # 打印生成器和判别器的损失
                  f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "  # 打印生成器A和B的GAN损失
                  f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "  # 打印循环一致性损失
                  f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")  # 打印身份映射损失

    epoch_cost = time.time() - start_time_e  # 计算当前epoch的总耗时
    per_step_time = epoch_cost / datasize  # 计算每步的平均耗时
    mean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasize  # 计算平均损失

    # 打印当前epoch的平均损失和耗时
    print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
          f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "
          f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")

    if epoch % save_checkpoint_epochs == 0:  # 如果是需要保存检查点的epoch
        os.makedirs(save_ckpt_dir, exist_ok=True)  # 确保保存目录存在
        # 保存生成器和判别器的模型检查点
        save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))
        save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))
        save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))
        save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))

print('End of training!')  # 打印训练结束的信息

模型推理

import os
from PIL import Image
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
from mindspore import load_checkpoint, load_param_into_net

# 加载权重文件
def load_ckpt(net, ckpt_dir):
    param_GA = load_checkpoint(ckpt_dir)
    load_param_into_net(net, param_GA)

g_a_ckpt = './CycleGAN_apple2orange/ckpt/g_a.ckpt'
g_b_ckpt = './CycleGAN_apple2orange/ckpt/g_b.ckpt'

load_ckpt(net_rg_a, g_a_ckpt)
load_ckpt(net_rg_b, g_b_ckpt)

# 图片推理
fig = plt.figure(figsize=(11, 2.5), dpi=100)
def eval_data(dir_path, net, a):

    def read_img():
        for dir in os.listdir(dir_path):
            path = os.path.join(dir_path, dir)
            img = Image.open(path).convert('RGB')
            yield img, dir

    dataset = ds.GeneratorDataset(read_img, column_names=["image", "image_name"])
    trans = [vision.Resize((256, 256)), vision.Normalize(mean=[0.5 * 255] * 3, std=[0.5 * 255] * 3), vision.HWC2CHW()]
    dataset = dataset.map(operations=trans, input_columns=["image"])
    dataset = dataset.batch(1)
    for i, data in enumerate(dataset.create_dict_iterator()):
        img = data["image"]
        fake = net(img)
        fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))
        img = (img[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))

        fig.add_subplot(2, 8, i+1+a)
        plt.axis("off")
        plt.imshow(img.asnumpy())

        fig.add_subplot(2, 8, i+9+a)
        plt.axis("off")
        plt.imshow(fake.asnumpy())

eval_data('./CycleGAN_apple2orange/predict/apple', net_rg_a, 0)
eval_data('./CycleGAN_apple2orange/predict/orange', net_rg_b, 4)
plt.show()

结果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2187697.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mac配置python出现DataDirError: Valid PROJ data directory not found错误的解决

最近在利用python下载SWOT数据时出现以下的问题&#xff1a; import xarray as xr import s3fs import cartopy.crs as ccrs from matplotlib import pyplot as plt import earthaccess from earthaccess import Auth, DataCollections, DataGranules, Store import os os.env…

CSS3--美开二度

免责声明&#xff1a;本文仅做分享&#xff01; 目录 定位 相对定位 绝对定位 定位居中 固定定位 堆叠层级 z-index 定位-小结 CSS 精灵 京东案例 字体图标 下载字体 使用字体 上传矢量图 CSS 修饰属性 垂直对齐方式 vertical-align 过渡 transition 透明度 opa…

【西门子V20变频器】 变频器运行时报A922报警

报警说明 原因&#xff1a; 1.变频器未接负载 2.变频器设定的电机参数与实际电机不匹配 3.查看P2179查看 无负载监控 设定的电流极限值&#xff0c;出厂默认为“3.0”

mysql事务 -- 事务的隔离性(测试实验+介绍,脏读,不可重复读,可重复度读,幻读)

目录 事务的隔离性 引入 测试 读未提交 脏读 读提交 不可重复读 属于问题吗? 例子 可重复读 幻读 串行化 原理 总结 事务的隔离性 引入 当我们让两个客户端共同执行begin语句时,就开始了两个事务并发访问 在这个过程中,可能会出现sql交叉的问题 但我们不希望因为…

项目定位与服务器(SERVER)模块划分

目录 定位 HTTP协议以及HTTP服务器 高并发服务器 单Reactor单线程 单Reactor多线程 多Reactor多线程 模块划分 SERVER模块划分 Buffer 模块 Socket模块 Channel 模块 Connection模块 Acceptor模块 TimerQueue模块 Poller模块 EventLoop模块 TcpServer模块 SE…

【ADC】噪声(1)噪声分类

概述 本文学习于TI 高精度实验室课程&#xff0c;总结 ADC 的噪声分类&#xff0c;并简要介绍量化噪声和热噪声。 文章目录 概述一、ADC 中的噪声类型二、量化噪声三、热噪声四、量化噪声与热噪声对比 一、ADC 中的噪声类型 ADC 固有噪声由两部分组成&#xff1a;第一部分是量…

【树莓派系列】树莓派wiringPi库详解,官方外设开发

树莓派wiringPi库详解&#xff0c;官方外设开发 文章目录 树莓派wiringPi库详解&#xff0c;官方外设开发一、安装wiringPi库二、wiringPi库API大全1.硬件初始化函数2.通用GPIO控制函数3.时间控制函数4.串口通信串口API串口通信配置多串口通信配置串口自发自收测试串口间通信测…

Django 后端数据传给前端

Step 1 创建一个数据库 Step 2 在Django中点击数据库连接 Step 3 连接成功 Step 4 settings中找DATABASES Step 5 将数据库挂上面 将数据库引擎和数据库名改成自己的 Step 6 在_init_.py中加上数据库的支持语句 import pymysql pymysql.install_as_MySQLdb() Step7 简单创建两…

以企业的视角进行大学生招聘

课程来源&#xff1a;中国计算机学会---朱颖韶&#xff08;资深人力资源领域--HR&#xff09; 一、招聘流程 1.简历->门槛 注重&#xff1a;专业学历、行业经验 2.笔试面试->专业知识与技能 3.简历面试-> 过往的成果 4.面试 沟通能力、学习力-----了解动机、价值观…

Pikachu-Sql Inject-insert/update/delete注入

insert 注入 插入语句 insert into tables values(value1,value2,value3); 如&#xff1a;插入用户表 insert into users (id,name,password) values (id,username,password); 当点击注册 先判断是否有SQL注入漏洞&#xff0c;经过判断之后发现存在SQL漏洞。构造insert的pa…

8644 堆排序

### 思路 堆排序是一种基于堆数据结构的排序算法。堆是一种完全二叉树&#xff0c;分为最大堆和最小堆。堆排序的基本思想是将待排序数组构造成一个最大堆&#xff0c;然后依次将堆顶元素与末尾元素交换&#xff0c;并调整堆结构&#xff0c;直到排序完成。 ### 伪代码 1. 读取…

自闭症干预寄宿学校:专业治疗帮助孩子发展

自闭症干预寄宿学校&#xff1a;星贝育园的专业治疗助力孩子全面发展 在自闭症儿童的教育与康复领域&#xff0c;寄宿学校以其独特的教育模式和全面的关怀体系&#xff0c;为众多家庭提供了重要的选择。广州星贝育园自闭症儿童寄宿制学校&#xff0c;作为这一领域的佼佼者&…

达梦core文件分析(学习笔记)

目录 1、core 文件生成 1.1 前置条件说明 1.2 关于 core 文件生成路径的说明 1.3查看 core 文件的前置条件 2、查看 core 文件堆栈信息 2.1 使用gdb 2.2 使用达梦dmrdc 3、core 分析过程 3.1 服务端主动 core 3.2因未知异常原因导致的 core 4、测试案例 4.1测试环境…

(十八)、登陆 k8s 的 kubernetes-dashboard 更多可视化工具

文章目录 1、回顾 k8s 的安装2、确认 k8s 运行状态3、通过 token 登陆3.1、使用现有的用户登陆3.2、新加用户登陆 4、k8s 可视化工具 1、回顾 k8s 的安装 Mac 安装k8s 2、确认 k8s 运行状态 kubectl proxy kubectl cluster-info kubectl get pods -n kubernetes-dashboard3、…

网页前端开发之Javascript入门篇(4/9):循环控制

Javascript循环控制 什么是循环控制&#xff1f; 答&#xff1a;其概念跟 Python教程 介绍的一样&#xff0c;只是语法上有所变化。 参考流程图如下&#xff1a; 其对应语法&#xff1a; var i 0; // 设置起始值 var minutes 15; // 设置结束值&#xff08;15分钟…

Stream流的终结方法(一)

1.Stream流的终结方法 2.forEach 对于forEach方法&#xff0c;用来遍历stream流中的所有数据 package com.njau.d10_my_stream;import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.function.Consumer; import java.util…

安全帽头盔检测数据集 3类 12000张 安全帽数据集 voc yolo

安全帽头盔检测数据集 3类 12000张 安全帽数据集 voc yolo 安全帽头盔检测数据集介绍 数据集名称 安全帽头盔检测数据集 (Safety Helmet and Person Detection Dataset) 数据集概述 该数据集专为训练和评估基于YOLO系列目标检测模型&#xff08;包括YOLOv5、YOLOv6、YOLOv7…

SpringCloud入门(十一)路由过滤器和路由断言工厂

一、路由过滤器 路由过滤器&#xff08; GatewayFilter &#xff09;是网关中提供的一种过滤器&#xff0c;可以对进入网关的请求和微服务返回的响应做处理&#xff1a; 如图&#xff1a;网关路由过滤器&#xff1a; 路由过滤器的作用是&#xff1a; 1.对路由的请求或响应做加…

第二十章(自定义类型,联合和枚举)

1. 联合体类型的声明 2. 联合体的特点 3. 联合体⼤⼩的计算 4. 枚举类型的声明 5. 枚举类型的优点 6. 枚举类型的使⽤ 光阴如骏马加鞭一、联合体 概念&#xff1a;像结构体一样&#xff0c;联合体也是由一个或者多个成员组成的&#xff0c;这些成员也可以是不同的类型。 …

JavaSE篇:文件IO

一 认识文件 在硬盘这种持久化存储的I/O设备或其他存储介质中 &#xff0c;当我们想要进行数据保存时&#xff0c;往往不是保存成⼀个整体&#xff0c;⽽是独⽴成⼀个个的单位进⾏保存&#xff0c;这个独⽴的单位就被抽象成⽂件的概念。就类似办公桌上的⼀份份真实的⽂件⼀般。…