深度学习:基于MindSpore实现CycleGAN壁画修复

news2024/11/28 18:59:08

关于CycleGAN的基础知识可参考:

深度学习:CycleGAN图像风格迁移转换-CSDN博客

以及MindSpore官方的教学视频:

CycleGAN图像风格迁移转换_哔哩哔哩_bilibili

本案例将基于CycleGAN实现破损草图到线稿图的转换

数据集

本案例使用的数据集里面的图片为经图线稿图数据。图像被统一缩放为256×256像素大小,其中用于训练的线稿图片25654张、草图图片25654张,用于测试的线稿图片100张、草图图片116张。

这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理。

DexiNed

 DexiNed(Dense Extreme Inception Network for Edge Detection)是一个为边缘检测任务设计的深度卷积神经网络模型。它由两个主要部分组成:Dexi和上采样网络(USNet)。

Dexi

这是模型的主要部分,包含六个编码块,每个块由多个子块组成,子块中包含卷积层、批量归一化和ReLU激活函数。从第二个块开始,引入了跳跃连接(skip connections),以保留不同层次的边缘特征。这些块的输出特征图被送入上采样网络以生成中间边缘图。

上采样网络(USNet)

这个部分由多个上采样块组成,每个块包括卷积层和反卷积层(也称为转置卷积层)。USNet的作用是将Dexi输出的低分辨率特征图上采样到更高的分辨率,以生成清晰的边缘图。

卷积层用于提取特征,反卷积层(或转置卷积层)用于将特征图的空间尺寸增大。

损失函数

 DexiNed模型使用的损失函数是专门为边缘检测任务设计的,它在一定程度上受到了BDCN(Bi-directional Cascade Network)损失函数的启发,并进行了一些修改和优化。这个损失函数的目的是在训练过程中平衡正面(正样本)和负面(负样本)的边缘样本比例,从而提高边缘检测的准确性。

损失函数定义为:

其中,

Li 是第 i 个输出的损失,λi 是对应的权重,用于平衡正负样本的比例。具体的 Li 计算方式为:

 DexiNed数据集

DexiNed模型的训练数据集主要是为边缘检测任务设计的高质量数据集。在论文中提到了两个主要的数据集:

  1. BIPED (Barcelona Images for Perceptual Edge Detection):这是一个特别为边缘检测设计的大规模数据集,包含详细的边缘标注信息。它由250张真实世界的图像组成,图像分辨率为1280×720像素,主要描绘城市环境场景。这些图像的边缘通过手动标注生成,以确保边缘检测的准确性。

  2. MDBD (Multicue Dataset for Boundary Detection):这是一个用于边界检测的数据集,也适用于边缘检测任务。它由100个高清图像组成,每个图像有多个参与者的标注,适用于训练和评估边缘检测算法。

DexiNed模型需要成对的数据来进行训练,即每张输入图像都需要有一个对应的标注图像(Ground Truth, GT)。这些标注图像详细地标出了图像中边缘的位置,模型通过比较预测边缘和这些标注来学习如何准确地检测边缘。

DexiNed在本例中主要用于将彩色图片转化为线稿图,随后将线稿图输入CycleGAN,得到输出。

基于MindSpore的壁画修复

加载数据集

#下载数据集
from download import download

url = "https://6169fb4615b14dbcb6b2cb1c4eb78bb2.obs.cn-north-4.myhuaweicloud.com/Cyc_line.zip"

download(url, "./localdata", kind="zip", replace=True)
from __future__ import division
import math
import numpy as np

import os
import multiprocessing

import mindspore.dataset as de
import mindspore.dataset.vision as vision

"""数据集分布式采样器"""
class DistributedSampler:
    """Distributed sampler."""
    def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            print("***********Setting world_size to 1 since it is not passed in ******************")
            num_replicas = 1
        if rank is None:
            print("***********Setting rank to 0 since it is not passed in ******************")
            rank = 0
        self.dataset_size = dataset_size
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle

    def __iter__(self):
        # deterministically shuffle based on epoch
        if self.shuffle:
            indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
            # np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
            indices = indices.tolist()
            self.epoch += 1
            # change to list type
        else:
            indices = list(range(self.dataset_size))

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples
# 加载CycleGAN数据集
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']

# 判断当前文件是否为图片
def is_image_file(filename):
    return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)

