政安晨:【Keras机器学习实践要点】(五)—— 通过子类化创建新层和模型

news2025/1/23 15:09:34

政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras实战演绎机器学习

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

介绍

本文将涵盖构建自己的子类化层和模型所需的所有知识。

您将了解以下功能

层类

add_weight()方法

可训练和不可训练的权重

build()方法

确保您的层可以与任何后端一起使用

add_loss()方法

call()中的训练参数

call()中的掩码参数

确保您的层可以序列化

让我们开始吧。


安装

import numpy as np
from tensorflow import keras
from keras import ops
from keras import layers

层级:状态(权重)与某些计算的组合

Keras 的核心抽象之一是层类。

层封装了状态(层的 "权重")和从输入到输出的转换("调用",层的前向传递)。

下面是一个密集连接的层。它有两个状态变量:变量 w 和 b。

class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super().__init__()
        self.w = self.add_weight(
            shape=(input_dim, units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)

    def call(self, inputs):
        return ops.matmul(inputs, self.w) + self.b

你可以通过调用一个层函数来使用它,就像调用Python函数一样,传入一些张量输入。

x = ops.ones((2, 2))
linear_layer = Linear(4, 2)
y = linear_layer(x)
print(y)

请注意,权重 w 和 b 在被设置为层属性后,会被层自动跟踪

assert linear_layer.weights == [linear_layer.w, linear_layer.b]

层可以有不可训练的重量

除了可训练权重外,还可以向层添加不可训练权重。在反向传播过程中,这些权重在训练层时不会被考虑在内。

下面介绍如何添加和使用不可训练权重

class ComputeSum(keras.layers.Layer):
    def __init__(self, input_dim):
        super().__init__()
        self.total = self.add_weight(
            initializer="zeros", shape=(input_dim,), trainable=False
        )

    def call(self, inputs):
        self.total.assign_add(ops.sum(inputs, axis=0))
        return self.total


x = ops.ones((2, 2))
my_sum = ComputeSum(2)
y = my_sum(x)
print(y.numpy())
y = my_sum(x)
print(y.numpy())

它是层权重的一部分,但被归类为不可训练的权重。

print("weights:", len(my_sum.weights))
print("non-trainable weights:", len(my_sum.non_trainable_weights))

# It's not included in the trainable weights:
print("trainable_weights:", my_sum.trainable_weights)

最佳实践推迟权重的创建,直到输入的形状已知

我们上面的线性层接受了一个input_dim参数,该参数用于在__init__()函数中计算权重w和偏置b的形状。

class Linear(keras.layers.Layer):
    def __init__(self, units=32, input_dim=32):
        super().__init__()
        self.w = self.add_weight(
            shape=(input_dim, units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)

    def call(self, inputs):
        return ops.matmul(inputs, self.w) + self.b

在许多情况下,您可能无法预先知道输入的大小,并且希望在实例化图层后的某个时间,当该值变为已知时,才懒惰地创建权重。

在Keras API中,我们建议在您的图层的build(self, inputs_shape)方法中创建图层权重。

像这样:

class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super().__init__()
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return ops.matmul(inputs, self.w) + self.b

您的层的__call__()方法在第一次被调用时会自动运行build。

现在您有一个懒惰的层,因此更容易使用。

# At instantiation, we don't know on what inputs this is going to get called
linear_layer = Linear(32)

# The layer's weights are created dynamically the first time the layer is called
y = linear_layer(x)

将build()单独实现如上所示,很好地将只创建权重一次与在每次调用中使用权重进行了分离。


层可以递归组合

如果你将一个 Layer 实例分配为另一个 Layer 的属性,外层将开始跟踪内层创建的权重。

我们建议在 init() 方法中创建这样的子层,并在第一个 call() 中触发构建它们的权重。

class MLPBlock(keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.linear_1 = Linear(32)
        self.linear_2 = Linear(32)
        self.linear_3 = Linear(1)

    def call(self, inputs):
        x = self.linear_1(inputs)
        x = keras.activations.relu(x)
        x = self.linear_2(x)
        x = keras.activations.relu(x)
        return self.linear_3(x)


mlp = MLPBlock()
y = mlp(ops.ones(shape=(3, 64)))  # The first call to the `mlp` will create the weights
print("weights:", len(mlp.weights))
print("trainable weights:", len(mlp.trainable_weights))

后端不可知层和特定后端层

只要一个层只使用 keras.ops 命名空间的 API(或者其他 Keras 命名空间,例如 keras.activations、keras.random 或 keras.layers),那么它就可以与任何后端一起使用——TensorFlow、JAX 或 PyTorch。

到目前为止,在本指南中看到的所有层都适用于所有Keras后端。

keras.ops命名空间提供了以下功能

NumPy API,例如ops.matmul,ops.sum,ops.reshape,ops.stack等。

神经网络特定的API,例如ops.softmax,ops.conv,ops.binary_crossentropy,ops.relu等。

您还可以在层中使用本机后端API(例如tf.nn函数),但是如果这样做,您的层只能与特定的后端一起使用。

例如,您可以使用jax.numpy编写以下特定于JAX的层:

import jax

class Linear(keras.layers.Layer):
    ...

    def call(self, inputs):
        return jax.numpy.matmul(inputs, self.w) + self.b

这将是等效的TensorFlow特定层:

import tensorflow as tf

class Linear(keras.layers.Layer):
    ...

    def call(self, inputs):
        return tf.matmul(inputs, self.w) + self.b

这将是等效的PyTorch特定层:

import torch

class Linear(keras.layers.Layer):
    ...

    def call(self, inputs):
        return torch.matmul(inputs, self.w) + self.b

由于跨后端兼容性是一种非常有用的特性,我们强烈建议您始终通过仅使用Keras APIs来使您的层与后端无关。


add_loss()方法

在编写层的call()方法时,您可以创建损失张量,以便在编写训练循环时稍后使用。通过调用self.add_loss(value)可以实现这一点。

# A layer that creates an activity regularization loss
class ActivityRegularizationLayer(keras.layers.Layer):
    def __init__(self, rate=1e-2):
        super().__init__()
        self.rate = rate

    def call(self, inputs):
        self.add_loss(self.rate * ops.mean(inputs))
        return inputs

这些损耗(包括任何内层创建的损耗)可以通过 layer.losses 检索到。

该属性在每次调用顶层层的 __call__() 开始时重置,因此 layer.losses 总是包含上次向前传递时创建的损耗值。

class OuterLayer(keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.activity_reg = ActivityRegularizationLayer(1e-2)

    def call(self, inputs):
        return self.activity_reg(inputs)


layer = OuterLayer()
assert len(layer.losses) == 0  # No losses yet since the layer has never been called

_ = layer(ops.zeros((1, 1)))
assert len(layer.losses) == 1  # We created one loss value

# `layer.losses` gets reset at the start of each __call__
_ = layer(ops.zeros((1, 1)))
assert len(layer.losses) == 1  # This is the loss created during the call above

此外,损失属性还包含为任何内层权重创建的正则化损失

class OuterLayerWithKernelRegularizer(keras.layers.Layer):
    def __init__(self):
        super().__init__()
        self.dense = keras.layers.Dense(
            32, kernel_regularizer=keras.regularizers.l2(1e-3)
        )

    def call(self, inputs):
        return self.dense(inputs)


layer = OuterLayerWithKernelRegularizer()
_ = layer(ops.zeros((1, 1)))

# This is `1e-3 * sum(layer.dense.kernel ** 2)`,
# created by the `kernel_regularizer` above.
print(layer.losses)

打印结果:

[Array(0.00217911, dtype=float32)]

在编写自定义训练循环时,应考虑到这些损失。

它们也可以与fit()方法无缝配合使用(如果有的话,会自动将它们求和并添加到主要损失中):

inputs = keras.Input(shape=(3,))
outputs = ActivityRegularizationLayer()(inputs)
model = keras.Model(inputs, outputs)

# If there is a loss passed in `compile`, the regularization
# losses get added to it
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# It's also possible not to pass any loss in `compile`,
# since the model already has a loss to minimize, via the `add_loss`
# call during the forward pass!
model.compile(optimizer="adam")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

可以选择在您的层上启用序列化功能

如果需要将自定义层作为功能模型的一部分进行序列化,可以选择实现 get_config() 方法

class Linear(keras.layers.Layer):
    def __init__(self, units=32):
        super().__init__()
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return ops.matmul(inputs, self.w) + self.b

    def get_config(self):
        return {"units": self.units}


# Now you can recreate the layer from its config:
layer = Linear(64)
config = layer.get_config()
print(config)
new_layer = Linear.from_config(config)
{'units': 64}

请注意,基本的Layer类的__init__()方法接受一些关键字参数,特别是name和dtype。

在__init__()中将这些参数传递给父类,并将它们包含在层的配置中是一个好的做法。

class Linear(keras.layers.Layer):
    def __init__(self, units=32, **kwargs):
        super().__init__(**kwargs)
        self.units = units

    def build(self, input_shape):
        self.w = self.add_weight(
            shape=(input_shape[-1], self.units),
            initializer="random_normal",
            trainable=True,
        )
        self.b = self.add_weight(
            shape=(self.units,), initializer="random_normal", trainable=True
        )

    def call(self, inputs):
        return ops.matmul(inputs, self.w) + self.b

    def get_config(self):
        config = super().get_config()
        config.update({"units": self.units})
        return config


layer = Linear(64)
config = layer.get_config()
print(config)
new_layer = Linear.from_config(config)
{'name': 'linear_7', 'trainable': True, 'dtype': 'float32', 'units': 64}

如果在从配置反序列化层时需要更大的灵活性,也可以覆盖 from_config() 类方法。

这是 from_config() 的基本实现:

def from_config(cls, config):
    return cls(**config)

顺序模型主题咱们就到这里。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1547457.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用Spark单机版环境

在Spark单机版环境中,可通过多种方式进行实战操作。首先,可使用特定算法或数学软件计算圆周率π,并通过SparkPi工具验证结果。其次,在交互式Scala版或Python版Spark Shell中,可以进行简单的计算、打印九九表等操作&…

ABAP - 上传文件模板到SMW0,并从SMW0上下载模板

upload file template to SMW0 and download the template from it 首先上传文件到tcode SMW0 选择新建后,输入文件名和描述,再选择想要上传的文件 上传完成后: 在表WWWPARAMS, WWWDATA里就会有信息存进去 然后就可以程序里写代码了: 屏幕上的效果:

jupyter notebook导出含中文的pdf(LaTex安装和Pandoc、MiKTex安装)

用jupyter notebook导出pdf时,因为报错信息,需要用到Tex nbconvert failed: xelatex not found on PATH, if you have not installed xelatex you may need to do so. Find further instructions at https://nbconvert.readthedocs.io/en/latest/install…

【数据分享】1929-2023年全球站点的逐年平均露点(Shp\Excel\免费获取)

气象数据是在各项研究中都经常使用的数据,气象指标包括气温、风速、降水、能见度等指标,说到气象数据,最详细的气象数据是具体到气象监测站点的数据! 有关气象指标的监测站点数据,之前我们分享过1929-2023年全球气象站…

界面控件DevExpress WinForms/WPF v23.2 - 电子表格支持表单控件

DevExpress WinForm拥有180组件和UI库,能为Windows Forms平台创建具有影响力的业务解决方案。DevExpress WinForm能完美构建流畅、美观且易于使用的应用程序,无论是Office风格的界面,还是分析处理大批量的业务数据,它都能轻松胜任…

IDEA编辑国际化.properties文件没有Resource Bundle怎么办?

问题描述 最近在做SpringBoot国际化,IDEA添加了messages.properties、messages_en_US.properties、messages_zh_CN.properties国际化文件后,在编辑页面底部没有Resource Bundle,这使得我在写keyvalue的时候在每个properties文件都要拷贝一次…

【Spring源码】Bean采用什么数据结构进行存储

一、前瞻 经过上篇源码阅读博客的实践,发现按模块阅读也能获得不少收获,而且能更加系统地阅读源码。 今天的阅读方式还是按模块阅读的方式,以下是Spring各个模块的组成。 那今天就挑Beans这个模块来阅读,先思考下本次阅读的阅读…

中间件学习--InfluxDB部署(docker)及springboot代码集成实例

一、需要了解的概念 1、时序数据 时序数据是以时间为维度的一组数据。如温度随着时间变化趋势图,CPU随着时间的使用占比图等等。通常使用曲线图、柱状图等形式去展现时序数据,也就是我们常常听到的“数据可视化”。 2、时序数据库 非关系型数据库&#…

gin语言基础学习--会话控制(下)

练习 模拟实现权限验证中间件 有2个路由,/cookie和/home/cookie用于设置cookiehome是访问查看信息的请求在请求home之前,先跑中间件代码,检验是否存在cookie 访问home,会显示错误,因为权限校验未通过 package mainim…

阿里云安全产品简介,Web应用防火墙与云防火墙产品各自作用介绍

在阿里云的安全类云产品中,Web应用防火墙与云防火墙是用户比较关注的安全类云产品,二则在作用上并不是完全一样的,Web应用防火墙是一款网站Web应用安全的防护产品,云防火墙是一款公共云环境下的SaaS化防火墙,本文为大家…

canal: 连接kafka (docker)

一、确保mysql binlog开启并使用ROW作为日志格式 docker 启动mysql 5.7配置文件 my.cnf [mysqld] log-binmysql-bin # 开启 binlog binlog-formatROW # 选择 ROW 模式 server-id1一定要确保上述两个值一个为ROW,一个为ON 二、下载canal的run.sh https://github.c…

【Java】LinkedList vs. ArrayList:Java中的数据结构选择

人不走空 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌赋:斯是陋室,惟吾德馨 目录 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌…

Kindling the Darkness:A Practical Low-light Image Enhancer

Abstract 在弱光条件下拍摄的图像通常会出现(部分)可见度较差的情况。,除了令人不满意的照明之外,多种类型的退化也隐藏在黑暗中,例如由于相机质量有限而导致的噪点和颜色失真。,换句话说,仅仅调高黑暗区域的亮度将不…

R语言随机抽取数据,并作两组数据间t检验,并保存抽取的数据,并绘制boxplot

前提:接着上述R脚本输出的seed结果来选择应该使用哪个seed比较合理,上个R脚本名字: “5utr_计算ABD中Ge1和Lt1的个数和均值以及按照TE个数小的进行随机100次抽样.R” 1.输入数据:“5utr-5d做ABD中有RG4和没有RG4的TE之间的T检验.c…

String类(三)

文章目录 string类(三)string类的模拟实现:1.默认成员变量和函数2.string的长度和下表引用3.字符串拷贝构造4. 赋值拷贝5.字符串比较6.字符串的增添操作7.insert插入操作8.遍历字符 string类(三) string类的模拟实现&…

jupyter lab使用虚拟环境

python -m ipykernel install --name 虚拟环境名 --display-name 虚拟环境名然后再启动jupyter lab就行了

【Unity】调整Player Settings的Resolution设置无效

【背景】 Build时修改了Player Settings下的Resolution设置,但是再次Building时仍然不生效。 【分析】 明显是沿用了之前的分辨率设定,所以盲猜解决办法是Build相关的缓存文件,或者修改打包名称。 【解决】 实测修改版本号无效&#xf…

IDEA使用常用的设置

一、IDEA常用设置 可参考:IDEA这样配置太香了_哔哩哔哩_bilibili 波波老师 二、插件 可参考:IDEA好用插件,强烈推荐_哔哩哔哩_bilibili 波波老师 三、其他 学会用点“.” IDEA弹窗Servers certificate is not trusted怎么禁止&#xf…

基于SSM作业提交与批改

基于SSM作业提交与批改的设计与实现 摘要 社会的进步导致人们对于学习的追求永不止境,那么追求学习的方式也从单一的书本教程变成了多样化的学习方式。多样化的学习方式不仅仅是需要人们智慧的依靠,还需要能够通过软件的加持进行信息化的价值体现。软件…

uniapp开发小程序遇到的问题,持续更新中

一、uniapp引入全局scss 在App.vue中引入uni.scss <style lang"scss">/* #ifndef APP-NVUE */import "uni.scss";/* #endif */ </style>注意&#xff1a;nvue页面的样式在编译时&#xff0c;有很多样式写法被限制了&#xff0c;容易报错。所…