365天深度学习训练营 第P6周:好莱坞明星识别

news2024/11/28 11:36:19
  • 🍨 本文为🔗365天深度学习训练营 内部限免文章(版权归 K同学啊 所有)
  • 🍦 参考文章地址: 🔗第P6周:好莱坞明星识别 | 365天深度学习训练营
  • 🍖 作者:K同学啊 | 接辅导、程序定制

文章目录

  • 我的环境:
  • 一、前期工作
    • 1. 设置 GPU
    • 2. 导入数据
    • 3. 划分数据集
  • 二、调用vgg-16模型
  • 三、训练模型
    • 1. 设置超参数
    • 2. 编写训练函数
    • 3. 编写测试函数
    • 4. 正式训练
  • 四、结果可视化
    • 1.Loss 与 Accuracy 图

我的环境:

  • 语言环境:Python 3.6.8
  • 编译器:jupyter notebook
  • 深度学习环境:
    • torch==0.13.1、cuda==11.3
    • torchvision==1.12.1、cuda==11.3

一、前期工作

1. 设置 GPU

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision import transforms, datasets

if __name__=='__main__':
    ''' 设置GPU '''
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using {} device\n".format(device))
Using cuda device

2. 导入数据

import os, PIL, pathlib
data_dir = 'D:/jupyter notebook/DL-100-days/datasets/48-data/'
data_dir = pathlib.Path(data_dir)
data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[5] for path in data_paths]
print(classeNames)
['Angelina Jolie',
 'Brad Pitt',
 'Denzel Washington',
 'Hugh Jackman',
 'Jennifer Lawrence',
 'Johnny Depp',
 'Kate Winslet',
 'Leonardo DiCaprio',
 'Megan Fox',
 'Natalie Portman',
 'Nicole Kidman',
 'Robert Downey Jr',
 'Sandra Bullock',
 'Scarlett Johansson',
 'Tom Cruise',
 'Tom Hanks',
 'Will Smith']
train_transforms = transforms.Compose([
    transforms.Resize([224,224]),# resize输入图片
    transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换成tensor
    transforms.Normalize(
        mean = [0.485, 0.456, 0.406],
        std = [0.229,0.224,0.225]) # 从数据集中随机抽样计算得到
])
 
total_data = datasets.ImageFolder(data_dir,transform=train_transforms)
total_data
Dataset ImageFolder
    Number of datapoints: 1800
    Root location: hlw
    StandardTransform
Transform: Compose(
               Resize(size=[224, 224], interpolation=PIL.Image.BILINEAR)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

3. 划分数据集

train_size = int(0.8*len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data,[train_size,test_size])
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                       batch_size=batch_size,
                                       shuffle=True,
                                       num_workers=1)

二、调用vgg-16模型

from torchvision.models import vgg16
    
model = vgg16(pretrained = True).to(device)
for param in model.parameters():
    param.requires_grad = False
 
model.classifier._modules['6'] = nn.Linear(4096,len(classNames))
 
model.to(device)
# 查看要训练的层
params_to_update = model.parameters()
# params_to_update = []
for name,param in model.named_parameters():
    if param.requires_grad == True:
#         params_to_update.append(param)
        print("\t",name)

三、训练模型

1. 设置超参数

# 优化器设置
optimizer = torch.optim.Adam(params_to_update, lr=1e-4)#要训练什么参数/
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.92)#学习率每5个epoch衰减成原来的1/10
loss_fn = nn.CrossEntropyLoss()

2. 编写训练函数

# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共900张图片
    num_batches = len(dataloader)   # 批次数目,29(900/32)
 
    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches
 
    return train_acc, train_loss

3. 编写测试函数

def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)          # 批次数目,8(255/32=8,向上取整)
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()
 
    test_acc  /= size
    test_loss /= num_batches
 
    return test_acc, test_loss

4. 正式训练

epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []
best_acc = 0
filename='checkpoint.pth'
 