# 定义一个函数用于从指定目录中创建数据集列表
def make_dataset(dir_path, max_dataset_size=float("inf")):
    # 初始化一个空列表用来存储图片路径
    images = []
    # 确保提供的dir_path是一个有效的目录
    assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path
    # 遍历目录下的所有文件,将图片的文件路径存入images列表
    for root, _, fnames in sorted(os.walk(dir_path)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
    # 返回指定长度的图片路径列表
    return images[:min(max_dataset_size, len(images))]

# CycleGAN中没有成对出现,但是分属两个领域的图片数据
class UnalignedDataset:
    '''
    此数据集类能够加载未对齐或未配对的数据集。
    需要两个目录来存放来自领域A和B的训练图片。
    可以使用'--dataroot /path/to/data'这样的标志来训练模型。
    同样,在测试时也需要准备两个目录。
    返回:两个领域的图片路径列表。
    '''
    def __init__(self, dataroot, max_dataset_size=float("inf"), use_random=True):
        # 根据指定根路径生成A\B领域数据的文件夹路径
        self.dir_A = os.path.join(dataroot, 'trainA')
        self.dir_B = os.path.join(dataroot, 'trainB')
        
        # 领域A图片数据的路径
        self.A_paths = sorted(make_dataset(self.dir_A, max_dataset_size))
        # 领域B图片数据的路径
        self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size))
        # 领域A的数据长度
        self.A_size = len(self.A_paths)
        # 领域B的数据长度
        self.B_size = len(self.B_paths)
        # 根据参数决定是否随机化
        self.use_random = use_random
    # 从数据集中根据给定的索引 index 获取一个样本对(分别来自领域A和领域B的一张图片)
    def __getitem__(self, index):
        # 数据A的索引
        index_A = index % self.A_size
        # 数据B的索引
        index_B = index % self.B_size
        # 每遍历完所有图片后会重新随机排序领域A中的图片路径列表,并且从领域B中随机选取图片。
        if index % max(self.A_size, self.B_size) == 0 and self.use_random:
            random.shuffle(self.A_paths)
            index_B = random.randint(0, self.B_size - 1)
        # 获取指定下标的图片路径
        A_path = self.A_paths[index_A]
        B_path = self.B_paths[index_B]
        # 获取图片对象
        A_img = np.array(Image.open(A_path).convert('RGB'))
        B_img = np.array(Image.open(B_path).convert('RGB'))
        # 返回领域A和B的图片
        return A_img, B_img
    
    def __len__(self):
        return max(self.A_size, self.B_size)
def create_dataset(dataroot, batch_size=1, use_random=True, device_num=1, rank=0, max_dataset_size=float("inf"), image_size=256):
    """
    创建数据集
    该数据集类可以加载用于训练或测试的图像。
    参数:
        dataroot (str): 图像根目录。
        batch_size (int): 批处理大小,默认为1。
        use_random (bool): 是否使用随机化,默认为True。
        device_num (int): 设备数量,默认为1。
        rank (int): 当前设备的排名,默认为0。
        max_dataset_size (float): 数据集的最大大小,默认为无穷大。
        image_size (int): 图像的尺寸,默认为256x256。
    返回:
        RGB图像列表。
    """
    shuffle = use_random # 是否打乱数据集
    # 获取系统可用的CPU核心数
    cores = multiprocessing.cpu_count()
    # 计算并行工作的线程数,根据设备数量分配
    num_parallel_workers = min(1, int(cores / device_num))
    # 定义归一化时使用的均值和标准差
    # 三个通道的均值和房擦汗都是127.5
    mean = [0.5 * 255] * 3
    std = [0.5 * 255] * 3
    # 创建数据集(未对齐)
    dataset = UnalignedDataset(dataroot, max_dataset_size=max_dataset_size, use_random=use_random)
    # 使用DistributedSampler来实现分布式采样
    distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
    # 创建GeneratorDataset,指定列名,并使用之前创建的sampler和并行工作线程数
    ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
                             sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
    # 指定数据增强操作
    if use_random:
        trans = [
            # 图片随机裁剪变比例
            vision.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
            # 水平翻转,概率为0.5
            vision.RandomHorizontalFlip(prob=0.5),
            # 图片数据归一化
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]
    else:  # 如果不启用随机化,则只进行简单的缩放和归一化
        trans = [
            C.Resize((image_size, image_size)),  # 固定大小缩放
            C.Normalize(mean=mean, std=std),  # 归一化
            C.HWC2CHW()  # 将HWC格式转换为CHW格式
        ]
    # 将数据增强操作映射到数据中
    ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
    ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
    # 设置批处理大小,并且丢弃不足一批的数据
    ds = ds.batch(batch_size, drop_remainder=True)
    
    return ds
#根据设备情况调整训练参数

