基于tensorflow深度学习的猫狗分类识别

news2024/11/21 0:28:10

 3f6a7ab0347a4af1a75e6ebadee63fc1.gif

🤵‍♂️ 个人主页:@艾派森的个人主页

✍🏻作者简介:Python学习者
🐋 希望大家多多支持,我们一起进步!😄
如果文章对你有帮助的话,
欢迎评论 💬点赞👍🏻 收藏 📂加关注+


目录

实验背景

实验目的

实验环境

实验过程

1.加载数据

2.数据预处理

3.构建模型

4.训练模型

5.模型评估

源代码


实验背景

        近年来,深度学习在计算机视觉领域取得了巨大的成功,尤其是在图像分类任务上。图像分类是计算机视觉领域的基本问题之一,而猫狗分类作为图像分类中的经典问题,吸引了广泛的研究兴趣。猫狗分类问题具有很高的实际应用价值。在现实世界中,人们经常需要对动物进行分类,如在宠物识别、动物行为分析和动物保护等领域。传统的图像分类方法通常需要手工设计特征提取器和分类器,这在处理复杂的图像数据时面临着挑战。

        深度学习通过学习端到端的特征提取和分类模型,不需要手动设计特征提取器,因此在猫狗分类问题上具有巨大的潜力。卷积神经网络(Convolutional Neural Networks,简称CNN)是深度学习中最常用的模型之一,特别适用于图像数据的处理。猫狗分类问题的研究可以帮助我们深入理解深度学习在图像分类任务中的应用,并且可以为其他图像分类问题的研究提供经验和指导。此外,研究人员还可以通过比较不同深度学习模型的性能和对比传统方法的效果,评估深度学习在猫狗分类问题上的优势和局限性。

        此外,随着深度学习模型的不断发展和算力的提升,研究人员可以尝试更复杂的模型架构、数据增强技术和迁移学习方法,以进一步提高猫狗分类任务的准确性和鲁棒性。因此,基于深度学习的猫狗分类实验具有重要的研究价值,可以推动深度学习在图像分类领域的发展,同时为实际应用场景提供更好的解决方案。

实验目的

        本实验的目的是基于深度学习方法进行猫狗分类,通过设计和训练深度神经网络模型,实现对输入图像进行准确的猫狗分类。具体目标包括:

        1.建立一个高性能的猫狗分类模型:通过深度学习技术,构建一个能够从原始图像数据中自动学习到猫狗分类特征的神经网络模型。该模型能够准确地对输入图像进行分类,具备较高的分类准确率和泛化能力。

        2.探索不同深度学习模型的性能差异:比较不同深度学习模型(如卷积神经网络、残差网络等)在猫狗分类任务上的性能表现,评估它们的准确率、召回率、精确率等指标,并分析其优势和不足之处。

        3.优化模型性能:通过调整模型的超参数、网络结构以及训练策略等,进一步提高猫狗分类模型的性能。例如,可以尝试不同的激活函数、优化器、学习率调度等,以提高模型的收敛速度和泛化能力。

        4.数据增强和处理:应用数据增强技术,如随机裁剪、旋转、翻转等,扩充训练数据集的多样性,提高模型对于各种场景和变化的鲁棒性。同时,对原始图像数据进行预处理,如图像归一化、均衡化等,以便更好地适应模型输入要求。

        5.评估模型性能:使用独立的测试数据集对训练好的模型进行评估,计算分类准确率、混淆矩阵等指标,评估模型的性能。同时,可以与其他传统方法进行比较,验证基于深度学习的方法在猫狗分类问题上的优越性。

实验环境

Python3.9

Jupyter notebook

实验过程

1.加载数据

首先导入本次实验用到的第三方库

efd0993698ff41559c35b628cd72996f.png

 接着定义我们数据集的路径

821554dd344e4e84a8167bcbfc84883f.png

定义训练集、测试集、验证集生成器

88f0b8a8ddd8405e974afdfe2ab1d65e.png

 将生成器连接到文件夹中的数据a8b15d355bf9441d986e3e193352b175.png

 可视化一些数据图片,来个九宫格展示

