基于CNN的多种类蝴蝶图像分类

news2025/3/19 3:00:03

基于CNN的多种类蝴蝶图像分类🦋

基于卷积神经网络对6499+2786张图像,75种不同类别的蝴蝶进行可视化分析、模型训练及分类展示

导入库

import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras import regularizers
import warnings

warnings.filterwarnings("ignore", category=UserWarning, message=r"Your `PyDataset` class should call `super().__init__\(\*\*kwargs\)`")

数据分析及可视化

df = pd.read_csv("/home/mw/input/btfl7333/btfl/btfl/Training_set.csv")
df.head(10)

在这里插入图片描述

print("查看数据信息")
print(df.describe())
print("查看空值")
print(df.isnull().sum())

在这里插入图片描述
查看各个类别包含的数据量

labelcounts = df['label'].value_counts().sort_index()
plt.figure(figsize=(14, 8))
sns.barplot(x=labelcounts.index, y=labelcounts.values, palette='viridis')
plt.title('蝴蝶类型数目详细信息')
plt.xlabel('蝴蝶类型')
plt.ylabel('类别数量')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()

在这里插入图片描述
随机查看部分图片及其对应的标签

image_dir = "/home/mw/input/btfl7333/btfl/btfl/train"
sample_images = df.sample(12, random_state=43)
fig, axes = plt.subplots(4, 3, figsize=(15, 15))

for i, (index, row) in enumerate(sample_images.iterrows()):
    img_path = os.path.join(image_dir, row['filename'])
    img = load_img(img_path, target_size=(150, 150))
    img_array = img_to_array(img) / 255.0  
    
    ax = axes[i // 3, i % 3]
    ax.imshow(img_array)
    ax.set_title(f"类别: {row['label']}")
    ax.axis('off')

plt.tight_layout()
plt.show()

在这里插入图片描述

数据预处理

为图像分类任务准备训练和验证数据
使用train_test_split将数据集按照80%的比例划分为训练集 (train_df) 和验证集 (val_df)。
创建训练集的数据生成器,对训练数据进行数据增强,同时将标签转换为独热编码形式
创建验证集的数据生成器,对测试数据进行像素归一化

train_df, val_df = train_test_split(df, test_size=0.2, random_state=42)

train_datagen = ImageDataGenerator(
    rescale=1./255, # 将像素值归一化到 [0, 1] 范围
    rotation_range=40, # 随机旋转图片,范围为0到40度
    width_shift_range=0.2, # 随机水平和垂直平移图片,范围为20%
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2, # 随机缩放图片
    horizontal_flip=True,
    fill_mode='nearest' # 在变换时填充空白区域,使用最近邻插值
)

val_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_dataframe(
    dataframe=train_df,
    directory=image_dir,
    x_col='filename',
    y_col='label',
    target_size=(150, 150),
    batch_size=32,
    class_mode='categorical' # 将标签转换为独热编码形式
)

val_generator = val_datagen.flow_from_dataframe(
    dataframe=val_df,
    directory=image_dir,
    x_col='filename',
    y_col='label',
    target_size=(150, 150),
    batch_size=32,
    class_mode='categorical'
)

在这里插入图片描述

展示部分处理后的数据

上一步已经对标签进行了编码

images, labels = next(train_generator)

# 设置绘图参数
plt.figure(figsize=(12, 8))

# 显示前10张图片及其标签
for i in range(10):
    plt.subplot(5, 2, i + 1)
    plt.imshow(images[i])  # 显示图片
    plt.title(f'Label: {labels[i]}')  # 显示标签
    plt.axis('off')  # 不显示坐标轴

plt.tight_layout()
plt.show()

在这里插入图片描述

构建模型

构建的是卷积神经网络CNN的模型,如下
输入层: 形状为 (150, 150, 3) 的图像输入。
卷积层 1: 32 个卷积核,尺寸为 (3, 3),激活函数为 ReLU。
池化层 1: 最大池化层,池化窗口为 (2, 2)。
卷积层 2: 64 个卷积核,尺寸为 (3, 3),激活函数为 ReLU。
池化层 2: 最大池化层,池化窗口为 (2, 2)。
卷积层 3: 128 个卷积核,尺寸为 (3, 3),激活函数为 ReLU。
池化层 3: 最大池化层,池化窗口为 (2, 2)。
展平层: 将多维特征图展平为一维。
全连接层 1: 128 个节点,激活函数为 ReLU。
dropout 层: 以减少过拟合,丢弃率为 0.5。
全连接层 2(输出层): 节点数与类别数相同,激活函数为 softmax

# 获取类别数量
num_classes = len(train_generator.class_indices)

# 构建模型
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))  # 使用 num_classes

