30天从入门到精通TensorFlow1.x 第七天,TensorFlow1.x 模型save()和restore()

news2025/1/18 9:03:30

文章目录

  • 一、接前一天
  • 二、TensorFlow中模型的保存和加载方式
  • 三、模型的保存和加载`Save()`类
    • 1. 创建方法以及文件组成
    • 2. Saver类的重要参数
      • 参数:
      • `重要函数参数`:
    • 3. Saver类的主要使用函数
  • 四、keras的模型保存和加载
    • 1. Keras比较简单:一般有三种选择
    • 2. 区别
    • 3. 保存完整的好处
  • 五、基于简单线性回归的实践
    • 1. 实例化对象
    • 2. 保存
    • 3.加载
    • 4. 基于线性回归完整的保存代码
    • 5. 基于线性回归完整的加载使用代码

一、接前一天

前一天讨论了tensorboard今天讨论在TensorFlow1.x中的模型保存加载使用是在实际开发中我们不会简单的将模型上传服务器,开发api接口调用的,因为将模型放在服务器上,使用标准API接口文档编写也可以实现推理功能。但是,这种方式可能会涉及到一些额外的工作,例如:设置部署环境处理并发请求处理模型版本更新等问题。这样的工作需要进行更多的编程和配置,而且可能不够灵活和可扩展。
这里涉及的技术例如: TF服务等,今天就先不讨论,明天详细讨论这里。

二、TensorFlow中模型的保存和加载方式

  1. SavedModel 格式:SavedModel 是 TensorFlow 推荐的模型格式,可以将模型保存为一个目录,其中包含了模型结构、变量和运行时信息。SavedModel 可以跨平台、跨语言地加载模型,并支持部署到 TensorFlow Serving 等服务中。

  2. Checkpoint 文件:Checkpoint 文件是 TensorFlow 训练的中间结果,保存了所有变量和权重的值。可以使用 tf.train.Saver() 将变量保存到 Checkpoint 文件中,也可以使用 tf.train.latest_checkpoint() 加载最新的 Checkpoint 文件。

  3. Keras 模型文件:Keras 是 TensorFlow 的高级 API,可以使用 keras.models.save_model() 方法将 Keras 模型保存为 HDF5SavedModel 格式的文件。

  4. Protocol Buffer 文件:TensorFlow 使用 Protocol Buffer 格式来序列化和反序列化数据,包括模型结构、变量和运行时信息。可以使用 tf.io.write_graph() 和 tf.io.read_graph() 方法将模型保存为或加载为 GraphDef Protocol Buffer 文件。

我们只讨论 第二种第三种,以适用我们在不同场景中的使用。

三、模型的保存和加载Save()

Saver类的保存方法属于第二种方法,即Checkpoint文件。 Saver是TensorFlow提供的一个用于保存和恢复模型变量的类。它可以将训练过程中的模型参数保存到硬盘上,以便在需要的时候重新加载模型参数。

1. 创建方法以及文件组成

通过tf.train.Saver 类创建 saver对象

在这里插入图片描述

元计算图:计算图的协议缓冲区定义,扩展名为.meta

检查点:各种变量的值,包含两个文件,一个是扩展名为.index的,另一个是 文件扩展名为data-0000-of-00001

检查点(Checkpoint)是保存训练过程中模型的参数值的文件。当我们训练一个模型时,可以定期保存模型参数的当前值为一个检查点文件,以便在训练过程中出现故障或需要更改超参数时,能够从上一个检查点恢复训练,而不用重新开始训练。

通过使用检查点文件,我们可以将训练过程分成若干个阶段,每个阶段保存一个检查点文件,以便在训练过程中出现意外情况(如程序意外退出、计算机断电等)时能够快速恢复模型。检查点还可以用于在训练过程中选择最佳的模型。我们可以通过比较不同的检查点文件,选择验证集分类准确率最高或损失函数最小的模型作为最终结果。

2. Saver类的重要参数

参数:

