使用Pytorch从零开始构建StyleGAN2

news2024/11/24 20:47:50

这篇博文是关于 StyleGAN2 的,来自论文Analyzing and Improving the Image Quality of StyleGAN,我们将使用 PyTorch 对其进行干净、简单且可读的实现,并尝试尽可能地还原原始论文。

如果您没有阅读 StyleGAN2 论文。或者不知道它是如何工作的并且你想了解它,我强烈建议你看看扫一下原始论文,了解其主要思想。

我们在本博客中使用的数据集是来自 Kaggle 的数据集,其中包含 16240 件女性上衣,分辨率为 256*192。

依赖项加载

一如既往,让我们首先加载我们需要的所有依赖项。

我们首先导入 torch,因为我们将使用 PyTorch,然后从那里导入 nn. 这将帮助我们创建和训练网络,并让我们导入 optim,一个实现各种优化算法(例如 sgd、adam 等)的包。我们从 torchvision 导入数据集和转换来准备数据并应用一些转换。

我们将从 torch.nn 导入 F 函数,从 torch.utils.data 导入 DataLoader 以创建小批量大小,从 torchvision.utils 导入 save_image 以保存一些假样本,log2 和 sqrt 形成数学,Numpy 用于线性代数,操作系统用于交互使用操作系统,tqdm 显示进度条,最后使用 matplotlib.pyplot 绘制一些图像。

import torch
from torch import nn, optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from math import log2, sqrt
import numpy as np
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

超参数

  • 通过真实图像的路径初始化DATASET。
  • 如果可用,则使用 Cuda 初始化设备,否则使用 CPU,将 epoch 数设为 300,将学习率设为 0.001,将批量大小设为 32。
  • 将 LOG_RESOLUTION 初始化为 7,因为我们试图生成 128*128 图像,并且 2^7 = 128。您可以根据所需的假图像的分辨率更改该值。
  • 在原始论文中,他们将 Z_DIM 和 W_DIM 初始化为 512,但我将它们初始化为 256,以减少 VRAM 使用和加速训练。如果我们将它们加倍,我们甚至可能会得到更好的结果。
  • 对于 StyleGAN2,我们可以使用任何我们想要的 GAN 损失函数,因此我使用论文“ Improved Training of Wasserstein GAN”中的 WGAN-GP 。该损失包含一个参数名称 λ,通常设置 λ = 10。
DATASET                 = "Women clothes"
DEVICE                  = "cuda" if torch.cuda.is_available() else "cpu"
EPOCHS                  = 300
LEARNING_RATE           = 1e-3
BATCH_SIZE              = 32
LOG_RESOLUTION          = 7 #for 128*128
Z_DIM                   = 256
W_DIM                   = 256
LAMBDA_GP               = 10

获取数据加载器

现在让我们创建一个函数get_loader来:

  • 对图像应用一些转换(将图像大小调整为我们想要的分辨率(2^LOG_RESOLUTION by 2^LOG_RESOLUTION),将它们转换为张量,然后应用一些增强,最后将它们标准化为从 -1 到1)。
  • 使用 ImageFolder 准备数据集,因为它已经以良好的方式构建。
  • 使用 DataLoader 创建小批量大小,该 DataLoader 通过打乱数据来获取数据集和批量大小。
  • 最后,返回loader。
def get_loader():
    transform = transforms.Compose(
        [
            transforms.Resize((2 ** LOG_RESOLUTION, 2 ** LOG_RESOLUTION)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5, 0.5, 0.5],
                [0.5, 0.5, 0.5],
            ),
        ]
    )
    dataset = datasets.ImageFolder(root=DATASET, transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
    )
    return loader

模型实现

现在让我们使用论文中的关键属性来实现 StyleGAN2 网络。我们将尽力使实现紧凑,但同时保持其可读性和可理解性。具体来说,有以下几个要点:

  • 噪声映射网络
  • 权重解调(而非自适应实例归一化 (AdaIN))
  • 跳跃连接(而非渐进式增长)
  • 感知路径长度标准化

噪声映射网络

让我们创建将从 nn.Module 继承的 MappingNetwork 类。

