PyTorch训练Celeba

news2024/10/8 13:29:54

CelebFaces属性数据集(CelebA)是一个大规模的人脸属性数据集,包含超过20万张名人图像,每张图像都有40个属性标注。该数据集中的图像涵盖了大范围的姿态变化和复杂的背景。CelebA具有高度的多样性、大量的数据和丰富的标注信息,包括:

  • 10,177个身份,
  • 202,599张人脸图像,
  • 每张图像有5个标志点位置和40个二进制属性标注。

该数据集可用于以下计算机视觉任务的训练和测试集:人脸属性识别、人脸识别、人脸检测、标志点(或面部部位)定位以及人脸编辑与合成。

1. 安装必要的依赖

如果还没有安装PyTorch和Torchvision,可以通过pip安装:

%pip install torch torchvision

2. 加载CelebA数据集

我们可以通过torchvision.datasets模块轻松获取它。
如果没有数据集且无法通过google drive下载,可以通过这篇notebook关联的数据集下载

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 图像的预处理:缩放到 64x64 大小,并归一化
# 这个变换管道通常用于图像分类任务中的数据预处理步骤。通过调整图像大小、转换为张量和标准化,可以确保输入数据的一致性和模型训练的稳定性。
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    #将图像从 PIL 图像或 numpy 数组转换为 PyTorch 张量,并且会自动将图像的像素值从 [0, 255] 范围缩放到 [0, 1] 范围。
    transforms.ToTensor(),
    #使用给定的均值和标准差对图像进行标准化。这里的均值和标准差分别是 [0.5, 0.5, 0.5],表示对每个通道(RGB)进行标准化。标准化公式为: [ text{output} = frac{text{input} - text{mean}}{text{std}} ] 由于输入的像素值已经被 transforms.ToTensor 缩放到 [0, 1] 范围,标准化后的像素值将被调整到 [-1, 1] 范围。
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 加载 CelebA 数据集,如果google drive可以用,download=True,可以直接下载,也可以自己下载后解压到/root/celeba目录下
root = '/root'
train_dataset = torchvision.datasets.CelebA(root=root, split='train', download=False, transform=transform)
test_dataset = torchvision.datasets.CelebA(root=root, split='test', download=False, transform=transform)

# 创建数据加载器,改大batch_size,在GPU T4下显存使用率没增加?
batch_size = 128
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 检查数据加载是否成功
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")

3. 定义CNN模型

接下来,我们定义一个简单的卷积神经网络模型,结构可以根据任务的复杂性调整。这里的例子是一个基础的CNN。

3.1 初始化

卷积层:
  • self.conv1:第一个卷积层,输入通道数为 3(RGB 图像),输出通道数为 64,卷积核大小为 3x3,步幅为 1,填充为 1。
  • self.conv2:第二个卷积层,输入通道数为 64,输出通道数为 128,卷积核大小为 3x3,步幅为 1,填充为 1。
  • self.conv3:第三个卷积层,输入通道数为 128,输出通道数为 256,卷积核大小为 3x3,步幅为 1,填充为 1。
全连接层:
  • self.fc1:第一个全连接层,输入大小为 25688,输出大小为 512。
  • self.fc2:第二个全连接层,输入大小为 512,输出大小为 40(CelebA 数据集有 40 个属性标签)。
池化层和激活函数:
  • self.pool:最大池化层,池化核大小为 2x2,步幅为 2。
  • self.relu:ReLU 激活函数。
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(256*8*8, 512)
        self.fc2 = nn.Linear(512, 40)  # CelebA 有 40 个属性标签

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.relu(self.conv3(x))
        x = self.pool(x)
        
        x = x.view(x.size(0), -1)  # 展平操作
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x
# 实例化模型
model = SimpleCNN()

4. 定义损失函数和优化器

对于多标签分类任务,我们使用 BCEWithLogitsLoss 损失函数。优化器可以选择 Adam 或 SGD。

criterion = nn.BCEWithLogitsLoss()  # 二分类交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.001)

5. 训练模型

下面是训练模型的代码,每个 epoch 后对模型进行验证。

5.1. train 函数

train 函数用于训练模型。

参数解释
  • model:要训练的神经网络模型。
  • train_loader:训练数据的 DataLoader 对象。
  • criterion:损失函数,用于计算预测值和真实值之间的误差。
  • optimizer:优化器,用于更新模型的参数。
  • num_epochs:训练的轮数。
  • device:计算设备(如 CPU 或 GPU)。

5.2. evaluate 函数

evaluate 函数用于评估模型的性能。

