24- 深度学习的模型保存和加载 (TensorFlow系列) (深度学习)

news2025/1/11 23:47:00

知识要点

keras 保存成hdf5文件, 1.保存模型和参数, 2.只保存参数

  • 1.保存模型和参数
    • save_model
    • callback ModelCheckpoint
  • 2. 只保存参数
    • save_weights
    • callback ModelCheckpoint save_weights_only = True

保存模型:

  • 案例数据: Fashion-MNIST总共有十个类别的图像
  • model.save_weights(os.path.join(logdir, 'fashion_mnist_weights_2.h5'))      # 保存参数的方法
  • 加载参数: model.load_weights(os.path.join(logdir, 'fashion_mnist_weight.h5'))
  • 保存模型: model.save(os.path.join(logdir, 'fashion_mnist_model.h5'))
  • 加载模型: model2 = keras.models.load_model(os.path.join(logdir, 'fashion_mnist_model.h5'))
  • 把keras模型保存成savedmodel格式: tf.saved_model.save(model, './keras_saved_model')


一 模型保存和部署

  • TFLite是为了将深度学习模型部署在移动端和嵌入式设备的工具包,可以把训练好的TF模型通过转化、部署和优化三个步骤,达到提升运算速度,减少内存、显存占用的效果。
  • TFlite主要由Converter和Interpreter组成。Converter负责把TensorFlow训练好的模型转化,并输出为.tflite文件(FlatBuffer格式)。转化的同时,还完成了对网络的优化,如量化。Interpreter则负责把.tflite部署到移动端,嵌入式(embedded linux device)和microcontroller,并高效地执行推理过程,同时提供API接口给Python,Objective-C,Swift,Java等多种语言。简单来说,Converter负责打包优化模型,Interpreter负责高效易用地执行推理
  • Fashion-MNIST总共有十个类别的图像。每一个类别由训练数据集6000张图像和测试数据集1000张图像。所以训练集和测试集分别包含60000张和10000张。测试训练集用于评估模型的性能。

  • 每一个输入图像的高度和宽度均为28像素。数据集由灰度图像组成。Fashion-MNIST,中包含十个类别,分别是t-shirt,trouser,pillover,dress,coat,sandal,shirt,sneaker,bag,ankle boot。

1.1 模型创建

  • 导包
# 导包
from tensorflow import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
  • 时尚数据导入
# 时尚数据导入
fashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]
  • 标准化
# 标准化
from sklearn.preprocessing import StandardScaler    # preprocessing 预处理
scaler = StandardScaler()

x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1, 784))
x_valid_scaled = scaler.fit_transform(x_valid.astype(np.float32).reshape(-1, 784))
x_test_scaled = scaler.fit_transform(x_test.astype(np.float32).reshape(-1, 784))
  • 创建模型
# 创建模型
model = keras.models.Sequential([keras.layers.Dense(512, activation = 'relu', input_shape = (784, )),
                                keras.layers.Dense(256, activation = 'relu'),
                                keras.layers.Dense(128, activation = 'relu'),
                                keras.layers.Dense(10, activation = 'softmax')])

model.compile(loss = 'sparse_categorical_crossentropy',
              optimizer = 'adam',
              metrics = ['accuracy'])

1.2 保存模型

# 保存模型
import os
logdir = './graph_def_and_weights'
if not os.path.exists(logdir):
    os.mkdir(logdir)
    
output_model_file = os.path.join(logdir, 'fashion_mnist_weight.h5')
callbacks = [keras.callbacks.TensorBoard(logdir),  # 保存地址
             # 保存效果最好的模型: save_best_only
             keras.callbacks.ModelCheckpoint(output_model_file,    
                                             save_best_only = True, 
                                             save_weights_only = True),
             keras.callbacks.EarlyStopping(patience = 5, min_delta = 1e-3)]   

history = model.fit(x_train_scaled, y_train, epochs = 10,
                    validation_data= (x_valid_scaled, y_valid),
                    callbacks = callbacks)

  • 保存模型
# 保存模型
output_model_file2 = os.path.join(logdir, 'fashion_mnist_model.h5')
model.save(output_model_file2)
  •  保存参数

# 另一种保存参数的方法
model.save_weights(os.path.join(logdir, 'fashion_mnist_weights_2.h5'))
  • 模型评估
# evaluate 评估
model.evaluate(x_valid_scaled, y_valid)   # [0.35909169912338257, 0.88919997215271]
  • 模型加载
# 加载模型
model2 = keras.models.load_model(output_model_file2)
model2.evaluate(x_valid_scaled, y_valid)  # [0.35909169912338257, 0.88919997215271]

二 保存模型为savemodel格式

# 把keras模型保存成savedmodel格式
tf.saved_model.save(model, './keras_saved_model')
  •  读取模型
# 加载savedmodel模型
loaded_saved_model = tf.saved_model.load('./keras_saved_model')
loaded_saved_model

 2.1 另一种保存

# 保存模型
import os
logdir = './graph_def_and_weights'
if not os.path.exists(logdir):
    os.mkdir(logdir)
    
output_model_file = os.path.join(logdir, 'fashion_mnist_weight.h5')
model.load_weights(output_model_file)

