pytorch实现运动鞋分类

news2025/1/24 11:45:11
  •  🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章地址: 365天深度学习训练营-第P5周:运动鞋分类
  • 🍖 作者:K同学啊

一、前期准备

1. 设置GPU

import torch
from torch import nn
import torchvision
from torchvision import transforms,datasets,models
import matplotlib.pyplot as plt
import os,PIL,pathlib
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')

2. 导入数据

data_dir = './data/'
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*/*.jpg')))
print("图片总数为:",image_count)
图片总数为: 578
classNames = [str(path).split('\\')[2] for path in data_dir.glob('train/*/')]
classNames
['adidas', 'nike']
roses= list(data_dir.glob('train/nike/*.jpg'))
PIL.Image.open(str(roses[0]))

 

3. 数据增强 解决过拟合

train_transforms = transforms.Compose([
        transforms.Resize([224, 224]),
       transforms.RandomRotation(45),#随机旋转,-45到45度之间随机选
#       transforms.CenterCrop(224),#从中心开始裁剪
        transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率概率
#         transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
#         transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
#         transforms.RandomGrayscale(p=0.025),#概率转换成灰度率,3通道就是R=G=B
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])#均值,标准差
    ])

test_transforms = transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

batch_size = 32

train_dataset = datasets.ImageFolder('./data/train/', transform = train_transforms)
test_dataset = datasets.ImageFolder('./data/test/', transform = test_transforms)


train_dl = torch.utils.data.DataLoader(train_dataset, 
                                       batch_size=batch_size, 
                                       shuffle=True, 
                                       num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,
                                      batch_size=batch_size, 
                                      shuffle=True, 
                                      num_workers=1)
classNames = train_dataset.classes

train_dataset.class_to_idx
{'adidas': 0, 'nike': 1}

4. 数据可视化

imgs, labels = next(iter(train_dl))
imgs.shape
import numpy as np

 # 指定图片大小,图像大小为20宽、5高的绘图(单位为英寸inch)
plt.figure(figsize=(20, 5)) 
for i, imgs in enumerate(imgs[:20]):
    npimg = imgs.numpy().transpose((1,2,0))
    npimg = npimg * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    npimg = npimg.clip(0, 1)
    # 将整个figure分成2行10列,绘制第i+1个子图。
    plt.subplot(2, 10, i+1)
    plt.imshow(npimg)
    plt.axis('off')

for X,y in test_dl:
    print('Shape of X [N, C, H, W]:', X.shape)
    print('Shape of y:', y.shape)
    break
Shape of X [N, C, H, W]: torch.Size([32, 3, 224, 224])
Shape of y: torch.Size([32])

二、构建CNN网络

2.1 搭建简单网络

搭建简单网络后发现由于数据量少导致过拟合,数据增强后最高准确率84%,说明模型不够好,选择改用Resnet18+迁移学习:

2.2 迁移学习

2.2.1 调用resnet18和预训练模型、冻结参数 

feature_extract = True

# 冻结参数
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False


# 修改输出层
def initialize_model(num_classes, feature_extract, use_pretrained=True):
    
    model_ft = models.resnet18(pretrained=use_pretrained)
    set_parameter_requires_grad(model_ft, feature_extract)
    
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, num_classes)
                            
    input_size = 32

    return model_ft, input_size
model_ft, input_size = initialize_model(2, feature_extract, use_pretrained=True)
model_ft = model_ft.to(device)
model_ft

2.2.2 取出输出层参数 

取出输出层参数  后面用于训练更新 

# 设置训练哪些层
params_to_update = model_ft.parameters()
print("Params to learn:")
if feature_extract: # 自己只训练输出层
    params_to_update = []
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            params_to_update.append(param)
            print("\t",name)
else:
    for name,param in model_ft.named_parameters():
        if param.requires_grad == True:
            print("\t",name)
Params to learn:
	 fc.weight
	 fc.bias

三、训练模型

3.1 设置超参数

动态学习率 

# 优化器设置
optimizer = torch.optim.Adam(params_to_update, lr=1e-4)#要训练啥参数,你来定
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.92)#学习率每7个epoch衰减成原来的1/10
loss_fn = nn.CrossEntropyLoss()
# def adjust_learning_rate(optimizer, epoch, start_lr):
#     # 每2个 epoch衰减到原来的0.98
#     lr = start_lr * (0.92 ** (epoch //2))
#     for param_group in optimizer.param_groups:
#         param_group['lr'] = lr
        
