基于MindSpore实现CycleGAN壁画修复

news2024/10/6 9:27:28

基于MindSpore实现CycleGAN壁画修复_哔哩哔哩_bilibili

本案例运行需要较大内存,建议在Ascend/GPU上运行。

模型介绍

模型简介

CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络,来自论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks 。该模型实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。

该模型一个重要应用领域是域迁移(Domain Adaptation),可以通俗地理解为图像风格迁移。其实在 CycleGAN 之前,就已经有了域迁移模型,比如 Pix2Pix ,但是 Pix2Pix 要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,因此 CycleGAN 诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,是一种新的无监督的图像迁移网络。

模型结构

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,其结构如下图所示(图片来源于原论文):

CycleGAN

该模型一个很重要的部分就是损失函数,在所有损失里面循环一致损失(Cycle Consistency Loss)是最重要的。循环损失的计算过程如下图所示(图片来源于原论文):

Cycle Consistency Loss

循环损失捕捉了这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。详细的训练过程见下文代码。

本任务旨在实现破损草图到目标线稿图之间的转化。

sketch_lineart

数据集

本案例使用的数据集里面的图片为经图线稿图数据。图像被统一缩放为256×256像素大小,其中用于训练的线稿图片25654张、草图图片25654张,用于测试的线稿图片100张、草图图片116张。

这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理.

数据集下载

使用 download 接口下载数据集,并将下载后的数据集自动解压到当前目录下。数据下载之前需要使用 pip install download 安装 download 包。

%%capture captured_output
# 实验环境已经预装了mindspore==2.3.0,如需更换mindspore版本,可更改下面 MINDSPORE_VERSION 变量
!pip uninstall mindspore -y
%env MINDSPORE_VERSION=2.3.0
!pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/${MINDSPORE_VERSION}/MindSpore/unified/aarch64/mindspore-${MINDSPORE_VERSION}-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.mirrors.ustc.edu.cn/simple
# 查看当前 mindspore 版本
!pip show mindspore
Name: mindspore
Version: 2.3.0
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 
import mindspore

#设置使用的设备
mindspore.set_context(device_target='Ascend')##gpu,cpu,Ascend
print(mindspore.get_context(attr_key='device_target'))
Ascend
#下载数据集
from download import download

url = "https://6169fb4615b14dbcb6b2cb1c4eb78bb2.obs.cn-north-4.myhuaweicloud.com/Cyc_line.zip"

download(url, "./localdata", kind="zip", replace=True)
Downloading data from https://6169fb4615b14dbcb6b2cb1c4eb78bb2.obs.cn-north-4.myhuaweicloud.com/Cyc_line.zip (1.88 GB)

file_sizes: 100%|███████████████████████████| 2.02G/2.02G [00:09<00:00, 222MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./localdata
#下载预训练模型参数
from download import download

url = "https://6169fb4615b14dbcb6b2cb1c4eb78bb2.obs.cn-north-4.myhuaweicloud.com/checkpoints.zip"

download(url, "./ckpt", kind="zip", replace=True)
Downloading data from https://6169fb4615b14dbcb6b2cb1c4eb78bb2.obs.cn-north-4.myhuaweicloud.com/checkpoints.zip (165.6 MB)

file_sizes: 100%|█████████████████████████████| 174M/174M [00:01<00:00, 148MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./ckpt

数据集加载

from __future__ import division
import math
import numpy as np

import os
import multiprocessing

import mindspore.dataset as de
import mindspore.dataset.vision as C

"""数据集分布式采样器"""
class DistributedSampler:
    """Distributed sampler."""
    def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            print("***********Setting world_size to 1 since it is not passed in ******************")
            num_replicas = 1
        if rank is None:
            print("***********Setting rank to 0 since it is not passed in ******************")
            rank = 0
        self.dataset_size = dataset_size
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle

    def __iter__(self):
        # deterministically shuffle based on epoch
        if self.shuffle:
            indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
            # np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
            indices = indices.tolist()
            self.epoch += 1
            # change to list type
        else:
            indices = list(range(self.dataset_size))

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples
"""加载Cycle GAN 数据集."""

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']

def is_image_file(filename):
    """Judge whether it is a picture."""
    return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)

def make_dataset(dir_path, max_dataset_size=float("inf")):
    """Return image list in dir."""
    images = []
    assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path

    for root, _, fnames in sorted(os.walk(dir_path)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
    return images[:min(max_dataset_size, len(images))]

class UnalignedDataset:
    """
    This dataset class can load unaligned/unpaired datasets.
    It requires two directories to host training images from domain A '/path/to/data/trainA'
    and from domain B '/path/to/data/trainB' respectively.
    You can train the model with the dataset flag '--dataroot /path/to/data'.
    Similarly, you need to prepare two directories:
    '/path/to/data/testA' and '/path/to/data/testB' during test time.
    Returns:
        Two domain image path list.
    """

    def __init__(self, dataroot, max_dataset_size=float("inf"), use_random=True):
        self.dir_A = os.path.join(dataroot, 'trainA')
        self.dir_B = os.path.join(dataroot, 'trainB')

        self.A_paths = sorted(make_dataset(self.dir_A, max_dataset_size))  # load images from '/path/to/data/trainA'
        self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size))  # load images from '/path/to/data/trainB'
        self.A_size = len(self.A_paths)  # get the size of dataset A
        self.B_size = len(self.B_paths)  # get the size of dataset B
        self.use_random = use_random

    def __getitem__(self, index):
        """Return a data point and its metadata information.

        Parameters:
            index (int)      -- a random integer for data indexing

        Returns a dictionary that contains A, B, A_paths and B_paths
            A (tensor)       -- an image in the input domain
            B (tensor)       -- its corresponding image in the target domain
            A_paths (str)    -- image paths
            B_paths (str)    -- image paths
        """
        index_B = index % self.B_size
        if index % max(self.A_size, self.B_size) == 0 and self.use_random:
            random.shuffle(self.A_paths)
            index_B = random.randint(0, self.B_size - 1)
        A_path = self.A_paths[index % self.A_size]
        B_path = self.B_paths[index_B]
        A_img = np.array(Image.open(A_path).convert('RGB'))
        B_img = np.array(Image.open(B_path).convert('RGB'))

        return A_img, B_img

    def __len__(self):
        """Return the total number of images in the dataset.
        """
        return max(self.A_size, self.B_size)
