Pytorch入门实战 P06-调用vgg16模型,进行人脸预测

news2024/10/7 12:22:55

目录

1、本文内容:

1、内容:

2、简单介绍下VGG16:

3、相关其他模型也可以调用:

2、代码展示:

3、训练结果:

1、不同优化器:

①【使用SGD优化器】

②【使用Adam优化器】

③Adam + 动态学习率ExponentialLR

④Adam + 动态学习率ExponentialLR+ 降低初始学习率(lr=0.001)

⑤Adam+动态学习率LinearLR

⑥Adam+动态学习率LinearLR+ 降低c初始学习率(lr=0.001)

4、总结


  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

1、本文内容:

1、内容:

这篇文章,主要是通过调用现有VGG16的模型,来完成人脸的预测。

这篇文章的亮点主要是提高测试集的精确度

2、简单介绍下VGG16:

VGG-16的主要特点:
        1、深度:VGG-16 = 16个卷积层+3个全连接层组成 ,因此具有相对较深的网络结构。这种深度有助于网络学习到更加抽象和复杂的特征。
        2、卷积层的设计:VGG-16的卷积层全部采用3x3的卷积核和步长为1的卷积操作,同时在卷积层之后都有接ReLU激活函数。这种设计的好处在于,通过堆叠多个较小的卷积核,可以提高网络的非线性建模能力,同时减少了参数数量, 从而降低了过拟合的风险。
        3、池化层:在卷积层之后,VGG-16使用最大池化层来减少特征图的空间尺寸,帮助提取更加显著的特征并减少计算量。
        4、全连接层:VGG-16在卷积层之后接有3个全连接层,最后一个全连接层输出与类别数相对应的向量,用于进行分类。

VGG-16结构说明:
        13个卷积层,分别用blockX-convX表示
        3个全连接层,用classifier表示
        5个池化层。

3、相关其他模型也可以调用:

 Pytorch官网链接地址

2、代码展示:

import copy
import pathlib
import warnings

import matplotlib.pyplot as plt
import torch
from PIL import Image
from torch import nn
from torchvision import datasets
from torchvision.models import vgg16, VGG16_Weights
from torchvision.transforms import transforms
import matplotlib as mpl

mpl.use('Agg')  # 在服务器上运行的时候,打开注释

# 检查设备
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

# 1、导入数据
data_dir = './data'
data_dir = pathlib.Path(data_dir)
print(data_dir)

data_paths = list(data_dir.glob('*'))
classNames = [str(path).split('/')[1] for path in data_paths]

# 图像预处理
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

total_data = datasets.ImageFolder('./data', transform=train_transforms)

# 划分数据集
train_size = int(0.8 * len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
print(train_size, test_size)  # 1440  360

# 数据加载
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

for X, y in test_dl:
    print('Shape of X [N,C,H,W]:', X.shape)  # [32, 3, 224, 224]
    print('Shape of y:', y.shape, y.dtype)  # torch.Size([32]) torch.int64
    break

# 调用官方的VGG-16模型
"""
    VGG-16的主要特点:
        1、深度:VGG-16 = 16个卷积层+3个全连接层组成 ,因此具有相对较深的网络结构。这种深度有助于网络学习到更加抽象和复杂的特征。
        2、卷积层的设计:VGG-16的卷积层全部采用3x3的卷积核和步长为1的卷积操作,同时在卷积层之后都有接ReLU激活函数。
                      这种设计的好处在于,通过堆叠多个较小的卷积核,可以提高网络的非线性建模能力,同时减少了参数数量,
                      从而降低了过拟合的风险。
        3、池化层:在卷积层之后,VGG-16使用最大池化层来减少特征图的空间尺寸,帮助提取更加显著的特征并减少计算量。
        4、全连接层:VGG-16在卷积层之后接有3个全连接层,最后一个全连接层输出与类别数相对应的向量,用于进行分类。

    VGG-16结构说明:
        13个卷积层,分别用blockX-convX表示
        3个全连接层,用classifier表示
        5个池化层。
"""
# 加载预训练模型,并且对模型进行微调。
model = vgg16(weights=VGG16_Weights.DEFAULT).to(device)

for param in model.parameters():
    param.requires_grad = False  # 冻结模型参数,这样在训练的时候只训练最后一层的参数。

# print("原始模型:",model)
# 修改classifier模块的第6层。即:(6): Linear(in_features=4096, out_features=1000, bias=True)
model.classifier[6] = nn.Linear(4096, len(classNames))  # 修改vgg16 模型中最后一层全连接层,输出目标类别个数
model.to(device)


# print("改完后的模型:",model)


# 编写训练函数
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    train_loss, train_acc = 0, 0

    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        # 计算预测误差
        pred = model(X)
        loss = loss_fn(pred, y)

        # 反向传播
        optimizer.zero_grad()  # grad梯度归零
        loss.backward()  # 反向传播
        optimizer.step()  # 每一步自动更新

        # 记录acc与loss
        train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    train_acc /= size
    train_loss /= num_batches

    return train_acc, train_loss


# 编写测试函数
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)

    test_acc, test_loss = 0, 0

    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)

            # 计算loss
            target_pred = model(imgs)
            loss = loss_fn(target_pred, target)

            test_loss += loss.item()
            test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc /= size
    test_loss /= num_batches
    return test_acc, test_loss


