CGAN|生成手势图像|可控制生成

news2025/1/25 7:09:28
  •     🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:TensorFlow入门实战|第3周:天气识别
  • 🍖 原作者:K同学啊|接辅导、项目定制

CGAN(条件生成对抗网络)的原理是在原始GAN的基础上,为生成器和判别器提供 额外的条件信息。

CGAN通过将条件信息(如类别标签或其他辅助信息)加入生成器和判别器的输入中,使得生成器能够根据这些条件信息生成特定类型的数据,而判别器则负责区分真实数据和生成数据是否符合这些条件。这种方式让生成器在生成数据时有了明确的方向,从而提高了生成数据的质量与相关性。

CGAN的特点包括有监督学习、联合隐层表征、可控性、使用卷积结构等,其具体内容为:

有监督学习:CGAN通过额外信息的使用,将原本无监督的GAN转变为一种有监督的学习模式,这使得网络的训练更加目标明确,生成结果更加符合预期。
联合隐层表征:在生成模型中,噪声输入和条件信息共同构成了联合隐层表征,这有助于生成更多样化且具有特定属性的数据。
可控性:CGAN的一个关键特点是提高了生成过程的可控性,即可以通过调整条件信息来指导模型生成特定类型的数据。
使用卷积结构:CGAN可以采用卷积神经网络作为其内部结构,这在图像相关的任务中尤其有效,因为它能够捕捉到局部特征,并提高模型对细节的处理能力。

一、前期工作

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
from torchsummary import summary
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 128
train_transform = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5], [0.5,0.5,0.5])])

train_dataset = datasets.ImageFolder(root="H:/G3/rps/rps", transform=train_transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True,
                                           num_workers=6)
def show_images(dl):
    for images, _ in dl:
        fig, ax = plt.subplots(figsize=(10, 10))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(images.detach(), nrow=16).permute(1, 2, 0))
        break

show_images(train_loader)

 

二、构建模型 

latent_dim = 100
n_classes = 3
embedding_dim = 100

def weights_init(m):
    classname = m.__class__.__name__

    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim), 
            nn.Linear(embedding_dim, 16)            
        )
        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),  
            nn.LeakyReLU(0.2, inplace=True)  
        )

        self.model = nn.Sequential( 
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),  
            nn.ReLU(True),            
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),     
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),       
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),       
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()  
        )

    def forward(self, inputs):
        noise_vector, label = inputs  
        label_output = self.label_conditioned_generator(label)     
        label_output = label_output.view(-1, 1, 4, 4)        
        latent_output = self.latent(noise_vector)     
        latent_output = latent_output.view(-1, 512, 4, 4) 
        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)
        return image

generator = Generator().to(device)
generator.apply(weights_init)
a = torch.ones(100)
b = torch.ones(1)
b = b.long()
a = a.to(device)
b = b.to(device)

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_condition_disc = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),     
            nn.Linear(embedding_dim, 3*128*128)         
        )
        
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, bias=False),      
            nn.LeakyReLU(0.2, inplace=True),             
            nn.Conv2d(64, 64*2, 4, 3, 2, bias=False),    
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),  
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*2, 64*4, 4, 3, 2, bias=False),  
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64*4, 64*8, 4, 3, 2, bias=False),  
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),                               
            nn.Dropout(0.4),                            
            nn.Linear(4608, 1),                         
            nn.Sigmoid()                                
        )

    def forward(self, inputs):
        img, label = inputs
        
        label_output = self.label_condition_disc(label)
        label_output = label_output.view(-1, 3, 128, 128)
        
        concat = torch.cat((img, label_output), dim=1)
        
        output = self.model(concat)
        return output

a = torch.ones(2,3,128,128)
b = torch.ones(2,1)
b = b.long()
a = a.to(device)
b = b.to(device)

c = discriminator((a,b))

 三、训练模型及可视化

 这一部分主要定义初始化权重,构建鉴别器和生成器。

# 定义损失函数
adversarial_loss = nn.BCELoss() 
 
def generator_loss(fake_output, label):
    gen_loss = adversarial_loss(fake_output, label)
    return gen_loss
 
def discriminator_loss(output, label):
    disc_loss = adversarial_loss(output, label)
    return disc_loss
learning_rate = 0.0002
 
G_optimizer = optim.Adam(generator.parameters(),     lr = learning_rate, betas=(0.5, 0.999))
D_optimizer = optim.Adam(discriminator.parameters(), lr = learning_rate, betas=(0.5, 0.999))

# 设置训练的总轮数
num_epochs = 100
# 初始化用于存储每轮训练中判别器和生成器损失的列表
D_loss_plot, G_loss_plot = [], []
 
