Tensorflow2 图像分类-Flowers数据深度学习模型保存、读取、参数查看和图像预测

news2025/2/26 9:33:32

目录

1.原文完整代码

1.1 模型运行参数总结

1.2模型训练效果

​编辑2.模型的保存

3.读取模型model

4.使用模型进行图片预测

5.补充 如何查看保存模型参数

 5.1 model_weights

 5.2 optimizer_weights


使用之前一篇代码:

  原文链接:Tensorflow2 图像分类-Flowers数据及分类代码详解

这篇文章中,经常有人问到怎么保存模型?怎么读取和应用模型进行数据预测?这里做一下详细说明。

1.原文完整代码

完整代码如下,做了少量修改:

import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

"""flower_photo/
  daisy/
  dandelion/
  roses/
  sunflowers/
  tulips/"""
import pathlib

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
print(data_dir)
print(type(data_dir))
data_dir = pathlib.Path(data_dir)
print(data_dir)
print(type(data_dir))

image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)
roses = list(data_dir.glob('roses/*'))
img0 = PIL.Image.open(str(roses[0]))
plt.imshow(img0)
plt.show()

batch_size = 32
img_height = 180
img_width = 180

# It's good practice to use a validation split when developing your model.
# Let's use 80% of the images for training, and 20% for validation.
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

class_names = train_ds.class_names
print(class_names)

# 图片可视化
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
    for i in range(30):
        ax = plt.subplot(3, 10, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")
plt.show()

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

normalization_layer = layers.experimental.preprocessing.Rescaling(1. / 255)
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))
first_image = image_batch[0]
# Notice the pixels values are now in `[0,1]`.
print(np.min(first_image), np.max(first_image))

data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.RandomFlip("horizontal",
                                                     input_shape=(img_height,
                                                                  img_width,
                                                                  3)),
        layers.experimental.preprocessing.RandomRotation(0.1),
        layers.experimental.preprocessing.RandomZoom(0.1),
    ]
)
plt.figure(figsize=(10, 10))
for images, _ in train_ds.take(1):
    for i in range(9):
        augmented_images = data_augmentation(images)
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(augmented_images[0].numpy().astype("uint8"))
        plt.axis("off")

num_classes = 5
model = Sequential([
    data_augmentation,
    layers.experimental.preprocessing.Rescaling(1. / 255),
    layers.Conv2D(16, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(32, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(64, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(128, 3, padding='same', activation='relu'),
    layers.MaxPooling2D(),
    layers.Dropout(0.15),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(num_classes)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()

epochs = 15
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs
)

model.save("./model/Flowers_1227.h5")   #保存模型

#读取并调用模型
pre_model = tf.keras.models.load_model("./model/Flowers_1227.h5")


acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)

img = keras.preprocessing.image.load_img(
    sunflower_path, target_size=(img_height, img_width)
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)  # Create a batch

predictions = pre_model.predict(img_array)
score = tf.nn.softmax(predictions[0])

print(
    "This image most likely belongs to {} with a {:.2f} percent confidence."
        .format(class_names[np.argmax(score)], 100 * np.max(score))
)

修改的代码包含:(1)修改了模型,增加了一个卷积层;(2)增加模型保存代码;(3)增加模型读取代码,并使用读取到的模型预测图片

1.1 模型运行参数总结

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
sequential (Sequential)      (None, 180, 180, 3)       0         
_________________________________________________________________
rescaling_1 (Rescaling)      (None, 180, 180, 3)       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 180, 180, 16)      448       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 90, 90, 16)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 90, 90, 32)        4640      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 45, 45, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 45, 45, 64)        18496     
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 22, 22, 64)        0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 22, 22, 128)       73856     
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 11, 11, 128)       0         
_________________________________________________________________
dropout (Dropout)            (None, 11, 11, 128)       0         
_________________________________________________________________
flatten (Flatten)            (None, 15488)             0         
_________________________________________________________________
dense (Dense)                (None, 128)               1982592   
_________________________________________________________________
dense_1 (Dense)              (None, 5)                 645       
=================================================================
Total params: 2,080,677
Trainable params: 2,080,677
Non-trainable params: 0

1.2模型训练效果

15次epoch有75.61%的精度,增加训练次数应该还有一定提升空间。

