P6周人脸识别

news2024/11/17 17:29:50

一、前期准备

1.设置GPU

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os,PIL,pathlib,warnings

warnings.filterwarnings("ignore")             #忽略警告信息

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

2.导入数据

import os,PIL,random,pathlib

data_dir = './6-data/'
data_dir = pathlib.Path(data_dir)

data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
classeNames

3.数据处理

 图片处理

# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([
    transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
    # transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
    transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])

total_data = datasets.ImageFolder("./6-data/",transform=train_transforms)
total_data

划分数据集(4:1)

total_data.class_to_idx
得到映射如下
{'Angelina Jolie': 0,
 'Brad Pitt': 1,
 'Denzel Washington': 2,
 'Hugh Jackman': 3,
 'Jennifer Lawrence': 4,
 'Johnny Depp': 5,
 'Kate Winslet': 6,
 'Leonardo DiCaprio': 7,
 'Megan Fox': 8,
 'Natalie Portman': 9,
 'Nicole Kidman': 10,
 'Robert Downey Jr': 11,
 'Sandra Bullock': 12,
 'Scarlett Johansson': 13,
 'Tom Cruise': 14,
 'Tom Hanks': 15,
 'Will Smith': 16
}
t
划分数据集
rain_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
train_dataset, test_dataset
加入dataloder中
batch_size = 32

train_dl = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=1)
for X, y in test_dl:
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

二、构建网络

这次直接调用官方VGG

from torchvision.models import vgg16

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
    
# 加载预训练模型,并且对模型进行微调
model = vgg16(pretrained = True).to(device) # 加载预训练的vgg16模型

for param in model.parameters():
    param.requires_grad = False # 冻结模型的参数,这样子在训练的时候只训练最后一层的参数

# 修改classifier模块的第6层(即:(6): Linear(in_features=4096, out_features=2, bias=True))
# 注意查看我们下方打印出来的模型
model.classifier._modules['6'] = nn.Linear(4096,len(classeNames)) # 修改vgg16模型中最后一层全连接层,输出目标类别个数
model.to(device)  
model

三、训练模型

1.设置学习率

# def adjust_learning_rate(optimizer, epoch, start_lr):
#     # 每 2 个epoch衰减到原来的 0.98
#     lr = start_lr * (0.92 ** (epoch // 2))
#     for param_group in optimizer.param_groups:
#         param_group['lr'] = lr

learn_rate = 1e-4 # 初始学习率
optimizer  = torch.optim.SGD(model.parameters(), lr=learn_rate)

2.编写训练函数

# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小
    num_batches = len(dataloader)   # 批次数目, (size/batch_size,向上取整)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches

    return train_acc, train_loss

3.编写测试函数

def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小
    num_batches = len(dataloader)          # 批次数目, (size/batch_size,向上取整)
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc  /= size
    test_loss /= num_batches

    return test_acc, test_loss

4.正式训练

import copy

loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
epochs     = 40

train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []

best_acc = 0    # 设置一个最佳准确率,作为最佳模型的判别指标

for epoch in range(epochs):
    # 更新学习率(使用自定义学习率时使用)
    # adjust_learning_rate(optimizer, epoch, learn_rate)
    
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)
    
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    # 保存最佳模型到 best_model
    if epoch_test_acc > best_acc:
        best_acc   = epoch_test_acc
        best_model = copy.deepcopy(model)
    
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    # 获取当前的学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, 
                          epoch_test_acc*100, epoch_test_loss, lr))
    
# 保存最佳模型到文件中
PATH = './best_model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)

print('Done')

四、结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

from PIL import Image 

classes = list(total_data.class_to_idx)

def predict_one_image(image_path, model, transform, classes):
    
    test_img = Image.open(image_path).convert('RGB')
    plt.imshow(test_img)  # 展示预测的图片

    test_img = transform(test_img)
    img = test_img.to(device).unsqueeze(0)
    
    model.eval()
    output = model(img)

    _,pred = torch.max(output,1)
    pred_class = classes[pred]
    print(f'预测结果是:{pred_class}')

# 预测训练集中的某张照片
predict_one_image(image_path='./6-data/Angelina Jolie/001_fe3347c0.jpg', 
                  model=model, 
                  transform=train_transforms, 
                  classes=classes)

五、模型评估

best_model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
epoch_test_acc, epoch_test_loss
# 查看是否与我们记录的最高准确率一致
epoch_test_acc

六,总结

构建数据集中

这次构建数据集还是直接调用的,但是遇到了一个问题就是

classeNames = [str(path).split("\\")[1] for path in data_paths]

替换为
classeNames = [os.path.basename(str(path)) for path in data_paths]
这两行代码的目的都是从一系列路径中提取某些信息,但它们使用不同的方法来实现这一目的。让我们来详细比较一下:

1. `classeNames = [str(path).split("\\")[1] for path in data_paths]`

   这行代码尝试对每个 `path` 对象执行以下操作:
   - 将 `path` 对象转换为字符串。
   - 使用反斜杠 `\` 作为分隔符来分割字符串。
   - 尝试访问分割后的列表中索引为 `1` 的元素,即第二个元素。

   这种方法假设所有路径都使用反斜杠 `\` 作为分隔符,并且每个路径至少包含两个由 `\` 分隔的元素。如果路径中没有足够的分隔符来产生至少两个元素,或者路径使用的是正斜杠 `/` 作为分隔符(在Unix-like系统中很常见),这将导致 `IndexError`。

2. `classeNames = [os.path.basename(str(path)) for path in data_paths]`

   这行代码使用 `os.path.basename` 函数来获取每个 `path` 对象的基名称(即路径的最后一部分)。`os.path.basename` 是一个平台独立的函数,它自动处理不同操作系统的路径分隔符。这意味着无论您在什么操作系统上运行代码,`os.path.basename` 都能正确地返回路径的最后一部分。

   这种方法不依赖于路径中分隔符的数量,因此不会产生 `IndexError`。即使路径中只有一个元素或没有使用反斜杠 `\` 作为分隔符,`os.path.basename` 也会返回正确的结果。

### 为什么后者不报错?

- **平台独立性**:`os.path.basename` 函数适用于所有操作系统,自动处理不同的路径分隔符。
- **错误处理**:`os.path.basename` 函数设计得更加健壮,即使路径只有一个元素或没有元素,它也能正确地返回空字符串或路径本身。
- **语义清晰**:`os.path.basename` 的用途非常明确,即获取路径的基名称,这使得代码更易于理解和维护。

### 总结

使用 `os.path.basename` 是更安全、更清晰的方法,特别是在处理来自不同操作系统的路径时。它减少了因路径分隔符引起的错误,并使代码更加健壮和易于理解。因此,推荐在处理文件路径时使用 `os.path` 模块中的函数,而不是手动分割字符串。

训练模型中

静态learn_rate = 1e-4,epoch=40

静态learn_rate = 1e-2,epoch=40

动态 learn_rate = 1e-2,epoch=30

结论,在想同epoch---learnrate下,使用动态学习了吧有提升

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1642860.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SQL:NOT IN与NOT EXISTS不等价

在对SQL语句进行性能优化时,经常用到一个技巧是将IN改写成EXISTS,这是等价改写,并没有什么问题。问题在于,将NOT IN改写成NOT EXISTS时,结果未必一样。 目录 一、举例验证二、三值逻辑简述三、附录:用到的S…

SpringBoot自定义定时任务

通常,在我们的项目中需要定时给前台发送一些提示性消息或者我们想要的定时信息,这个时候就需要使用定时任务来实现这一功能,实现也很简单,接下来具体来看看吧~ 简单定时任务 首先,你需要在你的启动类上加上开启定时任…

【人工智能Ⅱ】实验5:自然语言处理实践(情感分类)

实验5:自然语言处理实践(情感分类) 一:实验目的与要求 1:掌握RNN、LSTM、GRU的原理。 2:学习用RNN、LSTM、GRU网络建立训练模型,并对模型进行评估。 3:学习用RNN、LSTM、GRU网络做…

Photoshop中图像编辑的基本操作

Photoshop中图像编辑的基本操作 Photoshop中调整图像窗口大小Photoshop中辅助工具的使用网格的使用标尺的使用注释工具的使用 Photoshop中置入嵌入式对象Photoshop中图像与画布的调整画布大小的修改画布的旋转图像尺寸的修改 Photoshop中撤销与还原采用快捷键进行撤销与还原采用…

基于MQTT通信开发的失物招领小程序

项目架构设计 这个项目采用前后端分离的方式,重新设计了两条链路来支撑程序的信息获取和传递 前端的小程序页面再启动页面渲染时,直接通过DBAPI从后端数据库获取信息,直接渲染在小程序中项目中给DBAPI的定位是快速从后端获取信息&#xff0…

信息系统项目管理师0087:组织系统(6项目管理概论—6.2项目基本要素—6.2.6组织系统)

点击查看专栏目录 文章目录 6.2.6组织系统1.治理框架2.管理要素3.组织结构类型6.2.6组织系统 项目运行时会受到项目所在的组织结构和治理框架的影响与制约。为有效且高效地开展项目,项目经理需要了解组织内的组织机构及职责分配情况,帮助自己有效地利用其权力、影响力、能力、…

5分钟速通大语言模型(LLM)的发展与基础知识

✍️ 作者:哈哥撩编程(视频号同名) 博客专家全国博客之星第四名超级个体COC上海社区主理人特约讲师谷歌亚马逊演讲嘉宾科技博主极星会首批签约作者 🏆 推荐专栏: 🏅 程序员:职场关键角色通识宝…

【编程题-错题集】chika 和蜜柑(排序 / topK)

牛客对于题目链接&#xff1a;chika和蜜柑 (nowcoder.com) 一、分析题目 排序 &#xff1a;将每个橘⼦按照甜度由高到低排序&#xff0c;相同甜度的橘子按照酸度由低到高排序&#xff0c; 然后提取排序后的前 k 个橘子就好了。 二、代码 1、看题解之前AC的代码 #include <…

ttkbootstrap界面美化系列之PanedWindow(七)

在界面设计中经常用PanedWindow控件来对整个界面进行切割布局&#xff0c;让整个界面看上去有层次感&#xff0c;不至于说杂乱无章。在我之前的文章中有对tkinter的该控件做了详细的介绍&#xff0c;链接如下基于Tkinter的PanedWindow组件进行窗口布局-CSDN博客 本文主要是介绍…

Python-VBA函数之旅-oct函数

目录 一、oct函数的常见应用场景 二、oct函数使用注意事项 三、如何用好oct函数&#xff1f; 1、oct函数&#xff1a; 1-1、Python&#xff1a; 1-2、VBA&#xff1a; 2、推荐阅读&#xff1a; 个人主页&#xff1a;神奇夜光杯-CSDN博客 一、oct函数的常见应用场景 oc…

批量抓取某电影网站的下载链接

思路&#xff1a; 进入电影天堂首页&#xff0c;提取到主页面中的每一个电影的背后的那个urL地址 a. 拿到“2024必看热片”那一块的HTML代码 b. 从刚才拿到的HTML代码中提取到href的值访问子页面&#xff0c;提取到电影的名称以及下载地址 a. 拿到子页面的页面源代码 b. 数据提…

Linux理解文件操作 文件描述符fd 理解重定向 dup2 缓冲区 C语言实现自己的shell

文章目录 前言一、文件相关概念与操作1.1 open()1.2 close()1.3 write()1.4 read()1.4 写入的时候先清空文件内容再写入1.5 追加&#xff08;a && a&#xff09; 二、文件描述符2.1 文件描述符 fd 0 1 2 的理解2.2 FILE结构体&#xff1a;的源代码 三、深入理解文件描述…

Vue 开发中的一些问题简单记录,Cannot find module ‘webpack/lib/RuleSet‘

您好&#xff0c;我是码农飞哥&#xff08;wei158556&#xff09;&#xff0c;感谢您阅读本文&#xff0c;欢迎一键三连哦。 &#x1f4aa;&#x1f3fb; 1. Python基础专栏&#xff0c;基础知识一网打尽&#xff0c;9.9元买不了吃亏&#xff0c;买不了上当。 Python从入门到精…

【C语言】项目实践-贪吃蛇小游戏(Windows环境的控制台下)

一.游戏要实现基本的功能&#xff1a; • 贪吃蛇地图绘制 • 蛇吃食物的功能 &#xff08;上、下、左、右方向键控制蛇的动作&#xff09; • 蛇撞墙死亡 • 蛇撞自身死亡 • 计算得分 • 蛇身加速、减速 • 暂停游戏 二.技术要点 C语言函数、枚举、结构体、动态内存管…

【Java探索之旅】内部类 静态、实例、局部、匿名内部类全面解析

文章目录 &#x1f4d1;前言一、内部类1.1 概念1.2 静态内部类1.3 实例内部类1.4 局部内部类1.5 匿名内部类 &#x1f324;️全篇总结 &#x1f4d1;前言 在Java编程中&#xff0c;内部类是一种强大的特性&#xff0c;允许在一个类的内部定义另一个类&#xff0c;从而实现更好的…

Rust web简单实战

一、使用async搭建简单的web服务 1、修改cargo.toml文件添加依赖 [dependencies] futures "0.3" tokio { version "1", features ["full"] } [dependencies.async-std] version "1.6" features ["attributes"]2、搭…

网络攻击(Cyber Attacks)

目录 1.概念 2.分类 3.总结 1.概念 网络攻击&#xff08;Cyber Attacks&#xff0c;也称赛博攻击&#xff09;是指针对计算机信息系统、基础设施、计算机网络或个人计算机设备的&#xff0c;任何类型的进攻动作。对于计算机和计算机网络来说&#xff0c;破坏、揭露、修改、使…

【C++】STL — List的接口讲解 +详细模拟实现

前言&#xff1a; 本章我们将学习STL中另一个重要的类模板list… list是可以在常数范围内在任意位置进行插入和删除的序列式容器&#xff0c;并且该容器可以前后双向迭代。list的底层是带头双向循环链表结构&#xff0c;双向链表中每个元素存储在互不相关的独立节点中&#xf…

从零到屎山系列-游戏开发(Day2)

简介 这次就来一个比较简单的小游戏贪吃蛇 贪吃蛇 游戏规则就是一串珠子不断的移动&#xff0c;碰到场景里面的食物变长一点&#xff0c;碰到墙壁游戏结束。 开始动手 设计绘制设备 首先我计划从一个控制台游戏开始&#xff0c;需要一个控制台下的绘图机制&#xff0c;希…

基于Sen+MK的多站点不同季节和年尺度的SPEI趋势分析.md

再大的风浪&#xff0c;不过只短暂喧哗。 文章目录 前言1. 概述2.1 问题情景2.2 说明 2. 版本2.1 天津&#xff0c;2024年5月4日&#xff0c;Version1 3. 微信公众号GISRSGeography 一、数据1. 输入数据2. 输出数据 二、程序代码三、参考资料 前言 1. 概述 2.1 问题情景 假…