参数解释
  • model:要评估的神经网络模型。
  • test_loader:测试数据的 DataLoader 对象。
  • criterion:损失函数,用于计算预测值和真实值之间的误差。
  • device:计算设备(如 CPU 或 GPU)。
def train(model, train_loader, criterion, optimizer, num_epochs, device):
    model.train() #设置模型为训练模式
    for epoch in range(num_epochs):
        for images, labels in train_loader:
            # 将数据移动到选定的设备
            images, labels = images.to(device), labels.to(device).float()
            
            # 前向传播
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # 反向传播和优化
            optimizer.zero_grad()#清零梯度
            loss.backward()#计算梯度
            optimizer.step()#更新参数
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
                
def evaluate(model, test_loader, criterion, device):
    model.eval() # 设置模型为评估模式
    with torch.no_grad(): # 禁用梯度计算
        total_loss = 0
        correct = 0
        total = 0
        for images, labels in test_loader:
            # 将数据移动到选定的设备
            images, labels = images.to(device), labels.to(device).float()
            
            # 前向传播
            outputs = model(images)
            # 计算损失,并累加到 total_loss
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            
            # 计算准确率
            predicted = (outputs > 0.5).float()
            total += labels.size(0) * labels.size(1)  # 总标签数
            correct += (predicted == labels).sum().item()
        # 打印平均损失和准确率
        print(f'Average loss: {total_loss / len(test_loader):.4f}, Accuracy: {100 * correct / total:.2f}%')
  1. 开始训练

通过调用train函数进行训练,完成后用evaluate函数进行验证。

# 检查 CUDA 是否可用,并选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 将模型移动到选定的设备
model = model.to(device)

# 训练模型
train(model, train_loader, criterion, optimizer, num_epochs=10, device=device)

# 验证模型
evaluate(model, test_loader, criterion, device=device)
  1. 保存和加载模型

训练完成后,可以保存模型以便之后使用。

# 保存模型
torch.save(model.state_dict(), 'celeba_cnn.pth')

