政安晨:【Keras机器学习实践要点】(十)—— 自定义保存和序列化

news2024/11/24 19:57:36

目录

导言

涵盖的API

Setup

状态保存自定义

构建和编译保存自定义

结论


政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras机器学习实战

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

相较于上一篇文章而言,这是一篇更高级的图层和模型自定义保存指南。

导言

本文涵盖可在 Keras 保存中自定义的高级方法。对于大多数用户来说,上文的初级序列化、保存和导出文章中概述的方法已经足够

涵盖的API

我们将介绍以下应用程序接口:

  • save_assets() 和 load_assets()
  • save_own_variables() 和 load_own_variables()
  • get_build_config() 和 build_from_config()
  • get_compile_config() 和 compile_from_config()

还原模型时,这些操作将按以下顺序执行:

  • build_from_config()
  • compile_from_config()
  • load_own_variables()
  • load_assets()

Setup

import os
import numpy as np
import keras

状态保存自定义

这些方法决定了调用 model.save() 时如何保存模型图层的状态。您可以重载这些方法来完全控制状态保存过程。

save_own_variables() & load_own_variables()

这些方法分别在调用 model.save() 和 keras.models.load_model() 时保存和加载层的状态变量。默认情况下,保存和加载的状态变量是层的权重(可训练和不可训练)。

以下是 save_own_variables() 的默认实现:

def save_own_variables(self, store):
    all_vars = self._trainable_weights + self._non_trainable_weights
    for i, v in enumerate(all_vars):
        store[f"{i}"] = v.numpy()

这些方法使用的存储空间是一个字典,可以填充层变量。

下面我们来看一个自定义的示例:

示例:

@keras.utils.register_keras_serializable(package="my_custom_package")
class LayerWithCustomVariable(keras.layers.Dense):
    def __init__(self, units, **kwargs):
        super().__init__(units, **kwargs)
        self.my_variable = keras.Variable(
            np.random.random((units,)), name="my_variable", dtype="float32"
        )

    def save_own_variables(self, store):
        super().save_own_variables(store)
        # Stores the value of the variable upon saving
        store["variables"] = self.my_variable.numpy()

    def load_own_variables(self, store):
        # Assigns the value of the variable upon loading
        self.my_variable.assign(store["variables"])
        # Load the remaining weights
        for i, v in enumerate(self.weights):
            v.assign(store[f"{i}"])
        # Note: You must specify how all variables (including layer weights)
        # are loaded in `load_own_variables.`

    def call(self, inputs):
        dense_out = super().call(inputs)
        return dense_out + self.my_variable


model = keras.Sequential([LayerWithCustomVariable(1)])

ref_input = np.random.random((8, 10))
ref_output = np.random.random((8, 10))
model.compile(optimizer="adam", loss="mean_squared_error")
model.fit(ref_input, ref_output)

model.save("custom_vars_model.keras")
restored_model = keras.models.load_model("custom_vars_model.keras")

np.testing.assert_allclose(
    model.layers[0].my_variable.numpy(),
    restored_model.layers[0].my_variable.numpy(),
)

执行结果:

 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 101ms/step - loss: 0.2908

save_assets() 和 load_assets()


这些方法可以添加到模型类定义中,以存储和加载模型所需的任何附加信息。

例如,文本矢量化层(TextVectorization layer)和索引查找层(IndexLookup layer)等 NLP 领域层可能需要在保存时将其相关词汇(或查找表)存储到文本文件中。

让我们用一个简单的 assets.txt 文件来了解一下这个工作流程的基本情况。

示例:

@keras.saving.register_keras_serializable(package="my_custom_package")
class LayerWithCustomAssets(keras.layers.Dense):
    def __init__(self, vocab=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.vocab = vocab

    def save_assets(self, inner_path):
        # Writes the vocab (sentence) to text file at save time.
        with open(os.path.join(inner_path, "vocabulary.txt"), "w") as f:
            f.write(self.vocab)

    def load_assets(self, inner_path):
        # Reads the vocab (sentence) from text file at load time.
        with open(os.path.join(inner_path, "vocabulary.txt"), "r") as f:
            text = f.read()
        self.vocab = text.replace("<unk>", "little")


model = keras.Sequential(
    [LayerWithCustomAssets(vocab="Mary had a <unk> lamb.", units=5)]
)

x = np.random.random((10, 10))
y = model(x)

model.save("custom_assets_model.keras")
restored_model = keras.models.load_model("custom_assets_model.keras")

np.testing.assert_string_equal(
    restored_model.layers[0].vocab, "Mary had a little lamb."
)

构建和编译保存自定义

get_build_config() 和 build_from_config()

这些方法可共同保存图层的构建状态,并在加载时恢复这些状态。

默认情况下,这只包括一个包含图层输入形状的构建配置字典,但重载这些方法可用于包含更多变量和查找表,这些变量和查找表对于恢复构建的模型非常有用。

示例:

@keras.saving.register_keras_serializable(package="my_custom_package")
class LayerWithCustomBuild(keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super().__init__(**kwargs)
        self.units = units

    def call(self, inputs):
        return keras.ops.matmul(inputs, self.w) + self.b

    def get_config(self):
        return dict(units=self.units, **super().get_config())

    def build(self, input_shape, layer_init):
        # Note the overriding of `build()` to add an extra argument.
        # Therefore, we will need to manually call build with `layer_init` argument
        # before the first execution of `call()`.
        super().build(input_shape)
        self._input_shape = input_shape
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer=layer_init,
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,),
            initializer=layer_init,
            trainable=True,
        )
        self.layer_init = layer_init

    def get_build_config(self):
        build_config = {
            "layer_init": self.layer_init,
            "input_shape": self._input_shape,
        }  # Stores our initializer for `build()`
        return build_config

    def build_from_config(self, config):
        # Calls `build()` with the parameters at loading time
        self.build(config["input_shape"], config["layer_init"])


custom_layer = LayerWithCustomBuild(units=16)
custom_layer.build(input_shape=(8,), layer_init="random_normal")

model = keras.Sequential(
    [
        custom_layer,
        keras.layers.Dense(1, activation="sigmoid"),
    ]
)

x = np.random.random((16, 8))
y = model(x)

model.save("custom_build_model.keras")
restored_model = keras.models.load_model("custom_build_model.keras")

np.testing.assert_equal(restored_model.layers[0].layer_init, "random_normal")
np.testing.assert_equal(restored_model.built, True)

get_compile_config() 和 compile_from_config()


这些方法可共同保存编译模型时使用的信息(优化器、损耗等),并使用这些信息恢复和重新编译模型。

重载这些方法对于用自定义优化器、自定义损耗等编译恢复后的模型非常有用,因为在调用 compile_from_config() 中的 model.compile 之前需要对这些信息进行反序列化。

下面我们来看一个例子。

@keras.saving.register_keras_serializable(package="my_custom_package")
def small_square_sum_loss(y_true, y_pred):
    loss = keras.ops.square(y_pred - y_true)
    loss = loss / 10.0
    loss = keras.ops.sum(loss, axis=1)
    return loss


@keras.saving.register_keras_serializable(package="my_custom_package")
def mean_pred(y_true, y_pred):
    return keras.ops.mean(y_pred)


@keras.saving.register_keras_serializable(package="my_custom_package")
class ModelWithCustomCompile(keras.Model):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.dense1 = keras.layers.Dense(8, activation="relu")
        self.dense2 = keras.layers.Dense(4, activation="softmax")

    def call(self, inputs):
        x = self.dense1(inputs)
        return self.dense2(x)

    def compile(self, optimizer, loss_fn, metrics):
        super().compile(optimizer=optimizer, loss=loss_fn, metrics=metrics)
        self.model_optimizer = optimizer
        self.loss_fn = loss_fn
        self.loss_metrics = metrics

    def get_compile_config(self):
        # These parameters will be serialized at saving time.
        return {
            "model_optimizer": self.model_optimizer,
            "loss_fn": self.loss_fn,
            "metric": self.loss_metrics,
        }

    def compile_from_config(self, config):
        # Deserializes the compile parameters (important, since many are custom)
        optimizer = keras.utils.deserialize_keras_object(config["model_optimizer"])
        loss_fn = keras.utils.deserialize_keras_object(config["loss_fn"])
        metrics = keras.utils.deserialize_keras_object(config["metric"])

        # Calls compile with the deserialized parameters
        self.compile(optimizer=optimizer, loss_fn=loss_fn, metrics=metrics)


model = ModelWithCustomCompile()
model.compile(
    optimizer="SGD", loss_fn=small_square_sum_loss, metrics=["accuracy", mean_pred]
)

x = np.random.random((4, 8))
y = np.random.random((4,))

model.fit(x, y)

model.save("custom_compile_model.keras")
restored_model = keras.models.load_model("custom_compile_model.keras")

