卷积神经网络实现咖啡豆分类 - P7

news2025/1/11 21:53:37
  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制
  • 🚀 文章来源:K同学的学习圈子

目录

  • 环境
  • 步骤
    • 环境设置
      • 包引用
      • 全局设备对象
    • 数据准备
      • 查看图像的信息
      • 制作数据集
    • 模型设计
      • 手动搭建的vgg16网络
      • 精简后的咖啡豆识别网络
    • 模型训练
      • 编写训练函数
      • 编写测试函数
      • 开始训练
      • 展示训练过程
    • 模型效果展示
  • 总结与心得体会


环境

  • 系统: Linux
  • 语言: Python3.8.10
  • 深度学习框架: Pytorch2.0.0+cu118
  • 显卡:A5000 24G

步骤

环境设置

包引用

import torch
import torch.nn as nn # 网络
import torch.optim as optim # 优化器
from torch.utils.data import DataLoader, random_split # 数据集划分
from torchvision import datasets, transforms # 数据集加载,转换

import pathlib, random, copy # 文件夹遍历,实现模型深拷贝
from PIL import Image # python自带的图像类
import matplotlib.pyplot as plt # 图表
import numpy as np 
from torchinfo import summary # 打印模型参数

全局设备对象

方便将模型和数据统一拷贝到目标设备中

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

数据准备

查看图像的信息

data_path = 'coffee_data'
data_lib = pathlib.Path(data_path)
coffee_images = list(data_lib.glob('*/*'))

# 打印5张图像的信息
for _ in range(5):
	image = random.choice(coffee_images)
	print(np.array(Image.open(str(image))).shape)

图像信息
通过打印的信息,可以看出图像的尺寸都是224x224的,这是一个CV经常使用的图像大小,所以后面我就不使用Resize来缩放图像了。

# 打印20张图像粗略的看一下
plt.figure(figsize=(20, 4))
for i in range(20):
	plt.subplot(2, 10, i+1)
	plt.axis('off')
	image = random.choice(coffee_images) # 随机选出一个图像
	plt.title(image.parts[-2]) # 通过glob对象取出它的文件夹名称,也就是分类名
	plt.imshow(Image.open(str(image))) # 展示

数据集预览
通过展示,对数据集内的图像有个大概的了解

制作数据集

先编写数据的预处理过程,用来使用pytorch的api加载文件夹中的图像

transform = transforms.Compose([
	transforms.ToTensor(), # 先把图像转成张量
	transforms.Normalize( # 对像素值做归一化,将数据范围弄到-1,1
       mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225],
	),
])

加载文件夹

dataset = datasets.ImageFolder(data_path, transform=transform)

从数据中取所有的分类名

class_names = [k for k in dataset.class_to_idx]
print(class_names)

图像分类名
将数据集划分出训练集和验证集

train_size = int(len(dataset) * 0.8)
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

将数据集按划分成批次,以便使用小批量梯度下降

batch_size = 32
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

模型设计

在一开始时,直接手动创建了Vgg-16网络,发现少数几个迭代后模型就收敛了,于是开始精简模型。

手动搭建的vgg16网络