# 设置动态学习率
learn_rate = 1e-4  # 初始学习率

# 调用官方动态学习率接口时使用:
lambda1 = lambda epoch: 0.92 ** (epoch // 4)
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)  # 选定调整学习率的方法

# 正式训练
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数
epochs = 40
train_loss = []
train_acc = []
test_loss = []
test_acc = []

best_acc = 0  # 设置一个最佳准确率,作为最佳模型的判别指标。
for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    scheduler.step()  # 用于更新学习率(调用官网动态学习率接口的时候,在这里使用)

    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    # 保存最佳模型到best_model
    if epoch_test_acc > best_acc:
        best_acc = epoch_test_acc
        best_model = copy.deepcopy(model)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    # 获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']

    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f},|||| Test_acc:{:.1f}%,Test_loss:{:.3f}, Lr:{:.2E}')
    print(
        template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss, epoch_test_acc * 100, epoch_test_loss, lr))

# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)
print('Done')

# 结果可视化
warnings.filterwarnings('ignore')
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100  # 分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label="Training Accuracy")
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title("Training and Validation Accuracy")

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validataion Loss')
plt.savefig("/data/jupyter/deepinglearning_train_folder/p06_vgg16/resultImg.jpg")  # 保存图片在服务器的位置
plt.show()

# 指定图片进行预测
classes = list(total_data.class_to_idx)


def predict_one_image(image_path, model, transform, classes):
    test_img = Image.open(image_path).convert('RGB')
    plt.imshow(test_img)  # 展示预测图片

    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)

    model.eval()
    output = model(img)

    _, pred = torch.max(output, 1)
    print(_, pred)
    pred_class = classes[pred]
    print(f'预测结果:{pred_class}')


# 预测训练集中的某张照片
predict_one_image(image_path='./data/Angelina Jolie/001_fe3347c0.jpg', model=model, transform=train_transforms,
                  classes=classes)

# 评估模型
best_model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
print(f'模型评估:测试acc:{epoch_test_acc}-----测试Loss:{epoch_test_loss}')

3、训练结果:

1、不同优化器:

对比下,目前主流的两个优化器:SGD和Adam优化器。

①【使用SGD优化器】

测试精确度达到18%。

②【使用Adam优化器】

测试精确达到39%。

③Adam + 动态学习率ExponentialLR

测试精确度达到43%。

④Adam + 动态学习率ExponentialLR+ 降低初始学习率(lr=0.001)

测试精确度达到48%

⑤Adam+动态学习率LinearLR

测试精确度达到43%。

⑥Adam+动态学习率LinearLR+ 降低c初始学习率(lr=0.001)

测试精确度达到48%,最高可达到51%。

4、总结

