使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下

news2024/9/22 7:39:59

文章目录

  • 1 测试鉴别器
  • 2 建立生成器
  • 3 测试生成器
  • 4 训练生成器
  • 5 使用生成器
  • 6 内存查看

上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类。
使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上
接下来,我们测试一下鉴别器是否可以正常工作,并建立生成器。

1 测试鉴别器

# 数据类建立
celeba_dataset = CelebADataset(r'F:\学习\AI\对抗网络\face-data\celeba_aligned_small.h5py')
celeba_dataset.plot_image(66)

# 鉴别器类建立
D = Discriminator()
D.to(device)

for image_data_tensor in celeba_dataset:
    # real data
    D.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))
    # fake data
    D.train(generate_random_image((218,178,3)), torch.cuda.FloatTensor([0.0]))

此处我们调用了两个类,一个是celeba_dataset(Dataset)类,一个是D(Discriminator)类。两个类在博文的上篇中完成了定义。此处分别使用real数据与fake数据对模型进行训练。fake数据使用的是随机生成的不规则像素点,real数据使用的是真是人脸数据。
在使用GPU的情况,此处预计会消耗5分钟左右。
训练完成后,可以绘制损失值的变化以查看训练效果。

D.plot_progress()
plt.show()

在这里插入图片描述

2 建立生成器

生成器与鉴别器高度类似,仅网络的结构和训练部分略有不同。
网格结构选取的是输入层为100个节点,中间层为单层结构,包含3*10*10个节点,输出层为3 * 218 * 178。输出层是完全根据照片的像素格式来确定的,输入层与中间层可以根据经验进行修改与优化。各层之间均采用全连接的连接方式。相关部分的代码如下:

class Generator(nn.Module):

    def __init__(self):
        # 父类继承
        super().__init__()

        # 定义神经网络
        self.model = nn.Sequential(
            nn.Linear(100, 3 * 10 * 10),
            nn.LeakyReLU(),

            nn.LayerNorm(3 * 10 * 10),

            nn.Linear(3 * 10 * 10, 3 * 218 * 178),

            nn.Sigmoid(),
            View((218, 178, 3))
        )

