昇思25天学习打卡营第24天 | Pix2Pix实现图像转换

news2025/1/19 3:18:40

昇思25天学习打卡营第24天 | Pix2Pix实现图像转换

文章目录

  • 昇思25天学习打卡营第24天 | Pix2Pix实现图像转换
    • Pix2Pix模型
      • cGAN
      • CGAN的损失函数
    • 数据
    • 网络构建
      • 生成器
      • 判别器
      • Pix2Pix网络
    • 总结
    • 打卡

Pix2Pix模型

Pix2Pix是基于条件生成对抗网络(cGAN, Condition Generative Adversarial Networks)的一种图像转换模型,可以用于图像到图像的翻译。

cGAN

cGAN的生成器将输入图片作为指导信息,由输入图片不断生成“假”图像,完成从像素到像素的映射,而传统的GAN则是由随机噪声生成“假”图像。

CGAN的损失函数

cGAN的损失函数为:
L c G A N ( G , D ) = E ( x , y ) [ log ⁡ ( D ( x , y ) ) ] + E ( x , z ) [ log ⁡ ( 1 − D ( x , G ( x , z ) ) ) ) ] L_{cGAN}(G,D)=E_{(x,y)}[\log(D(x,y))]+E_{(x,z)}[\log(1-D(x,G(x,z))))] LcGAN(G,D)=E(x,y)[log(D(x,y))]+E(x,z)[log(1D(x,G(x,z))))]

  • 判别器 D D D需要尽可能区分真实图像和假图像,即使 log ⁡ ( D ( x , y ) ) \log(D(x,y)) log(D(x,y))最大化;
  • 生成器 G G G需要生成最真实的“假”图像 y y y以欺骗 D D D,即使 log ⁡ ( 1 − D ( G ( x , z ) ) ) \log(1-D(G(x,z))) log(1D(G(x,z)))最小化。

则cGAN的目标可以简化为:
arg ⁡ min ⁡ G max ⁡ D L c G A N ( G , D ) \arg\min_G\max_D L_{cGAN}(G,D) argGminDmaxLcGAN(G,D)
pix2pix1

数据

实验使用处理过的外墙(facades)数据

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"

download(url, "./dataset", kind="tar", replace=True)

from mindspore import dataset as ds
import matplotlib.pyplot as plt

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):
    plt.subplot(3, 10, i)
    plt.axis("off")
    plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()

在这里插入图片描述

网络构建

生成器

生成器 G G G采用U-Net结构,由两个部分组成:

  • 压缩路径:由卷积和降采样组成;
  • 扩张路径:由卷积和上采样组成。
    pix2pix2
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops

class UNetSkipConnectionBlock(nn.Cell):
    def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,
                 submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):
        super(UNetSkipConnectionBlock, self).__init__()
        down_norm = nn.BatchNorm2d(inner_nc)
        up_norm = nn.BatchNorm2d(outer_nc)
        use_bias = False
        if norm_mode == 'instance':
            down_norm = nn.BatchNorm2d(inner_nc, affine=False)
            up_norm = nn.BatchNorm2d(outer_nc, affine=False)
            use_bias = True
        if in_planes is None:
            in_planes = outer_nc
        down_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,
                              stride=2, padding=1, has_bias=use_bias, pad_mode='pad')
        down_relu = nn.LeakyReLU(alpha)
        up_relu = nn.ReLU()
        if outermost:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, pad_mode='pad')
            down = [down_conv]
            up = [up_relu, up_conv, nn.Tanh()]
            model = down + [submodule] + up
        elif innermost:
            up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv]
            up = [up_relu, up_conv, up_norm]
            model = down + up
        else:
            up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,
                                         kernel_size=4, stride=2,
                                         padding=1, has_bias=use_bias, pad_mode='pad')
            down = [down_relu, down_conv, down_norm]
            up = [up_relu, up_conv, up_norm]

            model = down + [submodule] + up
            if dropout:
                model.append(nn.Dropout(p=0.5))
        self.model = nn.SequentialCell(model)
        self.skip_connections = not outermost

    def construct(self, x):
        out = self.model(x)
        if self.skip_connections:
            out = ops.concat((out, x), axis=1)
        return out

class UNetGenerator(nn.Cell):
    def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):
        super(UNetGenerator, self).__init__()
        unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,
                                             norm_mode=norm_mode, innermost=True)
        for _ in range(n_layers - 5):
            unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,
                                                 norm_mode=norm_mode, dropout=dropout)
        unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,
                                             norm_mode=norm_mode)
        self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,
                                             outermost=True, norm_mode=norm_mode)

    def construct(self, x):
        return self.model(x)