# 循环进行训练
for epoch in range(1, num_epochs + 1):
    
    # 初始化每轮训练中判别器和生成器损失的临时列表
    D_loss_list, G_loss_list = [], []
    
    # 遍历训练数据加载器中的数据
    for index, (real_images, labels) in enumerate(train_loader):
        # 清空判别器的梯度缓存
        D_optimizer.zero_grad()
        # 将真实图像数据和标签转移到GPU(如果可用)
        real_images = real_images.to(device)
        labels      = labels.to(device)
        
        # 将标签的形状从一维向量转换为二维张量(用于后续计算)
        labels = labels.unsqueeze(1).long()
        # 创建真实目标和虚假目标的张量(用于判别器损失函数)
        real_target = Variable(torch.ones(real_images.size(0), 1).to(device))
        fake_target = Variable(torch.zeros(real_images.size(0), 1).to(device))
 
        # 计算判别器对真实图像的损失
        D_real_loss = discriminator_loss(discriminator((real_images, labels)), real_target)
        
        # 从噪声向量中生成假图像(生成器的输入)
        noise_vector = torch.randn(real_images.size(0), latent_dim, device=device)
        noise_vector = noise_vector.to(device)
        generated_image = generator((noise_vector, labels))
        
        # 计算判别器对假图像的损失(注意detach()函数用于分离生成器梯度计算图)
        output = discriminator((generated_image.detach(), labels))
        D_fake_loss = discriminator_loss(output, fake_target)
 
        # 计算判别器总体损失(真实图像损失和假图像损失的平均值)
        D_total_loss = (D_real_loss + D_fake_loss) / 2
        D_loss_list.append(D_total_loss)
 
        # 反向传播更新判别器的参数
        D_total_loss.backward()
        D_optimizer.step()
 
        # 清空生成器的梯度缓存
        G_optimizer.zero_grad()
        # 计算生成器的损失
        G_loss = generator_loss(discriminator((generated_image, labels)), real_target)
        G_loss_list.append(G_loss)
 
        # 反向传播更新生成器的参数
        G_loss.backward()
        G_optimizer.step()
 
    # 打印当前轮次的判别器和生成器的平均损失
    print('Epoch: [%d/%d]: D_loss: %.3f, G_loss: %.3f' % (
            (epoch), num_epochs, torch.mean(torch.FloatTensor(D_loss_list)), 
            torch.mean(torch.FloatTensor(G_loss_list))))
    
    # 将当前轮次的判别器和生成器的平均损失保存到列表中
    D_loss_plot.append(torch.mean(torch.FloatTensor(D_loss_list)))
    G_loss_plot.append(torch.mean(torch.FloatTensor(G_loss_list)))
 
    if epoch%10 == 0:
        # 将生成的假图像保存为图片文件
        save_image(generated_image.data[:50], './sample_%d' % epoch + '.png', nrow=5, normalize=True)
        # 将当前轮次的生成器和判别器的权重保存到文件
        torch.save(generator.state_dict(), './generator_epoch_%d.pth' % (epoch))
        torch.save(discriminator.state_dict(), './discriminator_epoch_%d.pth' % (epoch))

 
 
from numpy.random import randint, randn
from numpy import linspace
from matplotlib import pyplot as plt, gridspec
import numpy as np
 
# Assuming 'generator' and 'device' are defined earlier in your code
 
generator.load_state_dict(torch.load('./generator_epoch_100.pth'), strict=False)
generator.eval()
 
interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)
 
label = 0
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()
 
predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()
 
import warnings
warnings.filterwarnings("ignore")
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100
 
plt.figure(figsize=(8, 3))
 
pred = (predictions[0, :, :, :] + 1) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

 代码中的操作将预测结果的值加1(这样所有的值都变为非负数),然后乘以127.5,最终得到的值就在0到255之间。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1689660.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

项目十二:简单的python基础爬虫训练

许久未见,甚是想念,今日好运,为你带好运。ok,废话不多说,希望这门案例能带你直接快速了解并运用。🎁💖 基础流程 第一步:安装需要用到的requests库,命令如下 pip inst…

Vue3实战笔记(41)—自己封装一个计时器Hooks

文章目录 前言计时器钩子总结 前言 在Vue项目中,封装一个计时器挂钩(Hook)是一种实用的技术,它允许你在组件中方便地管理定时任务,如倒计时、计时器等,而无需在每个使用场景重复编写相同的逻辑代码。 计时…

uniapp如何使用自定义的图标

http://t.csdnimg.cn/8KenC 以上是原文章,下面内容是从这篇文章转发的 一、导入 1.在官方(iconfont-阿里巴巴矢量图标库)选择自己想要的图标,加入购物车 2. 在点击购物车下载代码 3.解压文件夹 并更改名字 4.将文件夹(iconfont&…

Xline社区会议Call Up|在 CURP 算法中实现联合共识的安全性

为了更全面地向大家介绍Xline的进展,同时促进Xline社区的发展,我们将于2024年5月31日北京时间11:00 p.m.召开Xline社区会议。 欢迎您届时登陆zoom观看直播,或点击“阅读原文”链接加入会议: 会议号: 832 1086 6737 密码: 41125…

软件开发成本估算 5大注意事项

一般来说,软件开发成本估算分为:软件规模估算、工作量估算、成本估算和确定软件开发成本等四个过程,其估算基本流程如下: 软件开发成本估算流程 为了进一步确保估算的准确性,提高资源规划和分配效率,确保软…

Redis篇 在linux系统上安装Redis