var_list: 可选参数,表示需要保存或恢复的变量列表,默认为所有可训练变量。
reshape: 可选参数,如果为True,则载入模型时会尝试自动将变量reshape成原来的形状,默认为False。
sharded: 可选参数,如果为True,则在保存数据时将其分割成多个文件进行存储(仅当save_relative_paths=True时有效),默认为False。
max_to_keep: 可选参数,表示最多保存的检查点数量,默认为5。
keep_checkpoint_every_n_hours: 可选参数,表示每隔多少小时保存一个检查点,默认为10000个步骤保存一次。
name: 可选参数,表示Saver对象的名称。

重要函数参数

  1. save(session, save_path, global_step=None, latest_filename=None, meta_graph_suffix=‘meta’, write_meta_graph=True, write_state=True, strip_default_attrs=False, save_debug_info=False):
    保存模型参数到指定路径,其中
    session:表示当前Session,
    save_path:表示保存路径,
    global_step:表示当前模型训练的全局步数,
    latest_filename:表示保存最近检查点的文件名,默认为’checkpoint’,
    meta_graph_suffix:表示元图文件的后缀,默认为’meta’,
    write_meta_graph:表示是否写入元图(即计算图)文件,默认为True,
    write_state:表示是否写入变量文件,默认为True,
    strip_default_attrs:表示是否去除默认属性,默认为False,
    save_debug_info:表示是否保存调试信息,默认为False。
  2. restore(session, save_path, var_list=None, global_step=None, latest_filename=None, meta_graph_suffix=‘meta’, write_meta_graph=True, write_state=True, strip_default_attrs=False): 从指定路径恢复模型参数,其中
    session:表示当前Session,
    save_path:表示保存路径,
    var_list:表示需要恢复的变量列表,默认为所有可训练变量,
    global_step:表示需要恢复的全局步数,
    latest_filename:表示最近检查点的文件名,默认为’checkpoint’,
    meta_graph_suffix:表示元图文件的后缀,默认为’meta’,
    write_meta_graph:表示是否写入元图(即计算图)文件,默认为True,
    write_state:表示是否写入变量文件,默认为True,
    strip_default_attrs:表示是否去除默认属性,默认为False

TensorFlow 2.x中的Saver类被tf.train.Checkpoint取代了

3. Saver类的主要使用函数

  1. tf.train.Saver()
    该函数可以创建一个Saver对象,用于保存恢复TensorFlow模型。你可以使用saver.save()方法将模型保存到磁盘上,也可以使用saver.restore()方法从磁盘上恢复模型。

  2. saver.save(sess, save_path)
    该方法可以将当前会话(Session)的所有变量保存到磁盘上,其中参数sess是一个已经打开的会话,save_path是要保存的模型文件的路径。

  3. tf.train.import_meta_graph(meta_graph_def)
    该函数可以从.meta文件中导入图结构(Graph),其中meta_graph_def是一个包含图定义信息的.meta文件。

  4. saver.restore(sess, save_path)
    该方法可以从指定的模型文件中恢复所有变量,其中参数sess是一个已经打开的会话,save_path是之前保存的模型文件的路径。

  5. tf.get_default_graph()
    该函数可以获取默认的计算图(Graph),在加载模型时需要借助该函数来获取图中的所有节点和变量。

  6. graph.get_tensor_by_name(name)
    该方法可以根据节点名称获取计算图中的张量(Tensor)。

  7. graph.get_operation_by_name(name)
    该方法可以根据节点名称获取计算图中的操作(Operation)。

四、keras的模型保存和加载

1. Keras比较简单:一般有三种选择

  1. 仅保存模型结构(Architecture):使用 model.to_json() 方法将模型结构保存为 JSON 文件格式,使用 model_from_json() 方法加载模型结构。

  2. 保存模型结构和权重(Weights):使用 model.save_weights() 方法将模型权重保存为 HDF5 文件格式,使用 load_weights() 方法加载模型权重。此外,还可以使用 model.save() 方法保存整个模型,包括模型结构和权重也是以 HDF5 文件格式进行保存

  3. 保存完整的模型对象,包括模型结构权重编译信息等:使用 tf.keras.models.save_model() 方法将整个模型保存为 SavedModel 格式或者 HDF5 文件格式,使用 tf.keras.models.load_model() 方法加载模型。如果需要保存模型的编译信息,例如优化器、损失函数、评估指标等,则需要设置 save_format=‘tf’ 参数,以保存为 SavedModel 格式。

