第G4周:CGAN|生成手势图像 | 可控制生成

news2024/11/13 10:16:47

本文为🔗365天深度学习训练营 中的学习记录博客
原作者:K同学啊

理论知识:
条件生成对抗网络(CGAN)是在生成对抗网络(GAN)的基础上进行了一些改进。对于原始GAN的生成器而言,其生成的图像数据是随机不可预测的,因此我们无法控制网络的输出,在实际操作中的可控性不强。
针对上述原始GAN无法生成具有特定属性的图像数据的问题,Mehdi Mirza等人在2014年提出了条件生成对抗网络,通过给原始生成对抗网络中的生成器G和判别器D增加额外的条件,例如我们需要生成器G生成一张没有阴影的图像,此时判别器D就需要判断生成器所生成的图像是否是一张没有阴影的图像。条件生成对抗网络的本质是将额外添加的信息融入到生成器和判别器中,其中添加的信息可以是图像的类别、人脸表情和其他辅助信息等,旨在把无监督学习的GAN转化为有监督学习的CGAN,便于网络能够在我们的掌控下更好地进行训练。CGAN网络结构如图1所示。

图1:条件生成对抗网络结构
在这里插入图片描述

由图1的网络结构可知,条件信息y作为额外的输入被引入对抗网络中,与生成器中的噪声z合并作为隐含层表达;而在判别器D中,条件信息y则与原始数据x合并作为判别函数的输入。这种改进在以后的诸多方面研究中被证明是非常有效的,也为后续的相关工作提供了积极的指导作用。

我的环境:
语言环境:Python3.10.11
编译器:Jupyter Notebook
深度学习框架:Pytorch 2.2.2+cpu
数据集:手势数据集

一、准备工作

from torchvision       import datasets, transforms
from torch.autograd    import Variable
from torchvision.utils import save_image, make_grid
from torchsummary      import summary
import matplotlib.pyplot as plt
import numpy             as np
import torch.nn          as nn
import torch.optim       as optim
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

代码输出结果:

device(type='cpu')
  1. 导入数据
batch_size = 128

train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

train_dataset = datasets.ImageFolder(root='E:/365-jinjieying/GAN rumenshizhan/G4/data/rps/rps/', 
                                     transform=train_transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True,
                                           num_workers=6)
  1. 数据可视化

关于ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))详解:

1、ax: 这是一个matplotlib的轴对象(axis),用于在图形上放置图像。通常,它用于创建子图。
2、make_grid(images.detach(), nrow=10): 这是一个函数调用。make_grid函数的作用是将一组图像拼接成一个网格。它接受两个参数:images和nrow。images是一个包含图像的张量,nrow是可选参数,表示每行显示的图像数量。在这里,它将图像进行拼接,并设置每行显示10个图像。
3、permute(1, 2, 0): 这是一个张量的操作,用于交换维度的顺序。在这里,对于一个3维的张量(假设图像维度为(C,H,W),其中C是通道数,H是高度,W是宽度),permute(1, 2, 0)将把通道维度(C)移动到最后,而将高度和宽度维度(H,W)放在前面。这样做是为了符合matplotlib对图像的要求,因为matplotlib要求图像的维度为(H,W,C)。
4、imshow(…): 这是matplotlib的一个函数,用于显示图像。在这里,它接受一个拼接好并且维度已经调整好的图像张量,并将其显示在之前创建的轴对象(ax)上。

# 可视化第一个 batch 的数据
def show_images(dl):
    for images, _ in dl:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))
        break

show_images(train_loader)

代码输出结果:
在这里插入图片描述

二、构建模型

latent_dim    = 100
n_classes     = 3
embedding_dim = 100
  1. 权重初始化
# 自定义权重初始化函数,用于初始化生成器和判别器的权重
def weights_init(m):
    # 获取当前层的类名
    classname = m.__class__.__name__

    # 如果当前层是卷积层(类名中包含 'Conv' )
    if classname.find('Conv') != -1:
        # 使用正态分布随机初始化权重,均值为0,标准差为0.02
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    
    # 如果当前层是批归一化层(类名中包含 'BatchNorm' )
    elif classname.find('BatchNorm') != -1:
        # 使用正态分布随机初始化权重,均值为1,标准差为0.02
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        # 将偏置项初始化为全零
        torch.nn.init.zeros_(m.bias)
  1. 构建生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # 定义条件标签的生成器部分,用于将标签映射到嵌入空间中
        # n_classes:条件标签的总数
        # embedding_dim:嵌入空间的维度
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),  # 使用Embedding层将条件标签映射为稠密向量
            nn.Linear(embedding_dim, 16)             # 使用线性层将稠密向量转换为更高维度
        )

        # 定义潜在向量的生成器部分,用于将噪声向量映射到图像空间中
        # latent_dim:潜在向量的维度
        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),  # 使用线性层将潜在向量转换为更高维度
            nn.LeakyReLU(0.2, inplace=True)  # 使用LeakyReLU激活函数进行非线性映射
        )

        # 定义生成器的主要结构,将条件标签和潜在向量合并成生成的图像
        self.model = nn.Sequential(
            # 反卷积层1:将合并后的向量映射为64x8x8的特征图
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),  # 批标准化
            nn.ReLU(True),  # ReLU激活函数
            # 反卷积层2:将64x8x8的特征图映射为64x4x4的特征图
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            # 反卷积层3:将64x4x4的特征图映射为64x2x2的特征图
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            # 反卷积层4:将64x2x2的特征图映射为64x1x1的特征图
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            # 反卷积层5:将64x1x1的特征图映射为3x64x64的RGB图像
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()  # 使用Tanh激活函数将生成的图像像素值映射到[-1, 1]范围内
        )

    def forward(self, inputs):
        noise_vector, label = inputs
        # 通过条件标签生成器将标签映射为嵌入向量
        label_output = self.label_conditioned_generator(label)
        # 将嵌入向量的形状变为(batch_size, 1, 4, 4),以便与潜在向量进行合并
        label_output = label_output.view(-1, 1, 4, 4)
        # 通过潜在向量生成器将噪声向量映射为潜在向量
        latent_output = self.latent(noise_vector)
        # 将潜在向量的形状变为(batch_size, 512, 4, 4),以便与条件标签进行合并
        latent_output = latent_output.view(-1, 512, 4, 4)
        
        # 将条件标签和潜在向量在通道维度上进行合并,得到合并后的特征图
        concat = torch.cat((latent_output, label_output), dim=1)
        # 通过生成器的主要结构将合并后的特征图生成为RGB图像
        image = self.model(concat)
        return image
