Pix2Pix实现图像转换
Pix2Pix概述
Pix2Pix是一种基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks)的图像转换模型,由Phillip Isola等人在2017年提出。它能够将语义/标签图像转换为真实图片、灰度图转换为彩色图、航空图转换为地图、白天图转换为夜晚图、线稿图转换为实物图等。Pix2Pix的创新之处在于使用相同的架构和目标函数,通过不同的数据集训练实现多种图像转换任务。
基础原理
cGAN与传统GAN的区别在于,cGAN的生成器以输入图片为指导信息生成“假”图像,而GAN的生成器则以随机噪声为输入。Pix2Pix的生成器使用U-Net结构,通过编码和解码输入图像生成输出图像;判别器使用PatchGAN结构,通过判断图像的局部区域(Patch)来区分真实图像和生成图像。cGAN的目标是通过生成器和判别器的博弈,使生成器生成的图像越来越接近真实图像。
准备工作
配置环境文件
本案例支持在GPU、CPU和Ascend平台的动静态模式下运行。
准备数据
使用指定的数据集,已处理好的外墙(facades)数据集,可直接使用MindSpore读取。
数据展示
from mindspore import dataset as ds
import matplotlib.pyplot as plt
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):
plt.subplot(3, 10, i)
plt.axis("off")
plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()
创建网络
生成器G结构
生成器使用U-Net结构,通过编码和解码输入图像生成输出图像。
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
class UNetSkipConnectionBlock(nn.Cell):
def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False, submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
super(UNetSkipConnectionBlock, self).__init__()
# 定义下采样和上采样的卷积层、激活函数和归一化层
down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4, stride=2, padding=1, has_bias=False, pad_mode='pad')
down_relu = nn.LeakyReLU(alpha)
up_relu = nn.ReLU()
# 定义下采样和上采样的卷积层、激活函数和归一化层
if outermost:
up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, pad_mode='pad')
model = [down_conv] + [submodule] + [up_relu, up_conv, nn.Tanh()]
elif innermost:
up_conv = nn.Conv2dTranspose(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, has_bias=False, pad_mode='pad')
model = [down_relu, down_conv] + [up_relu, up_conv, nn.BatchNorm2d(outer_nc)]
else:
up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, has_bias=False, pad_mode='pad')
model = [down_relu, down_conv, nn.BatchNorm2d(inner_nc)] + [submodule] + [up_relu, up_conv, nn.BatchNorm2d(outer_nc)]
if dropout:
model.append(nn.Dropout(p=0.5))
self.model = nn.SequentialCell(model)
self.skip_connections = not outermost
def construct(self, x):
out = self.model(x)
if self.skip_connections:
out = ops.concat((out, x), axis=1)
return out
基于UNet的生成器
class UNetGenerator(nn.Cell):
def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
super(UNetGenerator, self).__init__()
# 定义UNet生成器的各个层次
unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None, norm_mode=norm_mode, innermost=True)
for _ in range(n_layers - 5):
unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block, norm_mode=norm_mode, dropout=dropout)
unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block, norm_mode=norm_mode)
unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block, norm_mode=norm_mode)
unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block, norm_mode=norm_mode)
self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block, outermost=True, norm_mode=norm_mode)
def construct(self, x):
return self.model(x)
判别器D结构
判别器使用PatchGAN结构,通过判断图像的局部区域(Patch)来区分真实图像和生成图像。
import mindspore.nn as nn
class ConvNormRelu(nn.Cell):
def __init__(self, in_planes, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='batch', pad_mode='CONSTANT', use_relu=True, padding=None):
super(ConvNormRelu, self).__init__()
norm = nn.BatchNorm2d(out_planes)
if norm_mode == 'instance':
norm = nn.BatchNorm2d(out_planes, affine=False)
has_bias = (norm_mode == 'instance')
if not padding:
padding = (kernel_size - 1) // 2
if pad_mode == 'CONSTANT':
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias, padding=padding)
layers = [conv, norm]
else:
paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
pad = nn.Pad(paddings=paddings, mode=pad_mode)
conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
layers = [pad, conv, norm]
if use_relu:
relu = nn.ReLU()
if alpha > 0:
relu = nn.LeakyReLU(alpha)
layers.append(relu)
self.features = nn.SequentialCell(layers)
def construct(self, x):
output = self.features(x)
return output
class Discriminator(nn.Cell):
def __init__(self, in_planes=6, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
super(Discriminator, self).__init__()
kernel_size = 4
layers = [
nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
nn.LeakyReLU(alpha)
]
nf_mult = ndf
for i in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** i, 8) * ndf
layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8) * ndf
layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
self.features = nn.SequentialCell(layers)
def construct(self, x, y):
x_y = ops.concat((x, y), axis=1)
output = self.features(x_y)
return output
Pix2Pix的生成器和判别器初始化
import mindspore.nn as nn
from mindspore.common import initializer as init
继续训练过程
随着训练的进行,我们可以在每个周期结束时保存生成器的检查点,并可视化一些中间结果。我们也可以监控损失的变化来确定模型的训练情况。
训练完成
一旦训练完成,我们可以使用生成器来生成新的图像。通过提供输入图像,生成器可以生成相应的输出图像。
验证生成效果
在训练完成后,我们使用部分数据进行验证,看看生成器的效果。
# 验证生成器效果
from mindspore import load_checkpoint
# 加载生成器模型
load_checkpoint("results/ckpt/Generator.ckpt", net_generator)
# 可视化函数
def visualize_result(input_image, generated_image, target_image, epoch, idx):
plt.figure(figsize=(15, 5), dpi=140)
plt.subplot(1, 3, 1)
plt.axis("off")
plt.title("Input Image")
plt.imshow((input_image.transpose(1, 2, 0) + 1) / 2)
plt.subplot(1, 3, 2)
plt.axis("off")
plt.title("Generated Image")
plt.imshow((generated_image.transpose(1, 2, 0) + 1) / 2)
plt.subplot(1, 3, 3)
plt.axis("off")
plt.title("Target Image")
plt.imshow((target_image.transpose(1, 2, 0) + 1) / 2)
plt.suptitle(f"Epoch {epoch}, Step {idx}")
plt.show()
# 验证集
val_dataset = ds.MindDataset("./dataset/dataset_pix2pix/val.mindrecord", columns_list=["input_images", "target_images"], shuffle=False)
# 验证生成器
val_data_iter = val_dataset.create_dict_iterator(output_numpy=True)
for idx, data in enumerate(val_data_iter):
input_image = Tensor(data["input_images"])
target_image = Tensor(data["target_images"])
generated_image = net_generator(input_image)
visualize_result(input_image.asnumpy(), generated_image.asnumpy(), target_image.asnumpy(), epoch_num, idx)
if idx >= 5: # 仅展示部分验证结果
break
通过本文教程,我们使用MindSpore实现了Pix2Pix模型,包括数据准备、生成器和判别器的搭建、训练和验证。在训练过程中,我们使用了cGAN和L1损失的组合来优化生成器。最终,我们展示了模型的训练效果和生成结果。
在实际应用中,Pix2Pix模型可以用于各种图像到图像的转换任务,比如从卫星图像生成地图、将灰度图像转换为彩色图像等。希望本文的教程对您理解Pix2Pix模型有所帮助,并能在实际项目中应用此模型。