Tensorflow神经网络模型-鲜花种类识别

news2024/11/16 5:54:41

在这里插入图片描述必应壁纸供图

Tensorflow神经网络模型-鲜花种类识别

数据集:https://download.csdn.net/download/weixin_53742691/87982215

导入相关依赖

import warnings
import re
from IPython.display import clear_output, display
from tkinter import Tk, filedialog
from ipywidgets import Button
import cv2
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


warnings.filterwarnings("ignore")

数据探索

flower_category = "flowers"
categorys = 0
categorys_list = []
for category in os.listdir(flower_category):
    categorys += 1
    categorys_list.append(category)
print("种类总数为:%d" % categorys)
print(categorys_list)
种类总数为:5
['daisy', 'dandelion', 'rose', 'sunflower', 'tulip']
file_path = "flowers/sunflower/"
file_count = 0
for file in os.listdir(file_path):
    if re.match(r'\S*\.?[jpg,png,jpeg]', file):
        file_count += 1
print("文件总数是:%d" % file_count)
文件总数是:733

图片处理器

def img_deal(img_path):
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
    img = cv2.resize(img, (224, 224))
    return img

图片预览

sample_list = []
num = 0
for sample in os.listdir(file_path):
    num += 1
    sample = "flowers/sunflower/"+sample
    sample_list.append(sample)
    if num == 5:
        break
print(sample_list)
['flowers/sunflower/1008566138_6927679c8a.jpg', 'flowers/sunflower/1022552002_2b93faf9e7_n.jpg', 'flowers/sunflower/1022552036_67d33d5bd8_n.jpg', 'flowers/sunflower/10386503264_e05387e1f7_m.jpg', 'flowers/sunflower/10386522775_4f8c616999_m.jpg']
plt.figure(figsize=(20, 20))
for i in range(5):
    plt.subplot(1, 5, i+1)
    img = img_deal(sample_list[i])
    plt.imshow(img)
    plt.xlabel("sunflower "+str(i+1))
    plt.xticks([])
    plt.yticks([])
plt.show()

png

数据预处理

# 输入图片大小
img_size = (224, 224)
# 图像数据生成
gen = tf.keras.preprocessing.image.ImageDataGenerator(
    img_size,
    validation_split=0.25,
    preprocessing_function=tf.keras.applications.mobilenet_v2.preprocess_input
)

设置训练集

train_generator = gen.flow_from_directory(
    # 设置图片加载路径
    "flowers/",
    # 设置加载图片大小
    img_size,
    # 设置批次大小
    batch_size=32,
    class_mode="categorical",
    subset="training"
)
Found 3238 images belonging to 5 classes.

设置验证集

validation_generator = gen.flow_from_directory(
    "flowers/",
    img_size,
    batch_size=32,
    class_mode="categorical",
    subset="validation"
)
Found 1079 images belonging to 5 classes.

处理后的图片预览shuffle

plt.figure(figsize=(26, 10))
for i in range(32):
    plt.subplot(4, 8, i+1)
    sample = train_generator[0][0][i]
    # 设置图片色彩通道最小值
    sample = np.maximum(sample, 0)
    # 设置图片标签
    plt.imshow(sample)
    plt.xlabel(i)
    plt.xticks([])
    plt.yticks([])
plt.show()

png

模型搭建和训练

# 基础模型
base_model = tf.keras.applications.MobileNetV2(
    weights="imagenet",
    include_top=False,
    input_shape=(224, 224, 3)
)
# 锁定其他节点
for layers in base_model.layers:
    layers.trainable = False
# 重建模型
model = tf.keras.Sequential([
    base_model,
    # 展平
    tf.keras.layers.Flatten(),
    # 添加神经元
    tf.keras.layers.Dense(units=128, activation="relu"),
    tf.keras.layers.Dense(units=64, activation="relu"),
    tf.keras.layers.Dense(units=5, activation="softmax")
])
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 mobilenetv2_1.00_224 (Funct  (None, 7, 7, 1280)       2257984   
 ional)                                                          
                                                                 
 flatten (Flatten)           (None, 62720)             0         
                                                                 
 dense (Dense)               (None, 128)               8028288   
                                                                 
 dense_1 (Dense)             (None, 64)                8256      
                                                                 
 dense_2 (Dense)             (None, 5)                 325       
                                                                 
