InfoGAN:通过信息最大化生成对抗网络进行可解释的表示学习

news2024/12/23 14:11:39

系列文章目录

一 Conditional Generative Adversarial Nets
二 cGANs with Projection Discriminator
三 Conditional Image Synthesis with Auxiliary Classifier GANs
四 InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets


  • 系列文章目录
  • InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets
    • 1. 生成对抗网络
    • 2.用于引导潜在代码的互信息
      • 模式1
    • 鉴别器与Q同时训练
      • 模式二

InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets

InfoGAN是生成对抗网络的信息论扩展,能够以完全无监督的方式学习解开的表示。InfoGAN 是一个生成对抗网络,它最大化了一小部分潜在变量和观察之间的互信息。
InfoGAN 在 GAN 的基础上只增加了可以忽略不计的计算成本,并且易于训练。使用互信息诱导表示的核心思想可以应用于 VAE等其他方法。实验结果表示:互信息成本增强的生成模型可能是学习解缠结表示的有效方法

1. 生成对抗网络

生成对抗网络(GAN)是一种使用最大最小化博弈训练深度生成模型的框架。目标是学习与真实数据分布 Pdata(x) 匹配的生成器分布 PG(x)。 GAN 没有尝试显式地为数据分布中的每个 x 分配概率,而是学习生成器网络 G,该生成器网络 G 通过将噪声变量 z ∼ Pnoise(z) 转换为样本 G(z),从生成器分布 PG 生成样本。该生成器通过与对抗性判别器网络 D 进行训练,该网络旨在区分真实数据分布 Pdata 和生成器分布 PG 中的样本。因此,对于给定的生成器,最佳判别器是 D(x) = Pdata(x)/(Pdata(x) + PG(x))。更正式地说,极小极大博弈由以下表达式给出:
在这里插入图片描述

2.用于引导潜在代码的互信息

GAN 公式使用简单分解的连续输入噪声向量 z,同时对生成器使用该噪声的方式没有限制。因此,噪声可能会被生成器以高度纠缠的方式使用,导致 z 的各个维度与数据的语义特征不对应
然而,许多领域自然地分解为一组语义上有意义的变化因素。例如,当从 MNIST 数据集生成图像时,**如果模型自动选择分配一个离散随机变量来表示数字 (0-9) 的数字身份,并选择两个数字角度与粗细作为连续变量来表示。**在这种情况下,这些属性既是独立的又是显着的,如果能够在没有任何监督的情况下恢复这些概念,通过简单地指定 MNIST 数字是由独立的 1-of-10 变量和两个独立的连续变量生成的,那将会非常有用变量。
论文建议将输入噪声向量分解为两部分,而不是使用单个非结构化噪声向量:(i)z:它被视为不可压缩噪声源; (ii) c,潜在代码: 针对数据分布的显着结构化语义特征

  • 实现逻辑:
    为生成器网络提供不可压缩噪声 z 和潜在代码 c,因此生成器的形式变为 G(z, c)。然而,在标准 GAN 中,生成器可以通过找到满足 PG(x|c) = PG(x) 的解决方案来自由忽略额外的潜在代码 c。为了解决平凡码的问题,我们提出了一种信息论正则化:潜在码 c 和生成器分布 G(z, c) 之间应该存在高互信息。因此 I(c; G(z, c)) 应该很高。

