TensorFlow项目练手(二)——猫狗熊猫的分类任务

news2024/9/22 9:42:39

项目介绍

通过猫狗熊猫图片来对图片进行识别,分类出猫狗熊猫的概率,文章会分成两部分,从基础网络模型->利用卷积网络经典模型Vgg。

基础网络模型

基础的网络模型主要是用全连接层来分类,比较经典的方法,也是祖先最先使用的方法,目前已经在这类问题上,被卷积网络模型所替代,学习这部分是为了可以了解到最简单的分类任务的写法。

一、准备数据

  • 准备猫狗熊猫的训练数据集,各自1000张图片,分别放在/train/cats/train/dogs/train/panda
  • 准备猫狗熊猫的测试数据集,各5-10张,统一放在/test目录下,后续通过随机取出来测试

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

二、开始编写

1、获取数据

数据的获取主要包含2部分

  1. 先读取图片数据
  2. 对图片数据进行预处理
import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import layers

from keras.models import load_model
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import SGD
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
lb = LabelBinarizer()

import matplotlib.pyplot as plt
import random
import os
import numpy as np
np.set_printoptions(threshold=10000)
import cv2
import pickle

# 遍历所有文件名
def findAllFile(base):
    for root, ds, fs in os.walk(base):
        for f in fs:
            yield f

# 数据切分
def split_train(data,label,test_ratio):
    np.random.seed(43)
    shuffled_indices=np.random.permutation(len(data))
    test_set_size=int(len(data)*test_ratio)
    test_indices =shuffled_indices[:test_set_size]
    train_indices=shuffled_indices[test_set_size:]
    return data[train_indices],data[test_indices],label[train_indices],label[test_indices]

image_dir = ("./train/cats/", "./train/dogs/", "./train/panda/")
image_path = []
data = []
labels = []

# 读取图像路径
for path in image_dir:
    for i in findAllFile(path):
        image_path.append(path+i)

# 随机化数据
random.seed(43)
random.shuffle(image_path)

# 读取图像数据,读取label文件名数据
for j in image_path:
    image = cv2.imread(j)
    image = cv2.resize(image,(32,32)).flatten()
    data.append(image)
    label = j.split("/")[-2]
    labels.append(label)

# 数据预处理:规格化数据
data = np.array(data,dtype="float") / 255.0
labels = np.array(labels)
# 数据切分
(trainX,testX,trainY,testY) = split_train(data,labels,test_ratio=0.25)
# 将cat、dog、panda规格化数据
trainY = lb.fit_transform(trainY)
testY = lb.fit_transform(testY)

# 最终数据结果
print(trainX)
print(data)
print(data.shape) # (3000, 3072)32x32x3=3072,其图片3通道被拉长成一条操作
print(lb.classes_) # ['cats' 'dogs' 'panda']

将所有图片读取,并保存他们的数据集数据和训练结果,每张图片都会被规整到32x32并且进行拉长操作flatten(),最终输出的数据是一组图片的RGB数据

  • 数据集:我们将数据进行切分,25%作为验证集,75%数据作为训练集
  • 训练结果(label):我们按照文件名上进行分割,分割出对应的名字作为label

在这里插入图片描述

2、构建网络模型

  • 网络模型:采用全连接层
  • 优化器:使用梯度下降法SGD
  • 损失函数:使用分类算法
  • 权重初始化:高斯截断分布函数
# 2、创建模型层
EPOCHS = 200
model = Sequential()
model.add(Dense(512,input_shape=(3072,),activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)))
model.add(Dropout(0.5))
model.add(Dense(256,activation="relu",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)))
model.add(Dropout(0.5))
model.add(Dense(len(lb.classes_),activation="softmax",kernel_initializer = initializers.TruncatedNormal(mean=0.0, stddev=0.05, seed=None)))
# 损失函数和优化器,正则惩罚
model.compile(loss="categorical_crossentropy", optimizer=SGD(lr=0.001),metrics=["accuracy"])
H = model.fit(trainX, trainY, validation_data=(testX, testY),epochs=EPOCHS, batch_size=32)

在这里插入图片描述

3、模型评估