在init部分,我们发送 z_dim 和 w_din,并定义包含 8 个 EqualizedLinear 的网络映射,这是我们稍后将实现的用于均衡学习率的类,以及作为激活函数的 ReLu
在前一部分中,我们使用像素范数初始化 z_dim,然后返回网络映射。

class MappingNetwork(nn.Module):
    def __init__(self, z_dim, w_dim):
        super().__init__()
        self.mapping = nn.Sequential(
            EqualizedLinear(z_dim, w_dim),
            nn.ReLU(),
            EqualizedLinear(z_dim, w_dim),
            nn.ReLU(),
            EqualizedLinear(z_dim, w_dim),
            nn.ReLU(),
            EqualizedLinear(z_dim, w_dim),
            nn.ReLU(),
            EqualizedLinear(z_dim, w_dim),
            nn.ReLU(),
            EqualizedLinear(z_dim, w_dim),
            nn.ReLU(),
            EqualizedLinear(z_dim, w_dim),
            nn.ReLU(),
            EqualizedLinear(z_dim, w_dim)
        )

    def forward(self, x):
    	  x = x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + 1e-8)  # for PixelNorm 
        return self.mapping(x)

生成器

在下图中,您可以看到生成器架构,它以初始常量开始。然后它有一系列的块。每个块的特征图分辨率加倍。每个块输出一个 RGB 图像,它们被放大并求和以获得最终的 RGB 图像。

toRGB还有一个风格调制,为简单起见,图中未显示。

为了使代码尽可能简洁,在生成器的实现中,我们将使用稍后定义的三个类(StyleBlock、toRGB 和 GeneratorBlock)。
在这里插入图片描述

  • 在初始化部分,我们发送 log_resolution,它是图像分辨率的 log2​,W_DIM,它是w 的维数, n_featurese,它 是最高分辨率(最终块)卷积层中的特征数量,max_features,它是最大值任何生成器块中的功能数量。我们计算每个块的特征数量,得到生成器块的数量,并初始化可训练的 4x4 常量、4×4 分辨率的第一个样式块、获取 RGB 的层和生成器块。
  • 在前一部分中,我们为每个生成器块发送 w ,它具有形状 [ n_blocks, batch_size, W-dim ] 和 input_noise ,它是每个块的噪声,它是噪声张量对的列表,因为每个块(除了初始)在每个卷积层之后有两个噪声输入(见上图)。我们获取批量大小,扩展学习的常量以匹配批量大小,将其运行到第一个样式块,获取 RGB 图像,然后在上采样后再次将其运行到其余的生成器块中。最后,以 tanh 作为激活函数返回最后一张 RGB 图像。我们使用 tanh 的原因是它将作为输出(生成的图像)​​,并且我们希望像素的范围在 1 到 -1 之间。
class Generator(nn.Module):

    def __init__(self, log_resolution, W_DIM, n_features = 32, max_features = 256):

        super().__init__()

        features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 2, -1, -1)]
        self.n_blocks = len(features)

        self.initial_constant = nn.Parameter(torch.randn((1, features[0], 4, 4)))

        self.style_block = StyleBlock(W_DIM, features[0], features[0])
        self.to_rgb = ToRGB(W_DIM, features[0])

        blocks = [GeneratorBlock(W_DIM, features[i - 1], features[i]) for i in range(1, self.n_blocks)]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, w, input_noise):

        batch_size = w.shape[1]

        x = self.initial_constant.expand(batch_size, -1, -1, -1)
        x = self.style_block(x, w[0], input_noise[0][1])
        rgb = self.to_rgb(x, w[0])

        for i in range(1, self.n_blocks):
            x = F.interpolate(x, scale_factor=2, mode="bilinear")
            x, rgb_new = self.blocks[i - 1](x, w[i], input_noise[i])
            rgb = F.interpolate(rgb, scale_factor=2, mode="bilinear") + rgb_new

        return torch.tanh(rgb)

生成器block

在下图中,您可以看到生成器block架构,它由两个风格blocks(带有风格调制的 3×3 卷积)和 RGB 输出组成。
在这里插入图片描述