2. 区别

这三种方式的主要区别在于保存的内容不同。第一种方式只保存了模型的结构,无法直接使用;第二种方式保存了模型的权重,可以方便地用于继续训练或者对新数据进行预测;第三种方式保存了完整的模型对象,可以方便地复现模型,并且包含了模型的结构、权重和编译信息。

3. 保存完整的好处

第三种方式还支持保存为 SavedModel 格式,这是 TensorFlow 官方推荐的模型保存格式,可以方便地部署TensorFlow Serving 或者其他平台上。而 HDF5 文件格式则更加通用,可以在不同框架之间进行转换和共享。

五、基于简单线性回归的实践

1. 实例化对象

save = tf.train.Saver()

2. 保存

with tf.Session() as sess:
	...
	save.save(sess,'./')
'''
参数一:会话对象
参数二:保存的路径以及名字
'''

3.加载

with tf.Session() as sess:
	...
	save.restore(sess,'./')

4. 基于线性回归完整的保存代码

import tensorflow as tf
from tensorflow.python.client import device_lib
import os
print(device_lib.list_local_devices())




tf.reset_default_graph()

with tf.device("/device:cpu:0"):
    # create w and b init 0.0
    w = tf.Variable(0.0, name='weight')
    b = tf.Variable(0.0, name='bias')

    # create input and out
    x = tf.placeholder(dtype=tf.float32, shape=[None])
    out = tf.placeholder(dtype=tf.float32, shape=[None])

    # create loss and opt
    y = w * x + b

    loss = tf.reduce_mean(tf.square(y - out))
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
    train_op = optimizer.minimize(loss)

config = tf.ConfigProto()
config.log_device_placement=True

output_scalar = tf.reduce_mean(y)  # 将 y 变量转换为标量
loss_scalar = tf.reduce_mean(loss)
# 记录标量数据(输出结果和损失值)
tf.summary.scalar('output', output_scalar)
tf.summary.scalar('loss', loss_scalar)

# 记录张量数据(权重和偏置项)
tf.summary.histogram('weight', w)
tf.summary.histogram('bias', b)

# 记录网络结构
with tf.name_scope('hidden'):
    h = tf.nn.sigmoid(y)
    tf.summary.histogram('activation', h)

#实例化模型保存对象
save = tf.train.Saver()

# train model
with tf.Session(config=config) as sess:

    # 合并所有的摘要操作
    merged = tf.summary.merge_all()
    # 创建摘要写入器

    writer = tf.summary.FileWriter('./logss', sess.graph)

	#注意!!!在保存模型的时候 需要进行变量初始化
    tf.global_variables_initializer().run(session=sess)
    for i in range(100000):
        summary,_, loss_val, w_val, b_val = sess.run(
            [merged,train_op, loss, w, b],
            feed_dict={x: [1, 23, 4, 5, 7, 5, 7], out: [3, 5, 7, 9, 11, 13, 15]}
        ) #注意 输入的数据 形状要一致,避免输出与预测值得形状不一致问题
        if i % 100 == 0:
            print('Step {}: loss = {}, w = {}, b = {}'.format(i, loss_val, w_val, b_val))

            writer.add_summary(summary, i)


		#保存模型
    save_model_file = save.save(sess,'./model.ckpt')

5. 基于线性回归完整的加载使用代码

其实:加载模型进行预测就是使用训练的模型变量对会话进行初始化。一般我们训练的时候会对变量进行全局初始化,在加载模型的时候改为 训练好的模型变量。

import tensorflow as tf
from tensorflow.python.client import device_lib
import os
print(device_lib.list_local_devices())




tf.reset_default_graph()

with tf.device("/device:cpu:0"):
    # create w and b init 0.0
    w = tf.Variable(0.0, name='weight')
    b = tf.Variable(0.0, name='bias')

    # create input and out
    x = tf.placeholder(dtype=tf.float32, shape=[None])
    out = tf.placeholder(dtype=tf.float32, shape=[None])

    # create loss and opt
    y = w * x + b

    # loss = tf.reduce_mean(tf.square(y - out))
    # optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)
    # train_op = optimizer.minimize(loss)

config = tf.ConfigProto()
config.log_device_placement=True

