深度学习训练营之CGAN生成手势图像

news2025/1/20 1:50:43

深度学习训练营之CGAN生成手势

  • 原文链接
  • CGAN简单介绍
  • 环境介绍
  • 前置工作
    • 数据
    • 导入所需的包
    • 加载数据
    • 创建数据集
    • 查看数据集
  • 模型设置
    • 初始化模型的权重
    • 定义生成器
    • 构造判别器
  • 模型训练
    • 定义损失函数
    • 设置超参数
    • 正式开始训练
  • 结果可视化

原文链接

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第G3周:实现mnist手写数字识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

CGAN简单介绍

和前面所提及的文章一样,在 CGAN 中,生成器(Generator)判别器(Discriminato都接收条件信息。生成器的目标是生成与条件信息相关的合成样本,而判别器的目标是将生成的样本与真实样本区分开来。二者之间通过不断的学习,不断提高自身的判别能力以及生成能力
当生成器和判别器通过反馈循环不断地进行训练时,生成器会逐渐学会如何生成符合条件信息的样本,而判别器则会逐渐变得更加准确。在二者达到一个平衡点时就能停止训练

在这里插入图片描述

根据网络结构可知,条件信息 y y y作为额外的输入被引入对抗网络中,与生成器 G G G中的噪声 z z z合并作为隐含层表达;
而在判别器 D D D中,条件信息y则与原始数据x合并作为判别函数的输入。

环境介绍

  • 语言环境:Python3.9.12
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2

前置工作

数据

百度网盘:https://pan.baidu.com/s/1thsYtu_1bzd2yel0K97hYA?pwd=7e3q
将下载的数据集放到运行文件所在目录的data文件夹当中,data文件夹当中设置rps文件夹来放置图片
,同时设置文件夹images_GAN3为后续的训练做准备
在这里插入图片描述
运行文件生成手势图像.ipynb的目录下创建一个training_weights的文件夹
在这里插入图片描述

导入所需的包

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary
import matplotlib.pyplot as plt

加载数据

dataroot = "./data/rps"  # 数据路径
batch_size = 128  # 训练过程中的批次大小
image_size = 128   # 图像的尺寸(宽度和高度)
image_shape = (3, 128, 128)
image_dim = int(np.prod(image_shape))
latent_dim = 100
n_classes = 3     # 条件标签的总数
embedding_dim = 100

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

创建数据集

生成训练所用的数据集,并且创建数据加载器

train_dataset = datasets.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                           transforms.Resize(image_size),        # 调整图像大小
                           transforms.ToTensor(),                # 将图像转换为张量
                           transforms.Normalize((0.5, 0.5, 0.5), # 标准化图像张量
                                                (0.5, 0.5, 0.5)),
                           ]))

# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, 
                                         batch_size=batch_size,  # 批量大小
                                         shuffle=True,           # 是否打乱数据集
                                         num_workers=6 # 使用多个线程加载数据的工作进程数
                                        )

查看数据集

def show_images(images):
    fig, ax = plt.subplots(figsize=(20, 20))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(images.detach(), nrow=22).permute(1, 2, 0))

def show_batch(dl):
    for images, _ in dl:
        show_images(images)
        break
        
show_batch(train_loader)

请添加图片描述

模型设置

初始化模型的权重

# 自定义权重初始化函数,用于初始化生成器和判别器的权重
def weights_init(m):
    # 获取当前层的类名
    classname = m.__class__.__name__

    # 如果当前层是卷积层(类名中包含 'Conv' )
    if classname.find('Conv') != -1:
        # 使用正态分布随机初始化权重,均值为0,标准差为0.02
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    
    # 如果当前层是批归一化层(类名中包含 'BatchNorm' )
    elif classname.find('BatchNorm') != -1:
        # 使用正态分布随机初始化权重,均值为1,标准差为0.02
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        # 将偏置项初始化为全零
        torch.nn.init.zeros_(m.bias)