def create_dataset(dataroot,batch_size=1,use_random=True,device_num=1,rank=0,max_dataset_size=float("inf"),image_size=256):
    """
    Create dataset
    This dataset class can load images for train or test.
    Args:
        dataroot (str): Images root directory.
    Returns:
        RGB Image list.
    """
    shuffle = use_random
    cores = multiprocessing.cpu_count()
    num_parallel_workers = min(1, int(cores / device_num))

    mean = [0.5 * 255] * 3
    std = [0.5 * 255] * 3
    
    dataset = UnalignedDataset(dataroot, max_dataset_size=max_dataset_size, use_random=use_random)
    distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
    ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
                             sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
    if use_random:
        trans = [
            C.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
            C.RandomHorizontalFlip(prob=0.5),
            C.Normalize(mean=mean, std=std),
            C.HWC2CHW()
        ]
    else:
        trans = [
            C.Resize((image_size, image_size)),
            C.Normalize(mean=mean, std=std),
            C.HWC2CHW()
        ]
    ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
    ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
    ds = ds.batch(batch_size, drop_remainder=True)
        
    return ds
#根据设备情况调整训练参数

dataroot = "./localdata"
batch_size = 12
device_num = 1
rank = 0
use_random = True
max_dataset_size = 24000
image_size = 256

cyclegan_ds = create_dataset(dataroot=dataroot,max_dataset_size=max_dataset_size,batch_size=batch_size,device_num=device_num,rank = rank,use_random=use_random,image_size=image_size)
datasize = cyclegan_ds.get_dataset_size()
print("Datasize: ", datasize)
Datasize:  2000

可视化

通过 create_dict_iterator 函数将数据转换成字典迭代器,然后使用 matplotlib 模块可视化部分训练数据。

"""图片展示"""
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 5), dpi=60)

for i, data in enumerate(cyclegan_ds.create_dict_iterator()):
    if i < 5:
        show_images_a = data["image_A"]
        show_images_b = data["image_B"]
        show_images_a = (show_images_a.asnumpy()[0]+1)/2  # 转换为NumPy数组并去除多余的维度
        show_images_b = (show_images_b.asnumpy()[0]+1)/2 # 转换为NumPy数组并去除多余的维度

        # 确保通道顺序正确(假设从RGB转换为BGR)
        show_images_a = show_images_a.transpose((1, 2, 0))  # 将通道维度移动到最后
        show_images_b = show_images_b.transpose((1, 2, 0))  # 将通道维度移动到最后
        show_images_a = show_images_a[..., ::-1]  # 反转通道顺序(从RGB到BGR)
        show_images_b = show_images_b[..., ::-1]  # 反转通道顺序(从RGB到BGR)

        plt.subplot(2, 5, i+1)
        plt.imshow(show_images_a)
        plt.axis("off")

        plt.subplot(2, 5, i+6)
        plt.imshow(show_images_b)
        plt.axis("off")
    else:
        break
plt.show()

构建生成器

本案例生成器的模型结构参考的 ResNet 模型的结构,参考原论文,对于128×128大小的输入图片采用6个残差块相连,图片大小为256×256以上的需要采用9个残差块相连,所以本文网络有9个残差块相连,超参数 n_layers 参数控制残差块数。

生成器的结构如下所示:

CycleGAN Generator

具体的模型结构请参照下文代码:

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal

weight_init = Normal(sigma=0.02)

class ConvNormReLU(nn.Cell):
    def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',
                 pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):
        super(ConvNormReLU, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if padding is None:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, padding=padding, weight_init=weight_init)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, weight_init=weight_init)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output


class ResidualBlock(nn.Cell):
    def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)
        self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)
        self.dropout = dropout
        if dropout:
            self.dropout = nn.Dropout(p=0.5)

    def construct(self, x):
        out = self.conv1(x)
        if self.dropout:
            out = self.dropout(out)
        out = self.conv2(out)
        return x + out


class ResNetGenerator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,
                 pad_mode="CONSTANT"):
        super(ResNetGenerator, self).__init__()
        self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)
        self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)
        self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)
        layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers
        self.residuals = nn.SequentialCell(layers)
        self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)
        self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)
        if pad_mode == "CONSTANT":
            self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',
                                      padding=3, weight_init=weight_init)
        else:
            pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)
            conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)
            self.conv_out = nn.SequentialCell([pad, conv])

    def construct(self, x):
        x = self.conv_in(x)
        x = self.down_1(x)
        x = self.down_2(x)
        x = self.residuals(x)
        x = self.up_2(x)
        x = self.up_1(x)
        output = self.conv_out(x)
        return ops.tanh(output)

# 实例化生成器
net_rg_a = ResNetGenerator()
# net_rg_a.update_parameters_name('net_rg_a.')

net_rg_b = ResNetGenerator()
# net_rg_b.update_parameters_name('net_rg_b.')

构建判别器

判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。通过一系列的 Conv2d 、 BatchNorm2d 和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数得到最终概率。

# 定义判别器
class Discriminator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),
                  nn.LeakyReLU(alpha)]
        nf_mult = output_channel
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * output_channel
            layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * output_channel
        layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))
        self.features = nn.SequentialCell(layers)
​
    def construct(self, x):
        output = self.features(x)
        return output
​
# 判别器初始化
net_d_a = Discriminator()
# net_d_a.update_parameters_name('net_d_a.')
​
net_d_b = Discriminator()
# net_d_b.update_parameters_name('net_d_b.')

优化器和损失函数

根据不同模型需要单独的设置优化器,这是训练过程决定的。

对生成器 𝐺G 及其判别器 𝐷𝑌DY ,目标损失函数定义为:

L_{GAN}(G,D_{Y},X,Y)=E_{y-P_{data}(y)}[logD_{y}(y)]+E_{x-P_{data}(x)}[log(1-D_{y}(G(x)))]