三 tflite_interpreter 的使用

  • 导包
from tensorflow import keras
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import os
with open('./tflite_models/concrete_func_tf_lite', 'rb') as f:
    concrete_func_tflite = f.read()
  • 创建interpreter
# 创建interpreter
interpreter = tf.lite.Interpreter(model_content = concrete_func_tflite)
# 分配内存
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
  • 预测数值
input_data = tf.constant(np.ones(input_details[0]['shape'], dtype = np.float32))
# 传入预测数据
interpreter.set_tensor(input_details[0]['index'], input_data)

# 执行预测
interpreter.invoke()

# 获取输出
output_results = interpreter.get_tensor(output_details[0]['index'])
print(output_results)

四 to_concrete_function

  • 加载文件
# 从文件加载
loaded_keras_model = keras.models.load_model('./graph_def_and_weights/fashion_mnist_model.h5')
loaded_keras_model(np.ones((1, 784)))

  • 把keras模型转化为concrete function
# 把keras模型转化为concrete function
run_model = tf.function(lambda x: loaded_keras_model(x))
keras_concrete_func = run_model.get_concrete_function(tf.TensorSpec(loaded_keras_model.inputs[0].shape,
                                              loaded_keras_model.inputs[0].dtype))
# 使用
keras_concrete_func(tf.constant(np.ones((1, 784), dtype = np.float32)))

五 to_quantized_tflite

5.1 keras to tflite

# 从文件加载
loaded_keras_model = keras.models.load_model('./graph_def_and_weights/fashion_mnist_model.h5')
loaded_keras_model
# lite 精简版模型   # 创建转化器
keras_to_tflite_converter = tf.lite.TFLiteConverter.from_keras_model(loaded_keras_model)
keras_to_tflite_converter
# 给converter添加量化的优化  # 把32位的浮点数变成8位整数
keras_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]

# 执行转化
keras_tflite = keras_to_tflite_converter.convert()
# 写入指定文件
import os
if not os.path.exists('./tflite_models'):
    os.mkdir('./tflite_models')
    
with open('./tflite_models/quantized_keras_tflite', 'wb') as f:
    f.write(keras_tflite)

5.2 concrete function to tflite

# 把keras模型转化成concrete function
run_model = tf.function(lambda x: loaded_keras_model(x))
keras_concrete_func = run_model.get_concrete_function(tf.TensorSpec(loaded_keras_model.inputs[0].shape,
                                                                    loaded_keras_model.inputs[0].dtype))
concrete_func_to_tflite_converter = tf.lite.TFLiteConverter.from_concrete_functions([keras_concrete_func])
concrete_func_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
concrete_func_tflite = concrete_func_to_tflite_converter.convert()
with open('./tflite_models/quantized_concrete_func_tf_lite', 'wb') as f:
    f.write(concrete_func_tflite)

5.3 saved_model to tflite

saved_model_to_tflite_converter = tf.lite.TFLiteConverter.from_saved_model('./keras_saved_model/')
saved_model_to_tflite_converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
saved_model_tflite = saved_model_to_tflite_converter.convert()
with open('./tflite_models/quantized_saved_model_tflite', 'wb') as f:
    f.write(saved_model_tflite)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/383317.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

spark graph基础(一)

1 overView 1.1 图的构成 图由节点和边组成,其中VertexRDD[VD] 和EdgeRDD[ED] 继承和优化了 RDD[(VertexId, VD)] 和RDD[Edge[ED]] 。 class Graph[VD, ED] {val vertices: VertexRDD[VD]val edges: EdgeRDD[ED] }1.2 图使用示例 如下图所示,使用spa…

Typro使用以及安装教程来啦

Typora是一款轻便简洁的Markdown编辑器,支持即时渲染技术,这也是与其他Markdown编辑器最显著的区别。即时渲染使得你写Markdown就想是写Word文档一样流畅自如,不像其他编辑器的有编辑栏和显示栏。今天为大家分享下有关Typroa的安装以及使用&a…

TryHackMe-黑我杯

黑我杯 相信我们大家在TryHackMe的日积月累都学到了不少东西,从纯萌新到oscp再到更高 我很高兴能将国内各thm玩家聚集到一起,构建一个更好的学习环境和氛围 本次娱乐分两场: Offensive Pentesting — 中等难度Junior Penetration — 容易难…

@Autowired和@Resource到底有什么区别

Autowired 和 Resource 都是 Spring/Spring Boot 项目中,用来进行依赖注入的注解。它们都提供了将依赖对象注入到当前对象的功能,但二者却有众多不同,并且这也是常见的面试题之一,所以我们今天就来盘它。 Autowired 和 Resource 的…

Linux 确认 NTP 是否同步成功

NTP 即Network Time Protocol,它通过网络同步计算机系统之间的时钟。NTP 客户端会将其时钟与 NTP 服务器同步。NTP同步状态可以通过以下三个命令查询:ntpq:ntpq 是标准的 NTP 查询程序。ntpstat:显示网络时间同步的状态。timedate…

【房间墙上凿个洞,看你在干嘛~】安全攻防内网渗透-绕过防火墙和安全检测,搭建DNS隐蔽隧道

