昇思25天学习打卡营第21天|CycleGAN 图像风格迁移互换

news2024/9/22 5:39:19

今天是参加昇思25天学习打卡营的第21天,今天打卡的课程是“CycleGAN 图像风格迁移互换”,这里做一个简单的分享。

1.简介

从今天开始到第25天的学习内容都是生成式网络的内容。今天要学习的第一个生成式网络是CycleGAN,目标是实现图像风格迁移互换。

CycleGAN(Cycle Generative Adversarial Network) 即循环对抗生成网络,来自论文 Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks 。该模型实现了一种在没有配对示例的情况下学习将图像从源域 X 转换到目标域 Y 的方法。

该模型一个重要应用领域是域迁移(Domain Adaptation),可以通俗地理解为图像风格迁移。其实在 CycleGAN 之前,就已经有了域迁移模型,比如 Pix2Pix ,但是 Pix2Pix 要求训练数据必须是成对的,而现实生活中,要找到两个域(画风)中成对出现的图片是相当困难的,因此 CycleGAN 诞生了,它只需要两种域的数据,而不需要他们有严格对应关系,是一种新的无监督的图像迁移网络。

2.模型架构

  • 模型结构

CycleGAN 网络本质上是由两个镜像对称的 GAN 网络组成,其结构如下图所示(图片来源于原论文):

CycleGAN

为了方便理解,这里以苹果和橘子为例介绍。上图中 𝑋𝑋 可以理解为苹果,𝑌𝑌 为橘子;𝐺𝐺 为将苹果生成橘子风格的生成器,𝐹𝐹 为将橘子生成的苹果风格的生成器,𝐷𝑋𝐷𝑋 和 𝐷𝑌𝐷𝑌 为其相应判别器,具体生成器和判别器的结构可见下文代码。模型最终能够输出两个模型的权重,分别将两种图像的风格进行彼此迁移,生成新的图像。

该模型一个很重要的部分就是损失函数,在所有损失里面循环一致损失(Cycle Consistency Loss)是最重要的。循环损失的计算过程如下图所示(图片来源于原论文):

Cycle Consistency Loss

图中苹果图片 𝑥𝑥 经过生成器 𝐺𝐺 得到伪橘子 𝑌̂ 𝑌^,然后将伪橘子 𝑌̂ 𝑌^ 结果送进生成器 𝐹𝐹 又产生苹果风格的结果 𝑥̂ 𝑥^,最后将生成的苹果风格结果 𝑥̂ 𝑥^ 与原苹果图片 𝑥𝑥 一起计算出循环一致损失,反之亦然。循环损失捕捉了这样的直觉,即如果我们从一个域转换到另一个域,然后再转换回来,我们应该到达我们开始的地方。

  • 生成器构建

本案例生成器的模型结构参考的 ResNet 模型的结构,参考原论文,对于128×128大小的输入图片采用6个残差块相连,图片大小为256×256以上的需要采用9个残差块相连,所以本文网络有9个残差块相连,超参数 n_layers 参数控制残差块数。

生成器的结构如下所示:

CycleGAN Generator

具体的模型结构请参照下文代码:

import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal

weight_init = Normal(sigma=0.02)

class ConvNormReLU(nn.Cell):
    def __init__(self, input_channel, out_planes, kernel_size=4, stride=2, alpha=0.2, norm_mode='instance',
                 pad_mode='CONSTANT', use_relu=True, padding=None, transpose=False):
        super(ConvNormReLU, self).__init__()
        norm = nn.BatchNorm2d(out_planes)
        if norm_mode == 'instance':
            norm = nn.BatchNorm2d(out_planes, affine=False)
        has_bias = (norm_mode == 'instance')
        if padding is None:
            padding = (kernel_size - 1) // 2
        if pad_mode == 'CONSTANT':
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='same',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, padding=padding, weight_init=weight_init)
            layers = [conv, norm]
        else:
            paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))
            pad = nn.Pad(paddings=paddings, mode=pad_mode)
            if transpose:
                conv = nn.Conv2dTranspose(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                          has_bias=has_bias, weight_init=weight_init)
            else:
                conv = nn.Conv2d(input_channel, out_planes, kernel_size, stride, pad_mode='pad',
                                 has_bias=has_bias, weight_init=weight_init)
            layers = [pad, conv, norm]
        if use_relu:
            relu = nn.ReLU()
            if alpha > 0:
                relu = nn.LeakyReLU(alpha)
            layers.append(relu)
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output


