1、CycleGAN

news2024/11/13 9:47:20

1、CycleGAN

CycleGAN论文链接:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks

CycleGAN 是一种流行的深度学习模型,用于图像到图像的转换任务,且不需要成对的数据。在介绍CycleGAN之前,必须对于传统的GAN模型有了解。

一、关于GAN

GAN是对抗生成网络的缩写,对抗生成网络简单来说由两部分组成,分别是生成器和判别器

  1. 生成器(Generator): 生成器的任务是从随机噪声中生成虚假的数据(例如图片),以欺骗判别器,让判别器认为这些数据是真实的。生成器通过不断优化,提高生成图像的质量,使其越来越接近真实数据的分布。
  2. 判别器(Discriminator): 判别器的任务是区分真实数据和生成器生成的假数据。它通过判断输入的图像是真实的还是生成的来给出一个概率,并通过不断优化来提高识别假图像的能力。
  3. 对抗训练: 生成器和判别器在训练过程中互相对抗。生成器试图生成更逼真的数据来欺骗判别器,而判别器则努力提高识别能力。这个过程可以视为一个博弈过程,最终生成器能够生成与真实数据非常接近的高质量数据。

其实在GAN中,最重要的就是他的损失函数的构建,有了损失函数,就可以进行参数更新,使生成器和判别器的能力不断的提升,使生成器达到以假乱真的程度。

损失函数由生成器损失函数和判别器损失函数组成,对于生成器的损失函数来说,它希望生成出来的图像能够瞒过判别器 ,所以计算损失函数时,将自己的标签全部设置为1,而对于判别器来说,就要将生成器生成的图像判别为假,将真实正确的判别为真,两者求平均作为损失函数值,来进行梯度更新。

在这里插入图片描述

二、CycleGAN

CycleGAN 是一种流行的深度学习模型,用于图像到图像的转换任务,改变图片的风格,换脸等,其次它最大的特点就是需要训练的数据集不需要是成对出现,仅仅需要两个风格图片,就可以完成一张图片的风格转换。

在这里插入图片描述

CycleGAN架构图

对于CycleGAN,生成器中由六个损失值来控制参数更新,而判别器由两个损失值来控制参数更新

源码思路:

生成器参数更新:

  • 首先由生成器先初始化随机向量Input_B,然后通过学习由Generator_B2A生成Generated_A,计算Input_A和Generated_A的Loss值,在通过Generator_A2B 生成Cyclic_B,然后计算一个Cycle_Loss值。
  • 同时在源码中加入Identify_Loss值,逐点计算Loss值(其中就是将Input_B输入到Generator_A2B,本身B就是正确的,通过Generator_A2B生成identify_B,应该和Input_B越相似越好),Identify_Loss值也加入了参数更新Loss值中。

判别器参数更新:

​ 对于判别器,思路与传统的GAN的Loss值计算是一样的,也是计算将正确的预测为正确的loss值,加上生成的预测为错误的loss值做平均的值。但是具体的计算方法采用PatchGAN的计算思路。

局部判断(Patch-based Discrimination):
PatchGAN 将输入图像划分为多个小的感受野区域(patch),而不是对整张图像进行判断。对于每个感受野区域,PatchGAN 判别器会输出一个结果,判断该区域是真实还是生成的。
判别器的输出结果是一个大小为 N× N的矩阵,每个值对应图像中某个小区域的判别结果。

感受野的作用:
PatchGAN 的感受野是判别器看到的图像的局部区域大小,通常比整张图像小得多。感受野大小取决于网络的架构(如卷积核大小和网络层数)。每个感受野对应输出矩阵中的一个值,该值表示该局部区域是否来自真实图像。通过这种局部判断,PatchGAN 只需要判断局部模式是否与真实数据一致。

损失计算:

  • 预测结果: PatchGAN 的输出是一个 N×N 的矩阵,其中每个元素表示一个局部区域(感受野)的分类结果,即该区域是真实图像的概率。
  • 标签设置: 当计算损失时,真实图像的标签通常设置为一个同样大小为 N×N的矩阵,所有值为 1(表示这些区域都应该被判定为真实)。对于生成的假图像,则期望该矩阵的标签为 0。
  • 损失函数: 判别器的损失是基于这些感受野区域的输出值和标签矩阵的差异来计算的,可以使用二元交叉熵或其他适当的损失函数。