判别器

判别器使用PatchGAN结构。

import mindspore.nn as nn

class ConvNormRelu(nn.Cell):
    def __init__(self,
                 in_planes,
                 out_planes,
                 kernel_size=4,
                 stride=2,
                 alpha=0.2,
                 norm_mode='batch',
                 pad_mode='CONSTANT',
                 use_relu=True,
                 padding=None):
        super(ConvNormRelu, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if not padding:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',
                             has_bias=has_bias, padding=padding)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

class Discriminator(nn.Cell):
    def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [
            nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),
            nn.LeakyReLU(alpha)
        ]
        nf_mult = ndf
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * ndf
            layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * ndf
        layers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))
        self.features = nn.SequentialCell(layers)

    def construct(self, x, y):
        x_y = ops.concat((x, y), axis=1)
        output = self.features(x_y)
        return output

Pix2Pix网络

import mindspore.nn as nn
from mindspore.common import initializer as init

g_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'


net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,
                              ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))


net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,
                                  alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():
    if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):
        if init_type == 'normal':
            cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))
        elif init_type == 'xavier':
            cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))
        elif init_type == 'constant':
            cell.weight.set_data(init.initializer(0.001, cell.weight.shape))
        else:
            raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
    elif isinstance(cell, nn.BatchNorm2d):
        cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))
        cell.beta.set_data(init.initializer('zeros', cell.beta.shape))

class Pix2Pix(nn.Cell):
    """Pix2Pix模型网络"""
    def __init__(self, discriminator, generator):
        super(Pix2Pix, self).__init__(auto_prefix=True)
        self.net_discriminator = discriminator
        self.net_generator = generator

    def construct(self, reala):
        fakeb = self.net_generator(reala)
        return fakeb

总结

这一节介绍了Pix2Pix网络,该网络使用cGAN模型实现了图像到图像的映射。相对于传统GAN使用随机噪声来生成图像,cGAN直接使用输入图像来生成新图像。在Pix2Pix的实现中,使用了U-Net结构来保留不同分辨率下的细节信息。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1949357.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何在测试中保护用户隐私!

在当今数据驱动的时代,用户隐私保护成为了企业和开发团队关注的焦点。在软件测试过程中,处理真实用户数据时保护隐私尤为重要。本文将介绍如何在测试中保护用户隐私,并提供具体的方案和实战演练。 用户隐私保护的重要性 用户隐私保护不仅是法…

Qt自定义带前后缀图标的PushButton

写在前面 Qt提供QPushButton不满足带前后缀图标的需求,因此考虑自定义实现带前后缀图标的PushButton,方便后续快速使用。 效果如下: 同时可设置前后缀图标和文本之间间隙: 代码实现 通过前文介绍的Qt样式表底层实现 可以得…

linux ftp操作记录

