Mindspore框架CycleGAN模型实现图像风格迁移|(一)Cycle神经网络模型构建

news2024/11/14 20:48:42

Mindspore框架:CycleGAN模型实现图像风格迁移算法

  1. Mindspore框架CycleGAN模型实现图像风格迁移|(一)CycleGAN神经网络模型构建
  2. Mindspore框架CycleGAN模型实现图像风格迁移|(二)实例数据集(苹果2橘子)
  3. Mindspore框架CycleGAN模型实现图像风格迁移|(三)损失函数计算
  4. Mindspore框架CycleGAN模型实现图像风格迁移|(四)CycleGAN模型训练
  5. Mindspore框架CycleGAN模型实现图像风格迁移|(五)CycleGAN模型推理与资源下载

CycleGAN神经网络模型构建

1. 关于CycleGAN神经网络模型

CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络。
该模型实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。实现图像风格迁移。

经典风格迁移模型Pix2Pix,要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)成对出现的图片是相当困难的。 CycleGAN是一种新的无监督的图像迁移网络。更有可行性。

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,其结构如下:
在这里插入图片描述
X是一种风格,比如苹果;Y 是另一种风格,比如橘子。
𝐺为将苹果风格X生成橘子风格Y的生成器网络;𝐹为将橘子生成的苹果风格的生成器网络。
𝐷𝑋和 𝐷𝑌为其相应判别器网络。
模型最终能够输出两个模型的权重,分别将两种图像的风格进行彼此迁移,生成新的图像。

2.CycleGAN构建

在这里插入图片描述
输入大小为256×256以上的,需要采用9个残差块相连,超参数 n_layers=9 参数控制残差块数。
构建ConvNormReLUResidualBlock子网络,搭建ResNetGenerator实现G、F生成器。

  1. 生成器网络构建
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal

weight_init = Normal(sigma=0.02)

class ConvNormReLU(nn.Cell):
    def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',
                 pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):
        super(ConvNormReLU, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if padding is None:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, padding=padding, weight_init=weight_init)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, weight_init=weight_init)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output


class ResidualBlock(nn.Cell):
    def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)
        self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)
        self.dropout = dropout
        if dropout:
            self.dropout = nn.Dropout(p=0.5)

    def construct(self, x):
        out = self.conv1(x)
        if self.dropout:
            out = self.dropout(out)
        out = self.conv2(out)
        return x + out


class ResNetGenerator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,
                 pad_mode="CONSTANT"):
        super(ResNetGenerator, self).__init__()
        self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)
        self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)
        self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)
        layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers
        self.residuals = nn.SequentialCell(layers)
        self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)
        self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)
        if pad_mode == "CONSTANT":
            self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',
                                      padding=3, weight_init=weight_init)
        else:
            pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)
            conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)
            self.conv_out = nn.SequentialCell([pad, conv])

    def construct(self, x):
        x = self.conv_in(x)
        x = self.down_1(x)
        x = self.down_2(x)
        x = self.residuals(x)
        x = self.up_2(x)
        x = self.up_1(x)
        output = self.conv_out(x)
        return ops.tanh(output)

# 实例化生成器
net_rg_a = ResNetGenerator()
net_rg_a.update_parameters_name('net_rg_a.')

net_rg_b = ResNetGenerator()
net_rg_b.update_parameters_name('net_rg_b.')

  1. 构建判别器网络

判别器其实是一个二分类网络模型,输出判定该图像为真实图的概率。网络模型使用的是 Patch 大小为 70x70 的 PatchGANs 模型。通过一系列的 Conv2d 、 BatchNorm2d 和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数得到最终概率。

# 定义判别器
class Discriminator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=3, alpha=0.2, norm_mode='instance'):
        super(Discriminator, self).__init__()
        kernel_size = 4
        layers = [nn.Conv2d(input_channel, output_channel, kernel_size, 2, pad_mode='pad', padding=1, weight_init=weight_init),
                  nn.LeakyReLU(alpha)]
        nf_mult = output_channel
        for i in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** i, 8) * output_channel
            layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))
        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8) * output_channel
        layers.append(ConvNormReLU(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))
        layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1, weight_init=weight_init))
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

