卷积神经网络实现运动鞋识别 - P5

news2024/11/14 18:26:13
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:Pytorch实战 | 第P5周:运动鞋识别
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
      • 包引用
      • 训练设备
    • 数据准备
      • 图像解压后的路径
      • 打印图像的参数
      • 展示图像
      • 图像的预处理
      • 创建数据集
      • 获取数据集的分类
      • 打乱数据的顺序,生成批次
    • 模型设计
    • 模型训练
      • 训练函数
      • 评估函数
      • 循环迭代部分
    • 模型效果展示
      • 训练过程图表展示
      • 载入最佳模式,随机选择图像进行预测
  • 总结与心得体会


环境

  • 系统: Linux
  • 语言: Python3.8.10
  • 深度学习框架: Pytorch2.0.0+cu118

步骤

环境设置

包引用

import torch
import torch.nn as nn 
import torch.optim as optim # 优化器
import torch.nn.functional as F # 可以静态调用的方法

from torchvision import datasets, transforms # 数据集创建、数据预处理方法
from torch.utils.data import DataLoader # DataLoader可以将数据集封装成批次数据

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image # 加载图片预览使用的库
from torchinfo import summary # 可以打印模型实际运行时的图
import copy # 深拷贝使用的库
import pathlib, random # 文件夹遍历和随机数

训练设备

# 声明一个全局设备对象,方便后面将数据和模型拷贝到设备中
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

数据准备

图像解压后的路径

train_path = 'train'
test_path = 'test'

打印图像的参数

train_pathlib = pathlib.Path(train_path)
train_image_list = list(train_pathlib.glob('*/*'))
for _ in range(5):
	print(np.array(Image.open(str(random.choice(train_image_list)))).shape)

图片的参数
重复执行了多次,返回结果都是(240, 240, 3),可以确定图像的大小统一为240,240,在数据加载的过程中可以不对图像做缩放处理。

展示图像

plt.figure(figsize=(20, 4))
for i in range(20):
	image = random.choice(train_image_list)
	plt.subplot(2, 10, i+1)
	plt.axis('off')
	plt.imshow(Image.open(str(image)))
	plt.title(image.parts[-2])

数据集预览
至此我们对数据集中的图像有了一个初步的了解。接下来就是准备训练数据。

图像的预处理

定义一些图像的预处理方法,例如将图像读取并转为pytorch的tensor对象,然后对图像的数值做归一化处理

transform = transforms.Compose([
	transforms.ToTensor(),
	transforms.Normalize(
		mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    )
])

创建数据集

train_dataset = datasets.ImageFolder(train_path, transform=transform)
test_dataset = datasets.ImageFolder(test_path, transform=transform)

获取数据集的分类

class_names = [key for key in train_dataset.class_to_idx]
print(class_names)

数据分类

打乱数据的顺序,生成批次

batch_size = 32
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

模型设计

使用3x3的卷积核,最大的通道数到256,每次卷积操作后,就紧跟一个池化层,一共使用了4个卷积层和4个池化层。最后使用了三层全连接网络来做分类器。
模型结构图

