【NLP】使用 Keras 保存和加载深度学习模型

news2025/1/7 10:31:06

一、说明

        训练深度学习模型是一个耗时的过程。您可以在训练期间和训练后保存模型进度。因此,您可以从上次中断的地方继续训练模型,并克服漫长的训练挑战。

        在这篇博文中,我们将介绍如何保存模型并使用 Keras 逐步加载它。我们还将探索模型检查点回调,它通常用于模型训练。

二、加载数据集

        为了演示如何保存模型,让我们使用 MNIST 数据集。此数据集由数字图像组成。

MNIST 数据集

        在加载 MNIST 数据集之前,让我们先导入 TensorFlow 和 Keras。

import tensorflow as tf
from tensorflow import keras

现在,让我们使用 Keras 中的方法加载训练和测试数据集。load_data

(train_images,train_labels),(test_images,test_labels)=tf.keras.datasets.mnist.load_data()

        训练输入和输出数据集由 60,000 个样本组成,测试输入和输出数据集由 10,000 个样本组成。

三、数据预处理

        数据分析最重要的步骤之一是数据预处理。在深度学习中,一些数据预处理技术(如规范化和正则化)可以提高模型的性能。

        首先,让我们从这些数据集中获取前 1000 个样本,以更快地运行代码。让我们先对输出变量执行此操作。

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

我们将对训练数据执行相同的操作。数据样本由数字图像组成。这些图像是二维的。在将这些示例提供给模型之前,让我们使用该方法将它们转换为维度。此外,让我们规范化数据以提高模型的性能并使训练速度更快。reshape

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

很好,我们的数据集已经为模型做好了准备。让我们继续前进到模型构建步骤。

四、如何构建模型

        构建深度学习模型的最简单方法是 Keras 中的顺序技术。在这种技术中,层被逐个堆叠。我们现在要做的是定义一个包含模型的函数。通过这样做,我们可以更轻松地构建模型。

def create_model():
    model = tf.keras.Sequential([
      keras.layers.Dense(512, activation='relu',input_shape=(784,)),
      keras.layers.Dropout(0.2),
      keras.layers.Dense(10)
    ])
 
    model.compile(
      optimizer='adam',
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )
    return model

让我们来看看这些代码。首先,我们定义一个使用 Keras 创建和编译顺序模型的函数。

我们构建了一个包含两个密集层的模型,第一个层具有512个神经元和一个激活函数。我们还设置了一个 dropout 层,该层随机丢弃 20% 的输入单元,以帮助防止过度拟合。之后,我们编写了一个包含 10 个没有激活函数的神经元的密集层,因为它将用于 logits。relu

接下来,我们使用优化器和损失函数编译模型。作为指标,我们设置 .AdamSparseCategoricalCrossentropySparseCategoricalAccuracy

最后,我们使用语句返回已编译的模型。return

太棒了,我们已经定义了一个简单的顺序模型。现在,让我们从这个函数中获取一个名为 model 的示例对象。

model = create_model()

现在让我们看看使用摘要方法的模型的体系结构。

model.summary()

如您所见,我们的模型由输入层、辍学层和输出层组成。让我们继续探索回调。ModelCheckpoint

五、使用模型检查点回调保存模型权重

        可以保存模型以重用已训练的模型,或从上次中断的位置继续训练。

        如您所知,构建模型实际上意味着训练模型的权重,称为参数。通过回调,您可以在模型训练期间保存模型的权重。为了说明这一点,让我们从这个回调实例化一个对象。ModelCheckpoint

        首先,让我们创建模型将与 os 模块一起保存的目录。

import os
checkpoint_path = "training_1/my_checkpoints"
checkpoint_dir = os.path.dirname(checkpoint_path)

很好,我们已经创建了目录。现在让我们创建一个回调来保存模型的权重。

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
     filepath=checkpoint_path,
     # Let's save only the weights of the model
     save_weights_only=True)

        太好了,我们已经创建了回调。现在,让我们调用该方法和对此方法的回调。fitpass