# 判别器初始化
net_d_a = Discriminator()
net_d_a.update_parameters_name('net_d_a.')

net_d_b = Discriminator()
net_d_b.update_parameters_name('net_d_b.')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1932441.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

彻底搞定C指针系列

1.指针是什么&#xff1f; 运行一下代码 #include <stdio.h> int main() { int i 39; printf("%d\n", i); printf("%d\n",&i); return 0; } 运行结果&#xff1a; 39 618561996 运行测试代码&#xff1a; #include <stdio.…

计组 外围设备

外围设备 peripheral device 概述 基本组成部分&#xff1a;存储介质、驱动装置、控制电路 分类&#xff1a;输入设备、输出设备、外存设备、数据通信设备和过程控制设备 磁盘存储设备 磁化元/存储元&#xff1a;记录一个二进制信息位的最小单位 两个稳定的剩磁状态 读写磁…

2.javaWeb_请求和响应的处理(Request,Response)

2.请求和响应的处理 文章目录 2.请求和响应的处理一、动态资源和静态资源javax.servlet(包) 二、Servlet体系1.简介2.HttpServlet3.Servlet生命周期 三、Request对象1.ServletRequest1)ServletRequest主要功能有&#xff1a;2)ServletRequest类的常用方法: 2.HttpServletReques…

InterSystems IRIS Python 使用 DB-API连接 ,比odbc简单

1、下载安装驱动 同步的gitee下载地址&#xff1a;&#xff1a;于光/iris-driver-distribution - Gitee.com pip install intersystems_irispython-version-py3-none-any.whl 2、python代码测试 import iris as dbapi conn dbapi.connect(hostname127.0.0.1, port1972, names…

只讲干货!!自己喜欢的只能靠自己努力才能拿到!!今天拿下:IO流技术介绍

什么是IO 输入(Input) 指的是&#xff1a;可以让程序从外部系统获得数据&#xff08;核心含义是 “读 ” &#xff0c;读取外部数据&#xff09;。 输出(Output) 指的是&#xff1a;程序输出数据给外部系统从而可以操作外部 系统&#xff08;核心含义是“ 写 ” &#xff0c;将…

基于MindSpore实现BERT对话情绪识别

讲解视频&#xff1a; 基于MindSpore实现BERT对话情绪识别

防御笔记第七天(时需更新)

1.防火墙的可靠性&#xff1a; 因为防火墙不仅需要同步配置信息&#xff0c;还需要同步状态信息&#xff08;会话表等&#xff09;&#xff0c;所以防火墙不能像路由器那样单纯靠动态协议来进行切换&#xff0c;还需要用到双击热备技术。 双机---目前双机技术仅仅支持两台防火…

AI 大事件:超级明星 Andrej Karpathy 创立AI教育公司 Eureka Labs

&#x1f9e0; AI 大事件&#xff1a;超级明星 Andrej Karpathy 创立AI教育公司 Eureka Labs 摘要 Andrej Karpathy 作为前 OpenAI 联合创始人、Tesla AI 团队负责人&#xff0c;他的专业性和实力备受瞩目。Karpathy 对 AI 的普及和教育充满热情&#xff0c;从 YouTube 教程到…

一招轻松解决猫毛 最值得买的浮毛空气净化器排名

作为一名6年资深铲屎官&#xff0c;我常常被朋友问到关于宠物空气净化器的各种问题。有的人认为这是个神器&#xff0c;而有的人则认为这完全是花钱买智商税。其实我刚开始对购买宠物空气净化器也持怀疑态度&#xff0c;心想这么多钱花下去真的有效吗&#xff1f;但使用后&…

【Linux】将IDEA项目部署到云服务器上,让其成为后台进程(保姆级教学,满满的干货~~)

目录 部署项目到云服务器什么是部署一、 创建MySQL数据库二、 修改idea配置项三、 数据打包四、 部署云服务器五、开放端口号六 、 验证程序 部署项目到云服务器 什么是部署 ⼯作中涉及到的"环境" 开发环境:开发⼈员写代码⽤的机器.测试环境:测试⼈员测试程序使⽤…

Springboot整合MyBatis实现数据库查询(二)

