Mindspore框架:CycleGAN模型实现图像风格迁移算法
- Mindspore框架CycleGAN模型实现图像风格迁移|(一)CycleGAN神经网络模型构建
- Mindspore框架CycleGAN模型实现图像风格迁移|(二)实例数据集(苹果2橘子)
- Mindspore框架CycleGAN模型实现图像风格迁移|(三)损失函数计算
- Mindspore框架CycleGAN模型实现图像风格迁移|(四)CycleGAN模型训练
- Mindspore框架CycleGAN模型实现图像风格迁移|(五)CycleGAN模型推理与资源下载
实例数据集(苹果2橘子)
安装依赖库:
pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
pip install download
下载数据集到本地:图片来源于ImageNet,该数据集共有17个数据包,本文只使用了其中的苹果橘子部分。图像被统一缩放为256×256像素大小,其中用于训练的苹果图片996张、橘子图片1020张,用于测试的苹果图片266张、橘子图片248张。
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/CycleGAN_apple2orange.zip"
download(url, ".", kind="zip", replace=True)
加载数据集:使用 MindSpore 的 MindDataset 接口读取和解析数据集。
from mindspore.dataset import MindDataset
# 读取MindRecord格式数据
name_mr = "./CycleGAN_apple2orange/apple2orange_train.mindrecord"
data = MindDataset(dataset_files=name_mr)
print("Datasize: ", data.get_dataset_size())
batch_size = 1
dataset = data.batch(batch_size)
datasize = dataset.get_dataset_size()
可视化数据集:通过 create_dict_iterator 函数将数据转换成字典迭代器,然后使用 matplotlib 模块可视化部分训练数据。
import numpy as np
import matplotlib.pyplot as plt
mean = 0.5 * 255
std = 0.5 * 255
plt.figure(figsize=(12, 5), dpi=60)
for i, data in enumerate(dataset.create_dict_iterator()):
if i < 5:
show_images_a = data["image_A"].asnumpy()
show_images_b = data["image_B"].asnumpy()
plt.subplot(2, 5, i+1)
show_images_a = (show_images_a[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
plt.imshow(show_images_a)
plt.axis("off")
plt.subplot(2, 5, i+6)
show_images_b = (show_images_b[0] * std + mean).astype(np.uint8).transpose((1, 2, 0))
plt.imshow(show_images_b)
plt.axis("off")
else:
break
plt.show()