G4 - 可控手势生成 CGAN

news2025/1/22 9:23:52
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

目录

  • 代码
  • 总结与心得


代码

关于CGAN的原理上节已经讲过,这次主要是编写代码加载上节训练后的模型来进行指定条件的生成

图像的生成其实只需要使用Generator模型,判别器模型是在训练过程中才用的。

# 库引入
from torch.autograd import Variable
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 超参数
latent_dim = 100
n_classes = 3
embedding_dim = 100

# 工具函数
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight, 1.0, 0.02)
        torch.nn.init.zeros_(m.bias)


# 模型
class Generator(nn.Module):
    def __init__(self):
        super().__init__()

        self.label_conditioned_generator = nn.Sequential(
            nn.Embedding(n_classes, embedding_dim),
            nn.Linear(embedding_dim, 16)
        )

        self.latent = nn.Sequential(
            nn.Linear(latent_dim, 4*4*512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.model = nn.Sequential(
            nn.ConvTranspose2d(513, 64*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*8, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*8, 64*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*4, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*4, 64*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*2, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*2, 64*1, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64*1, momentum=0.1, eps=0.8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64*1, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, inputs):
        noise_vector, label = inputs
        label_output = self.label_conditioned_generator(label)
        label_output = label_output.view(-1, 1, 4, 4)
        latent_output = self.latent(noise_vector)
        latent_output = latent_output.view(-1, 512, 4, 4)

        concat = torch.cat((latent_output, label_output), dim=1)
        image = self.model(concat)
        return image

generator = Generator().to(device)
generator.apply(weights_init)
print(generator)
Generator(
  (label_conditioned_generator): Sequential(
    (0): Embedding(3, 100)
    (1): Linear(in_features=100, out_features=16, bias=True)
  )
  (latent): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): ConvTranspose2d(513, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)
from numpy.random import randint, randn
from numpy import linspace
from matplotlib import pyplot, gridspec

# 加载训练好的权重
generator.load_state_dict(torch.load('generator_epoch_300.pth'), strict=False)
# 关闭梯度积累
generator.eval()

# 生成随机变量
interpolated = randn(100)
interpolated = torch.tensor(interpolated).to(device).type(torch.float32)

# 生成条件变量
label = 0 # 生成第0个分类的图像
labels = torch.ones(1) * label
labels = labels.to(device).unsqueeze(1).long()

# 执行生成
predictions = generator((interpolated, labels))
predictions = predictions.permute(0, 2, 3, 1).detach().cpu()

# 屏蔽警告
import warnings
warnings.filterwarnings('ignore')


# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
# 防止负号无法显示
plt.rcParams['axes.unicode_minus']= False
# 设置图的分辨率
plt.rcParams['figure.dpi'] = 100

# 绘图
plt.figure(figsize=(8, 3))
pred = (predictions[0, :, :, :] + 1) * 127.5
pred = np.array(pred)
plt.imshow(pred.astype(np.uint8))
plt.show()

生成分类0
我们将分类修改为1重新生成一次

生成分类1

总结与心得

在本次实验的过程中,我了解了CGAN模型在训练完成后,后续如何使用的步骤:

  1. 保存训练好的生成器的权重
  2. 使用生成器加载
  3. 生成随机分布变量用于生成图像
  4. 生成指定的标签,并转换成控制向量
  5. 执行生成操作

另外关于警告和matplotlib设置中文字体的方式也是经常会用到的技巧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1719203.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

unity2020打包webGL时卡进程问题

我使用的2020.3.0f1c1,打包发布WEB版的时候会一直卡到asm2wasm.exe这个进程里,而且CPU占用率90%以上。 即使是打包一个新建项目的空场景也是同样的问题,我尝试过一直卡在这里会如何,结果还真打包成功了。只是打包一个空场景需要20…

下载HF AutoTrain 模型的配置文件

下载HF AutoTrain 模型的配置文件 一.在huggingface上创建AutoTrain项目二.通过HF用户名和autotrain项目名,拼接以下url,下载模型列表(json格式)到指定目录三.解析上面的json文件、去重、批量下载模型配置文件(权重以外的文件) 一.在huggingface上创建AutoTrain项目 二.通过HF用…

微信公众号【原子与分子模拟】: 熔化温度 + 超导电性 + 电子化合物 + 分子动力学模拟 + 第一性原理计算 + 数据处理程序

往期内容主要涵盖: 熔化温度 超导电性 电子化合物 分子动力学模拟 第一性原理计算 数据处理程序 【1】熔化温度 分子动力学 LAMMPS 相关内容 【文献分享】分子动力学模拟 LAMMPS 熔化温度 晶体缺陷 熔化方法 LAMMPS 文献:金属熔化行为的局域…

Mac安装第三方软件的命令安装方式

场景: 打开终端命令行,sudo xattr -rd com.apple.quarantine,注意最后quarantine 后面加一个空格!然后打开Finder(访达),点击左侧的 应用程序,找到相关应用,拖进终端qua…

HackTheBox-Machines--Bashed

Bashed 测试过程 1 信息收集 NMAP 80 端口 目录扫描 http://10.129.155.171/dev/phpbash.min.php http://10.129.155.171/dev/phpbash.php 半交互式 shell 转向 交互式shell python -c import socket,subprocess,os;ssocket.socket(socket.AF_INET,socket.SOCK_STREAM);s.co…

dmdts连接kingbase8报错

dmdts连接kingbase报错 环境介绍1 人大金仓jdbc配置2 dmdts 人大金仓jdbc默认配置3 dmdts 修改jdbc配置4 达梦产品学习使用列表 环境介绍 dts版本 使用dmdts连接kingbase金仓数据库报错 无效的URL 对比jdbc连接串,修改配置解决 1 人大金仓jdbc配置 配置URL模版信息等 类名…

深度学习聚类再升级!新算法实现强悍性能,准确率超98%

深度聚类不仅继承了传统聚类算法的优点,在对高维和非线性数据的处理能力,以及自适应性和抗噪性方面也具有很大优势。 具体来说,结合深度学习的聚类算法通过利用深度神经网络的强大特征提取能力,自动学习和识别数据中的复杂结构和…

【小白专用24.5.30已验证】Composer安装php框架thinkPHP6的安装教程

一、框架介绍 1、框架简介和版本选择 Thinkphp是一种基于php的开源web应用程序开发框架ThinkPHP框架,是免费开源的、轻量级的、简单快速且敏捷的php框架。你可以免费使用TP框架,甚至可以将你的项目商用; ThinkPHP8.0 是目前框架正式版的最新版…

Spring 框架:Java 企业级开发的基石

文章目录 序言Spring 框架的核心概念Spring 框架的主要模块Spring Boot:简化 Spring 开发Spring Cloud:构建微服务架构实际案例分析结论 序言 Spring 框架自 2002 年发布以来,已经成为 Java 企业级开发的标准之一。它通过提供全面的基础设施…

家政预约小程序10公众号集成

目录 1 使用测试号3 工作流配置4 配置关注事件脚本5 注册开放平台6 获取公众号access_token6 实现关注业务逻辑总结 我们本次实战项目构建的相当于一个预约平台,既有家政企业,也有家政服务人员还有用户。不同的人员需要收到不同的消息,比如用…

11- Redis 中的 SDS 数据结构

字符串在 Redis 中是很常用的,键值对中的键是字符串类型,值有时也是字符串类型。 Redis 是用 C 语言实现的,但是它没有直接使用 C 语言的 char* 字符数组来实现字符串,而是自己封装了一个名为简单动态字符串(simple d…

基于强化学习的控制率参数自主寻优

1.介绍 针对控制建模与设计场景中控制参数难以确定的普遍问题,提出了一种基于强化学习的控制律参数自主优化解决方案。该方案以客户设计的控制律模型为基础,根据自定义的控制性能指标,自主搜索并确定最优的、可状态依赖的控制参数组合。 可…

ToDesk提示会话数通道限制 - 解决方案及兑惠码分享

如果您最近在体验ToDesk这款远程操控工具时,遇到了提示信息告知“高速通道服务已到期”或“会话数受限”,这表明您本月享受的免费额度——即120小时的使用时间和最多300次的连接机会——已经耗尽。为了解锁无限制的使用时长与连接次数,建议您…

自动驾驶中的长尾问题

自动驾驶中的长尾问题 定义 长尾问题(Long-Tail Problem)是指在数据分布中,大部分的数据集中在少数类别上,而剩下的大多数类别却只有少量的数据。这种数据分布不平衡的现象在许多实际应用中广泛存在,特别是在自动驾驶…

20240531在飞凌的OK3588-C开发板上跑原厂的Buildroot测试USB摄像头

20240531在飞凌的OK3588-C开发板上跑原厂的Buildroot测试USB摄像头 2024/5/31 20:04 USB摄像头分辨率:1080p(1920x1080) 默认编译Buildroot的SDK即可点亮USB摄像头。v4l2-ctl --list-devices v4l2-ctl --list-formats-ext -d /dev/video74 …

双指针法 ( 快乐数 )

「快乐数」 定义为: 对于一个正整数,每一次将该数替换为它每个位置上的数字的平方和。然后重复这个过程直到这个数变为 1,也可能是 无限循环 但始终变不到 1。如果这个过程 结果为 1,那么这个数就是快乐数 编写一个算法来判断一个…

linux /www/server/cron内log文件占用空间过大,/www/server/cron是什么内容,/www/server/cron是否可以删除

linux服务器长期使用宝塔自带计划任务,计划任务执行记录占用服务器空间过大,导致服务器根目录爆满,需要长期排查并删除 /www/server/cron 占用空间过大问题处理 /www/server/cron是什么内容?/www/server/cron是否可以删除&#xf…

基于VGG16使用图像特征进行迁移学习的时装推荐系统

前言 系列专栏:【深度学习:算法项目实战】✨︎ 涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对…

AutoMQ 自动化持续测试平台技术内幕

01 背景 AutoMQ[1] 作为一款流系统,被广泛应用在客户的核心链路中,对可靠性的要求非常的高。所以我们需要一套模拟真实生产场景、长期运行的测试环境,在注入各种故障场景的前提下验证 SLA 的可行性,为新版本的发布和客户的使用提…