class ResidualBlock(nn.Cell):
    def __init__(self, dim, norm_mode='instance', dropout=False, pad_mode="CONSTANT"):
        super(ResidualBlock, self).__init__()
        self.conv1 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode)
        self.conv2 = ConvNormReLU(dim, dim, 3, 1, 0, norm_mode, pad_mode, use_relu=False)
        self.dropout = dropout
        if dropout:
            self.dropout = nn.Dropout(p=0.5)

    def construct(self, x):
        out = self.conv1(x)
        if self.dropout:
            out = self.dropout(out)
        out = self.conv2(out)
        return x + out


class ResNetGenerator(nn.Cell):
    def __init__(self, input_channel=3, output_channel=64, n_layers=9, alpha=0.2, norm_mode='instance', dropout=False,
                 pad_mode="CONSTANT"):
        super(ResNetGenerator, self).__init__()
        self.conv_in = ConvNormReLU(input_channel, output_channel, 7, 1, alpha, norm_mode, pad_mode=pad_mode)
        self.down_1 = ConvNormReLU(output_channel, output_channel * 2, 3, 2, alpha, norm_mode)
        self.down_2 = ConvNormReLU(output_channel * 2, output_channel * 4, 3, 2, alpha, norm_mode)
        layers = [ResidualBlock(output_channel * 4, norm_mode, dropout=dropout, pad_mode=pad_mode)] * n_layers
        self.residuals = nn.SequentialCell(layers)
        self.up_2 = ConvNormReLU(output_channel * 4, output_channel * 2, 3, 2, alpha, norm_mode, transpose=True)
        self.up_1 = ConvNormReLU(output_channel * 2, output_channel, 3, 2, alpha, norm_mode, transpose=True)
        if pad_mode == "CONSTANT":
            self.conv_out = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad',
                                      padding=3, weight_init=weight_init)
        else:
            pad = nn.Pad(paddings=((0, 0), (0, 0), (3, 3), (3, 3)), mode=pad_mode)
            conv = nn.Conv2d(output_channel, 3, kernel_size=7, stride=1, pad_mode='pad', weight_init=weight_init)
            self.conv_out = nn.SequentialCell([pad, conv])

    def construct(self, x):
        x = self.conv_in(x)
        x = self.down_1(x)
        x = self.down_2(x)
        x = self.residuals(x)
        x = self.up_2(x)
        x = self.up_1(x)
        output = self.conv_out(x)
        return ops.tanh(output)

# 实例化生成器
net_rg_a = ResNetGenerator()
net_rg_a.update_parameters_name('net_rg_a.')

net_rg_b = ResNetGenerator()
net_rg_b.update_parameters_name('net_rg_b.')

  • 优化器和损失函数
    在这里插入图片描述