其中 𝐺 试图生成看起来与 𝑌 中的图像相似的图像 𝐺(𝑥) ,而 𝐷𝑌的目标是区分翻译样本 𝐺(𝑥)G(x) 和真实样本 𝑦 ,生成器的目标是最小化这个损失函数以此来对抗判别器。即min_{G}max_{D_{Y}}L_{GAN}(G,D_{Y},X,Y)

单独的对抗损失不能保证所学函数可以将单个输入映射到期望的输出,为了进一步减少可能的映射函数的空间,学习到的映射函数应该是周期一致的,例如对于 𝑋 的每个图像 𝑥 ,图像转换周期应能够将 𝑥 带回原始图像,可以称之为正向循环一致性,即 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥 。对于 𝑌 ,类似的 𝑥→𝐺(𝑥)→𝐹(𝐺(𝑥))≈𝑥。可以理解采用了一个循环一致性损失来激励这种行为。

循环一致损失函数定义如下:

L_{cyc}(G,F)=E_{x-p_{data}(x)}[\left | \left | F(G(x))-x \right | \right |_{1}]+E_{y-p_{data}(y)}[\left | \left | F(G(y))-y \right | \right |_{1}]

循环一致损失能够保证重建图像 𝐹(𝐺(𝑥)) 与输入图像 𝑥 紧密匹配。

# 构建生成器,判别器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

# GAN网络损失函数,这里最后一层不使用sigmoid函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss("mean")

def gan_loss(predict, target):
    target = ops.ones_like(predict) * target
    loss = loss_fn(predict, target)
    return loss

前向计算

搭建模型前向计算损失的过程,过程如下代码。

为了减少模型振荡[1],这里遵循 Shrivastava 等人的策略[2],使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。

import mindspore as ms

# 前向计算

def generator(img_a, img_b):
    fake_a = net_rg_b(img_b)
    fake_b = net_rg_a(img_a)
    rec_a = net_rg_b(fake_b)
    rec_b = net_rg_a(fake_a)
    identity_a = net_rg_b(img_a)
    identity_b = net_rg_a(img_b)
    return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b

lambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5

def generator_forward(img_a, img_b):
    true = Tensor(True, dtype.bool_)
    fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)
    loss_g_a = gan_loss(net_d_b(fake_b), true)
    loss_g_b = gan_loss(net_d_a(fake_a), true)
    loss_c_a = l1_loss(rec_a, img_a) * lambda_a
    loss_c_b = l1_loss(rec_b, img_b) * lambda_b
    loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idt
    loss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt
    loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b
    return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b

def generator_forward_grad(img_a, img_b):
    _, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)
    return loss_g

def discriminator_forward(img_a, img_b, fake_a, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    loss_d = (loss_d_a + loss_d_b) * 0.5
    return loss_d

def discriminator_forward_a(img_a, fake_a):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    return loss_d_a

def discriminator_forward_b(img_b, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    return loss_d_b

# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50
def image_pool(images):
    num_imgs = 0
    image1 = []
    if isinstance(images, Tensor):
        images = images.asnumpy()
    return_images = []
    for image in images:
        if num_imgs < pool_size:
            num_imgs = num_imgs + 1
            image1.append(image)
            return_images.append(image)
        else:
            if random.uniform(0, 1) > 0.5:
                random_id = random.randint(0, pool_size - 1)

                tmp = image1[random_id].copy()
                image1[random_id] = image
                return_images.append(tmp)

            else:
                return_images.append(image)
    output = Tensor(return_images, ms.float32)
    if output.ndim != 4:
        raise ValueError("img should be 4d, but get shape {}".format(output.shape))
    return output

计算梯度和反向传播

其中梯度计算也是分开不同的模型来进行的,详情见如下代码:

from mindspore import value_and_grad

# 实例化求梯度的方法
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())

grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())

# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):
    net_d_a.set_grad(False)
    net_d_b.set_grad(False)

    fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)

    _, grads_g_a = grad_g_a(img_a, img_b)
    _, grads_g_b = grad_g_b(img_a, img_b)
    optimizer_rg_a(grads_g_a)
    optimizer_rg_b(grads_g_b)

    return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib

# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):
    net_d_a.set_grad(True)
    net_d_b.set_grad(True)

    loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)
    loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)

    loss_d = (loss_d_a + loss_d_b) * 0.5

    optimizer_d_a(grads_d_a)
    optimizer_d_b(grads_d_b)

    return loss_d

模型训练

训练分为两个主要部分:训练判别器和训练生成器,在前文的判别器损失函数中,论文采用了最小二乘损失代替负对数似然目标。

  • 训练判别器:训练判别器的目的是最大程度地提高判别图像真伪的概率。按照论文的方法需要训练判别器来最小化 E_{y-p_{data}(y)}[(D(y)-1)^{2}]

  • 训练生成器:如 CycleGAN 论文所述,我们希望通过最小化 E_{x-p_{data}(x)}[(D(x)-1)^{2}]来训练生成器,以产生更好的虚假图像。

下面定义了生成器和判别器的训练过程:

%%time
import os
import time
import random
import numpy as np
from PIL import Image
from mindspore import Tensor, save_checkpoint
from mindspore import dtype
from mindspore import load_checkpoint, load_param_into_net

def load_ckpt(net, ckpt_dir):
    param_GA = load_checkpoint(ckpt_dir)
    load_param_into_net(net, param_GA)

epochs = 1
save_step_num = 200
save_checkpoint_epochs = 1
save_ckpt_dir = './train_ckpt_outputs/'
save_img_dir='./image_outputs/'
resume=True

if(resume):
    g_a_ckpt = './ckpt/G_A_120.ckpt'
    g_b_ckpt = './ckpt/G_B_120.ckpt'
    d_a_ckpt = './ckpt/D_A_120.ckpt'
    d_b_ckpt = './ckpt/D_B_120.ckpt'

    load_ckpt(net_rg_a, g_a_ckpt)
    load_ckpt(net_rg_b, g_b_ckpt)
    load_ckpt(net_d_a, d_a_ckpt)
    load_ckpt(net_d_b, d_b_ckpt)
    
print('Start training!')