dataroot = "./localdata"
batch_size = 12
device_num = 1
rank = 0
use_random = True
max_dataset_size = 24000
image_size = 256

cyclegan_ds = create_dataset(dataroot=dataroot,max_dataset_size=max_dataset_size,batch_size=batch_size,device_num=device_num,rank = rank,use_random=use_random,image_size=image_size)
datasize = cyclegan_ds.get_dataset_size()
print("Datasize: ", datasize)

构建和训练CycleGAN的代码不再重复贴出,可参考上述提到的博文

构建DexiNed

import os
import cv2
import numpy as np
import time

import mindspore as ms
from mindspore import nn, ops
from mindspore import dataset as ds
from mindspore.amp import auto_mixed_precision
from mindspore.common import initializer as init
# DexiNed边缘检测数据集
class Test_Dataset():
    def __init__(self, data_root, mean_bgr, image_size):
        self.data = []
        # 列出根路径下的所有文件名
        imgs_ = os.listdir(data_root)
        # 初始化两个空列表,分别用于存储图像路径和对应的文件名
        self.names = []
        self.filenames = []
        for img in imgs_:
            if img.endswith(".png") or img.endswith(".jpg"):
                # 构建文件的完整路径
                dir = os.path.join(data_root, img)
                self.names.append(dir)
                self.filenames.append(img)
        self.mean_bgr = mean_bgr
        self.image_size = image_size
        
    def __len__(self):
        return len(self.names)
    
    def __getitem__(self, idx):
        # 使用OpenCV读取指定索引位置的图像,读取模式为彩色
        image = cv2.imread(self.names[idx], cv2.IMREAD_COLOR)
        # 读取图片的长宽
        im_shape = (image.shape[0], image.shape[1])
        # 对图像进行变换处理
        image = self.transform(img=image)
        # 返回图像、图像名、图像形状
        return image, self.filenames[idx], im_shape
    
    def transform(self, img):
        # 裁切
        img = cv2.resize(img, (self.image_size, self.image_size))
        # 将图像转换为浮点数类型
        img = np.array(img, dtype=np.float32)
        # 归一化
        img -= self.mean_bgr
        # 将图像从(H, W, C)格式转换为(C, H, W)格式
        img = img.transpose((2, 0, 1))
        return img
# DexiNed网络结构

# 初始化权重函数
def weight_init(net):
    for name, param in net.parameters_and_names():
        # 使用Xavier分布初始化权重
        if 'weight' in name:
            param.set_data(
                init.initializer(
                    init.XavierNormal(),
                    param.shape,
                    param.dtype))
        # 偏置初始化为0
        if 'bias' in name:
            param.set_data(init.initializer('zeros', param.shape, param.dtype))
# 表示DexiNed中的一个基础的密集连接层,实现具有批量归一化和ReLU激活的卷积层
class _DenseLayer(nn.Cell):
    def __init__(self, input_features, out_features):
        super(_DenseLayer, self).__init__()
        # 两个ConvNormReLU块
        self.conv1 = nn.Conv2d(
            input_features, out_features, kernel_size=3,
            stride=1, padding=2, pad_mode="pad",
            has_bias=True, weight_init=init.XavierNormal())
        self.norm1 = nn.BatchNorm2d(out_features)
        self.relu1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(
            out_features, out_features, kernel_size=3,
            stride=1, pad_mode="pad", has_bias=True,
            weight_init=init.XavierNormal())
        self.norm2 = nn.BatchNorm2d(out_features)
        self.relu = ops.ReLU()
    
    def construct(self, x):
        x1, x2 = x
        x1 = self.conv1(self.relu(x1))
        x1 = self.norm1(x1)
        x1 = self.relu1(x1)
        x1 = self.conv2(x1)
        new_features = self.norm2(x1)
        # 每一层的输出都会被传递给所有后续层作为输入
        return 0.5 * (new_features + x2), x2
# 基于DenseLayer定义DenseBlock
class _DenseBlock(nn.Cell):
    def __init__(self, num_layers, input_features, out_features):
        super(_DenseBlock, self).__init__()
        
        self.denselayer1 = _DenseLayer(input_features, out_features)
        
        input_features = out_features
        self.denselayer2 = _DenseLayer(input_features, out_features)
        
        if num_layers == 3:
            self.denselayer3 = _DenseLayer(input_features, out_features)
            self.layers = nn.SequentialCell(
                [self.denselayer1, self.denselayer2, self.denselayer3])
        else:
            self.layers = nn.SequentialCell(
                [self.denselayer1, self.denselayer2])

    def construct(self, x):
        x = self.layers(x)
        return x