①将SGD优化器换成Adam优化器,精确度会提升1倍。

②使用Adam+动态学习率(即:③、⑤,精度会再次提升。)

③使用③里的的配置,仅改变学习率(lr=1e-4→lr=1e-3),测试精度会再次提升,见④。

④使用⑤里的的配置,仅改变学习率(lr=1e-4→lr=1e-3),测试精度会再次提升,见⑥。

总结上述,测试精确度的提升,最大是优化器的改变、动态学习率、初始学习率的降低。

这些都会影响到模型的精确度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1601097.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python数据可视化:无向网络图

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 Python数据可视化: 无向网络图 [太阳]选择题 关于以下代码输出结果的说法中正确的是? import networkx as nx import matplotlib.pyplot as plt a [(A, B), (B, C), (B, D)] …

zabbix 自动发现与自动注册 部署 zabbix 代理服务器

zabbix 自动发现(对于 agent2 是被动模式) zabbix server 主动的去发现所有的客户端,然后将客户端的信息登记在服务端上。 缺点是如果定义的网段中的主机数量多,zabbix server 登记耗时较久,且压力会较大。1.确保客户端…

淘宝API商品详情数据在数据分析行业中具有不可忽视的重要性

淘宝商品详情数据在数据分析行业中具有不可忽视的重要性。这些数据为商家、市场分析师以及数据科学家提供了丰富的信息,有助于他们更深入地理解市场动态、消费者行为以及商品竞争态势。以下是淘宝商品详情数据在数据分析行业中的重要性体现: 请求示例&a…

Customizable Ghosts Pack

“可定制的幽灵包”为游戏开发商快速将幽灵角色融入游戏提供了坚实的基础。鬼角色的标准解决方案。 Customizable Ghost Pack: “可自定义的幽灵包”为游戏开发商快速将幽灵怪物集成到游戏中提供了坚实的基础。鬼角色的标准解决方案。 关键功能 ⭐怪物创造者工具。 ⭐完全…

浅谈Spring的Bean生命周期

在Spring框架中,Bean(即Java对象)的生命周期涵盖了从创建到销毁的全过程,主要包含以下几个阶段: 实例化(Instantiation): 当Spring IoC容器需要创建一个Bean时,首先会通过…

HCIA--综合实验(超详细)

要求: 1. 使用172.16.0.0/16划分网络 2.使用ospf协议合理规划区域保证更新安全 3.加快收敛速度 4. r1为DR没有BDR 5.PC2,3,4,5自动获取IP地址;PC1为外网,PC要求可用互相访问 6.r7为运营商,只能配…

[沫忘录]MySQL索引

[沫忘录]MySQL索引 索引概述 优点 提高数据检索效率,降低数据库IO成本通过索引对数据进行排序,降低数据排序成本,降低CPU消耗 缺点 索引会占用一定空间当更新数据时,也需更新索引数据,这会降低数据的更新效率 索引…

Adobe AE(After Effects)2023下载地址及安装教程

Adobe After Effects是一款专业级别的视觉效果和动态图形处理软件,由Adobe Systems开发。它被广泛用于电影、电视节目、广告和其他多媒体项目的制作。 After Effects提供了强大的合成和特效功能,可以让用户创建出令人惊艳的动态图形和视觉效果。用户可以…

LabVIEW变速箱自动测试系统

LabVIEW变速箱自动测试系统 在农业生产中,采棉机作为重要的农用机械,其高效稳定的运行对提高采棉效率具有重要意义。然而,传统的采棉机变速箱测试方法存在测试效率低、成本高、对设备可能产生损害等问题。为了解决这些问题,开发了…

[docker] 镜像部分补充

[docker] 镜像部分补充 这里补充一下比较少用的&#xff0c;关于镜像的内容 检查镜像 ❯ docker images REPOSITORY TAG IMAGE ID CREATED SIZE <none> <none> ca61c1748170 2 hours ago 1.11GB node latest 5212d…

数据中心IP代理VS住宅代理IP,区别详解