for epoch in range(epochs):
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    
    scheduler.step()#学习率衰减
    
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    # 保存最优模型
    if epoch_test_acc > best_acc:
        best_acc = epoch_train_acc
        state = {
            'state_dict': model.state_dict(),#字典里key就是各层的名字,值就是训练好的权重
            'best_acc': best_acc,
            'optimizer' : optimizer.state_dict(),
        }
        torch.save(state, filename)
        
        
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')
print('best_acc:',best_acc)

Epoch: 1, Train_acc:12.2%, Train_loss:2.701, Test_acc:13.9%,Test_loss:2.544
Epoch: 2, Train_acc:20.8%, Train_loss:2.386, Test_acc:20.6%,Test_loss:2.377
Epoch: 3, Train_acc:26.1%, Train_loss:2.228, Test_acc:22.5%,Test_loss:2.274

Epoch:19, Train_acc:51.6%, Train_loss:1.528, Test_acc:35.8%,Test_loss:1.864
Epoch:20, Train_acc:53.9%, Train_loss:1.499, Test_acc:35.3%,Test_loss:1.852
Done
best_acc: 0.37430555555555556

四、结果可视化

1.Loss 与 Accuracy 图

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率
 
epochs_range = range(epochs)
 
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
 
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
 
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/335705.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

扩散模型diffusion model用于图像恢复任务详细原理 (去雨,去雾等皆可),附实现代码

文章目录1. 去噪扩散概率模型2. 前向扩散3. 反向采样3. 图像条件扩散模型4. 可以考虑改进的点5. 实现代码1. 去噪扩散概率模型 扩散模型是一类生成模型, 和生成对抗网络GAN 、变分自动编码器VAE和标准化流模型NFM等生成网络不同的是, 扩散模型在前向扩散过程中对图像逐步施加噪…

如何为自己的应用选择数据库?有这些考虑因素

节选翻译自 Michal Toiba 的博客 微软前不久宣布推出分布式关系数据库 Azure Cosmos DB for PostgreSQL,使 Azure 成为第一个在单一数据库服务中同时支持关系和 NoSQL(非关系)数据的云平台。这意味着 Azure Cosmos DB 开发者在构建云原生应用…

只需三步,完成ChatGPT微信机器人搭建

大家好,我是可乐。 这两天 Chatgpt 又爆火了,去年12月份刚出来的时候,我写了两篇文章: ①、如何注册Chatgpt? ②、如何将 chatgpt接入微信? 然后沉寂了一个月,没想现在到又火了。本篇文章我…

计算机组成与体系结构 性能设计 William Stallings 第2章 性能问题

2.1 优化性能设计例如,当前需要微处理器强大功能的桌面应用程序包括:图像处理、三维渲染、语音识别、视频会议、多媒体创作、文件的声音和视频注释、仿真建模从计算机组成与体系结构的角度来看,一方面,现代计算机的基本组成与50多…

强大的ChatGpt为企业营销推广提供了全方位的加持

chatgpt,一个火出圈的“聊天机器人”。从写作文,到写代码,似乎没有什么是它干不了的。 ChatGpt在工业中的应用场景有哪些? 在工业领域,它可以用于提高生产效率,缩短生产周期,并帮助工人解决生产过程中的问…

优先级队列(PriorityQueue 和 Top-K问题)

一、PriorityQueue java中提供了两种优先级队列:PriorityQueue 和 PriorityBlockingQueue。其中 PriorityQueue 是线程不安全的,PriorityBolckingQueue 是线程安全的。 PriorityQueue 使用的是堆,且默认情况下是小堆——每次获取到的元素都是…

LeetCode刷题------字符串