定义生成器

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        # 定义条件标签的生成器部分,用于将标签映射到嵌入空间中
        # n_classes:条件标签的总数
        # embedding_dim:嵌入空间的维度
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),  # 使用Embedding层将条件标签映射为稠密向量
            nn.Linear(embedding_dim, 16)             # 使用线性层将稠密向量转换为更高维度
        )

        # 定义潜在向量的生成器部分,用于将噪声向量映射到图像空间中
        # latent_dim:潜在向量的维度
        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),  # 使用线性层将潜在向量转换为更高维度
            nn.LeakyReLU(0.2, inplace=True)  # 使用LeakyReLU激活函数进行非线性映射
        )

        # 定义生成器的主要结构,将条件标签和潜在向量合并成生成的图像
        self.model = nn.Sequential(
            # 反卷积层1:将合并后的向量映射为64x8x8的特征图
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),  # 批标准化
            nn.ReLU(True),  # ReLU激活函数
            # 反卷积层2:将64x8x8的特征图映射为64x4x4的特征图
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            # 反卷积层3:将64x4x4的特征图映射为64x2x2的特征图
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            # 反卷积层4:将64x2x2的特征图映射为64x1x1的特征图
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            # 反卷积层5:将64x1x1的特征图映射为3x64x64的RGB图像
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()  # 使用Tanh激活函数将生成的图像像素值映射到[-1, 1]范围内
        )

    def forward(self, inputs):
        noise_vector, label = inputs
        # 通过条件标签生成器将标签映射为嵌入向量
        label_output = self.label_conditioned_generator(label)
        # 将嵌入向量的形状变为(batch_size, 1, 4, 4),以便与潜在向量进行合并
        label_output = label_output.view(-1, 1, 4, 4)
        # 通过潜在向量生成器将噪声向量映射为潜在向量
        latent_output = self.latent(noise_vector)
        # 将潜在向量的形状变为(batch_size, 512, 4, 4),以便与条件标签进行合并
        latent_output = latent_output.view(-1, 512, 4, 4)
        
        # 将条件标签和潜在向量在通道维度上进行合并,得到合并后的特征图
        concat = torch.cat((latent_output, label_output), dim=1)
        # 通过生成器的主要结构将合并后的特征图生成为RGB图像
        image = self.model(concat)
        return image
    
generator = Generator().to(device)
generator.apply(weights_init)
print(generator)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Generator(
  (label_conditioned_generator): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=16, bias=True)
  )
  (latent): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

构造判别器

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        # 定义一个条件标签的嵌入层,用于将类别标签转换为特征向量
        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),     # 嵌入层将类别标签编码为固定长度的向量
            nn.Linear(embedding_dim, 3*128*128)         # 线性层将嵌入的向量转换为与图像尺寸相匹配的特征张量
        )
        
        # 定义主要的鉴别器模型
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),       # 输入通道为6(包含图像和标签的通道数),输出通道为64,4x4的卷积核,步长为2,padding为1
            nn.LeakyReLU(0.2, inplace=True),             # LeakyReLU激活函数,带有负斜率,增加模型对输入中的负值的感知能力
            nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),    # 输入通道为64,输出通道为64*2,4x4的卷积核,步长为3,padding为2
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),  # 批量归一化层,有利于训练稳定性和收敛速度
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),  # 输入通道为64*2,输出通道为64*4,4x4的卷积核,步长为3,padding为2
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),  # 输入通道为64*4,输出通道为64*8,4x4的卷积核,步长为3,padding为2
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),                               # 将特征图展平为一维向量,用于后续全连接层处理
            nn.Dropout(0.4),                            # 随机失活层,用于减少过拟合风险
            nn.Linear(4608, 1),                         # 全连接层,将特征向量映射到输出维度为1的向量
            nn.Sigmoid()                                # Sigmoid激活函数,用于输出范围限制在0到1之间的概率值
        )

    def forward(self, inputs):
        img, label = inputs
        
        # 将类别标签转换为特征向量
        label_output = self.label_condition_disc(label)
        # 重塑特征向量为与图像尺寸相匹配的特征张量
        label_output = label_output.view(-1, 3, 128, 128)
        
        # 将图像特征和标签特征拼接在一起作为鉴别器的输入
        concat = torch.cat((img, label_output), dim=1)
        
        # 将拼接后的输入通过鉴别器模型进行前向传播,得到输出结果
        output = self.model(concat)
        return output
    
discriminator = Discriminator().to(device)
discriminator.apply(weights_init)
print(discriminator)

在这里插入图片描述

模型训练

定义损失函数

使用的是BCELoss()
L o s s = − w ∗ [ p ∗ l o g ( q ) + ( 1 − p ) ∗ l o g ( 1 − q ) ] Loss = -w * [p * log(q) + (1-p) * log(1-q)] Loss=w[plog(q)+(1p)log(1q)],其中 p p p q q q分别为理论标签、实际预测值, w w w为权重。这里的 l o g log log对应数学上的 l n ln ln

torch.nn.BCELoss(weight=None, size_average=None, reduce=None, reduction=‘mean’)
计算目标值和预测值之间的二进制交叉熵损失函数。

有四个可选参数:weight、size_average、reduce、reduction
adversarial_loss = nn.BCELoss() 

def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss

def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss

设置超参数

设置学习率以及优化器

learning_rate = 0.0002

G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

正式开始训练

每10轮就进行对生成图像的保存

