Conditional Generative Adversarial Nets

news2024/12/21 20:01:30

条件生成对抗网络

1.生成对抗网络

生成对网络由两个“对抗性”模型组成:一个生成模型 G,用于捕获数据分布,另一个判别模型 D,用于估计样本来自训练数据而不是 G 的概率。G 和 D 都可以是非线性映射函数。
为了学习数据 x 上的生成器分布 Pg,生成器构建从先验噪声分布 pz(z) 到数据空间的映射函数 G(z; θg)。判别器 D(x; θd) 输出一个标量,表示 x 来自训练数据而不是 pg 的概率。
G 和 D 都是同时训练的:我们调整 G 的参数以最小化 log(1 − D(G(z)) 并调整 D 的参数以最小化 logD(X),就好像它们遵循两人的最小-最大一样价值函数 V (G, D) 的博弈:
在这里插入图片描述

G(Generator) -> 生成模块
D (Discriminator) -> 鉴别模块(输出就结果可以是二进制也可以是一维的置信度)
在这里插入图片描述

2.条件生成对抗网络

如果生成器和判别器都以一些额外的信息 y 为条件,则生成对抗网络可以扩展到条件模型。y 可以是任何类型的辅助信息,例如类标签或来自其他模态的数据。我们可以通过将 y 作为额外的输入层输入到判别器和生成器中来执行调节。
在生成器中,先验输入噪声 pz(z) 和 y 被组合在联合隐藏表示中,并且对抗性训练框架允许在如何组成该隐藏表示方面具有相当大的灵活性。
在判别器中,x 和 y 作为输入呈现给判别函数(在本例中再次由 MLP 体现)。两人迷你最大游戏的目标函数为:
在这里插入图片描述
在这里插入图片描述

3.判别器损失函数

判别器(Discriminator)
判别器的目标是区分生成器生成的假数据和真实数据。它接受来自生成器的输出或真实数据集的样本作为输入,并输出一个概率值,表示输入样本是真实数据的概率。
生成器(Generator)
生成器(Generator)的损失函数是它在对抗过程中试图最小化的目标。生成器的目标是产生尽可能接近真实数据分布的假数据,以便判别器(Discriminator)难以区分真假数据。
训练过程

  • 初始化:生成器和判别器的参数随机初始化。
  • 对抗训练:
    生成器生成假数据。
    判别器尝试区分真假数据。
    判别器的损失函数是它对真实数据和生成数据的预测误差的总和

生成器的损失函数是它欺骗判别器的成功率,即判别器错误地将生成数据识别为真实数据的概率。
在这里插入图片描述

  • 参数更新:
    判别器根据损失函数更新参数,以更好地区分真假数据。
    生成器根据损失函数更新参数,以生成更逼真的数据,以欺骗判别器。

代码实现

#以fashionMNist
# 损失函数
def d_loss_fn(r_logit, f_logit):
            r_loss = torch.nn.functional.binary_cross_entropy_with_logits(r_logit, torch.ones_like(r_logit))
            f_loss = torch.nn.functional.binary_cross_entropy_with_logits(f_logit, torch.zeros_like(f_logit))
            return r_loss, f_loss

def g_loss_fn(f_logit):
            f_loss = torch.nn.functional.binary_cross_entropy_with_logits(f_logit, torch.ones_like(f_logit))
            return f_loss
# 生成模型
class GeneratorCGAN(nn.Module):

    def __init__(self, z_dim, c_dim, dim=128):
        super(GeneratorCGAN, self).__init__()

        def dconv_bn_relu(in_dim, out_dim, kernel_size=4, stride=2, padding=1, output_padding=0):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride, padding, output_padding),
                nn.BatchNorm2d(out_dim),
                nn.ReLU()
            )

        self.ls = nn.Sequential(
            dconv_bn_relu(z_dim + c_dim, dim * 4, 4, 1, 0, 0),  # (N, dim * 4, 4, 4)
            dconv_bn_relu(dim * 4, dim * 2),  # (N, dim * 2, 8, 8)
            dconv_bn_relu(dim * 2, dim),   # (N, dim, 16, 16)
            nn.ConvTranspose2d(dim, 3, 4, 2, padding=1), nn.Tanh()  # (N, 3, 32, 32)
        )

    def forward(self, z, c):
        # z: (N, z_dim), c: (N, c_dim) ->[64, 110]
        x = torch.cat([z, c], 1)
        # [64, 110] -> [64, 3, 32, 32]
        x = self.ls(x.view(x.size(0), x.size(1), 1, 1))
        # print(x.shape)
        # 输出生成的图像结果
        return x