在这里插入图片描述

生成器参数更新
  1. 首先由生成器先初始化随机向量Input_B,然后通过学习由Generator_B2A生成Generated_A,计算Input_A和Generated_A的Loss值,在通过Generator_A2B 生成Cyclic_B,然后计算一个Cycle_Loss值。
  2. 同时在源码中加入Identify_Loss值,逐点计算Loss值(其中就是将Input_B输入到Generator_A2B,本身B就是正确的,通过Generator_A2B生成identify_B,应该和Input_B越相似越好),Identify_Loss值也加入了参数更新Loss值中。

源码解析:

class CycleGANModel(BaseModel):
    def backward_G(self):
        """计算生成器 G_A 和 G_B 的损失函数"""

        # 获取参数 lambda_identity, lambda_A, lambda_B
        # lambda_identity 控制身份损失项(identity loss)的权重
        # lambda_A 和 lambda_B 控制循环一致性损失项(cycle loss)的权重
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B

        # 如果 lambda_idt > 0,计算身份损失(identity loss)
        # 身份损失确保生成器在输入与输出相同图像时不会改变输入
        if lambda_idt > 0:
            # 对于生成器 G_A 来说,当输入真实图像 B 时,输出应为 B,即 G_A(B) ≈ B
            # ||G_A(B) - B|| 的损失
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt

            # 对于生成器 G_B 来说,当输入真实图像 A 时,输出应为 A,即 G_B(A) ≈ A
            # ||G_B(A) - A|| 的损失
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            # 如果 lambda_idt = 0,跳过身份损失,设为 0
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # 生成器 G_A 的 GAN 损失:试图让判别器 D_A 认为生成的图像 fake_B 是真实的
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)

        # 生成器 G_B 的 GAN 损失:试图让判别器 D_B 认为生成的图像 fake_A 是真实的
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)

        # 正向循环一致性损失:确保 G_B(G_A(A)) ≈ A,生成器 G_A 和 G_B 生成的图像应能还原到原图像
        # || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A

        # 反向循环一致性损失:确保 G_A(G_B(B)) ≈ B
        # || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B

        # 综合生成器的所有损失,包括 GAN 损失、循环一致性损失和身份损失
        self.loss_G = (self.loss_G_A + self.loss_G_B + 
                       self.loss_cycle_A + self.loss_cycle_B + 
                       self.loss_idt_A + self.loss_idt_B)

        # 计算梯度并进行反向传播,以更新生成器的参数
        self.loss_G.backward()
判别器参数更新

对于判别器,思路与传统的GAN的Loss值计算是一样的,也是计算将正确的预测为正确的loss值,加上生成的预测为错误的loss值做平均的值。但是具体的计算方法采用PatchGAN的计算思路。

局部判断(Patch-based Discrimination):
PatchGAN 将输入图像划分为多个小的感受野区域(patch),而不是对整张图像进行判断。对于每个感受野区域,PatchGAN 判别器会输出一个结果,判断该区域是真实还是生成的。
判别器的输出结果是一个大小为 N× N的矩阵,每个值对应图像中某个小区域的判别结果。

感受野的作用:
PatchGAN 的感受野是判别器看到的图像的局部区域大小,通常比整张图像小得多。感受野大小取决于网络的架构(如卷积核大小和网络层数)。每个感受野对应输出矩阵中的一个值,该值表示该局部区域是否来自真实图像。通过这种局部判断,PatchGAN 只需要判断局部模式是否与真实数据一致。
在这里插入图片描述

损失计算:

  • 预测结果: PatchGAN 的输出是一个 N×N 的矩阵,其中每个元素表示一个局部区域(感受野)的分类结果,即该区域是真实图像的概率。
  • 标签设置: 当计算损失时,真实图像的标签通常设置为一个同样大小为 N×N的矩阵,所有值为 1(表示这些区域都应该被判定为真实)。对于生成的假图像,则期望该矩阵的标签为 0。
  • 损失函数: 判别器的损失是基于这些感受野区域的输出值和标签矩阵的差异来计算的,可以使用二元交叉熵或其他适当的损失函数。