=================================================================
Total params: 10,294,853
Trainable params: 8,036,869
Non-trainable params: 2,257,984
_________________________________________________________________
model.compile(loss="categorical_crossentropy",
              optimizer="adam", metrics=['accuracy'])
 history = model.fit(train_generator,
                     epochs=5,
                     validation_data=validation_generator)
Epoch 1/5
102/102 [==============================] - 49s 320ms/step - loss: 1.1481 - accuracy: 0.7712 - val_loss: 0.5897 - val_accuracy: 0.8360
Epoch 2/5
102/102 [==============================] - 35s 343ms/step - loss: 0.1766 - accuracy: 0.9469 - val_loss: 0.6906 - val_accuracy: 0.8573
Epoch 3/5
102/102 [==============================] - 30s 289ms/step - loss: 0.0371 - accuracy: 0.9864 - val_loss: 0.6850 - val_accuracy: 0.8703
Epoch 4/5
102/102 [==============================] - 28s 273ms/step - loss: 0.0144 - accuracy: 0.9957 - val_loss: 0.7199 - val_accuracy: 0.8703
Epoch 5/5
102/102 [==============================] - 29s 282ms/step - loss: 0.0013 - accuracy: 1.0000 - val_loss: 0.6943 - val_accuracy: 0.8749
model.save("models/flower_model.h5")

自主测试

model = tf.keras.models.load_model("models/flower_model.h5")
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 mobilenetv2_1.00_224 (Funct  (None, 7, 7, 1280)       2257984   
 ional)                                                          
                                                                 
 flatten (Flatten)           (None, 62720)             0         
                                                                 
 dense (Dense)               (None, 128)               8028288   
                                                                 
 dense_1 (Dense)             (None, 64)                8256      
                                                                 
 dense_2 (Dense)             (None, 5)                 325       
                                                                 
=================================================================
Total params: 10,294,853
Trainable params: 8,036,869
Non-trainable params: 2,257,984
_________________________________________________________________
def select_file(b):
    clear_output()
    root = Tk()
    root.withdraw()
    root.call('wm', 'attributes', '.', '-topmost', True)
    b.files = filedialog.askopenfilename(multiple=True)
    print(b.files)


fileselect = Button(description="选择文件")
fileselect.on_click(select_file)
display(fileselect)
Button(description='选择文件', style=ButtonStyle())
len(fileselect.files)
25
plt.figure(figsize=(20, 20))
for i in range(25):
    plt.subplot(5, 5, i+1)
    img = img_deal(fileselect.files[i])
    plt.imshow(img)
    plt.xlabel(i+1)
    plt.xticks([])
    plt.yticks([])
plt.show()

png

# 图片进行打包
from tensorflow.keras.applications.densenet import preprocess_input
test_img = []
for i in range(25):
    img = img_deal(fileselect.files[i])
    test_img.append(img)
test_img = np.asarray(test_img)
test_pre_image = preprocess_input(test_img)
test_pre_image.shape
(25, 224, 224, 3)
decoder_dict = dict(zip(train_generator.class_indices.values(),
                    train_generator.class_indices.keys()))
decoder_dict
{0: 'daisy', 1: 'dandelion', 2: 'rose', 3: 'sunflower', 4: 'tulip'}
predictions = model.predict(test_pre_image)
for prediction in predictions:
    print(decoder_dict[prediction.argmax()], end=" ")
sunflower tulip tulip rose rose rose rose tulip daisy sunflower dandelion rose daisy dandelion dandelion rose tulip tulip tulip daisy daisy sunflower dandelion dandelion rose 

整体输出可视化测试

font = {
    "size": "22",
    "color": "red"
}
plt.figure(figsize=(20, 20))
for i in range(25):
    plt.subplot(5, 5, i+1)
    img = img_deal(fileselect.files[i])
    plt.imshow(img)
    img = preprocess_input(img)
    img = np.expand_dims(img, 0)
    result = model.predict(img)
    label = decoder_dict[result.argmax()]
    plt.xlabel(label, font)
    plt.xticks([])
    plt.yticks([])
