使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上

news2024/11/25 6:50:28

文章目录

  • 1 数据集描述
  • 2 GPU设置
  • 3 设置Dataset类
  • 4 设置辨别器类
  • 5 辅助函数与辅助类

1 数据集描述

此项目使用的是著名的celebA(CelebFaces Attribute)数据集。其包含10,177个名人身份的202,599张人脸图片,每张图片都做好了特征标记,包含人脸bbox标注框、5个人脸特征点坐标以及40个属性标记,数据由香港中文大学开放提供(不包含商业用途的使用)。
在这里插入图片描述

在实际训练前,已经将数据处理成了HDF5的数据集格式。使用h5py处理HDF5数据集可以提供很多方便,使得数据处理更加高效、灵活、可扩展,显著提升训练过程的文件读取速度。可以使用h5py包自行对数据进行处理,也可直接下载我已经处理好的HDF5数据格式。
如需了解更多h5py相关知识,可以查看HDF5补充知识。

2 GPU设置

前面几篇博客的内容,都是对手写数字这个数据集的处理,CPU还能吃得消。这次数据输入明显增加,需要使用GPU处理数据。如电脑无NAVIDIA独显,建议使用Google Colab执行代码,Colab提供了免费的GPU算力。

if torch.cuda.is_available():
  torch.set_default_tensor_type(torch.cuda.FloatTensor)
  print("using cuda:", torch.cuda.get_device_name(0))
  pass

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

这段代码的作用是,如果当前设备有可用的CUDA,则将默认的张量类型设置为CUDA浮点张量并输出使用的CUDA设备的名称。然后,它将设备设置为CUDA设备(如果有)或CPU。

3 设置Dataset类

基于面向对象编程的基本原则,我们建立一个Dataset类,使类具有数据读取、获取指定索引的数据与绘制指定索引的图像,具体代码如下:

class CelebADataset(Dataset):

    def __init__(self, file):
        self.file_object = h5py.File(file, 'r')
        self.dataset = self.file_object['img_align_celeba']
        pass

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        if index >= len(self.dataset):
            raise IndexError()
        img = numpy.array(self.dataset[str(index) + '.jpg'])
        return torch.cuda.FloatTensor(img) / 255.0

    def plot_image(self, index):
        plt.imshow(numpy.array(self.dataset[str(index) + '.jpg']), interpolation='nearest')
        plt.show()

在获取指定索引对应的数据时,如果指定数大于索引的最大值,我们命令程序返回一个IndexError()错误,以便于快速查找问题所在。
为了理解这一数据类,我们对类进行使用:

celeba_dataset = CelebADataset('文件地址.h5py')

这里创建了一个名为celeba_datasetCelebADataset类,并传入了文件的所在路径file。在__init__中,使用h5py.File方法读取路经所在的文件。

celeba_dataset.plot_image(66)

绘制数据集中66.jpg图形。如果前面代码正确,此处将绘制出数据集中的人脸头像。如果为绘制出图形并产生报错,考虑路径是否有误以及数据格式是否正确。
在这里插入图片描述

4 设置辨别器类

本项目的核心类为鉴别器类与生成器类,下面开始编写鉴别器类。首先建立神经网络框架:

class Discriminator(nn.Module):

    def __init__(self):
        # 父类继承
        super().__init__()

        # 神经网络定义
        self.model = nn.Sequential(
            View(218 * 178 * 3),

            nn.Linear(3 * 218 * 178, 100),
            nn.LeakyReLU(),

            nn.LayerNorm(100),

            nn.Linear(100, 1),
            nn.Sigmoid()
        )

        # 创建损失函数
        self.loss_function = nn.BCELoss()

        # 创建优化器
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)

        # 初始化计数器
        self.counter = 0
        self.progress = []

这段代码定义了一个名为Discriminator的类,继承了PyTorch中nn.Module类。在__init__函数中,通过nn.Sequential定义了一个神经网络模型,包括三个线性层,两个激活函数,一个归一化层。一开始的View(218*178*3)是新代码。它的作用是将大小为(218, 178, 3) 的三维图像张量重塑成一个长度为218×178×3的一维张量。基于自上而下的编程习惯,我们会在后面对View进行定义。
在此基础上,定义了损失函数nn.BCELoss()和优化器Adam,并定义了一个计数器和一个存储进度的列表。

