BigGAN

news2024/11/6 7:19:31

1、BIGGAN 解读

1.1、作者

Andrew Brock、Jeff Donahue、Karen Simonyan

1.2、摘要

尽管最近在生成图像建模方面取得了进展,但从 ImageNet 等复杂数据集中 成功生成高分辨率、多样化的样本仍然是一个难以实现的目标。为此,我们以迄 今为止最大的规模训练生成对抗网络,并研究该规模特有的不稳定性。我们发现, 对生成器应用正交正则化使其易于使用简单的“截断技巧”,通过减少生成器输 入的方差,可以精细控制样本保真度和品种之间的权衡。我们的修改导致模型在 类条件图像合成中设置了新的技术状态。当以 128×128 分辨率在 ImageNet 上训 练时,BigGANs 的 IS 分数为 166.5,FID 分数为 7.4,比之前最好的 IS 为 52.52 和 FID 为 18.65 有所改进。

1.3、模型

GResidualBlock块代码如下:

class GResidualBlock(nn.Module):
    ''' Implements a residual block in BigGAN's generator '''
    def __init__(
        self,
        c_dim: int,
        in_channels: int,
        out_channels: int,
    ):
        super().__init__()

        self.conv1 = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        self.conv2 = nn.utils.spectral_norm(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))

        self.bn1 = ClassConditionalBatchNorm2d(c_dim, in_channels)
        self.bn2 = ClassConditionalBatchNorm2d(c_dim, out_channels)

        self.activation = nn.ReLU()
        self.upsample_fn = nn.Upsample(scale_factor=2)     # upsample occurs in every gblock

        self.mixin = (in_channels != out_channels)
        if self.mixin:
            self.conv_mixin = nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0))

    def forward(self, x, y):
        # x,y输入给BatchNorm
        h = self.bn1(x, y) # BatchNorm
        h = self.activation(h)# ReLU
        h = self.upsample_fn(h) # Upsample
        h = self.conv1(h)# 3x3Conv

        # x卷积后成h,y输入给BatchNorm
        h = self.bn2(h, y) # BatchNorm
        h = self.activation(h)# ReLU
        h = self.conv2(h)# 3x3Conv

        # x输入给Upsample
        x = self.upsample_fn(x)# Upsample
        if self.mixin:
            x = self.conv_mixin(x)# 1x1Conv

        # 1x1卷积后的x + 经过两次3x3卷积后的x
        return h + x # add

Non-Local Block的代码如下:

# Self-Attention module == Non-Local block
class AttentionBlock(nn.Module):
    ''' Implements a self-attention block from SA-GAN '''
    def __init__(self, channels: int):
        super().__init__()

        self.channels = channels

        self.theta = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 8, kernel_size=1, padding=0, bias=False))
        self.phi = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 8, kernel_size=1, padding=0, bias=False))
        self.g = nn.utils.spectral_norm(nn.Conv2d(channels, channels // 2, kernel_size=1, padding=0, bias=False))
        self.o = nn.utils.spectral_norm(nn.Conv2d(channels // 2, channels, kernel_size=1, padding=0, bias=False))

        self.gamma = nn.Parameter(torch.tensor(0.), requires_grad=True)

    def forward(self, x):
        spatial_size = x.shape[2] * x.shape[3]

        # apply convolutions to get query (theta), key (phi), and value (g) transforms
        theta = self.theta(x)
        phi = F.max_pool2d(self.phi(x), kernel_size=2)
        g = F.max_pool2d(self.g(x), kernel_size=2)

        # reshape spatial size for self-attention
        theta = theta.view(-1, self.channels // 8, spatial_size)
        phi = phi.view(-1, self.channels // 8, spatial_size // 4)
        g = g.view(-1, self.channels // 2, spatial_size // 4)

        # compute dot product attention with query (theta) and key (phi) matrices
        beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), dim=-1)

        # compute scaled dot product attention with value (g) and attention (beta) matrices
        o = self.o(torch.bmm(g, beta.transpose(1, 2)).view(-1, self.channels // 2, x.shape[2], x.shape[3]))

        # apply gain and residual
        return self.gamma * o + x

BigGAN的Generation结构如图所示:

根据上图代码如下:

class Generator(nn.Module):
    ''' Implements the BigGAN generator '''
 
    def __init__(
        self,
        base_channels: int = 96,
        bottom_width: int = 4,# yml里面是2
        z_dim: int = 120,
        shared_dim: int = 128,
        n_classes: int = 1000,
    ):
        super().__init__()
 
        n_chunks = 6    # 5 (generator blocks) + 1 (generator input)
        self.z_chunk_size = z_dim // n_chunks # 120//6 == 20
        self.z_dim = z_dim
        self.shared_dim = shared_dim
        self.bottom_width = bottom_width
        self.n_classes = n_classes
 
        # no spectral normalization on embeddings, which authors observe to cripple the generator
        self.shared_emb = nn.Embedding(n_classes, shared_dim)

        # Linear层 Linear(20,16*96*2**2)
        self.proj_z = nn.Linear(self.z_chunk_size, 16 * base_channels * bottom_width ** 2)
 
        # 不能用一个大nn。连续的,因为我们在每个块上添加class+noise
        self.g_blocks = nn.ModuleList([
            # ResBlock up 16ch → 16ch
            GResidualBlock(shared_dim + self.z_chunk_size, 16 * base_channels, 16 * base_channels),
            # ResBlock up 16ch → 8ch
            GResidualBlock(shared_dim + self.z_chunk_size, 16 * base_channels, 8 * base_channels),
            # ResBlock up 8ch → 4ch
            GResidualBlock(shared_dim + self.z_chunk_size, 8 * base_channels, 4 * base_channels),
            # ResBlock up 4ch → 2ch
            GResidualBlock(shared_dim + self.z_chunk_size, 4 * base_channels, 2 * base_channels),
            # Non-Local Block (64 × 64)
            AttentionBlock(2 * base_channels),
            # ResBlock up 2ch → ch
            GResidualBlock(shared_dim + self.z_chunk_size, 2 * base_channels, base_channels),
        ])
        self.proj_o = nn.Sequential(
            # BN, ReLU, 3 × 3 Conv ch → 3, Tanh
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True),
            nn.utils.spectral_norm(nn.Conv2d(base_channels, 3, kernel_size=1, padding=0)),
            nn.Tanh(),
        )
 
    def forward(self, z, y):
        '''
        z: random noise with size self.z_dim
        y: one-hot class embeddings with size self.shared_dim
        '''
        y = self.shared_emb(y)# class

        # 块z并连接到共享类嵌入
        zs = torch.split(z, self.z_chunk_size, dim=1)
        z = zs[0]
        ys = [torch.cat([y, z], dim=1) for z in zs[1:]] # Split的结果+Class
 
        # project noise and reshape to feed through generator blocks
        h = self.proj_z(z)# Linear层
        h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
 
        # feed through generator blocks
        idx = 0
        for g_block in self.g_blocks:
            if isinstance(g_block, AttentionBlock):
                h = g_block(h)
            else:
                h = g_block(h, ys[idx])
                idx += 1
 
        # project to 3 RGB channels with tanh to map values to [-1, 1]
        h = self.proj_o(h)
 
        return h

1.4、试验

1.4.1、不同 Batch size 对性能的影响

作者发现简单地将 Batch size 增大就可以实现性能上较好的提升,文章做 了实验验证。在 Batch size 增大到原来 8 倍的时候,生成性能上的 IS 提高 了 46%。文章推测这可能是每批次覆盖更多模式的结果,为生成和判别两个网 络提供更好的梯度。增大 Batch size 还会带来在更少的时间训练出更好性能的 模型,但增大 Batch size 也会使得模型在训练上稳定性下降,后续再分析如何 提高稳定性。

在实验上,单单提高 Batch size 还受到限制,文章在每层的通道数也做了 相应的增加,当通道增加 50%,大约两倍于两个模型中的参数数量。这会导致 IS 进一步提高 21%。文章认为这是由于模型的容量相对于数据集的复杂性而增加。

1.4.2、选择先验分布

z 通过实验对比了 N(0,1)、Bernoulli{0,1}、Censored Normal max(N(0,1), 0),根据参考训练速度、模型性能,文章最终选择了 z∼ N(0,I)。

1.4.3、选择阈值

所谓的“截断技巧”就是通过对从先验分布 z 采样,通过设置阈值的方式 来截断 z 的采样,其中超出范围的值被重新采样以落入该范围内。这个阈值可 以根据生成质量指标 IS 和 FID 决定。 通过实验可以知道通过对阈值的设定,随着阈值的下降生成的质量会越来越 好,但是由于阈值的下降、采样的范围变窄,就会造成生成上取向单一化,造成 生成的多样性不足的问题。往往 IS 可以反应图像的生成质量,FID 则会更假注 重生成的多样性。

1.4.4、尝试控制 G

在探索模型的稳定性上,文章在训练期间监测一系列权重、梯度和损失统计 数据,以寻找可能预示训练崩溃开始的指标。实验发现每个权重矩阵的前三个奇 异值 σ0,σ1,σ2 是最有用的,它们可以使用 Alrnoldi 迭代方法进行有效计 算。

对于奇异值 σ0,大多数 G 层具有良好的光谱规范,但有些层(通常是 G 中 的第一层而非卷积)则表现不佳,光谱规范在整个训练过程中增长,在崩溃时爆 炸。

一顿操作后,文章得出了调节 G 可以改善模型的稳定性,但是无法确保一 直稳定,从而文章转向对 D 的控制。

1.4.5、尝试控制 D

考虑 D 网络的光谱,试图寻找额外的约束来寻求稳定的训练。使用正交正 则化,DropOut 和 L2 的各种正则思想重复该实验,揭示了这些正则化策略的都 有类似行为:对 D 的惩罚足够高,可以实现训练稳定性但是性能成本很高,但 是在图像生成性能上也是下降的,而且降的有点多。

实验还发现 D 在训练期间的损失接近于零,但在崩溃时经历了急剧的向上 跳跃,这种行为的一种可能解释是 D 过度拟合训练集,记忆训练样本而不是学 习真实图像和生成图像之间的一些有意义的边界。

为了评估这一猜测,文章在 ImageNet 训练和验证集上评估判别器,并测量 样本分类为真实或生成的百分比。虽然在训练集下精度始终高于 98%,但验证 准确度在 50-55% 的范围内,这并不比随机猜测更好(无论正则化策略如何)。

这证实了 D 确实记住了训练集,也符合 D 的角色:不断提炼训练数据并为 G 提 供有用的学习信号。 可以通过约束 D 来强制执行稳定性,但这样做会导致性能上的巨大成本。 使用现有技术,通过放松这种调节并允许在训练的后期阶段发生崩溃(人为把握 训练实际),可以实现更好的最终性能,此时模型被充分训练以获得良好的结果。

1.4.6、用分辨率评估模型

在 ImageNet 数据集下做评估,实验在 ImageNet ILSVRC 2012(大家都在 用的 ImageNet 的数据集)上 128×128,256×256 和 512×512 分辨率评估模 型。

1.4.7、验证 G 网络并非是记住训练集

为了进一步说明 G 网络并非是记住训练集,在固定 z 下通过调节条件标签 c 做插值生成,通过下图的实验结果可以发现,整个插值过程是流畅的,也能说 明 G 并非是记住训练集,而是真正做到了图像生成。

1.5、与 GAN 的对比

BigGAN 的主要改进有一下三部分:

(1)通过大规模 GAN 的应用,BigGAN 实现了生成上的巨大突破,参数量 扩大两到四倍,batchsize 扩大八倍;

(2)采用先验分布 z 的“截断技巧”,允许对样本多样性和保真度进行精 细控制;

(3)在大规模 GAN 的实现上不断克服模型训练问题,采用技巧减小训练的 不稳定,但完全的稳定性只能以极高的性能成本实现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/374656.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用MUI与H5+构建移动端app

前言 通过mui构建APP 效果图: <!DOCTYPE html> <html> <head><meta charset

c语言指针(结构体)

*结构体&#xff1a;-箭头&#xff08;->&#xff09;&#xff1a;左边必须为指针&#xff1b;-点号&#xff08;.&#xff09;&#xff1a;左边必须为实体。*函数传数组用指针传递&#xff1a;-传的是第一个的元素的指针-也就是说在函数里&#xff0c;形参只是一个指针&…

Delphi 中 FireDAC 数据库连接(总览)

本系列包含一组文章&#xff0c;描述了如何用在Delphi中使用FireDAC设置数据库驱动和管理数据库连接。通过这一些列文章的学习&#xff0c;将熟练掌握FireDAC数据库连接管理应用。自由使用FireDAC&#xff01;主题说明定义连接描述了如何存储和使用FireDAC连接参数以及连接定义…

sqlserver 数据库优化工具,安全性设置,并发设置,SQL耗时优化

当前数据库版本查询&#xff0c;硬件资源最大限度决定当前数据库实例的性能.数据库并发设置&#xff0c;理论最大连接数&#xff0c;当前实例数据库实例设置的最大连接数&#xff0c;决定高并发的连接数数据库连接字符串优化&#xff0c;是否复用&#xff0c;生命周期&#xff…

英语学术写作作业文章分析1

Abstract:a concise overview of the study研究简明概述 Introduction&#xff1a;1.providing the background 提供研究背景 2.indentifying the problem under study 了解正在研究的问题 3.evaluating previous studies评估之前的研究 4.identifying the research gap认识研…

GAIDC 2023盛会迎来大模型论坛“主场”,百度飞桨护航大模型产业发展

‍‍‍‍2月25日-26日&#xff0c;2023全球人工智能开发者先锋大会&#xff08;GAIDC&#xff09;在上海临港举行&#xff0c;大会以“向光而行的AI开发者”为主题&#xff0c;汇聚了当前科技和产业革命中的开发者先锋力量。百度深度参与本次大会&#xff0c;飞桨联合上海市人工…

-source1.5中不支持diamond运算符终极解决办法

1. 常规办法 在File->Setting中设置如下&#xff1a; 然后检查&#xff1a;File->Project Structure里面的相关配置&#xff1a; 以上办法能解决问题的概率在90%&#xff0c;如果还不行&#xff0c;那么请按照以下方法&#xff0c;基本上100%可以解决。 如果还没解决则…

前端基础之CSS扫盲

文章目录一. CSS基本规范1. 基本语法格式2. 在HTML引入CSS3. 选择器分类二. CSS常用属性1. 文本属性2. 文本格式3. 背景属性4. 圆角矩形和圆5. 元素的显示模式6. CSS盒子模型7. 弹性布局光使用HTML来写一个前端页面的话其实只是写了一个大体的框架, 整体的页面并不工整美观, 而…

01 导入已有环境配置(上课使用)

一、导入虚拟机 通过 VMware 打开已经安装好的 集群 打开三个 将导入的虚拟机打开 第一次打开会需要选择方式 “我已移动虚拟机” 不会改变原虚拟机的配置 “我已复制虚拟机” 会去改变虚拟机mac地址&#xff0c;需要重新的去配置 mac 地址 选择 未列出 &#xff0c;以 root…

Matlab论文插图绘制模板第79期—无线条等高线填充图

资源群里有朋友问如何绘制等高线填充图&#xff0c;但删除线条&#xff0c;只保留填充颜色的那种。 那么&#xff0c;本期就来分享一下无线条等高线填充图的绘制模板。 先来看一下成品效果&#xff1a; 特别提示&#xff1a;Matlab论文插图绘制模板系列&#xff0c;旨在降低大…

【C语言进阶】指针与数组、转移表详解

前言 大家好我是程序猿爱打拳&#xff0c;我们在学习完指针的基本概念后知道了指针就是地址&#xff0c;我们可以通过这个地址并对它进行解引用从而改变一些数据。但只学习指针的基础是完全不够的&#xff0c;因此学习完指针的基础后我们可以学习关于指针的进阶&#xff0c;其中…

WireShark如何进行USB包协议分析

USB协议学习的步骤之一就是从抓包看协议通信,进而学习usb设备开发是怎么回事。这里发现一个工具就是wireshark。 WireShark如果要抓取usb设备的包,需要在安装的时候,选择usbpcap一并进行安装。

惠普星14Pro电脑开机不了显示错误代码界面怎么办?

惠普星14Pro电脑开机不了显示错误代码界面怎么办&#xff1f;有用户电脑开机之后&#xff0c;进入了一个错误界面&#xff0c;里面有一些错误代码。重启电脑之后依然是无法进入到桌面中&#xff0c;那么这个情况怎么去进行解决呢&#xff1f;我们可以重装一个新系统&#xff0c…

Android自定义View实现打钩签到动画

效果图实现原理我们看实现的动画效果&#xff0c;其实是分为1. 绘制未选中状态图形&#xff08;圆弧和对号&#xff09;2. 绘制选中状态圆弧的旋转的动画3. 绘制选中状态圆弧向中心收缩铺满动画4. 绘制选中状态对号5. 绘制选中状态下圆的放大回弹动画6. 暴露接口接口回调传递选…

FEBC2022|打造VR内容生态闭环 佳创视讯持续加码轻量化内容建设

2月24日&#xff0c;由陀螺科技主办的未来商业生态链接大会作为 2023 癸卯兔年开年率先召开的行业重要影响力盛会在深圳成功召开。今年大会云集了科技、软件、游戏、XR等元宇宙领域的世界500强、上市公司及行业独角兽企业&#xff0c;围绕游戏、元宇宙、XR、数字营销等多项热门…

ZigBee基本概念

本文主要介绍zigbee中profile、cluster、attribute、command的概念&#xff0c;以及zigbee的一些基本思想。zigbee联盟为了不同厂商的zigbee设备之间能够互联互通&#xff0c;于是制订了的zigbee协议标准&#xff0c;到今天&#xff08;2016.3.28&#xff09;已经到了版本3.0。…

【数据结构】知识点总结(C语言)

线性表、栈和队列、串、数组和广义表、树和二叉树、图、查找、排序线性表线性表&#xff08;顺序表示&#xff09;线性表是具有相同特性元素的一个有限序列&#xff0c;数据元素之间是线性关系&#xff0c;起始元素称为线性起点&#xff0c;终端元素称为线性终点。线性表的顺序…

python线程通信

在现实生活中&#xff0c;如果一个人团队正在共同完成任务&#xff0c;那么他们之间应该有通信&#xff0c;以便正确完成任务。 同样的比喻也适用于线程。 在编程中&#xff0c;要减少处理器的理想时间&#xff0c;我们创建了多个线程&#xff0c;并为每个线程分配不同的子任务…

线上监控诊断神器arthas

目录 什么是arthas 常用命令列表 1、dashboard仪表盘 2、heapdump dumpJAVA堆栈快照 3、jvm 4、thread 5、memory 官方文档 安装使用 1、云安装arthas 2、获取需要监控进程ID 3、运行arthas 4、进入仪表盘 5、其他命令使用查看官方文档 什么是arthas arthas是阿…

【自然语言处理】BERT GPT

BERT & GPT近年来&#xff0c;随着大规模预训练语言模型的发展&#xff0c;自然语言处理领域发生了巨大变革。BERT 和 GPT 是其中最流行且最有影响力的两种模型。在本篇博客中&#xff0c;我们将讨论 BERT 和 GPT 之间的区别以及它们的演变过程。 1.起源 201820182018 年&a…