def train_loop(epoch):
    g_loss = []
    d_loss = []
    start_time_e = time.time()
    for step, data in enumerate(cyclegan_ds.create_dict_iterator()):
        start_time_s = time.time()
        img_a = data["image_A"]
        img_b = data["image_B"]
        res_g = train_step_g(img_a, img_b)
        fake_a = res_g[0]
        fake_b = res_g[1]
        if step==0:
            os.makedirs(save_img_dir, exist_ok=True)
            img_a_pil=Image.fromarray((img_a[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0)).asnumpy())
            img_b_pil=Image.fromarray((img_b[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0)).asnumpy())
            fake_a_pil=Image.fromarray((fake_a[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0)).asnumpy())
            fake_b_pil=Image.fromarray((fake_b[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0)).asnumpy())

            img_a_pil.save(f"./{save_img_dir}/{epoch}_img_a.jpg")
            img_b_pil.save(f"./{save_img_dir}/{epoch}_img_b.jpg")
            fake_a_pil.save(f"./{save_img_dir}/{epoch}_fake_a.jpg")
            fake_b_pil.save(f"./{save_img_dir}/{epoch}_fake_b.jpg")
            
        res_d = train_step_d(img_a, img_b, image_pool(fake_a), image_pool(fake_b))
        loss_d = float(res_d.asnumpy())
        step_time = time.time() - start_time_s

        res = []
        for item in res_g[2:]:
            res.append(float(item.asnumpy()))
        g_loss.append(res[0])
        d_loss.append(loss_d)

        if step % save_step_num == 0:
            print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
                  f"step:[{int(step):>4d}/{int(datasize):>4d}], "
                  f"time:{step_time:>3f}s,\n"
                  f"loss_g:{res[0]:.2f}, loss_d:{loss_d:.2f}, "
                  f"loss_g_a: {res[1]:.2f}, loss_g_b: {res[2]:.2f}, "
                  f"loss_c_a: {res[3]:.2f}, loss_c_b: {res[4]:.2f}, "
                  f"loss_idt_a: {res[5]:.2f}, loss_idt_b: {res[6]:.2f}")

    epoch_cost = time.time() - start_time_e
    per_step_time = epoch_cost / datasize
    mean_loss_d, mean_loss_g = sum(d_loss) / datasize, sum(g_loss) / datasize

    print(f"Epoch:[{int(epoch + 1):>3d}/{int(epochs):>3d}], "
          f"epoch time:{epoch_cost:.2f}s, per step time:{per_step_time:.2f}, "
          f"mean_g_loss:{mean_loss_g:.2f}, mean_d_loss:{mean_loss_d :.2f}")

    if epoch % save_checkpoint_epochs == 0:
        os.makedirs(save_ckpt_dir, exist_ok=True)
        save_checkpoint(net_rg_a, os.path.join(save_ckpt_dir, f"g_a_{epoch}.ckpt"))
        save_checkpoint(net_rg_b, os.path.join(save_ckpt_dir, f"g_b_{epoch}.ckpt"))
        save_checkpoint(net_d_a, os.path.join(save_ckpt_dir, f"d_a_{epoch}.ckpt"))
        save_checkpoint(net_d_b, os.path.join(save_ckpt_dir, f"d_b_{epoch}.ckpt"))

            
for t in range(epochs):
    train_loop(t)
    

print('End of training!')
Start training!
Epoch:[  1/  1], step:[   0/2000], time:34.012208s,
loss_g:2.02, loss_d:0.59, loss_g_a: 0.48, loss_g_b: 0.28, loss_c_a: 0.57, loss_c_b: 0.19, loss_idt_a: 0.17, loss_idt_b: 0.33
Epoch:[  1/  1], step:[ 200/2000], time:1.041357s,
loss_g:2.11, loss_d:0.71, loss_g_a: 0.11, loss_g_b: 0.88, loss_c_a: 0.38, loss_c_b: 0.35, loss_idt_a: 0.20, loss_idt_b: 0.19
Epoch:[  1/  1], step:[ 400/2000], time:1.043393s,
loss_g:2.03, loss_d:0.59, loss_g_a: 0.18, loss_g_b: 0.23, loss_c_a: 0.70, loss_c_b: 0.47, loss_idt_a: 0.18, loss_idt_b: 0.28
Epoch:[  1/  1], step:[ 600/2000], time:1.041804s,
loss_g:1.38, loss_d:0.61, loss_g_a: 0.14, loss_g_b: 0.10, loss_c_a: 0.31, loss_c_b: 0.45, loss_idt_a: 0.13, loss_idt_b: 0.25
Epoch:[  1/  1], step:[ 800/2000], time:1.041061s,
loss_g:2.66, loss_d:0.66, loss_g_a: 0.92, loss_g_b: 0.23, loss_c_a: 0.56, loss_c_b: 0.56, loss_idt_a: 0.25, loss_idt_b: 0.14
Epoch:[  1/  1], step:[1000/2000], time:1.040926s,
loss_g:1.66, loss_d:0.57, loss_g_a: 0.16, loss_g_b: 0.74, loss_c_a: 0.29, loss_c_b: 0.23, loss_idt_a: 0.14, loss_idt_b: 0.10
Epoch:[  1/  1], step:[1200/2000], time:1.042564s,
loss_g:1.22, loss_d:0.59, loss_g_a: 0.08, loss_g_b: 0.30, loss_c_a: 0.28, loss_c_b: 0.27, loss_idt_a: 0.15, loss_idt_b: 0.15
Epoch:[  1/  1], step:[1400/2000], time:1.042071s,
loss_g:94884.72, loss_d:87463.26, loss_g_a: 94495.55, loss_g_b: 384.89, loss_c_a: 1.75, loss_c_b: 1.10, loss_idt_a: 0.88, loss_idt_b: 0.55
Epoch:[  1/  1], epoch time:2139.83s, per step time:1.07, mean_g_loss:1690704.17, mean_d_loss:1467356.82
End of training!
CPU times: user 2h 8min 28s, sys: 30min 18s, total: 2h 38min 47s
Wall time: 35min 40s

模型推理

下面我们通过加载生成器网络模型参数文件来对原图进行风格迁移,结果中第一行为原图,第二行为对应生成的结果图。配合边缘检测模型可以实现从原图提取线稿图的作用。

边缘检测Dexined模型

