昇思25天学习打卡营第18天|Pix2Pix实现图像转换

news2025/2/21 19:24:12

Pix2Pix概述

Pix2Pix是基于条件生成对抗网络实现的一种深度学习图像转换模型。Pix2Pix是将cGAN应用于有监督的图像到图像翻译,包括生成器和判别器。

基础原理

cGAN的生成器是将输入图片作为指导信息,由输入图像不断尝试生成用于迷惑判别器的“假”图像,由输入图像转换输出为相应“假”图像的本质是从像素到另一个像素的映射,而传统GAN的生成器是基于一个给定的随机噪声生成图像,输出图像通过其他约束条件控制生成。Pix2Pix中判别器的任务是判断从生成器输出的图像是真实的训练图像还是生成的“假”图像。在生成器与判别器的不断博弈过程中,模型会达到一个平衡点,生成器输出的图像与真实训练数据使得判别器刚好具有50%的概率判断正确。

CGAN的目标损失函数为:

L_{cGAN}(G,D)=E_{(x,y)}[log(D(x,y))]+E_{(x,z)}[log(1-D(x,G(x,z)))]

目标函数是使判别器的损失最大化,而生成器的损失最小化。

pix2pix1

数据准备

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"

download(url, "./dataset", kind="tar", replace=True)

from mindspore import dataset as ds
import matplotlib.pyplot as plt

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):
    plt.subplot(3, 10, i)
    plt.axis("off")
    plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()

创建网络

生成器G结构

使用U-Net,它分为两个部分,其中左侧是由卷积和降采样操作组成的压缩路径,右侧是由卷积和上采样组成的扩张路径,扩张的每个网络块的输入由上一层上采样的特征和压缩路径部分的特征拼接而成。

pix2pix2

定义UNet Skip Connection Block

import mindspore
import mindspore.nn as nn
import mindspore.ops as ops

class UNetSkipConnectionBlock(nn.Cell):
    def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
                 submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
        super(UNetSkipConnectionBlock, self).__init__()
        down_norm = nn.BatchNorm2d(inner_nc)
        up_norm = nn.BatchNorm2d(outer_nc)
        use_bias = False
        if norm_mode == 'instance':
            down_norm = nn.BatchNorm2d(inner_nc, affine=False)
            up_norm = nn.BatchNorm2d(outer_nc, affine=False)
            use_bias = True
        if in_planes is None:
            in_planes = outer_nc
        down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
                              stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
        down_relu = nn.LeakyReLU(alpha)
        up_relu = nn.ReLU()
        if outermost:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, pad_mode='pad')
            down = [down_conv]
            up = [up_relu, up_conv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv]
            up = [up_relu, up_conv, up_norm]
            model = down + up
        else:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv, down_norm]
            up = [up_relu, up_conv, up_norm]

            model = down + [submodule] + up
            if dropout:
                model.append(nn.Dropout(p=0.5))
        self.model = nn.SequentialCell(model)
        self.skip_connections = not outermost

    def construct(self, x):
        out = self.model(x)
        if self.skip_connections:
            out = ops.concat((out, x), axis=1)
        return out

基于UNet的生成器

class UNetGenerator(nn.Cell):
    def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
        super(UNetGenerator, self).__init__()
        unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
                                             norm_mode=norm_mode, innermost=True)
        for _ in range(n_layers - 5):
            unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
                                                 norm_mode=norm_mode, dropout=dropout)
        unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
                                             outermost=True, norm_mode=norm_mode)

    def construct(self, x):
        return self.model(x)

基于PatchGAN的判别器

生成的矩阵中的每个点代表原图的一小块区域(patch)。通过矩阵中的各个值来判断原图中对应每个Patch的真假。

import mindspore.nn as nn

class ConvNormRelu(nn.Cell):
    def __init__(self,
                 in_planes,
                 out_planes,
                 kernel_size=4,
                 stride=2,
                 alpha=0.2,
                 norm_mode='batch',
                 pad_mode='CONSTANT',
                 use_relu=True,
                 padding=None):
        super(ConvNormRelu, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if not padding:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
                             has_bias=has_bias, padding=padding)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

class Discriminator(nn.Cell):
    def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [
            nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
            nn.LeakyReLU(alpha)
        ]
        nf_mult = ndf
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * ndf
            layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * ndf
        layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
        self.features = nn.SequentialCell(layers)

    def construct(self, x, y):
        x_y = ops.concat((x, y), axis=1)
        output = self.features(x_y)
        return output

Pix2Pix的生成器和判别器初始化

实例化Pix2Pix生成器和判别器

import mindspore.nn as nn
from mindspore.common import initializer as init

g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'


net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,
                              ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))


net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,
                                  alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))

class Pix2Pix(nn.Cell):
    """Pix2Pix模型网络"""
    def __init__(self, discriminator, generator):
        super(Pix2Pix, self).__init__(auto_prefix=True)
        self.net_discriminator = discriminator
        self.net_generator = generator

    def construct(self, reala):
        fakeb = self.net_generator(reala)
        return fakeb