在信息论中,X 和 Y 之间的互信息,I(X; Y ),衡量从随机变量 Y 的知识中学到的关于另一个随机变量 X 的“信息量”。互信息可以表示为两个随机变量的差值熵项:
在这里插入图片描述
这个定义有一个直观的解释:I(X; Y ) 是观察到 Y 时 X 的不确定性的减少。如果 X 和 Y 是独立的,则 I(X; Y ) = 0,因为了解一个变量并不能揭示另一个变量;相反,如果 X 和 Y 通过确定性可逆函数相关,则可以获得最大互信息。这种解释使得很容易制定成本:给定任何 x ∼ PG(x),我们希望 PG(c|x) 具有较小的熵。换句话说,潜在代码c中的信息不应在生成过程中丢失。类似的互信息启发目标之前已经在聚类背景下被考虑过。因此,我们建议解决以下信息正则化极小极大博弈:
在这里插入图片描述
实际上,互信息项 I(c; G(z, c)) 很难直接最大化,因为它需要访问后验 P (c|x)。幸运的是,我们可以通过定义辅助分布 Q(c|x) 来近似 P(c|x) 来获得它的下界:
在这里插入图片描述
在实践中,我们将辅助分布 Q 参数化为神经网络。在大多数实验中,Q 和 D 共享所有卷积层,并且有一个最终的全连接层来输出条件分布 Q(c|x) 的参数,这意味着 InfoGAN 只为 GAN 添加了可以忽略不计的计算成本。我们还观察到,LI (G, Q) 总是比正常的 GAN 目标收敛得更快,因此 InfoGAN 本质上是免费随 GAN 提供。

代码实现

class QInfoGAN2(nn.Module):

    def __init__(self, x_dim, c_dim, dim=96, norm='batch_norm', weight_norm='none'):
        super(QInfoGAN2, self).__init__()

        norm_fn = _get_norm_fn_2d(norm)
        weight_norm_fn = _get_weight_norm_fn(weight_norm)

        def conv_norm_lrelu(in_dim, out_dim, kernel_size=3, stride=1, padding=1):
            return nn.Sequential(
                weight_norm_fn(nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding)),
                norm_fn(out_dim),
                nn.LeakyReLU(0.2)
            )

        self.ls = nn.Sequential(  # (N, x_dim, 32, 32)
            conv_norm_lrelu(x_dim, dim),
            conv_norm_lrelu(dim, dim),
            conv_norm_lrelu(dim, dim, stride=2),  # (N, dim , 16, 16)

            conv_norm_lrelu(dim, dim * 2),
            conv_norm_lrelu(dim * 2, dim * 2),
            conv_norm_lrelu(dim * 2, dim * 2, stride=2),  # (N, dim*2, 8, 8)

            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=3, stride=1, padding=0),
            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),
            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),  # (N, dim*2, 6, 6)

            nn.AvgPool2d(kernel_size=6),  # (N, dim*2, 1, 1)
            torchlib.Reshape(-1, dim * 2),  # (N, dim*2)
            nn.Linear(dim * 2, c_dim)  # (N, c_dim)
        )

    def forward(self, x):
        # x: (N, x_dim, 32, 32)
        logit = self.ls(x)
        return logit

class DiscriminatorInfoGAN2(nn.Module):

    def __init__(self, x_dim, dim=96, norm='none', weight_norm='spectral_norm'):
        super(DiscriminatorInfoGAN2, self).__init__()

        norm_fn = _get_norm_fn_2d(norm)
        weight_norm_fn = _get_weight_norm_fn(weight_norm)

        def conv_norm_lrelu(in_dim, out_dim, kernel_size=3, stride=1, padding=1):
            return nn.Sequential(
                weight_norm_fn(nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding)),
                norm_fn(out_dim),
                nn.LeakyReLU(0.2)
            )

        self.ls = nn.Sequential(  # (N, x_dim, 32, 32)
            conv_norm_lrelu(x_dim, dim),
            conv_norm_lrelu(dim, dim),
            conv_norm_lrelu(dim, dim, stride=2),  # (N, dim , 16, 16)

            conv_norm_lrelu(dim, dim * 2),
            conv_norm_lrelu(dim * 2, dim * 2),
            conv_norm_lrelu(dim * 2, dim * 2, stride=2),  # (N, dim*2, 8, 8)

            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=3, stride=1, padding=0),
            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),
            conv_norm_lrelu(dim * 2, dim * 2, kernel_size=1, stride=1, padding=0),  # (N, dim*2, 6, 6)

            nn.AvgPool2d(kernel_size=6),  # (N, dim*2, 1, 1)
            torchlib.Reshape(-1, dim * 2),  # (N, dim*2)
            weight_norm_fn(nn.Linear(dim * 2, 1))  # (N, 1)
        )

    def forward(self, x):
        # x: (N, x_dim, 32, 32)
        logit = self.ls(x)
        return logit