Epoch 1/15
92/92 [==============================] - 99s 1s/step - loss: 1.3126 - accuracy: 0.4087 - val_loss: 1.0708 - val_accuracy: 0.5477
Epoch 2/15
92/92 [==============================] - 88s 957ms/step - loss: 1.0561 - accuracy: 0.5562 - val_loss: 0.9844 - val_accuracy: 0.5872
Epoch 3/15
92/92 [==============================] - 89s 966ms/step - loss: 0.9517 - accuracy: 0.6117 - val_loss: 1.0068 - val_accuracy: 0.6035
Epoch 4/15
92/92 [==============================] - 84s 913ms/step - loss: 0.8743 - accuracy: 0.6550 - val_loss: 0.8538 - val_accuracy: 0.6580
Epoch 5/15
92/92 [==============================] - 82s 891ms/step - loss: 0.8065 - accuracy: 0.6809 - val_loss: 0.8371 - val_accuracy: 0.6703
Epoch 6/15
92/92 [==============================] - 82s 892ms/step - loss: 0.7623 - accuracy: 0.7115 - val_loss: 0.8203 - val_accuracy: 0.7016
Epoch 7/15
92/92 [==============================] - 94s 1s/step - loss: 0.7309 - accuracy: 0.7245 - val_loss: 0.7539 - val_accuracy: 0.7057
Epoch 8/15
92/92 [==============================] - 90s 982ms/step - loss: 0.6928 - accuracy: 0.7262 - val_loss: 0.7811 - val_accuracy: 0.7166
Epoch 9/15
92/92 [==============================] - 88s 955ms/step - loss: 0.6840 - accuracy: 0.7333 - val_loss: 0.8314 - val_accuracy: 0.6703
Epoch 10/15
92/92 [==============================] - 81s 877ms/step - loss: 0.6591 - accuracy: 0.7565 - val_loss: 0.7585 - val_accuracy: 0.7153
Epoch 11/15
92/92 [==============================] - 83s 899ms/step - loss: 0.6195 - accuracy: 0.7633 - val_loss: 0.7600 - val_accuracy: 0.7125
Epoch 12/15
92/92 [==============================] - 86s 934ms/step - loss: 0.6006 - accuracy: 0.7657 - val_loss: 0.6871 - val_accuracy: 0.7262
Epoch 13/15
92/92 [==============================] - 86s 934ms/step - loss: 0.5736 - accuracy: 0.7762 - val_loss: 0.6955 - val_accuracy: 0.7452
Epoch 14/15
92/92 [==============================] - 82s 897ms/step - loss: 0.5523 - accuracy: 0.7871 - val_loss: 0.7513 - val_accuracy: 0.7234
Epoch 15/15
92/92 [==============================] - 86s 935ms/step - loss: 0.5379 - accuracy: 0.7956 - val_loss: 0.6591 - val_accuracy: 0.7561

  图片数据增强后的效果图:

2.模型的保存

训练模型的保存实际上只需一行代码就行,在模型训练完成之后,我们将模型保存到指定的路径并给模型命名。模型保存的格式是.h5后缀的格式,这种文件是hdf5格式的数据,我们可以使用专门的软件打开查看模型相关参数。

在model.fit()训练完模型之后,保存模型到model文件夹下:

model.save("./model/Flowers_1227.h5")   #保存模型

运行完成之后在项目文件下可以看到model文件夹,文件中可以看到我们保存的模型:

 模型大小有45.7M.

3.读取模型model

读取代码也只需要一行,如下:

#读取并调用模型
pre_model = tf.keras.models.load_model("./model/Flowers_1227.h5")

4.使用模型进行图片预测

根据上面读取到的模型直接进行图片预测。

可以省去前面的数据训练部分,直接使用后面的部分代码,读取模型然后进行图片预测。

继续运行上文中的代码后面部分,即最后面的部分是预测一张图片属于什么类型的。

运行结果是:

This image most likely belongs to sunflowers with a 98.13 percent confidence.

读取模型进行预测的代码如下:

import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

"""flower_photo/
  daisy/
  dandelion/
  roses/
  sunflowers/
  tulips/"""
import pathlib

batch_size = 32
img_height = 180
img_width = 180

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)
class_names = train_ds.class_names
print(class_names)

#读取并调用模型
pre_model = tf.keras.models.load_model("./model/Flowers_1227.h5")

sunflower_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/592px-Red_sunflower.jpg"
sunflower_path = tf.keras.utils.get_file('Red_sunflower', origin=sunflower_url)

img = keras.preprocessing.image.load_img(
    sunflower_path, target_size=(img_height, img_width)
)
img_array = keras.preprocessing.image.img_to_array(img)
img_array = tf.expand_dims(img_array, 0)  # Create a batch

predictions = pre_model.predict(img_array)
score = tf.nn.softmax(predictions[0])
print(score)

