Keras实战之图像分类识别

news2024/11/29 10:50:41

文章目录

    • 整体流程
      • 数据加载与预处理
      • 搭建网络模型
      • 优化网络模型
        • 学习率
        • Drop-out操作
        • 权重初始化方法对比
        • 正则化
        • 加载模型进行测试

实战:利用Keras框架搭建神经网络模型实现基本图像分类识别,使用自己的数据集进行训练测试。

问:为什么选择Keras?
答:使用Keras便捷快速。用起来简单,入门容易,上手快。没有tensorflow那么复杂的规范。

整体流程

  1. 读取数据
  2. 数据预处理
  3. 切分数据集(分为训练集和测试集)
  4. 搭建网络模型(初始化参数)
  5. 训练网络模型
  6. 评估测试模型(通过对比不同参数下损失函数不断优化模型)
  7. 保存模型到本地

(1)手动配置参数,设置数据存储路径、模型保存路径、图片保存路径

# 输入参数,手动设置数据存储路径、模型保存路径、图片保存路径等
ap = argparse.ArgumentParser()
ap.add_argument("-d", "--dataset", required=True,
	help="path to input dataset of images")
ap.add_argument("-m", "--model", required=True,
	help="path to output trained model")
ap.add_argument("-l", "--label-bin", required=True,
	help="path to output label binarizer")
ap.add_argument("-p", "--plot", required=True,
	help="path to output accuracy/loss plot")
args = vars(ap.parse_args())

在这里插入图片描述

数据加载与预处理

# 拿到图像数据路径,方便后续读取
imagePaths = sorted(list(utils_paths.list_images(args["dataset"])))
random.seed(42)
random.shuffle(imagePaths)
# 数据洗牌前设置随机种子确保后面调参过程中训练数据集一样

# 遍历读取数据
for imagePath in imagePaths:
	# 读取图像数据,由于使用神经网络,需要输入数据给定成一维
	image = cv2.imread(imagePath)
	# 而最初获取的图像数据是三维的,则需要将三维数据进行拉长
	image = cv2.resize(image, (32, 32)).flatten()
	data.append(image)

	# 读取标签,通过读取数据存储位置文件夹来判断图片标签
	label = imagePath.split(os.path.sep)[-2]
	labels.append(label)

# scale图像数据,归一化
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)

# 转换标签,one-hot格式
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.transform(testY)

数据预处理:①通过数据除以255进行数据归一化;②对数据标签进行格式转换。

搭建网络模型

  1. 创建序列结构
model = Sequential()
  1. 添加全连接层
  • 第一层全连接层Dense设计512个神经元,当前输入特征个数(输入神经元个数)为3072,设置激活函数为"relu";
  • 第二层设计256个神经元;
  • 第三层设计类别数个神经元(即3个),并作softmax操作得到最终分类类别。
# 第一层
model.add(Dense(512, input_shape=(3072,),activation="relu"))
# 第二层
model.add(Dense(256, activation="relu",))
# 第三层
model.add(Dense(len(lb.classes_), activation="softmax",))
  1. 初始化参数
# 学习率
INIT_LR = 0.01
# 迭代次数
EPOCHS = 200
  1. 训练网络模型
# 给定损失函数和评估方法
opt = SGD(lr=INIT_LR) # 指定优化器为梯度下降的优化器
model.compile(loss="categorical_crossentropy", optimizer=opt,
	metrics=["accuracy"])

# 训练网络模型
H = model.fit(trainX, trainY, validation_data=(testX, testY),
	epochs=EPOCHS, batch_size=32)
  1. 测试网络模型

使用上面训练所得网络模型对测试集进行预测,并对比预测解国和数据集真实结果打印结果报告(包括准确率、recall、f1-score),并将损失函数以折线图的效果直观展示出来

predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1),
	predictions.argmax(axis=1), target_names=lb.classes_))
  1. 评估结果
    在这里插入图片描述
    在这里插入图片描述
    从损失函数图像中可看出,模型出现明显过拟合现象,故而该初始参数所构建的模型效果较差,需要通过调参优化模型。

优化网络模型

学习率

对比学习率为0.01和0.001的损失函数图像。

在这里插入图片描述

train_loss与val_loss之间差异仍然存在,但是可看出学习率越大,过拟合现象越明显。

Drop-out操作

Dropout操作:在搭建网络模型中,通过设置一0到1范围内的参数从而防止过拟合。
在这里插入图片描述

权重初始化方法对比

(1)RandomNormal随机高斯初始化

kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)
model.add(Dense(512, input_shape=(3072,),activation="relu",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))
model.add(Dense(256, activation="relu",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer =initializers.random_normal(mean=0.0,stddev=0.05)))

在这里插入图片描述
图中可看出,添加RandomNormal初始化后,过拟合现象减弱了一丢丢。

(2)TruncatedNormal截断

kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)

相比于正常高斯分布截断了两边,只取小于2倍stddev的值