f99d7afa18b64ddca0f33b91123197b0.png

 02d2e066cba24c14a33141bcc97ce52a.png

2.数据预处理

5bcd5225624c4ab3b56bdd3a5397160c.png

3.构建模型

构建模型、定义优化器

aec104b621444d7692be3c8ecf2c23a5.png

 保存模型

f47c6eb09b024b6a9f6b4647fac41818.png

4.训练模型

ae585ea527a54d319af96f54786679f3.png

5.模型评估

将模型训练和验证的损失可视化出来、以及训练和验证的准确率

44b1487c2e5940ec84c40667c66e1c0d.png

a88ce5d207424f51b77dd3304f317550.png

对验证数据集进行评估 

cbfe6ac76d774a7f9378c160c8c88d8c.png

对测试数据集进行评估 

 8073c95d0959412487d0bc957e92b63f.png

 将模型的混淆矩阵一热力图的形式展示4a5bdae21e8a42b8865523e83614d45c.png

946454010fc940f385af9ceaf036e80e.png

源代码

import numpy as np
import random
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.metrics import confusion_matrix
import seaborn as sns
sns.set(style='darkgrid', font_scale=1.4)
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.inception_resnet_v2 import InceptionResNetV2, preprocess_input
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# 数据集路径
train_dir = './train'
test_dir = './test'
CFG = dict(
    seed = 77,
    batch_size = 16,
    img_size = (299,299),
    epochs = 5,
    patience = 5
)
train_data_generator = ImageDataGenerator(
        validation_split=0.15,
        rotation_range=15,
        width_shift_range=0.1,
        height_shift_range=0.1,
        preprocessing_function=preprocess_input,
        shear_range=0.1,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

val_data_generator = ImageDataGenerator(preprocessing_function=preprocess_input, validation_split=0.15)
test_data_generator = ImageDataGenerator(preprocessing_function=preprocess_input)
# 将生成器连接到文件夹中的数据
train_generator = train_data_generator.flow_from_directory(train_dir, target_size=CFG['img_size'], shuffle=True, seed=CFG['seed'], class_mode='categorical', batch_size=CFG['batch_size'], subset="training")
validation_generator = val_data_generator.flow_from_directory(train_dir, target_size=CFG['img_size'], shuffle=False, seed=CFG['seed'], class_mode='categorical', batch_size=CFG['batch_size'], subset="validation")
test_generator = test_data_generator.flow_from_directory(test_dir, target_size=CFG['img_size'], shuffle=False, seed=CFG['seed'], class_mode='categorical', batch_size=CFG['batch_size'])

# 样本和类的数量
nb_train_samples = train_generator.samples
nb_validation_samples = validation_generator.samples
nb_test_samples = test_generator.samples
classes = list(train_generator.class_indices.keys())
print('Classes:'+str(classes))
num_classes = len(classes)
# 可视化一些例子
plt.figure(figsize=(15,15))
for i in range(9):
    ax = plt.subplot(3,3,i+1)
    ax.grid(False)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    batch = train_generator.next()
    imgs = (batch[0] + 1) * 127.5
    label = int(batch[1][0][0])
    image = imgs[0].astype('uint8')
    plt.imshow(image)
    plt.title('cat' if label==1 else 'dog')
plt.show()
base_model = InceptionResNetV2(weights='imagenet', include_top=False, input_shape=(CFG['img_size'][0], CFG['img_size'][1], 3))
x = base_model.output
x = Flatten()(x)
x = Dense(100, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax', kernel_initializer='random_uniform')(x)

# 构建模型
model = Model(inputs=base_model.input, outputs=predictions)

for layer in base_model.layers:
    layer.trainable = False
    
# 定义优化器
optimizer = Adam()
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
# 保存模型
save_checkpoint = keras.callbacks.ModelCheckpoint(filepath='model.h5', monitor='val_loss', save_best_only=True, verbose=1)
early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=CFG['patience'], verbose=True)
# 训练模型
history = model.fit(
        train_generator,
        steps_per_epoch=nb_train_samples // CFG['batch_size'],
        epochs=CFG['epochs'],
        callbacks=[save_checkpoint,early_stopping],
        validation_data=validation_generator,
        verbose=True,
        validation_steps=nb_validation_samples // CFG['batch_size'])
history_dict = history.history
loss_values = history_dict['loss']
val_loss_values = history_dict['val_loss']
epochs_x = range(1, len(loss_values) + 1)
plt.figure(figsize=(10,10))
plt.subplot(2,1,1)
plt.plot(epochs_x, loss_values, 'b-o', label='Training loss')
plt.plot(epochs_x, val_loss_values, 'r-o', label='Validation loss')
plt.title('Training and validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')

# Accuracy
plt.subplot(2,1,2)
acc_values = history_dict['accuracy']
val_acc_values = history_dict['val_accuracy']
plt.plot(epochs_x, acc_values, 'b-o', label='Training acc')
plt.plot(epochs_x, val_acc_values, 'r-o', label='Validation acc')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Acc')
plt.legend()
plt.tight_layout()
plt.show()
# 对验证数据集进行评估
score = model.evaluate(validation_generator, verbose=False)
print('Val loss:', score[0])
print('Val accuracy:', score[1])
# 对测试数据集进行评估
score = model.evaluate(test_generator, verbose=False)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
# 混淆矩阵
y_pred = np.argmax(model.predict(test_generator), axis=1)
cm = confusion_matrix(test_generator.classes, y_pred)

# 热力图
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cbar=True, cmap='Blues',xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.title('Confusion matrix')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/669212.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Triton教程 --- 速率限制器

Triton教程 — 速率限制器 Triton系列教程: 快速开始利用Triton部署你自己的模型Triton架构模型仓库存储代理模型设置优化动态批处理 速率限制器 速率限制器管理 Triton 在模型实例上调度请求的速率。 速率限制器在 Triton 中加载的所有模型上运行,以允许跨模型优…

带你用Python制作7个程序,让你感受到端午节的快乐

名字:阿玥的小东东 学习:Python、C/C 主页链接:阿玥的小东东的博客_CSDN博客-python&&c高级知识,过年必备,C/C知识讲解领域博主 目录 前言 程序1:制作粽子 程序2:龙舟比赛 程序3:艾草挂 程序4…

基于Java高校共享单车管理系统设计实现(源码+lw+部署文档+讲解等)

博主介绍: ✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战 ✌ 🍅 文末获取源码联系 🍅 👇🏻 精…

《网络安全0-100》网络安全的未来趋势

网络安全的未来趋势 网络安全是一个永恒的话题,随着技术的发展 和应用,网络安全也面临着新的挑战和威胁。 以下是网络安全未来的趋势: 人工智能和机器学习:人工智能和机器学习已 经成为网络安全领域的热门技术。未来&#xff…

编译原理笔记11:自上而下语法分析(1)基础概念、左递归和公共左因子处理、递归下降分析(咕咕咕)

目录 自上而下分析的一般方法用推导的方法分析输入序列左递归问题及其消除(消除左递归)消除直接左递归消除间接左递归左递归消除算法 公共左因子问题及其消除(提取左因子)提取左因子 递归下降分析 词法分析,是把源程序…

基于物联网及云平台的光伏运维系统

系统结构 在光伏变电站安装逆变器、以及多功能电力计量仪表,通过网关将采集的数据上传至服务器,并将数据进行集中存储管理。用户可以通过PC访问平台,及时获取分布式光伏电站的运行情况以及各逆变器运行状况。平台整体结构如图所示。 光伏背景…

Cortext-M3系列:调试组件(9)

1、调试组件简介 在 CM3 中有很多调试组件,使用它们可以执行各种调试功能:断点、数据观察点、闪存地址重载以及各种跟踪等。软件开发人员也许永远无需了解调试组 的细节,因为它们通常只是由调试器及其周边工具使用的。 本文对每种调试组件做一…

基于Java学生公寓管理中心系统设计实现(源码+lw+部署文档+讲解等)

博主介绍: ✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战 ✌ 🍅 文末获取源码联系 🍅 👇🏻 精…

IDEA2022.3.3支持Jrebel and Xrebel教程

目录 前言 思路 步骤 1、下载服务并启动 推荐下载windows环境的exe文件,直接点开就行。 如果用linux 需要安装go环境: 下载好后启动 2、idea安装Jrebel and XRebel插件 3、激活插件 前言 由于服务平台限制,只支持darwin、linux和windows环境。这…

(转载)无监督学习神经网络的分类(matlab实现)

对于监督学习神经网络,事先需要知道与输入相对应的期望输出,根据期望输出与网络输出间的偏差来调整网络的权值和阈值。然而,在大多数情况下,由于人们认知能力以及环境的限制,往往无法或者很难获得期望的输出&#xff0…

AbstractQueuedSynchronizer源码

介绍 基于队列的抽象同步器,它是jdk中所有显示的线程同步工具的基础,像ReentrantLock/DelayQueue/CountdownLatch等等,都是借助AQS实现的。 public abstract class AbstractQueuedSynchronizerextends AbstractOwnableSynchronizerimplemen…

Camera 基础知识点

和你一起终身学习,这里是程序员Android 经典好文推荐,通过阅读本文,您将收获以下知识点: 1.1 Camera 工作原理1.2 Camera 模组组成1.3 Camera 常见缩写解释1.4 Camera 部分名词解释1.5 参考文献 一、Camera 基础知识 1.1 Camera 工作原理 外部…

[进阶]Java:线程安全问题、取钱模拟

什么是线程安全问题? 多个线程,同时操作同一个共享资源的时候,可能会出现业务安全问题。 线程安全问题出现的原因? 存在多个线程在同时执行同时访问一个共享资源存在修改该共享资源 代码演示如下: 账户类&#xff…

深蓝学院C++基础与深度解析笔记 第 5 章 语句

1. 语句基础 ● 语句的常见类别 – 表达式语句:表达式后加分号,对表达式求值后丢弃,可能产生副作用 – 空语句:仅包含一个分号的语句,可能与循环一起工作 – 复合语句(语句体):由大…

软考A计划-系统集成项目管理工程师-信息系统集成及服务管理体系

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例点击跳转>软考全系列 👉关于作者 专注于Android/Unity和各种游戏开发技巧&#xff…

Flutter Dart 变量和内置类型

目录 一、变量 1.1 var 1.2 Object 1.3 dynamic 1.4 final与const 二、内置类型 2.1 num(数值) 2.2 Strings(字符串) 2.3 bool(布尔值) 2.4 List(列表) 2.5 Map(映射集…

Android apk 反编译后打包(含签名)

想分析某些app源码时,遇到烦人弹框,现在想反编译看看具体实现。 用到的工具: GDA4.06 apk反编译工具 apktool apk 打包工具 jdk 环境 一、反编译分析 将apk反编译打开 找到入口代码 弹框代码如图 二、解包、打包 使用apktool解包 ps: apktool工具…

unity游戏架构设计

1.unity架构的3个等级 EmptyGO 所有功能写一个脚本挂载object上面,没有单列manager。 Simple GameManager 写一个公用的管理器,方便调用 Manager of Managers 不同的类型的东西用不同的管理器【声音管理器,关卡管理器,】 2…

chatgpt赋能python:Python搜索快捷键

Python搜索快捷键 介绍 Python作为一门广泛应用在各个领域的编程语言,其强大的搜索功能也得到了广泛的应用和赞誉。但是,在日常的使用中,有时我们需要进行大量的搜索和筛选操作,这时候掌握一些Python搜索快捷键将能够极大地提高…

java入门2(运算符)

目录 运算符和C语言基本一样 算术运算符 单目运算符:自增自减运算符 比较运算符 逻辑运算符 位运算符(C语言好像没有) 优先级 交换算法 运算符和C语言基本一样 算术运算符 比如拆分一个三位数 public class java练习代码 {public…