昇思学习打卡营第31天|深度解密 CycleGAN 图像风格迁移:从草图到线稿的无缝转化

news2024/10/5 5:35:29
1. 简介

        图像风格迁移是计算机视觉领域中的一个热门研究方向,其中 CycleGAN (循环对抗生成网络) 在无监督领域取得了显著的突破。与传统需要成对训练数据的模型如 Pix2Pix 不同,CycleGAN 不需要严格的成对数据,只需两类图片域数据,便可实现图像风格的迁移与互换。

        本篇博文将通过一个实际案例演示如何使用 CycleGAN 实现从草图到目标线稿图的图像风格迁移任务,并详细介绍 CycleGAN 的模型结构、数据处理及训练过程。

2. 模型介绍

        CycleGAN 的核心思想源自 "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks" 论文。该模型在不需要成对示例的情况下,学习将源域 X 的图像转换到目标域 Y。其应用领域包括风格迁移、图像增强和域适应等任务。

2.1 CycleGAN 网络结构

        CycleGAN 由两个 GAN 模型组成,其对称的架构允许在不同的域之间来回转换图像。具体而言,CycleGAN 使用两个生成器(G 和 F)和两个判别器(D_X 和 D_Y),生成器负责将域 X 的图像转换到域 Y,并通过判别器对生成结果进行真假判断。

        模型架构如下:

  1. 生成器:生成器采用 ResNet 结构,由 9 个残差块组成,适合处理 256x256 尺寸的图片。
  2. 判别器:判别器通过 PatchGAN 模型检测图像的真实性,以保证生成的图像足够逼真。
2.2 循环一致性损失

        CycleGAN 通过 循环一致性损失 来保证从域 X 到域 Y,再从域 Y 转换回域 X 的图像应尽可能接近原始图像。这种损失机制确保模型不会丢失重要的图像特征。

3. 数据集

        本案例使用的数据集包含线稿图和草图图像,所有图片大小为 256x256 像素。数据集分为训练集和测试集,训练集包含 25654 张图片,测试集包含约 100 张线稿图片和 116 张草图图片。

4. 模型实现
4.1 生成器模型

        生成器模型基于 ResNet 结构,通过卷积、反卷积及残差块实现图像风格的转换。以下是生成器的代码实现:

import mindspore.nn as nn

class ResidualBlock(nn.Cell):
    def __init__(self, dim):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.SequentialCell(
            nn.Conv2d(dim, dim, kernel_size=3, padding=1, pad_mode="pad"),
            nn.BatchNorm2d(dim),
            nn.ReLU(),
            nn.Conv2d(dim, dim, kernel_size=3, padding=1, pad_mode="pad"),
            nn.BatchNorm2d(dim)
        )

    def construct(self, x):
        return x + self.conv_block(x)

class ResNetGenerator(nn.Cell):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(ResNetGenerator, self).__init__()
        model = [
            nn.Conv2d(input_nc, 64, kernel_size=7, padding=3, pad_mode="pad"),
            nn.BatchNorm2d(64),
            nn.ReLU()
        ]

        # Downsampling
        model += [
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        ]

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(256)]

        # Upsampling
        model += [
            nn.Conv2dTranspose(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2dTranspose(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        ]

        model += [
            nn.Conv2d(64, output_nc, kernel_size=7, padding=3, pad_mode="pad"),
            nn.Tanh()
        ]

        self.model = nn.SequentialCell(model)

    def construct(self, x):
        return self.model(x)
4.2 判别器模型

        判别器基于 PatchGAN 的结构,通过卷积网络将输入图片划分为多个小的 patch,并分别进行真假判别。

class Discriminator(nn.Cell):
    def __init__(self, input_nc, ndf=64):
        super(Discriminator, self).__init__()
        self.model = nn.SequentialCell([
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2),

            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2),

            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=1, padding=1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2),

            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=1)
        ])

    def construct(self, x):
        return self.model(x)
4.3 优化器与损失函数

        CycleGAN 采用对抗性损失和循环一致性损失的组合来训练生成器和判别器。优化器选择了 Adam 优化器,学习率设置为 0.0002。

import mindspore as ms

# 定义损失函数和优化器
gan_loss = nn.BCELoss()
cycle_loss = nn.L1Loss()

optimizer_G = nn.Adam(generator.parameters(), learning_rate=0.0002)
optimizer_D = nn.Adam(discriminator.parameters(), learning_rate=0.0002)
5. 训练与推理

        训练过程中,我们交替训练生成器和判别器。判别器通过真假样本的判别进行训练,而生成器则通过对抗判别和循环一致性进行优化。以下是一个训练步骤的实现:

def train_step(real_A, real_B):
    # 生成器前向计算
    fake_B = generator_A2B(real_A)
    fake_A = generator_B2A(real_B)
    
    # 判别器前向计算
    D_A_loss = gan_loss(discriminator_A(fake_A), Tensor(0)) + gan_loss(discriminator_A(real_A), Tensor(1))
    D_B_loss = gan_loss(discriminator_B(fake_B), Tensor(0)) + gan_loss(discriminator_B(real_B), Tensor(1))
    
    # 生成器损失计算
    cycle_A_loss = cycle_loss(generator_B2A(fake_B), real_A)
    cycle_B_loss = cycle_loss(generator_A2B(fake_A), real_B)
    G_loss = cycle_A_loss + cycle_B_loss + D_A_loss + D_B_loss
    
    optimizer_G.step()
    optimizer_D.step()
    
    return G_loss, D_A_loss, D_B_loss

结语

        通过本次的CycleGAN模型实践,我们深入理解了图像风格迁移的基本原理,特别是在无监督情况下如何实现两个域之间的图像转换。CycleGAN的循环一致性损失在保持图像内容一致性的同时,又能实现风格的转换,这是其在域迁移任务中广泛应用的重要原因。在整个实现过程中,不仅对生成器和判别器的构建有了更清晰的理解,同时也进一步熟悉了损失函数的优化策略。

        这次实验的关键在于让模型具备在没有配对数据的情况下,也能够进行风格转换的能力。虽然实验需要较大的计算资源,但我们通过小规模数据集也能够体验到CycleGAN的强大之处。希望通过这个项目,我们不仅能掌握CycleGAN的基本原理,也能为以后的图像生成和风格迁移任务打下坚实的基础。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2189361.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【redis学习篇1】redis基本常用命令

目录 redis存储数据的模式 常用基本命令 一、set 二、keys pattern keys 字符串当中携带问号 keys 字符串当中携带*号 keys 【^字母】 keys * 三、exists 四、del 五、expire 5.1 ttl命令 5.2key删除策略 5.2.1惰性删除 5.2.2定期删除 六、type key的数据类型…

数据结构--线性表(顺序结构)

1.线性表的定义和基本操作 1.1线性表以及基本逻辑 1.1.1线性表 (1)n(>0)个数据元素的有限序列,记作(a1,a2,...an),其中ai是线性表中的数据元素,n是表的长度。 (2)…

【RabbitMQ】RabbitMQ学习

1. 发送流程 生产者 - connection - channel - 交换机 - 对列- channel - connection - 消费者 2. 工作模式 2.1. 简单模式(点对点) 一个消费者一个生产者,直接进行通信。 2.2. 工作对列模式 多个消费者共同消费消息对列中的消息。同一条…

10其他内容补充

如何生成随机数原理详细分析 文章目录 如何生成随机数原理详细分析原理如果使用相同的随机数种子,得到的随机数序列会是相同的吗示例为什么需要随机数种子 动态内存管理前言malloc函数calloc函数realloc函数free函数 - 避免内存泄漏常见的动态内存错误 原理 说到如何生成一个随…

实现TCP Connect的断线重连机制:策略与实践