print(
    "This image most likely belongs to {} with a {:.2f} percent confidence."
        .format(class_names[np.argmax(score)], 100 * np.max(score))
)

运行结果:

Found 3670 files belonging to 5 classes.
Using 2936 files for training.
2022-12-27 22:34:13.075000: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']
2022-12-27 22:34:14.205000: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
tf.Tensor([1.1012822e-04 5.2587932e-04 4.4165729e-03 9.8130155e-01 1.3645952e-02], shape=(5,), dtype=float32)
This image most likely belongs to sunflowers with a 98.13 percent confidence.

我们可以看到预测结果中,score = tf.nn.softmax(predictions[0]),代表的是该图片属于每种类型的概率大小,第四个是最大的,即第四个对应的是sunflowers类别,因此可以认为预测的结果就是sunflowers。

这里预测的图片是使用在线下载的一张图片进行预测的,实际上我们可以读取我们本地路径下的文件进行大批量的预测,并将每张图片预测结果保存到文本文件中用于后续的分析。

5.补充 如何查看保存模型参数

使用HDFView软件查看.H5后缀的文件。下载链接:HDFViiew-2.11.0-win64.exe-桌面系统文档类资源-CSDN下载

网上其他地方也有免费下载的,之前是在国外网站下载的,有时间查找的同学可以花点时间去找一下。 

可以看到,该模型主要有两部分,model_weights和optimizer_weights.即模型权重系数和优化器权重系数参数。

 我们点击展开这两个文件夹,我们可以看到里面的文件层次和我们的模型层次是一致的。

 卷积层中的参数可以查看到如下:

 5.1 model_weights

model_weigths参数展开如下,从dropout开始是空的。

 5.2 optimizer_weights

optimizer_weights参数如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/118441.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

English Learning - L1-7 介词 2022.12.26 周一

