深度学习Day-30:CGAN入门丨生成手势图像丨可控制生成

news2025/2/6 8:56:45

  🍨 本文为:[🔗365天深度学习训练营] 中的学习记录博客
 🍖 原作者:[K同学啊 | 接辅导、项目定制]

要求:

  1. 结合代码进一步了解CGAN
  2. 学习如何运用生成好的生成器生成指定图像

一、 基础配置

  • 语言环境:Python3.8
  • 编译器选择:Pycharm
  • 深度学习环境:
    • torch==1.12.1+cu113
    • torchvision==0.13.1+cu113

二、 前期准备 

1. 导入第三方库

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import os

os.makedirs('./images', exist_ok=True)
os.makedirs('./training_weights', exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

得到如下输出: 

cuda

 2. 导入数据

batch_size = 128
train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

train_dataset = datasets.ImageFolder(root="GAN-3-data", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=0)

3. 数据可视化

运行下述代码:

def show_images(dl):
    for images, _ in dl:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))
        break

show_images(train_loader)

输出图像为:

4. 定义超参数 

运行下述代码:

latent_dim = 100
n_classes = 3
embedding_dim = 100

5. 构建模型

5.1.初始化权重

def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)

    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

5.2.定义生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 16)
        )
        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.model = nn.Sequential(
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, inputs):
        noise_vector, label = inputs
        label_output = self.label_conditioned_generator(label)
        label_output = label_output.view(-1, 1, 4, 4)
        latent_output = self.latent(noise_vector)
        latent_output = latent_output.view(-1, 512, 4, 4)
        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)
        return image

generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

from torchinfo import summary
summary(generator)

输出为:

Generator(
  (label_conditioned_generator): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=16, bias=True)
  )
  (latent): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Generator                                --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       1,616
├─Sequential: 1-2                        --
│    └─Linear: 2-3                       827,392
│    └─LeakyReLU: 2-4                    --
├─Sequential: 1-3                        --
│    └─ConvTranspose2d: 2-5              4,202,496
│    └─BatchNorm2d: 2-6                  1,024
│    └─ReLU: 2-7                         --
│    └─ConvTranspose2d: 2-8              2,097,152
│    └─BatchNorm2d: 2-9                  512
│    └─ReLU: 2-10                        --
│    └─ConvTranspose2d: 2-11             524,288
│    └─BatchNorm2d: 2-12                 256
│    └─ReLU: 2-13                        --
│    └─ConvTranspose2d: 2-14             131,072
│    └─BatchNorm2d: 2-15                 128
│    └─ReLU: 2-16                        --
│    └─ConvTranspose2d: 2-17             3,072
│    └─Tanh: 2-18                        --
=================================================================
Total params: 7,789,308
Trainable params: 7,789,308
Non-trainable params: 0
=================================================================

 5.3.定义鉴别器

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 3 * 128 * 128)
        )

        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 2, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 4, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 3, 2, bias=False),
            nn.BatchNorm2d(64 * 8, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Dropout(0.4),
            nn.Linear(4608, 1),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        img, label = inputs

        label_output = self.label_condition_disc(label)
        label_output = label_output.view(-1, 3, 128, 128)

        concat = torch.cat((img, label_output), dim=1)

        output = self.model(concat)
        return output

discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

summary(discriminator)

输出为:

Discriminator(
  (label_condition_disc): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=49152, bias=True)
  )
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (6): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (9): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Flatten(start_dim=1, end_dim=-1)
    (12): Dropout(p=0.4, inplace=False)
    (13): Linear(in_features=4608, out_features=1, bias=True)
    (14): Sigmoid()
  )
)
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Discriminator                            --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       4,964,352
├─Sequential: 1-2                        --
│    └─Conv2d: 2-3                       6,144
│    └─LeakyReLU: 2-4                    --
│    └─Conv2d: 2-5                       131,072
│    └─BatchNorm2d: 2-6                  256
│    └─LeakyReLU: 2-7                    --
│    └─Conv2d: 2-8                       524,288
│    └─BatchNorm2d: 2-9                  512
│    └─LeakyReLU: 2-10                   --
│    └─Conv2d: 2-11                      2,097,152
│    └─BatchNorm2d: 2-12                 1,024
│    └─LeakyReLU: 2-13                   --
│    └─Flatten: 2-14                     --
│    └─Dropout: 2-15                     --
│    └─Linear: 2-16                      4,609
│    └─Sigmoid: 2-17                     --
=================================================================
Total params: 7,729,709
Trainable params: 7,729,709
Non-trainable params: 0
=================================================================

三、 训练模型 

1. 定义训练参数