🍑个人主页:Jupiter. 🚀 所属专栏:Linux从入门到进阶 欢迎大家点赞收藏评论😊 断线重连机制,它成为确保应用在网络不稳定情况下仍能持续提供服务的关键技术之一。本文旨在深入探讨TCP(传输控制协…

浅聊前后端分离开发和前后端不分离开发模式

1.先聊聊Web开发的开发框架Spring MVC 首先要知道,Spring MVC是Web开发领域的一个知名框架,可以开发基于请求-响应模式的Web应用。而Web开发的本质是遵循HTTP(Hyper Text Transfer Protocol: 超文本传输协议)协议【发请求&#xf…

仿RabbitMQ实现消息队列客户端

文章目录 客⼾端模块实现订阅者模块信道管理模块异步⼯作线程实现连接管理模块生产者客户端消费者客户端 客⼾端模块实现 在RabbitMQ中,提供服务的是信道,因此在客⼾端的实现中,弱化了Client客⼾端的概念,也就是说在RabbitMQ中并…

V2M2引擎源码BlueCodePXL源码完整版

V2M2引擎源码BlueCodePXL源码完整版 链接: https://pan.baidu.com/s/1ifcTHAxcbD2CyY7gDWRVzQ?pwdmt4g 提取码: mt4g 参考资料:BlueCodePXL源码完整版_1234FCOM专注游戏工具及源码例子分享

图解大模型计算加速系列:vLLM源码解析3,块管理器(BlockManager)上篇

vllm块管理器又分成朴素块管理器(UncachedBlockAllocator)和prefix caching型块管理器(CachedBlockAllocator)。本篇我们先讲比较简单的前者,下篇我们来细看更有趣也是更难的后者。 【全文目录如下】 【1】前情提要…

阿里巴巴开源的FastJson 1反序列化漏洞复现攻击保姆级教程

免责申明 本文仅是用于学习检测自己搭建的靶场环境有关FastJson1反序列化漏洞的原理和攻击实验,请勿用在非法途径上,若将其用于非法目的,所造成的一切后果由您自行承担,产生的一切风险和后果与笔者无关;本文开始前请认真详细学习《‌中华人民共和国网络安全法》‌及其所在…

Linux高级编程_29_信号

文章目录 进程间通讯 - 信号信号完整的信号周期信号的编号信号的产生发送信号1 kill 函数(他杀)作用:语法:示例: 2 raise函数(自杀)作用:示例: 3 abort函数(自杀)作用:语法:示例: 4 …

macos安装git并连接gitCode远程仓库

文章目录 资料地址下载和安装初始化配置本地全局配置,SSH公私密钥生成远程SSH key配置 新建代码仓库,并关联到本地 资料地址 git官网地址gitCode地址 下载和安装 打开git官网地址,直接下载。【不建议使用brew,因为本人实践&…

【Qt】控件概述(2)—— 按钮类控件

控件概述(2) 1. PushButton2. RadioButton——单选按钮2.1 使用2.2 区分信号 clicked,clicked(bool),pressed,released,toggled(bool)2.3 QButtonGroup分组 3. CheckBox——复选按钮 1. PushButton QPushB…

B树系列解析

我最近开了几个专栏,诚信互三! > |||《算法专栏》::刷题教程来自网站《代码随想录》。||| > |||《C专栏》::记录我学习C的经历,看完你一定会有收获。||| > |||《Linux专栏》&#xff1…

etcd 快速入门

简介 随着go与kubernetes的大热,etcd作为一个基于go编写的分布式键值存储,逐渐为开发者所熟知,尤其是其还作为kubernetes的数据存储仓库,更是引起广泛专注。 本文我们就来聊一聊etcd到底是什么及其工作机制。 首先,…

查找回收站里隐藏的文件

在Windows里,每个磁盘分区都有一个隐藏的回收站Recycle, 回收站里保存着用户删除的文件、图片、视频等数据,比如,C盘的回收站为C:\RECYCLE.BIN\,D盘的的回收站为D:\RECYCLE.BIN\,E盘的的回收站为E:\RECYCLE…

【解决方案】JVM调优:给定资源条件下减少Full GC频率

1 缘起 在一次其他团队技术分享时,有幸进行了旁听, 谈到一个应用场景,服务端在给定的资源下,频繁Full GC, 降低了服务请求处理能力以及任务处理能力,频繁Full GC,导致服务处理能力下降, 服务在Full GC期间无法处理用户请求以及其他任务,服务不稳定,可以理解为服务在…

【C++算法】9.双指针_四数之和

文章目录 题目链接:题目描述:解法C 算法代码:图解 题目链接: 18.四数之和 题目描述: 解法 解法一:排序暴力枚举利用set去重 解法二:排序双指针 从左往右依次固定一个数a在a后面的区间里&#x…

坐标系变换总结

二维情况下的转换 1 缩放变换 形象理解就是图像在x方向和y方向上放大或者缩小。 代数形式: { x ′ k x x y ′ k y y \begin{cases} x k_x x \\ y k_y y \end{cases} {x′kx​xy′ky​y​ 矩阵形式: ( x ′ y ′ ) ( k x 0 0 k y ) ( x y ) \be…

【C语言】数据在内存中的存储(万字解析)

文章目录 一、大小端字节序和字节序判断1.案例引入2.什么是大小端字节序3.大小端字节序判断 二、整数在内存中的存储以及相关练习1.整型在内存中的存储2.练习练习1:练习2练习3练习4练习5:练习6 三、浮点数在内存中的存储1.案例引入2.浮点数在内存中的存储…