# 构建生成器,判别器优化器
optimizer_rg_a = nn.Adam(net_rg_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_rg_b = nn.Adam(net_rg_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

optimizer_d_a = nn.Adam(net_d_a.trainable_params(), learning_rate=0.0002, beta1=0.5)
optimizer_d_b = nn.Adam(net_d_b.trainable_params(), learning_rate=0.0002, beta1=0.5)

# GAN网络损失函数,这里最后一层不使用sigmoid函数
loss_fn = nn.MSELoss(reduction='mean')
l1_loss = nn.L1Loss("mean")

def gan_loss(predict, target):
    target = ops.ones_like(predict) * target
    loss = loss_fn(predict, target)
    return loss
  • 前向计算

为了减少模型振荡[1],这里遵循 Shrivastava 等人的策略[2],使用生成器生成图像的历史数据而不是生成器生成的最新图像数据来更新鉴别器。这里创建 image_pool 函数,保留了一个图像缓冲区,用于存储生成器生成前的50个图像。

import mindspore as ms

# 前向计算

def generator(img_a, img_b):
    fake_a = net_rg_b(img_b)
    fake_b = net_rg_a(img_a)
    rec_a = net_rg_b(fake_b)
    rec_b = net_rg_a(fake_a)
    identity_a = net_rg_b(img_a)
    identity_b = net_rg_a(img_b)
    return fake_a, fake_b, rec_a, rec_b, identity_a, identity_b

lambda_a = 10.0
lambda_b = 10.0
lambda_idt = 0.5

def generator_forward(img_a, img_b):
    true = Tensor(True, dtype.bool_)
    fake_a, fake_b, rec_a, rec_b, identity_a, identity_b = generator(img_a, img_b)
    loss_g_a = gan_loss(net_d_b(fake_b), true)
    loss_g_b = gan_loss(net_d_a(fake_a), true)
    loss_c_a = l1_loss(rec_a, img_a) * lambda_a
    loss_c_b = l1_loss(rec_b, img_b) * lambda_b
    loss_idt_a = l1_loss(identity_a, img_a) * lambda_a * lambda_idt
    loss_idt_b = l1_loss(identity_b, img_b) * lambda_b * lambda_idt
    loss_g = loss_g_a + loss_g_b + loss_c_a + loss_c_b + loss_idt_a + loss_idt_b
    return fake_a, fake_b, loss_g, loss_g_a, loss_g_b, loss_c_a, loss_c_b, loss_idt_a, loss_idt_b

def generator_forward_grad(img_a, img_b):
    _, _, loss_g, _, _, _, _, _, _ = generator_forward(img_a, img_b)
    return loss_g

def discriminator_forward(img_a, img_b, fake_a, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    loss_d = (loss_d_a + loss_d_b) * 0.5
    return loss_d

def discriminator_forward_a(img_a, fake_a):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_a = net_d_a(fake_a)
    d_img_a = net_d_a(img_a)
    loss_d_a = gan_loss(d_fake_a, false) + gan_loss(d_img_a, true)
    return loss_d_a

def discriminator_forward_b(img_b, fake_b):
    false = Tensor(False, dtype.bool_)
    true = Tensor(True, dtype.bool_)
    d_fake_b = net_d_b(fake_b)
    d_img_b = net_d_b(img_b)
    loss_d_b = gan_loss(d_fake_b, false) + gan_loss(d_img_b, true)
    return loss_d_b

# 保留了一个图像缓冲区,用来存储之前创建的50个图像
pool_size = 50
def image_pool(images):
    num_imgs = 0
    image1 = []
    if isinstance(images, Tensor):
        images = images.asnumpy()
    return_images = []
    for image in images:
        if num_imgs < pool_size:
            num_imgs = num_imgs + 1
            image1.append(image)
            return_images.append(image)
        else:
            if random.uniform(0, 1) > 0.5:
                random_id = random.randint(0, pool_size - 1)

                tmp = image1[random_id].copy()
                image1[random_id] = image
                return_images.append(tmp)

            else:
                return_images.append(image)
    output = Tensor(return_images, ms.float32)
    if output.ndim != 4:
        raise ValueError("img should be 4d, but get shape {}".format(output.shape))
    return output
  • 计算梯度和反向传播

其中梯度计算也是分开不同的模型来进行的:

from mindspore import value_and_grad

# 实例化求梯度的方法
grad_g_a = value_and_grad(generator_forward_grad, None, net_rg_a.trainable_params())
grad_g_b = value_and_grad(generator_forward_grad, None, net_rg_b.trainable_params())

grad_d_a = value_and_grad(discriminator_forward_a, None, net_d_a.trainable_params())
grad_d_b = value_and_grad(discriminator_forward_b, None, net_d_b.trainable_params())

# 计算生成器的梯度,反向传播更新参数
def train_step_g(img_a, img_b):
    net_d_a.set_grad(False)
    net_d_b.set_grad(False)

    fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib = generator_forward(img_a, img_b)

    _, grads_g_a = grad_g_a(img_a, img_b)
    _, grads_g_b = grad_g_b(img_a, img_b)
    optimizer_rg_a(grads_g_a)
    optimizer_rg_b(grads_g_b)

    return fake_a, fake_b, lg, lga, lgb, lca, lcb, lia, lib

# 计算判别器的梯度,反向传播更新参数
def train_step_d(img_a, img_b, fake_a, fake_b):
    net_d_a.set_grad(True)
    net_d_b.set_grad(True)

    loss_d_a, grads_d_a = grad_d_a(img_a, fake_a)
    loss_d_b, grads_d_b = grad_d_b(img_b, fake_b)

    loss_d = (loss_d_a + loss_d_b) * 0.5

    optimizer_d_a(grads_d_a)
    optimizer_d_b(grads_d_b)

    return loss_d

3.小结

CycleGAN是两个镜像对称的GAN网络组成,属于无监督学习类型的模型,不需要准备完全匹配的成对样本,这大大降低了数据准备的难度,也使得模型可以应用的范围更加广泛。通俗的理解,在这个模型中,先通过构建生成网络模型,提取图像的内容和风格特征,并尝试进行风格迁移图像生成,通过定义损失函数来衡量生成图像与内容图像和风格图像的相似度。然后一个的镜像的结构,进行逆向的风格迁移,通过循环一致损失函数来评估一次循环的损失来纠正模型学习的方向。

以上是第21天的学习内容,附上今日打卡记录:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1926357.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

苹果将在2025年春季通过iOS更新大幅提升Siri的智能|TodayAI

据彭博社&#xff08;Bloomberg&#xff09;最新报道&#xff0c;苹果公司&#xff08;Apple&#xff09;计划在2025年春季推出其最新的Apple Intelligence AI系统&#xff0c;通过iOS 18.4版本更新为广大iPhone用户带来更智能的Siri。这一消息由科技行业知名记者马克古尔曼在其…

【从0到1进阶Redis】主从复制

1、概念 主从复制&#xff0c;是指将一个台 Redis 服务器的数据&#xff0c;复制到其他的 Redis 服务器。前者称为主节点&#xff08;master/leader&#xff09;&#xff0c;后者称为从节点&#xff08;slave/follower&#xff09;&#xff1b;数据的复制是单向的&#xff0c;…

计算机网络生成树协议介绍与实践

生成树协议 1.环路 二层环路&#xff1a;数据链路层&#xff0c;交换机&#xff08;二层设备&#xff09;通过线路连接环状。即物理成环并且没有开启防环协议。 危害&#xff1a;广播风暴&#xff1a;交换机将未知帧广播&#xff0c;收到后的交换机继续广播&#xff0c;不断…

C语言--递归

曾经有一个段子&#xff1a;上大学时&#xff0c;我们的c语言老师说&#xff1a;学c时&#xff0c;如果有50%的同学死在了循环上面&#xff0c;那么就有90%的同学死在了递归上面。接下来&#xff0c;就来看看递归是怎么个事&#xff1f; 一.递归的介绍 递归是指一个函数直接或…

CV09_深度学习模块之间的缝合教学(4)--调参

深度学习就像炼丹。炉子就是模型&#xff0c;火候就是那些参数&#xff0c;材料就是数据集。 1.1 参数有哪些 调参调参&#xff0c;参数到底是哪些参数&#xff1f; 1.网络相关的参数&#xff1a;&#xff08;1&#xff09;神经网络网络层 &#xff08;2&#xff09;隐藏层…

偶数位的数c++

题目描述 给你两个整数 l,r&#xff0c;求 l∼r 范围内有多少个位数为偶数的数。 输入 一行两个整数 l,r。 输出 输出位数为偶数的数的数量。 样例输入 5 15样例输出 6 提示 样例解释 10,11,12,13,14,15 位数为偶数&#xff0c;都是两位数。 数据规模与约定 对于 1…

过滤器、监听器、拦截器

目录 一、 主要内容 二、 过滤器 2.1 介绍 2.2 实现 MyFilter01 MyFilter02 MyServlet01 MyServlet02 2.3 说明 2.4 执行顺序 1. 使用 web.xml 2. 使用注解和Order 3. 使用FilterRegistrationBean 2.5字符乱码 三、监听器 3.1介绍 3.2实现 3.3在线人数统计 O…

克洛托光电再度合作福晶科技,高精度光学镜头装调仪正式交付

近日&#xff0c;苏州东方克洛托光电技术有限公司&#xff08;下称“克洛托光电”&#xff09;高精度光学镜头装调仪正式交付于福建福晶科技股份有限公司&#xff0c;研发人员在现场完成设备安装调试并介绍使用方法。据悉&#xff0c;这已是双方第二次展开合作。 前沿产品力助推…

企业知识库用不起来?试一下用HelpLook同步钉钉组织架构

提升企业管理和协同效率已成为增强竞争力的关键。企业通过知识管理&#xff0c;搭建内部知识库&#xff0c;将分散的经验和知识转化为系统化流程&#xff0c;减少重复解释&#xff0c;促进业务高效运作。这为企业提供了坚实的基础。 企业知识库面临的挑战 尽管传统知识库内容丰…

银河麒麟高级服务器操作系统V10加固操作指南

1:检查系统openssh安全配置: 2:检查是否设置口令过期前警告天数: 3:检查账户认证失败次数限制: 修改/etc/pam.d/system-auth文件中deny的参数即可 4:检查是否配置SSH方式账户认证失败次数限制:

github相关命令

如果我们要从 GitHub 上拉取一个项目到本地&#xff0c;进行修改并上传回去&#xff0c;通常需要以下步骤&#xff1a; 1. 克隆远程仓库到本地 使用 git clone 命令将 GitHub 上的项目克隆到本地&#xff1a; (网址示例如下所示&#xff09; git clone https://github.com/你的…

多旋翼无人机挂载多功能抛投器技术详解

多旋翼无人机&#xff0c;作为一种具有高效、灵活、稳定等特性的无人驾驶飞行器&#xff0c;在现代社会的多个领域得到了广泛应用。其中&#xff0c;挂载多功能抛投器技术&#xff0c;使得无人机在物资投送、救援等任务中发挥出更加重要的作用。以下将详细介绍多旋翼无人机挂载…

正则表达式怎么控制匹配的字符串更近的一个

http((?!http).)*m3u8 正则表达式怎么控制匹配的字符串更近的一个 正则如何匹配最近的字符 正则如何匹配最近的两个字符 怎么控制只要离字符串b匹配更近一点的字符串a 解释 a.b&#xff0c;它将会匹配最长的以a开始&#xff0c;以b结束的字符串 a.?b匹配最短的&#xff…

测试开发面经总结(三)

TCP三次握手 TCP 是面向连接的协议&#xff0c;所以使用 TCP 前必须先建立连接&#xff0c;而建立连接是通过三次握手来进行的。 一开始&#xff0c;客户端和服务端都处于 CLOSE 状态。先是服务端主动监听某个端口&#xff0c;处于 LISTEN 状态 客户端会随机初始化序号&…

微信小程序密码 显示隐藏 真机兼容问题

之前使用type来控制&#xff0c;发现不行&#xff0c;修改为password属性即可 <van-fieldright-icon"{{passwordType password? closed-eye:eye-o}}"model:value"{{ password }}"password"{{passwordType password ? true: false}}"borde…

数仓工具—Hive语法之排除特定列

排除特定列 Apache Hive是一个基于Hadoop HDFS的数据仓库框架,用于存储和分析大量数据。Apache Hive支持大多数关系数据库功能,如对大型表进行分区和根据分区列存储值。 在本文中,我们将检查从SELECT查询中排除Hive分区列的方法。 这个在我们需要表中大量列的时候,例如一…

【python】OpenCV—European Article Number

参考学习来自&#xff1a;OpenCV基础&#xff08;25&#xff09;条码和二维码扫的生成与识别 1 条形码介绍 EAN-13是欧洲物品编码&#xff08;European Article Number&#xff09;的缩写&#xff0c;是一种广泛使用的条形码标准&#xff0c;特别是在超级市场和其它零售业中。…

第二周周日学习总结

题目总结 1. 给你一个仅由数字组成的字符串 s&#xff0c;在最多交换一次 相邻 且具有相同 奇偶性 的数字后&#xff0c;返回可以得到的 字典序最小的字符串 。 如果两个数字都是奇数或都是偶数&#xff0c;则它们具有相同的奇偶性。例如&#xff0c;5 和 9、2 和 4 奇偶性…

zookeeper基础知识学习

官网&#xff1a;Apache ZooKeeper 下载地址&#xff1a;Index of /dist/zookeeper/zookeeper-3.5.7Index of /dist/zookeeperIndex of /dist/zookeeper/zookeeper-3.5.7 ZK配置参数说明&#xff1a; 1、tickTime2000&#xff1a;通讯心跳时间&#xff0c;zookeeper服务器与客…

Reinforced Causal Explainer for GNN论文笔记

论文&#xff1a;TPAMI 2023 图神经网络的强化因果解释器 论文代码地址&#xff1a;代码 目录 Abstract Introduction PRELIMINARIES Causal Attribution of a Holistic Subgraph​ individual causal effect (ICE)​ *Causal Screening of an Edge Sequence Reinforc…