模型训练后之后,对模型进行评估,可以看到当前的分类情况

# 3、模型评估
predictions = model.predict(testX, batch_size=32)
print(classification_report(testY.argmax(axis=1),predictions.argmax(axis=1), target_names=lb.classes_))

在这里插入图片描述

4、数据可视化

将数据绘制在图上,看看其训练和预测的准确率情况,并将其保存起来

# 4、数据可视化
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, H.history["loss"], label="train_loss")
plt.plot(N, H.history["val_loss"], label="val_loss")
plt.plot(N, H.history["accuracy"], label="train_acc")
plt.plot(N, H.history["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy (Simple NN)")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig("./plot.png")

在这里插入图片描述

5、保存模型

# 5、保存模型到本地
model.save("./model")
f = open("./label.pickle", "wb")
f.write(pickle.dumps(lb))
f.close()

6、结果输出预测

随机获取测试集中的图片,对数据进行预处理后,进行预测,将结果显示出来

# 6、测试模型
test_image_dir =  "./test/"
test_image_path = []
for i in findAllFile(test_image_dir):
    test_image_path.append(test_image_dir+i)
test_image = random.sample(test_image_path, 1)[0]

# 数据预处理
image = cv2.imread(test_image)
output = image.copy()
image = image.astype("float") / 255.0
image = cv2.resize(image,(32,32)).flatten()
image = image.reshape((1, image.shape[0]))

# 加载模型
model = load_model("./model")
lb = pickle.loads(open("./label.pickle", "rb").read())
# 开始预测
preds = model.predict(image)

# 查看预测结果
text1 = "{}: {:.2f}% ".format(lb.classes_[0], preds[0][0] * 100)
text2 = "{}: {:.2f}% ".format(lb.classes_[1], preds[0][1] * 100)
text3 = "{}: {:.2f}% ".format(lb.classes_[2], preds[0][2] * 100)
cv2.putText(output, text1, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
cv2.putText(output, text2, (10, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
cv2.putText(output, text3, (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
cv2.imshow("Image", output)
cv2.waitKey(0)

通过我们的测试集合看出来,其准确率还是有点不尽如意,主要是数据集较小,且训练次数不足的原因导致

在这里插入图片描述

vgg模型

在后续的技术迭代中,卷积神经网络基本上已经覆盖了图像识别技术,使用卷积神经网络结合vgg的架构,可以更准确地提高准确率

一、准备数据

跟上面基础模型一样,所有数据都是一样的

二、开始编写

1、获取数据

获取数据的代码跟基础网络模型完全一致,唯一区别在于# image = cv2.resize(image,(32,32)).flatten() # 将图片resize到64,且去掉拉长操作

import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras import regularizers
from tensorflow.keras import layers

from keras.models import load_model
from keras.models import Sequential
from keras.layers import Dropout
from keras.layers.core import Dense
from keras.optimizers import SGD
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
lb = LabelBinarizer()

import matplotlib.pyplot as plt
import random
import os
import numpy as np
np.set_printoptions(threshold=10000)
import cv2
import pickle

# 遍历所有文件名
def findAllFile(base):
    for root, ds, fs in os.walk(base):
        for f in fs:
            yield f

# 数据切分
def split_train(data,label,test_ratio):
    np.random.seed(43)
    shuffled_indices=np.random.permutation(len(data))
    test_set_size=int(len(data)*test_ratio)
    test_indices =shuffled_indices[:test_set_size]
    train_indices=shuffled_indices[test_set_size:]
    return data[train_indices],data[test_indices],label[train_indices],label[test_indices]

image_dir = ("./train/cats/", "./train/dogs/", "./train/panda/")
image_path = []
data = []
labels = []

# 1、数据预处理
# 读取图像路径
for path in image_dir:
    for i in findAllFile(path):
        image_path.append(path+i)

# 随机化数据
random.seed(43)
random.shuffle(image_path)

# 读取图像数据,读取label文件名数据
for j in image_path:
    image = cv2.imread(j)
    image = cv2.resize(image,(64,64))
    # image = cv2.resize(image,(32,32)).flatten() # 将图片resize到64,且去掉拉长操作
    data.append(image)
    label = j.split("/")[-2]
    labels.append(label)

# 规格化数据
data = np.array(data,dtype="float") / 255.0
labels = np.array(labels)
# 数据切分
(trainX,testX,trainY,testY) = split_train(data,labels,test_ratio=0.25)
# 将cat、dog、panda规格化数据
trainY = lb.fit_transform(trainY)
testY = lb.fit_transform(testY)

# 最终数据结果
print(trainX)
print(data)
print(data.shape) # (3000, 3072)32x32x3=3072,其图片3通道被拉长成一条操作
print(lb.classes_) # ['cats' 'dogs' 'panda']

2、构建网络模型

采用vgg的框架,搭建最简单的vgg层数的网络模型

from keras.models import Sequential
from keras.layers.normalization.batch_normalization_v1 import BatchNormalization
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.initializers import TruncatedNormal
from keras.layers.core import Activation
from keras.layers.core import Flatten
from keras.layers.core import Dropout
from keras.layers.core import Dense


model = Sequential()
chanDim = 1
inputShape = (64, 64, 3)

model.add(Conv2D(32, (3, 3), padding="same",input_shape=inputShape))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
#model.add(Dropout(0.25))

# (CONV => RELU) * 2 => POOL 
model.add(Conv2D(64, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(64, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
#model.add(Dropout(0.25))

# (CONV => RELU) * 3 => POOL 
model.add(Conv2D(128, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(128, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(128, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
#model.add(Dropout(0.25))

# FC层
model.add(Flatten())
model.add(Dense(512))
model.add(Activation("relu"))
model.add(BatchNormalization())
#model.add(Dropout(0.6))

# softmax 分类,kernel_initializer=TruncatedNormal(mean=0.0, stddev=0.01)
model.add(Dense(len(lb.classes_)))
model.add(Activation("softmax"))

# 损失函数和优化器,正则惩罚
EPOCHS = 200
model.compile(loss="categorical_crossentropy", optimizer=SGD(lr=0.001),metrics=["accuracy"])
H = model.fit(trainX, trainY, validation_data=(testX, testY),epochs=EPOCHS, batch_size=32)

在这里插入图片描述

3、数据可视化

同样的操作,将数据绘制在图上,看看其训练和预测的准确率情况,并将其保存起来

# 4、数据可视化
N = np.arange(0, EPOCHS)
plt.style.use("ggplot")
plt.figure()
plt.plot(N, H.history["loss"], label="train_loss")
plt.plot(N, H.history["val_loss"], label="val_loss")
plt.plot(N, H.history["accuracy"], label="train_acc")
plt.plot(N, H.history["val_accuracy"], label="val_acc")
plt.title("Training Loss and Accuracy (Simple NN)")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig("./plot.png")

在这里插入图片描述

4、保存模型

# 5、保存模型到本地
model.save("./model")
f = open("./label.pickle", "wb")
f.write(pickle.dumps(lb))
f.close()

5、结果输出预测

同样的操作,唯一的区别在于

  • # image = cv2.resize(image,(32,32)).flatten() #不拉平,且改为64x64
  • # image = image.reshape((1, image.shape[0])) #数据改为数组
# 6、测试模型
test_image_dir =  "./test/"
test_image_path = []
for i in findAllFile(test_image_dir):
    test_image_path.append(test_image_dir+i)
test_image = random.sample(test_image_path, 1)[0]

# 数据预处理
image = cv2.imread(test_image)
output = image.copy()
image = image.astype("float") / 255.0
# image = cv2.resize(image,(32,32)).flatten() #不拉平,且改为64x64
# image = image.reshape((1, image.shape[0])) #数据改为数组
image = cv2.resize(image,(64,64))
image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))

# 加载模型
model = load_model("./model")
lb = pickle.loads(open("./label.pickle", "rb").read())
# 开始预测
preds = model.predict(image)

# 查看预测结果
text1 = "{}: {:.2f}% ".format(lb.classes_[0], preds[0][0] * 100)
text2 = "{}: {:.2f}% ".format(lb.classes_[1], preds[0][1] * 100)
text3 = "{}: {:.2f}% ".format(lb.classes_[2], preds[0][2] * 100)
cv2.putText(output, text1, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
cv2.putText(output, text2, (10, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
cv2.putText(output, text3, (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 0.7,(0, 0, 255), 2)
cv2.imshow("Image", output)
cv2.waitKey(0)

通过我们的结果展示,可以发现分类的精准度可以达到80-90%以上,说明这个模型是比基础网络模型50%左右的准确度好很多,但是,也会精准的分类错误,原因在于我们只有1000的数据集,比较容易分类错误,当然你的数据量越大,就可以解决当前的问题。
在这里插入图片描述

源代码

  • 源码查看

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/714018.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(c语言)给定两个数,求这两个数的最大公约数

目录 方法一 方法二&#xff1a;辗转相除法 方法一 找出两个数中的较小值&#xff0c;从较小值减至两个数%这个数0即可。 //给定两个数&#xff0c;求这两个数的最大公约数 #include <stdio.h>int main() {int a 0;int b 0;scanf("%d %d", &a, &…

【力扣】145、二叉树的后序遍历

145、二叉树的后序遍历 注&#xff1a;二叉树的后序遍历&#xff1a;左右根&#xff1b; // 递归 var postorderTraversal function (root){const arr [];//新建一个数组&#xff1b;const fun (node) >{if(node){fun(node.left);fun(node.right);arr.push(node.val)}}f…

STM32单片机蓝牙APP自动量程万用表电流电压电阻表LCD1602

实践制作DIY- GC0149---蓝牙APP自动量程万用表 基于STM32单片机设计---蓝牙APP自动量程万用表 二、功能介绍&#xff1a; STM32F103C系列最小系统板LCD1602显示器模拟开关信号选择电路电压采集电路电流测量电路&#xff08;康铜丝采样&#xff09;电阻测量电路1个黑色公共端子…

html掉落本地图片效果

实现一个加载本地图片并掉落的html页面。 说明 将DuanWu.html与zongzi_1.png, zongzi_2.png, zongzi_3.png, yadan.png4张图片放在同一个目录下&#xff0c;然后双击打开DuanWu.html即可。 使用Chrome或Microsoft Edge浏览器打开 若使用IE浏览器打开&#xff0c;下方会出现In…

Java框架之springboot starter

写在前面 本文一起看下springboot starter相关的内容。 1&#xff1a;官方提供的starter 在spring-boot-autocongure包中定义了官方提供的一百多个starter&#xff0c;如下&#xff1a; 2&#xff1a;框架是如何定义starter的&#xff1f; 因为springboot的普及度逐步提高&…

Unity Sponza(斯蓬扎宫)场景-BuildIn-URP-HDRP

Sponza&#xff08;斯蓬扎宫&#xff09;场景 &#x1f354;URP &#x1f354;URP 资源下载

Linux: hang: 线程太多,导致的一个例子

今天遇到了一个Linux系统hang住的情况&#xff0c;从vmcore里看bt&#xff0c;没有看到明显的crash、lockup等信息&#xff1b; 而且从vmcore里也不能看具体的当时CPU事情情况。 不过还是怀疑&#xff0c;是因为线程太多&#xff0c;导致资源占用比较严重&#xff0c;从而导致一…

GeForce RTX 40系列显卡哪个更好?这个避坑测评攻略快收下

自2022年底以来&#xff0c;Nvidia一直在推出基于Ada Lovelace架构的GeForce RTX 40系列消费级GPU&#xff0c;旨在取代之前基于Ampere架构的GeForce RTX 30系列和基于Turing架构的GeForce RTX 20系列。 Nvidia称其RTX 40系列GPU的性能比前代产品有了显着提升&#xff0c;许多…

写给新手程序员的一封信

为什么写这篇文章 我是一名毕业四年的后端开发&#xff08;可能会很多人来说&#xff0c;工作时间也没多长嘛&#xff09;&#xff0c;但是在这四年里&#xff0c;我写过PHP、Go、vue、做了两年多的敏捷团队管理&#xff0c;也设计过一些系统的架构。也算是有着相对较丰富的项…

数据结构-手撕单链表+代码详解

⭐️ 往期相关文章 ✨ 链接1&#xff1a;数据结构-手撕顺序表(动态版)代码详解 ✨ 链接2&#xff1a;数据结构和算法的概念以及时间复杂度空间复杂度详解 ⭐️ 链表 &#x1f320; 什么是链表&#xff1f; 链表是一种物理存储结构上非连续、非顺序的存储结构&#xff0c;数据…

OPLS-DA分析,组间差异 图形详解

OPLS-DA分析&#xff0c;组间差异 在上一场小工具讲解中&#xff0c;小姐姐给大家介绍了PLS-DA的原理及用途&#xff0c;而在代谢组学数据分析中&#xff0c;除去PLS-DA以外&#xff0c;OPLS-DA分析也是非常常见的&#xff0c;仅一个字母之差&#xff0c;那二者到底有何差别&am…

HTML select 用法及常用事件

前言 用于记录开发中常用到的&#xff0c;快捷开发 简单实例 <select><option value"volvo">Volvo</option><option value"saab">Saab</option><option value"mercedes">Mercedes</option><opt…

【WSN定位】基于浣熊优化算法的多通信半径和跳距加权Dvhop定位算法【Matlab代码#46】

文章目录 【可更换其他算法&#xff0c;获取资源请见文章第6节&#xff1a;资源获取】1. Dvhop定位算法2. 原始浣熊优化算法2.1 开发阶段2.2 探索阶段 3. 多通信半径和跳距加权策略3.1 多通信半径策略3.2 跳距加权策略 4. 部分代码展示5. 仿真结果展示6. 资源获取 【可更换其他…

超细,设计一个“完美“的测试用例,用户登录模块实例...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 好的测试用例一定…

ad18学习笔记七:drc检查在线和批量的区别

Altium Designer 22 DRC规则检查解析 - 哔哩哔哩 硬件工程师基本功&#xff1a;DRC设置要点详解-凡亿课堂 AD中PCB检查设计错误规则设置&#xff08;DRC检查配置&#xff09;_ad怎么检查pcb有没有错误_没有价值的生命的博客-CSDN博客 Altium Designer之DRC检查学习笔记_ad d…

【Linux】基于环形队列的生产者消费者模型的实现

文章目录 前言一、基于环形队列的生产者消费者模型的实现 前言 上一篇文章我们讲了信号量的几个接口和基于环形队列的生产者消费者模型&#xff0c;下面我们就快速来实现。 一、基于环形队列的生产者消费者模型的实现 首先我们创建三个文件&#xff0c;分别是makefile&#x…

【C++】医学影像归档和通信系统-PACS

一、PACS是通过DICOM3.0国际标准接口&#xff0c;将CT、超声、放射检查(DR)、核磁、磁共振&#xff08;MR)等多种医学影像以数字化的形式保存&#xff0c;提供授权方式查看和调回&#xff0c;并提供一些辅助诊断管理功能的系统。 二、PACS系统是HIS系统的重要组成部分&#xff…

强化学习路径优化:基于Q-learning算法的机器人路径优化(MATLAB)

一、强化学习之Q-learning算法 Q-learning算法是强化学习算法中的一种&#xff0c;该算法主要包含&#xff1a;Agent、状态、动作、环境、回报和惩罚。Q-learning算法通过机器人与环境不断地交换信息&#xff0c;来实现自我学习。Q-learning算法中的Q表是机器人与环境交互后的…

打造自己的分布式MinIO对象存储

MinIO是一个对象存储解决方案&#xff0c;它提供了一个与Amazon Web Services S3兼容的API&#xff0c;并支持所有核心S3特性。MinIO旨在部署在任何地方——公共云或私有云、裸机基础架构、协调环境和边缘基础架构。 分布式MinIO如何工作 Server Pool由多个Minio服务节点与其附…

OPNET Modeler 怎么修改背景颜色

OPNET Modeler 软件中除了顶层的网络模型&#xff0c;节点模型和进程模型中的默认背景色都是灰色的。 节点模型背景颜色如下图所示。 进程模型背景颜色如下图所示。 使用时间长了发现这个灰色背景对眼睛保护还真不错&#xff0c;而且在这种灰色背景下&#xff0c;你添加包流线…