安装Redis 在Ubuntu上安装Redis 在Ubuntu上安装Redis 在linux系统中,我们安装Redis,必须先使它有root权限. 那么在linux中,如何切换到root用户权限呢? sudo su 就可切换到用户权限了. 在切换到用户权限后,我们需要用一条命令来搜索Redis相关的软件包 apt search redis 会出现非…

Labelme自定义数据集COCO格式【实例分割】

参考博客 labelme标注自定义数据集COCO类型_labelme标注coco-CSDN博客 LabelMe使用_labelme中所有的create的作用解释-CSDN博客 1制作自己的数据集 1.1labelme安装 自己的数据和上面数据的区别就在于没有.json标签文件,所以训练自己的数据关键步骤就是获取标签文…

x264 码率控制原理:rate_estimate_qscale 函数

rate_estimate_qscale 函数 原理 函数功能:根据目前使用的实际比特数更新一帧的qscale;是一个复杂的决策过程,需要考虑多种因素,如帧类型、编码的复杂度、目标比特率、缓冲区大小等,以确保视频质量和文件大小之间的平衡。函数参数分析:x264_t *h :编码器上下文信息结构…

UDEV规则配置usb摄像头

参考自:【linux】linux下摄像头设置固定的设备名-udev_linux 摄像头的设备文件名-CSDN博客 UDEV规则 在Linux系统中,UDEV(Userspace DEV)是一个用于管理设备节点和/dev目录下设备文件的动态设备管理器。当你连接USB摄像头或其他USB设备时&am…

【ai】chatgpt的plugin已经废弃

发现找不到按钮,原来是要申请: https://openai.com/index/chatgpt-plugins/ 发现申请已经跳转了,好像是废弃了? 不接受新插件了,但是openai的api 是可以继续用的。 https://openai.com/waitlist/plugins/We are no longer accepting new Plugins, builders can now create…

react中的useEffect()的使用

useEffect()是react中的hook函数,作用是用于创建由渲染本身引起的操作,而不是事件的触发,比如Ajax请求,DOM的更改 首先useEffect()是个函数,接受两个参数,第一个参数是一个方法,第二个参数是数…

Generic Segmentation Offload(GSO)

Generic Segmentation Offload汉语意思是啥? Generic Segmentation Offload(GSO)的汉语意思是“通用分段卸载”。在网络通信中,GSO 是一种技术,用于在网络栈中将较大的传输单元分段为更小的单元,以提高网络…

自然语言处理实战项目29-深度上下文相关的词嵌入语言模型ELMo的搭建与NLP任务的实战

大家好,我是微学AI,今天给大家介绍一下自然语言处理实战项目29-深度上下文相关的词嵌入语言模型ELMo的搭建与NLP任务的实战,ELMo(Embeddings from Language Models)是一种深度上下文相关的词嵌入语言模型,它采用了多层双向LSTM编码器构建语言模型,并通过各层LSTM的隐藏状…

顶顶通实时质检系统新增一大功能:黑名单功能介绍

文章目录 前言联系我们功能介绍配置方案 前言 顶顶通实时质检系统新增黑名单一大功能。该功能可通过调用质检系统的黑名单接口,对被叫号码进行检测。如果被检测的号码符合所设定的拦截规则,就会对当前呼叫进行拦截,取消呼叫。 联系我们 有意…

数据库系列之MySQL数据库中内存使用分析

在实际系统环境中,MySQL实例的内存使用随着业务的增长缓慢增长,有些时候并没有及时的释放。本文简要介绍下MySQL数据库中和内存相关的配置,以及分析内存的实际使用情况,以进行应急和调优处理。 1、MySQL内存结构 在MySQL中内存的…

运营抖音小店,这件事情每天都需要去做!一个都不能少!

大家好,我是电商小V 咱们的店铺开好之后,然后运营自己的店铺每天需要做好什么事情呢?这个问题是很多新手小伙伴开通抖店之后最关心的问题,咱们今天就来详细的说一下运营抖音小店每天需要做什么呢? 第一点:奖…

Transformer详解常见面试问题

文章目录 1. 各模块解决1.1 输入部分1.2 多头注意力(作者使用8个头)1.3 残差和LayerNorm1.4 Decoder部分 2.Transformer经典问题2.1 tranformer为何使用多头注意力机制?2.2 Transformer相比CNN的优缺点2.3 Encoder和decoder的区别&#xff1f…

03-ArcGIS For JavaScript结合ThreeJS功能

ArcGIS For JavaScript结合ThreeJS功能 概述three.js中功能实现externalRenderers(4.28及以下版本)RenderNode(4.29版本) 概述 ArcGIS For Javacript提供了一些对象可以支持加载webgl上下文信息,这里包括webgl编程的代…

【Crypto】Url编码

文章目录 Url编码解题感悟 Url编码 Url编码 搞定 小小flag,拿下! 解题感悟 有点饿了…

Edge浏览器:重新定义现代网页浏览

引言 - Edge的起源与重生 Edge浏览器,作为Microsoft Windows标志性的互联网窗口,源起于1995年的Internet Explorer。在网络发展的浪潮中,IE曾是无可争议的霸主,但随着技术革新与用户需求的演变,它面临的竞争日益激烈。…