【使用 TensorFlow 2】01/3 中创建和训练自定义层

news2025/1/12 13:30:09

之前我们已经看到了如何创建自定义损失函数

接下来,我写了关于使用 Lambda 层创建自定义激活函数的文章   

一、说明

        TensorFlow 2发布已经接近2年时间,不仅继承了Keras快速上手和易于使用的特性,同时还扩展了原有Keras所不支持的分布式训练的特性。3大设计原则:简化概念,海纳百川,构建生态.这是本系列的第三部分,我们将创建自定义密集层并在 TensorFlow 2 中训练它们。

二、图层介绍 

        Lambda 层是 TensorFlow 中的简单层,可用于创建一些自定义激活函数。但是 lambda 层有很多限制,尤其是在训练这些层时。因此,我们的想法是使用TensorFlow中可继承的Keras层创建可训练的自定义层 - 特别关注密集层。

        什么是图层?

图1.图层 — 密集图层表示

 

        层是一个类,它接收一些参数,通过状态和计算传递它们,并根据神经网络的要求传递输出。每个模型架构都包含多个层,无论是顺序层还是函数式 API。

        状态 — 主要是在“model.fit”期间训练的可训练特征。在密集层中,状态构成权重和偏差,如图 1 所示。这些值会更新,以便在模型训练时提供更好的结果。在某些图层中,状态还可以包含不可训练的特征。

        计算 — 计算有助于将一批输入数据转换为一批输出数据。在图层的这一部分中,将进行计算。在密集层中,计算执行以下计算 —

        Y = (w*X+c),并返回 Y。

        Y 是输出,X 是输入,w = 权重,c = 偏置。

三、创建自定义密集层 

        现在我们知道了密集层内部发生了什么,让我们看看如何创建自己的密集层并在模型中使用它。

import tensorflow as tf
from tensorflow.keras.layers import Layer

class SimpleDense(Layer):

    def __init__(self, units=32):
        '''Initializes the instance attributes'''
        super(SimpleDense, self).__init__()
        self.units = units

    def build(self, input_shape):
        '''Create the state of the layer (weights)'''
        # initialize the weights
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(name="kernel",   initial_value=w_init(shape=(input_shape[-1], self.units),
                 dtype='float32'),trainable=True)

        # initialize the biases
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(name="bias",initial_value=b_init(shape=(self.units,), dtype='float32'),trainable=True)

    def call(self, inputs):
        '''Defines the computation from inputs to outputs'''
        return tf.matmul(inputs, self.w) + self.b

        上面代码的解释 — 该类名为 SimpleDense。当我们创建自定义层时,我们必须继承 Keras 的层类。这是在“类简单密集(层)”行中完成的。

        “__init__”是类中第一个有助于初始化类的方法。 “init”接受参数并将其转换为可在类中使用的变量。这是从“Layer”类继承的,因此需要进行一些初始化。此初始化是使用“super”关键字完成的。“单位”是一个局部类变量。这类似于密度层中的单元数。默认值设置为 32,但在调用类时始终可以更改。

        “build”是类中的下一个方法。这用于指定状态。在密集层中,权重和偏差所需的两种状态是“w”和“b”。当创建密集层时,我们不只是创建网络隐藏层的一个神经元,而是一次创建多个神经元(在这种情况下将创建 32 个神经元)。层中的每个神经元都需要初始化并给出一些随机权重和偏差值。TensorFlow包含许多内置函数来初始化这些值。

        为了初始化权重,我们使用 TensorFlow 的 'random_normal_initializer' 函数,该函数将使用正态分布随机初始化权重。'self.w' 以张量变量的形式包含权重的状态。这些状态将使用“w_init”进行初始化。作为权重包含的值将采用“float_32”格式。它设置为“可训练”,这意味着每次运行后,这些初始权重将根据损失函数和优化器进行更新。添加了名称“内核”,以便以后可以轻松跟踪。

        为了初始化偏差,使用了TensorFlow的“zeros_initializer”函数。这会将所有初始偏置值设置为零。'self.b' 是一个张量,其大小与单位大小相同(此处为 32),这 32 个偏差项中的每一个最初都设置为零。这也设置为“可训练”,因此偏差项将在训练开始时更新。添加了名称“偏差”,以便以后能够追踪它。

        “调用”是执行计算的最后一种方法。在这种情况下,由于它是一个密集层,它将输入乘以权重,添加偏差,最后返回输出。“matmul”运算用作 self.w 和 self.b 是张量而不是单个数值。

