【神经网络】基于CNN(卷积神经网络)构建猫狗分类模型

news2024/12/28 19:56:37

文章目录

    • 解决问题
    • 数据集
    • 探索性数据分析
    • 数据预处理
      • 数据集分割
      • 数据预处理
    • 构建模型并训练
      • 构建模型
      • 训练模型
    • 结果分析与评估
    • 模型保存
    • 结果预测
    • 经验总结

解决问题

针对经典猫狗数据集,基于卷积神经网络,构建猫狗二元分类模型,使用数据集进行参数训练,模型评估,然后使用模型进行分类预测,最后对模型进行保存,供后续使用。

数据集

数据集来源

猫狗数据集

探索性数据分析

查看待训练识别图片

from matplotlib import pyplot as plt
import os
import random

# 获取文件名
_,_,cat_images = next(os.walk('../../dataset/kagglecatsanddogs_5340/PetImages/Cat'))

# 准备3*3 图表
fig, ax = plt.subplots(3, 3, figsize=(20, 10))
# 随机选择一幅图像并绘制
for idx, img in enumerate(random.sample(cat_images, 9)):
    img_read = plt.imread('../../dataset/kagglecatsanddogs_5340/PetImages/Cat/' + img)
    ax[int(idx / 3), idx % 3].imshow(img_read)
    ax[int(idx / 3), idx % 3].set_title('cat/' + img)
    ax[int(idx / 3), idx % 3].axis('off')
plt.show()

查看狗图片类似,将Cat目录换成Dog即可

image-20240614224011275

数据预处理

数据集分割

由于下载的图片猫和狗各在一个文件夹内,如下:

image-20240613223710122

需要将数据按80%:20%进行分割,分为训练集和测试集。目录结构如下:

image-20240613223517479

下面进行数据拆分,核心代码(以猫图片为例)如下:

# 训练数据集80% 测试数据集20%
train_size = 0.8
# 获取猫图像数量
_, _, cat_images = next(os.walk(src_folder+'Cat/'))
num_cat_images = len(cat_images)
num_cat_images_train = int(train_size * num_cat_images)
num_cat_images_test = num_cat_images - num_cat_images_train
# 分割猫图像
cat_train_images = random.sample(cat_images, num_cat_images_train)
for img in cat_train_images:
	shutil.copy(src=src_folder+'Cat/'+img, dst=src_folder+'Train/Cat/')
cat_test_images  = [img for img in cat_images if img not in cat_train_images]
for img in cat_test_images:
	shutil.copy(src=src_folder+'Cat/'+img, dst=src_folder+'Test/Cat/')
	

数据预处理

这一步要将分割后的数据集转成和模型结构匹配的数据类型。使用keras提供的ImageDataGenerator类和flow_from_directory()方法

ImageDataGenerator类:图像增强类,可以进行图像旋转、图像平移、水平翻转、图像缩放等操作;

flow_from_directory()方法:ImageDataGenerator类的方法,支持以图像路径为输入,按批次加载图像到内存,防止训练数据量过大,机器内存不足问题;还支持对图像进行预处理操作,例如尺寸缩放和图像增强

# 训练数据预处理
training_data_generator = ImageDataGenerator(rescale=1./255)
training_set = training_data_generator.flow_from_directory('../../dataset/kagglecatsanddogs_5340/PetImages/train/',target_size=(32, 32),batch_size=16,class_mode='binary')

# 测试数据预处理
testing_data_generator = ImageDataGenerator(rescale= 1./255)
testing_set = testing_data_generator.flow_from_directory('../../dataset/kagglecatsanddogs_5340/PetImages/test/',target_size=(32, 32),batch_size=16, class_mode='binary')

构建模型并训练

构建模型