class GeneratorBlock(nn.Module):

    def __init__(self, W_DIM, in_features, out_features):

        super().__init__()

        self.style_block1 = StyleBlock(W_DIM, in_features, out_features)
        self.style_block2 = StyleBlock(W_DIM, out_features, out_features)

        self.to_rgb = ToRGB(W_DIM, out_features)

    def forward(self, x, w, noise):

        x = self.style_block1(x, w, noise[0])
        x = self.style_block2(x, w, noise[1])

        rgb = self.to_rgb(x, w)

        return x, rgb

  • 在init部分,我们发送 W_DIM(即 w 的维数)、 in_features(即输入特征图中的特征数量)和 out_features(即输出特征图中的特征数量),然后我们初始化两个风格blocks并到RGB层。
  • 在前向部分中,我们发送形状为 [ batch_size, in_features, height, width ] 的输入特征图 x,形状为 [ batch_size, W_DIM ] 的 w,以及​​形状为两个噪声张量的元组的噪声。 [ batch_size, 1, height, width ],然后我们将 x 运行到两个风格blocks中,并使用 toRGB 层获得 RGB 图像。最后,我们返回 x 和 RGB 图像。

风格blocks

在这里插入图片描述

  • 在init部分,我们发送 W_DIM、in_features 和 out_features,然后用从 w 获得的风格向量(图中用A表示)初始化 to_style,并使用稍后实现的均衡学习率线性层 (EqualizedLinear) 、权重调制卷积层、噪声尺度、偏差和激活函数。
  • 在前向部分,我们发送x、w和噪声,然后得到风格向量s,将x和s运行到权重调制卷积中,缩放并添加噪声,最后添加偏差并评估激活函数。
class StyleBlock(nn.Module):

    def __init__(self, W_DIM, in_features, out_features):

        super().__init__()

        self.to_style = EqualizedLinear(W_DIM, in_features, bias=1.0)
        self.conv = Conv2dWeightModulate(in_features, out_features, kernel_size=3)
        self.scale_noise = nn.Parameter(torch.zeros(1))
        self.bias = nn.Parameter(torch.zeros(out_features))

        self.activation = nn.LeakyReLU(0.2, True)

    def forward(self, x, w, noise):

        s = self.to_style(w)
        x = self.conv(x, s)
        if noise is not None:
            x = x + self.scale_noise[None, :, None, None] * noise
        return self.activation(x + self.bias[None, :, None, None])

转RGB

在这里插入图片描述

  • 在初始化部分,我们发送 W_DIM 和特征,然后通过从 w 获得的风格向量(图中用A表示)、权重调制卷积层、偏差和激活函数来初始化 to_style 。
  • 在前向部分,我们发送 x 和 w,然后我们得到样式向量 style,我们将 x 和 style 运行到权重调制卷积中,最后,我们添加偏差并评估激活函数。
class ToRGB(nn.Module):

    def __init__(self, W_DIM, features):

        super().__init__()
        self.to_style = EqualizedLinear(W_DIM, features, bias=1.0)

        self.conv = Conv2dWeightModulate(features, 3, kernel_size=1, demodulate=False)
        self.bias = nn.Parameter(torch.zeros(3))
        self.activation = nn.LeakyReLU(0.2, True)

    def forward(self, x, w):

        style = self.to_style(w)
        x = self.conv(x, style)
        return self.activation(x + self.bias[None, :, None, None])

卷积与权重调制和解调

此类通过样式向量缩放卷积权重,并通过对其进行归一化来解调。

  • 在init部分,我们发送 in_features、out_features、kernel_size、demodulates(是否按标准差对权重进行归一化的标志)和 eps(用于归一化的ϵ),然后初始化输出特征的数量、解调、填充大小,使用我们稍后将实现的类 EqualizedWeight 和 eps 来设置具有均衡学习率的权重参数。
  • 在前向部分,我们发送输入特征图 x 和基于样式的缩放张量 s,然后我们从 x 中获取批量大小、高度和宽度,重塑尺度,获得均衡的学习率权重,然后调制 x 和 s,如果 demodulates 为 True,则使用以下方程解调它们,其中i是输入通道,j是输出通道,k是内核索引。最后,我们返回 x。
    在这里插入图片描述