# declare an instance of the class 
my_dense = SimpleDense(units=1)  
# define an input and feed into the layer 
x = tf.ones((1, 1)) 
y = my_dense(x)  
# parameters of the base Layer class like `variables` can be used 
print(my_dense.variables)

输出:

[<tf.Variable 'simple_dense/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[0.00382898]], dtype=float32)>, 
<tf.Variable 'simple_dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]

        上面代码的解释 — 第一行创建一个仅包含一个神经元的密集层(单位 = 1)。x(输入)是形状为 (1,1) 的张量,值为 1。Y = my_dense(x),有助于初始化密集层。“.variables”帮助我们查看在密集层中初始化的值(权重和偏差)。

        “my_dense.variable”的输出显示在代码块下方。它表明“simple_dense”中有两个变量,称为“内核”和“偏差”。核 'w' 初始化值 0.0038,随机正态分布值,偏差 'b' 初始化值 0。这只是图层的初始状态。训练后,这些值将相应更改。

import numpy as np
# define the dataset 
xs = np.array([-1.0,  0.0, 1.0, 2.0, 3.0, 4.0], dtype=float) 
ys = np.array([-3.0, -1.0, 1.0, 3.0, 5.0, 7.0], dtype=float)   
# use the Sequential API to build a model with our custom layer 
my_layer = SimpleDense(units=1) 
model = tf.keras.Sequential([my_layer])  
# configure and train the model 
model.compile(optimizer='sgd', loss='mean_squared_error') model.fit(xs, ys, epochs=500,verbose=0)  
# perform inference 
print(model.predict([10.0]))  
# see the updated state of the variables 
print(my_layer.variables)

        输出:

[[18.981567]]
[<tf.Variable 'sequential/simple_dense_1/kernel:0' shape=(1, 1) dtype=float32, numpy=array([[1.9973286]], dtype=float32)>, 
<tf.Variable 'sequential/simple_dense_1/bias:0' shape=(1,) dtype=float32, numpy=array([-0.99171764], dtype=float32)>]

        上面代码的解释 - 上面使用的代码是检查自定义层是否工作的非常简单的方法。设置输入和输出,使用自定义层编译模型,最后训练 500 轮。重要的是要看到,训练模型后,权重和偏差的值现在已经发生了变化。最初设置为 0.0038 的权重现在为 1.9973,最初设置为零的偏差现在为 -0.9917。

四、向自定义密集层添加激活函数 

        之前我们创建了自定义 Dense 层,但我们没有随该层添加任何激活。当然,要添加激活,我们可以将激活编写为模型中的单独行,或者将激活添加为 Lambda 层。但是我们如何在上面创建的同一自定义层中实现激活。

        答案是对自定义密集层中的“__init__”和“call”方法进行简单的调整。

class SimpleDense(Layer):

    # add an activation parameter
    def __init__(self, units=32, activation=None):
        super(SimpleDense, self).__init__()
        self.units = units
        
        # define the activation to get from the built-in activation layers in Keras
        self.activation = tf.keras.activations.get(activation)


    def build(self, input_shape):
        w_init = tf.random_normal_initializer()
        self.w = tf.Variable(name="kernel",
            initial_value=w_init(shape=(input_shape[-1], self.units),dtype='float32'),trainable=True)
        b_init = tf.zeros_initializer()
        self.b = tf.Variable(name="bias",
            initial_value=b_init(shape=(self.units,), dtype='float32'),trainable=True)
        super().build(input_shape)


    def call(self, inputs):
        
        # pass the computation to the activation layer
        return self.activation(tf.matmul(inputs, self.w) + self.b)

上面代码的解释 — 大多数代码与我们之前使用的代码完全相同。

要添加激活,我们需要在“__init__”中指定我们需要激活。可以将激活对象的字符串或实例传递到此激活中。它设置为默认值为None,因此如果未提及激活函数,则不会引发错误。接下来,我们必须将激活函数初始化为 — 'tf.keras.activations.get(activation)'。

最后的编辑是在“调用”方法中,在计算权重和偏差之前,我们需要添加self.activation 来激活计算。所以现在的回报是计算和激活。

五、自定义密集层的完整代码,在 mnist 数据集上激活 

import tensorflow as tf
from tensorflow.keras.layers import Layer
class SimpleDense(Layer):
def __init__(self, units=32, activation=None):
        super(SimpleDense, self).__init__()
        self.units = units
        
        # define the activation to get from the built-in activation layers in Keras
self.activation = tf.keras.activations.get(activation)


    def build(self, input_shape):
w_init = tf.random_normal_initializer()
        self.w = tf.Variable(name="kernel",
            initial_value=w_init(shape=(input_shape[-1], self.units),dtype='float32'),trainable=True)
b_init = tf.zeros_initializer()
        self.b = tf.Variable(name="bias",
            initial_value=b_init(shape=(self.units,), dtype='float32'),trainable=True)
        super().build(input_shape)


    def call(self, inputs):
        
        # pass the computation to the activation layer
return self.activation(tf.matmul(inputs, self.w) + self.b)
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# build the model
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    # our custom Dense layer with activation
    SimpleDense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])
# compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
# fit the model
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

        使用我们的自定义密集层和激活来训练模型,训练准确度为 97.8%,验证准确度为 97.7%。

六、结论 

        这是在TensorFlow中创建自定义层的方法。即使我们只看到密集层的工作,它也可以很容易地被任何其他层所取代,例如执行以下计算的二次层——

        它有 3 个状态变量a、b 和 c,

计算:

将密集层替换为二次层:

import tensorflow as tf
from tensorflow.keras.layers import Layer
class SimpleQuadratic(Layer):

    def __init__(self, units=32, activation=None):
        '''Initializes the class and sets up the internal variables'''
        
        super(SimpleQuadratic,self).__init__()
        self.units=units
        self.activation=tf.keras.activations.get(activation)
    
    def build(self, input_shape):
        '''Create the state of the layer (weights)'''
        
        a_init = tf.random_normal_initializer()
        a_init_val = a_init(shape=(input_shape[-1],self.units),dtype= 'float32')
        self.a = tf.Variable(initial_value=a_init_val, trainable='true')
        
        b_init = tf.random_normal_initializer()
        b_init_val = b_init(shape=(input_shape[-1],self.units),dtype= 'float32')
        self.b = tf.Variable(initial_value=b_init_val, trainable='true')
        
        c_init= tf.zeros_initializer()
        c_init_val = c_init(shape=(self.units,),dtype='float32')
        self.c = tf.Variable(initial_value=c_init_val,trainable='true')
        
   
    def call(self, inputs):
        '''Defines the computation from inputs to outputs'''
        x_squared= tf.math.square(inputs)
        x_squared_times_a = tf.matmul(x_squared,self.a)
        x_times_b= tf.matmul(inputs,self.b)
        x2a_plus_xb_plus_c = x_squared_times_a+x_times_b+self.c
        
        return self.activation(x2a_plus_xb_plus_c)
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  SimpleQuadratic(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

        该二次层在 mnist 数据集上的验证准确率为 97.8%。

        因此,我们看到我们可以在 TensorFlow 模型中实现我们自己的层以及所需的激活,以编辑甚至提高整体精度。阿琼·萨卡尔

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1080645.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是物联网阀控水表?

物联网阀控水表是一种新型的水表&#xff0c;结合了物联网技术和传统水表的功能&#xff0c;可以实现对水的计量、控制和管理。随着人们对水资源的日益重视&#xff0c;物联网阀控水表的应用越来越广泛&#xff0c;为水资源的合理利用和管理提供了有效手段。 物联网阀控水表是由…

springboot json在线转换为实体类

json字符串映射到一个实体类。 这里有一个在线转换工具 http://www.bejson.com/json2javapojo/new/ 截图如下&#xff1a;

软件测试可以发现所有Bug吗,软件测评机构有哪些优势?

软件产品的质量要想得到保障&#xff0c;软件测试这项工作必不可少&#xff0c;主要是在测试过程中提前发现问题&#xff0c;用来促进鉴定软件的正确性、完整性、安全性和质量的过程。那软件测试可以发现所有bug吗?   我们要明确的是&#xff0c;软件测试是一项非常重要的工…

阿里云企业邮箱的替代方案有哪些?

"阿里云企业邮箱有哪些替代方案&#xff1f;替代方案有Zoho Mail、腾讯企业邮箱、网易企业邮箱、搜狐企业邮箱、新浪企业邮箱等。" 在中国的云计算市场中&#xff0c;阿里云企业邮箱无疑是一颗璀璨的明星。然而&#xff0c;市场上的需求多元化使得消费者在选择云服务…

el-select的el-option添加操作按钮插槽后实现勾选与按钮操作分离

这里我在el-option的选项文字后面添加了两个svg按钮&#xff08;编辑和删除&#xff09;&#xff1a;如图 当我们点击el-option时无法区分是勾选el-option还是点击了el-option选项文字后面的按钮&#xff0c;其实只要在后面的编辑和删除的操作按钮的click事件上面添加.native.…

速通Redis基础(三):掌握Redis的列表类型和命令

目录 Redis列表类型 Redis列表的基本命令 LPUSH LPUSHX RPUSH RPUSHX LRANGE LPOP RPOP LINDEX LINSERT LLEN 阻塞版本命令 BLPOP BRPOP Redis的列表命令小结 Redis是一种高性能、开源的NoSQL数据库&#xff0c;以其支持多种数据类型而闻名。在前两篇博客中&am…

【数据结构】线段树

算法提高课笔记 还未更新完 文章目录 原理pushupbuildmodifyquerypushdown&#xff08;懒标记 / 延迟标记&#xff09;扫描线法 原理 时间复杂度&#xff1a;O(logn) 线段树是一棵二叉树&#xff0c;把一段区间分成多个部分 类似堆的方式&#xff0c;用一维数组存整棵树 对…

远场Far-Field beamforming与近场Near-Field beamforming有何关系

这里写目录标题 UPA![在这里插入图片描述](https://img-blog.csdnimg.cn/170e1282d2d6424595263daf77707234.png)写在前面Channel Estimation for Extremely Large-Scale Massive MIMO:Far-Field, Near-Field, or Hybrid-Field?Far Field modelNear Field model UPA 写在前面…

Java内存空间(学习随笔)

1、程序运行中栈可能会出现两种错误 StackOverFlowError&#xff1a; 若栈的内存大小不允许动态扩展&#xff0c;那么当线程请求栈的深度超过当前 Java 虚拟机栈的最大深度的时候&#xff0c;就抛出 StackOverFlowError 错误。OutOfMemoryError&#xff1a; 如果栈的内存大小可…

音视频方法技术有哪些?H.265技术详解

H.265发展背景 H.264虽然是一个划时代的数字视频压缩标准&#xff0c;但是随着数字视频产业链的高速发展&#xff0c;H.264的局限性逐步显现&#xff0c;并且由于H.264标准核心压缩算法的完全固化&#xff0c;并不能够通过调整或扩充来更好地满足当前高清数字视频应用。 视频…

电子书制作软件Vellum mac中文版特点

Vellum mac是一款专业的电子书制作软件&#xff0c;它可以帮助用户将文本文件转换为高质量的电子书&#xff0c;支持多种格式&#xff0c;包括EPUB、MOBI、PDF等。Vellum具有直观的用户界面和易于使用的工具&#xff0c;可以让用户快速地创建和发布电子书。 Vellum mac软件特点…

基层医院信息管理系统源码 his系统全套成品源码带电子病历4级

基层医院his系统源码 二级医院信息管理系统源码&#xff0c;演示自主版权&#xff0c;云端SaaS服务 技术细节&#xff1a; 前端&#xff1a;AngularNginx 后台&#xff1a;JavaSpring&#xff0c;SpringBoot&#xff0c;SpringMVC&#xff0c;SpringSecurity&#xff0c;MyBa…

【MySQL】深入解析MySQL双写缓冲区

原创不易&#xff0c;注重版权。转载请注明原作者和原文链接 文章目录 为什么需要Doublewrite BufferDoublewrite Buffer原理Doublewrite Buffer和redo logDoublewrite Buffer相关参数总结 在数据库系统的世界中&#xff0c;保障数据的完整性和稳定性是至关重要的任务。为了实现…

web 基础和http 协议

一、域名 域名的概念 IP地址不易记忆&#xff0c;域名方便记住&#xff0c;以便于用户进行搜索访问 早期使用Hosts文件解析域名地址 缺点&#xff1a; ① 主机名称重复 ② 主机维护困难 DNS&#xff08;Domain Name System&#xff09;域名系统 ① 分布式 将一个大的数…

【AN-Animate教程——熟悉工作区】

【AN-Animate教程——熟悉工作区】 初始页面创建舞台主舞台界面其他常用板块 本篇内容&#xff1a;Animate用途 重点内容&#xff1a;熟悉工作区&#xff0c;以及基本操作 工 具&#xff1a;Adobe Animate 2022 初始页面 在初始页面当中&#xff0c;我们可以看到一个忍者和一个…

使用VS编译Redis源码报错

使用Redis源码版本,解压工程右键生成hiredis项目正常,编译Win32_Interop项目报下图错误(error C2039:system_error:不是std成员;error C3861: system_category:找不到标识符) 解决办法:在Win32_variadicFunctor.cpp和Win32_FDAPI.cpp添加 #include <system_error> ,再右键…

k8s 集群部署 kubesphere

一、最小化部署 kubesphere 1、在已有的 Kubernetes 集群上部署 KubeSphere&#xff0c;下载 YAML 文件: wget https://github.com/kubesphere/ks-installer/releases/download/v3.4.0/kubesphere-installer.yaml wget https://github.com/kubesphere/ks-installer/releases/…

204318-14-9,依多曲肽,DOTA-TOC

DOTA-[Tyr3]-Octreotide&#xff0c;依多曲肽,DOTA-(酪氨酸3)-奥曲肽是一种重要的多肽分子&#xff0c;其结构与奥曲肽类似&#xff0c;具有多种重要的药理作用。由于其具有大量的羧基官能团和醇羟基官能团&#xff0c;可以与各种放射性核素结合&#xff0c;因此被广泛应用于放…

基于springboot实现旅游网站管理平台系统项目【项目源码+论文说明】

基于springboot实现旅游网站平台管理系统演示 摘要 随着科学技术的飞速发展&#xff0c;网络快速发展、人民生活的快节奏都在努力与现代的先进技术接轨&#xff0c;通过科技手段来提高自身的优势&#xff0c;旅游管理系统当然也不能排除在外。旅游管理系统是以实际运用为开发背…

关于竞品分析怎么做?掌握这5点就够了!

大家好&#xff0c;我是设计师l1m0&#xff0c;今天要给大家分享的竞品分析相关知识。 在竞争激烈的市场中&#xff0c;了解竞争对手并且在产品开发和市场营销中制定明智的策略至关重要。这正是产品竞品分析的目的所在。本文将详细介绍如何进行产品竞品分析&#xff0c;以及通…