generator = Generator().to(device)
generator.apply(weights_init)
print(generator)

代码输出结果:

Generator(
  (label_conditioned_generator): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=16, bias=True)
  )
  (latent): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
from torchinfo import summary

summary(generator)

代码输出结果:

=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Generator                                --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       1,616
├─Sequential: 1-2                        --
│    └─Linear: 2-3                       827,392
│    └─LeakyReLU: 2-4                    --
├─Sequential: 1-3                        --
│    └─ConvTranspose2d: 2-5              4,202,496
│    └─BatchNorm2d: 2-6                  1,024
│    └─ReLU: 2-7                         --
│    └─ConvTranspose2d: 2-8              2,097,152
│    └─BatchNorm2d: 2-9                  512
│    └─ReLU: 2-10                        --
│    └─ConvTranspose2d: 2-11             524,288
│    └─BatchNorm2d: 2-12                 256
│    └─ReLU: 2-13                        --
│    └─ConvTranspose2d: 2-14             131,072
│    └─BatchNorm2d: 2-15                 128
│    └─ReLU: 2-16                        --
│    └─ConvTranspose2d: 2-17             3,072
│    └─Tanh: 2-18                        --
=================================================================
Total params: 7,789,308
Trainable params: 7,789,308
Non-trainable params: 0
=================================================================
a = torch.ones(100)
b = torch.ones(1)
b = b.long()
a = a.to(device)
b = b.to(device)
generator((a,b))

代码输出结果:

tensor([[[[ 1.0031e-03,  2.0869e-03,  1.9845e-03,  ...,  3.0130e-03,
            2.6252e-03,  2.2065e-03],
          [ 1.1618e-04, -3.6158e-03,  5.6959e-03,  ..., -8.4364e-03,
            3.1406e-03,  5.5108e-04],
          [ 1.4334e-03,  4.0221e-03,  4.0934e-03,  ...,  2.8333e-04,
            2.6435e-03,  1.3885e-03],
          ...,
          [-4.1486e-04, -2.7967e-03,  6.7204e-04,  ..., -4.9740e-03,
            7.0077e-04,  1.1249e-04],
          [ 6.3077e-04,  6.6072e-03,  1.5604e-03,  ...,  3.0677e-03,
            2.2500e-03,  2.8187e-03],
          [ 1.7245e-04, -2.3452e-03,  2.3115e-03,  ..., -3.5958e-04,
           -2.4765e-03,  7.9562e-04]],

         [[-3.7547e-04, -4.1838e-03, -2.8569e-03,  ..., -9.1987e-04,
           -2.1239e-03,  1.5380e-03],
          [-4.7561e-04, -7.1297e-04, -1.6785e-03,  ...,  1.6375e-03,
            1.3578e-03,  1.7257e-04],
          [-2.1179e-03, -2.9905e-03, -1.1075e-03,  ..., -4.4021e-03,
           -7.6593e-03, -2.1271e-03],
          ...,
          [ 7.7395e-04, -2.2313e-03, -4.4506e-04,  ..., -3.2353e-04,
            3.1063e-04, -7.9383e-04],
          [-1.8639e-03, -1.7142e-03, -1.6943e-03,  ..., -3.4480e-03,
           -7.4053e-03, -8.9081e-04],
          [ 2.2234e-03,  4.7096e-03,  2.6137e-03,  ...,  5.2392e-03,
            3.5531e-03,  9.4842e-04]],

         [[-1.8115e-04,  1.1916e-03,  1.8226e-03,  ...,  1.0432e-03,
           -5.6756e-04, -1.5127e-03],
          [-8.8927e-04,  2.5289e-03, -6.3127e-04,  ...,  3.6392e-03,
           -1.7966e-03,  4.8274e-04],
          [-1.2093e-03, -3.8731e-03, -4.6937e-03,  ...,  1.2670e-03,
           -6.5182e-03,  1.5792e-03],
          ...,
          [ 7.5084e-04,  8.0225e-04, -1.1725e-03,  ...,  1.4274e-05,
           -3.2307e-05, -1.4468e-04],
          [-1.7094e-03, -3.2152e-03, -3.3625e-03,  ..., -9.7879e-04,
           -5.1849e-03,  2.4819e-03],
          [-2.5147e-04, -1.2123e-03, -2.9802e-04,  ..., -1.9884e-03,
           -5.4956e-04, -1.6835e-03]]]], grad_fn=<TanhBackward0>)
  1. 构建鉴别器
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # 定义一个条件标签的嵌入层,用于将类别标签转换为特征向量
        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),     # 嵌入层将类别标签编码为固定长度的向量
            nn.Linear(embedding_dim, 3*128*128)         # 线性层将嵌入的向量转换为与图像尺寸相匹配的特征张量
        )
        
        # 定义主要的鉴别器模型
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),       # 输入通道为6(包含图像和标签的通道数),输出通道为64,4x4的卷积核,步长为2,padding为1
            nn.LeakyReLU(0.2, inplace=True),             # LeakyReLU激活函数,带有负斜率,增加模型对输入中的负值的感知能力
            nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),    # 输入通道为64,输出通道为64*2,4x4的卷积核,步长为3,padding为2
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),  # 批量归一化层,有利于训练稳定性和收敛速度
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),  # 输入通道为64*2,输出通道为64*4,4x4的卷积核,步长为3,padding为2
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),  # 输入通道为64*4,输出通道为64*8,4x4的卷积核,步长为3,padding为2
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),                               # 将特征图展平为一维向量,用于后续全连接层处理
            nn.Dropout(0.4),                            # 随机失活层,用于减少过拟合风险
            nn.Linear(4608, 1),                         # 全连接层,将特征向量映射到输出维度为1的向量
            nn.Sigmoid()                                # Sigmoid激活函数,用于输出范围限制在0到1之间的概率值
        )

    def forward(self, inputs):
        img, label = inputs
        
        # 将类别标签转换为特征向量
        label_output = self.label_condition_disc(label)
        # 重塑特征向量为与图像尺寸相匹配的特征张量
        label_output = label_output.view(-1, 3, 128, 128)
        
        # 将图像特征和标签特征拼接在一起作为鉴别器的输入
        concat = torch.cat((img, label_output), dim=1)
        
        # 将拼接后的输入通过鉴别器模型进行前向传播,得到输出结果
        output = self.model(concat)
        return output
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