作者:Eason_LYC 悲观者预言失败,十言九中。 乐观者创造奇迹,一次即可。 一个人的价值,在于他所拥有的。所以可以不学无术,但不能一无所有! 技术领域:WEB安全、网络攻防 关注WEB安全、网络攻防。…

Spark 广播/累加

Spark 广播/累加广播变量普通变量广播分布式数据集广播克制 Shuffle强制广播配置项Join Hintsbroadcast累加器Spark 提供了两类共享变量:广播变量(Broadcast variables)/累加器(Accumulators) 广播变量 创建广播变量…

快速上手配置firewalld

firewalld使用firewall-cmd命令配置策略。 查看当前firewalld当前服务运行状态 firewall-cmd --state firewalld防火墙状态还用使用如下命令查看状态 systemctl status firewalld 查看所有打开运行的端口 firewall-cmd --zonepublic --list-ports 查看区域信息情况 firewall…

qml学习之qwidget与qml结合使用并调用信号槽交互

学习qml系列之一说明: 学习qml系列之qwiget和qml信号槽的交互使用,并在qwidget中显示qml界面 在qml中发送信号到qwidget里 在qwidget里发送信号给qml 在qwidget里面调用qml界面方式 方式一:使用QQuickView 这个是Qt5.0中提供的一个类&…

小白量化《穿云箭集群量化》(5)抄底雷达策略

小白量化《穿云箭集群量化》(5)抄底雷达策略 雷达能够提前发现远处敌我动向。雷达是现代战争不可或缺的装备。 证券市场中分三类人,先知先觉者,后知后觉者,不知不觉者。先知先觉者往往是市场主力,他们拥有信…

Feign踩坑源码分析 -- 请求参数分号变逗号

一.案例 1.1.Post请求: http://localhost:8250/xx/task/test json格式参数: {"string": "a;b;c;d" } 1.2.controller代码: AutowiredDataSourceClientService dataSourceClientService;RequestMapping("/test"…

《计算机原理》——HelloWorld.cpp如何运行的

学校《计算机原理》开课啦!特此开辟专栏,将一些知识作为笔记,记录下来。 前言 本篇博客知识点来源于educoder的相关题目 1. 相关知识 1.1 计算机语言 计算机语言是人与计算机之间通讯的语言,计算机语言包括编写计算机程序的字符…

[MatLab]图像绘制

一、绘制二维图像 1.一张图上绘制一条线 绘制代码如下面所示: x 0:0.01:2*pi; y sin(x); figure %建立幕布 plot(x,y) %绘制图像 %设置图像属性 title(ysin(x)) xlabel(x) ylabel(y)xlim([0 2*pi]) %限制x轴的值域 自定义图线的颜色…

GB28181协议--SIP协议介绍

1、SIP协议简介 SIP(Session Initiation Protocol,会话初始协议)是一个用于建立、更改和终止多媒体会话的应用层控制协议,其中的会话可以是IP电话、多媒体会话或多媒体会议(GB28181安防使用的是SIP协议)。S…

lab备考第二步:HCIE-Cloud-Compute-第一题:FusionCompute

第一题 FusionCompute 一、题目介绍 1.1. 扩容CAN节点与对接共享存储(必选) 题目及【考生提醒关键点】 扩容一台CNA节点,配置管理地址设置为:192.168.100.212。密码设置为:Cloud12#$。【输入之前确认自己的大小写是否…

任务类风险漏洞挖掘思路

任务类风险定义: 大部分游戏都离不开任务,游戏往往也会借助任务,来引导玩家上手,了解游戏背景,增加游戏玩法,提升游戏趣味性。任务就像线索,将游戏的各个章节,各种玩法,…

docker上安装nacos

文章目录一、docker安装nacos简单版1.拉取镜像2、挂载目录,用于映射到容器,目录按自己的情况创建3、mysql新建nacos-config的数据库,并执行脚本 sql脚本地址如下:4、修改配置文件custom.properties5、启动容器6、访问二、docker安…

错误:PermissionError: [WinError 32] 另一个程序正在使用此文件,进程无法访问。“+文件路径“的解决方案

最近在使用python进行筛选图片的时候,想到用python里面的os库进行图片的删除。 具体筛选方法就是,删除掉图片长度或宽度小于100像素的图片,示例代码如下所示: for file in os.listdir(img_path):if file .split( . )[ - 1 ] j…

深度强化学习DLR

1 强化学习基础知识 强化学习过程:⾸先环境(Env)会给智能体(Agent)⼀个状态(State),智能体接收到环境给的观测值之后会做出⼀个动作(Action),环境接收到智能体给的动作之后会做出⼀系列的反应,例如对这个动作给予⼀个奖励(Reward…

射频功率放大器基于纵向导波的杆状构件腐蚀诊断方法的研究

实验名称:基于纵向导波的杆状构件腐蚀诊断方法研究方向:无损探伤测试设备:信号号发生器、安泰ATA-8202功率放大器、数据采集卡、直流电源、超声探头、钢杆、前置放大器。实验过程:图:试验装置试验装置如图3.2所示。监测…