adversarial_loss = nn.BCELoss()

def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss

def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss

2. 定义优化器

learning_rate = 0.0002

G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

3. 训练模型

num_epochs = 100

D_loss_plot, G_loss_plot = [], []

for epoch in range(1, num_epochs + 1):

    D_loss_list, G_loss_list = [], []

    for index, (real_images, labels) in enumerate(train_loader):
        D_optimizer.zero_grad()

        real_images = real_images.to(device)
        labels = labels.to(device)

        labels = labels.unsqueeze(1).long()

        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)

        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector, labels))

        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output, fake_target)

        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)

        D_total_loss.backward()
        D_optimizer.step()

        G_optimizer.zero_grad()
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
        G_loss_list.append(G_loss)

        G_loss.backward()
        G_optimizer.step()

    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
        (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)),
        torch.mean(torch.FloatTensor(G_loss_list))))

    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))

    if epoch % 10 == 0:
        save_image(generated_image.data[:50], './images/sample_%d' % epoch + '.png', nrow=5, normalize=True)
        torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

输出为:

Epoch: [1/100]: D_loss: 0.285, G_loss: 2.018
Epoch: [2/100]: D_loss: 0.331, G_loss: 2.298
Epoch: [3/100]: D_loss: 0.403, G_loss: 1.715
Epoch: [4/100]: D_loss: 0.467, G_loss: 1.416
Epoch: [5/100]: D_loss: 0.490, G_loss: 1.618
Epoch: [6/100]: D_loss: 0.490, G_loss: 1.585
Epoch: [7/100]: D_loss: 0.379, G_loss: 1.674
Epoch: [8/100]: D_loss: 0.443, G_loss: 1.889
Epoch: [9/100]: D_loss: 0.541, G_loss: 2.067
Epoch: [10/100]: D_loss: 0.565, G_loss: 1.751
Epoch: [11/100]: D_loss: 0.528, G_loss: 1.495
Epoch: [12/100]: D_loss: 0.555, G_loss: 1.461
Epoch: [13/100]: D_loss: 0.569, G_loss: 1.490
Epoch: [14/100]: D_loss: 0.531, G_loss: 1.498
Epoch: [15/100]: D_loss: 0.504, G_loss: 1.532
Epoch: [16/100]: D_loss: 0.487, G_loss: 1.612
Epoch: [17/100]: D_loss: 0.457, G_loss: 1.776
Epoch: [18/100]: D_loss: 0.462, G_loss: 1.767
Epoch: [19/100]: D_loss: 0.437, G_loss: 1.946
Epoch: [20/100]: D_loss: 0.446, G_loss: 1.848
Epoch: [21/100]: D_loss: 0.463, G_loss: 1.718
Epoch: [22/100]: D_loss: 0.473, G_loss: 1.748
Epoch: [23/100]: D_loss: 0.503, G_loss: 1.579
Epoch: [24/100]: D_loss: 0.482, G_loss: 1.410
Epoch: [25/100]: D_loss: 0.489, G_loss: 1.440
Epoch: [26/100]: D_loss: 0.494, G_loss: 1.425
Epoch: [27/100]: D_loss: 0.510, G_loss: 1.398
Epoch: [28/100]: D_loss: 0.475, G_loss: 1.410
Epoch: [29/100]: D_loss: 0.473, G_loss: 1.459
Epoch: [30/100]: D_loss: 0.473, G_loss: 1.489
Epoch: [31/100]: D_loss: 0.462, G_loss: 1.484
Epoch: [32/100]: D_loss: 0.448, G_loss: 1.520
Epoch: [33/100]: D_loss: 0.457, G_loss: 1.548
Epoch: [34/100]: D_loss: 0.418, G_loss: 1.558
Epoch: [35/100]: D_loss: 0.433, G_loss: 1.667
Epoch: [36/100]: D_loss: 0.402, G_loss: 1.665
Epoch: [37/100]: D_loss: 0.401, G_loss: 1.709
Epoch: [38/100]: D_loss: 0.425, G_loss: 1.841
Epoch: [39/100]: D_loss: 0.399, G_loss: 1.711
Epoch: [40/100]: D_loss: 0.429, G_loss: 1.873
Epoch: [41/100]: D_loss: 0.374, G_loss: 1.857
Epoch: [42/100]: D_loss: 0.382, G_loss: 1.869
Epoch: [43/100]: D_loss: 0.431, G_loss: 1.935
Epoch: [44/100]: D_loss: 0.355, G_loss: 1.871
Epoch: [45/100]: D_loss: 0.363, G_loss: 1.875
Epoch: [46/100]: D_loss: 0.485, G_loss: 2.011
Epoch: [47/100]: D_loss: 0.391, G_loss: 1.994
Epoch: [48/100]: D_loss: 0.331, G_loss: 1.924
Epoch: [49/100]: D_loss: 0.317, G_loss: 1.930
Epoch: [50/100]: D_loss: 0.353, G_loss: 2.035
Epoch: [51/100]: D_loss: 0.334, G_loss: 2.072
Epoch: [52/100]: D_loss: 0.387, G_loss: 2.092
Epoch: [53/100]: D_loss: 0.380, G_loss: 2.139
Epoch: [54/100]: D_loss: 0.302, G_loss: 2.077
Epoch: [55/100]: D_loss: 0.311, G_loss: 2.055
Epoch: [56/100]: D_loss: 0.326, G_loss: 2.169
Epoch: [57/100]: D_loss: 0.309, G_loss: 2.239
Epoch: [58/100]: D_loss: 0.323, G_loss: 2.207
Epoch: [59/100]: D_loss: 0.285, G_loss: 2.239
Epoch: [60/100]: D_loss: 0.306, G_loss: 2.304
Epoch: [61/100]: D_loss: 0.287, G_loss: 2.254
Epoch: [62/100]: D_loss: 0.295, G_loss: 2.406
Epoch: [63/100]: D_loss: 0.305, G_loss: 2.499
Epoch: [64/100]: D_loss: 0.298, G_loss: 2.462
Epoch: [65/100]: D_loss: 0.255, G_loss: 2.418
Epoch: [66/100]: D_loss: 0.480, G_loss: 2.714
Epoch: [67/100]: D_loss: 0.265, G_loss: 2.379
Epoch: [68/100]: D_loss: 0.256, G_loss: 2.453
Epoch: [69/100]: D_loss: 0.252, G_loss: 2.465
Epoch: [70/100]: D_loss: 0.240, G_loss: 2.600
Epoch: [71/100]: D_loss: 0.250, G_loss: 2.516
Epoch: [72/100]: D_loss: 0.228, G_loss: 2.534
Epoch: [73/100]: D_loss: 0.249, G_loss: 2.566
Epoch: [74/100]: D_loss: 0.385, G_loss: 2.915
Epoch: [75/100]: D_loss: 0.232, G_loss: 2.566
Epoch: [76/100]: D_loss: 0.335, G_loss: 2.776
Epoch: [77/100]: D_loss: 0.243, G_loss: 2.703
Epoch: [78/100]: D_loss: 0.232, G_loss: 2.650
Epoch: [79/100]: D_loss: 0.216, G_loss: 2.736
Epoch: [80/100]: D_loss: 0.219, G_loss: 2.725
Epoch: [81/100]: D_loss: 0.272, G_loss: 2.869
Epoch: [82/100]: D_loss: 0.218, G_loss: 2.839
Epoch: [83/100]: D_loss: 0.219, G_loss: 2.836
Epoch: [84/100]: D_loss: 0.233, G_loss: 2.948
Epoch: [85/100]: D_loss: 0.209, G_loss: 2.952
Epoch: [86/100]: D_loss: 0.251, G_loss: 3.052
Epoch: [87/100]: D_loss: 0.198, G_loss: 2.905
Epoch: [88/100]: D_loss: 0.193, G_loss: 3.054
Epoch: [89/100]: D_loss: 0.215, G_loss: 2.995
Epoch: [90/100]: D_loss: 0.193, G_loss: 3.081
Epoch: [91/100]: D_loss: 0.446, G_loss: 3.269
Epoch: [92/100]: D_loss: 0.227, G_loss: 2.871
Epoch: [93/100]: D_loss: 0.191, G_loss: 3.008
Epoch: [94/100]: D_loss: 0.200, G_loss: 3.066
Epoch: [95/100]: D_loss: 0.200, G_loss: 3.142
Epoch: [96/100]: D_loss: 0.186, G_loss: 3.113
Epoch: [97/100]: D_loss: 0.207, G_loss: 3.159
Epoch: [98/100]: D_loss: 0.219, G_loss: 3.213
Epoch: [99/100]: D_loss: 0.177, G_loss: 3.205
Epoch: [100/100]: D_loss: 0.184, G_loss: 3.258

