基于卷积神经网络的猫种类的识别

news2024/11/28 4:48:42

1.介绍

图像分类是计算机视觉中的一个关键任务,而猫种类识别作为一个有趣且实用的应用场景,通过卷积神经网络(CNN)的模型能够识别猫的不同品种。在这篇博客中,将详细介绍如何利用深度学习技术构建模型,从而实现猫种类的自动识别。将探讨数据集的获取和预处理、模型的构建与训练,以及模型的评估和应用。

2.数据集准备与预处理

使用的数据集包含布偶猫、橘猫、蓝猫和虎斑猫等四种猫的图片。这个数据集包含了多个角度和不同环境下的猫的图像。在数据预处理阶段,加载了图像数据,并进行了多方面的处理:图片大小调整、灰度处理、归一化和标签编码等。致力于创建一个高质量、多样性和均衡性的数据集,以保证模型的有效性和泛化能力。

2.1 数据集准备

百度图片搜索猫咪图片 ——放到对应文件夹下

2.2 预处理 统一图片大小及后缀名

预处理是图像处理中至关重要的步骤之一,特别是在图像分类任务中。在进行猫种类识别之前,我们需要对图像进行预处理,以确保数据的一致性和适应模型的需求。在这个任务中,我们执行了两个关键的预处理步骤:统一图片大小和统一图片后缀名。

统一图片大小: 图像大小的不一致会对模型训练产生影响。因此,我们将所有的图像调整为相同的尺寸,例如将所有图像调整为100x100像素大小。通过这样的处理,我们确保了训练数据的一致性,使模型能够更好地学习和识别特征。

统一图片后缀名: 在数据收集过程中,不同来源的图像可能具有不同的后缀名或格式。为了方便统一处理,我们将所有图像的后缀名或格式转换为相同的格式,例如将所有图像统一保存为.jpg格式。这样做有利于数据加载和处理,并减少了数据集混乱性和不一致性所带来的问题。

在这里插入图片描述

import os
from PIL import Image
import random
import shutil

# 输入目录和输出目录
input_directory = '/kaggle/input/mycatdataset/'
output_directory = '/kaggle/working/data/'

# 创建输出目录
train_directory = os.path.join(output_directory, 'train')
test_directory = os.path.join(output_directory, 'test')
os.makedirs(train_directory, exist_ok=True)
os.makedirs(test_directory, exist_ok=True)

# 定义一个函数来处理文件名和图片大小,并分配到训练集和测试集中
def process_images(category_folder, prefix, train_output_directory, test_output_directory):
    category_path = os.path.join(input_directory, category_folder)
    
    files = os.listdir(category_path)
    random.shuffle(files)  # 打乱文件顺序

    # 将前80%的图片分配到训练集,剩余的20%分配到测试集
    num_train = int(len(files) * 0.8)
    train_files = files[:num_train]
    test_files = files[num_train:]

    for index, file in enumerate(train_files):
        image_path = os.path.join(category_path, file)
        img = Image.open(image_path)
        img = img.resize((100, 100))
        output_file_name = f"{prefix}{index}.jpg"
        output_path = os.path.join(train_output_directory, output_file_name)
        img.save(output_path)

    for index, file in enumerate(test_files):
        image_path = os.path.join(category_path, file)
        img = Image.open(image_path)
        img = img.resize((100, 100))
        output_file_name = f"{prefix}{index}.jpg"
        output_path = os.path.join(test_output_directory, output_file_name)
        img.save(output_path)

# 处理不同的猫的文件夹
process_images('布偶猫', '0_', train_directory, test_directory)
process_images('橘猫', '1_', train_directory, test_directory)
process_images('蓝猫', '2_', train_directory, test_directory)
process_images('虎斑猫', '3_', train_directory, test_directory)

3.训练

参考采用VGG模型

设计了一个深度卷积神经网络来处理猫种类的分类任务。该模型由卷积层、池化层和全连接层组成,具有良好的特征提取和抽象能力。使用了ReLU激活函数、Dropout层和Softmax输出层,并且选择了合适的损失函数和优化器。通过对模型进行反复调整和优化,确保了模型在处理猫图像时的有效性和鲁棒性。

import os
import numpy as np
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt

# 获取训练集和测试集的图片路径列表
train_path = '/kaggle/working/data/train'
test_path = '/kaggle/working/data/test'
train_images = os.listdir(train_path)
test_images = os.listdir(test_path)

# 读取图像并转换为Numpy数组的函数
def load_images(path, image_list):
    images = []
    labels = []
    for img in image_list:
        image = Image.open(os.path.join(path, img)).convert('RGB')
        image = image.resize((100, 100))  # 调整图像大小为100x100
        images.append(np.array(image))
        label = int(img.split('_')[0])  # 根据文件名提取标签
        labels.append(label)
    return np.array(images), np.array(labels)