plt.show()

png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/707854.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

无序中的秩序之美:集合数据为编程世界增添新的维度

文章目录 集合数据简介集合数据特点常见的集合数据类型1. 列表(List)2. 元组(Tuple)3. 集合(Set)4. 字典(Dictionary) 集合数据简介 集合数据是指将多个元素组合在一起的数据结构。…

【数据库原理】MyShop 商城数据库设计(SQL server)

MyShop 商城数据库设计 项目背景定义课程设计要求概念结构设计逻辑结构设计数据结构的描述用户信息数据结构的描述地址信息数据结构的描述商品类别数据结构的描述商品数据结构的描述购物车数据结构的描述订单数据结构的描述订单项数据结构的描述 物理结构设计用户表结构地址表结…

机器视觉检测中的图像预处理方法:平滑模糊处理,锐化

一、平滑模糊处理 以Dalsa sherlock软件为例,一起来了解一下视觉检测中平滑模糊的图像处理方法。 1、观察灰度分布来描述一幅图像称为空间域,观察图像变化的频率被称为频域。 2、频域分析:低频对应区域的图像强度变化缓慢,高频对应的变化快。低通滤波器去除了图像的高频部…

【支付系统】java springboot 生成二维码,二维码中文乱码

支付系统必不可少的就是生成二维码,有时我们会需要将支付链接转换为二维码.用户通过移动设备扫描二维码调起支付. 该篇文章主要使用的是hutool自带的二维码生成功能. 1. 引入依赖(hutool 可以按需引入这里就直接使用all了) <dependency><groupId>com.google.zxing&…

conda环境里用不了电脑系统环境里的应用ffmpeg;ffmpeg调用本地windows麦克风读取

1、 ffmpegzai conda环境里执行不了&#xff0c;在系统可以运行 import ffmpegstream ffmpeg.input(rD:\sound\222.mp4) stream ffmpeg.filter(stream, fps, fps25, roundup) stream ffmpeg.output(stream, rD:\sound\dummy2.mp4) ffmpeg.run(stream)会报错&#xff1a; Fi…

分布式配置 Config

一、Config 简介 1、Config的组成 Server&#xff1a;分布式配置中心&#xff0c;是独立运行的微服务应用&#xff0c;连接配置仓库(Git、SVN、本地化文件等)并为客户端提供获取配置信息、加密信息和解密信息的访问接口。Client&#xff1a;各个微服务&#xff0c;通过 Serve…

ChatGLM2-6B发布,位居C-Eval榜首

ChatGLM-6B自2023年3月发布以来&#xff0c;就已经爆火&#xff0c;如今6月25日&#xff0c;清华二代发布&#xff08;ChatGLM2-6B&#xff09;&#xff0c;位居C-Eval榜单的榜首&#xff01; 项目地址&#xff1a;https://github.com/THUDM/ChatGLM2-6B HuggingFace&#xf…

java开发必备技能之Java泛型

简介 泛型的优点 1、泛型的本质是为了参数化类型&#xff0c;也就是在在不创建新的类型的情况下&#xff0c;通过泛型指定的不同类型来控制形参具体限制的类型&#xff0c;很明显这种方法提高了代码的复用性(比如List因为使用泛型可以添加任意类型的对象&#xff0c;而不需要…

《PyTorch深度学习实践》第八讲 加载数据集

b站刘二大人《PyTorch深度学习实践》课程第八讲加载数据集笔记与代码&#xff1a;https://www.bilibili.com/video/BV1Y7411d7Ys?p8&vd_sourceb17f113d28933824d753a0915d5e3a90 Dataset用于构造数据集&#xff0c;该数据集能够支持索引 DataLoader用于从数据集中拿出一个…

深入理解计算机系统(3)_计算机指令

深入理解计算机系统系列文章目录 第一章 计算机的基本组成 1. 内容概述 2. 计算机基本组成 第二章 计算机的指令和运算 3. 计算机指令 4. 程序的机器级表示 5. 计算机运算 6. 信息表示与处理 第三章 处理器设计 7. CPU 8. 其他处理器 第四章 存储器和IO系统 9. 存储器的层次…