LeetCode:344.反转字符串 定义两个指针(也可以说是索引下标),一个从字符串前面,一个从字符串后面,两个指针同时向中间移动,并交换元素。 var reverseString function(s) {let l -1, r s.len…

Windows下LuaBridge2.8的环境配置及简单应用

Windows下LuaBridge2.8的环境配置及简单应用 LuaBridge2.8下载链接: https://github.com/vinniefalco/LuaBridge/tags 关于Lua的环境配置可参考以下链接(这里不做简述): https://ufgnix0802.blog.csdn.net/article/details/125341…

什么是抗压能力?抗压能力的重要性及提高方法

在现实生活中每个人都面临着或大或小的压力。不论是学习还是工作,没有压力就没有进步和提升,压力可以转化为动力,而扛不住压力那就会导致躺平的结果。在职业的道路上,压力一直都存在,尤其是站在企业人力资源管理的角度…

开学季,关于校园防诈骗宣传,如何组织一场微信线上答题考试

开学季,关于校园防诈骗宣传,如何组织一场微信线上答题考试如何组织一场微信线上答题考试在线考试是一种非常节约成本的考试方式,考生通过微信扫码即可参加培训考试,不受时间、空间的限制,近几年越来越受企事业单位以及…

Navicat远程连接mysql教程及Navicat报错10061解决办法

文章目录想要远程连接的前提是数据库要配置密码,配置密码参考如下进入要连接的数据库看一下原有的配置,host 和 user修改配置授权root用户进行远程登录配置修改完成后,打开navicat,新建连接点击测试链接,显示连接后点击…

10个自动化测试框架,测试工程师用起来

软件行业正迈向自主、快速、高效的未来。为了跟上这个高速前进的生态系统的步伐,必须加快应用程序的交付时间,但不能以牺牲质量为代价。快速实现质量是必要的,因此质量保证得到了很多关注。为了满足卓越的质量和更快的上市时间的需求&#xf…

ThinkPad笔记本如何拆卸及安装电池

注意:外置电池可以自行拆卸,内置电池如需拆卸建议联系服务商,由工程师协助操作。 外置电池的拆卸方法: 方式一.部分机型只有一个锁扣,将锁扣移动到解锁状态,然后将电池向机身外侧抽出(以T420为例&#xf…

用Qt开发的ffmpeg流媒体播放器,支持截图、录像,支持音视频播放,支持本地文件播放、网络流播放

前言 本工程qt用的版本是5.8-32位,ffmpeg用的版本是较新的5.1版本。它支持TCP或UDP方式拉取实时流,实时流我采用的是监控摄像头的RTSP流。音频播放采用的是QAudioOutput,视频经ffmpeg解码并由YUV转RGB后是在QOpenGLWidget下进行渲染显示。本…

如何使用ModelScope训练自有的远场语音唤醒模型?

就像人和人交流时先会喊对方的名字一样,关键词就好比智能设备的"名字",而关键词检测模块则相当于交互流程的触发开关。 本文介绍魔搭社区中远场语音增强与唤醒一体化的语音唤醒模型的构成、体验方式,以及如何基于开发者自有数据进…

MySQL数据库07——高级条件查询

前面一章介绍了基础的一个条件的查询,如果多条件,涉及到逻辑运算,and or 之类的。就是高级一点的条件查询。本章来介绍复杂的条件搜索表达式。 AND运算符 AND运算符只有当两边操作数均为True时,最后结果才为True。人们使用AND描述…

高性能IO模型:为什么单线程Redis能那么快?

我们通常说Redis是单线程,主要是指Redis的网络IO和键值对读写是由一个线程来完成的。这也是Redis对外提供键值存储服务的主要流程。 但redis的其他功能,比如持久化、异步删除、集群数据同步等,其实是由额外的线程执行的。 Redis为什么用单线…

探讨MySQL事务特性和实现原理

一、概念 事务 一般指的是逻辑上的一组操作,或者作为单个逻辑单元执行的一系列操作,一个事务中的所有操作会被封装成一个不可分割的执行单元,这个单元的所有操作要么全部执行成功,要么全部执行失败,只要其中任意一个操…

《Terraform 101 从入门到实践》 第四章 States状态管理

《Terraform 101 从入门到实践》这本小册在南瓜慢说官方网站和GitHub两个地方同步更新,书中的示例代码也是放在GitHub上,方便大家参考查看。 军书十二卷,卷卷有爷名。 为什么需要状态管理 Terraform的主要作用是管理云平台上的资源&#xff…

个人学习系列 - 解决拦截器操作请求参数后台无法获取

由于项目需要使用拦截器对请求参数进行操作,可是请求流只能操作一次,导致后面方法不能再获取流了。 新建SpringBoot项目 1. 新建拦截器WebConfig.java /*** date: 2023/2/6 11:21* author: zhouzhaodong* description:*/ Configuration public class W…