# optimizer = torch.optim.Adam(params_to_update,lr=1e-4)

3.2 编写训练函数

# 训练循环
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集的大小,一共900张图片
    num_batches = len(dataloader)   # 批次数目,29(900/32)

    train_loss, train_acc = 0, 0  # 初始化训练损失和正确率
    
    for X, y in dataloader:  # 获取图片及其标签
        X, y = X.to(device), y.to(device)
        
        # 计算预测误差
        pred = model(X)          # 网络输出
        loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失
        
        # 反向传播
        optimizer.zero_grad()  # grad属性归零
        loss.backward()        # 反向传播
        optimizer.step()       # 每一步自动更新
        
        # 记录acc与loss
        train_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()
            
    train_acc  /= size
    train_loss /= num_batches

    return train_acc, train_loss

3.3 编写测试函数

def test (dataloader, model, loss_fn):
    size        = len(dataloader.dataset)  # 测试集的大小,一共10000张图片
    num_batches = len(dataloader)          # 批次数目,8(255/32=8,向上取整)
    test_loss, test_acc = 0, 0
    
    # 当不进行训练时,停止梯度更新,节省计算内存消耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            
            # 计算loss
            target_pred = model(imgs)
            loss        = loss_fn(target_pred, target)
            
            test_loss += loss.item()
            test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()

    test_acc  /= size
    test_loss /= num_batches

    return test_acc, test_loss

3.4 正式训练

3.4.1 训练输出层

epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []
best_acc = 0
filename='checkpoint.pth'


for epoch in range(epochs):
    model_ft.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model_ft, loss_fn, optimizer)
    
    scheduler.step()#学习率衰减
    
    model_ft.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model_ft, loss_fn)
    
    # 保存最优模型
    if epoch_test_acc > best_acc:
        best_acc = epoch_test_acc
        state = {
            'state_dict': model_ft.state_dict(),#字典里key就是各层的名字,值就是训练好的权重
            'best_acc': best_acc,
            'optimizer' : optimizer.state_dict(),
        }
        torch.save(state, filename)
        
        
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')
print('best_acc:',best_acc)
Epoch:17, Train_acc:65.9%, Train_loss:0.630, Test_acc:67.1%,Test_loss:0.615
Epoch:18, Train_acc:66.1%, Train_loss:0.613, Test_acc:64.5%,Test_loss:0.599
Epoch:19, Train_acc:63.7%, Train_loss:0.636, Test_acc:65.8%,Test_loss:0.579
Epoch:20, Train_acc:66.3%, Train_loss:0.612, Test_acc:65.8%,Test_loss:0.583
Done
best_acc: 0.6593625498007968

3.4.2 训练所有层

for param in model_ft.parameters():
    param.requires_grad = True

# 再继续训练所有的参数,学习率调小一点
optimizer = torch.optim.Adam(model_ft.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.92)

# 损失函数
criterion = nn.CrossEntropyLoss()
# 加载之前训练好的权重参数
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
epochs     = 20
train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []
best_acc = 0
filename='best_resnet18.pth'


for epoch in range(epochs):
    model_ft.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model_ft, loss_fn, optimizer)
    
    scheduler.step()#学习率衰减
    
    model_ft.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model_ft, loss_fn)
    
    # 保存最优模型
    if epoch_test_acc > best_acc:
        best_acc = epoch_test_acc
        state = {
            'state_dict': model_ft.state_dict(),#字典里key就是各层的名字,值就是训练好的权重
            'best_acc': best_acc,
            'optimizer' : optimizer.state_dict(),
        }
        torch.save(state, filename)
        
        
    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)
    
    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Done')
print('best_acc:',best_acc)
Epoch:18, Train_acc:99.8%, Train_loss:0.010, Test_acc:86.8%,Test_loss:0.398
Epoch:19, Train_acc:99.0%, Train_loss:0.031, Test_acc:93.4%,Test_loss:0.203
Epoch:20, Train_acc:99.2%, Train_loss:0.019, Test_acc:93.4%,Test_loss:0.184
Done
best_acc: 0.9342105263157895

四、结果可视化

加载训练好的模型

model_ft, input_size = initialize_model(2, feature_extract, use_pretrained=True)