金融基础知识(三):期权

1.认购期权与认沽期权 认购期权和认沽期权都是交易所常见的期权合约。 认购期权&#xff08;Call Option&#xff09;是一种给予持有人以在未来某个时间或特定事件发生时购买底层标的资产的权利。认购期权的持有人在行权日&#xff08;Expiration Date&#xff09;可以按照期…

B/S架构的C#云检验系统源码 实验室信息管理系统源码

科技的飞速发展为实验室信息管理带来了新机遇&#xff0c;云计算技术的应用更是为实验室信息管理打开了新的大门。云 LIS 实验室信息管理系统&#xff0c;作为一种新型的信息化管理方案&#xff0c;已经在多个实验室的信息化管理中得到应用&#xff0c;并且具有广阔的应用前景。…

Python3 命名空间和作用域 | 菜鸟教程(十七)

目录 一、命名空间 &#xff08;一&#xff09;简介 1、命名空间(Namespace)是从名称到对象的映射&#xff0c;大部分的命名空间都是通过 Python 字典来实现的。 2、命名空间提供了在项目中避免名字冲突的一种方法。 3、各个命名空间是独立的&#xff0c;没有任何关系的&a…

访问者模式(Vistor)

定义 访问者是一种行为设计模式&#xff0c;它能将算法与其所作用的对象隔离开来。 前言 1. 问题 假如你的团队开发了一款能够使用巨型图像中地理信息的应用程序。图像中的每个节点既能代表复杂实体&#xff08;例如一座城市&#xff09;&#xff0c; 也能代表更精细的对象…

Nginx【Docker(安装Nginx、Nginx服务启停控制、全局块、events块、HTTP块)】(二)-全面详解(学习总结---从入门到深化)

目录 Docker安装Nginx Nginx服务启停控制 Nginx配置指令详解_全局块 Nginx配置指令详解_events块 Nginx配置指令详解_HTTP块 Docker安装Nginx 拉取官方的Nginx镜像 [rootlocalhost ~]# docker pull nginx 以下命令使用 Nginx 默认的配置来启动一个 Nginx 容器实例&#xf…

小驰私房菜_28_Qcom Camx相关名词

(Qcom 7325平台) CSID = Camera Serial Interface Decoder module IPE = Image Processing Engine IFE (x3) = Image Front End IFE_lite (x2) BPS = Bayer processing segment (for Snapshot) IPE = Image Processing Engine VPU = Video Processing Unit (CODEC) DP…

matplotlib布局模式

栅格布局 import matplotlib.pyplot as plt import numpy as np plt.figure("OBJ")x np.linspace(-np.pi, np.pi, 1000) cosy np.cos(x) siny np.sin(x) y x * 0.5 timesy x ** 2 # 创建九宫格 gs plt.GridSpec(3, 3) # 第0-1行&#xff0c;第2列 plt.subplot…

Eclipse中有用的快捷键

Eclipse中有的快捷键自己记不清楚&#xff0c;但用起来又很方便&#xff0c;遇到了就放在这边备忘。 【CtrlO】快速定位某个类中的属性、方法 有时候&#xff0c;一个类中的属性、方法比较多&#xff0c;想用快捷键快速查找&#xff0c;提升效率。 举例&#xff1a;我想查找…

MYSQL-聚合函数及分组查询

常用聚合函数 COUNT() 求有多少行 SUM() 求和 AVG() 求平均值 MIN() 求最小值 MAX() 求最大值 举个栗子 SELECT AVG(price) FROM products WHERE price_id > 10; 这行代码就是在求id大于10的价格的平均值 AVG(price)表示求price列的平均值 执行逻辑为 先由WHERE…

Mock在接口测试中的实际应用

关于Mock测试 01、含义和目的 1、 什么是mock测试&#xff1f; Mock 测试就是在测试过程中&#xff0c;对于某些不容易构造&#xff08;如 HttpServletRequest 必须在Servlet 容器中才能构造出来&#xff09;或者不容易获取的比较复杂的对象&#xff08;如 JDBC 中的ResultSe…