class DiscriminatorCGAN(nn.Module):

    def __init__(self, x_dim, c_dim, dim=96, norm='none', weight_norm='spectral_norm'):
        super(DiscriminatorCGAN, self).__init__()

        norm_fn = _get_norm_fn_2d(norm)
        weight_norm_fn = _get_weight_norm_fn(weight_norm)

        def conv_norm_lrelu(in_dim, out_dim, kernel_size=3, stride=1, padding=1):
            return nn.Sequential(
                weight_norm_fn(nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding)),
                norm_fn(out_dim),
                nn.LeakyReLU(0.2)
            )

        self.ls = nn.Sequential(  # (N, x_dim+c_dim, 32, 32)
            conv_norm_lrelu(x_dim + c_dim, dim),
            conv_norm_lrelu(dim, dim),
            conv_norm_lrelu(dim, dim, stride=2),  # (N, dim , 16, 16)

            conv_norm_lrelu(dim, dim * 2),
            conv_norm_lrelu(dim * 2, dim * 2),
            conv_norm_lrelu(dim * 2, dim * 2, stride=2),  # (N, dim*2, 8, 8)

            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=3, stride=1, padding=0),
            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),
            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),  # (N, dim*2, 6, 6)

            nn.AvgPool2d(kernel_size=6),  # (N, dim*2, 1, 1)
            torchlib.Reshape(-1, dim * 2),  # (N, dim*2)
            weight_norm_fn(nn.Linear(dim * 2, 1))  # (N, 1)
        )

    def forward(self, x, c):
        # x: (N, x_dim, 32, 32), c: (N, c_dim)
        # [64, 10] -> [64, 10, 32, 32]
        c = c.view(c.size(0), c.size(1), 1, 1) * torch.ones([c.size(0), c.size(1), x.size(2), x.size(3)], dtype=c.dtype, device=c.device)
        # 常规损失函数 [64, 10, 32, 32] ->[64, 1]
        logit = self.ls(torch.cat([x, c], 1))
        # 输出置信度
        return logit
# model:鉴别器输入维度3:三通道图像,输出维度10:对应类别
D = DiscriminatorCGAN(x_dim=3, c_dim=c_dim)
# 生成器模型:编码维度,输出维度10:对应类别
G = GeneratorCGAN(z_dim=z_dim, c_dim=c_dim)

训练架构

  # 训练鉴别器模型输入与输出
  # 图像
  x = x.to(device)
  # 对应类别
  c_dense = c_dense.to(device)
  # 随机图像
  z = torch.randn(batch_size, z_dim).to(device)
  # 条件标签
  c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
  # 随机数与条件输入生成器生成伪图像
  x_f = G(z, c).detach()
  # 原始图像与条件输入鉴别器计算标签图像分数
  x_gan_logit = D(x, c)  # [batchsize,1]
  # 输入伪图像与条件计算伪图像分数
  x_f_gan_logit = D(x_f, c) # [batchsize,1]
_x_gan_loss, d_x_f_gan_loss = d_loss_fn(x_gan_logit, x_f_gan_logit)
  # 训练生成器模型输入与输出
  z = torch.randn(batch_size, z_dim).to(device)
  # 生成器中计算损失函数
  x_f = G(z, c)
  x_f_gan_logit = D(x_f, c)
  g_gan_loss = g_loss_fn(x_f_gan_logit)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2185204.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

设计模式-生成器模式/建造者模式Builder