import os
import cv2
import numpy as np
import time

import mindspore as ms
from mindspore import nn, ops
from mindspore import dataset as ds
from mindspore.amp import auto_mixed_precision
from mindspore.common import initializer as init
"""DexiNed边缘检测数据集"""
class Test_Dataset():
    """自定义数据集"""

    def __init__(self, data_root, mean_bgr, image_size):
        self.data = []
        imgs_ = os.listdir(data_root)
        self.names = []
        self.filenames = []
        for img in imgs_:
            if img.endswith(".png") or img.endswith(".jpg"):
                dir = os.path.join(data_root, img)
                self.names.append(dir)
                self.filenames.append(img)
        self.mean_bgr = mean_bgr
        self.image_size = image_size

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        image = cv2.imread(self.names[idx], cv2.IMREAD_COLOR)
        im_shape = (image.shape[0], image.shape[1])
        image = self.transform(img=image)
        return image, self.filenames[idx], im_shape

    def transform(self, img):
        img = cv2.resize(img, (self.image_size, self.image_size))
        img = np.array(img, dtype=np.float32)
        img -= self.mean_bgr
        img = img.transpose((2, 0, 1))
        return img
"""DexiNed 网络结构"""

def weight_init(net):
    for name, param in net.parameters_and_names():
        if 'weight' in name:
            param.set_data(
                init.initializer(
                    init.XavierNormal(),
                    param.shape,
                    param.dtype))
        if 'bias' in name:
            param.set_data(init.initializer('zeros', param.shape, param.dtype))


class CoFusion(nn.Cell):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv1 = nn.Conv2d(
            in_ch, 64, kernel_size=3,
            stride=1, padding=1, has_bias=True,
            pad_mode="pad", weight_init=init.XavierNormal())
        self.conv2 = nn.Conv2d(
            64, 64, kernel_size=3,
            stride=1, padding=1, has_bias=True,
            pad_mode="pad", weight_init=init.XavierNormal())
        self.conv3 = nn.Conv2d(
            64, out_ch, kernel_size=3,
            stride=1, padding=1, has_bias=True,
            pad_mode="pad", weight_init=init.XavierNormal())
        self.relu = nn.ReLU()
        self.norm_layer1 = nn.GroupNorm(4, 64)
        self.norm_layer2 = nn.GroupNorm(4, 64)

    def construct(self, x):
        attn = self.relu(self.norm_layer1(self.conv1(x)))
        attn = self.relu(self.norm_layer2(self.conv2(attn)))
        attn = ops.softmax(self.conv3(attn), axis=1)

        return ((x * attn).sum(1)).expand_dims(1)


class _DenseLayer(nn.Cell):
    def __init__(self, input_features, out_features):
        super(_DenseLayer, self).__init__()

        self.conv1 = nn.Conv2d(
            input_features, out_features, kernel_size=3,
            stride=1, padding=2, pad_mode="pad",
            has_bias=True, weight_init=init.XavierNormal())
        self.norm1 = nn.BatchNorm2d(out_features)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(
            out_features, out_features, kernel_size=3,
            stride=1, pad_mode="pad", has_bias=True,
            weight_init=init.XavierNormal())
        self.norm2 = nn.BatchNorm2d(out_features)
        self.relu = ops.ReLU()

    def construct(self, x):
        x1, x2 = x
        x1 = self.conv1(self.relu(x1))
        x1 = self.norm1(x1)
        x1 = self.relu1(x1)
        x1 = self.conv2(x1)
        new_features = self.norm2(x1)
        return 0.5 * (new_features + x2), x2


class _DenseBlock(nn.Cell):
    def __init__(self, num_layers, input_features, out_features):
        super(_DenseBlock, self).__init__()
        self.denselayer1 = _DenseLayer(input_features, out_features)
        input_features = out_features
        self.denselayer2 = _DenseLayer(input_features, out_features)
        if num_layers == 3:
            self.denselayer3 = _DenseLayer(input_features, out_features)
            self.layers = nn.SequentialCell(
                [self.denselayer1, self.denselayer2, self.denselayer3])
        else:
            self.layers = nn.SequentialCell(
                [self.denselayer1, self.denselayer2])

    def construct(self, x):
        x = self.layers(x)
        return x


class UpConvBlock(nn.Cell):
    def __init__(self, in_features, up_scale):
        super(UpConvBlock, self).__init__()
        self.up_factor = 2
        self.constant_features = 16

        layers = self.make_deconv_layers(in_features, up_scale)

        assert layers is not None, layers
        self.features = nn.SequentialCell(*layers)

    def make_deconv_layers(self, in_features, up_scale):
        layers = []
        all_pads = [0, 0, 1, 3, 7]
        for i in range(up_scale):
            kernel_size = 2 ** up_scale
            pad = all_pads[up_scale]  # kernel_size-1
            out_features = self.compute_out_features(i, up_scale)
            layers.append(nn.Conv2d(
                in_features, out_features,
                1, has_bias=True))
            layers.append(nn.ReLU())
            layers.append(nn.Conv2dTranspose(
                out_features, out_features, kernel_size,
                stride=2, padding=pad, pad_mode="pad",
                has_bias=True, weight_init=init.XavierNormal()))
            in_features = out_features
        return layers

    def compute_out_features(self, idx, up_scale):
        return 1 if idx == up_scale - 1 else self.constant_features

    def construct(self, x):
        return self.features(x)


class SingleConvBlock(nn.Cell):
    def __init__(self, in_features, out_features, stride,
                 use_bs=True
                 ):
        super().__init__()
        self.use_bn = use_bs
        self.conv = nn.Conv2d(
            in_features,
            out_features,
            1,
            stride=stride,
            pad_mode="pad",
            has_bias=True,
            weight_init=init.XavierNormal())
        self.bn = nn.BatchNorm2d(out_features)

    def construct(self, x):
        x = self.conv(x)
        if self.use_bn:
            x = self.bn(x)
        return x