class Conv2dWeightModulate(nn.Module):

    def __init__(self, in_features, out_features, kernel_size,
                 demodulate = True, eps = 1e-8):

        super().__init__()
        self.out_features = out_features
        self.demodulate = demodulate
        self.padding = (kernel_size - 1) // 2

        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])
        self.eps = eps

    def forward(self, x, s):

        b, _, h, w = x.shape

        s = s[:, None, :, None, None]
        weights = self.weight()[None, :, :, :, :]
        weights = weights * s

        if self.demodulate:
            sigma_inv = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps)
            weights = weights * sigma_inv

        x = x.reshape(1, -1, h, w)

        _, _, *ws = weights.shape
        weights = weights.reshape(b * self.out_features, *ws)

        x = F.conv2d(x, weights, padding=self.padding, groups=b)

        return x.reshape(-1, self.out_features, h, w)

鉴别器

在下图中,您可以看到鉴别器架构。它首先将分辨率为 2 L O G _ R E S O L U T I O N x 2 L O G _ R E S O L U T I O N 2 ^{LOG\_RESOLUTION} x 2^{LOG\_RESOLUTION} 2LOG_RESOLUTIONx2LOG_RESOLUTION的图像转换 为相同分辨率的特征图,然后通过一系列具有残差连接的块来运行它。每个块的分辨率下采样 2 倍,同时特征数量加倍。
在这里插入图片描述

  • 在init部分,我们发送log_resolution、n_feautures和max_features,计算每个块的特征数量,然后初始化一个名为from_rgb的层,将RGB图像转换为具有n_features特征数量、鉴别器数量的特征图块、鉴别器块、添加标准差图后的特征数、最终的 3×3 卷积层和最终的线性层以获得分类。
  • 对于判别器上的 Minibatch std,我们在为每个示例(跨所有通道和像素)获取 std 时添加minibatch_std部分,然后我们对单个通道重复它并将其与图像连接。通过这种方式,鉴别器将获得有关批次/图像变化的信息。
  • 在前向部分,我们发送 x,它是形状 [ batch_size, 3, height, width ] 的输入图像,然后运行它并抛出 from_RGB 层、鉴别器块、minibatch_std、3×3 卷积、展平和分类分数。
class Discriminator(nn.Module):

    def __init__(self, log_resolution, n_features = 64, max_features = 256):

        super().__init__()

        features = [min(max_features, n_features * (2 ** i)) for i in range(log_resolution - 1)]

        self.from_rgb = nn.Sequential(
            EqualizedConv2d(3, n_features, 1),
            nn.LeakyReLU(0.2, True),
        )
        n_blocks = len(features) - 1
        blocks = [DiscriminatorBlock(features[i], features[i + 1]) for i in range(n_blocks)]
        self.blocks = nn.Sequential(*blocks)

        final_features = features[-1] + 1
        self.conv = EqualizedConv2d(final_features, final_features, 3)
        self.final = EqualizedLinear(2 * 2 * final_features, 1)

    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self, x):

        x = self.from_rgb(x)
        x = self.blocks(x)

        x = self.minibatch_std(x)
        x = self.conv(x)
        x = x.reshape(x.shape[0], -1)
        return self.final(x)

鉴别器blocks

在下图中,您可以看到判别器blocks架构,它由两个带有残差连接的 3×3 卷积组成。
在这里插入图片描述

  • 在init部分,我们发送in_features和out_features,并初始化包含下采样和用于残差连接的1×1卷积层的残差块,该块层包含两个以Leaky Rely作为激活的3×3卷积函数,使用 AvgPool2d 的 down_sample 层,以及添加残差后我们将使用的比例因子。
  • 在前向部分中,我们发送 x 并运行它抛出残差连接以获得名为残差的变量,然后运行 ​​x 抛出卷积和下采样,然后添加残差和缩放,然后返回它。
class DiscriminatorBlock(nn.Module):

    def __init__(self, in_features, out_features):
        super().__init__()
        self.residual = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2), # down sampling using avg pool
                                      EqualizedConv2d(in_features, out_features, kernel_size=1))

        self.block = nn.Sequential(
            EqualizedConv2d(in_features, in_features, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2, True),
            EqualizedConv2d(in_features, out_features, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2, True),
        )

        self.down_sample = nn.AvgPool2d(
            kernel_size=2, stride=2
        )  # down sampling using avg pool

        self.scale = 1 / sqrt(2)

    def forward(self, x):
        residual = self.residual(x)

        x = self.block(x)
        x = self.down_sample(x)

        return (x + residual) * self.scale