class GeneratorCGAN(nn.Module):

    def __init__(self, z_dim, c_dim, dim=128):
        super(GeneratorCGAN, self).__init__()

        def dconv_bn_relu(in_dim, out_dim, kernel_size=4, stride=2, padding=1, output_padding=0):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, kernel_size, stride, padding, output_padding),
                nn.BatchNorm2d(out_dim),
                nn.ReLU()
            )

        self.ls = nn.Sequential(
            dconv_bn_relu(z_dim + c_dim, dim * 4, 4, 1, 0, 0),  # (N, dim * 4, 4, 4)
            dconv_bn_relu(dim * 4, dim * 2),  # (N, dim * 2, 8, 8)
            dconv_bn_relu(dim * 2, dim),   # (N, dim, 16, 16)
            nn.ConvTranspose2d(dim, 3, 4, 2, padding=1), nn.Tanh()  # (N, 3, 32, 32)
        )

    def forward(self, z, c):
        # z: (N, z_dim), c: (N, c_dim) ->[64, 110]
        x = torch.cat([z, c], 1)
        # [64, 110] -> [64, 3, 32, 32]
        x = self.ls(x.view(x.size(0), x.size(1), 1, 1))
        # print(x.shape)
        return x
# model
D = DiscriminatorInfoGAN2(x_dim=3).to(device)
Q = model.QInfoGAN2(x_dim=3, c_dim=c_dim).to(device)
G = model.GeneratorInfoGAN2(z_dim=z_dim, c_dim=c_dim).to(device)

模式1

鉴别器与Q同时训练



 # train D and Q
 x = x.to(device)
c_dense = torch.tensor(np.random.randint(c_dim, size=[batch_size])).to(device)
z = torch.randn(batch_size, z_dim).to(device)
c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)

x_f = G(z, c).detach()
x_gan_logit = D(x)
x_f_gan_logit = D(x_f)
x_f_c_logit = Q(x_f)

d_x_gan_loss, d_x_f_gan_loss = d_loss_fn(x_gan_logit, x_f_gan_logit)
d_x_f_c_logit = torch.nn.functional.cross_entropy(x_f_c_logit, c_dense)
gp = model.gradient_penalty(D, x, x_f, mode=gp_mode)
d_loss = d_x_gan_loss + d_x_f_gan_loss + gp * gp_coef
d_q_loss = d_x_f_c_logit
## 训练生成器
c_dense = torch.tensor(np.random.randint(c_dim, size=[batch_size])).to(device)
            c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
            z = torch.randn(batch_size, z_dim).to(device)

            x_f = G(z, c)
            x_f_gan_logit = D(x_f)
            x_f_c_logit = Q(x_f)

            g_gan_loss = g_loss_fn(x_f_gan_logit)
            d_x_f_c_logit = torch.nn.functional.cross_entropy(x_f_c_logit, c_dense)

模式二

生成器与Q同时进行训练

# train D
x = x.to(device)
c_dense = torch.tensor(np.random.randint(c_dim, size=[batch_size])).to(device)
z = torch.randn(batch_size, z_dim).to(device)
c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
# 输入随机图像与随机条件
x_f = G(z, c).detach()
# 输入图像,鉴别器判断
x_gan_logit = D(x)
# 输入伪图像鉴别器进行判断
x_f_gan_logit = D(x_f)
d_x_gan_loss, d_x_f_gan_loss = d_loss_fn(x_gan_logit, x_f_gan_logit)
# train G and Q
c_dense = torch.tensor(np.random.randint(c_dim, size=[batch_size])).to(device)
c = torch.tensor(np.eye(c_dim)[c_dense.cpu().numpy()], dtype=z.dtype).to(device)
z = torch.randn(batch_size, z_dim).to(device)
x_f = G(z, c)
x_f_gan_logit = D(x_f)
# Q与类别计算损失
 x_f_c_logit = Q(x_f)