class DoubleConvBlock(nn.Cell):
    def __init__(self, in_features, mid_features,
                 out_features=None,
                 stride=1,
                 use_act=True):
        super(DoubleConvBlock, self).__init__()

        self.use_act = use_act
        if out_features is None:
            out_features = mid_features
        self.conv1 = nn.Conv2d(
            in_features,
            mid_features,
            3,
            padding=1,
            stride=stride,
            pad_mode="pad",
            has_bias=True,
            weight_init=init.XavierNormal())
        self.bn1 = nn.BatchNorm2d(mid_features)
        self.conv2 = nn.Conv2d(
            mid_features,
            out_features,
            3,
            padding=1,
            pad_mode="pad",
            has_bias=True,
            weight_init=init.XavierNormal())
        self.bn2 = nn.BatchNorm2d(out_features)
        self.relu = nn.ReLU()

    def construct(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        if self.use_act:
            x = self.relu(x)
        return x


class maxpooling(nn.Cell):
    def __init__(self):
        super(maxpooling, self).__init__()
        self.pad = nn.Pad(((0,0),(0,0),(1,1),(1,1)), mode="SYMMETRIC")
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')

    def construct(self, x):
        x = self.pad(x)
        x = self.maxpool(x)
        return x
class DexiNed(nn.Cell):

    def __init__(self):
        super(DexiNed, self).__init__()
        self.block_1 = DoubleConvBlock(3, 32, 64, stride=2,)
        self.block_2 = DoubleConvBlock(64, 128, use_act=False)
        self.dblock_3 = _DenseBlock(2, 128, 256)  # [128,256,100,100]
        self.dblock_4 = _DenseBlock(3, 256, 512)
        self.dblock_5 = _DenseBlock(3, 512, 512)
        self.dblock_6 = _DenseBlock(3, 512, 256)

        self.maxpool = maxpooling()

        self.side_1 = SingleConvBlock(64, 128, 2)
        self.side_2 = SingleConvBlock(128, 256, 2)
        self.side_3 = SingleConvBlock(256, 512, 2)
        self.side_4 = SingleConvBlock(512, 512, 1)
        self.side_5 = SingleConvBlock(512, 256, 1)  
        
        # right skip connections, figure in Journal paper
        self.pre_dense_2 = SingleConvBlock(128, 256, 2)
        self.pre_dense_3 = SingleConvBlock(128, 256, 1)
        self.pre_dense_4 = SingleConvBlock(256, 512, 1)
        self.pre_dense_5 = SingleConvBlock(512, 512, 1)
        self.pre_dense_6 = SingleConvBlock(512, 256, 1)

        self.up_block_1 = UpConvBlock(64, 1)
        self.up_block_2 = UpConvBlock(128, 1)
        self.up_block_3 = UpConvBlock(256, 2)
        self.up_block_4 = UpConvBlock(512, 3)
        self.up_block_5 = UpConvBlock(512, 4)
        self.up_block_6 = UpConvBlock(256, 4)
        self.block_cat = SingleConvBlock(6, 1, stride=1, use_bs=False)

    def slice(self, tensor, slice_shape):
        t_shape = tensor.shape
        height, width = slice_shape
        if t_shape[-1] != slice_shape[-1]:
            new_tensor = ops.interpolate(
                tensor,
                sizes=(height, width),
                mode='bilinear',
                coordinate_transformation_mode="half_pixel")
        else:
            new_tensor = tensor
        return new_tensor

    def construct(self, x):
        assert x.ndim == 4, x.shape

        # Block 1
        block_1 = self.block_1(x)
        block_1_side = self.side_1(block_1)

        # Block 2
        block_2 = self.block_2(block_1)
        block_2_down = self.maxpool(block_2)
        block_2_add = block_2_down + block_1_side
        block_2_side = self.side_2(block_2_add)
        # Block 3
        block_3_pre_dense = self.pre_dense_3(block_2_down)
        block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense])
        block_3_down = self.maxpool(block_3)  # [128,256,50,50]
        block_3_add = block_3_down + block_2_side
        block_3_side = self.side_3(block_3_add)
        # Block 4
        block_2_resize_half = self.pre_dense_2(block_2_down)
        block_4_pre_dense = self.pre_dense_4(
            block_3_down + block_2_resize_half)
        block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense])
        block_4_down = self.maxpool(block_4)
        block_4_add = block_4_down + block_3_side
        block_4_side = self.side_4(block_4_add)
        # Block 5
        block_5_pre_dense = self.pre_dense_5(
            block_4_down)  # block_5_pre_dense_512 +block_4_down
        block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense])
        block_5_add = block_5 + block_4_side
        # Block 6
        block_6_pre_dense = self.pre_dense_6(block_5)
        block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense])
        # upsampling blocks
        out_1 = self.up_block_1(block_1)
        out_2 = self.up_block_2(block_2)
        out_3 = self.up_block_3(block_3)
        out_4 = self.up_block_4(block_4)
        out_5 = self.up_block_5(block_5)
        out_6 = self.up_block_6(block_6)
        results = [out_1, out_2, out_3, out_4, out_5, out_6]
        # concatenate multiscale outputs
        op = ops.Concat(1)
        block_cat = op(results)

        block_cat = self.block_cat(block_cat)  # Bx1xHxW
        results.append(block_cat)
        return results
'''将输入图像规格化到指定范围'''
def image_normalization(img, img_min=0, img_max=255, epsilon=1e-12):
    img = np.float32(img)
    img = (img - np.min(img)) * (img_max - img_min) / \
        ((np.max(img) - np.min(img)) + epsilon) + img_min
    return img
'''对DexiNed模型的输出数据进行后处理'''
def fuse_DNoutput(img):
    edge_maps = []
    tensor = img
    for i in tensor:
        sigmoid = ops.Sigmoid()
        output = sigmoid(i).numpy()
        edge_maps.append(output)
    tensor = np.array(edge_maps)
    idx = 0
    tmp = tensor[:, idx, ...]
    tmp = np.squeeze(tmp)
    preds = []
    for i in range(tmp.shape[0]):
        tmp_img = tmp[i]
        tmp_img = np.uint8(image_normalization(tmp_img))
        tmp_img = cv2.bitwise_not(tmp_img)
        preds.append(tmp_img)
        if i == 6:
            fuse = tmp_img
            fuse = fuse.astype(np.uint8)
    idx += 1
    return fuse
"""DexiNed 检测."""