代码输出结果:

Discriminator(
  (label_condition_disc): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=49152, bias=True)
  )
  (model): Sequential(
    (0): Conv2d(6, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (6): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(3, 3), padding=(2, 2), bias=False)
    (9): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Flatten(start_dim=1, end_dim=-1)
    (12): Dropout(p=0.4, inplace=False)
    (13): Linear(in_features=4608, out_features=1, bias=True)
    (14): Sigmoid()
  )
)
summary(discriminator)

代码输出结果:

=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
Discriminator                            --
├─Sequential: 1-1                        --
│    └─Embedding: 2-1                    300
│    └─Linear: 2-2                       4,964,352
├─Sequential: 1-2                        --
│    └─Conv2d: 2-3                       6,144
│    └─LeakyReLU: 2-4                    --
│    └─Conv2d: 2-5                       131,072
│    └─BatchNorm2d: 2-6                  256
│    └─LeakyReLU: 2-7                    --
│    └─Conv2d: 2-8                       524,288
│    └─BatchNorm2d: 2-9                  512
│    └─LeakyReLU: 2-10                   --
│    └─Conv2d: 2-11                      2,097,152
│    └─BatchNorm2d: 2-12                 1,024
│    └─LeakyReLU: 2-13                   --
│    └─Flatten: 2-14                     --
│    └─Dropout: 2-15                     --
│    └─Linear: 2-16                      4,609
│    └─Sigmoid: 2-17                     --
=================================================================
Total params: 7,729,709
Trainable params: 7,729,709
Non-trainable params: 0
=================================================================
a = torch.ones(2,3,128,128)
b = torch.ones(2,1)
b = b.long()
a = a.to(device)
b = b.to(device)
c = discriminator((a,b))
c.size()

代码输出结果:

torch.Size([2, 1])

三、训练模型

  1. 定义损失函数
adversarial_loss = nn.BCELoss() 

def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss

def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss
  1. 定义优化器
learning_rate = 0.0002

G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))
  1. 训练模型

# 设置训练的总轮数
num_epochs = 300
# 初始化用于存储每轮训练中判别器和生成器损失的列表
D_loss_plot, G_loss_plot = [], []