model.fit(train_images,
   train_labels,
   epochs=10,
   validation_data=(test_images, test_labels),
   callbacks=[checkpoint_cb])

        因此,我们将模型权重保存在目录中。我们使用的回调在每个纪元结束时更新检查点文件。让我们使用 os 模块查看目录中的文件。

os.listdir(checkpoint_dir)

# Output
['my_checkpoints.index', 'my_checkpoints.data-00000-of-00001', 'checkpoint']

如您所见,权重是在最后一个纪元之后保存的。让我们继续看看如何加载重量。

六、装载权重

        保存权重后,可以将其加载到模型中。请注意,您只能将保存的权重用于具有相同体系结构的模型。

        让我们实例化一个对象来演示这一点。

model = create_model()

        请注意,我们尚未训练此模型的权重。这些权重是随机生成的。现在让我们看看这个未经训练的模型在测试数据上的准确性分数。evaluate

loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"Untrained model, accuracy: {100 * acc:5.2f}%")

# Output:
Untrained model, accuracy: 10.70%

        如您所见,未经训练的模型在测试数据上的准确率约为 10%。这是一个相当糟糕的分数,对吧?

        现在,让我们加载之前使用该方法保存的权重,然后查看此模型在测试数据上的准确性得分。load_weights

model.load_weights(checkpoint_path)

        太棒了,我们加载了重量。现在,让我们检查此模型在测试集上的性能。

loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print(f"Untrained model, accuracy: {100 * acc:5.2f}%")

#Output:
Untrained model, accuracy: 87.40%

        如您所见,事实证明该模型的准确率约为90%。

        在本节中,我们已经了解了如何保存模型权重以及如何加载它们。现在让我们继续探索如何保存整个模型。

七、保存整个模型

q        训练模型后,可能需要部署该模型。若要将模型的体系结构、权重和训练配置保存在单个文件中,可以使用该方法。save

        您可以将模型保存为两种不同的格式,以及 .请记住,在 Keras 中,默认情况下使用该格式。让我们保存最终模型。让我们为它创建一个目录。SaveModelHDF5SavedModel

mkdir saved_model

        现在让我们将模型保存在此文件中。

model.save('saved_model/my_model')

        太好了,我们保存了我们的模型。让我们看一下此目录中的文件。

ls saved_model/my_model

# Output:
assets fingerprint.pb keras_metadata.pb saved_model.pb variables

        在这里,您可以看到文件和子目录。不需要模型的源代码即可将模型投入生产。 足以进行部署。让我们仔细看看这些文件。SavedModel

        该文件包含模型的体系结构和计算图形。saved_model.pd

        该文件包含 Keras 所需的额外信息。keras_metadata.pb

        子目录包含权重和偏差等参数值。variables

        子目录包含额外的文件,例如属性和类的名称。assets

        很好,我们看到了如何保存整个模型。现在让我们看看如何加载模型。

八、加载模型

        您可以使用该方法加载保存的模型。为此,让我们首先创建模型体系结构,然后加载模型。load_model

new_model = create_model()
new_model = tf.keras.models.load_model('saved_model/my_model')

太棒了,我们已经加载了模型。让我们看一下这个模型的架构。

new_model.summary()

        请注意,此模型是使用与原始模型相同的参数编译的。让我们看看这个模型在测试数据上的准确性。

loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print(f'Restored model, accuracy: {100 * acc:5.2f}%')

# Output:
Restored model, accuracy: 87.40%

        如您所见,我们保存在测试数据上的模型的准确性得分为 87%。

        您还可以以格式保存模型。但是大多数TensorFlow部署工具都需要这种格式。h5SavedModel

九、总结

        训练模型时,可以保存模型以从上次中断的地方继续。通过保存模型,您还可以共享您的模型并允许其他人重新创建您的工作。

        在这篇博文中,我们介绍了如何保存和加载深度学习模型。首先,我们学习了如何使用回调保存模型权重。接下来,我们看到了保存和加载整个模型以部署模型。ModelCheckpoint