def test(imgs,dexined_ckpt):
    
    if not os.path.isfile(dexined_ckpt):
        raise FileNotFoundError(
            f"Checkpoint file not found: {dexined_ckpt}")
    print(f"DexiNed ckpt path : {dexined_ckpt}")
    # os.makedirs(dexined_output_dir, exist_ok=True)
    model = DexiNed()
    # model = auto_mixed_precision(model, 'O2')
    ms.load_checkpoint(dexined_ckpt, model)
    model.set_train(False)
    preds = []
    origin = []
    total_duration = []
    print('Start dexined testing....')
    for img in imgs.create_dict_iterator():
        filename = str(img["names"])[2:-2]
        # print(filename)
        # output_dir_f = os.path.join(dexined_output_dir, filename)
        image = img["data"]
        origin.append(filename)
        end = time.perf_counter()
        pred = model(image)
        img_h = img["img_shape"][0, 0]
        img_w = img["img_shape"][0, 1]
        pred = fuse_DNoutput(pred)
        dexi_img = cv2.resize(
            pred, (int(img_w.asnumpy()), int(img_h.asnumpy())))
        # cv2.imwrite("output.jpg", dexi_img)
        tmp_duration = time.perf_counter() - end
        total_duration.append(tmp_duration)
        preds.append(pred)
    total_duration_f = np.sum(np.array(total_duration))
    print("FPS: %f.4" % (len(total_duration) / total_duration_f))
    return preds,origin
import os
import numpy as np
from PIL import Image
import mindspore.dataset as ds
import matplotlib.pyplot as plt
import mindspore.dataset.vision as vision
from mindspore.dataset import transforms
from mindspore import load_checkpoint, load_param_into_net

# 加载权重文件
def load_ckpt(net, ckpt_dir):
    param_GA = load_checkpoint(ckpt_dir)
    load_param_into_net(net, param_GA)

    
#模型参数地址
g_a_ckpt = './ckpt/G_A_120.ckpt'
dexined_ckpt = "./ckpt/dexined.ckpt"

#图片输入地址
img_path='./ckpt/jt'
#输出地址
save_path='./result'

load_ckpt(net_rg_a, g_a_ckpt)

os.makedirs(save_path, exist_ok=True)
# 图片推理
fig = plt.figure(figsize=(16, 4), dpi=64)
def eval_data(dir_path, net, a):
    my_dataset = Test_Dataset(
        dir_path, mean_bgr=[167.15, 146.07, 124.62], image_size=512)
    
    dataset = ds.GeneratorDataset(
        my_dataset, column_names=[
            "data", "names", "img_shape"])
    dataset = dataset.batch(1, drop_remainder=True)
    pres ,origin= test(dataset,dexined_ckpt)
    for i, data in enumerate(pres):
        img =ms.Tensor((np.array([data,data,data])/255-0.5)*2).unsqueeze(0)
        fake = net(img.to(ms.float32))
        
        fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))
        img = (Image.open(os.path.join(img_path,origin[i])).convert('RGB'))

        fake_pil=Image.fromarray(fake.asnumpy())
        fake_pil.save(f"{save_path}/{i}.jpg")
        
        if i<8:
            fig.add_subplot(2, 8, min(i+1+a, 16))
            plt.axis("off")
            plt.imshow(np.array(img))

            fig.add_subplot(2, 8, min(i+9+a, 16))
            plt.axis("off")
            plt.imshow(fake.asnumpy())

eval_data(img_path,net_rg_a, 0)

plt.show()
DexiNed ckpt path : ./ckpt/dexined.ckpt
Start dexined testing....
FPS: 9.150511.4

参考

[1] I. Goodfellow. NIPS 2016 tutorial: Generative ad-versarial networks. arXiv preprint arXiv:1701.00160,2016. 2, 4, 5

[2] A. Shrivastava, T. Pfister, O. Tuzel, J. Susskind, W. Wang, R. Webb. Learning from simulated and unsupervised images through adversarial training. In CVPR, 2017. 3, 5, 6, 7

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2191572.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

指针赋值or常数赋值

int main (){int a 10;int b ;b a;int *c &a;int *d c; } 常数 a,b赋值&#xff1a; 都是将存储的值&#xff08;10&#xff09;赋值给别人。 指针赋值也是类似的&#xff1a; 指针存储的值&#xff08;&a&#xff09;为地址&#xff0c;就是把c指向的地址赋值给…

C语言 | Leetcode C语言题解之第458题可怜的小猪

题目&#xff1a; 题解&#xff1a; int poorPigs(int buckets, int minutesToDie, int minutesToTest){int base minutesToTest / minutesToDie 1;int pigs ceil(log(buckets)/log(base));return pigs; }

24-10-2-读书笔记(二十二)-《契诃夫文集》(一)上([俄] 契诃夫 [译] 汝龙)啊!真想生活。

文章目录 《契诃夫文集》&#xff08;一&#xff09;上&#xff08;[俄] 契诃夫 [译] 汝龙 &#xff09;早期生活——塔甘罗格&#xff08;人物家庭简介&#xff09;学生时期——莫斯科&#xff08;写作与学习&#xff09;流浪时期——哈萨林&#xff08;游历与流浪&#xff09…

VCSEL驱动电路

1.1 驱动电路 发射端可用MOS管控制VCSEL二极管负极方式发出脉冲光(正极对地)&#xff0c;具体作用过程如下&#xff1a; Step 1: MOS管断开, C2 电容充电(左侧HV)&#xff1b; Step 2: 信号控制MOS管打开&#xff1b; Step 3: MOS管打开后, C2电容左侧电压降为0V, 右侧变为…

当你系统有多个模块的时候,要设计统一入口页了。

本次给大家带来一批系统统一入口页的界面&#xff0c;这次都是科技感十足的界面。入口页将所有系统集中展示出来&#xff0c;并辅以其他设计元素进行氛围烘托&#xff0c;别看简单&#xff0c;但真的不好设计。

免费送源码:Java+B/S+MySQL 基于springboot网上书店管理系统 计算机毕业设计原创定制