# 表示上采样块,这是USNet的一部分,用于将特征图的尺寸增大。
class UpConvBlock(nn.Cell):
    def __init__(self, in_features, up_scale):
        super(UpConvBlock, self).__init__()
         # 定义上采样的因子,默认为2
        self.up_factor = 2
        # 定义一个常量特征数,通常用于控制输出通道的数量
        self.constant_features = 16
        
        layers = self.make_deconv_layers(in_features, up_scale)
        
        assert layers is not None, layers
        self.features = nn.SequentialCell(*layers)
    # # 构建上采样层
    def make_deconv_layers(self, in_features, up_scale):
        layers = []
        all_pads = [0, 0, 1, 3, 7]
        # 根据up_scale循环创建相应的层
        # 逐步放大
        for i in range(up_scale):
            # 定义卷积核大小
            kernel_size = 2 ** up_scale
            # 获取填充大小
            pad = all_pads[up_scale]
            # 计算输出维度
            out_features = self.compute_out_features(i, up_scale)
            # 创建反卷积层
            layers.append(nn.Conv2d(in_features, out_features, 1, has_bias=True))
            layers.append(nn.ReLU())
            layers.append(nn.Conv2dTranspose(
                out_features, out_features, kernel_size,
                stride=2, padding=pad, pad_mode="pad",
                has_bias=True, weight_init=init.XavierNormal()))
            # 更新通道数
            in_features = out_features
        return layers
    # 计算当前输出通道数
    def compute_out_features(self, idx, up_scale):
        return 1 if idx == up_scale - 1 else self.constant_features
    def construct(self, x):
        return self.features(x)
# 单个卷积块,包含Conv和BatchNorm
class SingleConvBlock(nn.Cell):
    def __init__(self, in_features, out_features, stride, use_bs=True):
        super().__init__()
        self.use_batch_norm = use_bs
        self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, pad_mode='pad', has_bias=True, weight_init=init.XavierNormal())
        self.bn = nn.BatchNorm2d(out_features)
        
    def construct(self, x):
        x = self.conv(x)
        if self.use_batch_norm:
            x = self.bn(x)
        return x