model.add(Dense(512, input_shape=(3072,), activation="relu" ,kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))
model.add(Dense(256, activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05)))

在这里插入图片描述

对比stddev取不同值时的loss函数图可得,TruncatedNormal中stddev值越小,过拟合风险越低,模型效果越好。TruncatedNormal消除过拟合的效果RandomNormal好。

正则化
kernel_regularizer=regularizers.l2(0.01)

正则化后,损失函数loss = 初始loss + aR(W)。正则化惩罚W,让稳定的W减少过拟合。

model.add(Dense(512, input_shape=(3072,), activation="relu" ,kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(256, activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))
model.add(Dense(len(lb.classes_), activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None),kernel_regularizer=regularizers.l2(0.01)))

对比正则化前后取迭代150到200的loss波动图,可发现正则化后虽然开始时loss值较大,但后期过拟合现象有明显减弱
在这里插入图片描述
再对比正则化参数l2 = 0.01和0.05的结果可得,l2越大,W的惩罚力度越大,过拟合风险越小
在这里插入图片描述

加载模型进行测试
# 导入所需工具包
from keras.models import load_model
import argparse
import pickle
import cv2

# 设置输入参数
ap = argparse.ArgumentParser()
ap.add_argument("-i", "--image", required=True,
	help="path to input image we are going to classify")
ap.add_argument("-m", "--model", required=True,
	help="path to trained Keras model")
ap.add_argument("-l", "--label-bin", required=True,
	help="path to label binarizer")
ap.add_argument("-w", "--width", type=int, default=28,
	help="target spatial dimension width")
ap.add_argument("-e", "--height", type=int, default=28,
	help="target spatial dimension height")
ap.add_argument("-f", "--flatten", type=int, default=-1,
	help="whether or not we should flatten the image")
args = vars(ap.parse_args())

# 加载测试数据并进行相同预处理操作
image = cv2.imread(args["image"])
output = image.copy()
image = cv2.resize(image, (args["width"], args["height"]))

# scale the pixel values to [0, 1]
image = image.astype("float") / 255.0

# 对图像进行拉平操作
image = image.flatten()
image = image.reshape((1, image.shape[0]))

# 读取模型和标签
print("[INFO] loading network and label binarizer...")
model = load_model(args["model"])
lb = pickle.loads(open(args["label_bin"], "rb").read())

# 预测
preds = model.predict(image)

# 得到预测结果以及其对应的标签
i = preds.argmax(axis=1)[0]
label = lb.classes_[i]

# 在图像中把结果画出来
text = "{}: {:.2f}%".format(label, preds[0][i] * 100)
cv2.putText(output, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,
	(0, 0, 255), 2)

# 绘图
cv2.imshow("Image", output)
cv2.waitKey(0)

分类结果:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
通过预测结果可得:该模型在预测猫上存在较大误差,在预测熊猫上较为准确。或许改进增加迭代次数可进一步优化模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1902586.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

全网最详细的Appium自动化测试框架(一)环境搭建

一、环境搭建 1、安装python3 2、安装appium-destop 3 、安装python虚拟环境 ,安装依赖库 : pip install Appium-Python-Client pip install pytest 4、安装java brew install java 配置好环境变量 5、安装 android-platform-tools (也可以用android sdk 工…

数据库概念题总结

1、 2、简述数据库设计过程中,每个设计阶段的任务 需求分析阶段:从现实业务中获取数据表单,报表等分析系统的数据特征,数据类型,数据约束描述系统的数据关系,数据处理要求建立系统的数据字典数据库设计…

C++11|包装器

目录 引入 一、function包装器 1.1包装器使用 1.2包装器解决类型复杂 二、bind包装器 引入 在我们学过的回调中,函数指针,仿函数,lambda都可以完成,但他们都有一个缺点,就是类型的推导复杂性,从而会…

【TORCH】绘制权重分布直方图,权重torch.fmod对torch.normal生成的随机数进行取模运算

要绘制上述代码中权重初始化的分布,可以分别展示每一层初始化权重的直方图。我们将用 torch.fmod 对 torch.normal 生成的随机数进行取模运算,确保权重值在 -2 到 2 之间。 含义解释 torch.normal(0, init_sd, size...):生成服从均值为 0、…

编译Open Cascade(OCC)并使用C#进行开发

说明: VS版本:Visual Studio Community 2022系统:Windows 11 专业版23H2Open CASCADE:v7.7.0(链接:https://pan.baidu.com/s/1-o1s4z3cjpYf5XkwhSDspQ?pwdp9i5提取码:p9i5) 下载和…

【Kafka】Kafka生产者开启幂等性后报错:Cluster authorization failed.

文章目录 背景解决服务端配置ACL增加授权 背景 用户业务需求,需要开启生产者的幂等性,生产者加了配置:enable.idempotence true用户使用的集群开启了ACL认证:SASL_PLAINTEXT/SCRAM-SHA-512用户生产消息时报错:org.ap…

[笔记] 卷积 - 02 滤波器在时域的等效形式

1.讨论 这里主要对时域和频域的卷积运算的特征做了讨论,特别是狄拉克函数的物理意义。 关于狄拉克函数,参考这个帖子:https://zhuanlan.zhihu.com/p/345809392 1.狄拉克函数提到的好函数的基本特征是能够快速衰减,对吧&#xf…

VBA提取word表格内容到excel

这是一段提取word表格中部分内容的vb代码。 Sub 提取word表格() mypath ThisWorkbook.Path & "\"myname Dir(mypath & "*.doc*")n 4 index of rowsRange("A1:F1") Array("课程代码", "课程名称", "专业&…

云服务器在 Web 应用程序中作用

云服务器在Web应用程序中扮演着至关重要的角色,它不仅是现代Web应用程序的基石,还是推动业务发展和提升用户体验的关键技术之一。下面将详细探讨云服务器在Web应用程序中的重要作用及其优势。 首先,云服务器为Web应用程序提供了高度可扩展的…

蜂窝物联粮仓环境在线监测系统,确保粮食安全

在金黄的麦田里,每一粒小麦都承载着农民的辛勤与期待。为了保证这些宝贵粮食的品质与安全,储存环节显得尤为重要。传统的粮仓管理方式已难以满足现代粮食储存的需求,因此,引入智慧粮仓环境监控系统成为了必然的选择。 一、为何需…

谷粒商城 - 树形菜单递归流查询、三级分类数据查询性能优化、Jmter 性能压测

目录 树形分类菜单(递归查询,强扩展) 1)需求 2)数据库表设计 3)实现 4)关于 asSequence 优化 性能压测 1)Jmeter 安装使用说明 2)中间件对性能的影响 三级分类数…