学习率均衡线性层

现在是时候实现EqualizedLinear了,我们之前在几乎每个类中都使用它来均衡线性层的学习率。

  • 在init部分,我们发送 in_features、out_features 和偏差。我们通过稍后定义的类 EqualizedWeight 来初始化权重,并初始化偏差。
  • 在前向部分,我们发送 x 并返回 x、权重和偏差的线性变换.
class EqualizedLinear(nn.Module):

    def __init__(self, in_features, out_features, bias = 0.):

        super().__init__()
        self.weight = EqualizedWeight([out_features, in_features])
        self.bias = nn.Parameter(torch.ones(out_features) * bias)

    def forward(self, x: torch.Tensor):
        return F.linear(x, self.weight(), bias=self.bias)

学习率均衡 2D 卷积层

现在让我们实现之前用来均衡卷积层学习率的EqualizedConv2d 。

  • 在init部分,我们发送 in_features、out_features、kernel_size 和 padding。我们通过稍后定义的类 EqualizedWeight 初始化填充、​​权重以及偏差。
  • 在前向部分,我们发送 x 并返回 x、权重、偏差和填充的卷积。
class EqualizedConv2d(nn.Module):

    def __init__(self, in_features, out_features,
                 kernel_size, padding = 0):

        super().__init__()
        self.padding = padding
        self.weight = EqualizedWeight([out_features, in_features, kernel_size, kernel_size])
        self.bias = nn.Parameter(torch.ones(out_features))

    def forward(self, x: torch.Tensor):
        return F.conv2d(x, self.weight(), bias=self.bias, padding=self.padding)

学习率均衡权重参数

现在让我们实现在学习率均衡线性层和学习率均衡 2D 卷积层中使用的EqualizedWeight类。

这是基于 ProGAN 论文中引入的均衡学习率。他们不是将权重初始化为 N(0, c ),而是将权重初始化为 N(0,1),然后在使用时将其乘以c。

  • 在初始化部分,我们以权重参数的形式发送,我们用 N(0,1) 初始化常数 c 和权重。
  • 在前面的部分,我们将权重乘以c并返回。
class EqualizedWeight(nn.Module):

    def __init__(self, shape):

        super().__init__()

        self.c = 1 / sqrt(np.prod(shape[1:]))
        self.weight = nn.Parameter(torch.randn(shape))

    def forward(self):
        return self.weight * self.c

感知路径长度标准化

感知路径长度归一化鼓励w中的固定大小步长,以导致图像中固定大小的变化。
在这里插入图片描述
其中 J w J_w Jw使用以下等式计算,w 从映射网络中采样,y是带有噪声 N(0, I) 的图像,a是训练过程中的指数移动平均值。
在这里插入图片描述

  • 在 init部分, 我们发送 beta,它是用于计算指数移动平均线a 的常数β 。初始化beta,steps为计算出的步数N, exp_sum_a为 J w T y J_w^T y JwTy的指数和。
  • 在前向部分,我们发送x,它是形状为[ batch_size, W_DIM ]的w的批次,x是生成的形状为[ batch_size, 3, height, width ]的图像,获取设备和像素数,计算上面的方程,更新指数和,增加N,并返回惩罚。
class PathLengthPenalty(nn.Module):

    def __init__(self, beta):

        super().__init__()

        self.beta = beta
        self.steps = nn.Parameter(torch.tensor(0.), requires_grad=False)

        self.exp_sum_a = nn.Parameter(torch.tensor(0.), requires_grad=False)

    def forward(self, w, x):

        device = x.device
        image_size = x.shape[2] * x.shape[3]
        y = torch.randn(x.shape, device=device)

        output = (x * y).sum() / sqrt(image_size)
        sqrt(image_size)

        gradients, *_ = torch.autograd.grad(outputs=output,
                                            inputs=w,
                                            grad_outputs=torch.ones(output.shape, device=device),
                                            create_graph=True)

        norm = (gradients ** 2).sum(dim=2).mean(dim=1).sqrt()

        if self.steps > 0:

            a = self.exp_sum_a / (1 - self.beta ** self.steps)

            loss = torch.mean((norm - a) ** 2)
        else:
            loss = norm.new_tensor(0)

        mean = norm.mean().detach()
        self.exp_sum_a.mul_(self.beta).add_(mean, alpha=1 - self.beta)
        self.steps.add_(1.)

        return loss