# 定义超参数
# 特征滤波器尺寸
FILTER_SIZE = 3
# 特征滤波器数量
FILTER_NUM = 32
# 图片输入尺寸
INPUT_SIZE = 32
# 最大池化尺寸
MAXPOOL_SIZE = 2
# 批量处理图片的大小
BATCH_SIZE = 16
STEPS_PER_EPOCH = 20000 // BATCH_SIZE
# 训练轮次
EPOCHS = 10
# 定义模型
model = Sequential()
# 添加卷积、池化层 提取特征
model.add(Conv2D(FILTER_NUM, (FILTER_SIZE, FILTER_SIZE), input_shape=(INPUT_SIZE, INPUT_SIZE, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(MAXPOOL_SIZE,MAXPOOL_SIZE)))
# 再添加卷积、池化层 提取特征
model.add(Conv2D(FILTER_NUM, (FILTER_SIZE, FILTER_SIZE), input_shape=(INPUT_SIZE, INPUT_SIZE, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(MAXPOOL_SIZE,MAXPOOL_SIZE)))
# 对输出结果进行降维处理,转成一维张量
model.add(Flatten())
# 添加全链接层,根据特征进行分类预测
model.add(Dense(units=128, activation='relu'))
# 添加dropout层,随机将一部分输入设置为0,防止模型复杂,出现过拟合现象
model.add(Dropout(0.5))
# 添加输出层,一个节点
model.add(Dense(units=1, activation='sigmoid'))

该模型结构分为,卷积池化层,卷积池化层,Flatten层,全链接层1,全链接层2(输出层)如下:

image-20240614221747896

其中,第一列是神经网络的层,第二列是每层的输出形状,第三层是每层训练的参数

可以看到,该模型图像输入尺寸是(32,32),经过一层卷积(32个特征过滤器)输出为(30,30,32),经过一层最大池化层,输出为(15,15,32);其中特征滤波器尺寸为3*3,所以滤波后的尺寸会是32-(3-1)=30,经过最大池化(2x2)尺寸减半,为15。

训练模型

# 模型训练
model.fit(training_set, steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS, verbose=1)
image-20240614123217943

结果分析与评估

model.evaluate(testing_set,steps=len(testing_set),verbose=1)
image-20240614222619817

准确度达到了0.7856

模型保存

from joblib import dump, load

# 模型持久化 到磁盘
dump(model, './猫狗分类.onnx')

结果预测

引入保存模型,随机选取一张图片进行预测分类

from matplotlib import pyplot as plt
fig, ax = plt.subplots()
img = plt.imread('../../dataset/kagglecatsanddogs_5340/PetImages/Dog/6.jpg')
ax.imshow(img)
plt.show()
image-20240614223543720
from joblib import dump, load
model = load('./猫狗分类.onnx')

from tensorflow.keras.preprocessing.image import img_to_array,load_img

img = load_img('../../dataset/kagglecatsanddogs_5340/PetImages/Dog/6.jpg',target_size=(32,32))
img = img_to_array(img)
img /= 255
import numpy as np
img_array = np.expand_dims(img, axis=0)
print(img_array.shape)
model.predict(img_array)

在这里插入图片描述

由于是二元分类,0和1分别表示猫狗,输出概率接近表示是狗,接近0表示是猫狗。但具体为啥0表示猫1表示狗而不是反过来表示,还待研究。

经验总结

1 在使用next()加载图像时,要确保路径正确,否则会报StopIteration错误,原因是路径错误,找不到可迭代的数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1835773.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

我主编的电子技术实验手册(08)——串联电阻分压

本专栏是笔者主编教材(图0所示)的电子版,依托简易的元器件和仪表安排了30多个实验,主要面向经费不太充足的中高职院校。每个实验都安排了必不可少的【预习知识】,精心设计的【实验步骤】,全面丰富的【思考习…

由于bug造成truncate table卡住问题

客户反应truncate table卡主,检查awr发现多个truncate在awr报告期内一直没执行完,如下: 检查ash,truncate table表的等待事件都是“enq: RO - fast object reuse”和“local write wait” 查找“enq: RO - fast object reuse”&am…

qmt量化交易策略小白学习笔记第35期【qmt编程之指数数据--如何获取指数行情数据】

qmt编程之获取沪深指数数据 qmt更加详细的教程方法,会持续慢慢梳理。 也可找寻博主的历史文章,搜索关键词查看解决方案 ! 感谢关注,咨询免费开通量化回测与获取实盘权限,欢迎和博主联系! 获取指数行情数…

机器学习笔记——无监督学习下的k均值聚类

k均值聚类算法原理 目标是将样本分类 原理:首先随机选择k何点作为中心,然后计算每一个点到中心的聚类,然后计算到每个中心的距离,选择到中心最短距离的那个中心所在的类进行归类,然后更新中心点,一直重复。…

TikTok带货崛起:从短视频平台到电商平台的转型

近年来,随着社交媒体的迅速发展,TikTok以其独特的短视频形式迅速在全球范围内风靡,不仅吸引了海量用户,还成功地抓住了年轻一代的注意力。随着用户量的激增和内容生态的丰富,TikTok也开始探索新的商业模式,…

看完这篇文章你才能了解什么是大模型

引言 近年来,人工智能(AI)技术迅速崛起,成为全球科技领域的热门话题。大模型(Large Language Model)技术以其庞大的参数和复杂的结构,为AI提供了强大的计算和学习能力,推动着AI技术…

Python热涨落流体力学求解算法和英伟达人工智能核评估模型

🎯要点 🎯平流扩散简单离散微分算子 | 🎯相场模拟:简单旋节线分解、枝晶凝固的 | 🎯求解二维波动方程,离散化时间导数 🎯英伟达 A100 人工智能核性能评估模型 | 🎯热涨落流体动力学…

算法基础精选题单 枚举 (合适的枚举顺序+合适的枚举内容+前缀和和差分) (个人题解)

前言: 今日第一份题解,题目主要是于枚举有关,枚举算是算法题中较为简单的部分了(对我来说还是有些难想的),话不多说,见下。 正文: 题单:237题】算法基础精选题单_ACM竞赛…

6.18 作业

qt中用定时器实现闹钟 头文件 #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include<QTime>//时间类 #include<QTimer>//时间事件类 #include<QtTextToSpeech/QTextToSpeech>//文本转语音类QT_BEGIN_NAMESPACE namespace Ui { class Widge…

【Linux环境下Hadoop部署—Xshell6】解决“要继续使用此程序,您必须应用最新的更新或使用新版本。”

问题描述 打开xshell使用&#xff0c;弹出&#xff1a; 解决方案&#xff1a; 修改安装目录下面的 nsilense.dll 文件 用二进制编辑器&#xff08;notepad的HEX-Editor插件&#xff09;打开Xshell/Xftp安装目录下的 nslicense.dll。 HexEdit插件安装&#xff1a; 1.下载HexEdi…

leetcode21 合并两个有序单链表

将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 示例 1&#xff1a; 输入&#xff1a;l1 [1,2,4], l2 [1,3,4] 输出&#xff1a;[1,1,2,3,4,4]示例 2&#xff1a; 输入&#xff1a;l1 [], l2 [] 输出&#xff1a;[]示例…

Java23种设计模式(二)

1、单例模式 单例模式&#xff08;Singleton Pattern&#xff09;是 Java 中最简单的设计模式之一。这种类型的设计模式属于创建型模式&#xff0c;它提供了一种创建对象的最佳方式。 这种模式涉及到一个单一的类&#xff0c;该类负责创建自己的对象&#xff0c;同时确保只有…

MySQL日志——redolog

redo log&#xff08;重做日志&#xff09; 为什么需要redo log&#xff1f; 在mysql提交一个事务后&#xff0c;这个事务所作的数据修改并不会直接保存到磁盘文件中&#xff0c;而是先保存在buffer pool缓冲区中&#xff0c;在需要读取数据时&#xff0c;先从缓冲区中找&…

【MySQL进阶之路 | 高级篇】SQL执行过程

1. 客户端与服务器的连接 运行中的服务器程序与客户端程序本质上都是计算机的一个进程&#xff0c;所以客户端进程向服务器端进程发送请求并得到相应过程的本质就是一个进程间通信的过程. 我们可以使用TCP/IP网络通信协议&#xff0c;命名管道和共享内存等方式&#xff0c;实…

胡说八道(24.6.17)——STM32以及通信杂谈

之前的文章中咱们谈到了STM32的时钟&#xff0c;今天我们来联系实际&#xff0c;来看看内部时钟下和外部时钟下的两种不同时钟的电平翻转。本次终于有硬件了&#xff0c;是最基础的STM32F103C8T6。 首先是&#xff0c;内部时钟的配置操作。 系统的内部时钟是72MHz&#xff0c;由…

IPython 使用技巧整理

IPython 是一个强大的交互式 Python shell&#xff0c;广泛用于数据分析、科学计算和开发工作。本文将整理一些 IPython 的实用技巧&#xff0c;帮助你更高效地使用 IPython。 目录 快速启动和退出魔法命令高效的代码编写变量和对象信息历史命令IPython 扩展错误调试与 Jupy…

30v-180V降3.3V100mA恒压WT5107

30v-180V降3.3V100mA恒压WT5107 WT5107是一款恒压单片机供电芯片&#xff0c;它可以30V-180V直流电转换成稳定的3.3V直流电&#xff08;最大输出电流300mA&#xff09;&#xff0c;为各种单片机供电。WT5107的应用也非常广泛。它可以用于智能家居、LED照明、电子玩具等领域。比…

关于glibc-all-in-one下载libc2.35以上报错问题

./download libc版本 下载2.35时报错&#xff1a;原因是缺少解压工具zstd sudo apt-get install zstd 下载后重新输命令就可以了 附加xclibc命令 xclibc -x ./pwn ./libc-版本 ldd pwn文件 xclibc -c libc版本

rds2212控制台+license-server4.5版本控制台无法获取验证码的解决方案(by lqw)

这两个的控制台的日志信息报错如下&#xff1a; 原因&#xff1a; 使用的jdk不支持awt的字体 解决方案&#xff1a; 更换jdk&#xff0c;重新配置jdk环境变量&#xff0c;或者安装fontconfig组件 yum install -y fontconfig

逆向分析-Ollydbg动态跟踪Ransomware.exe恶意锁机程序

1.认识Ollydbg Ollydbg是一个新的动态追踪工具&#xff0c;将IDA与SoftICE结合起来的思想&#xff0c;Ring 3级调试器&#xff0c;非常容易上手&#xff0c;己代替SoftICE成为当今最为流行的调试解密工具了。同时还支持插件扩展功能&#xff0c;是目前最强大的调试工具。 Oll…