g_gan_loss = g_loss_fn(x_f_gan_logit)
d_x_f_c_logit = torch.nn.functional.cross_entropy(x_f_c_logit, c_dense)
 g_loss = g_gan_loss + d_x_f_c_logit

在这里插入图片描述
在这里插入图片描述

源代码链接
参考链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2186266.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python 棒棒糖图

结果: import matplotlib.pyplot as plt import matplotlib.ticker as ticker import numpy as npdef lolly_plot(x, y, color_lis, breaks, back_color,title, sub_title):# 获取每个点的颜色colors [assign_color(temperature, breaks, color_lis) for temperatu…

想学道家智慧,误打误撞被儒家引导读《道德经》?这是怎么回事?

想学道家智慧,却误打误撞被儒家引导读《道德经》?这是怎么回事? 原来,这其中的缘由可以追溯到汉代。董仲舒等人在整理文献时,对《老子》进行了修改和补充,形成了所谓的《道德经》。这一版本不仅颠覆了原本…

【玩转 JS 函数式编程_004】1.4 如何应对 JavaScript 的不同版本

本节目录 1.4 如何应对 JavaScript 的不同版本 How do we work with JavaScript?1.4.1. 使用转译工具 Using transpilers1.4.2. 应用在线环境 Working online1.4.3. 测试环境 Testing 1.4 如何应对 JavaScript 的不同版本 How do we work with JavaScript? 上面介绍的语言特…

netty之Netty传输文件、分片发送、断点续传

前言 1:在实际应用中我们经常使用到网盘服务,他们可以高效的上传下载较大文件。那么这些高性能文件传输服务,都需要实现的分片发送、断点续传功能。 2:在Java文件操作中有RandomAccessFile类,他可以支持文件的定位读取…

【递归】13. leetcode 1457. 二叉树中的伪回文路径