# GPU模式
model_ft = model_ft.to(device)

# 保存文件的名字
filename='best_resnet18.pth'

# 加载模型
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

结果可视化

import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

测试模型

train_on_gpu = True

# 得到一个batch的测试数据
imgs, labels = next(iter(train_dl))

# 进行预测
model_ft.eval()

if train_on_gpu:
    output = model_ft(imgs.cuda())
else:
    output = model_ft(imgs)


# 获得预测结果(概率最大的)
_, preds_tensor = torch.max(output, 1)

preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())
preds
array([0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1,
       0, 0, 0, 0, 1, 0, 1, 0, 0, 1], dtype=int64)
import numpy as np

 # 指定图片大小,图像大小为20宽、5高的绘图(单位为英寸inch)
plt.figure(figsize=(20, 10)) 
for idx, imgs in enumerate(imgs[:10]):
    #ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    npimg = imgs.numpy().transpose((1,2,0))
    npimg = npimg * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
    npimg = npimg.clip(0, 1)
    # 将整个figure分成2行10列,绘制第i+1个子图。
    ax = plt.subplot(2, 5, idx+1)
    ax.set_title("{} ({})".format(classNames[preds[idx]], classNames[labels[idx]]),
                 color=("green" if classNames[preds[idx]]==classNames[labels[idx]] else "red"))
    plt.imshow(npimg)
    plt.axis('off')

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/91363.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Batch 批处理-作业监听器

引言 接着上篇:Spring Batch 批处理-作业增量参数,了解作业参数增量器后,本篇就来了解一下Spirng Batch 作业监听器,看能玩出啥花样。 作业监听器 Spring Batch 步骤/作业的设计延续Spring传统设计模式,加入生命周期…

「Electron|快速开始」来写个Hello World桌面应用吧

本文主要介绍如何快速使用Electron生成一个Hello World应用 文章目录主要步骤一、准备工作创建项目安装electron二、编写electron应用所需的基本内容首先,我们需要给electron应用一个入口创建窗口往窗口里面放一个HTML界面,写上"Hello World!"…

计算机网络技术-常见网络命令

文档下载:https://download.csdn.net/download/weixin_57836618/87294136 实验2 常见网络命令 1. 实验目的与意义 ① 通过实验熟悉与网络相关的组件的含义和用途。 ② 了解系统网络命令的含义、用途和操作方法。 ③ 能够查看网络的状态,对网络进行简…

DEiT实战:使用DEiT实现图像分类任务(一)

DEiT实战摘要安装包安装timm数据增强Cutout和MixupEMA项目结构计算mean和std生成数据集摘要 DEiT是FaceBook在2020年提出的一篇Transformer模型。该模型解决了Transformer难以训练的问题,三天内使用4块GPU,完成了ImageNet的训练,并且没有使用…

mPEG-N3;mPEG-Azide;甲氧基聚乙二醇叠氮CAS:89485-61-0

叠氮化物/叠氮基官能化的甲氧基聚乙二醇(mPEG-N3)是一种单官能PEG衍生物,可用于修饰蛋白质,肽和其他材料。 叠氮化物基团可以在铜催化的水溶液中与炔烃反应。 也可以容易地还原成胺基。 名称 甲氧基聚乙二醇叠氮 mPEG-N3 别称 甲…

周志华 《机器学习初步》模型评估与选择

周志华 《机器学习初步》模型评估与选择 Datawhale2022年12月组队学习 ✌ 文章目录周志华 《机器学习初步》模型评估与选择一.泛化能力二.过拟合和欠拟合泛化误差 VS 经验误差过拟合 VS 欠拟合三.模型选择的三大问题如何获得测试结果:评估方法如何评估性能优劣&…

工厂设备管理中经常会遇到哪些问题?

我调查过上百家企业的设备管理问题,发现大家认为所有设备管理问题中,最典型的问题主要包括以下五个方面: 1)领导不重视管理 “生产量是最重要的”、“销售额是最重要”、“重ERP,轻现场管理”……等管理理念是企业中的…

镜像法的理解——工程电磁场 P9

模型一:无限大导体平面 此处有几点理解需要格外谈一下 1. 只有在有电力线的地方,才会产生电场的作用 2.对于下平面的分析,下平面如果存在电荷的话,必然存在电力线,那么从无穷远处做功到此处,必然会存在电…