训练

包括训练判别器和生成器。训练判别器的目的是最大程度地提高判别图像真伪的概率。训练生成器是希望能产生更好的虚假图像。

代码实现:

import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor

epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100

def get_lr():
    lrs = [lr] * dataset_size * n_epochs
    lr_epoch = 0
    for epoch in range(n_epochs_decay):
        lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
        lrs += [lr_epoch] * dataset_size
    lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
    return Tensor(np.array(lrs).astype(np.float32))

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

def forword_dis(reala, realb):
    lambda_dis = 0.5
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    pred1 = net_discriminator(reala, realb)
    loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
    loss_dis = loss_d * lambda_dis
    return loss_dis

def forword_gan(reala, realb):
    lambda_gan = 0.5
    lambda_l1 = 100
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    loss_1 = loss_f(pred0, ops.ones_like(pred0))
    loss_2 = l1_loss(fakeb, realb)
    loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
    return loss_gan

d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)

grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())

def train_step(reala, realb):
    loss_dis, d_grads = grad_d(reala, realb)
    loss_gan, g_grads = grad_g(reala, realb)
    d_opt(d_grads)
    g_opt(g_grads)
    return loss_dis, loss_gan

if not os.path.isdir(ckpt_dir):
    os.makedirs(ckpt_dir)

g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)

for epoch in range(epoch_num):
    for i, data in enumerate(data_loader):
        start_time = datetime.datetime.now()
        input_image = Tensor(data["input_images"])
        target_image = Tensor(data["target_images"])
        dis_loss, gen_loss = train_step(input_image, target_image)
        end_time = datetime.datetime.now()
        delta = (end_time - start_time).microseconds
        if i % 2 == 0:
            print("ms per step:{:.2f}  epoch:{}/{}  step:{}/{}  Dloss:{:.4f}  Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))
        d_losses.append(dis_loss.asnumpy())
        g_losses.append(gen_loss.asnumpy())
    if (epoch + 1) == epoch_num:
        mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")

推理

from mindspore import load_checkpoint, load_param_into_net

param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):
    plt.subplot(2, 10, i + 1)
    plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 10, i + 11)
    plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

总结

Pix2Pix作为GAN的一种变体,再生成图像和扩充数据方面有着重要作用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1899756.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【三】ubuntu24虚拟机集群配置免密登陆

文章目录 环境背景1. 配置域名映射2. 配置免密登录2.1 在每台机器上生成SSH密钥对:2.2 将公钥分发到其他机器:2.2.1 报错问题2.2.2 修复方法 3. 验证免密登录在 ubuntu1 上:在 ubuntu2 上:在 ubuntu3 上: 测试连接 环境…

应急响应-网站入侵篡改指南Webshell内存马查杀漏洞排查时间分析

查看146天的内存马 方法: 1. 日志 这种地址一般在扫描 还要注意post传参注入 对其进行全局定位 发现有sql注入 我们可以也尝试去sqlmap注入 如果以这种方式注入ua头就会改变 2. 了解自己的中间件,框架,cve,等 因为不知道时间…

linux-5.10.110内核源码分析 - Freescale ls1012a pcie host驱动