# 读取训练集和测试集图像数据
x_train, y_train = load_images(train_path, train_images)
x_test, y_test = load_images(test_path, test_images)

# 数据预处理
x_train = x_train.astype('float32') / 255.0  # 将图像像素值缩放到[0, 1]范围内
x_test = x_test.astype('float32') / 255.0

y_train = tf.keras.utils.to_categorical(y_train)  # 将训练集标签进行one-hot编码
y_test = tf.keras.utils.to_categorical(y_test)

# 构建卷积神经网络模型
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(100, 100, 3)),
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),
    
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),
    
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(4, activation='softmax')
])

# 编译模型
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# 训练模型
history = model.fit(x_train, y_train, batch_size=20, epochs=50, validation_data=(x_test, y_test))

# 保存模型
model.save('/kaggle/working/cat_model.h5')

# 评估模型
score = model.evaluate(x_test, y_test)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
# 可视化测试结果
plt.figure(figsize=(12, 6))

# 绘制训练集和验证集的损失
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

# 绘制训练集和验证集的准确率
plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

在训练过程中,将数据集划分为训练集和验证集,并对模型进行了多轮次的训练。记录了训练和验证集上的损失值和准确率,以便了解模型的训练情况和性能表现。通过使用验证集进行评估,检查了模型的泛化能力和过拟合情况。

训练结果:
在这里插入图片描述

4.图像识别与预测

最后,使用训练好的模型对新的猫图像进行了预测。通过加载模型并使用其对待识别的图像进行预处理和推理,得到了预测的猫种类标签和模型对该标签的置信度。这项实验展示了模型在实际应用中的效果和可靠性。

import os
import numpy as np
from PIL import Image
import tensorflow as tf

# 加载模型
model = tf.keras.models.load_model('/kaggle/working/cat_model.h5')

# 定义类别标签
class_labels = {0: '布偶猫', 1: '橘猫', 2: '蓝猫', 3: '虎斑猫'}

# 准备待分类的图像
def load_image(file_path):
    img = Image.open(file_path).convert('RGB')
    img = img.resize((100, 100))
    img_array = np.array(img)
    img_array = img_array.astype('float32') / 255.0
    img_array = np.expand_dims(img_array, axis=0)  # 增加维度以符合模型输入格式
    return img_array

# 用模型进行图像分类
def predict_image(image_path):
    img = load_image(image_path)
    predictions = model.predict(img)
    predicted_class = np.argmax(predictions)
    predicted_label = class_labels[predicted_class]
    confidence = predictions[0][predicted_class]
    return predicted_label, confidence

# 指定待分类的图像路径
image_path_to_classify = '/kaggle/input/cat-image/85d7dcff30734677667a1c8f3aa860a5.jpeg'

# 进行图像分类预测
predicted_label, confidence = predict_image(image_path_to_classify)
print(f'Predicted Label: {predicted_label}')
print(f'Confidence: {confidence * 100:.2f}%')

在这里插入图片描述

模型下载:传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1216357.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

gd32 USB HOST 接口

接口 CPU引脚 复用 DM PB14 USBHS_DM AF12 DP PB15 USBHS_DP AF12

互联网上门预约洗衣洗鞋店小程序;

拽牛科技干洗店洗鞋店软件,方便快捷,让你轻松洗衣。只需在线预约洗衣洗鞋服务,附近的门店立即上门取送,省心省力。轻松了解品牌线下门店,通过列表形式展示周围门店信息,自动选择最近门店为你服务。简单填写…

【Linux专题】SFTP 用户配置 ChrootDirectory

【赠送】IT技术视频教程,白拿不谢!思科、华为、红帽、数据库、云计算等等https://xmws-it.blog.csdn.net/article/details/117297837?spm1001.2014.3001.5502 红帽认证 认证课程介绍:红帽RHCE9.0学什么内容,新版有什么变化-CSDN…

【带头学C++】----- 七、链表 ---- 7.1 链表的概述

目录 七、链表 7.1 链表的是什么? 7.2数组和链表的优点和缺点 7.3 链表概述 ​编辑 7.4 设计静态链表 7.4.1 定义一个结点(结构体) 7.4.2 使用头结点构建一个单向链表 七、链表 7.1 链表的是什么? C链表是一种数据结构&a…

如何构建风险矩阵?3大注意事项

风险矩阵法(RMA)是确定威胁优先级别的最有效工具之一,可以帮助项目团队识别和评估项目中的风险,帮助项目团队对风险进行排序,清晰地展示风险的可能性和严重性,为项目团队制定风险管理策略提供依据。 如果没…

SecureCRT\\FX:打造安全可靠的终端模拟器和FTP客户端