Utils

梯度惩罚

在下面的代码片段中,您可以找到 WGAN-GP 损失的gradient_penalty 函数。

def gradient_penalty(critic, real, fake,device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)
 
    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

Sample W

该函数对 Z 进行随机采样,并从映射网络中获取 W。

def get_w(batch_size):

    z = torch.randn(batch_size, W_DIM).to(DEVICE)
    w = mapping_network(z)
    return w[None, :, :].expand(LOG_RESOLUTION, -1, -1)

噪声生成

该函数为每个生成器block组生成噪声

def get_noise(batch_size):
    
        noise = []
        resolution = 4

        for i in range(LOG_RESOLUTION):
            if i == 0:
                n1 = None
            else:
                n1 = torch.randn(batch_size, 1, resolution, resolution, device=DEVICE)
            n2 = torch.randn(batch_size, 1, resolution, resolution, device=DEVICE)

            noise.append((n1, n2))

            resolution *= 2

        return noise

在下面的代码片段中,您可以找到generate_examples函数,它接受生成器gen 、epoch数和n=100。该函数的目标是生成n 个假图像并将它们保存为每个epoch的结果。

def generate_examples(gen, epoch, n=100):
    
    gen.eval()
    alpha = 1.0
    for i in range(n):
        with torch.no_grad():
            w     = get_w(1)
            noise = get_noise(1)
            img = gen(w, noise)
            if not os.path.exists(f'saved_examples/epoch{epoch}'):
                os.makedirs(f'saved_examples/epoch{epoch}')
            save_image(img*0.5+0.5, f"saved_examples/epoch{epoch}/img_{i}.png")

    gen.train()

训练

在本节中,我们将训练 StyleGAN2。

让我们首先创建训练函数,该函数采用判别器/批评器、生成器 gen、每 16 个 epoch 使用的 path_length_penalty、加载器和网络优化器。我们首先循环使用 DataLoader 创建的所有小批量大小,并且只获取图像,因为我们不需要标签。

然后,当我们想要最大化E(critic(real)) - E(critic(fake))时,我们为判别器\Critic 设置训练。这个方程意味着评论家可以区分真实和虚假图像的程度。

之后,当我们想要最大化E(critic(fake))时,我们为生成器和映射网络设置训练,并且每 16 个时期向该函数添加一个感知路径长度。

最后,我们更新循环。

def train_fn(
    critic,
    gen,
    path_length_penalty,
    loader,
    opt_critic,
    opt_gen,
    opt_mapping_network,
):
    loop = tqdm(loader, leave=True)

    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]

        w     = get_w(cur_batch_size)
        noise = get_noise(cur_batch_size)
        with torch.cuda.amp.autocast():
            fake = gen(w, noise)
            critic_fake = critic(fake.detach())
            
            critic_real = critic(real)
            gp = gradient_penalty(critic, real, fake, device=DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake))
                + LAMBDA_GP * gp
                + (0.001 * torch.mean(critic_real ** 2))
            )

        critic.zero_grad()
        loss_critic.backward()
        opt_critic.step()

        gen_fake = critic(fake)
        loss_gen = -torch.mean(gen_fake)

        if batch_idx % 16 == 0:
            plp = path_length_penalty(w, fake)
            if not torch.isnan(plp):
                loss_gen = loss_gen + plp

        mapping_network.zero_grad()
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        opt_mapping_network.step()

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )

现在让我们初始化加载器、网络和优化器,并使网络处于训练模式

loader              = get_loader()

gen                 = Generator(LOG_RESOLUTION, W_DIM).to(DEVICE)
critic              = Discriminator(LOG_RESOLUTION).to(DEVICE)
mapping_network     = MappingNetwork(Z_DIM, W_DIM).to(DEVICE)
path_length_penalty = PathLengthPenalty(0.99).to(DEVICE)

opt_gen             = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_critic          = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
opt_mapping_network = optim.Adam(mapping_network.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))