def backward_D_basic(self, netD, real, fake):
    """计算判别器的 GAN 损失

    参数:
        netD (network)      -- 判别器网络 D
        real (tensor array) -- 真实图像
        fake (tensor array) -- 生成器生成的假图像

    返回判别器的损失。
    同时,我们调用 loss_D.backward() 来计算梯度。
    """
    # 计算对真实图像的判别结果
    pred_real = netD(real)
    # 判别器的真实图像损失,期望判别器将真实图像判定为真实 (标签为 True)
    loss_D_real = self.criterionGAN(pred_real, True)

    # 对生成图像的判别结果,使用 .detach() 确保生成图像的梯度不会传回生成器
    pred_fake = netD(fake.detach())
    # 判别器的假图像损失,期望判别器将生成的假图像判定为假 (标签为 False)
    loss_D_fake = self.criterionGAN(pred_fake, False)

    # 将真实图像损失和假图像损失合并,取平均值作为判别器的总损失
    loss_D = (loss_D_real + loss_D_fake) * 0.5

    # 反向传播,计算损失相对于判别器权重的梯度
    loss_D.backward()

    # 返回判别器的总损失
    return loss_D

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2148493.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Msf之Python分离免杀

Msf之Python分离免杀 ——XyLin. 成果展示: VT查杀率:8/73 (virustotal.com) 火绒和360可以过掉,但Windows Defender点开就寄掉了 提示:我用360测的时候,免杀过了,但360同时也申报了,估计要不了多久就寄…

《Linux运维总结:基于Ubuntu 22.04操作系统+x86_64架构CPU部署二进制mongodb 7.0.14分片集群》

总结:整理不易,如果对你有帮助,可否点赞关注一下? 更多详细内容请参考:《Linux运维篇:Linux系统运维指南》 一、简介 1、应用场景 当您遇到如下问题时,可以使用分片集群解决: a、 存储容量受单机限制,即磁盘资源遭遇瓶颈。 b、 读写能力受单机限制,可能是CPU、内…

开关磁阻电机(SRM)系统的matlab性能仿真与分析

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 4.1 SRM的基本结构 4.2 SRM的电磁关系 4.3 SRM的输出力矩 5.完整工程文件 1.课题概述 开关磁阻电机(SRM)系统的matlab性能仿真与分析,对比平均转矩vs相电流,转矩脉动vs相电流&a…

Python OpenCV精讲系列 - 高级图像处理技术(九)

💖💖⚡️⚡️专栏:Python OpenCV精讲⚡️⚡️💖💖 本专栏聚焦于Python结合OpenCV库进行计算机视觉开发的专业教程。通过系统化的课程设计,从基础概念入手,逐步深入到图像处理、特征检测、物体识…

JavaWeb---纯小白笔记01:JavaWeb概述和Tomcat安装

本次将对WEB开发的相关的概念和Tomcat等进行介绍。 Web开发简介: C/S和B/S是两种常用的网络架构模式 区别: C/S:client/server --客户端与服务器之间直接进行通信,对用户,本地电脑要求高 B/S:browser/server--通过…

人工智能-大语言模型-微调技术-LoRA及背后原理简介

1. 《LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS》 LORA: 大型语言模型的低秩适应 摘要: 随着大规模预训练模型的发展,全参数微调变得越来越不可行。本文提出了一种名为LoRA(低秩适应)的方法,通过在Transf…

K8S容器实例Pod安装curl-vim-telnet工具

在没有域名的情况下,有时候需要调试接口等需要此工具 安装curl、telnet、vim等 直接使用 apk add curlapk add vimapk add tennet

Angular: ‘ng’ is not recognized as an internal or external command

背景 运行新项目的前端angular2项目时,需要全局安装angular-cli,然后使用ng serve --open命令启动项目。我安装好angular-cli后,在cmd里输入ng命令,死活无法识别。 解决过程 我按照网上的说法,去配置npm环境变量&am…

软考高级:数据库规范化: 1NF、2NF、3NF和 BCNF AI 解读