# 双卷积块
class DoubleConvBlock(nn.Cell):
    def __init__(self, in_features, mid_features,
                 out_features=None,
                 stride=1,
                 use_act=True):
        super(DoubleConvBlock, self).__init__()

        self.use_act = use_act
        if out_features is None:
            out_features = mid_features
        self.conv1 = nn.Conv2d(
            in_features,
            mid_features,
            3,
            padding=1,
            stride=stride,
            pad_mode="pad",
            has_bias=True,
            weight_init=init.XavierNormal())
        self.bn1 = nn.BatchNorm2d(mid_features)
        self.conv2 = nn.Conv2d(
            mid_features,
            out_features,
            3,
            padding=1,
            pad_mode="pad",
            has_bias=True,
            weight_init=init.XavierNormal())
        self.bn2 = nn.BatchNorm2d(out_features)
        self.relu = nn.ReLU()

    def construct(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        if self.use_act:
            x = self.relu(x)
        return x
# 自定义最大汇聚层
class maxpooling(nn.Cell):
    def __init__(self):
        super(maxpooling, self).__init__()
        self.pad = nn.Pad(((0,0),(0,0),(1,1),(1,1)), mode="SYMMETRIC")
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')

    def construct(self, x):
        x = self.pad(x)
        x = self.maxpool(x)
        return x
# 组件DexiNed网络
class DexiNed(nn.Cell):

    def __init__(self):
        super(DexiNed, self).__init__()
        #  DoubleConvBlock(双卷积块)实例,用于构建DexiNed的编码部分。
        # 处理输入图像和第一层的输出。
        self.block_1 = DoubleConvBlock(3, 32, 64, stride=2,)
        self.block_2 = DoubleConvBlock(64, 128, use_act=False)
        # 于实现DexiNed中的密集连接层
        # 用于提取多尺度特征
        self.dblock_3 = _DenseBlock(2, 128, 256)  # [128,256,100,100]
        self.dblock_4 = _DenseBlock(3, 256, 512)
        self.dblock_5 = _DenseBlock(3, 512, 512)
        self.dblock_6 = _DenseBlock(3, 512, 256)
        # 最大池化层,用于下采样操作。
        self.maxpool = maxpooling()
        # 用于从不同层次提取特征。
        self.side_1 = SingleConvBlock(64, 128, 2)
        self.side_2 = SingleConvBlock(128, 256, 2)
        self.side_3 = SingleConvBlock(256, 512, 2)
        self.side_4 = SingleConvBlock(512, 512, 1)
        self.side_5 = SingleConvBlock(512, 256, 1)  
        # 用于准备输入到密集块的数据。
        # right skip connections, figure in Journal paper
        self.pre_dense_2 = SingleConvBlock(128, 256, 2)
        self.pre_dense_3 = SingleConvBlock(128, 256, 1)
        self.pre_dense_4 = SingleConvBlock(256, 512, 1)
        self.pre_dense_5 = SingleConvBlock(512, 512, 1)
        self.pre_dense_6 = SingleConvBlock(512, 256, 1)
        # 上采样块,用于恢复特征图的尺寸。
        self.up_block_1 = UpConvBlock(64, 1)
        self.up_block_2 = UpConvBlock(128, 1)
        self.up_block_3 = UpConvBlock(256, 2)
        self.up_block_4 = UpConvBlock(512, 3)
        self.up_block_5 = UpConvBlock(512, 4)
        self.up_block_6 = UpConvBlock(256, 4)
        # 单卷积块,用于将多个尺度的特征图融合成一个单一的输出。
        self.block_cat = SingleConvBlock(6, 1, stride=1, use_bs=False)
    # 如果张量的形状与目标形状不匹配,则使用双线性插值进行调整;否则直接返回张量。
    def slice(self, tensor, slice_shape):
        t_shape = tensor.shape
        height, width = slice_shape
        if t_shape[-1] != slice_shape[-1]:
            new_tensor = ops.interpolate(
                tensor,
                sizes=(height, width),
                mode='bilinear',
                coordinate_transformation_mode="half_pixel")
        else:
            new_tensor = tensor
        return new_tensor

    def construct(self, x):
        # 确保输入张量 x 是四维的。
        assert x.ndim == 4, x.shape
        # 通过 block_1 处理输入 x,并通过 side_1 提取特征。
        # Block 1
        block_1 = self.block_1(x)
        block_1_side = self.side_1(block_1)
        # 通过 block_2 处理 block_1 的输出,然后通过 maxpool 降采样,并与 block_1_side 相加。再通过 side_2 提取特征。
        # Block 2
        block_2 = self.block_2(block_1)
        block_2_down = self.maxpool(block_2)
        block_2_add = block_2_down + block_1_side
        block_2_side = self.side_2(block_2_add)
        # 通过 pre_dense_3 处理 block_2_down,并将其与 block_2_add 一起传递给 dblock_3。
        # 然后通过 maxpool 降采样,并与 block_2_side 相加。再通过 side_3 提取特征。
        # Block 3
        block_3_pre_dense = self.pre_dense_3(block_2_down)
        block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense])
        block_3_down = self.maxpool(block_3)  # [128,256,50,50]
        block_3_add = block_3_down + block_2_side
        block_3_side = self.side_3(block_3_add)
        # 通过 pre_dense_2 和 pre_dense_4 处理 block_2_down 和 block_3_down,并将它们相加后传递给 dblock_4。
        # 然后通过 maxpool 降采样,并与 block_3_side 相加。再通过 side_4 提取特征。
        # Block 4
        block_2_resize_half = self.pre_dense_2(block_2_down)
        block_4_pre_dense = self.pre_dense_4(
            block_3_down + block_2_resize_half)
        block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense])
        block_4_down = self.maxpool(block_4)
        block_4_add = block_4_down + block_3_side
        block_4_side = self.side_4(block_4_add)
        # 通过 pre_dense_5 处理 block_4_down,并将其与 block_4_add 一起传递给 dblock_5。然后与 block_4_side 相加。
        # Block 5
        block_5_pre_dense = self.pre_dense_5(
            block_4_down)  # block_5_pre_dense_512 +block_4_down
        block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense])
        block_5_add = block_5 + block_4_side
        # 通过 pre_dense_6 处理 block_5,并将其与 block_5_add 一起传递给 dblock_6。
        # Block 6
        block_6_pre_dense = self.pre_dense_6(block_5)
        block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense])
        # upsampling blocks
        # 对每个块的输出进行上采样,恢复特征图的尺寸。
        out_1 = self.up_block_1(block_1)
        out_2 = self.up_block_2(block_2)
        out_3 = self.up_block_3(block_3)
        out_4 = self.up_block_4(block_4)
        out_5 = self.up_block_5(block_5)
        out_6 = self.up_block_6(block_6)
        results = [out_1, out_2, out_3, out_4, out_5, out_6]
        # 将所有上采样的输出拼接在一起,并通过 block_cat 进行最后的融合,生成最终的边缘图。
        # concatenate multiscale outputs
        op = ops.Concat(1)
        block_cat = op(results)

        block_cat = self.block_cat(block_cat)  # Bx1xHxW
        results.append(block_cat)
        # 返回包含多个尺度的边缘图和最终融合后的边缘图的结果列表。
        return results