English Learning - L1-7 介词 2022.12.26 周一7 介词7.1 介词功能 1 - 表示动作的方向,范围和程度7.2 介词功能 2 - 胶水词,链接不同的名词7.3 介词功能 3 - 与 be 动词连用代替动词7.4 江南四大介词on核心意义:在。。。上; 在(某…

融云 x OHLA:「社交+游戏」双轮驱动,逐鹿中东陌生人社交

完整报告,关注公众号文章限免下载 走过十多年的出海历程,中国创业者面临的机遇和挑战正在发生根本性变化。TikTok、SHEIN 在全球大获全胜的背后,不仅有中国产业链成熟、工程师红利的厚积薄发,也有一代代出海人布局全球商业路径的思…

Ubuntu20.04部署KVM并安装Ubuntu Server 20.04

kvm虚拟化技术 KVM介绍 KVM是Linux开源社区大力支持的虚拟化技术,基于Intel和AMD的硬件虚拟化技术。KVM(Kernel-bashdVirtual Machine,即基于内核的虚拟机),它是用于Linux内核中的虚拟化环境设施,是Linux…

python:什么?你听MP3居然还要付费?看我一键......

前言 大家早好、午好、晚好吖 ❤ ~ 在我们上班空闲\游玩\散步的时候,总会习惯的拿出手机放首音乐来听一听 但是吧,有时候我们听一首歌起劲的时候,它会你提醒你 这时候怎么办呢?通常我们是下一首,或者充值 但是手头不宽裕但又想听怎么办? …

JavaEE-Spring(IoC控制反转,DI依赖注入,Spring项目创建和基本使用,ApplicationContext和BeanFactory的区别)

文章目录1. IoCDI2. Spring项目创建和使用ApplicationContext和BeanFactory的区别1. IoC Spring是一个包含多个工具方法的IoC容器 tomcat是web容器 List/Map是数据存储容器 IoC:Inversion of Control(控制反转) 将对象的控制权交给Spring&…

RK3399+PCIe+FPGA 在高速AD采样中的应用

一、需求 要实现高速AD/DA的数据采集,并发送到高性能arm核进行数据处理; 方案RK3399pcieFPGAAD/DA。 二、器件介绍 一、RK3399 RK3399是一款低功耗、高性能处理器,用于计算、个人移动互联网设备和其他智能设备应用。基于Big.Little架构&…

计算机发展史之查尔斯·巴贝奇

查尔斯巴贝奇(Charles Babbage,1791年12月26日—1871年10月18日)是一名英国数学家、发明家、科学家,科学管理的先驱者,出生于一个富有的银行家的家庭,曾就读于剑桥大学三一学院。 他在24岁时就被选为英国皇…

智慧医院数据可视化(数据大屏)

本次分享的作品是用软件Axure8.0(兼容9和10)制作的针对智慧医院设计的数据可视化大屏,其作品内容主要是对医院的运营情况、门诊、住院、手术、药品、医务、医疗设备、卫生耗材以及医疗质量数据进行综合可视分析。 运营情况:对医院的整体数据…

左神算法学习:第一天-------位运算

前言 位运算是在算法设计中的一种非常重要和高效的方法,常见的有与运算,非运算,异或运算。我们常用的比较多的可能就是异或运算,又叫无进位相加。 1.1 取非运算----(~) 取非运算其实就是和我们的无符号数…

cadence SPB17.4 - 用元件管理器来更新原理图中的元件属性信息

文章目录cadence SPB17.4 - 用元件管理器来更新原理图中的元件信息概述笔记修正原理图库修正CIS库的元件登记表ENDcadence SPB17.4 - 用元件管理器来更新原理图中的元件信息 概述 画好图后, 出了BOM. 同学指出BOM中有些元件型号信息不合适, 影响元件购买, 想改一下. 更新了原…

设计模式-桥接、职责链、中介

前言 本文为datawhale2022年12月组队学习《大话设计模式》task6打卡学习。 【教程地址】https://github.com/datawhalechina/sweetalk-design-pattern 一、桥接模式 1.1 基本定义 桥接模式(Bridge Pattern)又称为柄体(Handle and Body)模式或接口(In…

第十二讲:生成树概念及STP技术应用

在传统的交换网络中,设备通过单条链路进行连接,当某一个点或是某一个链路发生故障时可能导致网络无法访问,解决这种问题的办法是在网络中提供冗余链路,但是交换机网络中的冗余链路会产生广播风暴、MAC地址失效等现象,最…

StarRocks 统计信息和 Cost 估算

导读:欢迎来到 StarRocks 源码解析系列文章,我们将为你全方位揭晓 StarRocks 背后的技术原理和实践细节,助你逐步了解这款明星开源数据库产品。本期 StarRocks 技术内幕将主要介绍 StarRocks 统计信息和 Cost 估算。 1.背景 在学习本文之前&…

mysql搭建主从复制

Mysql主从复制搭建过程: 主从需同步时间,主开启ntpd(ntp网络时间协议,它的端口号udp123)服务-----修改配置,从通过/usr/sbin/ntpdate 主ip(ntpdate包需要提前安装);主:开启中继二进…

整数划分问题(Java递归)

整数划分问题(Java递归) 文章目录整数划分问题(Java递归)0、 问题描述1、递归式2、代码3、参考0、 问题描述 整数划分问题 将正整数n表示成一系列正整数之和:nn1n2…nk,其中n1≥n2≥…≥nk≥1,k…

数字校园建设方案技术建议书

【版权声明】本资料来源网络,仅用于行业知识分享,供个人学习参考,请勿商用。【侵删致歉】如有侵权请联系小编,将在收到信息后第一时间进行删除!完整资料领取见文末,部分资料内容: 1.1 华为数字化…

“设计”小哥转行5G网络优化工程师!从零开始,三个月实现逆风翻盘~

5G网络优化,一个陌生的领域,对于一个毫无经验的小白来说,选择转行必定是需要勇气和决心的。好在,在决定选择5G网络优化的这一段时间里,老师给予了我最大的帮助和支持,包括从授课,到练习&#xf…

【Linux】基础IO(open、文件描述符、缓冲区)

文章目录1、从文件操作开始1.1 文件操作的系统调用接口1.2 文件描述符2、重定向3、缓冲区1、从文件操作开始 在C语言阶段,接触了很多库函数,如fopen、fclose、fread和fwrte,这些函数帮助了程序实现了内存与磁盘的输入输出功能。 不过之前都…

轻松搭建MQTT服务器,开发流程全透明

1、使用场景 MQTT服务器适用场景就不多介绍了,基本上实在IOT圈发光发热,所以说是特定领域的一个服务端软件,我们是用在车联网的环境里,用来发布消息。 2、选型 最早说需要使用mqtt服务器,然后我以为需要自己开发服务…

专利代理机构代理专利流程

代理申请专利流程是怎么样的? (一)咨询 1、 确定发明创造的内容是否属于可以申请专利的内容。 (二)技术交底 1、申请人向专利代理人提供有关发明创造的背景资料或委托检索有关内容; 2、申请人详细介绍发明创造的内容,帮助专利代理人充分理解发明创造…