class Discriminator(nn.Module):
    def forward(self, inputs):
        # simply run model
        return self.model(inputs)

    def train(self, inputs, targets):
        # calculate the output of the network
        outputs = self.forward(inputs)

        # calculate loss
        loss = self.loss_function(outputs, targets)

        # increase counter and accumulate error every 10
        self.counter += 1
        if (self.counter % 10 == 0):
            self.progress.append(loss.item())

        if (self.counter % 1000 == 0):
            print("counter = ", self.counter)

        # 梯度归零,向后传递,优化执行
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()


    def plot_progress(self):
        df = pandas.DataFrame(self.progress, columns=['loss'])
        df.plot(ylim=(0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))

接下来定义forward功能,train功能,plot_progress功能。在forward()函数中,它只是让模型对输入数据进行前向传播并返回网络的输出。在train()函数中,它使用输入数据和目标数据来计算网络的损失,并使用优化器来更新网络的参数。最后,plot_progress()函数可以用来绘制训练进度。以上类方法与手写字体识别博文中的定义完全相同,如有需要可找到对应博文查看。

5 辅助函数与辅助类

class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape, # 逗号不是多打的

    def forward(self, x):
        return x.view(*self.shape)

在前面定义鉴别器类时,我们已经使用了View,此处对View进行补充定义。在 forward 方法中,它对输入的 x 应用了 view 方法,并将 shape 属性作为参数传入。这个模型的作用是将输入的张量的形状调整为 shape 属性所指定的形状。

def generate_random_image(size):
    random_data = torch.rand(size)
    return random_data


def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data

以上两个随机张量生成器,其作用与手写数字识别中的作用完全相同,在此不再赘述。后续在使用时也会再进行介绍。

截至目前,我们已经建立好了模型所必需的鉴别器类与Dataset类。下一篇会讲解最重要的鉴别器类以及对模型的训练与使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/168188.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据结构】顺序表的原理及实现

1.什么是顺序表 顺序表是用一段物理地址连续的存储单元进行存储元素的线性结构,通常是以数组进行存储。通过数据元素物理存储的相邻关系来反映数据元素之间逻辑上的相邻关系。 2.顺序表的实现 判断顺序表是否为空表public boolean isEmpty()判断顺序表是否满publi…

复旦MBA海外短期课程 | 善用ESG金融,共创可持续未来

2022年,世界在颠簸中向前迈进:全球气候变化、能源危机、大国博弈……在这样的背景下,近年来备受瞩目的ESG价值、“双碳”目标、可持续发展、责任投资等话题愈发成为焦点。今年复旦MBA培养体系全面升级之际,在新增的“未来发展模块…

Pandas CSV 文件

Pandas CSV 文件CSV(Comma-Separated Values,逗号分隔值,有时也称为字符分隔值,因为分隔字符也可以不是逗号),其文件以纯文本形式存储表格数据(数字和文本)。CSV 是一种通用的、相对…

【沐风老师】3dMax一键生成中央空调排风口插件使用教程

3dMax一键创建中央空调排风口插件,快捷生成矩形或菱形两种形状的排风口。 【版本要求】 不确定。3dmax2020环境测试可用,其他版本自行测试,欢迎反馈测试结果。 【安装方法】 方法一:拖动插件文件到3dMax窗口打开。 方法二&#x…

ESP32相关知识点

1.vs code代码回退到上一步: 方法1:在Windows中可以使用快捷键“Alt←”实现 方法2:利用vs code界面操作,Go-Back 2.设置vs code下ESP-IDF Monitor Device的波特率。 步骤Manage-New Code update available------Command Palette 弹出对话框 搜索或选择:ESP-IDF:SD…

基于Androidstudio的打车拼车app

需求信息: 客户端: 1:注册和登录:用户信息录入与登录。 2:修改密码:用户忘记密码可在此处找回。 3:地图功能:提供城市地图,能被自由移动、放大和缩小。 4:自身定位:获取用户位置信息&#xff0c…

STM-PWM采集

输入捕获分为两种方式进行捕获 1、pwm输入捕获:精度高,每个定时器只能采集一个pwm,且只能使用通道1、通道2。 2、通用输入捕获:相对比较精确,每个定时器可以采集多个pwm, 1、pwm输入捕获使用教程如下: 参…

Vue 快速入门(一)

1、介绍 Vue(读音/vju/,类似view),是中国的大神尤雨溪开发的,为数不多的国人开发的世界顶级开源软件。是一套用于构建用户界面的渐进式框架,Vue 被设计为可以自底向上逐层应用。MVVM响应式编程模型,避免直接操作DOM&am…