#实例化保存对象
save = tf.train.Saver()

# train model
with tf.Session(config=config) as sess:


    save.restore(sess,'./')
    #加载模型的时候不需要进行变量初始化
    # tf.global_variables_initializer().run(session=sess)

    outs = sess.run(
        [y],
        feed_dict={x:[1,2,3]}
    ) #注意 输入的数据 形状要一致,避免输出与预测值得形状不一致问题

    print('outs;',outs)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/619662.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PLSQL - Months_Between的理解和使用

Months_Between是一个内置的日期操纵函数,用于计算两个日期相隔的月份数。Oracle文档的介绍如下: MONTHS_BETWEEN returns number of months between dates date1 and date2. The month and the last day of the month are defined by the parameter NL…

jenkins主从节点安装及pipeline构建

一、背景 通过Jenkins主节点配置的pipeline下发给从节点执行,从而兼容容器化执行 二、安装主节点 docker-compose.yml jenkins:user: rootrestart: alwaysimage: jenkinsci/blueoceancontainer_name: jenkins# network_mode: hostports:- "8081:8080"-…

BOS EDI Excel 方案简介

BOS EDI & Excel 方案简介 本文将继续分享BOS示例工作流:使用Excel端口和Email端口生成一系列文件,完成与BOS的EDI通信。 下载工作流 下载示例文件 BOS EDI到Excel示例流具有预配置的端口,用于从BOS的EDI集成规范转换以下交易集&…

2022年国赛高教杯数学建模D题气象报文信息卫星通信传输解题全过程文档及程序

2022年国赛高教杯数学建模 D题 气象报文信息卫星通信传输 原题再现 在某些紧急救援任务中,需要进行物资空投。在地面通信系统瘫痪的情形下,为了更好地获得准确完整的地面气象观测信息,通常对任务区域的重要目标点采用派遣气象分队的方式来获…

如果让你来设计CPU之内存篇

哈喽,我是子牙,一个很卷的硬核男人。深入研究Windows内核、Linux内核、Hotspot源码…聚焦做那些大家想学没地方学的课程:手写操作系统、手写虚拟机、手写模拟器、手写编程语言… 目前已经做了两个成熟的课程:手写JVM、手写OS&…

PPP认证协议详解

PPP认证协议详解 1. 引言 PPP(Point-to-Point Protocol)认证协议在计算机网络中扮演着重要的角色。它是一种用于建立和认证网络连接的协议,广泛应用于各种网络环境,包括互联网接入、虚拟专用网络(VPN)和远…

【头歌】试的学习

1.基本路径测试 2.画出程序控制流图 3.计算流图的环形复杂度 4.确定线性独立路径的基本集合 5.设计测试用例 基本路径测试 除了逻辑覆盖,还有一种常用的白盒测试的测试方法:基本路径测试。基本路径测试是 Tom McCabe提出的一种白盒测试技术。使用这种技…

c++学习——继承

继承 **继承****继承的案例****继承的三种方式方式&#xff1a;****继承中的对象类型****继承中的构造和析构顺序****继承中同名成员的处理****同名静态成员处理****多继承语法****菱形继承** 继承 普通的输出 #define _CRT_SECURE_NO_WARNINGS #include <iostream> us…

8. 让java性能提升的JIT深度解剖

JVM性能调优 1. C1、C2与Graal编译器1.1 C1编译器1.2 C2编译器1.3 分层编译 2. 热点代码3. 热点探测4. 方法调用计数器5. 回边计数器6. 编译优化技术6.1 方法内联 7. 锁消除8. 栈上分配9. 逃逸分析技术10. 标量替换 本文是按照自己的理解进行笔记总结&#xff0c;如有不正确的地…

【LeetCode热题100】打卡第14天:下一个排列最长有效括号

文章目录 【LeetCode热题100】打卡第14天&#xff1a;下一个排列&最长有效括号下一个排列⛅前言&#x1f512;题目&#x1f511;题解 最长有效括号&#x1f512;题目&#x1f511;题解 【LeetCode热题100】打卡第14天&#xff1a;下一个排列&最长有效括号 下一个排列 …

如何入门挖掘SRC?