使用DexiNed进行推理

'''将输入图像规格化到指定范围'''
def image_normalization(img, img_min=0, img_max=255, epsilon=1e-12):
    img = np.float32(img)
    img = (img - np.min(img)) * (img_max - img_min) / \
        ((np.max(img) - np.min(img)) + epsilon) + img_min
    return img
'''对DexiNed模型的输出数据进行后处理'''
# DexiNed 模型会输出多个尺度的边缘图(results列表中包含了多个上采样的结果),
# 这个函数将这些边缘图进行融合,并应用一些图像处理技术来生成最终的边缘检测结果。
def fuse_DNoutput(img):
    edge_maps = []
    tensor = img
    for i in tensor:
        sigmoid = ops.Sigmoid()
        output = sigmoid(i).numpy()
        edge_maps.append(output)
    tensor = np.array(edge_maps)
    idx = 0
    tmp = tensor[:, idx, ...]
    tmp = np.squeeze(tmp)
    preds = []
    for i in range(tmp.shape[0]):
        tmp_img = tmp[i]
        tmp_img = np.uint8(image_normalization(tmp_img))
        tmp_img = cv2.bitwise_not(tmp_img)
        preds.append(tmp_img)
        if i == 6:
            fuse = tmp_img
            fuse = fuse.astype(np.uint8)
    idx += 1
    return fuse
"""DexiNed 检测."""

def test(imgs,dexined_ckpt):
    
    if not os.path.isfile(dexined_ckpt):
        raise FileNotFoundError(
            f"Checkpoint file not found: {dexined_ckpt}")
    print(f"DexiNed ckpt path : {dexined_ckpt}")
    # os.makedirs(dexined_output_dir, exist_ok=True)
    model = DexiNed()
    # model = auto_mixed_precision(model, 'O2')
    ms.load_checkpoint(dexined_ckpt, model)
    model.set_train(False)
    preds = []
    origin = []
    total_duration = []
    print('Start dexined testing....')
    for img in imgs.create_dict_iterator():
        filename = str(img["names"])[2:-2]
        # print(filename)
        # output_dir_f = os.path.join(dexined_output_dir, filename)
        image = img["data"]
        origin.append(filename)
        end = time.perf_counter()
        # 使用DexiNed进行预测
        pred = model(image)
        # 获取图片宽高
        img_h = img["img_shape"][0, 0]
        img_w = img["img_shape"][0, 1]
        # 调用 fuse_DNoutput 函数对模型的输出进行后处理。
        pred = fuse_DNoutput(pred)
        # 将处理后的边缘图调整为原始图像的尺寸。
        dexi_img = cv2.resize(
            pred, (int(img_w.asnumpy()), int(img_h.asnumpy())))
        # cv2.imwrite("output.jpg", dexi_img)
        tmp_duration = time.perf_counter() - end
        total_duration.append(tmp_duration)
        preds.append(pred)
    total_duration_f = np.sum(np.array(total_duration))
    print("FPS: %f.4" % (len(total_duration) / total_duration_f))
    # 返回处理后的边缘图列表 preds 和原始图像文件名列表 origin。
    return preds,origin

DexiNed结合CycleGAN对壁画进行修复

import os
import numpy as np
from PIL import Image
import mindspore.dataset as ds
import matplotlib.pyplot as plt
import mindspore.dataset.vision as vision
from mindspore.dataset import transforms
from mindspore import load_checkpoint, load_param_into_net

# 加载权重文件
def load_ckpt(net, ckpt_dir):
    param_GA = load_checkpoint(ckpt_dir)
    load_param_into_net(net, param_GA)

    
#模型参数地址
g_a_ckpt = './ckpt/G_A_120.ckpt'
dexined_ckpt = "./ckpt/dexined.ckpt"

#图片输入地址
img_path='./ckpt/jt'
#输出地址
save_path='./result'

load_ckpt(net_rg_a, g_a_ckpt)