4. 可视化

4.1.LOSS图

G_loss_list = [i.item() for i in G_loss_plot]
D_loss_list = [i.item() for i in D_loss_plot]

import warnings

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100

plt.figure(figsize=(8,4))
plt.title("Generator and Descriminator Loss During Training")
plt.plot(G_loss_list,label = "G")
plt.plot(D_loss_list,label = "D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

输出图像为:

 

4.2.生成指定图像

from numpy.random import randn

generator.load_state_dict(torch.load("./training_weights/generator_epoch_100.pth"), strict = False)
generator.eval()

interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

label = 0
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()

predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()

import warnings

warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100


plt.figure(figsize=(8, 3))
pred = (predictions[0, :, :, :] + 1 ) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

输出图像为:

四、理论基础

        CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息。

        CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。

        CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:

  1. 有监督学习:CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
  2. 联合隐层表征:在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
  3. 可控性:CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
  4. 使用卷积结构:CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。

        相比于传统的GAN,CGAN的主要异同点包括条件信息的输入、训练稳定性、损失函数、网络结构等,其具体内容为:

  1. 条件信息的输入:CGAN引入了条件变量,使得生成器和判别器都能接收到更多的信息来指导训练过程,这是传统GAN所不具备的。
  2. 训练稳定性:传统GAN在训练过程中容易产生模式崩溃(mode collapse)的问题,而CGAN由于有了额外的条件信息,可以提高训练的稳定性和生成数据的多样性。
  3. 损失函数:虽然CGAN的损失函数仍然保留了传统GAN的对抗损失函数的形式,但额外添加的条件信息使得损失计算更加复杂且有针对性。
  4. 网络结构:在实现上,CGAN可以采用更深更复杂的网络结构,如卷积神经网络,这有助于处理更为复杂的数据类型,比如高分辨率图像。

        CGAN网络结构如下图所示:                由上图的网络结构可知,条件信息y作为额外的输入被引入对抗网络中,与生成器中的噪声z合并作为隐含层表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的诸多方面研究中被证明是非常有效的,也为后续的相关工作提供了积极的指导作用

        综上所述,CGAN的核心在于它通过引入条件信息来增强模型的生成能力和可控性,与传统GAN相比,它提供了更明确的训练目标和更好的生成效果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2056496.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

功能测试与自动化测试详解

🍅 点击文末小卡片,免费获取软件测试全套资料,资料在手,涨薪更快 什么是自动化测试? 自动化测试是指利用软件测试工具自动实现全部或部分测试,它是软件测试的一个重要组成 部分,能完成许多手工测试无法实…

【C++】————智能指针

作者主页: 作者主页 本篇博客专栏:C 创作时间 :2024年8月20日 一,什么是智能指针 在C中没有垃圾回收机制,必须自己释放分配的内存,否则就会造成内存泄露。解决这个问题最有效的方法是使用智能指针&…

传染病防控宣传小程序的设计

管理员账户功能包括:系统首页,个人中心,用户管理,防控知识管理,医院信息管理,健康上报管理,医疗捐赠管理,捐赠信息管理,系统管理 微信端账号功能包括:系统首…

力扣面试经典算法150题:买卖股票的最佳时机 II

买卖股票的最佳时机 II 今天的题目是力扣面试经典150题中的数组的中等难度题:买卖股票的最佳时机 II。 题目链接:https://leetcode.cn/problems/best-time-to-buy-and-sell-stock-ii/description/?envTypestudy-plan-v2&envIdtop-interview-150 问…

EfficientFormer 系列算法

1. EfficientFormer V1 模型 论文地址:https://proceedings.neurips.cc/paper_files/paper/2022/file/5452ad8ee6ea6e7dc41db1cbd31ba0b8-Paper-Conference.pdf EfficientFormer V1 基于 ViT 的模型中使用的网络架构和具体的算子,找到端侧低效的原因。然…

深入剖析资产负债率与净资产收益率,掌握财务报表解读技巧

一、概述 财务报表中蕴含了丰富的信息,如果我们在解读时没有清晰的思路,忽略重点,就很容易被庞杂的数据搞得晕头转向。本文将从几个关键指标出发,包括资产负债率的分析、净资产收益率的解读,以及如何计算销售复合增长…

企业高性能web服务器——nginx

一、web基础介绍 Apache 和 Nginx 是当今为互联网提供动力的最流行的Web 服务器。 1.1、apache服务器 1.1.1、Apache prefork 模型 预派生模式,有一个主控制进程,然后生成多个子进程,使用select模型,最大并发1024每个子进程有一…

萌啦数据ozon怎么用,萌啦数据ozon使用教程

在跨境电商的浩瀚蓝海中,Ozon作为俄罗斯及独联体地区领先的电商平台,正吸引着越来越多中国卖家的目光。而“萌啦数据”作为专为跨境电商卖家打造的数据分析工具,其针对Ozon平台的功能更是让众多商家如虎添翼。今天,我们就来详细探…

后悔和父母出游的年轻人,正在计划带宠物旅行

文 | 螳螂观察 作者 | 青月 美编 |赵倩 相比于和父母一起出门远游,现在越来越多的95后“铲屎官”似乎更愿意和自家的宠物们组“旅游搭子”。 这听起来可能有些刺耳,但其实是当下很多年轻人的心声。 “带父母一起去北京玩,本来打算第二天…

【 每日一题 | 计算机网络】定长子网划分

重要知识点讲解 我们首先需要了解一下无分类CIDR的编址格式x.x.x/24,表示有24位的网路号,那么相应的主机号为32-248位子网掩码(很重要),用来表示IP地址中标识网络号以及子网号的,也就是说如果要进行子网划…

鸿蒙内核源码分析(中断切换篇) | 系统因中断活力四射

关于中断部分系列篇将用三篇详细说明整个过程. 中断概念篇 中断概念很多,比如中断控制器,中断源,中断向量,中断共享,中断处理程序等等.本篇做一次整理.先了解透概念才好理解中断过程.用海公公打比方说明白中断各个概念…

Windows 环境下 Go 语言使用第三方压缩包 gozstd 的报错处理

该文章主要记录在windows平台用go语言使用gozstd包时,遇到的错误及处理过程(踩坑之旅)! 一、gozstd简介 gozstd是一个针对Zstandard(简称Zstd)的Go语言包装器,它提供了简单且高效的API&#xf…

金山云Q2调整后EBITDA率提升至3.2% 高质量发展驱动经营质效双增

8月20日,金山云公布了2024年第二季度业绩。 季度内,金山云整体业绩延续向好态势,实现收入规模、盈利能力、经营现金流的联动共赢。财报显示,金山云Q2营收18.9亿元,公有云实现收入12.3亿元,行业云实现收入6…

The Sandbox 新提案: 2024 年亚洲和拉丁美洲区块链活动预算

理事会建议: 积极 🙂 内容 此提案请求为2024年第四季度,The Sandbox 在东南亚和拉丁美洲的主要区块链活动中的激活分配 94,500 美元的 SAND 倡议预算。(具体活动列表见下方活动描述) 原因 区域团队希望在这些现场活…

国际校企合作|深信服、常州信息职业技术学院、马来西亚汽车工业大学三方国际化人才培养合作签约仪式圆满成功

2024年8月19日,深信服科技股份有限公司与常州信息职业技术学院、马来西亚汽车工业大学正式签署了具有里程碑意义的国际校企合作协议。此次签约不仅是“教随产出、校企同行”理念的一次成功实践,更是中马两国友谊与合作的象征。 常州信息职业技术学院党委…

面试题目:(4)给表达式添加运算符

目录 题目 代码 思路解析 例子 题目 题目 给定一个仅包含数字 0-9 的字符串 num 和一个目标值整数 target &#xff0c;在 num 的数字之间添加 二元 运算符&#xff08;不是一元&#xff09;、- 或 * &#xff0c;返回 所有能够得到 target 的表达式。1 < num.length &…

【JVM】深入理解类加载机制(一)

深入理解类加载机制 Klass模型 Java的每个类&#xff0c;在JVM中都有一个对应的Klass类实例与之对应&#xff0c;存储类的元信息如:常量池、属性信息、方法信息…从继承关系上也能看出来&#xff0c;类的元信息是存储在元空间的。普通的Java类在JVM中对应的是InstanceKlass(C)…

便利店(超市)管理系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言具体实现截图详细视频演示技术栈系统测试为什么选择我官方认证玩家&#xff0c;服务很多代码文档&#xff0c;百分百好评&#xff0c;战绩可查&#xff01;&#xff01;入职于互联网大厂&#xff0c;可以交流&#xff0c;共同进步。有保障的售后 代码参考数据库参…

层次聚类算法原理及Python实现

层次聚类算法&#xff08;Hierarchical Clustering Method&#xff09;是一种基于簇间相似度在不同层次上分析数据&#xff0c;从而形成树形聚类结构的算法。它主要分为两种形式&#xff1a;凝聚层次聚类&#xff08;自下而上&#xff09;和分裂层次聚类&#xff08;自上而下&a…

ansible --------拓展

编辑 hosts 配置文件 [rootmo ~]# vim /etc/ansible/hosts # 创建目录 [rootmo ~]# mkdir /etc/ansible/playbook # 编辑配置文件 [rootmo ~]# vim /etc/ansible/playbook/nginx.yml # 执行测试 [rootmo ~]# ansible-playbook /etc/ansible/playbook/nginx.yml roles 修…