数据库的规范化是数据库设计中的一个重要过程,旨在减少数据冗余和提高数据一致性。它通过一系列规则(称为范式)来优化数据库表的结构。 常见的范式有1NF、2NF、3NF和BCNF。让我们分别来解释这些范式。 生活化例子 想象你在整理一个家庭成…

吴泳铭:AI最大的想象力不在手机屏幕,而是改变物理世界

刚刚,阿里巴巴集团CEO、阿里云智能集团董事长兼CEO吴泳铭在2024云栖大会上发表主题演讲—— “ 过去22个月,AI发展速度超过任何历史时期,但我们依然还处于AGI变革的早期。生成式AI最大的想象力,绝不是在手机屏幕上做一两个新的超…

【论文阅读】Slim Fly: A Cost Effective Low-Diameter Network Topology 一种经济高效的小直径网络拓扑

文章目录 Slim Fly: A Cost Effective Low-Diameter Network Topology文章总结1. 摘要2. indroduction3. 主要工作 主要思想references Slim Fly: A Cost Effective Low-Diameter Network Topology Slim Fly:一种经济高效的小直径网络拓扑 SC’14 Maciej Besta 苏…

毕业设计选题:基于ssm+vue+uniapp的农产品自主供销小程序

开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…

实战OpenCV之图像阈值处理

基础入门 图像阈值处理是一种二值化技术,它基于预设的阈值,可以将图像中的像素分为两大类:一大类是背景,另一大类是前景或目标对象。这个过程涉及将图像中的每个像素值与阈值进行比较,并根据比较结果决定保留原始值还是…

已解决 Termius双击左键复制时,会自动输入Ctrl+C ^C

已解决 Termius双击左键复制时,会自动输入CtrlC ^C 一、问题现象 使用Termius双击左键复制时,会自动输入CtrlC,如图 二、解决办法 查阅了资料,又说是某翻译软件鼠标取词的问题,有说是输入法问题,众说纷纭…

AI免费UI页面生成

https://v0.dev/chat v0 - UI设计 cursor - 编写代码 参考:https://www.youtube.com/watch?vIyIVvAu1KZ4 界面和claude类似,右侧展示效果和代码 https://pagen.so/

【Python常用模块】_cx_Oracle模块详解

课 程 推 荐我 的 个 人 主 页:👉👉 失心疯的个人主页 👈👈入 门 教 程 推 荐 :👉👉 Python零基础入门教程合集 👈👈虚 拟 环 境 搭 建 :👉👉 Python项目虚拟环境(超详细讲解) 👈👈PyQt5 系 列 教 程:👉👉 Python GUI(PyQt5)教程合集 👈👈…

【代码随想录训练营第42期 Day61打卡 - 图论Part11 - Floyd 算法与A * 算法

目录 一、Floyd算法与A * 算法 1、Floyd算法 思想 伪代码 2、 A * 算法 思想 伪代码 二、经典题目 题目一:卡码网 97. 小明逛公园 题目链接 题解:Floyd 算法 题目二:卡码网 127. 骑士的攻击 题目链接 题解:A * 算法&a…

基于java的工费医疗报销管理系统设计与实现

博主介绍:专注于Java vue .net php phython 小程序 等诸多技术领域和毕业项目实战、企业信息化系统建设,从业十五余年开发设计教学工作 ☆☆☆ 精彩专栏推荐订阅☆☆☆☆☆不然下次找不到哟 我的博客空间发布了1000毕设题目 方便大家学习使用 感兴趣的…

单细胞BisqueRNA和BayesPrism去卷积分析工具简单比较

曾老师发来了一个工具,BisqueRNA,这个工具也可以用于单细胞/bulk数据的反卷积~ 因此本次就对这两个工具简单测评一下 ~ 生信菜鸟团:https://mp.weixin.qq.com/s/3dZQxDdY6M1WwMMcus5Gkg 笔者也曾经写过一个推文简单的介绍过,有…

C++的初阶模板和STL

C的初阶模板和STL 回顾之前的内存管理,我们还要补充一个概念:内存池 也就是定位new会用到的场景,内存池只会去开辟空间。 申请内存也就是去找堆,一个程序中会有很多地方要去找堆,这样子效率会很低下,为了…