np.testing.assert_equal(model.model_optimizer, restored_model.model_optimizer)
np.testing.assert_equal(model.loss_fn, restored_model.loss_fn)
np.testing.assert_equal(model.loss_metrics, restored_model.loss_metrics)

执行如下:

 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 79ms/step - accuracy: 0.0000e+00 - loss: 0.0627 - mean_metric_wrapper: 0.2500

结论

使用本文中学到的方法可以实现多种使用情况,保存和加载具有特殊资产和状态元素 的复杂模型。

总结一下:

save_own_variables 和 load_own_variables 决定了保存和加载状态的方式。
save_assets 和 load_assets 可用于存储和加载模型所需的任何附加信息。
get_build_config 和 build_from_config 用于保存和恢复模型的构建状态。
get_compile_config 和 compile_from_config 保存和恢复模型的编译状态。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1557857.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024年京东云主机租用价格_京东云服务器优惠价格表

2024年京东云服务器优惠价格表&#xff0c;轻量云主机优惠价格5.8元1个月、轻量云主机2C2G3M价格50元一年、196元三年&#xff0c;2C4G5M轻量云主机165元一年&#xff0c;4核8G5M云主机880元一年&#xff0c;游戏联机服务器4C16G配置26元1个月、4C32G价格65元1个月、8核32G费用…

速腾聚创上市后首份财报:冲击年销百万台,押注人形机器人

作者 |老缅 编辑 |德新 港股「激光雷达第一股」速腾聚创&#xff0c;交出了上市后的首份业绩报告。 3月27日&#xff0c;速腾聚创发布了2023年度财报。 报告期内&#xff0c;公司迎来高速的业务增长——2023年总收入达到人民币11.2亿元&#xff0c;同比增长达到111.2%。这主…

Artplayer视频JSON解析播放器源码|支持弹幕|json数据模式

全开源Artplayer播放器视频解析源码&#xff0c;支持两种返回模式&#xff1a;网页播放模式、json数据模式&#xff0c;json数据模式支持限制ip每分钟访问次数UA限制key密钥&#xff0c;也可理解为防盗链 &#xff0c;本播放器带弹幕库。 运行环境 推荐使用PHP8.0 redis扩展…

书生 浦语大模型全链路开源体系

通用大模型成为发展通用人工智能的重要途径 书生 浦语大模型的开源历程 书生 浦语 2.0体系&#xff0c;面向不同的使用需求&#xff0c;每个规格包含三个模型版本&#xff0c;&#xff08;7B、20B&#xff09;InternLM2-Base、InternLM2、InternLM2-Chat。 大模型是回归语言建…

前缀树/字典树Trie

目录 一、Trie的数据结构 二、代码示例 一、Trie的数据结构 Tire通常包括&#xff1a; 1.root节点(根节点)&#xff1a;插入、查找、删除、遍历等操作从root节点开始. 2.flag&#xff1a;结束标志true/false&#xff0c;用于表示当前节点是否为一个完整的字符串的结尾. 3.ke…

第几个幸运数字(蓝桥杯)

文章目录 第几个幸运数字题目描述答案&#xff1a;1905生成法C代码代码详细注释代码思路解释 第几个幸运数字 题目描述 本题为填空题&#xff0c;只需要算出结果后&#xff0c;在代码中使用输出语句将所填结果输出即可。 到x星球旅行的游客都被发给一个整数&#xff0c;作为…

软考高级架构师:信息安全保护等级

作者&#xff1a;明明如月学长&#xff0c; CSDN 博客专家&#xff0c;大厂高级 Java 工程师&#xff0c;《性能优化方法论》作者、《解锁大厂思维&#xff1a;剖析《阿里巴巴Java开发手册》》、《再学经典&#xff1a;《Effective Java》独家解析》专栏作者。 热门文章推荐&am…

二十四种设计模式与六大设计原则(三):【装饰模式、迭代器模式、组合模式、观察者模式、责任链模式、访问者模式】的定义、举例说明、核心思想、适用场景和优缺点

接上次博客&#xff1a;二十四种设计模式与六大设计原则&#xff08;二&#xff09;&#xff1a;【门面模式、适配器模式、模板方法模式、建造者模式、桥梁模式、命令模式】的定义、举例说明、核心思想、适用场景和优缺点-CSDN博客 目录 装饰模式【Decorator Pattern】 定义…

Android MediaPlayer

MediaPlayer 类是媒体框架最重要的组成部分之一。此类的对象能够获取、解码以及播放音频和视频&#xff0c;而且只需极少量设置。它支持多种不同的媒体源&#xff0c;例如&#xff1a; • 本地资源 • 内部 URI&#xff0c;例如您可能从内容解析器那获取的 URI • 外部网址…