gen.train()
critic.train()
mapping_network.train()

现在让我们使用训练循环来训练网络,并在每 50 个 epoch 中保存一些假样本。

loader = get_loader()  

for epoch in range(EPOCHS):
    train_fn(
        critic,
        gen,
        path_length_penalty,
        loader,
        opt_critic,
        opt_gen,
        opt_mapping_network,
    )
    if epoch % 50 == 0:
    	generate_examples(gen, epoch)

结论

在本文中,我们使用 PyTorch 从头开始​​为 StyleGAN2 这个大型项目制作了一个干净、简单且可读的实现。我们尝试尽可能地复制原始论文。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1312215.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mysql 5.7.34升级到5.7.44修补漏洞

mysql 5.7.34旧版本,漏扫有漏洞,升级到最新版本 旧版本5.7.34在 /home/mysql/mysql中安装 备份旧版本数据还有目录 数据库备份升级 tar -xf mysql-5.7.44-el7-x86_64.tar #覆盖旧版本数据库文件 #注意看看文件是否和你起服务的用户一样 \cp -r mysql-5…

开发者必备的5类AI工具,不容错过!

在当今快节奏和激烈竞争的时代,提高工作效率和产品质量变得尤为重要。作为软件开发者,也必须紧跟现代化工具的步伐,以保持领先优势。在这篇文章中,笔者总结了2023年开发者必备的5类AI工具,这些工具将帮助您提升工作效率…

【六】python观察者设计模式

6.1行为型模式简介 观察者设计模式是最简单的行为型模式之一,所以我们先简单了解一下行为型模式 创建型模式的工作原理是基于对象的创建机制的。由于这些模式隔离了对象的创建细 节,所以使得代码能够与要创建的对象的类型相互独立。结构型模式用于设计对象和类的结…

精准运维的利器:风险控制模型

导读: 前期在《承载运维成功之梦:精准运维》一文中阐述了精准运维的原理、方法和实例。所谓精准运维,就是通过一系列方法掌握服务对象所使用信息系统的特性及其所服务企业的业务特性,通过掌控信息系统运行风险、运行特点、资源调…

C语言——字符函数和字符串函数(二)