model.summary()

在这里插入图片描述

# 编译模型
model.compile(optimizer='adam', 
              loss='categorical_crossentropy', 
              metrics=['accuracy'])
# 训练模型
history = model.fit(train_generator, 
                    steps_per_epoch=train_generator.n // train_generator.batch_size, 
                    validation_data=val_generator, 
                    validation_steps=val_generator.n // val_generator.batch_size, 
                    epochs=40)

在这里插入图片描述

模型评估

plt.plot(history.history['acc'], label='Train Accuracy')
plt.plot(history.history['val_acc'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend()
plt.show()

在这里插入图片描述

plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend()
plt.show()

在这里插入图片描述

# 保存模型
model.save('butterfly_classifier.h5')

使用模型进行预测展示

# 加载之前保存的模型
model = load_model('butterfly_classifier.h5')

val_images, val_labels = next(val_generator)

# 进行预测
predictions = model.predict(val_images)
pred_labels = np.argmax(predictions, axis=1)
true_labels = np.argmax(val_labels, axis=1)

# 获取类别映射
class_indices = val_generator.class_indices
class_names = {v: k for k, v in class_indices.items()}

# 定义显示图像的函数
def display_images(images, true_labels, pred_labels, class_names, num_images=9):
    plt.figure(figsize=(15, 15))
    for i in range(num_images):
        plt.subplot(3, 3, i + 1)
        plt.imshow(images[i])
        true_label = class_names[int(true_labels[i])]
        pred_label = class_names[int(pred_labels[i])]
        plt.title(f"True: {true_label}\nPred: {pred_label}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# 调用显示函数
display_images(val_images, true_labels, pred_labels, class_names, num_images=9)

在这里插入图片描述

总结

这次这个基于cnn的图像分类,获得了高于 70% 的准确率。可以加载我保存好的模型进行预测试试,感兴趣的还可以继续调参训练

# 若需要完整数据集以及代码请点击以下链接
https://mbd.pub/o/bread/aJaVmJ9s

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2317537.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Unity插件-适用于画面传输的FMETP STREAM使用方法(三)基础使用

目录 一、插件介绍 二、组件介绍 三、Game View Streaming 1、使用 FM Network UDP 的基本设置 Server Scene Client Scene 2、使用使用 FM WebSocket 的基本设置 四、Audio Streaming 五、Microphone Streaming 一、插件介绍 ​​​​​​Unity插件-适用于画面传输的…

微信小程序wx.request接口报错(errno: 600001, errMsg: “request:fail -2:net::ERR_FAILED“)

来看看报错 报错如下: 请求发送部分,代码如下: uni.request({url: self.serverUrl "/getRealName",method: GET,data: {"code": self.info.code,},header: {"Authorization": uni.getStorageSync(tokenHead) uni.getStorageSync(token)}}…

基于Python+MySQL编写的(WinForm)图书管理系统

一、项目需求分析 1.1 项目介绍 项目背景 图书馆管理系统是一些单位不可缺少的部分,书籍是人类不可缺少的精神食粮,尤其对于学校来说,尤其重要。所以图书馆管理系统应该能够为用户提供充足的信息和快捷的查询手段。但一直以来人们使用传统人工的方式管…

[贪心算法] 摆动序列

1.解析 这里我们的贪心体现在,这里我们只需要找到每一个拐点位置的数字即可, 证明: 当我们在A点时,我们下一步的选择有四种 A到D这个线段内的数字(不包括D)选择D点D到F的点F之后的点 对于A到D来说&#xf…

WPF未来展望:紧跟技术发展趋势,探索新的可能性

WPF未来展望:紧跟技术发展趋势,探索新的可能性 一、前言二、WPF 与.NET 技术的融合发展2.1 拥抱.NET Core2.2 利用.NET 5 及后续版本的新特性 三、WPF 在新兴技术领域的应用拓展3.1 与云计算的结合3.2 融入物联网生态 四、WPF 在用户体验和设计方面的创新…

低空经济腾飞:无人机送货、空中通勤,未来已来

近年来,低空经济逐渐成为社会关注的焦点。从无人机送货到“空中的士”,再到飞行培训的火热进行,低空经济正迎来前所未有的发展机遇。随着技术进步和政策支持,这一曾经看似遥远的未来场景,正逐步变为现实。 低空经济如何…

QT编译器mingw与msvc区别及环境配置

一.QT编译器mingw与msvc主要区别 二.QT开发环境配置 1. MinGW 配置 安装步骤: 通过 Qt 官方安装器 安装时勾选 MinGW 组件(如 Qt 6.7.0 MinGW 64-bit)。 确保系统环境变量包含 MinGW 的 bin 目录(如 C:\Qt\Tools\mingw1120_64…

【css酷炫效果】纯CSS实现进度条加载动画

【css酷炫效果】纯CSS实现进度条加载动画 缘创作背景html结构css样式完整代码基础版进阶版 效果图 通过CSS渐变与背景位移动画,无需JavaScript即可创建流体动态进度条。 想直接拿走的老板,链接放在这里:https://download.csdn.net/download/u…

Feedback-Guided Autonomous Driving

Feedback-Guided Autonomous Driving idea 问题设定:基于 CARLA 的目标驱动导航任务,通过知识蒸馏,利用特权智能体的丰富监督信息训练学生传感器运动策略函数 基于 LLM 的端到端驱动模型:采用 LLaVA 架构并添加航点预测头&#…

图解AUTOSAR_CP_WatchdogDriver

AUTOSAR WatchdogDriver模块详解 AUTOSAR MCAL层看门狗驱动模块详细解析 目录 1. 模块概述2. 架构位置 2.1. 组件架构 3. 主要功能4. API接口5. 配置参数 5.1. 配置模型 6. 错误代码7. 状态管理 7.1. 状态机 8. 处理流程 8.1. 活动流程 9. 操作序列 9.1. 典型操作序列 10. 硬件…

大数据学习(65)- Hue详解

🍋🍋大数据学习🍋🍋 🔥系列专栏: 👑哲学语录: 用力所能及,改变世界。 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言📝支持一…

C语言学习笔记(第三部份)

说明:由于所有内容放在一个md文件中会非常卡顿,本文件将接续C_1.md文件的第三部分 整型存储和大小端 引例: int main(void) {// printf("%d\n", SnAdda(2, 5));// PrintDaffodilNum(10000);// PrintRhombus(3);int i 0;int arr[…

深入理解蒸馏、Function Call、React、Prompt 与 Agent

AI基础概念与实操 一、什么是蒸馏二、如何理解Function Call、React、Prompt与Agent(一)Function Call与Agent(二)Agent中的React概念(三)Prompt与Agent的关联 实操演练function callprompt 一、什么是蒸馏…

记录一个SQL自动执行的html页面

在实际工作场景中,需要运用到大量SQL语句更新业务逻辑,对程序员本身,写好的sql语句执行没有多大问题(图1),但是对于普通用户来说还是有操作难度的。因此我们需要构建一个HTML页面(图2&#xff0…

qt介绍图表 charts 一

qt chartsj基于Q的Graphics View框架,其核心组件是QChartView和QChart.QChartView是一个显示图表的独立部件,基类为QGraphicsView.QChar类管理图表的序列,图例和轴示意图。 绘制一个cos和sin曲线图,效果如下 实现代码 #include…

Transformer:GPT背后的造脑工程全解析(含手搓过程)

Transformer:GPT背后的"造脑工程"全解析(含手搓过程) Transformer 是人工智能领域的革命性架构,通过自注意力机制让模型像人类一样"全局理解"上下文关系。它摒弃传统循环结构,采用并行计算实现高…

S32K144入门笔记(十):TRGMUX的初始化

目录 1. 概述 2. 代码配置 1. 概述 书接上回,TRGMUX本质上是一个多路选择开关,根据用户手册中的描述,它可以实现多个输入的选择输出,本篇文章将验证如何通过配置工具来生成初始化配置代码。 2. 代码配置 笔者通过配置TRGMUX实现…

有了大模型为何还需要Agent智能体

一、什么是Agent? Agent(智能体) 是一种能感知环境、自主决策、执行动作的智能实体,当它与大语言模型(如通义千问QWen、GPT)结合时,形成一种**“增强型AI系统”**。其核心架构如下:…

DNS主从服务器

1.1环境准备 作用系统IP主机名web 服务器redhat9.5192.168.33.8webDNS 主服务器redhat9.5192.168.33.18dns1DNS 从服务器redhat9.5192.168.33.28dns2客户端redhat9.5192.168.33.7client 1.2修改主机名和IP地址 web服务器 [rootweb-8 ~]# hostnamectl hostname web [rootweb-8…

Flume详解——介绍、部署与使用

1. Flume 简介 Apache Flume 是一个专门用于高效地 收集、聚合、传输 大量日志数据的 分布式、可靠 的系统。它特别擅长将数据从各种数据源(如日志文件、消息队列等)传输到 HDFS、HBase、Kafka 等大数据存储系统。 特点: 可扩展&#xff1…