目录 第一章、准备1.1&#xff09;准备数据库表1.2&#xff09;创建springboot项目&#xff0c;添加依赖1.3&#xff09;使用mybatis逆向工程 第二章、代码开发2.1&#xff09;建包并编写代码2.2&#xff09;application配置文件2.3&#xff09;设置编译位置 第三章、测试访问3…

09 深度推荐模型演化中的“平衡与不平衡“规律

你好&#xff0c;我是大壮。08 讲我们介绍了深度推荐算法中的范式方法&#xff0c;并简单讲解了组合范式推荐方法&#xff0c;其中还提到了多层感知器&#xff08;MLP&#xff09;。因此&#xff0c;这一讲我们就以 MLP 组件为基础&#xff0c;讲解深度学习范式的其他组合推荐方…

Why is LlamaCPP freezing during inference?

题意&#xff1a;为什么LlamaCPP在推理过程中会冻结 问题背景&#xff1a; Im using the following code to try and recieve a response from LlamaCPP, used through the LlamaIndex library. My model is stored locally in a gguf file. Im trying to do inference on the…

【数学建模】——多领域资源优化中的创新应用-六大经典问题解答

目录 题目1&#xff1a;截取条材 题目 1.1问题描述 1.2 数学模型 1.3 求解 1.4 解答 题目2&#xff1a;商店进货销售计划 题目 2.1 问题描述 2.2 数学模型 2.3 求解 2.4 解答 题目3&#xff1a;货船装载问题 题目 3.1问题重述 3.2 数学模型 3.3 求解 3.4 解…

Beelzebub过程记录及工具集

文章目录 靶场搭建靶场测试过程安装dirsearch扫描目录wpscan扫描破解 靶场搭建 https://download.vulnhub.com/beelzebub/Beelzebub.zip 下载解压镜像&#xff0c;从vmware打开。 一键式开机即可。 打开后配置网络。 确保网络可达。 靶场测试过程 首先使用nmap扫描网段的存…

为什么品牌需要做 IP 形象?

品牌做IP形象的原因有多方面&#xff0c;这些原因共同构成了IP形象在品牌建设中的重要性和价值&#xff0c;主要原因有以下几个方面&#xff1a; 增强品牌识别度与记忆点&#xff1a; IP形象作为品牌的视觉符号&#xff0c;具有独特性和辨识性&#xff0c;能够在消费者心中留…

提高自动化测试脚本编写效率 5大关键注意事项

提高自动化测试脚本编写效率能加速测试周期&#xff0c;减少人工错误&#xff0c;提升软件质量&#xff0c;促进项目按时交付&#xff0c;增强团队生产力和项目成功率。而自动化测试脚本编写效率低下&#xff0c;往往会导致测试周期延长&#xff0c;增加项目成本&#xff0c;延…

【C#】已知有三个坐标点:P0、P1、P2,当满足P3和P4连成的一条直线 与 P0和P1连成一条直线平行且长度一致,该如何计算P3、P4?

问题描述 已知有三个坐标点&#xff1a;P0、P1、P2&#xff0c;当满足P3和P4连成的一条直线 与 P0和P1连成一条直线平行且长度一致&#xff0c;该如何计算P3、P4&#xff1f; 解决办法 思路一&#xff1a;斜率及点斜式方程 # 示例坐标 x0, y0 1, 1 # P0坐标 x1, y1 4, 4 # …

MySQL执行状态查看与分析

当mysql出现性能问题时&#xff0c;一般会查看mysql的执行状态&#xff0c;执行命令&#xff1a; show processlist 各列的含义 列名含义id一个标识&#xff0c;你要kill一个语句的时候使用&#xff0c;例如 mysql> kill 207user显示当前用户&#xff0c;如果不是root&…

生信软件27 - 基于python的基因注释数据查询/检索库mygene

1. mygene库简介 MyGene.info提供简单易用的REST Web服务来查询/检索基因注释数据&#xff0c;具有以下特点&#xff1a; mygene技术文档&#xff1a; https://docs.mygene.info/en/latest/ 多物种支持: 包括人、小鼠、大鼠、斑马鱼等多个模式生物&#xff1b; 多数据源聚合…