Python内存优化的实战技巧详解

概要 Python是一种高级编程语言,以其易读性和强大的功能而广受欢迎。然而,由于其动态类型和自动内存管理,Python在处理大量数据或高性能计算时,内存使用效率可能不如一些低级语言。本文将介绍几种Python内存优化的技巧,并提供相应的示例代码,帮助在开发中更高效地管理内…

uniapp启动安卓模拟器mumu

mumu模拟器下载 ADB: android debug bridge , 安卓调试桥,是一个多功能的命令行工具,他使你能够与连接的安卓设备进行交互 # adb连接安卓模拟器 adb connect 127.0.0.1:port # 查看adb设备 adb deviceshubuilderx 有内置的adb&a…

【鸿蒙学习笔记】@Link装饰器:父子双向同步

官方文档:Link装饰器:父子双向同步 目录标题 [Q&A] Link装饰器作用 [Q&A] Link装饰器特点样例:简单类型样例:数组类型样例:Map类型样例:Set类型样例:联合类型 [Q&A] Link装饰器作用…

锂电池寿命预测 | Matlab基于改进的遗传算法优化BP神经网络的锂离子电池健康状态SOH估计

目录 预测效果基本介绍程序设计参考资料 预测效果 基本介绍 主要流程如下: 1、首先提取“放电截止电压时间”作为锂电池间接健康因子; 2、然后引入改进的遗传算法对BP神经网络的模型参数进行优化。 3、最后 NASA 卓越预测中心的锂电池数据集 B0005、B0006、B0007对…

VSCode设置字体大小

方法1:Ctrl 和 Ctrl -,可以控制整个VSCode界面的整体缩放,但是不会调整字体大小 方法2:该方法只能设置编辑器界面的字号,无法改变窗口界面的字号。 (1)点开左下角如下图标,进入…

【JVM基础篇】Java垃圾回收器介绍

垃圾回收器(垃圾回收算法实现) 垃圾回收器是垃圾回收算法的具体实现。由于垃圾回收器分为年轻代和老年代,除了G1(既能管控新生代,也可以管控老年代)之外,新生代、老年代的垃圾回收器必须按照ho…

【Python】组合数据类型:序列,列表,元组,字典,集合

个人主页:【😊个人主页】 系列专栏:【❤️Python】 文章目录 前言组合数据类型序列类型序列常见的操作符列表列表操作len()append()insert()remove()index()sort()reverse()count() 元组三种序列类型的区别 集合类型四种操作符集合setfrozens…

tongweb 部署软航流版签一体化应用示例 提示跨域错误CORS ERROR

目录 问题现象与描述 解决办法 原理解析 什么是CORS 浏览器跨域请求限制 跨域问题解决方法 跨域请求流程 浏览器请求分类解析 http请求方法简介 问题现象与描述 重庆软航科技有限公司提供了一套针对针对word、excel等流式文件转换成PDF版式文件并进行版式文件在线签章…

什么是 DDoS 攻击及如何防护DDOS攻击

自进入互联网时代,网络安全问题就一直困扰着用户,尤其是DDOS攻击,一直威胁着用户的业务安全。而高防IP被广泛用于增强网络防护能力。今天我们就来了解下关于DDOS攻击,以及可以防护DDOS攻击的高防IP该如何正确选择使用。 一、什么是…