idea从零开发Android 安卓 (超详细)

首先把所有的要准备的说明一下 idea 2023.1 什么版本也都可以操作都是差不多的 gradle 8.7 什么版本也都可以操作都是差不多的 Android SDK 34KPI 下载地址&#xff1a; AndroidDevTools - Android开发工具 Android SDK下载 Android Studio下载 Gradle下载 SDK Tools下载 …

智慧水利中数据可视化的关键作用

在当今这个数据驱动的时代&#xff0c;数据可视化已成为转化复杂数据集为易于理解的视觉格式的关键技术&#xff0c;它在智慧水利领域的应用尤为显著。智慧水利利用现代信息技术&#xff0c;整合水资源管理的各个方面&#xff0c;旨在提高水资源的使用效率和管理效能。数据可视…

Linux基础篇:VMware虚拟机3种常用的网络模式介绍

VMware虚拟机3种常用的网络模式介绍 VMware虚拟机提供了几种不同的网络连接模式&#xff0c;以满足不同场景下的网络需求。以下是VMware虚拟机的三种主要网络模式&#xff1a; 1.桥接模式&#xff08;Bridged Mode&#xff09;网卡名称VMnet0 桥接模式允许虚拟机直接连接到物…

Linux——将云服务器作为跳板机,frp实现内网穿透

文章目录 操作步骤1. 准备工作&#xff1a;2. 配置frp服务器端&#xff1a;3. 配置frp客户端&#xff1a;4. 启动frp客户端&#xff1a;5. 测试连接&#xff1a;6. 安全注意事项&#xff1a; 云服务器性能分析阿里云具体操作步骤1. 购买&#xff1a;2. 登录&#xff1a;3. 首次…

Transformer论文阅读

Transformer论文阅读 摘要结论1 Introduction &#xff08;导言&#xff09;2 Background3 Model Architecture3.1 Encoder and Decoder StacksEncoderLayer NormDecoder 3.2 Attention3.2.1 Scaled Dot-Product Attention3.2.2 Scaled Dot-Product Attention3.2.3 Application…

HAProxy + Vitess负载均衡

一、环境搭建 Vitess环境搭建&#xff1a; 具体vitess安装不再赘述&#xff0c;主要是需要启动3个vtgate&#xff08;官方推荐vtgate和vtablet数量一致&#xff09; 操作&#xff1a; 在vitess/examples/common/scripts目录中&#xff0c;修改vtgate-up.sh文件&#xff0c;…

嵌入式Qt 布局管理器QBoxLayout

一.存在问题 二.布局管理器 三.布局接口函数的使用 TestBtn1.setText("Test Button 1"); TestBtn1.setSizePolicy(QSizePolicy::Expanding, QSizePolicy::Expanding); TestBtn1.setMinimumSize(160, 30); 使用setSizePolicy&#xff0c;那么 TestBtn1按钮 就会随着…

TypseScript再学习之类型别名和接口(10)

先看类型别名&#xff1a;使用关键字 type 声明,注意有等于号额 // 类型别名 使用关键字 type 声明,注意有等于号额 type Cat {name: string; }; let huahua: Cat {name: "花花", };type和interface不同之处在于&#xff1a;interface 是可以自动合并类型的&#…

源支付V7开源版2.99,修复各种提示错误

源支付V7开源版2.99&#xff0c;修复各种提示错误 加密说明&#xff1a;200拿来的&#xff0c;只有8.1这个文件加密&#xff0c;其他文件无任何加密&#xff0c;已修复各种提示错误 测试其他开源版安装提示错误&#xff0c;有几个文件是加密的 注&#xff1a;开发不易&#…

基于stm32的h5新建工程

目录 基于stm32的h5新建工程前言实验目的原理图部分搭建工程引脚配置界面&#xff1a;时钟配置界面工程选项卡&#xff1a; 编写代码实现点灯本文中使用的测试工程 基于stm32的h5新建工程 本文目标&#xff1a;基于stm32的基础实验 按照本文的描述&#xff0c;应该可以跑通实…

python学习16:python中的布尔类型和条件语句的学习

python中的布尔类型和条件语句的学习 1.布尔&#xff08;bool&#xff09;类型的定义&#xff1a; 布尔类型的字面量&#xff1a;True表示真&#xff08;是、肯定&#xff09; False表示假&#xff08;否、否定&#xff09; True本质上是一个数字记作1&#xff0c;False记作0 …