# 加载模型
model.load_state_dict(torch.load('celeba_cnn.pth'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2196459.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

进入猛增模式后,小米股价还剩下多少上涨空间?

猛兽财经核心观点: (1)小米集团的股价已经上涨到了2022年以来的最高点。 (2)股价从2023年的最低点上涨了185%以上。 (3)随着智能手机的需求反弹和电动汽车利润率的增长,猛兽财经认为…

YOLOv10改进策略【注意力机制篇】| NAM 即插即用模块,重新优化通道和空间注意力(含二次创新)

一、本文介绍 本文记录的是基于NAM模块的YOLOv10目标检测改进方法研究。 许多先前的研究专注于通过注意力操作捕获显著特征,但缺乏对权重贡献因素的考虑,而这些因素能够进一步抑制不重要的通道或像素。而本文利用NAM改进YOLOv10,通过权重的贡…

数字人直播违规被“封”,一文助你彻底解决!

随着数字人直播的日渐兴起,与之相关的各类消息逐渐进入到人们的视野之中,并开始成为众多企业、创业者以及技术爱好者所重点关注的对象。就目前的讨论情况来看,热度最高且讨论次数最多的便是数字人直播违规吗这一话题。 的确,从数字…

一个three三维 文字 粒子 着色器的作品用来感谢大家对github点星

一个three三维 文字 粒子 着色器的作品用来感谢大家对github点星 源链接:https://z2586300277.github.io/three-cesium-examples/#/codeMirror?navigationThreeJS&classifyshader&idtextStarShader 国内站点预览:http://threehub.cn github地…

CVE-2024-9014 pgAdmin4 OAuth2 client ID与secret敏感信息泄漏漏洞

文章目录 免责声明漏洞描述搜索语法漏洞复现nuclei修复建议 免责声明 本文章仅供学习与交流,请勿用于非法用途,均由使用者本人负责,文章作者不为此承担任何责任 漏洞描述 pgAdmin4 是开源数据库 PostgreSQL 的图形管理工具攻击者可构造恶意…

向量数据库!AI 时代的变革者还是泡沫?

向量数据库!AI 时代的变革者还是泡沫? 前言一、向量数据库的基本概念和原理二、向量数据库在AI中的应用场景三、向量数据库的优势和挑战四、向量数据库的发展现状和未来趋势五、向量数据库对AI发展的影响 前言 数据是 AI 的核心,而向量则是数…

一个设备不知道ip地址怎么办?应对策略来袭

在数字化时代,设备连接网络已成常态,IP地址作为设备的网络身份证,其重要性不言而喻。然而,面对设备IP地址遗失的困境,我们往往感到束手无策。 那么,一个设备不知道IP地址怎么办?本文将为你提供一…

中国通信技术革命史

文章目录 引言I 中国通信技术革命史电报中国卫星通信的历史固定电话寻呼机(BP机)大哥大(手机)制定自己的移动通信网络技术体系5G未来科技发展的总趋势:用更少的能量,传输、处理和存储更多的信息II 知识扩展通信史(单位能量的信息传输率越来越高,网络地不断融合。)超级智能…

秒杀系统的原则和注意项

做任何技术方案都需要结合当时的业务场景、资金情况、用户体量等维度综合考虑,没有最好的技术方案,只有最合适的技术方案。 做秒杀方案亦是如此,秒杀活动经常会引发高并发、系统宕机和库存超卖的棘手问题,作为开发者,我…

火情监测摄像机:守护生命与财产安全的“眼睛”

随着城市化进程的加快,火灾隐患日益增多。为了有效预防和及时应对火灾事故,火情监测摄像机应运而生,成为现代消防安全的重要组成部分。这种高科技设备不仅能够实时监控火灾发生,还能为救援提供宝贵的信息支持。火情监测摄像机主要…

vulnhub-THE PLANETS-EARTH靶机

下载并导入靶机至VMWare,设置网络模式为NAT,开机 开启攻击机(kali),也设置为Nat模式,与靶机处于同一网段 扫描靶机ip Nmap 192.168.114.0/24 扫描网段内活跃的主机 可以推断靶机ip为192.168.114.129 扫描…

什么是源代码加密?十种方法教你软件开发源代码加密

什么是源代码加密 源代码加密是一种安全措施,它通过加密技术对软件的源代码进行保护,以防止未授权的访问、泄露、篡改或逆向工程。源代码是软件程序的原始代码,通常由程序员编写,然后编译成可执行程序。由于源代码包含了软件的设…

攻防世界---->工业协议分析2

前言:做题笔记。 下载 PCAPNG 说明是一个网络数据包文件。 那么直接用Wireshark查看分析。 调整一下长度显示: 可以看到 ARP协议: UDP 进行通信。 长度都是58,我们去找变动点。 转: flag{7FoM2StkhePz} 题外话&…

画质修复哪个软件好?提升老旧照片画质的黑科技分享

朝霞好看?拍它!落日好看?拍它! 回头一翻相册才发现,只有那一小部分的光影好看,那就把它放大裁出来! 放大了画面,画质降低画面模糊了,反而没有肉眼看的画面好看了咋办&a…

COSPLAY大赛静态HTML网页模板源码

源码名称:COSPLAY大赛静态HTML网页模板 源码介绍:一款cosplay大赛HTML网页模板源码,过往参赛选手会自动从腾讯大赛获取,可用于cosplay大赛,漫展等。 需求环境:H5 下载地址: https://www.5188…

SpringBoot框架下旅游管理系统的创新设计与实现

第二章 相关技术简介 2.1 JAVA技术 本次系统开发采用的是面向对象的Java作为软件编程语言,Java表面上很像C,但是Java仅仅是继承了C的某些优点,程序员很少使用的C语言的特征在Java设计中去掉了。Java编程语言并没有什么结构,它把数…

Linux:进程调度算法和进程地址空间

✨✨✨学习的道路很枯燥,希望我们能并肩走下来! 文章目录 目录 文章目录 前言 一 进程调度算法 1.1 进程队列数据结构 1.2 优先级 ​编辑 1.3 活动队列 ​编辑 1.4 过期队列 1.5 active指针和expired指针 1.6 进程连接 二 进程地址空间 2.1 …

uniapp 游戏 - 使用 uniapp 实现的扫雷游戏

0. 思路 1. 效果图 2. 游戏规则 扫雷的规则很简单。盘面上有许多方格,方格中随机分布着一些雷。你的目标是避开雷,打开其他所有格子。一个非雷格中的数字表示其相邻 8 格子中的雷数,你可以利用这个信息推导出安全格和雷的位置。你可以用右键在你认为是雷的地方插旗(称为标…

AI赋能新质生产力医院管理项目成功举办

2024年9月27日,为进一步贯彻实施《2024年全国卫生健康工作会议》精神,提升医学诊断准确性,优化医院服务流程,并降低医疗成本,清华大学智慧医疗研究院联合北京整合医学学会,在郑州大学第一附属医院东院区成功…

Java实体对象转换利器MapStruct详解

概述 现在的JAVA项目多数采用分层结构,参考《阿里巴巴JAVA开发手册》。 分层之后,每一层都有自己的领域模型,即不同类型的 Bean:  DO ( Data Object ) :与数据库表结构一一对应,…