os.makedirs(save_path, exist_ok=True)
# 图片推理
fig = plt.figure(figsize=(16, 4), dpi=64)
def eval_data(dir_path, net, a):
    my_dataset = Test_Dataset(
        dir_path, mean_bgr=[167.15, 146.07, 124.62], image_size=512)
    
    dataset = ds.GeneratorDataset(
        my_dataset, column_names=[
            "data", "names", "img_shape"])
    dataset = dataset.batch(1, drop_remainder=True)
    # 使用DexiNed将原图转为线稿图
    preds ,origin= test(dataset,dexined_ckpt)
    for i, data in enumerate(preds):
        # 读取线稿图
        img =ms.Tensor((np.array([data,data,data])/255-0.5)*2).unsqueeze(0)
        # 将线稿图放入生成器,生成假图
        fake = net(img.to(ms.float32))
        fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))
        # 找到线稿图对应的原图,用作后续输出
        img = (Image.open(os.path.join(img_path,origin[i])).convert('RGB'))
        # 保存CycleGAN生成的结果
        fake_pil=Image.fromarray(fake.asnumpy())
        fake_pil.save(f"{save_path}/{i}.jpg")
        # 展示推理结果
        if i<8:
            fig.add_subplot(2, 8, min(i+1+a, 16))
            plt.axis("off")
            plt.imshow(np.array(img))

            fig.add_subplot(2, 8, min(i+9+a, 16))
            plt.axis("off")
            plt.imshow(fake.asnumpy())

eval_data(img_path,net_rg_a, 0)

plt.show()

输出结果如下:

详细可以参考MindSpore官方教学视频:

基于MindSpore实现CycleGAN壁画修复_哔哩哔哩_bilibili

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2189401.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Qt系统学习篇(6)-QMainWindow

QMainWindow是一个为用户提供主窗口程序的类&#xff0c;包含一个菜单栏(menu bar)、多个工具栏(tool bars)、多个锚接部件(dock widgets)、一个状态栏(status bar)及一个中心部件(central widget)&#xff0c;是许多应用程序的基础&#xff0c;如文本编辑器&#xff0c;图片编…

webpack信息泄露

先看看webpack中文网给出的解释 webpack 是一个模块打包器。它的主要目标是将 JavaScript 文件打包在一起,打包后的文件用于在浏览器中使用,但它也能够胜任转换、打包或包裹任何资源。 如果未正确配置&#xff0c;会生成一个.map文件&#xff0c;它包含了原始JavaScript代码的映…

算法笔记(九)——栈

文章目录 删除字符串中的所有相邻重复项比较含退格的字符串基本计算机II字符串解码验证栈序列 栈是一种先进后出的数据结构&#xff0c;其操作主要有 进栈、压栈&#xff08;Push&#xff09; 出栈&#xff08;Pop&#xff09; 常见的使用栈的算法题 中缀转后缀逆波兰表达式求…

关注、取关、Redis实现共同关注、 博客推送与分页查询