1、dts pcie设备树 1.1、pcie设备树 pcie1: pcie3400000 {compatible "fsl,ls1012a-pcie";reg <0x00 0x03400000 0x0 0x00100000 /* controller registers */0x40 0x00000000 0x0 0x00002000>; /* configuration space */reg-names "regs", &…

Linux-DNS

DNS域名解析服务 1.DNS介绍 DNS 是域名系统 (Domain Name System) 的缩写&#xff0c;是因特网的一项核心服务&#xff0c;它作为可以将域名和IP地址相互映射的一个分布式数据库&#xff0c;能够使人更方便的访问互联网&#xff0c;而不用去记住能够被机器直接读取的IP数串。…

乐鑫ESPRESSIF芯片开发简介

乐鑫科技&#xff08;Espressif Systems&#xff0c;通常简称乐鑫或ESPRESSIF&#xff09;是一家全球化的无晶圆厂半导体公司&#xff0c;专注于研发无线通信微控制器单元&#xff08;MCU&#xff09;芯片&#xff0c;特别在物联网&#xff08;IoT&#xff09;领域有着显著的影…

【CentOS 7.6】Linux版本 portainer本地镜像导入docker安装配置教程,不需要魔法拉取!(找不着镜像的来看我)

吐槽 我本来根本不想写这篇博客&#xff0c;但我很不解也有点生气&#xff0c;CSDN这么大没有人把现在需要魔法才能拉取的镜像放上来。 你们都不放&#xff0c;根本不方便。我来上传资源。 portainer-ce-latest.tar Linux/amd64 镜像下载地址&#xff1a; 链接&#xff1a;h…

windows下搭建python+jupyter notebook

一.下载python 下面网址下载python3 https://www.python.org/ 二. 安装jupyter notebook 三. 修改配置 四. 检测是否正常运行

【IT领域新生必看】 Java编程中的重写(Overriding)规则:初学者轻松掌握的全方位指南

文章目录 引言什么是方法重写&#xff08;Overriding&#xff09;&#xff1f;方法重写的基本示例 方法重写的规则1. 方法签名必须相同示例&#xff1a; 2. 返回类型可以是子类型&#xff08;协变返回类型&#xff09;示例&#xff1a; 3. 访问修饰符不能比父类的更严格示例&am…

《C++20设计模式》代理模式

文章目录 一、前言二、实现1、UML类图2、实现 一、前言 这代理模式和装饰器模式很像啊。都是套一层类。&#x1f630; 主要就是功能差别 装饰器&#xff1a; 为了强化原有类的功能。代理模式&#xff1a; 不改变原有功能&#xff0c;只是强化原有类的潜在行为。 我觉的书上有…

spark on k8s两种方式的原理与对比

spark on k8s两种方式的原理与对比 1、spark on k8s 方式 spark-submit可以直接用来向 Kubernetes 集群提交 Spark 应用&#xff0c;提交机制如下&#xff1a; 1、Spark 创建一个在Kubernetes pod中运行的 Spark 驱动程序。 2、驱动程序创建在 Kubernetes Pod 中运行的执行器…

Python创建MySQL数据库

一、使用Docker部署本地MySQL数据库 docker run --restartalways -p 3307:3306 --name mysql -e MYSOL_ROOT_PASSWORDlms123456 -d mysql:8.0.25 参数解析: 用户名:root 密码:lms123456 端口:3307 二、在Pycharm开发工具中配置连接MySQL数据库 三、安装zdppy_mysql pip inst…

《向量数据库指南》——Milvus Cloud索引增强如何提升 RAG Pipeline 效果?

索引增强 1.自动合并块 在建立索引时&#xff0c;分两个粒度搭建&#xff0c;一个是chunk本身&#xff0c;另一个是chunk所在的parent chunk。先搜索更细粒度的chunks&#xff0c;接着采用一种合并的策略——如果前k个子chunk中超过n个chunk属于同一个parent chunk&#xff0c…

centos下编译安装redis最新稳定版

一、目标 编译安装最新版的redis 二、安装步骤 1、redis官方下载页面 Downloads - Redis 2、下载最新版的redis源码包 注&#xff1a;此时的最新稳定版是 redis 7.2.5 wget https://download.redis.io/redis-stable.tar.gz 3、安装编译环境 yum install -y gcc gcc-c …

使用patch-package自动修改node_modules中的内容/打补丁

背景 在使用VuePress搭建个人博客的过程中&#xff0c;我需要使用到一个用来复制代码块的插件uepress-plugin-nuggets-style-copy。 问题&#xff1a;插件可以正常安装&#xff0c;但是启动会报错。通过查看错误信息&#xff0c;定位是插件中的copy.vue文件出现错误&#xff0c…

学习笔记——动态路由——OSPF聚合(汇总)

十一、OSPF聚合(汇总) 1、路由聚合(汇总) 路由汇总是一种重要的思想&#xff0c;在大型的项目中是必须考虑的一个重点事项。随着网络的规模越来越大&#xff0c;网络中的设备所需维护的路由表项也就会越来越多&#xff0c;路由表的规模也就会逐渐变大&#xff0c;而路由表是需…

【智能算法应用】麻雀搜索算法SSA优化Kmeans图像分割

目录 1.算法原理2.数学模型3.结果展示4.参考文献5.代码获取 1.算法原理 【智能算法】麻雀搜索算法&#xff08;SSA&#xff09;原理及实现 2.数学模型 Kmeans是一种无监督的聚类算法,由于参数简洁,时间复杂度低已成功应用于图像分割,取得了良好的分割效果。但传统的 K 均值聚…

45 mysql truncate 的实现

前言 truncate 是一个我们也经常会使用到的命令 其作用类似于 delete from $table; 但是 他会比 delete 块很多&#xff0c;这里我们来看一下 它的实现 delete 的时候会逐行进行处理, 打上 删除标记, 然后 由后台任务 进行数据处理 truncate table 的实现 执行 sql 如下 …

计算机图形学入门24:材质与外观

1.前言 想要得到一个漂亮准确的场景渲染效果&#xff0c;不只需要物理正确的全局照明算法&#xff0c;也要了解现实中各种物体的表面外观和在图形学中的模拟方式。而物体的外观和材质其实就是同一个意思&#xff0c;不同的材质在光照下就会表现出不同的外观&#xff0c;所以外观…

HTTP与HTTPS的主要区别

HTTP&#xff08;超文本传输协议&#xff09;与HTTPS&#xff08;超文本传输安全协议&#xff09;的主要区别在于安全性、数据传输方式、默认使用的端口以及对网站的影响。 一、安全性&#xff1a; HTTP是一种无加密的协议&#xff0c;数据在传输过程中以明文形式发送&#x…