一、什么是数据中心/机房IP代理&#xff1f; 数据中心/机房IP代理是使用数据中心拥有并进行分配和管理的IP的代理&#xff0c;俗称机房IP代理。 二、数据中心/机房IP代理的特点 与住宅代理通过使用ISP拥有和分配的IP地址的设备路由请求的情况不同&#xff0c;数据中心代理利…

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单人脸检测/识别实战案例 之一 简单人脸识别

Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单人脸检测/识别实战案例 之一 简单人脸识别 目录 Python 基于 OpenCV 视觉图像处理实战 之 OpenCV 简单人脸检测/识别实战案例 之一 简单人脸识别 一、简单介绍 二、简单人脸识别实现原理 三、简单人脸识别案例实现简单步…

陇剑杯 流量分析 webshell CTF writeup

陇剑杯 流量分析 链接&#xff1a;https://pan.baidu.com/s/1KSSXOVNPC5hu_Mf60uKM2A?pwdhaek 提取码&#xff1a;haek目录结构 LearnCTF ├───LogAnalize │ ├───linux简单日志分析 │ │ linux-log_2.zip │ │ │ ├───misc日志分析 │ │ …

【蓝桥杯嵌入式】串口通信与RTC时钟

【蓝桥杯嵌入式】串口通信与RTC时钟 串口通信cubemx配置串口通信程序设计 RTC时钟cubemx配置程序设计 串口通信 cubemx配置 打开串口通信&#xff0c;并配置波特率为9600 打开串口中断 重定义串口接收与发送引脚&#xff0c;默认是PC4&#xff0c;PC5&#xff0c;需要改为P…

2024 CKA 基础操作教程(十四)

题目内容 设置配置环境&#xff1a; [candidatenode-1] $ kubectl config use-context mk8s Task 现有的 Kubernetes 集群正在运行版本 1.29.0。仅将 master 节点上的所有 Kubernetes 控制平面和节点组件升级到版本 1.29.1。 确保在升级之前 drain master 节点&#xff0c…

强强联手|AI赋能智能工业化,探索AI在工业领域的应用

随着人工智能&#xff08;AI&#xff09;技术的不断发展和应用&#xff0c;AI在各个领域展现出了巨大的潜力和价值。在工业领域&#xff0c;AI的应用也越来越受到关注。AI具备了丰富的功能和强大的性能&#xff0c;为工业领域的发展带来了巨大的机遇和挑战。 YesPMP是专业的互联…

IAM 统一身份认证与访问管理服务

即统一身份认证与访问管理服务&#xff0c;是云服务商提供的一套云上身份管理解决方案&#xff0c;可帮助企业安全地管理云上资源的访问权限。 在当今云计算时代&#xff0c;企业越来越依赖云服务来存储和处理敏感数据。然而&#xff0c;这也带来了新的安全挑战&#xff0c;即…

1 GBDT:梯度提升决策树

1 前言 前面简单梳理的基本的决策树算法&#xff0c;那么如何更好的使用这个基础算法模型去优化我们的结果是本节要探索的主要内容。 梯度提升决策树&#xff08;Gradient Boosting Decision Trees&#xff09;是一种集成学习方法&#xff0c;通常用于解决回归和分类问题。它通…

v-for中涉及的key

一、为什么要用key&#xff1f; key可以标识列表中每个元素的唯一性&#xff0c;方便Vue高效地更新虚拟DOM&#xff1b;key主要用于dom diff算法&#xff0c;diff算法是同级比较&#xff0c;比较当前标签上的key和标签名&#xff0c;如果都一样&#xff0c;就只移动元素&#…

【原创教程】海为PLC与RS-WS-ETH-6传感器的MUDBUS_TCP通讯

一、关于RS-WS-ETH-6传感器的准备工作 要完成MODBUS_TCP通讯,我们必须要知道设备的IP地址如何分配,只有PLC和设备的IP在同一网段上,才能建立通讯。然后还要选择TCP的工作模式,来建立设备端和PC端的端口号。接下来了解设备的报文格式,方便之后发送报文完成数据交互。 1、…