基于springboot网上书店管理系统 摘 要 网上书店管理系统采用B/S结构、java开发语言、以及Mysql数据库等技术。系统主要分为管理员和用户两部分&#xff0c;管理员管理主要功能包括&#xff1a;首页、网站管理&#xff08;轮播图、网站公告&#xff09;人员管理&#xff08;管…

MATLAB - 机械臂手眼标定(眼在手内) - 估计安装在机器人上的移动相机的姿态

系列文章目录 前言 本示例展示了如何为装有手眼构型摄像头的机械臂或机械手执行和验证手眼校准。 一、概述 执行手眼校准有助于操作配备末端执行器&#xff08;简称 “手”&#xff09;的机械臂&#xff0c;该末端执行器依赖于摄像头提供的视觉数据。一旦完成了眼在手外的校准&…

数据结构阶段测试2的一点小补充

数据结构阶段测试2的一点小补充 1.已知⼩根堆为8,15,10,21,34,16,12&#xff0c;删除关键字8之后需重建堆&#xff0c;最后的叶⼦ 节点为() A. 34 B. 21 C. 16 D. 12 解题思路 向下调整算法删除堆顶元素 &#x1f4a1; 答案&#xff1a;C 删除堆顶元素的思路&#xff1a; …

环境对象刺激,recordnunber,记忆柱,记忆柱群,主注意对象,目的对象,状态中枢,奖惩预期,思维等等之间的联系

我们清醒状态下&#xff0c;随时都有目的&#xff0c;目的控制影响着我们思想行为的方向。目的是用对象来表征的&#xff08;目的对象&#xff09;&#xff0c;对象之所以能够表征目的&#xff0c;是因为对象能够被&#xff08;状态性&#xff09;赋值&#xff08;任何赋值都是…

UE4 材质学习笔记03(翻书(Flipbook)动画/环境混合)

一.FlipBook Animation 如果你想让游戏以每秒30帧的速度运行&#xff0c;所有内容都必须在33毫秒内渲染出来&#xff0c; 如果你想让游戏以每秒60帧的速度运行的话&#xff0c;必须在16毫秒内。 所以当一个效果需要很多细节的时候&#xff0c;往往会离线创建它&#xff0c;然…

【Qt】控件概述(4)—— 输出类控件

输出类控件 1. QLineEdit——单行输入框2. QTextEdit——多行输入框3. QComboBox——下拉框4. QSpinBox——微调框5. QDateEdit && QTimeEdit && QDateTimeEdit6 QDial——旋钮7. QSlider——滑动条 1. QLineEdit——单行输入框 QLineEdit是一个单行的输入框&…

BUU刷题-Pwn-shanghai2018_baby_arm(ARM_ROP_csu_init,ARM架构入门)

解题思路&#xff1a; 泄露或修改内存数据&#xff1a; 堆地址&#xff1a;无需栈地址&#xff1a;无需libc地址&#xff1a;无需BSS段地址&#xff1a;无需 劫持程序执行流程&#xff1a;ARM_ROP && mprotect函数(运行内存权限修改) && [[ARM_ROP_csu_init]…

【AI自然语言处理应用】通义晓蜜CCAI

通义晓蜜CCAI-对话分析AIO 对话分析AIO&#xff0c;即对话分析all-in-one API&#xff0c;是基于深度调优的对话大模型&#xff0c; 为营销服类产品提供智能化升级所需的生成式摘要总结、质检、分析等能力的官方应用。 面向对象&#xff1a;开发者、自研企业、传统呼叫中心采购…

02 nth_element 与第k小

题目&#xff1a; 方案一&#xff1a;sort排序 #include<bits/stdc.h> using namespace std;int main() {int n;int k;cin>>n>>k;int a[n]{0};for(int i0;i<n;i){cin>>a[i];}sort(a,an); cout<<a[k]<<endl;}方案二&#xff1a;…

【机器学习(十一)】糖尿病数据集分类预测案例分析—XGBoost分类算法—Sentosa_DSML社区版

文章目录 一、XGBoost算法二、Python代码和Sentosa_DSML社区版算法实现对比(一) 数据读入和统计分析(二)数据预处理(三)模型训练与评估(四)模型可视化 三、总结 一、XGBoost算法 关于集成学习中的XGBoost算法原理&#xff0c;已经进行了介绍与总结&#xff0c;相关内容可参考【…

leetcode面试题17.04:消失的数字(C语言版)

思路1 先排序&#xff0c;再依次查找&#xff0c;如果下一个值不等于前一个1&#xff0c;那么下一个值就是消失数字。 时间复杂度分析&#xff1a;冒泡排序的时间复杂度为O(N^2)&#xff0c;qsort排序时间复杂度为O(N*logN)。因此该思路不可行。 思路2 求和0到N&#xff0c;再减…

Python爬虫使用实例-mdrama

一个Python爬虫使用实例&#xff1a;主要用于下载指定的剧集音频。分别从网页和json文件中获取剧集的title和剧集中所存在音频的id&#xff0c;调用you-get&#xff0c;最后自动重命名下载文件夹为剧集名title。 目标网址&#xff1a; https://www.missevan.com/mdrama/其中为…

【C++】关键字+命名空间

大家好&#xff0c;我是苏貝&#xff0c;本篇博客带大家了解C的命名空间&#xff0c;如果你觉得我写的还不错的话&#xff0c;可以给我一个赞&#x1f44d;吗&#xff0c;感谢❤️ 目录 一. 关键字二. 命名空间2.1 命名空间的定义2.2 命名空间的使用a. 命名空间名称作用域限定…

R包的安装、加载以及如何查看帮助文档

0x01 如何安装R包 一、通过R 内置函数安装&#xff08;常用&#xff09; 1.安装CRAN的R包 install.packages()是一个用于安装 R 包的重要函数。 语法&#xff1a;install.packages(pkgs, repos getOption("repos"),...) 其中&#xff1a; pkgs&#xff1a;要安…

SpringCloud Alibaba - Eureka注册中心,Nacos配置中心

Eureka 1、创建服务端 server:port: 8761 # eureka 默认端口spring:application:name: eureka-server # 应用名称&#xff08;微服务中建议必须定义应用名称&#xff09; SpringBootApplication EnableEurekaServer // 开启eureka注册中心功能 public class EurekaServerAppli…