在进行损失计算时,我们将鉴别器的返回值作为实际输出,将torch.cuda.FloatTensor([1.0]作为目标输出,来计算损失。相关比分的代码如下:

class Generator(nn.Module):
    def train(self, D, inputs, targets):
        # 计算输出
        g_output = self.forward(inputs)

        # 将输出传至鉴别器
        d_output = D.forward(g_output)

        # 计算损失
        loss = D.loss_function(d_output, targets)

对于生成器的完整代码,也将在文末进行提供。

3 测试生成器

未经训练的生成器,应该具备生成类似雪花马赛克的随机图像能力。下面建立了一个生成器类,并用未经训练的生成器直接输出图像。

G = Generator()
G.to(device)

output = G.forward(generate_random_seed(100))
img = output.detach().cpu().numpy()
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()

如果代码运行正常,应得到类似下面的图象。
在这里插入图片描述

4 训练生成器

训练时,对数据集进行遍历,并且依次执行下面三步:

  1. 使用真实照片数据,对鉴别器进行训练,期望的鉴别器输出值为1;
  2. 使用生成器输出的fake数据,对鉴别器进行训练,期望的鉴别器输出值为0;
  3. 使用鉴别器的返回值,训练生成器,生成器所希望的鉴别器输出为1
    具体代码如下:
for image_data_tensor in celeba_dataset:
    # train discriminator on true
    D.train(image_data_tensor, torch.cuda.FloatTensor([1.0]))

    # train discriminator on false
    # use detach() so gradients in G are not calculated
    D.train(G.forward(generate_random_seed(100)).detach(), torch.cuda.FloatTensor([0.0]))

    # train generator
    G.train(D, generate_random_seed(100), torch.cuda.FloatTensor([1.0]))

在训练后,可以分别查看鉴别器与生成器的损失变化曲线。

D.plot_progress()
G.plot_progress()

下图为鉴别器损失值变化曲线
在这里插入图片描述
下图为生成器损失值变化曲线
在这里插入图片描述

5 使用生成器

在这里插入图片描述

6 内存查看

最后可以查看一下本次训练的内存使用情况
(1)分配给张量的当前内存(输出单位是GB)

torch.cuda.memory_allocated(device) / (1024*1024*1024)

我的输出结果为:0.6999950408935547
(2)分配给张量的总内存(输出单位是GB)

torch.cuda.max_memory_allocated(device) / (1024*1024*1024)

我的输出结果为:0.962151050567627
(3)内存消耗汇总

print(torch.cuda.memory_summary(device, abbreviated=True))

输出结果如下:

|===========================================================================|
|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|===========================================================================|
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  733998 KB |     985 MB |   14018 GB |   14017 GB |
|---------------------------------------------------------------------------|
| Active memory         |  733998 KB |     985 MB |   14018 GB |   14017 GB |
|---------------------------------------------------------------------------|
| GPU reserved memory   |    1086 MB |    1086 MB |    1086 MB |       0 B  |
|---------------------------------------------------------------------------|
| Non-releasable memory |    9426 KB |   12685 KB |  353393 MB |  353383 MB |
|---------------------------------------------------------------------------|
| Allocations           |      68    |      87    |    2580 K  |    2580 K  |
|---------------------------------------------------------------------------|
| Active allocs         |      68    |      87    |    2580 K  |    2580 K  |
|---------------------------------------------------------------------------|
| GPU reserved segments |      15    |      15    |      15    |       0    |
|---------------------------------------------------------------------------|
| Non-releasable allocs |      11    |      14    |    1410 K  |    1410 K  |
|---------------------------------------------------------------------------|
| Oversize allocations  |       0    |       0    |       0    |       0    |
|---------------------------------------------------------------------------|
| Oversize GPU segments |       0    |       0    |       0    |       0    |
|===========================================================================|

代码文件:博客附件代码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/188565.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Source Insight基本使用

据说阅读Linux源码经常使用此工具;先看一下基本使用; 新建一个工程; OK以后出现下图;这是insight项目的目录; 把要阅读的源码工程加进来; 如下2个选项选中,OK; 如果下图右侧的内容没…

在Windows中操作系统下,检查Python脚本是否已运行

在Windows中操作系统下,检查Python脚本是否已运行 作者:虚坏叔叔 博客:https://xuhss.com 早餐店不会开到晚上,想吃的人早就来了!😄 一、原理 用一个空的虚拟文件。 在进程开始时,检查文件是…

ruby 给钉钉群发消息

给钉钉群发一条工作消息用途如下: Ipa, apk,打包完成了, 可以用作测试群表格导出成功了, 一般的群消息比如后台日志报警等等 步骤如下 群设置 - 智能群助手 - 添加机器人 - 选择 - 自定义 机器人设置里面要设置一个自定义关键词, 比如这里面 我起个名字 summerxx 上篇说到我…

javaweb高校大学毕业生就业跟踪系统ssm idea maven

系统所要实现的功能分析,对于现在网络方便的管理,系统要实现毕业生可以直接在平台上进行查看所有数据信息,根据需求可以进行在线添加,删除或修改企业信息、问卷调查、问卷提交、招聘信息、投递简历、企业评价、就业调查、就业表提…

(十三)devops持续集成开发——jenkins流水线发布一个sonar qube质量检查项目

前言 在前面的内容中我们已经介绍过如何在jenkins中集成质量检查工具sonar qube,以及sonar qube服务的安装。本节内容我们通过使用jenkins构建一个包含sonar qube质量检查的流水线项目,从而实现项目部署发布上线前的代码质量检查。从而保证系统的稳定性…

带约束进化算法问题分析Constrained Evolutionary Algorithms

经典论文《Evolutionary Algorithms for Constrained Parameter Optimization Problems》对带约束的进化算法进行了综述,本文不涉及其内容的翻译,主要为个人对论文理解和思考。 1. 进化算法定义Evolutionary Algorithms 论文中所讨论的进化算法主要为以…

java泛型5

泛型类 Java泛型不仅允许在使用通配符形参时设定上限,而且可以在定义泛型形参时设定上限,用于表示传给该泛型形参的实际类型要么是该上限类型,要么是该上限类型的子类。 上面程序定义了一个Apple泛型类,该Apple类的泛型形参的上限…

免安装PortableGit配置 + TortoiseGit安装

文章目录官网/安装Git将git命令添加到Path环境变量添加GitHub登录账号下载安装TortoiseGit官网/安装Git Git官网:https://git-scm.com/ 国内用户,建议通过淘宝镜像网站下载安装文件: https://registry.npmmirror.com/binary.html?pathgit…

18.异常

目录 一.异常 1.1 什么是异常 1.2 为什么要学习异常 1.3 异常的体系 1.5 编译时异常 1.5.1 什么是编译时异常 1.5.2 编译时异常的作用 1.5.3 常见编译时异常 1.6 运行时异常 1.6.1 什么是运行时异常 1.6.2 常见运行时异常 1.6 异常的默认处理流程(RunTim…

web3:区块链常见的几大共识机制及优缺点

web3相关学习一并收录至该博客:web3学习博客目录大全 胡歌看了都得给我一键三连吧! 目录什么是共识?什么是共识机制?共识机制的目标为什么需要共识机制?如何评价一个共识机制的优劣:共识机制分类PoW( Proof of Work)工作量证明&a…

2023年最新!北京Java培训机构排行榜新鲜出炉!

北京作为中国的首都,其人才的需求的体量之大是其他城市不可比的。那么在北京学习Java,到底该怎么选择Java培训机构哪?怎么在众多的机构里面选择出最适合自己的哪?下面是小编根据口碑和实力整理出的北京Java培训机构排行榜单,仅供…

【Effective_Objective-C_6 块block】

文章目录前言GCD和块的简介37.理解块的概念块的基础知识块可以捕获变量内联块的用法块的内部结构全局块,栈块,堆块堆块全局块要点38.为常用的块类型创建typedef要点39.用handler块降低代码分散程度协议传值实现异步块实现异步回调操作里的块要点40.用块引…

说说redux的实现原理是什么,写出核心代码?

目录标题一、redux三大基本原则是:二、实现原理:三、如何使用一、redux三大基本原则是: 单一数据源state是只读的使用纯函数来执行修改 注意的是,redux并不是只应用在react中,还与其他界面库一起使用,如V…

3.26 haas506 2.0开发教程-example- 串口控制ESP32-CAM OV2640拍照

haas506串口控制ESP32-CAM OV2640拍照介绍ESP32-CAM开发板硬件连接代码流程代码ESP32-CAM开发板代码HaaS506开发板代码测试ESP32-CAM开发板测试介绍 通过HaaS506串口发送指令,控制ESP32-CAM进行拍照,并将照片储存在SD卡中。ESP32-CAM需要5V供电才能正常…

小程序开发常见问题总结(超实用)

小程序开发常见问题总结(超实用) 文章目录小程序开发常见问题总结(超实用)1.小程序user agent stylesheet问题。2.this.setData is not function错误3.flex布局3.1flex布局原理3.2flex父项属性3.3flex布局子项元素4.自定义组件1.在…

白银k线图基础知识梳理:包覆形态

伦敦银价格走势是国际市场上所有参与者多方合力的结果,这些参与者包括银行、白银商、期货交易商、对冲基金等金融机构、各种法人机构以及个人投资者。一根简单的K线,能够把所有市场参与者博弈的结果展示出来,并且反映出银价运行和变化的各个细…

node后端接收到axios的post请求体为空

node后端接收到axios的post请求体为空??? 使用axios发送post请求,传入了Object格式的参数,在node后端req.body接收到的参数为空,但是网页上抓包检查时,发现请求的body确实是携带了参数的&#x…

【工具】2023开年利器,重写收藏逻辑和内置白板应用的Arc浏览器

目录一、为什么你需要一款新的浏览器?二、重写的收藏夹逻辑三、自带笔记和白板的浏览器四、如何获得Arc浏览器一、为什么你需要一款新的浏览器? 人生漫漫,三年混乱。在经历了这些起伏之后,你一定有一个不断进取的决心。 工欲善其…

如何设置将SAP红灯报错改为黄灯(OBA5 更改消息控制 )

在SAP的业务操作中或者后台配置经常遇到SAP校验报红灯的错误导致业务进行不下去。可以通过OBA5 更改消息控制事务修改消息报错类型,例如把红灯报错改为黄灯,这样业务就可以进行下去了。 举两个例子来说明一下如何配置。 目录 例子1:固定资…

【stl -- 内建函数对象】

目录:前言一、仿函数二、算数仿函数三、关系仿函数四、逻辑仿函数总结前言 概念 stl内建了一些仿函数 分类 算数仿函数、 关系仿函数、 逻辑仿函数 用法 这些仿函数所产生的对象,用法和普通函数完全一样; 使用内建仿函数需要包含头文件 一、…