构建起模式:将一个复杂类的表示与其构造分离,使得相同的构建过程能够得出不同的表示。(建造者其实和工厂模式差不多) 详细的UML类图 图文说明:距离相同的构建过程 得出不同的展示。此时就用两个类(文本生成…

探索未来:hbmqtt,Python中的AI驱动MQTT

文章目录 **探索未来:hbmqtt,Python中的AI驱动MQTT**1. 背景介绍2. hbmqtt是什么?3. 安装hbmqtt4. 简单的库函数使用方法4.1 连接到MQTT服务器4.2 发布消息4.3 订阅主题4.4 接收消息4.5 断开连接 5. 应用场景示例5.1 智能家居控制5.2 环境监测…

WebGIS之Cesium三维软件开发

目录 第 1 章 三维 WebGIS 概述 1.1 Google Earth 1 1.2 SkylineGlobe 2 1.3 LocaSpace Viewe 2 1.4 Cesium 3 1.5 Cesium API 概要 4 第 2 章 Cesium 快速入门 2.1 Cesium 环境搭建 7 2.1.1 安装 Node.js 环境 7 2.1.2 配置 Cesium 依赖 8 2.2 搭建第一个 Cesi…

【2006.07】UMLS工具——MetaMap原理深度解析

文献:《MetaMap: Mapping Text to the UMLS Metathesaurus》2006 年 7 月 14 日 https://lhncbc.nlm.nih.gov/ii/information/Papers/metamap06.pdf MetaMap:将文本映射到 UMLS 元数据库 总结 解决的问题 自动概念映射问题:解决如何将文本…

Vue3丨进一步了解这 20 个响应式 API,写码如有神

前面说的话 在 Vue2 中,个人觉得对于数据的操作比较 “黑盒” 。而 Vue3 把响应式系统更显式地暴露出来,使得我们对数据的操作有了更多的灵活性。所以,对于 Vue3 的几个响应式的 API ,我们需要更加的理解掌握,才能在实…

【MySQL】子查询、合并查询、表的连接

目录 一、子查询 1、单行子查询 显示SMITH同一部门的员工信息 2、多行子查询 in关键字 查询和10号部门的工作岗位相同的雇员的名字、岗位、工资、部门号,但是筛选出的雇员的部门不能有10号部门 all关键字 查询工资比30号部门中所有雇员工资高的雇员的姓名、…

TS(type,属性修饰符,抽象类,interface)一次性全部总结

目录 1.type 1.基本用法 2.联合类型 3.交叉类型 2.属性修饰符 1.public 属性修饰符 属性的简写形式 2.proteced 属性修饰符 3.private 属性修饰符 4.readonly 属性修饰符 3.抽象类 4.interface 1.定义类结构 2.定义对象结构 3.定义函数结构 4.接口之间的继…

postgresql|数据库|postgis编译完成后的插件迁移应该如何做(postgis插件最终章)

一、 本文的写作理由 postgis插件一般是编译安装,编译安装的原因是可以选择自己喜欢的版本,但编译的难度也是比较高的,因为有各种依赖,依赖之间还有依赖,非常容易形成依赖循环,因此,失败率是比…

【Python】CSVKit:强大的命令行CSV工具套件

CSVKit 是一个基于命令行的工具集,用于简化 CSV 文件的处理和管理。它提供了从数据转换、筛选、格式化到分析的全方位支持,特别适合需要处理复杂表格数据的用户。相比传统的 Excel 操作,CSVKit 更高效且功能更强大,非常适合数据分…

VSOMEIP代码阅读整理(1) - 网卡状态监听

一. 概述 在routing进程所使用的配置文件中,存在如下配置项目:{"unicast" : "192.168.56.101",..."service-discovery" :{"enable" : "true","multicast" : "224.244.224.245",…

线程和进程的关系和区别

目录 进程 概念 特点 生命周期 进程的通信 应用场景 线程 概念 特点 类型 状态 调度 应用场景 线程和进程的关系与区别 关系 区别 总结 僵尸进程 产生原因 解决方法 进程 概念 第一,进程是一个实体。每一个进程都有它自己的地址空间&#xff…

数字通信中不同信道类型对通信系统性能影响matlab仿真分析,对比AWGN,BEC,BSC以及多径信道

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 5.算法完整程序工程 1.算法运行效果图预览 (完整程序运行后无水印) 2.算法运行软件版本 matlab2022a 3.部分核心程序 (完整版代码包含详细中文注释和操作步骤视频&#xff09…

C0013.Clion中利用C++调用opencv打开摄像头

下载opencv https://opencv.org/get-started/ 直接官网下载opencv-4.9.0-windows.exe 安装opencv opencv配置环境变量 如上安装配置完成。

SpringBoot框架下的健康信息管理解决方案

第1章 绪论 1.1背景及意义 随着社会的快速发展,计算机的影响是全面且深入的。人们生活水平的不断提高,日常生活中人们对医院管理方面的要求也在不断提高,由于老龄化人数更是不断增加,使得师生健康信息管理系统的开发成为必需而且紧…

第三批安全可靠评测名单公布,几家欢喜几家忧

9月30号,赶在国庆长假之前,中国信息安全评测中心发布了《安全可靠评测结果公告(2024年第2号)》,测试结果自发布之日起有效期三年。 本期测试分为集中式数据库、分布式数据库和中央处理器三个大类,结果共有14家公司的16个产品入围&…

AI绘画实现数字人2D形象生成及3D数字人视频生成

概述 随着人工智能技术的不断进步,AI绘画已经成为数字艺术创作领域的重要工具。本章将详细介绍如何利用AI绘画技术生成数字人的2D形象,并进一步将其转化为3D数字人视频。通过一系列实践步骤和Python代码示例,您将能够掌握从平台使用到系统部…

计算机毕业设计之:音乐媒体播放及周边产品运营平台(源码+文档+讲解)

博主介绍: ✌我是阿龙,一名专注于Java技术领域的程序员,全网拥有10W粉丝。作为CSDN特邀作者、博客专家、新星计划导师,我在计算机毕业设计开发方面积累了丰富的经验。同时,我也是掘金、华为云、阿里云、InfoQ等平台…

看Threejs好玩示例,学习创新与技术(Noise)

给图像加一点噪声效果,可以起到朦胧背景的效果,比如下面这幅画。 除了普通的图片外,我们可以把这个效果应用到地图或其他方面,比如超过范围不允许用户了解更详细的内容。当然,也可以采用雾Fog效果,但后处理…

鸿蒙ArkUI实战开发-主打自研语言及框架

ArkUI 是 HarmonyOS 的声明式 UI 开发框架,而 ArkUI-X 是基于 ArkUI 框架扩展而来的跨平台开发框架。ArkUI-X 支持 HarmonyOS、OpenHarmony、Android 和 iOS 平台,允许开发者使用一套代码构建支持多平台的应用程序。 一、ArkUI-X 的实战开发步骤 在实战开…

(c++)在堆区创建一个数组并且访问与释放

在堆区创建一个数组,然后利用一个指针指向这个数组的首地址,通过这个指针来访问这个数组。 代码展示了三种赋值的方式: 1.直接利用数组访问赋值 2.利用循环结构(和1原理一样) 3.循环结构键盘输入赋值 然后输出这个…