class Network(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        self.conv1 = nn.Conv2d(3, 64, 3)
        self.bn1  = nn.BatchNorm2d(64)
        
        self.conv2 = nn.Conv2d(64, 128, 3)
        self.bn2 = nn.BatchNorm2d(128)
        
        self.conv3 = nn.Conv2d(128, 256, 3)
        self.bn3 = nn.BatchNorm2d(256)
        
        self.conv4 = nn.Conv2d(256, 256, 3)
        self.bn4 = nn.BatchNorm2d(256)
        
        self.maxpool = nn.MaxPool2d(2)
        
        self.fc1 = nn.Linear(13*13*256, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, num_classes)
        
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        # 240 -> 238
        x = F.relu(self.bn1(self.conv1(x)))
        # 238 -> 119
        x = self.maxpool(x)
        
        # 119 -> 117
        x = F.relu(self.bn2(self.conv2(x)))
        # 117 -> 58
        x = self.maxpool(x)
        
        # 58 -> 56
        x = F.relu(self.bn3(self.conv3(x)))
        # 56 -> 28
        x = self.maxpool(x)
        
        # 28 -> 26
        x = F.relu(self.bn4(self.conv4(x)))
        # 26 -> 13
        x = self.maxpool(x)
        
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = F.relu(self.dropout(self.fc1(x)))
        x = F.relu(self.dropout(self.fc2(x)))
        x = self.fc3(x)
        
        return x
model = Network(len(class_names)).to(device)

summary(model, input_size=(32, 3, 240, 240))

模型结构图

模型训练

模型训练过程中,每个epoch都会对全部的训练集进行一次完整的遍历,所以可以封装一些训练和评估方法,将业务逻辑和循环分开

训练函数

def train(train_loader, model, loss_fn, optimizer):
	size = len(train_loader.dataset)
	num_batches = len(train_loader)

	train_loss, train_acc = 0, 0
	for x, y in train_loader:
		x, y = x.to(device), y.to(device)
		
		pred = model(x)
		loss = loss_fn(pred, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		train_loss += loss.item()
		train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

	train_loss /= num_batches
	train_acc /= size

	return train_loss, train_acc

评估函数

def test(test_loader, model, loss_fn):
	size = len(test_loader.dataset)
	num_batches = len(test_loader)

	test_loss, test_acc = 0, 0
	for x, y in test_loader:
		x, y = x.to(device), y.to(device)

		pred = model(x)
		loss = loss_fn(pred, y)

		test_loss += loss.item()
		test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

	test_loss /= num_batches
	test_acc /= size

	return test_loss, test_acc

循环迭代部分

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch:0.92**(epoch //2)) 
# 创建学习率的衰减
epochs = 50

train_loss, train_acc = [], []
test_loss, test_acc = [], []
best_acc = 0
for epoch in range(epochs):
    model.train()
    epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)
    model.eval()
    with torch.no_grad():
        epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)
        
    scheduler.step() # 每次迭代调用一次,自动做学习率衰减
    
    # 如果当前评估的学习率更好,就保存当前模型
    if best_acc < epoch_test_acc:
        best_acc = epoch_test_acc
        best_model = copy.deepcopy(model)
    
    # 记录历史记录
    train_loss.append(epoch_train_loss)
    train_acc.append(epoch_train_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)
    
    # 打印每个迭代的数据
    print(f"Epoch:{epoch+1}, TrainLoss: {epoch_train_loss:.3f}, TrainAcc: {epoch_train_acc*100:.1f}, TestLoss: {epoch_test_loss:.3f}, TestAcc: {epoch_test_acc*100:.1f}")

# 打印本次训练的最佳正确率
print(f'training finished, best_acc is {best_acc*100:.1f}')

# 将最佳模型保存到文件中
torch.save(model.state_dict(), 'best_model.pth')

模型训练过程

模型效果展示

训练过程图表展示

画一个拆线图,观察训练过程中损失函数和正确率的变化趋势

plt.figure(figsize=(20,5))

epoch_range = range(epochs)

plt.subplot(1,2, 1)
plt.plot(epoch_range, train_loss, label='train loss')
plt.plot(epoch_range, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')

plt.subplot(1,2,2)
plt.plot(epoch_range, train_acc, label='train accuracy')
plt.plot(epoch_range, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

训练过程图示
可以看出模型在最后基本已经收敛,最佳准确率是88.2%,满足了挑战任务。

载入最佳模式,随机选择图像进行预测

model.load_state_dict(torch.load('best_model.pth'))
model = model.to(device)

test_pathlib = pathlib.Path(test_path)

image_list = list(test_pathlib.glob('*/*'))

image_path = random.choice(image_list)
image = transform(Image.open(str(image_path)))
image = image.unsqueeze(0)
image = image.to(device)

pred = model(image)

plt.figure(figsize=(5,5))
plt.axis('off')
plt.imshow(Image.open(str(image_path)))
plt.title(f'real: {image_path.parts[-2]}, predict: {class_names[pred.argmax(1).item()]}')

预测结果
上次运行上面的预测任务,发现正确率还不错。

总结与心得体会

  1. 整个模型设计的思路其实是模仿了vgg16模型,在卷积层的数量和通道上做了简化。轻量级的任务可以首先试着减少池化层间的卷积次数,减少模型中最大的特征图的通道数
  2. 对图像的归一化操作很重要。在没有归一化前,模型的最佳正确率只能达到80%,推测可能是因为未做归一化的图像值域范围太大,不方便收敛,归一化后,原始图像中的输入特征值范围变成0~1,模型的权重变化更易作用到特征上。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/959208.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【git】从一个git仓库迁移到另外一个git仓库

在远端服务器创建一个新的仓库 用界面创建&#xff0c;当然也可以用命令创建 拉去源仓库 git clone --bare git192.168.10.10:java/common.gitgit clone --bare <旧仓库地址>拉去成功以后会出现 进入到文件夹内部 出现下面信息&#xff1a; 推送到新的远端仓库 git …

Sharding-JDBC分片策略

Sharding-JDBC分片策略 包含分片键和分片算法&#xff0c;由于分片算法的独立性&#xff0c;将其独立抽离。真正可用于分片操作的是分片键 分片算法&#xff0c;也就是分片策略。目前提供5种分片策略。 一个好的分片策略好的分片键好的的分片算法 1. 标准分片策略 对应Stan…

3D数字孪生技术助力港口全新升级,提供实时数据进行智能调度

港口3D数字孪生平台是一种基于数字技术的虚拟模型&#xff0c;它可以模拟真实的港口环境&#xff0c;并对港口的运营、管理、安全等方面进行实时监控和优化。该平台带来了许多智能化提升&#xff0c;包括以下几个方面&#xff1a; 一、自动化操作和智能调度 数字孪生平台可以通…

ssm珠宝首饰交易平台源码和论文

ssm珠宝首饰交易平台源码和论文101 开发工具&#xff1a;idea 数据库mysql5.7 数据库链接工具&#xff1a;navcat,小海豚等 技术&#xff1a;ssm 摘 要 随着科学技术的飞速发展&#xff0c;各行各业都在努力与现代先进技术接轨&#xff0c;通过科技手段提高自身的优势&a…

在SOLIDWORKS的Toolbox中,表达轴承的承受能力与寿命计算器

轴承作为机械设计中最常用的标准件之一&#xff0c;在选型时需要对其进行严格的能力和寿命的计算。手工计算涉及到诸多的公式和参数&#xff0c;需要较多的精力去完成计算的工作。 在SOLIDWORKS的Toolbox中便包含了针对于轴承的计算器&#xff0c;通过该工具可以非常快速的计算…

MySQL主从复制案例

主从复制与读写分离 在实际的生产环境中&#xff0c;对数据库的读和写都在同一个数据库服务器中&#xff0c;是不能满足实际需求的。无论是在安全性、高可用性还是高并发等各个方面都是完全不能满足实际需求的。因此&#xff0c;通过主从复制的方式来同步数据&#xff0c;再通…

Apinto 网关 V0.14 版本发布,6 大插件更新!

大家好&#xff01; 距离上次更新已经过去一段时间了&#xff0c;这段日子里我们一直在酝酿新的功能&#xff0c;本次的迭代将给大家带来 6 大插件的更新~一起来看看有哪些变化吧&#xff01; 新特性 1. 新增 额外参数v2 插件&#xff0c;支持对转发参数进行加密、拼接等操作…

对话出海企业:2023亚马逊云科技出海日圆桌论坛

在全球经济亟待复苏的今天&#xff0c;持续对外开放是中国未来经济发展重要的“两条腿”之一。在愈发饱和的国内市场&#xff0c;中国企业需要对外寻找全新机遇才能在未来不确定的市场博弈下生存下去。“出海”&#xff0c;也成为近几年最炙手可热的词汇之一&#xff0c;大量中…

ArcGIS Maps SDK for JS(一):概述与使用

文章目录 1 概述2 如何使用ArcGIS Maps SDK for JavaScript2.1 AMD 模块与 ES 模块2.2 AMD 模块和 ES 模块比较 3 几种安装方式3.1 通过 ArcGIS CDN 获取 AMD 模块3.2 通过 NPM 运行 ES 模块3.3 通过 CDN 获取 ES 模块3.4 本地构建 ES3.5 本地构建 AMD 3 VSCode下载与安装2.1 下…

JDK源码解析-LinkedList

1. LinkedList类 1.1 LinkedList类定义&数据结构 定义 LinkedList是一种可以在任何位置进行高效地插入和移除操作的有序序列&#xff0c;它是基于双向链表实现的。 数据结构 基础知识补充 单向链表&#xff1a; element&#xff1a;用来存放元素 next&#xff1a;用来…

Leetcode Top 100 Liked Questions(序号141~189)

​ 141. Linked List Cycle ​ 题意&#xff1a;给你一个链表&#xff0c;判断链表有没有环 我的思路 两个指针&#xff0c;一个每次走两步&#xff0c;一个每次走一步&#xff0c;如果走两步的那个走到了NULL&#xff0c;那说明没有环&#xff0c;如果两个指针指向相等&…

取一个整数各偶数位上的数构成一个新的数字

1 题目 我可太难了&#xff0c;这题我的思路有点复杂&#xff0c;遇到的困难很多&#xff0c;总是值传递搞不清楚&#xff0c;地址传递总是写错。 从低位开始取出一个整数s的各奇数位上的数&#xff0c;剩下的偶数位的数依次构成一个新数t。 例如&#xff1a; 输入s&#xff…

VB:水仙花数问题

VB&#xff1a;水仙花数问题 Private Sub Command1_Click()Rem 水仙花数问题Dim x%, a%, b%, c%, z%n 0For x 100 To 999a Fix(x / 100) Fix函数是去尾的作用&#xff0c;只保留整数部分&#xff0c;当然也可以直接用整除(\)b Fix((x - a * 100) / 10)c x Mod 10z a ^ 3…

C语言中的分支和循环语句:从入门到精通

分支和循环语句 1. 前言2. 预备知识2.1 getchar函数2.2 putchar函数2.3 计算数组的元素个数2.4 清屏2.5 程序的暂停2.6 字符串的比较 3. 结构化3.1 顺序结构3.2 分支结构3.3 循环结构 4. 真假性5. 分支语句&#xff08;选择结构&#xff09;5.1 if语句5.1.1 语法形式5.1.2 else…

人气总冠军-商艺馨 | 第11季中国好猫步辽宁总决赛

第十一季中国好猫步 辽宁总决赛 中国好猫步少儿模特赛事活动属于CCAC大满贯赛事中的大师赛&#xff0c;一直以来&#xff0c;以华丽的舞美、创意丰富的赛制、贴心的服务、丰厚的奖励和众多媒体曝光优势&#xff0c;成为无数少儿模特梦寐以求登上的舞台&#xff01;并且多次登上…

CG MAGIC进行实体渲染后!分析渲染器CR和VR的区别之处!

新手小白来说&#xff0c;如何选择渲染器&#xff0c;都会提出疑问&#xff1f; 渲染效果图究竟用CR渲染器还是VR渲染器呢&#xff1f; 今天&#xff0c;CG MAGIC小编通过一个真实的项目场景&#xff0c;实例渲染之后&#xff0c;CR渲染器和VR渲染器区别有哪几点&#xff1f; 1…

Java里面单向链表实现

Java里面单向链表实现 说明代码 说明 这里记录下单向链表实现。并在类里面写了一些新增和删除方法。 代码 package com.example.test;//单向链表类 public class Node {//结点的值public int val;//当前结点的后继结点,当 next null 时&#xff0c;代表这个结点是所在链表的…

VB:求1000以内的质数

VB&#xff1a;求1000以内的质数 Private Sub Command1_Click() Dim m%, i%, p%, k%, n% For m 2 To 1000 求1000以内的质数&#xff0c;2是最小的质数p 1k Int(Sqr(m))For i 2 To kIf m Mod i 0 Thenp 0Exit ForEnd IfNext iIf p 1 ThenPrint Tab((n Mod 10) * 5 2);…

20. python从入门到精通——Flask框架

目录 安装虚拟环境和Flask 第一个Flask程序 Flask的调试模式 路由 变量规则&#xff1a;当在页面中输出变量的时候就需要遵循变量的规则 构造URL 在route函数中设置http方法 获取静态文件路径 蓝图 模板 Web表单 CSRF 安装虚拟环境和Flask Flask框架主要依赖两个库…

Java“牵手”1688图片识别商品接口数据,图片地址识别商品接口,图片识别相似商品接口,1688API申请指南

1688商城是一个网上购物平台&#xff0c;售卖各类商品&#xff0c;包括服装、鞋类、家居用品、美妆产品、电子产品等。要通过图片地址识别获取1688商品列表和商品详情页面数据&#xff0c;您可以通过开放平台的接口或者直接访问1688商城的网页来获取商品详情信息。以下是两种常…