感谢您的阅读。您可以在此处找到笔记的链接。

参考资料:

如何使用 Keras 保存和加载深度学习模型? |迈向人工智能 (towardsai.net)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/781720.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

虹科活动 | 虹科ADAS自动驾驶研讨会

​​虹科ADAS/自动驾驶研讨会将于8月7日在上海闵行展开——加快ADAS/AD开发步伐! 期待您的参与!

Day45: 300.最长递增子序列,674. 最长连续递增序列,718. 最长重复子数组

目录 300.最长递增子序列 思路 674. 最长连续递增序列 思路 718. 最长重复子数组 思路 300.最长递增子序列 300. 最长递增子序列 - 力扣(LeetCode) 思路 1. 确定dp数组及其下标含义 dp[i]表示i之前包括i的以nums[i]结尾的最长递增子序列…

【每日运维】判断服务器时间同步是否正常

对于 ntpd 服务 ntpq -premote:时间同步源的 IP 地址或域名refid:参考 ID,它是一个代表时间源的唯一标识符st:层级,表示时间同步源的层级关系。较低的层级意味着更接近原子钟的时间源t:状态,表…

uni-app:script中设置的data,在界面的显示(包含图片src为data中的数据该如何展示),以及控制台的输出

样式&#xff1a; 两个图标的区别&#xff1a; 第一个图标是图片文件直接在文件夹static中展示 前台代码展示&#xff1a; <image class"logo" src"/static/logo.png"></image> 第二个图标是从服务器端进行的引用 在script中的data中进行的设…

【C++修炼之路】stl 中的容器适配器

&#x1f451;作者主页&#xff1a;安 度 因 &#x1f3e0;学习社区&#xff1a;StackFrame &#x1f4d6;专栏链接&#xff1a;C修炼之路 文章目录 一、stack二、queue三、deque四、priority_queue1、仿函数2、实现 如果无聊的话&#xff0c;就来逛逛 我的博客栈 吧! &#x1…

从新手到大师:优雅的Vim熟练之旅(万文详解)

从新手到大师&#xff1a;优雅的Vim熟练之旅 博主简介一、前言1.1、Vim编辑器的重要性和流行性1.2、目标 二、Vim简介2.1、什么是Vim2.2、历史和背景简介2.3、Vim的优势和适用场景 三、安装和设置Vim3.1、下载和安装Vim编辑器3.2、基本配置&#xff1a;.vimrc文件的重要性和常用…

Spinger ESE独立出版|2023年第二届能源与环境工程国际会议(CFEEE 2023)

会议简介 Brief Introduction 2023年第二届能源与环境工程国际会议(CFEEE 2023) 会议时间&#xff1a;2023年9月1日-3日 召开地点&#xff1a;中国三亚 大会官网&#xff1a;CFEEE 2023-2023 International Conference on Frontiers of Energy and Environment Engineering 由I…

leetcode 491. 递增子序列

2023.7.23 本题本质上也是要选取递归树中的满足条件的所有节点&#xff0c;而不是选取叶子节点。 故在将符合条件的path数组放入ans数组后&#xff0c;不要执行return。 还一点就是这个数组不是有序的&#xff0c;并且也不能将它有序化&#xff0c;所以这里的去重操作不能和之前…

MyBatis框架提供的分页助手插件pagehelper

使用MyBatis框架提供的分页助手插件可以很方便地实现分页查询。以下是一个基于MyBatis分页助手插件完成分页查询的示例&#xff1a; 1.首先&#xff0c;确保在项目的依赖中添加了MyBatis分页助手插件的依赖&#xff0c;例如&#xff1a; <dependency><groupId>co…

【C语言】getchar和putchar函数详解:字符输入输出的利器

目录 &#x1f4cc;getchar函数 ▪️ 函数原型&#xff1a; ▪️ 目的&#xff1a; ▪️ 返回值&#xff1a; ▪️ 用法&#xff1a; &#x1f4cc;putchar函数 ▪️ 函数原型&#xff1a; ▪️ 目的&#xff1a; ▪️ 参数&#xff1a; ▪️ 返回值&#xff1a; ▪…