# 设置训练的总轮数
num_epochs = 100
# 初始化用于存储每轮训练中判别器和生成器损失的列表
D_loss_plot, G_loss_plot = [], []

# 循环进行训练
for epoch in range(1, num_epochs + 1):
    
    # 初始化每轮训练中判别器和生成器损失的临时列表
    D_loss_list, G_loss_list = [], []
    
    # 遍历训练数据加载器中的数据
    for index, (real_images, labels) in enumerate(train_loader):
        # 清空判别器的梯度缓存
        D_optimizer.zero_grad()
        # 将真实图像数据和标签转移到GPU(如果可用)
        real_images = real_images.to(device)
        labels      = labels.to(device)
        
        # 将标签的形状从一维向量转换为二维张量(用于后续计算)
        labels = labels.unsqueeze(1).long()
        # 创建真实目标和虚假目标的张量(用于判别器损失函数)
        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))

        # 计算判别器对真实图像的损失
        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)
        
        # 从噪声向量中生成假图像(生成器的输入)
        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector, labels))
        
        # 计算判别器对假图像的损失(注意detach()函数用于分离生成器梯度计算图)
        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output, fake_target)

        # 计算判别器总体损失(真实图像损失和假图像损失的平均值)
        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)

        # 反向传播更新判别器的参数
        D_total_loss.backward()
        D_optimizer.step()

        # 清空生成器的梯度缓存
        G_optimizer.zero_grad()
        # 计算生成器的损失
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
        G_loss_list.append(G_loss)

        # 反向传播更新生成器的参数
        G_loss.backward()
        G_optimizer.step()

    # 打印当前轮次的判别器和生成器的平均损失
    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
            (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), 
            torch.mean(torch.FloatTensor(G_loss_list))))
    
    # 将当前轮次的判别器和生成器的平均损失保存到列表中
    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))

    if epoch%10 == 0:
        # 将生成的假图像保存为图片文件
        save_image(generated_image.data[:50], './data/images_GAN3/sample_%d' % epoch + '.png', nrow=5, normalize=True)
        # 将当前轮次的生成器和判别器的权重保存到文件
        torch.save(generator.state_dict(), './training_weights/generator_epoch_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './training_weights/discriminator_epoch_%d.pth' % (epoch))

在这里插入图片描述

结果可视化

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_loss_plot,label="G")
plt.plot(D_loss_plot,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

这是我训练出来的结果
请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/837385.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

leetcode 763. 划分字母区间

2023.8.3 本题的关键是要确保同一字母需要在同一片段中,而这就需要关注到每个字母最后一次出现的位置。 思路:用一个哈希表保存每个字母(26个)最后一次出现的位置。然后从头遍历,不断更新最右边界,直到当前…

LLVM笔记1

参考:https://www.bilibili.com/video/BV1D84y1y73v/?share_sourcecopy_web&vd_sourcefc187607fc6ec6bbd2c74a3d0d7484cf 文章目录 零、入门名词解释1. Compiler & Interpreter2. AOT静态编译和JIT动态解释的编译方式3. Pass4. Intermediate Representatio…

Eureka增加账号密码认证登录

一、业务背景 注册中心Eureka在微服务开发中经常使用到,用来管理发布的微服务,供前端或者外部调用。但是如果放到生产环境,我们直接通过URL访问的话,这显然是不安全的。 所以需要给注册中心加上登录认证。 通过账号和密码认证进行…

openGauss学习笔记-30 openGauss 高级数据管理-别名

文章目录 openGauss学习笔记-30 openGauss 高级数据管理-别名30.1 语法格式30.1.1 列别名语法30.1.2 表别名语法 30.2 参数说明30.3 示例 openGauss学习笔记-30 openGauss 高级数据管理-别名 SQL可以重命名一张表或者一个字段的名称,这个名称为该表或该字段的别名。…

Spring:JDBCTemplate

JDBCTemplate 概述 概述 JDBC(Java DataBase Connectivity,Java 数据库连接), 一 种用于执行 SQL 语句的 Java API(Application Programming Interface , 应用程序设计接口 ),可以为…

【每日一题】—— C. Mocha and Hiking(Codeforces Round 738 (Div. 2))

🌏博客主页:PH_modest的博客主页 🚩当前专栏:每日一题 💌其他专栏: 🔴 每日反刍 🟡 C跬步积累 🟢 C语言跬步积累 🌈座右铭:广积粮,缓称…

IO流【笔记】

1. IO概述 1.1 什么是IO 生活中,你肯定经历过这样的场景。当你编辑一个文本文件,忘记了ctrls ,可能文件就白白编辑了。当你电脑上插入一个U盘,可以把一个视频,拷贝到你的电脑硬盘里。那么数据都是在哪些设备上的呢&a…

软件开发过程中前后端联调相关问题

一、接口调用三步曲 1. uniapp接口调用 data中定义 onload中调用 例如:this.getSwiperList()//调用获取轮播图数据的方法 method中定义获取方法 2. 微信小程序接口调用 reques.js中接口封装 如:ScenicspotInfo (data)> re…

高级web前端开发工程师的岗位职责最新(合集)

高级web前端开发工程师的岗位职责最新1 职责: 1.主导公司前端开发的技术方向,指导其他前端开发人员工作 2.负责产品的Web前端开发及用户交互体验设计; 3.基于HTML、CSS、JavaScript标准进行页面制作,编写界面组件; 4.协同后台开发工程师&…

8.4 day05软件学习

文章目录 微服务的概念微服务的原则微服务的特征:集群介绍 spring aop 在家学习效率真不高,下午好兄弟喊出去玩,一直到晚上才回来,赶紧总结一下早上学习的内容。 继续看java基础进阶的思想,之前学的很多都忘了。 微服…

网工内推 | 实施、售后工程师,厂商认证优先

01 安井食品集团股份有限公司 招聘岗位:网络工程师 职责描述: 1.负责集团组网的网络规划、实施、维护工作; 2.负责公司局域网的网络规划、实施、维护工作; 3.负责公司企业安全系统规划、实施、维护工作; 4、负责公…

百度UEditor编辑器如何关闭抓取远程图片功能

百度UEditor编辑器如何关闭抓取远程图片功能 这个坑娘的功能,开始时居然不知道如何触发,以为有个按钮,点击一下触发,翻阅了文档,没有发现,然后再网络上看到原来是复制粘贴非白名单内的图片到编辑框时触发&a…

LT6711A 是一款HDMI 2.0转DP 1.2/EDP 1.4的芯片,实用于AR或者PC以及PAD

LT6711A 1.概述: Lontium LT6711A是HDMI2.0到DP1.2转换器,内部有c型替代模式开关和PD控制器。对于HDMI输入,LT6711A具有一个HDMI2.0接收器,有1个时钟通道和3个数据通道,每个数据通道最大运行6Gb/s,最大输…

DAY02_Spring第三方资源配置管理Spring容器Spring注解开发Spring整合Mybatis和Junit

目录 一 第三方资源配置管理1 管理DataSource连接池对象问题导入1.1 管理Druid连接池1.2 管理c3p0连接池 2 加载properties属性文件问题导入2.1 基本用法2.2 配置不加载系统属性2.3 加载properties文件写法 二 Spring容器1 Spring核心容器介绍问题导入1.1 创建容器1.2 获取bean…

智能汽车驾驶演进:虚拟ECU种类与优劣分析

现代汽车更安全、更舒适、更智能的代价是车载ECU(Electronic Control Unit)数量的迅速增长,与之相对应的是ECU上规模软件越来越大、软件开发成本在整车制造成本中的占比越来越高。车企可以从规则与方法两个角度入手来解决上述问题&#xff1a…

ES6新增的语法

ES6实际上是一个泛指,泛指 ES2015 及后续的版本 1,let用于声明变量的关键字 let 声明的变量只在所处于的代码块内有效 if (true) { let a 10 } console.log(a) // a is not defined 2, let 不存在变量提升 console.log(a) // a is not deined let a 1…

元素2D转3D 椭圆形旋转实现

椭圆旋转功能展示 transform-style: preserve-3d;(主要css代码) gif示例(背景图可插入透明以此实现边框线的旋转) 导致的无法点击遮挡问题可以参考我的另一个文章 穿透属性-----------------------css穿透属性 实时代码展示

如何与 Boot Barn 建立 EDI 连接?

Boot Barn 专注于提供各种高品质的靴子、鞋类和西部服饰。其经营范围广泛,为广大顾客提供最新潮流和经典款式的选择。 Boot Barn 的使命是成为顾客在西部风格时尚领域的首选购物地点。多年来,Boot Barn 凭借卓越的服务和优质的产品赢得了众多客户的信赖和…

LabVIEW使用DSA技术从X射线图像测量肺气容量

LabVIEW使用DSA技术从X射线图像测量肺气容量 相衬X射线(PCX)成像技术利用相邻介质之间折射率的微小差异来增强传统X射线成像通常不可见的物体的边界。事实证明,这一进展在一系列生物医学和材料科学中非常有益于材料表征、疾病检测以及解剖形…

Hive 中把一行记录拆分为多行记录

背景 业务场景:统计每个小时视频同时在线观看人数,因后台的业务数据是汇总之后的,只有开始时间、结束时间,没有每小时的详细日志数据,无法直接进行统计,所以需要对每条业务数据进行拆分,来统计…