挖洞其实算是web渗透中第一个明确的关卡 越过这个坎&#xff0c;从此天高任鸟飞&#xff0c;海阔凭鱼跃。越不过&#xff0c;就永远越不过。 先说平台&#xff1a; 漏洞响应平台&#xff1a;实战渗透测试&#xff0c;同时能获得一些外快。 补天漏洞响应平台&#xff1a;http…

Netty核心技术五--Netty高性能架构设计

1. 线程模型基本介绍 不同的线程模式&#xff0c;对程序的性能有很大影响&#xff0c;为了搞清Netty 线程模式&#xff0c;我们来系统的讲解下 各个线程模式&#xff0c; 最后看看Netty 线程模型有什么优越性.目前存在的线程模型有: 传统阻塞 I/O 服务模型Reactor 模式 根据 R…

郭光灿团队实现低温集成量子纠缠光源

中国科大郭光灿院士团队在集成化量子光源制备研究中取得重要进展。该团队任希锋研究组基于低温集成自发四波混频过程&#xff0c;展示了低温条件下集成量子纠缠光源的制备&#xff0c;相关成果于6月2日发表在光学知名学术期刊Optica上。 “利用低温综合四波混合技术产生纠缠现象…

Mapbox表达式详细解读

初学mapbox 的小伙伴们一定会被表达式给弄的晕头转向的。明明条件判断或者回调函数能解决的问题。mapbox里非得让你用表达式。这确实比较ex。 不过我们既然遇到了,也不要怕,这篇文章我就带着大家一点一点的搞明白这个所谓的表达式。 首先从宏观上讲,要知道为什么使用表达式…

【面试高频】cookie、session、token?看完再也不担心被问了

在以往的面试记录里&#xff0c;我又看到了一个多次被问到的知识点&#xff0c;那就是 cookie、session、token 的区别有哪些&#xff1f;如果现在来问你&#xff0c;不知道你能否说清楚呢&#xff1f; 今天不仅仅是整理出这三者的区别&#xff0c;更重要的是能够真正去理解这三…

Python | print写入日志

Python | print写入日志 有时我们需要将屏幕上打印的消息保存到一个文件中&#xff0c;如果每条信息都通过调用写入函数来实现&#xff0c;就太麻烦了 这里自己定义1个日志类&#xff0c;然后将 sys.stdout 设置为该类即可&#xff0c;非常方便 sys.stdout Logger(fileName …

卡尔曼滤波与组合导航原理(八)遗忘滤波

函数模型 { X k Φ k l k − 1 X k − 1 Γ k − 1 W k − 1 Z k H k X k V k \left\{\begin{array}{l} \boldsymbol{X}_{k}\boldsymbol{\Phi}_{k l k-1} \boldsymbol{X}_{k-1}\boldsymbol{\Gamma}_{k-1} \boldsymbol{W}_{k-1} \\ \boldsymbol{Z}_{k}\boldsymbol{H}_{k} \…

C语言:使用 普通方法 和 二分查找算法(折半查找算法) 在一个有序数组中查找具体的某个数字n

题目&#xff1a; 从键盘输入数字n&#xff0c;在一个 有序数组 中查找具体的某个数字n。 思路一&#xff1a;普通方法 &#xff08;逻辑简单&#xff0c;在无序数组中也可以使用&#xff0c;但效率较低&#xff0c;需要逐个查找&#xff09; 总体思路&#xff1a; &#xff…

日常培训管理-参训名单/BootstrapTable获取表数据 / js 删除两个数组中id相同的对象/

---2022.11.9 1、 现在有一个功能是从下面待选名单中选中&#xff0c;再点击这个添加按钮&#xff0c;就会将这些人添加到上面这个参训名单&#xff0c;然后再给其中每个人手动打分。分打完 BootstrapTable中有两组数据&#xff0c;在下面待选名单数据条目前面中打钩选中&am…

从零开始学习CTF——CTF基本概念

这一系列把自己学习的CTF的过程详细写出来&#xff0c;方便大家学习时可以参考。 一、CTF简介 01」简介 中文一般译作夺旗赛&#xff08;对大部分新手也可以叫签到赛&#xff09;&#xff0c;在网络安全领域中指的是网络安全技术人员之间进行技术竞技的一种比赛形式。 CTF…