📝前言: 上一篇文章C语言——字符函数和字符串函数(一)对字符函数和字符串函数strlen,strcpy和strncpy,strcat和strncat进行了初步的讲解 这篇文章主要再讲解几个我们常用到的其他字符串函数(附…

el-tree 高亮某些节点

:render-content"renderContent"

DevOps 和人工智能 – 天作之合

如今,人工智能和机器学习无处不在,所以它们开始在 DevOps 领域崭露头角也毫不令人意外。人工智能和机器学习正在通过自动化任务改变 DevOps,并使各企业的软件开发生命周期更高效、更深刻和更安全。我们在 DevOps 趋势中简要讨论过这一问题&am…

山姆·奥特曼重新掌舵OpenAI,为人工智能“保驾护航”

原创 | 文 BFT机器人 OpenAI首席执行官Sam Altman于2023年12月11日星期一在美国乔治亚州亚特兰大举行的全球论坛年会上发表讲话。来自40个国家的5200多名代表参加了此次会议,旨在重新构想全球经济,让大型科技企业的利益和机会惠及到所有人。 山姆奥特曼…

Unity | Shader基础知识(第五集:案例<小彩球>)

目录 一、本节介绍 1 上集回顾 2 本节介绍 二、原理分析 1 现实中出现彩色的原因 2 软件里的彩色的原理 3 方案 三、 实现数字由【-1,1】映射为【0,1】 1 结论 2 原理 四、代码实现 1 注意事项 2 详解结构体appdata_base 3 接收数据 4 映射数据 5 输出给SV_TAR…

Spring Cloud + Vue前后端分离-第5章 单表管理功能前后端开发

Spring Cloud Vue前后端分离-第5章 单表管理功能前后端开发 完成单表的增删改查 控台单表增删改查的前后端开发,重点学习前后端数据交互,vue ajax库axios的使用等 通用组件开发:分页、确认框、提示框、等待框等 常用的公共组件:确认框、提示框、等待…

eNSP中ping通不同VLAN中的计算机

以一边为例 LSW3 <Huawei>sys [Huawei]undo info en//关闭提示 [Huawei]vlan batch 13 24 [Huawei] int e0/0/2 [Huawei-Ethernet0/0/2]port link-type a [Huawei-Ethernet0/0/2] port de vlan 13 [Huawei-Ethernet0/0/2] q//退出 [Huawei] int e0/0/3 [Huawei-Ethernet0…

一个非常不错的源码和教程资源下载网站整站打包代码,适合用来搭建资源网站或者知识付费网站

找了好多资源类网站代码&#xff0c;目前发现这个不错。基于wordpress开发的&#xff0c;集成了ripro9.2的主题和一些美化的子主题样式&#xff0c;效果非常不错。更难得的是这个网站源码是全开源的&#xff0c;没有任何加密代码&#xff0c;想二次开发的话&#xff0c;非常适合…

jmeter,取“临时重定向的登录接口”响应头中的cookie

1、线程组--创建线程组&#xff1b; 2、线程组--添加--取样器--HTTP请求&#xff1b; 3、Http请求--添加--后置处理器--正则表达式提取器&#xff1b; 4、线程组--添加--监听器--查看结果树&#xff1b; 5、线程组--添加--取样器--调试取样器。 首先理解 自动重定向 与跟随…

kubernetes 学习笔记

1. Kubernetes 介绍 1.1 应用部署方式的演变 在部署应用程序的方式上&#xff0c;主要经理了三个时代&#xff1a; 传统部署&#xff1a;互联网早期&#xff0c;会直接将应用程序部署在物理机上。虚拟化部署&#xff1a;可以在一台物理机上运行多个虚拟机&#xff0c;每个虚…

一文讲清 QWidget 大小位置

一文讲清 QWidget 大小位置 前言 ​ QWidget 的位置基于桌面坐标系&#xff0c;以左上角为原点&#xff0c;向右x轴增加&#xff0c;向下y轴增加。 一、图解 ​ ​ 如上图所示&#xff0c;当窗口为顶层窗口时&#xff08;即没有任何父窗口&#xff09;&#xff0c;系统会自…

一款基于分布式文件存储的数据库MongoDB的介绍及基本使用教程

MongoDB 是由C语言编写的&#xff0c;是一个基于分布式文件存储的开源数据库系统。 在高负载的情况下&#xff0c;添加更多的节点&#xff0c;可以保证服务器性能。 MongoDB 旨在为WEB应用提供可扩展的高性能数据存储解决方案。 MongoDB 将数据存储为一个文档&#xff0c;数据结…

RocketMQ 跟踪消息发送轨迹

目录 概述实践如何启用消息轨迹配置创建Topic代码测试 结束 概述 阅读此文可以解决 RocketMQ 中消息是否发送成功&#xff0c;是否消费成功。 查询消息轨迹可作为生产环境中排查问题强有力的数据支持 &#xff0c;也是研发同学解决线上问题的重要武器之一。 详细如下&#x…

Navicat16 无限试用 亲测有效

Navicat16 无限试用 亲测有效 亲测有效&#xff01;&#xff01;&#xff01; 吐槽下&#xff0c;有的用不了&#xff0c;有的是图片&#xff0c;更甚者还有收费的&#xff0c;6的一批 粘贴下面的代码&#xff0c;保存到桌面&#xff0c;命名为 trial-navicat16.bat echo off…

DDOS攻击方式有哪些,要如何防护

DDOS攻击我们也称之为流量攻击&#xff0c;分布式拒绝服务攻击(英文意思是Distributed Denial of Service&#xff0c;简称DDOS&#xff09;于不同位置的多个攻击者同时向一个或数个目标发动攻击&#xff0c;或者一个攻击者控制了位于不同位置的多台机器并利用这些机器对受害者…

【漏洞复现】I Doc View在线文档预览任意文件读取 1day

漏洞描述 I Doc View在线文档预览是一款在线文档预览系统&#xff0c;可以实现文档的预览及文档协作编辑功能。其存在代码执行漏洞&#xff0c;使得攻击者可以通过利用这个接口&#xff0c;触发服务器下载并解析恶意文件&#xff0c;从而导致远程命令执行漏洞。进而控制服务器…