20 QTreeWidget控件

代码&#xff1a; //treeWidget树控件//1&#xff1a;设置头部标签 QStringList()匿名对象创建ui->treeWidget->setHeaderLabels(QStringList()<<"英雄"<<"英雄介绍");//2&#xff1a;设置itemQTreeWidgetItem * liItem new QTreeWidg…

刘铁猛C#教程笔记——操作符

C#语言中的操作符 表中位于同一行的操作符优先级相同&#xff0c;从上到下优先级依次减弱&#xff1b; 操作符的用法举例 成员访问运算符——“.”&#xff1a;用于访问类中的成员或者访问位于某个名空间中的类&#xff0c;如&#xff1a; using System; using System.Collec…

Unity进阶--fsm状态机的使用笔记

文章目录 Unity进阶--fsm状态机的使用笔记第一种用基础的if播放实现动画控制switch--case实现状态机使用状态机 Unity进阶–fsm状态机的使用笔记 第一种用基础的if播放实现动画控制 朴实无华&#xff0c;简单易懂&#xff0c;但是耦合性太差。 switch–case实现状态机 写对应…

【JAVA】云HIS系统功能菜单知识(二)

随着医疗信息化和互联网技术的不断发展&#xff0c;云HIS在大数据管理和应用的优势日益凸显。对于医疗机构而言&#xff0c;云HIS平台可以帮助其实现更高效的医疗服务管理&#xff0c;并提高医疗服务的整体水平和效率。 一、系统管理 1.医院信息 基本信息、法人代表、主要负责…

IntelliJ IDEA2023中利用maven-archetype-quickstart模板创建项目无src文件夹及maven插件下载过慢问题的解决

目录 介绍问题之解决问题2的解决问题1的解决 介绍 昨天下载并安装了IntelliJ IDEA 2023的最新版&#xff08;以下简称为IDEA 2023&#xff09;&#xff0c;学习利用该IDE编写Java项目及将其与maven结合构建项目。我所安装的maven是去年暑假安装的&#xff0c;版本为Apache Mav…

【架构基础】架构概念

软件架构产生的背景 1972年图灵奖获得者、荷兰计算机科学家Edsger Wybe Dijkstra早在20世纪60年代就开始涉及软件架构概念了。 20世纪60年代第一次软件危机引出了结构化编程&#xff0c;创造了模块的概念。 20世纪80年代第二次软件危机引出了面向对象编程&#xff0c;创造了…

Flask的send file和send_from_directory的区别

可以自行查看flask 文档。 send file高效&#xff1b; send from directory安全&#xff0c;且适用于静态资源交互。 都是实现相同的功能的。 send_file send_from_directory

所有docker命令无效,解决办法

目录 ■前言 今天使用docker时&#xff0c;所有命令无效 ■解决办法如下 1.停止docker服务 2.查看状态 3.删除之前的docker相关的文件 4.再次查看状态 5.使用相关命令 &#xff08;好用了&#xff09; 6.重新下载镜像 ■前言 今天使用docker时&#xff0c;所有命令无…

MySQL 8.0 OCP (1Z0-908) 考点精析-备份与恢复考点1:MySQL Enterprise Backup概要

文章目录 MySQL 8.0 OCP (1Z0-908) 考点精析-备份与恢复考点1&#xff1a;MySQL Enterprise Backup概要MySQL Enterprise Backup下载与安装MySQL Enterprise Backup的备份过程MySQL Enterprise Backup的优势mysqlbackup 客户端例题例题1 &#xff1a; MySQL Enterprise Backup概…

idea的插件FastRequest,比postman更好用

1.安装插件Restful Fast Request 在插件plugin中直接搜索Restful Fast Request,然后点击install安装 2.使用插件 插件位置在右面&#xff0c;点开后呈现以下页面 配置项目名和环境 选择配置好的项目名和环境 启动项目后可以看到接口的小火箭&#xff0c;点击小火箭 3.…