Java网络多线程——UDP编程

UDP编程通信 基本介绍 类DatagramSocket和DatagramPacket【数据包/数据报】实现了基于UDP协议网络程序。UDP数据报通过数据报套接字DatagramSocket发送和接收,系统不保证UDP数据报一定能安全送到目的地,也不确信什么时候可以抵达。DatagramPacket对象封…

从「堆叠」到「降本」,智能汽车传感器颠覆性革命即将到来!

随着汽车智能化的演进,传感器的堆叠造成了整车成本的急剧上升。尤其是多传感器融合(摄像头、毫米波雷达和激光雷达)技术作为当下的主流趋势之一,焦点依然回到成本层面。 同时,传统的整车电子架构和计算能力的限制&…

Flutter 小技巧之快速理解手势逻辑

又到了小技巧系列更新时间,今天我们主要分享 Flutter 里的手势触摸逻辑,其实在很久之前我就写过 《面深入触摸和滑动原理》相关的源码分析文章,但是最近有人说源码分析看不懂,有没有简要好理解的,那么本篇就用更简单的…

[附源码]Node.js计算机毕业设计高校图书馆网站Express

项目运行 环境配置: Node.js最新版 Vscode Mysql5.7 HBuilderXNavicat11Vue。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分离等等。 环境需要 1.运行环境:最好是Nodejs最新版,我…

高通平台 5G RF调试总结

目录: 1.QRCT4的使用 2.RFC配置 3.5G CA 配置概括 4.RFPD 运行及错误分析 5.CA吞吐率问题分析 最新的5G HImalyaa平台RFC的配置方法和之前的平台发生了根本性的变化,主要体现在使用QRCT4工具来配置RFC XML文件,然后根据XML文件编译生成s…

MobileNetV3原理说明及实践落地

本文参考: pytorch实现并训练MobileNetV3 - 灰信网(软件开发博客聚合) 【神经网络】(16) MobileNetV3 代码复现,网络解析,附Tensorflow完整代码 - 代码天地 1 MobileNetV3与V1、V2对比 (1)Mob…

【LeetCode每日一题:1945. 字符串转化后的各位数字之和~~~模拟】

题目描述 给你一个由小写字母组成的字符串 s ,以及一个整数 k 。 首先,用字母在字母表中的位置替换该字母,将 s 转化 为一个整数(也就是,‘a’ 用 1 替换,‘b’ 用 2 替换,… ‘z’ 用 26 替换…

匿名浏览器是什么?为什么联盟营销需要借助匿名浏览器?

这段时间小伙伴们都对联盟营销很感兴趣,东哥也是陆陆续续出了两三篇相关的科普文章,今天继续给大家介绍匿名浏览器在联盟营销上的帮助,毕竟互联网时代,学会如何借助工具高效工作是很重要的。关于联盟营销的概念科普文章大家可以看…

学不会的python之通过某几个关键字排序、分组一个字典列表(列表中嵌套字典)

通过某个关键字排序、分组一个字典列表排序问题描述解决方案1.operator 模块的 itemgetter 函数2.lambda 表达式引申分组问题描述解决方案1.itertools.groupby() 函数2.defaultdict() 构建多值字典排序 问题描述 现在你有一个字典列表(列表中嵌套字典),你想要根据…

web 向 unity 传输文件流 blob 记录

场景:web 与unity 通信,向 unity 传输文件 二进制流。 由 unity 转换并下载文件。 流程: web 端将缓存的 blob 数据流读取为 base64 编码的数据 → 传给 unity, →unity 解码转换 base64 数据并下载。 web 端: 1、 将数据转换成…

【Axure教程】自定义审批流原型模板

审批流即审批流程,是对某项工作的审批活动的一系列有序组合。审批流在业务系统中担当者非常重要的角色,所以今天作者就教大家制作一个通用的自定也审批流的原型模板,方便大家日后的工作。 一、效果展示 1、可以根据业务需要添加多个审批节点…

QT学习笔记(中)

QT学习笔记(中) 文章目录QT学习笔记(中)P21 消息对话框P22 其他标准对话框P23 登录窗口界面和布局P24 控件 按钮组P25 QListWidget控件P26 QTreeWidget控件的使用P27 tableWidgetP28 其他常用控件介绍P30 自定义控件P31 QEventP32…