一.ftp 创建用户 passwd: user ftpuser does not exist 如果你遇到 passwd: user ftpuser does not exist 的错误,这意味着系统中不存在名为 ftpuser 的用户。你需要首先确认FTP用户是否是系统用户,还是FTP服务器软件(如Pure-FTPd&#xff…

类和对象:完结

1.再深构造函数 • 之前我们实现构造函数时,初始化成员变量主要使⽤函数体内赋值,构造函数初始化还有⼀种⽅ 式,就是初始化列表,初始化列表的使⽤⽅式是以⼀个冒号开始,接着是⼀个以逗号分隔的数据成 员列表&#xf…

redis的使用场景

1. redis的使用场景 redis使用场景的案例:[1]热点数据的缓存[2]分布式锁[3]短信业务(登录注册时)2. redis实现注册登录功能 代码 在发送验证码时,先判断数据库是否有该手机号,有则发送验证码(此时redis缓存…

基于微信小程序+SpringBoot+Vue的自习室选座与门禁系统(带1w+文档)

基于微信小程序SpringBootVue的自习室选座与门禁系统(带1w文档) 基于微信小程序SpringBootVue的自习室选座与门禁系统(带1w文档) 本课题研究的研学自习室选座与门禁系统让用户在小程序端查看座位,预定座位,支付座位价格,该系统让用户预定座位…

Jmeter三种方式获取数组中多个数据并将其当做下个接口参数入参【附带JSON提取器和CSV格式化】

目录 一、传统方式-JOSN提取器获取接口返回值 1、接口调用获取返回值 2、添加JSON提取器 3、调试程序查看结果 4、添加循环控制器 5、设置count计数器 6、添加请求 7、执行请求 二、CSV参数化 1、将结果写入后置处理程序 2、设置循环处理器 3、添加CSV文件 4、设置…

【机器学习】用Jupyter Notebook实现并探索单变量线性回归的代价函数以及遇到的一些问题

引言 在机器学习中,代价函数(Cost Function)是一个用于衡量模型预测值与实际值之间差异的函数。在监督学习中,代价函数是评估模型性能的关键工具,它可以帮助我们了解模型在训练数据上的表现,并通过优化过程…

IPD推行成功的核心要素(十五)项目管理提升IPD相关项目交付效率和用户体验

研发项目往往包含很多复杂的流程和具体的细节。因此,一套完整且标准的研发项目管理制度和流程对项目的推进至关重要。研发项目管理是成功推动创新和技术发展的关键因素。然而在实际管理中,研发项目管理常常面临着需求不确定、技术风险、人员素质、成本和…

PyTorch安装CUDA标准流程(可解决大部分GPU无法使用问题)

最近一段时间在研究PyTorch中的GPU的使用方法,之前曾经安装过CUDA,不过在PyTorch中调用CUDA时无法使用。考虑到是版本不兼容问题,卸载后尝试了其他的版本,依旧没有能解决问题,指导查阅了很多资料后才找到了解决方案。 …

uni-app声生命周期

应用的生命周期函数在App.vue页面 onLaunch:当uni-app初始化完成时触发(全局触发一次) onShow:当uni-app启动,或从后台进入前台时显示 onHide:当uni-app从前台进入后台 onError:当uni-app报错时触发,异常信息为err 页面的生命周期 onLoad…

数据治理之“财务一张表”

前言 信息技术的发展,伴随企业业务系统的纷纷建设,提升业务处理效率的同时,也将企业的整体主价值链流程分成了一段一段的业务子流程,很多情况下存在数据上报延迟、业务协作不顺畅、计划反馈不及时、库存积压占资多……都可以从数据…

20240725java的Controller、DAO、DO、Mapper、Service层、反射、AOP注解等内容的学习

在Java开发中,‌controller、‌dao、‌do、‌mapper等概念通常与MVC(‌Model-View-Controller)‌架构和分层设计相关。‌这些概念各自承担着不同的职责,‌共同协作以构建和运行一个应用程序。‌以下是这些概念的解释:‌…

深度学习趋同性的量化探索:以多模态学习与联合嵌入为例

深度学习趋同性的量化探索:以多模态学习与联合嵌入为例 参考文献 据说是2024年最好的人工智能论文,是否有划时代的意义? [2405.07987] The Platonic Representation Hypothesis (arxiv.org) ​arxiv.org/abs/2405.07987 趋同性的量化表达 …

OAK相机支持的图像传感器有哪些?

相机支持的传感器 在 RVC2 上,固件必须具有传感器配置才能支持给定的相机传感器。目前,我们支持下面列出的相机传感器的开箱即用(固件中)传感器配置。 名称 分辨率 传感器类型 尺寸 最大 帧率 IMX378 40563040 彩色 1/2.…

产品经理-简历的筛选标准(22)

什么是简历 简要地描述过往的经历—一份简历的核心要素就是介绍你所经历过的事情 因此准备简历的关键是“简”“要”二字: 一方面是你挑选出来的事情,一定是重要的、能给予你所谋求的位置提供竞争力的事情;另一方面是在描述这件事情的时候 要…

《Java初阶数据结构》----6.<优先级队列之PriorityQueue底层:堆>

前言 大家好,我目前在学习java。之前也学了一段时间,但是没有发布博客。时间过的真的很快。我会利用好这个暑假,来复习之前学过的内容,并整理好之前写过的博客进行发布。如果博客中有错误或者没有读懂的地方。热烈欢迎大家在评论区…

Golang | Leetcode Golang题解之第290题单词规律

题目: 题解: func wordPattern(pattern string, s string) bool {word2ch : map[string]byte{}ch2word : map[byte]string{}words : strings.Split(s, " ")if len(pattern) ! len(words) {return false}for i, word : range words {ch : patt…

【Python实战因果推断】56_因果推理概论6

目录 Causal Quantities: An Example Bias Causal Quantities: An Example 让我们看看在我们的商业问题中,你如何定义这些量。首先,你要注意到,你永远无法知道价格削减(即促销活动)对某个特定商家的确切影响&#xf…

算法 定长按组翻转链表

一、题目 已知一个链表的头部head,每k个结点为一组,按组翻转。要求返回翻转后的头部 k是一个正整数,它的值小于等于链表长度。如果节点总数不是k的整数倍,则剩余的结点保留原来的顺序。示例如下: (要求不…