1 题目描述 题目链接:二叉树中的伪回文路径 2 解答思路 第一步:挖掘出相同的子问题 (关系到具体函数头的设计) 第二步:只关心具体子问题做了什么 (关系到具体函数体怎么写,是一个宏观的过…

【重学 MySQL】四十九、阿里 MySQL 命名规范及 MySQL8 DDL 的原子化

【重学 MySQL】四十九、阿里 MySQL 命名规范及 MySQL8 DDL 的原子化 阿里 MySQL 命名规范MySQL8 DDL的原子化 阿里 MySQL 命名规范 【强制】表名、字段名必须使用小写字母或数字,禁止出现数字开头,禁止两个下划线中间只出现数字。数据库字段名的修改代价…

使用 Elastic 将 AI 摘要添加到你的网站

作者:来自 Elastic Gustavo Llermaly 我们目前所知道的搜索(搜索栏、结果、过滤器、页面等)已经取得了长足的进步,并实现了多种不同的功能。当我们知道找到所需内容所需的关键字或知道哪些文档包含我们想要的信息时,尤…

云计算SLA响应时间的matlab模拟与仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 用matlab模拟,一个排队理论。输入一堆包,经过buffer(一个或者几个都行)传给server,这些包会在buffer里…

热网无线监测系统/config.aspx接口存在反射性XSS漏洞

漏洞描述 热网无线监测系统/config.aspx接口存在反射性XSS漏洞 漏洞复现 FOFA body"Downloads/HDPrintInstall.rar" || body"skins/login/images/btn_login.jpg" POC IP/config.aspx POC <script>alert(1)</script> 点击确认成功弹窗1

C++继承与菱形继承(一文了解全部继承相关基础知识和面试点!)

目的减少重复代码冗余 Class 子类(派生类) &#xff1a; 继承方式 父类&#xff08;基类&#xff09; 继承方式共有三种&#xff1a;公共、保护、私有 父类的私有成员private无论哪种继承方式都不可以被子类使用 保护protected权限的内容在类内是可以访问&#xff0c;但是在…

kubernetes-强制删除命名空间

一、故障现象 1、删除命名空间卡住、强制删除也卡住 2、其他终端显示命名空间下无资源 二、处理步骤 1、kubectl get namespace cilium-test -o json > temp.json 获取你需要删除的命名空间json描述文件。 2、修改finalize字段 3、替换 kubectl replace --raw "/api/v1…

csdn加目录,标题

最简单的办法是&#xff0c; 目录 先写文章&#xff0c; 在给需要加标题的&#xff0c; 给他添加格式为标题一&#xff0c; 在点目录 先写文章&#xff0c; 在给需要加标题的&#xff0c; 给他添加格式为标题一&#xff0c; 最后点目录就会生成目录在文章前面了&#x…

AMD GPU推理:三步让你了解AI推理的游戏规则

我们常听到“AI推理”这个词,但很多朋友可能只了解个大概。如果你对AI有基本的认识,可能会好奇:AMD的GPU怎么在推理中大展拳脚?它跟常见的NVIDIA相比又如何?今天,我们就来聊聊这个问题。你可能在考虑选购GPU或者在项目中使用AMD,本文将带你从实际应用出发,一步步拆解。…

云原生(四十一)| 阿里云ECS服务器介绍

文章目录 阿里云ECS服务器介绍 一、云计算概述 二、什么是公有云 三、公有云优缺点 1、优点 2、缺点 四、公有云品牌 五、市场占有率 六、阿里云ECS概述 七、阿里云ECS特点 阿里云ECS服务器介绍 一、云计算概述 云计算是一种按使用量付费的模式&#xff0c;这种模式…

Stable Diffusion绘画 | 来训练属于自己的模型:炼丹参数调整--步数设置与计算

要想训练一个优质的模型&#xff0c;一定要认识和了解模型训练中&#xff0c;参数的作用和意义。 整个模型训练的过程&#xff0c;参数并不是一成不变的&#xff0c;也没有固定的模板&#xff0c; 当我们修改了模型训练里面的某个参数&#xff0c;很可能就需要连带其他一系列…

USB外设的Device与Host的差异

USB&#xff0c;Universal Serial Bus。 USB协会&#xff1a; USB-IF协会认证&#xff1a;USB IF全称USB Implementers Forum&#xff0c;是由一群开发通用串行总线规范的公司创立的非营利性组织。USB-IF组织的成立旨在推广通用串行总线技术并提供相应的技术规范&#xff0c;…

手机使用指南:如何在没有备份的情况下从 Android 设备恢复已删除的联系人

在本指南中&#xff0c;您将了解如何从 Android 手机内存中恢复已删除的联系人。Android 诞生、见证并征服了 80% 的智能手机行业。有些人可能将此称为“非常大胆的宣言”&#xff0c;但最近的统计数据完全支持我们的说法。灵活性、高度改进的可用性和快速性是 Android 操作系统…

【C语言】内存函数的使用和模拟实现

文章目录 一、memcpy的使用和模拟实现二、memmove的使用和模拟实现三、memset的使用四、memcmp的使用 一、memcpy的使用和模拟实现 在之前我们学习了使用和模拟实现strncpy函数&#xff0c;它是一个字符串函数&#xff0c;用来按照给定的字节个数来拷贝字符串&#xff0c;那么问…

“衣依”服装销售平台:Spring Boot框架的设计与实现

3系统分析 3.1可行性分析 通过对本“衣依”服装销售平台实行的目的初步调查和分析&#xff0c;提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本“衣依”服装销售平台采用JAVA作为开发语言&#xff…

TOGAF框架在企业数字化转型中从理论到实践的全面应用指南

数字化转型的背景与意义 随着全球技术的快速发展&#xff0c;数字化已成为现代企业生存和发展的核心驱动力。企业数字化转型不仅意味着引入新技术&#xff0c;还要求在业务模式、组织架构和运营方式上进行深度变革。然而&#xff0c;数字化转型的实施通常面临诸多挑战&#xf…