在现代的工作环境中,远程连接和文件传输是不可或缺的任务。而SecureCRT\\FX作为一款安全可靠的终端模拟器和FTP客户端,将帮助您高效管理远程连接和文件传输。 SecureCRT\\FX提供了强大的终端模拟功能,支持SSH、Telnet、RDP等多种协议&#x…

92.Linux的僵死进程以及处理方法

目录 1.什么是僵死进程? 2.代码演示僵死进程 3.解决办法 1.什么是僵死进程? 僵死进程是指一个子进程在父进程之前结束,但父进程没有正确地等待(使用 wait 或 waitpid 等系统调用)来获取子进程的退出状态。当一个进…

流程图怎么画,用什么软件做?一文弄懂流程图:从流程图的定义、流程图各种图形的含义到流程图制作,一步到位!

流程图,也被称为过程流程图或流程图,是一种表达工作或过程中步骤之间逻辑关系的可视化工具。它主要由不同形状和符号的框以及指向这些框的箭头组成。每个形状或符号都有特定的含义,它们代表了工作流程中的一种特定类型的步骤或动作。 使用流…

视频集中存储/云存储平台EasyCVR级联下级平台的详细步骤

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台可拓展性强、视频能力灵活、部署轻快,可支持的主流标准协议有国标GB28181、RTSP/Onvif、RTMP等,以及支持厂家私有协议与SDK接入,包括海康Ehome、海大宇等设备的SDK等。平台既具备传统安…

JVM bash:jmap:未找到命令 解决

如果我们在使用JVM的jmap命令时遇到了"bash: jmap: 未找到命令"的错误,这可能是因为jmap命令没有在系统的可执行路径中。 要解决这个问题,可以尝试以下几种方法: 1. 检查Java安装:确保您已正确安装了Java Development …

【Android】导入三方jar包/系统的framework.jar

1.Android.mk导包 1).jar包位置 与res和src同一级的libs中(没有就新建) 2).Android.mk文件 LOCAL_STATIC_ANDROID_LIBRARIES:android静态库,经常用于一些support的导包 LOCAL_JAVA_LIBRARIES:依赖的java库,一般为系统的jar…

[修改Linux下ssh端口号]解决无法修改sshd_config无法修改

前言:写本文的前因是本人的阿里云服务器经常被黑客暴力破解ssh的22端口号。再网络上搜索解决都是说使用root权限进行修改,但本人在root下也无法成功进行修改sshd_config文件。所以在大量搜索下终于找到了解决方案,现在分享出来给有需要的人使…

一个集成了AI和BI报表功能的新一代数据库管理系统神器--Chat2DB

世人皆知Navicate,无人识我Chat2DB 📖 简介 Chat2DB 是一款开源免费的多数据库客户端工具,支持多平台和主流数据库。 集成了AI的能力,能进行自然语言转SQL、SQL解释、SQL优化、SQL转换 ✨ 好处 1、AIGC和数据库客户端的联动&am…

广州华锐互动:办税服务厅税务登记VR仿真体验让税务办理更加灵活高效

在数字化世界的今天,我们正在见证各种业务过程的转型,而税务办理也不例外。最近,一种全新的交互方式正在改变我们处理税务的方式:虚拟现实(VR)。 首先,用户需要戴上虚拟现实头显,然后…

怎么调监控清晰度,监控画面不清晰怎么修复?

监控画面不清晰怎么修复,通过调整视频的分辨率可以达到使视频更清晰的目的,另外就是如果是室外的环境下,视频的监控镜头会积累灰尘,擦一下镜头有可能会使得拍摄的视频更清晰一些。另外就是可以通过一些软件将视频分辨率提高&#…

互联网医院系统:数字化时代中医疗服务的未来

随着数字化时代的发展,互联网医院系统在医疗服务中的作用日益凸显。本文将讨论互联网医院系统的一些关键技术方面,探讨这些技术如何推动医疗服务进入数字化时代。 1. 数据智能与个性化服务 互联网医院系统依赖于大数据分析和人工智能技术,…

网络运维Day19

文章目录 环境准备数据备份为什么要备份什么是备份备份到哪里什么时候备份如何备份 完整备份物理备份逻辑备份测试恢复所有库 构建MySQL服务xtrabackup完全备份与恢复完全备份完全恢复增量备份增量恢复 总结 环境准备 IP地址采用自动分配,以自己的为准 可以将之前的…

【LeetCode刷题-滑动窗口】--1004.最大连续1的个数III

1004.最大连续1的个数III 方法&#xff1a;滑动窗口 class Solution {public int longestOnes(int[] nums, int k) {int left 0,right 0,zero 0,res 0;while(right < nums.length){if(nums[right]0){zero;}while(zero > k){if(nums[left] 0){zero--;}left;}res Ma…