class Vgg16(nn.Module):
	def __init__(self, num_classes):
		super().__init__()
		
		self.block1 = nn.Sequential(
			nn.Conv2d(3, 64, 3, padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.Conv2d(64, 64, 3, padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		self.block2 = nn.Sequential(
			nn.Conv2d(64, 128, 3, padding=1),
			nn.BatchNorm2d(128),
			nn.ReLU(),
			nn.Conv2d(128, 128, 3, padding=1),
			nn.BatchNorm2d(128),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		self.block3 = nn.Sequential(
			nn.Conv2d(128, 256, 3, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(),
			nn.Conv2d(256, 256, 3, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(),
			nn.Conv2d(256, 256, 3, padding=1),
			nn.BatchNorm2d(256),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		self.block4 = nn.Sequential(
			nn.Conv2d(256, 512, 3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(),
			nn.Conv2d(512, 512, 3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(),
			nn.Conv2d(512, 512, 3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		self.block5 = nn.Sequential(
			nn.Conv2d(512, 512, 3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(),
			nn.Conv2d(512, 512, 3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(),
			nn.Conv2d(512, 512, 3, padding=1),
			nn.BatchNorm2d(512),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		self.pool = nn.AdaptiveAvgPool2d(7)
		self.classifier = nn.Sequential(
			nn.Linear(7*7*512, 4096),
			nn.Dropout(0.5),
			nn.ReLU(),
			nn.Linear(4096, 4096),
			nn.Dropout(0.5),
			nn.ReLU(),
			nn.Linear(4096, num_classes),
		)

	def forward(self, x):
		x = self.block1(x)
		x = self.block2(x)
		x = self.block3(x)
		x = self.block4(x)
		x = self.block5(x)
		x = self.pool(x)
		x = x.view(x.size(0),-1)
		x = self.classifier(x)
		return x
vgg = Vgg16(len(class_names)).to(device)
summary(vgg, input_size=(32, 3, 224, 224))

VGG16模型
通过模型结构的打印可以发现,VGG-16网络共有134285380个可训练参数(我加了BatchNorm,和官方的比会稍微多出一些),参数量非常巨大,对于咖啡豆识别这种小场景,这么多可训练参数肯定浪费,于是对原始的VGG-16网络结构进行精简。

精简后的咖啡豆识别网络

class Network(nn.Module):
	def __init__(self, num_classes):
		super().__init__()
		
		self.block1 = nn.Sequential(
			nn.Conv2d(3, 64, 3, padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.Conv2d(64, 64, 3, padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)

		self.block2 = nn.Sequential(
			nn.Conv2d(64, 128, 3, padding=1),
			nn.BatchNorm2d(128),
			nn.ReLU(),
			nn.Conv2d(128, 128, 3, padding=1),
			nn.BatchNorm2d(128),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)
		
		self.block3 = nn.Sequential(
			nn.Conv2d(128, 64, 3, padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.Conv2d(64, 64, 3, padding=1),
			nn.BatchNorm2d(64),
			nn.ReLU(),
			nn.MaxPool2d(2),
		)

		self.pool = nn.AdaptiveAvgPool2d(7),

		self.classifier = nn.Sequential(
			nn.Linear(7*7*64, 64),
			nn.Dropout(0.4),
			nn.ReLU(),
			nn.Linear(64, num_classes)
		)
	
	def forward(self, x):
		x = self.block1(x)
		x = self.block2(x)
		x = self.block3(x)
		x = self.pool(x)
		x = x.view(x.size(0), -1)
		x = self.classifier(x)
		return x
model = Network(len(class_names)).to(device)
summary(model, input_size=(32, 3, 224, 224))

精简后的模型
可以看到精简后的网络模型参数量还不到原来的1/10,但是其在测试集上的正确率依然能够达到100%!

模型训练

编写训练函数

def train(train_loader, model, loss_fn, optimizer):
	size = len(train_loader.dataset)
	num_batches = len(train_loader)

	train_loss, train_acc = 0, 0
	for x, y in train_loader:
		x, y = x.to(device), y.to(device)
	
		pred = model(x)
		loss = loss_fn(pred, y)

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()

		train_loss += loss.item()
		train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

	train_loss /= num_batches
	train_acc /= size

	return train_loss, train_acc

编写测试函数

def test(test_loader, model, loss_fn):
	size = len(test_loader.dataset)
	num_batches = len(test_loader)
	
	test_loss, test_acc = 0, 0
	for x, y in test_loader:
		x, y = x.to(device), y.to(device)

		pred = model(x)
		loss = loss_fn(pred, y)

		test_loss += loss.item()
		test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

	test_loss /= num_batches
	test_acc /= size

	return test_loss, test_acc

开始训练

首先定义损失函数,优化器设置学习率,这里我们再弄一个学习率的衰减,再加上总的迭代次数,最佳模型的保存位置

epochs = 30
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 0.92**(epoch//2))
best_model_path = 'best_coffee_model.pth'

然后编写训练+测试的循环,并记录训练过程的数据

best_acc = 0
train_loss, train_acc = [], []
test_loss, test_acc = [], []
for epoch in epochs:
	model.train()
	epoch_train_loss, epoch_train_acc = train(train_loader, model, loss_fn, optimizer)
	scheduler.step()

	model.eval()
	with torch.no_grad();
		epoch_test_loss, epoch_test_acc = test(test_loader, model, loss_fn)
	
	train_loss.append(epoch_train_loss)
	train_acc.append(epoch_train_acc)
	test_loss.append(epoch_test_loss)
	test_acc.append(epoch_test_acc)

	lr = optimizer.state_dict()['param_groups'][0]['lr']

	if best_acc < epoch_test_acc:
		best_acc = epoch_test_acc
		best_model = copy.deepcopy(model)

	print(f"Epoch: {epoch+1}, Lr:{lr}, TrainAcc: {epoch_train_acc*100:.1f}, TrainLoss: {epoch_train_loss:.3f}, TestAcc: {epoch_test_acc*100:.1f}, TestLoss: {epoch_test_loss:.3f}")

print(f"Saving Best Model with Accuracy: {best_acc*100:.1f} to {best_model_path}")
torch.save(best_model.state_dict(), best_model_path)
print('Done')

训练过程日志
可以看出,模型在测试集上的正确率最高达到了100%

展示训练过程

epoch_ranges = range(epochs)
plt.figure(figsize=(20,6))
plt.subplot(121)
plt.plot(epoch_ranges, train_loss, label='train loss')
plt.plot(epoch_ranges, test_loss, label='validation loss')
plt.legend(loc='upper right')
plt.title('Loss')

plt.figure(figsize=(20,6))
plt.subplot(122)
plt.plot(epoch_ranges, train_acc, label='train accuracy')
plt.plot(epoch_ranges, test_acc, label='validation accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy')

训练过程参数

模型效果展示

model.load_state_dict(torch.load(best_model_path))
model.to(device)
model.eval()

plt.figure(figsize=(20,4))
for i in range(20):
	plt.subplot(2, 10, i+1)
	plt.axis('off')
	image = random.choice(coffee_images)
	input = transform(Image.open(str(image))).to(device).unsqueeze(0)
	pred = model(input)
	plt.title(f'T:{image.parts[-2]}, P:{class_names[pred.argmax()]}')
	plt.imshow(Image.open(str(image)))

模型效果展示
通过结果可以看出,确实是所有的咖啡豆都正确的识别了。

总结与心得体会

  • 因为目前网络还是很快就收敛到一个很高的水平,所以应该还有很大的精简的空间,但是可能会稍微牺牲一些正确率。
  • 模型的选取要根据实际任务来确定,像咖啡豆种类识别这种任务,使用VGG-16太浪费了。
  • 在精简的过程中,没有感觉到训练速度有明显的变化 ,说明参数量和训练速度并没有直接的相关关系。
  • 连续多层参数一样的卷积操作好像比只用一层效果要好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1013614.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FL Studio21.2国内中文语言最新试用版新增功能介绍

FL studio是一款音乐软件国内试用版仅更新至21.0&#xff0c;国外已更新至21.2了&#xff01;初始自带的鼓组音色也从土嗨鼓变成了beat作者喜闻乐见的808&#xff0c;KSHMR直接原地失业。注意安装后上方的菜单可能会出现布局杂乱的情况&#xff0c;右键菜单空白处选择default即…

Linux的常见指令

目录 pwd命令ls 指令mkdir指令touch指令cd 指令rmdir指令 && rm 指令man指令nanocp指令mv指令cat指令more指令less指令head指令tail指令grep指令热键zip/unzip指令tar指令uname –r指令输出重定向 图形化界面和命令行操作本质都是对操作系统进行直接或间接的操作 pwd命…

x86平台运行arm64平台docker 镜像

本文介绍在x86服务器上安装qemu-aarch64-statick仿真器&#xff0c;以实现x86服务器可以运行docker或docker-compose镜像。 报错信息&#xff1a; x86服务器默认不能运行ARM平台镜像&#xff0c;会提示如下错误&#xff1a; WARNING: The requested images platform (linux/ar…

Mock数据:单元测试中的心灵鸡汤

在当今的软件开发领域&#xff0c;质量控制已经成为了一个不可或缺的环节。为了确保软件的稳定性和可靠性&#xff0c;开发者们投入了大量的时间和精力进行各种测试。其中&#xff0c;单元测试作为最基础的测试方法&#xff0c;其重要性不言而喻。然而&#xff0c;单元测试中的…

GDB之源码与汇编映射对应关系(十五)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…

排序算法-----归并排序

目录 前言&#xff1a; 归并排序 1. 定义 2.算法过程讲解 2.1大致思路 2.2图解示例 拆分合成步骤 ​编辑 相关动态图 3.代码实现&#xff08;C语言&#xff09; 4.算法分析 4.1时间复杂度 4.2空间复杂度 4.3稳定性 前言&#xff1a; 今天我们就开始学习新的排序算法…

数值类型表示二——定点和浮点格式

目录 目录 定点小数与定点整数 定点小数原反补的转换 定点小数与定点整数的取值范围 位数扩展的区别 浮点数的格式 浮点数的规格化 规格化处理举例 例1&#xff1a; 例2&#xff1a; 特例&#xff1a; 知识点总结&#xff1a; 浮点数的IEEE754标准 移码的回顾&…

通过UltraSync减轻主节点负担、提升业务系统性能,AntDB如何做到?

众所周知&#xff0c;数据库在多中心场景下&#xff0c;主中心主库不仅要承担原本业务的压力&#xff0c;而且还要将redo日志传输到不同的备库端&#xff0c;这样对主库将产生很大的性能影响。通常情况下&#xff0c;备中心和主中心不在同⼀机房&#xff0c;为了保证业务响应速…

Linux内核 6.6版本将遏制NVIDIA驱动的不正当行为

导读Linux 内核开发团队日前宣布&#xff0c;即将发布的 Linux 6.6 版本将增强内核模块机制&#xff0c;以更好地防御 NVIDIA 闭源驱动的不正当行为。 Linux 内核开发团队日前宣布&#xff0c;即将发布的 Linux 6.6 版本将增强内核模块机制&#xff0c;以更好地防御 NVIDIA 闭…

【HR】胜任力相关资料--20230915

0_建模技术介绍 传统的两种胜任力词典 光辉合益LOMINGER 67项能力检核表 海氏 DDI胜任力词典2.0 北森GENE建模技术 三种建模的方法 A公司 建模及应用 素质模型的组合 建模的选择 工具&#xff1a;光辉领导力素质卡片【38条素质】 素质模型示例 素质模型的应用及意义 1_能力素…

创建UI账号密码登录界面

头文件 #ifndef MYWND_H #define MYWND_H#include <QPushButton> #include <QMainWindow>class MyWnd : public QMainWindow {Q_OBJECTpublic:MyWnd(QWidget *parent nullptr);~MyWnd(); }; #endif // MYWND_H 源文件 #include "mywnd.h" #include &…

链动2+1模式:让中小企业家轻松实现社交电商

社交电商是一种利用社交网络和社群平台&#xff0c;通过人与人之间的互动和分享&#xff0c;实现商品或服务的销售和推广的电商模式。社交电商具有低成本、高效率、高转化率、高忠诚度等优势&#xff0c;是当下最火热的电商趋势之一。 然而&#xff0c;对于中小企业家来说&…

HarmonyOS学习路之方舟开发框架—学习ArkTS语言(状态管理 七)

PersistentStorage&#xff1a;持久化存储UI状态 前两个小节介绍的LocalStorage和AppStorage都是运行时的内存&#xff0c;但是在应用退出再次启动后&#xff0c;依然能保存选定的结果&#xff0c;是应用开发中十分常见的现象&#xff0c;这就需要用到PersistentStorage。 Pe…

MongoDB的搭建 和crud操作

MongoDB docker 下载 docker run --restartalways -d --name mongo -v /docker/mongodb/data:/data/db -p 27017:27017 mongo:4.0.6使用navcat工具使用MongoDB Crud操作 jar包 <dependency><groupId>org.projectlombok</groupId><artifactId>lom…

three.js 入门 初识

基本步骤&#xff1a; 初始设置创建场景创建相机创建可见对象创建渲染器渲染场景 安装 npm install three 引入 import * as THREE from "three"; 一、three三要素&#xff1a;场景、相机、渲染 1.场景&#xff1a; //创建场景 const scenenew THREE.Scene()…

python中not的用法

前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 话不多说&#xff0c;直接开搞&#xff0c;如果有什么疑惑/资料需要的可以点击文章末尾名片领取源码 python中的not具体表示是什么: 在python中not是逻辑判断词&#xff0c;用于布尔型True和False&#xff0c; not True为F…

【Unity插件】实现多人在线游戏——Mirror插件的使用介绍

文章目录 前言导入Mirror插件 简单介绍一、RPC调用二、错误注意 基本使用一、创建场景的网络管理器二、创建一个玩家三、添加玩家初始生成位置四、玩家控制五、同步摄像机六、同步不同角色的名字和颜色修改七、同步动画八、同步子弹方法一方法二 九、聊天功能十、场景同步切换十…

torch.where()两种用法

参考官方文档。 官方文档中只给了第一种用法。根据条件condition&#xff0c;从input,other中选择元素f返回。如果满足条件&#xff0c;则返回input元素。若不满足&#xff0c;返回other元素。 还有一种用法是通过where返回张量中满足条件condition的坐标&#xff0c;以二维张…

【Unity每日一记】资源加载相关和检测相关

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;元宇宙-秩沅 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 秩沅 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a;uni…

中兴协力NB-IoT部署实验(含复杂项目)

这个实验要求每个人都完成一遍&#xff0c;并且不同学生的部分操作内容也不同&#xff0c;个别班级最后也被要求基于此完成复杂项目&#xff0c;黑字部分是必要操作&#xff0c;紫字部分是辅助完成操作或复杂项目的讲解 进入实验室&#xff0c;选择模拟器&#xff08;同一台模…