# 循环进行训练
for epoch in range(1, num_epochs + 1):
    
    # 初始化每轮训练中判别器和生成器损失的临时列表
    D_loss_list, G_loss_list = [], []
    
    # 遍历训练数据加载器中的数据
    for index, (real_images, labels) in enumerate(train_loader):
        # 清空判别器的梯度缓存
        D_optimizer.zero_grad()
        # 将真实图像数据和标签转移到GPU(如果可用)
        real_images = real_images.to(device)
        labels      = labels.to(device)
        
        # 将标签的形状从一维向量转换为二维张量(用于后续计算)
        labels = labels.unsqueeze(1).long()
        # 创建真实目标和虚假目标的张量(用于判别器损失函数)
        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        # 计算判别器对真实图像的损失
        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)
        
        # 从噪声向量中生成假图像(生成器的输入)
        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector, labels))
        
        # 计算判别器对假图像的损失(注意detach()函数用于分离生成器梯度计算图)
        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output, fake_target)

        # 计算判别器总体损失(真实图像损失和假图像损失的平均值)
        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss.item())

        # 反向传播更新判别器的参数
        D_total_loss.backward()
        D_optimizer.step()

        # 清空生成器的梯度缓存
        G_optimizer.zero_grad()
        # 计算生成器的损失
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
        G_loss_list.append(G_loss.item())

        # 反向传播更新生成器的参数
        G_loss.backward()
        G_optimizer.step()

    # 打印当前轮次的判别器和生成器的平均损失
    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
            (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), 
            torch.mean(torch.FloatTensor(G_loss_list))))
    
    # 将当前轮次的判别器和生成器的平均损失保存到列表中
    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))

    if epoch%10 == 0:
        # 将生成的假图像保存为图片文件
        save_image(generated_image.data[:50], './G4/images/sample_%d' % epoch + '.png', nrow=5, normalize=True)
        # 将当前轮次的生成器和判别器的权重保存到文件
        torch.save(generator.state_dict(), './G4/training_weights/generator_epoch_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './G4/training_weights/discriminator_epoch_%d.pth' % (epoch))

代码输出结果:

Epoch: [1/300]: D_loss: 0.317, G_loss: 1.466
Epoch: [2/300]: D_loss: 0.119, G_loss: 3.441
Epoch: [3/300]: D_loss: 0.194, G_loss: 3.075
Epoch: [4/300]: D_loss: 0.326, G_loss: 2.409
Epoch: [5/300]: D_loss: 0.303, G_loss: 2.349
Epoch: [6/300]: D_loss: 0.247, G_loss: 2.302
Epoch: [7/300]: D_loss: 0.253, G_loss: 2.743
Epoch: [8/300]: D_loss: 0.330, G_loss: 2.541
Epoch: [9/300]: D_loss: 0.424, G_loss: 2.501
Epoch: [10/300]: D_loss: 0.486, G_loss: 1.698
Epoch: [11/300]: D_loss: 0.447, G_loss: 1.587
Epoch: [12/300]: D_loss: 0.497, G_loss: 1.635
Epoch: [13/300]: D_loss: 0.466, G_loss: 1.665
Epoch: [14/300]: D_loss: 0.482, G_loss: 1.764
Epoch: [15/300]: D_loss: 0.526, G_loss: 1.563
Epoch: [16/300]: D_loss: 0.488, G_loss: 1.584
Epoch: [17/300]: D_loss: 0.457, G_loss: 1.556
Epoch: [18/300]: D_loss: 0.441, G_loss: 1.597
Epoch: [19/300]: D_loss: 0.457, G_loss: 1.696
Epoch: [20/300]: D_loss: 0.461, G_loss: 1.671
Epoch: [21/300]: D_loss: 0.434, G_loss: 1.704
Epoch: [22/300]: D_loss: 0.419, G_loss: 1.773
Epoch: [23/300]: D_loss: 0.402, G_loss: 1.798
Epoch: [24/300]: D_loss: 0.465, G_loss: 1.836
Epoch: [25/300]: D_loss: 0.419, G_loss: 1.789
Epoch: [26/300]: D_loss: 0.443, G_loss: 1.719
Epoch: [27/300]: D_loss: 0.414, G_loss: 1.698
Epoch: [28/300]: D_loss: 0.451, G_loss: 1.607
Epoch: [29/300]: D_loss: 0.467, G_loss: 1.563
Epoch: [30/300]: D_loss: 0.478, G_loss: 1.504
Epoch: [31/300]: D_loss: 0.462, G_loss: 1.471
Epoch: [32/300]: D_loss: 0.468, G_loss: 1.462
Epoch: [33/300]: D_loss: 0.467, G_loss: 1.461
Epoch: [34/300]: D_loss: 0.463, G_loss: 1.457
Epoch: [35/300]: D_loss: 0.474, G_loss: 1.438
Epoch: [36/300]: D_loss: 0.450, G_loss: 1.390
Epoch: [37/300]: D_loss: 0.454, G_loss: 1.521
Epoch: [38/300]: D_loss: 0.431, G_loss: 1.494
Epoch: [39/300]: D_loss: 0.487, G_loss: 1.557
Epoch: [40/300]: D_loss: 0.439, G_loss: 1.563
Epoch: [41/300]: D_loss: 0.425, G_loss: 1.586
Epoch: [42/300]: D_loss: 0.456, G_loss: 1.626
Epoch: [43/300]: D_loss: 0.433, G_loss: 1.605
Epoch: [44/300]: D_loss: 0.428, G_loss: 1.625
Epoch: [45/300]: D_loss: 0.433, G_loss: 1.649
Epoch: [46/300]: D_loss: 0.421, G_loss: 1.677
Epoch: [47/300]: D_loss: 0.449, G_loss: 1.694
Epoch: [48/300]: D_loss: 0.399, G_loss: 1.697
Epoch: [49/300]: D_loss: 0.500, G_loss: 1.936
Epoch: [50/300]: D_loss: 0.392, G_loss: 1.715
Epoch: [51/300]: D_loss: 0.372, G_loss: 1.660
Epoch: [52/300]: D_loss: 0.417, G_loss: 1.767
Epoch: [53/300]: D_loss: 0.412, G_loss: 1.732
Epoch: [54/300]: D_loss: 0.492, G_loss: 1.843
Epoch: [55/300]: D_loss: 0.412, G_loss: 1.892
Epoch: [56/300]: D_loss: 0.354, G_loss: 1.720
Epoch: [57/300]: D_loss: 0.407, G_loss: 1.831
Epoch: [58/300]: D_loss: 0.374, G_loss: 1.772
Epoch: [59/300]: D_loss: 0.388, G_loss: 1.871
Epoch: [60/300]: D_loss: 0.400, G_loss: 1.936
Epoch: [61/300]: D_loss: 0.436, G_loss: 1.920
Epoch: [62/300]: D_loss: 0.344, G_loss: 1.839
Epoch: [63/300]: D_loss: 0.342, G_loss: 1.917
Epoch: [64/300]: D_loss: 0.352, G_loss: 1.939
Epoch: [65/300]: D_loss: 0.368, G_loss: 2.024
Epoch: [66/300]: D_loss: 0.375, G_loss: 2.005
Epoch: [67/300]: D_loss: 0.484, G_loss: 2.181
Epoch: [68/300]: D_loss: 0.344, G_loss: 1.982
Epoch: [69/300]: D_loss: 0.332, G_loss: 1.977
Epoch: [70/300]: D_loss: 0.331, G_loss: 2.023
Epoch: [71/300]: D_loss: 0.355, G_loss: 2.098
Epoch: [72/300]: D_loss: 0.342, G_loss: 2.120
Epoch: [73/300]: D_loss: 0.351, G_loss: 2.168
Epoch: [74/300]: D_loss: 0.350, G_loss: 2.141
Epoch: [75/300]: D_loss: 0.338, G_loss: 2.206
Epoch: [76/300]: D_loss: 0.295, G_loss: 2.190
Epoch: [77/300]: D_loss: 0.317, G_loss: 2.214
Epoch: [78/300]: D_loss: 0.392, G_loss: 2.321
Epoch: [79/300]: D_loss: 0.303, G_loss: 2.225
Epoch: [80/300]: D_loss: 0.328, G_loss: 2.273
Epoch: [81/300]: D_loss: 0.584, G_loss: 2.558
Epoch: [82/300]: D_loss: 0.298, G_loss: 2.130
Epoch: [83/300]: D_loss: 0.286, G_loss: 2.182
Epoch: [84/300]: D_loss: 0.281, G_loss: 2.258
Epoch: [85/300]: D_loss: 0.315, G_loss: 2.300
Epoch: [86/300]: D_loss: 0.293, G_loss: 2.285
Epoch: [87/300]: D_loss: 0.263, G_loss: 2.310
Epoch: [88/300]: D_loss: 0.295, G_loss: 2.370
Epoch: [89/300]: D_loss: 0.310, G_loss: 2.462
Epoch: [90/300]: D_loss: 0.291, G_loss: 2.466
Epoch: [91/300]: D_loss: 0.565, G_loss: 2.573
Epoch: [92/300]: D_loss: 0.291, G_loss: 2.351
Epoch: [93/300]: D_loss: 0.255, G_loss: 2.417
Epoch: [94/300]: D_loss: 0.258, G_loss: 2.455
Epoch: [95/300]: D_loss: 0.278, G_loss: 2.496
Epoch: [96/300]: D_loss: 0.534, G_loss: 2.803
Epoch: [97/300]: D_loss: 0.561, G_loss: 2.472
Epoch: [98/300]: D_loss: 0.285, G_loss: 2.326
Epoch: [99/300]: D_loss: 0.255, G_loss: 2.395
Epoch: [100/300]: D_loss: 0.283, G_loss: 2.444
Epoch: [101/300]: D_loss: 0.272, G_loss: 2.447
Epoch: [102/300]: D_loss: 0.281, G_loss: 2.547
Epoch: [103/300]: D_loss: 0.246, G_loss: 2.540
Epoch: [104/300]: D_loss: 0.255, G_loss: 2.487
Epoch: [105/300]: D_loss: 0.276, G_loss: 2.647
Epoch: [106/300]: D_loss: 0.243, G_loss: 2.581
Epoch: [107/300]: D_loss: 0.284, G_loss: 2.665
Epoch: [108/300]: D_loss: 0.356, G_loss: 2.767
Epoch: [109/300]: D_loss: 0.325, G_loss: 2.608
Epoch: [110/300]: D_loss: 0.258, G_loss: 2.635
Epoch: [111/300]: D_loss: 0.260, G_loss: 2.638
Epoch: [112/300]: D_loss: 0.245, G_loss: 2.639
Epoch: [113/300]: D_loss: 0.235, G_loss: 2.760
Epoch: [114/300]: D_loss: 0.268, G_loss: 2.682
Epoch: [115/300]: D_loss: 0.254, G_loss: 2.810
Epoch: [116/300]: D_loss: 0.373, G_loss: 2.840
Epoch: [117/300]: D_loss: 0.251, G_loss: 2.751
Epoch: [118/300]: D_loss: 0.225, G_loss: 2.770
Epoch: [119/300]: D_loss: 0.235, G_loss: 2.848
Epoch: [120/300]: D_loss: 0.601, G_loss: 3.162
Epoch: [121/300]: D_loss: 0.235, G_loss: 2.735
Epoch: [122/300]: D_loss: 0.230, G_loss: 2.797
Epoch: [123/300]: D_loss: 0.205, G_loss: 2.831
Epoch: [124/300]: D_loss: 0.204, G_loss: 2.881
Epoch: [125/300]: D_loss: 0.228, G_loss: 2.906
Epoch: [126/300]: D_loss: 0.219, G_loss: 2.861
Epoch: [127/300]: D_loss: 0.277, G_loss: 3.000
Epoch: [128/300]: D_loss: 0.302, G_loss: 3.143
Epoch: [129/300]: D_loss: 0.296, G_loss: 2.934
Epoch: [130/300]: D_loss: 0.227, G_loss: 2.874
Epoch: [131/300]: D_loss: 0.239, G_loss: 3.007
Epoch: [132/300]: D_loss: 0.284, G_loss: 2.962
Epoch: [133/300]: D_loss: 0.234, G_loss: 2.940
Epoch: [134/300]: D_loss: 0.217, G_loss: 3.016
Epoch: [135/300]: D_loss: 0.212, G_loss: 3.039
Epoch: [136/300]: D_loss: 0.227, G_loss: 3.104
Epoch: [137/300]: D_loss: 0.497, G_loss: 3.176
Epoch: [138/300]: D_loss: 0.229, G_loss: 2.908
Epoch: [139/300]: D_loss: 0.207, G_loss: 3.000
Epoch: [140/300]: D_loss: 0.321, G_loss: 3.066
Epoch: [141/300]: D_loss: 0.227, G_loss: 3.061
Epoch: [142/300]: D_loss: 0.194, G_loss: 3.103
Epoch: [143/300]: D_loss: 0.199, G_loss: 3.133
Epoch: [144/300]: D_loss: 0.287, G_loss: 3.104
Epoch: [145/300]: D_loss: 0.216, G_loss: 3.148
Epoch: [146/300]: D_loss: 0.207, G_loss: 3.178
Epoch: [147/300]: D_loss: 0.226, G_loss: 3.231
Epoch: [148/300]: D_loss: 0.184, G_loss: 3.292
Epoch: [149/300]: D_loss: 0.302, G_loss: 3.468
Epoch: [150/300]: D_loss: 0.574, G_loss: 2.866
Epoch: [151/300]: D_loss: 0.218, G_loss: 3.066
Epoch: [152/300]: D_loss: 0.208, G_loss: 2.998
Epoch: [153/300]: D_loss: 0.212, G_loss: 3.073
Epoch: [154/300]: D_loss: 0.207, G_loss: 3.224
Epoch: [155/300]: D_loss: 0.176, G_loss: 3.168
Epoch: [156/300]: D_loss: 0.198, G_loss: 3.295
Epoch: [157/300]: D_loss: 0.191, G_loss: 3.271
Epoch: [158/300]: D_loss: 0.315, G_loss: 3.324
Epoch: [159/300]: D_loss: 0.717, G_loss: 3.085
Epoch: [160/300]: D_loss: 0.245, G_loss: 2.947
Epoch: [161/300]: D_loss: 0.199, G_loss: 3.082
Epoch: [162/300]: D_loss: 0.181, G_loss: 3.173
Epoch: [163/300]: D_loss: 0.180, G_loss: 3.343
Epoch: [164/300]: D_loss: 0.182, G_loss: 3.205
Epoch: [165/300]: D_loss: 0.179, G_loss: 3.409
Epoch: [166/300]: D_loss: 0.195, G_loss: 3.355
Epoch: [167/300]: D_loss: 0.195, G_loss: 3.273
Epoch: [168/300]: D_loss: 0.264, G_loss: 3.330
Epoch: [169/300]: D_loss: 0.269, G_loss: 3.302
Epoch: [170/300]: D_loss: 0.180, G_loss: 3.333
Epoch: [171/300]: D_loss: 0.200, G_loss: 3.352
Epoch: [172/300]: D_loss: 0.163, G_loss: 3.417
Epoch: [173/300]: D_loss: 0.267, G_loss: 3.426
Epoch: [174/300]: D_loss: 0.175, G_loss: 3.398
Epoch: [175/300]: D_loss: 0.201, G_loss: 3.431
Epoch: [176/300]: D_loss: 0.162, G_loss: 3.589
Epoch: [177/300]: D_loss: 0.271, G_loss: 3.522
Epoch: [178/300]: D_loss: 0.203, G_loss: 3.478
Epoch: [179/300]: D_loss: 0.234, G_loss: 3.439
Epoch: [180/300]: D_loss: 0.178, G_loss: 3.442
Epoch: [181/300]: D_loss: 0.172, G_loss: 3.618
Epoch: [182/300]: D_loss: 0.221, G_loss: 3.475
Epoch: [183/300]: D_loss: 0.157, G_loss: 3.620
Epoch: [184/300]: D_loss: 0.158, G_loss: 3.660
Epoch: [185/300]: D_loss: 0.235, G_loss: 3.713
Epoch: [186/300]: D_loss: 0.246, G_loss: 3.531
Epoch: [187/300]: D_loss: 0.237, G_loss: 3.516
Epoch: [188/300]: D_loss: 0.198, G_loss: 3.526
Epoch: [189/300]: D_loss: 0.184, G_loss: 3.643
Epoch: [190/300]: D_loss: 0.200, G_loss: 3.579
Epoch: [191/300]: D_loss: 0.322, G_loss: 3.644
Epoch: [192/300]: D_loss: 0.218, G_loss: 3.369
Epoch: [193/300]: D_loss: 0.171, G_loss: 3.682
Epoch: [194/300]: D_loss: 0.202, G_loss: 3.514
Epoch: [195/300]: D_loss: 0.185, G_loss: 3.708
Epoch: [196/300]: D_loss: 0.170, G_loss: 3.656
Epoch: [197/300]: D_loss: 0.175, G_loss: 3.640
Epoch: [198/300]: D_loss: 0.293, G_loss: 3.677
Epoch: [199/300]: D_loss: 0.165, G_loss: 3.678
Epoch: [200/300]: D_loss: 0.195, G_loss: 3.716
Epoch: [201/300]: D_loss: 0.179, G_loss: 3.801
Epoch: [202/300]: D_loss: 0.189, G_loss: 3.747
Epoch: [203/300]: D_loss: 0.219, G_loss: 3.774
Epoch: [204/300]: D_loss: 0.212, G_loss: 3.667
Epoch: [205/300]: D_loss: 0.265, G_loss: 3.833
Epoch: [206/300]: D_loss: 0.518, G_loss: 3.526
Epoch: [207/300]: D_loss: 0.166, G_loss: 3.600
Epoch: [208/300]: D_loss: 0.171, G_loss: 3.729
Epoch: [209/300]: D_loss: 0.166, G_loss: 3.676
Epoch: [210/300]: D_loss: 0.143, G_loss: 3.771
Epoch: [211/300]: D_loss: 0.169, G_loss: 3.847
Epoch: [212/300]: D_loss: 0.192, G_loss: 3.752
Epoch: [213/300]: D_loss: 0.146, G_loss: 3.923
Epoch: [214/300]: D_loss: 0.138, G_loss: 3.975
Epoch: [215/300]: D_loss: 0.182, G_loss: 3.931
Epoch: [216/300]: D_loss: 0.259, G_loss: 3.830
Epoch: [217/300]: D_loss: 0.359, G_loss: 3.688
Epoch: [218/300]: D_loss: 0.191, G_loss: 3.619
Epoch: [219/300]: D_loss: 0.197, G_loss: 3.712
Epoch: [220/300]: D_loss: 0.170, G_loss: 3.892
Epoch: [221/300]: D_loss: 0.158, G_loss: 3.830
Epoch: [222/300]: D_loss: 0.174, G_loss: 3.958
Epoch: [223/300]: D_loss: 0.356, G_loss: 3.756
Epoch: [224/300]: D_loss: 0.163, G_loss: 3.833
Epoch: [225/300]: D_loss: 0.165, G_loss: 3.827
Epoch: [226/300]: D_loss: 0.245, G_loss: 3.814
Epoch: [227/300]: D_loss: 0.215, G_loss: 3.875
Epoch: [228/300]: D_loss: 0.143, G_loss: 3.891
Epoch: [229/300]: D_loss: 0.162, G_loss: 3.822
Epoch: [230/300]: D_loss: 0.165, G_loss: 3.825
Epoch: [231/300]: D_loss: 0.141, G_loss: 3.997
Epoch: [232/300]: D_loss: 0.349, G_loss: 3.999
Epoch: [233/300]: D_loss: 0.186, G_loss: 3.756
Epoch: [234/300]: D_loss: 0.203, G_loss: 3.852
Epoch: [235/300]: D_loss: 0.141, G_loss: 3.948
Epoch: [236/300]: D_loss: 0.150, G_loss: 4.067
Epoch: [237/300]: D_loss: 0.124, G_loss: 3.990
Epoch: [238/300]: D_loss: 0.140, G_loss: 4.152
Epoch: [239/300]: D_loss: 0.134, G_loss: 4.240
Epoch: [240/300]: D_loss: 0.287, G_loss: 4.138
Epoch: [241/300]: D_loss: 0.241, G_loss: 3.737
Epoch: [242/300]: D_loss: 0.163, G_loss: 3.928
Epoch: [243/300]: D_loss: 0.156, G_loss: 4.052
Epoch: [244/300]: D_loss: 0.156, G_loss: 4.118
Epoch: [245/300]: D_loss: 0.139, G_loss: 4.274
Epoch: [246/300]: D_loss: 0.275, G_loss: 4.058
Epoch: [247/300]: D_loss: 0.484, G_loss: 3.686
Epoch: [248/300]: D_loss: 0.151, G_loss: 3.774
Epoch: [249/300]: D_loss: 0.135, G_loss: 3.973
Epoch: [250/300]: D_loss: 0.145, G_loss: 3.991
Epoch: [251/300]: D_loss: 0.370, G_loss: 3.788
Epoch: [252/300]: D_loss: 0.147, G_loss: 3.953
Epoch: [253/300]: D_loss: 0.118, G_loss: 3.983
Epoch: [254/300]: D_loss: 0.423, G_loss: 3.963
Epoch: [255/300]: D_loss: 0.185, G_loss: 3.787
Epoch: [256/300]: D_loss: 0.137, G_loss: 3.969
Epoch: [257/300]: D_loss: 0.121, G_loss: 4.092
Epoch: [258/300]: D_loss: 0.123, G_loss: 4.210
Epoch: [259/300]: D_loss: 0.117, G_loss: 4.198
Epoch: [260/300]: D_loss: 0.256, G_loss: 3.950
Epoch: [261/300]: D_loss: 0.154, G_loss: 4.174
Epoch: [262/300]: D_loss: 0.156, G_loss: 4.155
Epoch: [263/300]: D_loss: 0.162, G_loss: 4.018
Epoch: [264/300]: D_loss: 0.157, G_loss: 4.300
Epoch: [265/300]: D_loss: 0.230, G_loss: 4.131
Epoch: [266/300]: D_loss: 0.171, G_loss: 4.015
Epoch: [267/300]: D_loss: 0.141, G_loss: 4.234
Epoch: [268/300]: D_loss: 0.137, G_loss: 4.182
Epoch: [269/300]: D_loss: 0.181, G_loss: 4.344
Epoch: [270/300]: D_loss: 0.337, G_loss: 4.148
Epoch: [271/300]: D_loss: 0.190, G_loss: 4.132
Epoch: [272/300]: D_loss: 0.146, G_loss: 4.203
Epoch: [273/300]: D_loss: 0.114, G_loss: 4.289
Epoch: [274/300]: D_loss: 0.139, G_loss: 4.265
Epoch: [275/300]: D_loss: 0.125, G_loss: 4.394
Epoch: [276/300]: D_loss: 0.116, G_loss: 4.547
Epoch: [277/300]: D_loss: 1.516, G_loss: 3.457
Epoch: [278/300]: D_loss: 0.391, G_loss: 2.590
Epoch: [279/300]: D_loss: 0.290, G_loss: 3.127
Epoch: [280/300]: D_loss: 0.205, G_loss: 3.510
Epoch: [281/300]: D_loss: 0.167, G_loss: 3.637
Epoch: [282/300]: D_loss: 0.189, G_loss: 3.763
Epoch: [283/300]: D_loss: 0.154, G_loss: 3.826
Epoch: [284/300]: D_loss: 0.146, G_loss: 4.023
Epoch: [285/300]: D_loss: 0.119, G_loss: 4.056
Epoch: [286/300]: D_loss: 0.134, G_loss: 4.139
Epoch: [287/300]: D_loss: 0.141, G_loss: 4.132
Epoch: [288/300]: D_loss: 0.134, G_loss: 4.178
Epoch: [289/300]: D_loss: 0.140, G_loss: 4.251
Epoch: [290/300]: D_loss: 0.134, G_loss: 4.323
Epoch: [291/300]: D_loss: 0.175, G_loss: 4.230
Epoch: [292/300]: D_loss: 0.223, G_loss: 4.081
Epoch: [293/300]: D_loss: 0.123, G_loss: 4.274
Epoch: [294/300]: D_loss: 0.109, G_loss: 4.293
Epoch: [295/300]: D_loss: 0.132, G_loss: 4.391
Epoch: [296/300]: D_loss: 0.204, G_loss: 4.356
Epoch: [297/300]: D_loss: 0.166, G_loss: 4.285
Epoch: [298/300]: D_loss: 0.150, G_loss: 4.346
Epoch: [299/300]: D_loss: 0.151, G_loss: 4.312
Epoch: [300/300]: D_loss: 0.445, G_loss: 3.976

四、模型分析

  1. Loss 图
G_loss_list = [i.item() for i in G_loss_plot]
D_loss_list = [i.item() for i in D_loss_plot]
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

plt.figure(figsize=(8,4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_loss_list,label="G")
plt.plot(D_loss_list,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

代码输出结果:
在这里插入图片描述

  1. 生成指定图像
    predictions = predictions.permute(0,2,3,1)详解
    这行代码是一个PyTorch(深度学习框架)中的操作,用于维度的重新排列。让我们逐步解释这行代码的意思:
    假设predictions是一个PyTorch张量(tensor),它的维度为 (batch_size, height, width, channels),其中:

● batch_size:批量大小,表示张量中有多少个样本。
● height:高度,表示图像的高度(或特征图的高度)。
● width:宽度,表示图像的宽度(或特征图的宽度)。
● channels:通道数,表示图像或特征图的通道数,例如RGB图像的通道数为3。

现在,让我们来解释这行代码的操作:

predictions.permute(0, 2, 3, 1)

permute是PyTorch中的一个函数,用于对张量的维度进行重新排列。在这个代码中,permute函数将张量的维度进行重新排列,以得到一个新的张量。具体地说,它将原始张量中的维度按照指定的顺序进行重新排列。
参数说明:

0, 2, 3, 1:这是一个指定新维度顺序的元组。在这里,它表示将原始维度中的第0维移到新张量的第0维,第2维移到新张量的第1维,第3维移到新张量的第2维,最后,第1维移到新张量的第3维。

所以,假设原始张量的形状是 (batch_size, height, width, channels),通过这行代码后,新张量的形状将变为 (batch_size, width, channels, height)。
这种维度重新排列在深度学习中非常常见,尤其是在卷积神经网络(Convolutional Neural Networks,CNNs)中,因为在某些情况下,不同的层需要不同的维度排列。permute函数就是为了帮助我们方便地处理这种情况,使得在不同层之间传递数据时更加高效和便捷。

# 导入所需的库
from numpy.random import randint, randn
from numpy        import linspace
from matplotlib   import pyplot, gridspec

# 导入生成器模型
# generator.load_state_dict(torch.load('./G4/training_weights/generator_epoch_300.pth'), strict=False)
generator.load_state_dict(torch.load('./G4/generator_epoch_300/generator_epoch_300.pth',map_location=torch.device('cpu')), strict=False)

generator.eval()   

interpolated = randn(100)  # 生成两个潜在空间的点
# 将数据转换为torch张量并将其移至GPU(假设device已正确声明为GPU)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

label  = 0  # 手势标签,可在0,1,2之间选择
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()

# 使用生成器生成插值结果
predictions = generator((interpolated, labels))
predictions = predictions.permute(0,2,3,1).detach().cpu()
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

plt.figure(figsize=(8, 3))

pred = (predictions[0, :, :, :] + 1 ) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

代码输出结果(label = 0时,如上代码所述):
在这里插入图片描述

代码输出结果(如果label = 1时):
在这里插入图片描述
代码输出结果(如果label = 2时):
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1951516.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C#基础——类、构造函数和静态成员

类 类是一个数据类型的蓝图。构成类的方法和变量称为类的成员&#xff0c;对象是类的实例。类的定义规定了类的对象由什么组成及在这个对象上可执行什么操作。 class 类名 { (访问属性) 成员变量; (访问属性) 成员函数; } 访问属性&#xff1a;public&#xff08;公有的&…

MinIO对象生命周期

Object Management — MinIO Object Storage for Windowshttps://min.io/docs/minio/windows/administration/object-management.html 1.概念 MinIO 对象生命周期是指对存储在 MinIO 中的对象进行自动管理的一套策略。这些策略可以用于自动删除旧对象、转移对象到不同的存储类别…

Golang高效合并(拼接)多个gzip压缩文件

有时我们可能会遇到需要把多个 gzip 文件合并成单个 gzip 文件的场景&#xff0c;最简单最容易的方式是把每个gzip文件都先解压&#xff0c;然后合并成一个文件后再次进行压缩&#xff0c;最终得到我们想要的结果&#xff0c;但这种先解压后压缩的方式显然效率不高&#xff0c;…

day08:订单状态定时处理、来单提醒和客户催单

文章目录 Spring Task介绍cron表达式入门案例 订单状态定时处理需求分析代码开发扩展 WebSocket介绍入门案例特点 来单提醒需求分析和设计代码实现 客户催单需求分析和设计代码实现 Spring Task 介绍 Spring Task 是Spring框架提供的任务调度工具&#xff0c;可以按照约定的时…

爬虫提速!用Python实现多线程下载器!

✨ 内容&#xff1a; 在网络应用中&#xff0c;下载速度往往是用户体验的关键。多线程下载可以显著提升下载速度&#xff0c;通过将一个文件分成多个部分并行下载&#xff0c;可以更高效地利用带宽资源。今天&#xff0c;我们将通过一个实际案例&#xff0c;学习如何用Python实…

C++ | Leetcode C++题解之第292题Nim游戏

题目&#xff1a; 题解&#xff1a; class Solution { public:bool canWinNim(int n) {return n % 4 ! 0;} };

如何使用API快速打造健康医疗系统?

在数字医疗市场&#xff0c;数据是人们经常谈及的一个话题。当前&#xff0c;消费者医疗和健康应用收集的数据越来越多&#xff0c;电子健康记录的实施也创造出了大量有关病人的电子信息。 API接口在智慧医院跨网、跨机构之间的业务协同和数据共享交换中得到数据共享。支撑了医…

鸿蒙APP架构及开发入门

1.鸿蒙系统 1.1 什么是鸿蒙 鸿蒙是一款面向万物互联时代的、全新的分布式操作系统。 在传统的单设备系统能力基础上&#xff0c;鸿蒙提出了基于同一套系统能力、适配多种终端形态的分布式理念&#xff0c;能够支持手机、平板、智能穿戴、智慧屏、车机、PC、智能音箱、耳机、…

【数学建模】权重生成与评价模型(上)

文章目录 权重生成与评价模型&#xff08;上&#xff09;1. 层次分析法1.1 层次分析法的原理构建判断矩阵权重向量计算一致性检验 1.2 层次分析法的案例1. 建立层次结构2. 构建判断矩阵3. 计算权重向量4. 一致性检验5. 计算综合权重 1.3 另一种得出综合得分的方法例子计算步骤完…

计算机实验室排课查询小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;学生管理&#xff0c;教师管理&#xff0c;实验室信息管理&#xff0c;实验室预约管理&#xff0c;取消预约管理&#xff0c;实验课程管理&#xff0c;实验报告管理&#xff0c;报修信息管理&#xff0…

Leetcode49. 字母异位词分组(java实现)

今天我来给大家分享的是leetcode49的解题思路&#xff0c;题目描述如下 如果没有做过leetcode242题目的同学&#xff0c;可以先把它做了&#xff0c;会更好理解异位词的概念。 本道题的大题思路是&#xff1a; 首先遍历strs&#xff0c;然后统计每一个数组元素出现的次数&#…

Java 基础学习第二节: Java 变量与数据类型

第二节 001.回顾 1.Java开发环境 1.Java编译运行过程 编译期:.java源文件,经过编译,生成.class字节码文件运行期:JVM加载.class文件并运行跨平台,一次编程,到处使用 2.名词解释 JVM:java虚拟机,加载并运行.classJRE:java运行环境,JVMjava系统类库JDK:java开发工具包,JRE开发j…

【计算机网络】期末实验答辩

注意事项&#xff1a; 1&#xff09;每位同学要在下面做过的实验列表中选取三个实验进行答辩准备&#xff0c;并将自己的姓名&#xff0c;学号以及三个实验序号填入共享文档"1&#xff08;2&#xff09;班答辩名单"中。 2&#xff09;在答辩当日每位同学由老师在表…

Dify 零代码 AI 应用开发:快速入门与实战

一、Dify 介绍 Dify 是一个开源的大语言模型 (LLM) 应用开发平台。它结合了后端即服务 (Backend-as-a-Service) 和 LLMOps (LLMOps) 的概念&#xff0c;使开发人员能够快速构建生产级生成式 AI (Generative AI) 应用。即使是非技术人员也可以参与 AI 应用的定义和数据操作。 …

CeoMax总裁主题最新3.8.1破解免授权版/WordPress付费资源素材下载主题

CeoMax总裁主题最新3.8.1破解免授权版&#xff0c;一套WordPress付费资源素材下载的主题&#xff0c;感觉这是做资源站唯一一个可以和ripro媲美甚至超越的模板&#xff0c;UI很美&#xff0c;功能也很强大&#xff0c;有想学习的可下载搭建学习一下&#xff0c;仅供学习研究借鉴…

C语言同时在一行声明指针和整型变量

如果这么写&#xff0c; int *f, g; 并没有声明2个指针&#xff0c;编译器自己会识别&#xff0c;f是一个指针&#xff0c;g是一个整型变量&#xff1b; void CTszbView::OnDraw(CDC* pDC) {CTszbDoc* pDoc GetDocument();ASSERT_VALID(pDoc);// TODO: add draw code for nat…

11. Hibernate 持久化对象的各种状态

1. 前言 本节课和大家聊聊持久化对象的 3 种状态。通过本节课程&#xff0c;你将了解到&#xff1a; 持久化对象的 3 种状态&#xff1b;什么是对象持久化能力。 2. 持久化对象的状态 程序运行期间的数据都是存储在内存中。内存具有临时性。程序结束、计算机挂机…… 内存中…

ElasticSearch核心之DSL查询语句实战

什么是DSL&#xff1f; Elasticsearch提供丰富且灵活的查询语言叫做DSL查询(Query DSL),它允许你构建更加复杂、强大的查询。 DSL(Domain Specific Language特定领域语言)以JSON请求体的形式出现。目前常用的框架查询方法什么的底层都是构建DSL语句实现的&#xff0c;所以你必…

linux 网络子系统

__netif_receive_skb_core 是 Linux 内核网络子系统中一个非常重要的函数&#xff0c;它负责将网络设备驱动层接收到的数据包传递到上层协议栈进行处理。以下是对该函数的一些关键点的详细解析&#xff1a; 一、函数作用 __netif_receive_skb_core 函数是处理接收到的网络数据…

HAL STM32 SPI/ABZ/PWM方式读取MT6816磁编码器数据

HAL STM32 SPI/ABZ/PWM方式读取MT6816磁编码器数据 &#x1f4da;MT6816相关资料&#xff08;来自商家的相关资料&#xff09;&#xff1a; 资料&#xff1a;https://pan.baidu.com/s/1CAbdLBRi2dmL4D7cFve1XA?pwd8888 提取码&#xff1a;8888&#x1f4cd;驱动代码编写&…