虹科分享 | HPC调度解决方案:HK-Adaptive在数字卫星图像领域的应用

2011年3月11日,日本海岸附近发生了9.0级地震。这次地震引发了强大的海啸,并向内陆传播了6英里,不仅使地球的轴心偏移了大约10到25厘米,还导致福岛核电站发生核紧急情况。 为了减少这场灾害的损失,不同国家的不同组织以…

浅谈虚拟电厂与企业微电网数字化建设

安科瑞 华楠摘要:2023年1月8日,微信公众号鱼眼看电改(作者俞庆)发表了文章《虚拟电厂与负荷侧数字化》,原文如下:“虚拟电厂是电力数字化的一个应用方向,准确的说,是负荷侧数字化的发展方向。所以负荷侧数字…

NetSuite Classifications的注意事项

Classifications,我们将之称为“分析维度”吧,这样更能体现其真正的用途。在NetSuite中标准的分析维度是“LDC(Location、Department、Class)”,这是NetSuite的特色。有些注意事项我们今天分享下。 1. LDC的用途区别 …

Django REST framework--Swagger API文档生成器

Django REST framework--Swagger API文档生成器swagger在线接口文档drf-yasg安装与配置安装drf-yasg配置drf-yasg互动模式文档模式定制化用法(viewset模式)修饰视图装饰器api_view修饰视图集swagger在线接口文档 目前为止,接口开发到了一定的…

Spring AOP【统一异常处理与统一数据格式封装】

Spring AOP【统一异常处理与统一格式封装】🍎一.统一异常处理🍒1.1 实现一个异常方法🍒1.2 统一处理异常代码的实现🍒1.3 统一处理所有异常🍎二.统一格式封装🍒2.1 实现一个返回数据方法🍒2.2 统…

回收租赁商城系统功能拆解09讲-会员管理

回收租赁系统适用于物品回收、物品租赁、二手买卖交易等三大场景。 可以快速帮助企业搭建类似闲鱼回收/爱回收/爱租机/人人租等回收租赁商城。 回收租赁系统支持智能评估回收价格,后台调整最终回收价,用户同意回收后系统即刻放款,用户微信零…

echarts省市区id(区域编码)实现地图下钻点击(data赋值自定义属性值,geojson信息获取)

致新的一年:不知不觉已经是2023年,祝新的一年大展宏图(兔),前途(兔)似锦,今年梦想实现! 正文: 接触echarts也有很长一段时间了,最近有个很常见的…

UE4 反射 学习笔记

想让类具有反射机制&#xff0c;必须要有这四个要素&#xff1a; 1、generated.h文件 2、UCLASS()宏 3、继承自UObject 4、GENERATED_BODY() void ABlurCharacter::BeginPlay() {Super::BeginPlay();ABlurCharacter *BlurCharacter NewObject<ABlurCharacter>();UCl…

直线检测算法汇总分析

直线检测算法汇总 1、场景需求 在计算机视觉领域&#xff0c;我们经常需要做一些特殊的任务&#xff0c;而这些任务中经常会用到直线检测算法&#xff0c;比如车道线检测、长度测量等。尽管直线检测的任务看起来比较简单&#xff0c;但是在具体的应用过程中&#xff0c;你会发…

MySQL50题

四张表&#xff1a; 1.学生表 Student&#xff08;s_id,s_name,s_birth,s_sex) 2.课程表Course(c_id,c_name,t_id) 3.教师表Teacher&#xff08;t_id&#xff0c;t_name) 4.成绩表Score(s_id,c_id,s_score) 建表语句&#xff1a; 创建学生表并且往表中插入语句 CREATE TABL…

如何下载通达信接口 费用如何?

之前我分享了自编的一些通达信指标公式。经粉丝咨询&#xff0c;我发现自己疏忽了一个问题&#xff1a;许多人不知道如何下载/使用通达信接口软件&#xff01; 通达信软件PC版&#xff0c;有以下两种形态&#xff1a; 第一种形态是官方版。 官方版的软件下载链接在这里&…

C语言—文件操作(学好文件操作,再也不用担心数据丢失)

专栏&#xff1a;C语言 个人主页&#xff1a;HaiFan. 专栏简介&#xff1a;本专栏主要更新一些C语言的基础知识&#xff0c;也会实现一些小游戏和通讯录&#xff0c;学时管理系统之类的&#xff0c;有兴趣的朋友可以关注一下。 文件操作前言一、为什么使用文件二、什么是文件1.…