Resourceprivate StringRedisTemplate stringRedisTemplate;Resourceprivate IUserService userService;Overridepublic Result follow(Long followUserId, Boolean isFollow) {//1.获取登陆的用户Long userId UserHolder.getUser().getId();//1.判断是关注还是取关if(isFollo…

基于Springboot+Vue的小区运动中心预约管理系统的设计与实现 (含源码数据库)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 这段数…

FPGA-UART串口接收模块的理解

UART串口接收模块 背景 在之前就有写过关于串口模块的文章——《串口RS232的学习》。工作后很多项目都会用到串口模块&#xff0c;又来重新理解一下FPGA串口接收的代码思路。 关于串口相关的参数&#xff0c;以及在文章《串口RS232的学习》中已有详细的描述&#xff0c;这里就…

Linux启动mysql报错

甲方公司意外停电&#xff0c;所有服务器重启后&#xff0c;发现部署在Linux上的mysql数据库启动失败.再加上老员工离职&#xff0c;新接手项目&#xff0c;对Linux系统了解不多&#xff0c;解决起来用时较多&#xff0c;特此记录。 1.启动及报错 1.1 启动语句1 启动语句1&a…

Java编程基础(Scanner类==>循环语句)

文章目录 前言一、Scanner类1.创建Scanner对象2.使用3.实践 二、if条件语句1.简单if语句2.if-else语句3.if-else if-else语句3.实践 三、switch 开关语句四、循环语句1.for语句2.while语句3.do-while语句4.break和continue语句 总结 前言 我们发现在学习Java语言编程基础时&am…

【GEE数据库】WRF常用数据集总结

【GEE数据库】WRF常用数据集总结 GEE数据集介绍数据1:MODIS数据集LAI(叶面积指数)和Fpar(绿色植被率)年尺度土地利用类型数据2:月反射率(Monthly Albedo)数据3:LULC和ISA参考GEE数据集介绍 GEE数据搜索网址-A planetary-scale platform for Earth science data &…

(PyTorch) 深度学习框架-介绍篇

前言 在当今科技飞速发展的时代&#xff0c;人工智能尤其是深度学习领域正以惊人的速度改变着我们的世界。从图像识别、语音处理到自然语言处理&#xff0c;深度学习技术在各个领域都取得了显著的成就&#xff0c;为解决复杂的现实问题提供了强大的工具和方法。 PyTorch 是一个…

9.30学习记录(补)

手撕线程池: 1.进程:进程就是运行中的程序 2.线程的最大数量取决于CPU的核数 3.创建线程 thread t1; 在使用多线程时&#xff0c;由于线程是由上至下走的&#xff0c;所以主程序要等待线程全部执行完才能结束否则就会发生报错。通过thread.join()来实现 但是如果在一个比…

08.STL简介

1. 什么是STL STL(standard template libaray-标准模板库)&#xff1a;是C标准库的重要组成部分&#xff0c;不仅是一个可复用的 组件库&#xff0c;而且是一个包罗数据结构与算法的软件框架 2.发展历史 1. 起源与早期探索&#xff08;20世纪80年代初期&#xff09;&#xff…

可变形卷积(Deformable Convolution)是什么?

普通卷积 普通卷积&#xff08;dilation1&#xff09; 普通卷积就是特征图与卷积核的权重W相乘再求和 y(p0​) 表示输出特征图在位置 p0​ 的值。&#x1d465;(&#x1d45d;0&#x1d45d;&#x1d45b;)表示输入特征图在位置 pn​ 的值。&#x1d464;(&#x1d45d;&…

烟火烟雾检测数据集 9600张 烟雾火焰检测 带标注 voc yolo 2类 烟火数据集 烟雾数据集 烟火检测烟雾检测

烟火检测数据集 9600张 烟雾火焰检测 带标注 voc yolo 烟火检测数据集介绍 数据集名称 烟火检测数据集 (Fire and Smoke Detection Dataset) 数据集概述 该数据集专为训练和评估基于YOLO系列目标检测模型&#xff08;包括YOLOv5、YOLOv6、YOLOv7等&#xff09;而设计&#x…

malloc源码分析之 ----- 你想要啥chunk

文章目录 malloc源码分析之 ----- 你想要啥chunktcachefastbinsmall binunsorted binbin处理top malloc源码分析之 ----- 你想要啥chunk tcache malloc源码&#xff0c;这里以glibc-2.29为例&#xff1a; void * __libc_malloc (size_t bytes) {mstate ar_ptr;void *victim;vo…

Qt Quick 3D 入门:QML 3D场景详解

随着 Qt 6 的发布&#xff0c;QtQuick3D 模块带来了新的 3D 渲染和交互能力&#xff0c;使得在 Qt 中创建 3D 场景变得更加简单和直观。本文将带您从一个简单的 QML 3D 应用开始&#xff0c;详细讲解各个相关领域的概念、代码实现以及功能特点。 什么是 Qt Quick 3D&#xff1…

C++拾趣——绘制Console中Check Box

大纲 居中显示窗口清屏并重设光标绘制窗口绘制窗口顶部绘制复选项绘制按钮行绘制窗口底部 修改终端默认行为对方向键的特殊处理过程控制Tab键的处理Enter键的处理上下左右方向键的处理 完整代码代码地址 这次我们要绘制复选框&#xff0c;如下图。 居中显示窗口 按照界面库的…

网约班车升级手机端退票

背景 作为老古董程序员&#xff0c;不&#xff0c;应该叫互联网人员&#xff0c;因为我现在做的所有的事情&#xff0c;都是处于爱好&#xff0c;更多的时间是在和各行各业的朋友聊市场&#xff0c;聊需求&#xff0c;聊怎么通过IT互联网 改变实体行业的现状&#xff0c;准确的…

ExcelToWord-Excel套打Word-Word邮件合并工具分享

Excel to Word转换工具分享 在日常工作或学习中&#xff0c;我们常常需要将Excel中的数据导出到Word文档中&#xff0c;以便更好地展示信息。市场上有许多Excel to Word的转换工具&#xff0c;它们各有特色。今天&#xff0c;我们就来推荐几款这样的工具&#xff0c;并探讨一下…

如何使用虚拟机充当软路由

文章目录 前言下载系统把 iso 转为虚拟机使用 VMware 创建虚拟机 前言 很多人需要软路由&#xff0c;但是软路由需要设备的投入&#xff0c;我这里使用虚拟机充当软路由。省下了设备的投入。不过多花了电费。大家自己取舍吧。 下载系统 ImmortalWrt Firmware Selector 在上…