【使用 TensorFlow 2】02/3 使用 Lambda 层创建自定义激活函数

news2025/1/23 17:28:54

一、说明

        TensorFlow 2发布已经接近2年时间,不仅继承了Keras快速上手和易于使用的特性,同时还扩展了原有Keras所不支持的分布式训练的特性。3大设计原则:简化概念,海纳百川,构建生态.这是本系列的第三部分,我们将创建激活层并在 TensorFlow 2 中训练它们。
        之前我们已经了解了如何创建自定义损失函数 -使用 TensorFlow 2 创建自定义损失函数

自定义ReLU函数(来源:作者创建的图片)

二、介绍

        在本文中,我们将了解如何创建自定义激活函数。虽然 TensorFlow 已经包含一堆内置激活函数,但有多种方法可以创建您自己的自定义激活函数或编辑现有激活函数。

        ReLU(修正线性单元)仍然是任何神经网络架构的隐藏层中最常用的激活函数。ReLU 也可以表示为函数 f(x),其中,

        f(x) = 0, 当 x < 0 时,

        并且,当 x ≥ 0 时,f(x) = x。

        因此,该函数仅考虑正部分,并写为:

        f(x) = 最大值(0,x)

        或在代码表示中,

if input > 0:
   return input
else:
   return 0

        但这个ReLU函数是预定义的。如果我们想自定义此函数或创建我们自己的 ReLU 激活该怎么办?在 TensorFlow 中有一种非常简单的方法可以做到这一点——我们只需使用Lambda 层

ReLU 和 GeLU 激活函数

如何使用 lambda 层?

tf.keras.layers.Lambda(lambda x: tf.abs(x))

Lambda 只是可以在 TensorFlow 中直接调用的另一层。在 lambda 层中,首先指定参数。在上面的代码片段中,该值为“x”(lambda x)。在本例中,我们想要求 x 的绝对值,因此我们使用 tf.abs(x)。因此,如果 x 的值为 -1,则该 lambda 层会将 x 的值更改为 1。

如何使用 lambda 层创建自定义 ReLU?

def custom_relu(x):
    return K.maximum(0.0,x)
model = tf.keras.models.Sequential([
     tf.keras.layers.Flatten(input_shape=(128,128)),
     tf.keras.layers.Dense(512),
     tf.keras.layers.Lambda(custom_relu),
     tf.keras.layers.Dense(5, activation = 'softmax')
])

上面的代码片段展示了如何在 TensorFlow 模型中实现自定义 ReLU。我们创建一个函数 custom_relu 并返回 0 或 x 的最大值(与 ReLU 函数相同)。

在下面的顺序模型中,在 Dense 层之后,我们创建一个 Lambda 层并将其传递到自定义激活函数中。但这段代码仍然没有做任何与 ReLU 激活函数不同的事情。

当我们开始研究自定义函数的返回值时,乐趣就开始了。假设我们取 0.5 和 x 中的最大值,而不是 0 和 x。我们已经有了自己定制的 ReLU。然后可以根据需要更改这些值。

def custom_relu(x):
返回 K.maximum(0.5,x)

def custom_relu(x):
    return K.maximum(0.5,x)
model = tf.keras.models.Sequential([
     tf.keras.layers.Flatten(input_shape=(128,128)),
     tf.keras.layers.Dense(512),
     tf.keras.layers.Lambda(custom_relu),
     tf.keras.layers.Dense(5, activation = 'softmax')
])

在 mnist 数据集上使用 lambda 激活的示例

#using absolute value (Lambda layer example 1)
import tensorflow as tf
from tensorflow.keras import backend as K
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128),
  tf.keras.layers.Lambda(lambda x: tf.abs(x)), 
  tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

将 ReLU 激活替换为 mnist 数据集上的绝对值,测试精度为 97.384%。

#using custom ReLU activation (Lambda layer example 2)
import tensorflow as tf
from tensorflow.keras import backend as K
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
def my_relu(x):
    return K.maximum(-0.1, x)
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128),
    tf.keras.layers.Lambda(my_relu), 
    tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x_train, y_train, epochs=5)
model.evaluate(x_test, y_test)

        将 ReLU 激活替换为自定义 ReLU 激活,在 mnist 数据集上取最大值 -0.1 或 x,测试精度为 97.778%。

三、结论

        尽管 lambda 层使用起来非常简单,但它们有很多限制。在下一篇文章中,我将介绍如何在 TensorFlow 中创建可训练的完全自定义层。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1078318.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【pycharm】控制台报错:终端无法加载文件\venv\Scripts\activate.ps1

目录 一、在pycharm控制台输入 二、在windows的power shell &#xff08;以管理员方式打开&#xff09; 三、 在pycharm控制台输入 四、重新打开pycharm即可 前言&#xff1a;安装pycharm2022-03版本出现的终端打开报错 一、在pycharm控制台输入 get-executionpolicy …

单目标分割标签图叠加代码

本代码只适合两个图片合并&#xff0c;如果出现三个图片合并&#xff0c;就将第三个图删除&#xff0c;先合并一次&#xff0c;然后再将图片加入&#xff0c;再合并一次 1. 问题背景 有的时候标签有多个&#xff0c;需要将两张或者是多张图象叠加在一起&#xff0c;成为以下情…

Redis学习6——新数据类型

Bitmaps bitfield HyperLog Geographic Stream 重点来了 redis各个数据类型的操作命令可以到:http://www.redis.cn/commands.html查看

面试经典 150 题 20 —(数组 / 字符串)— 151. 反转字符串中的单词

151. 反转字符串中的单词 方法一 class Solution { public:string reverseWords(string s) {istringstream instr(s);vector<string> words{};string word;while(instr>>word){words.push_back(word);}int length words.size();string result words[length-1];f…

怒刷LeetCode的第26天(Java版)

第一题 题目来源 64. 最小路径和 - 力扣&#xff08;LeetCode&#xff09; 题目内容 解决方法 方法一&#xff1a;动态规划 可以使用动态规划来解决这个问题。 首先创建一个与网格大小相同的二维数组dp&#xff0c;用于存储从起点到每个位置的最小路径和。然后初始化dp[0…

shiro反序列化和log4j

文章目录 安装环境shiro漏洞验证log4j 安装环境 进入vulhb目录下的weblogic&#xff0c;复现CVE-2018-2894漏洞&#xff1a; cd /vulhub/shiro/CVE-2010-3863查看docker-compose的配置文件&#xff1a; cat docker-compose.yml如图&#xff0c;里面有一个镜像文件的信息和服…

136.【JUC并发编程_02】

JUC并发编程 (四)、共享模型之管程1.wait notify(1).小故事_为什么需要wait(2).wait notify 的工作原理(3).API介绍 2.wait notify 的正确使用步骤 ⭐(1).sleep 和 wait 的区别(2).步骤1_产生的问题(3).步骤2_wait notify 改进产生问题(4).步骤3_产生叫错人问题 (虚假唤醒)(5).…

opencv安装成功之后运行代码还是出错

错误提示 Traceback (most recent call last): File "F:\download\55957_人工智能基础与应用&#xff08;微课版&#xff09;_源代码\OpenCV\camera.py", line 4, in <module> import cv2 File "F:\software\anaconda\envs\tensorflow\cv2\__init__.py&q…

Unity 热更新技术 | (一) 热更新的基本概念原理及主流热更新方案介绍

&#x1f3ac; 博客主页&#xff1a;https://xiaoy.blog.csdn.net &#x1f3a5; 本文由 呆呆敲代码的小Y 原创&#xff0c;首发于 CSDN&#x1f649; &#x1f384; 学习专栏推荐&#xff1a;Unity系统学习专栏 &#x1f332; 游戏制作专栏推荐&#xff1a;游戏制作 &…

什么是强缓存、协商缓存?

为了减少资源请求次数,加快资源访问速度,浏览器会对资源文件如图片、css文件、js文件等进行缓存,而浏览器缓存策略又分为强缓存和协商缓存,什么是强缓存?什么是协商缓存?两者之间的区别又是什么?接下来本文就带大家深入了解这方面的知识。 强缓存 所谓强缓存,可以理解…

声音生成评价项目AudioLDM_eval项目配置过程

文章目录 引言正文问题一&#xff1a;模型下载不了问题二 TypeError: pad_center() takes 1 positional argument but 2 were given问题三 AttributeError: module numpy has no attribute complex. 结果 引言 对于生成的声音&#xff0c;如何进行评价&#xff0c;一般是通过计…

Matlab之查询子字符串在字符串中的起始位置函数strfind

一、功能 strfind函数用于在一个字符串中查找指定的子字符串&#xff0c;并返回子字符串在字符串中的起始位置。 二、语法 indices strfind(str, pattern) 其中&#xff0c;str是要进行查找的字符串&#xff0c;pattern是要查找的子字符串。 函数会返回一个由子字符串在字…

网络与信息安全基础知识 (软件设计师笔记)

&#x1f600;前言 在当今世界&#xff0c;我们见证了科技&#xff0c;特别是网络技术的繁荣发展&#xff0c;这种发展不仅让我们的生活变得更加便捷&#xff0c;但也带来了一系列的安全问题。网络安全不仅关系到每一个上网的个人&#xff0c;更是关乎到国家的安全和社会的稳定…

大数据要怎么样学才可以到企业级实战

大数据在企业级实战中扮演着重要角色&#xff0c;因此掌握大数据技术和应用是非常有价值的。下面将详细介绍学习大数据并达到企业级实战水平的步骤和方法。 一、基础知识准备 1. 数据基础知识&#xff1a;了解数据的概念、类型、结构等基本概念&#xff0c;并熟悉常见的数据处…

Edge 无法登录/同步问题【一招搞定】

目录 前言 一、打开 Edge 浏览器显示未同步&#xff0c;点击同步无效 二、Edge 登录报错 0x801901f4 或 0x80190001 解决方法 2.1 报错 0x801901f4 解决方法 2.1.0 Edge 登陆报错图示 2.1.1 添加 Edge 推荐的 DNS 地址 2.1.2 重新登录 Edge 账号成功 2.2 报错 0x801…

第四章 树和二叉树

第四章 树和二叉树 树的基本概念树的概念树的相关术语 二叉树二叉树基本概念二叉树的性质 二叉树的存储结构二叉树的顺序存储结构二叉树的链式存储结构 二叉树的遍历二叉树遍历的递归实现二叉树的层次遍历二叉树遍历的非递归实现 树和森林树的存储结构树、森林与二叉树的关系树…

【低代码开发】:低代码开发助力应用创新

低代码开发&#xff1a;加速应用开发的未来趋势 引言什么是低代码以及功能特点&#xff1f;什么是低代码开发&#xff1f;低代码平台的特点和功能低代码平台的应用场景和优势低代码的优点低代码的缺点低代码平台项目开发流程选择和实施低代码平台 低代码未来的发展趋势低代码平…

Java基础(变量篇)

变量是Java程序中基本的存储单元&#xff0c;变量名有三个基本要素&#xff1a;数据类型、变量名和值。变量名是一块内存单元的名称&#xff0c;就像门牌号一样&#xff0c;通过变量可以找到它表示的内存单元&#xff0c;并对这块内存单元进行操作。在Java中变量必须声明后使用…

英国/法国/意大利/德国/西班牙,电动交通设备配件等相关政策更新

产品安全 合规政策更新&#xff01; 首先请看邮件内容 尊敬的卖家&#xff1a; 您好&#xff01; 我们此次联系您是因为您正在销售需要审批流程的商品。为此&#xff0c;亚马逊正在实施审批流程&#xff0c;以确认我们网站上提供的商品类型须符合指定的认证标准。要在亚马逊…

Cesium小技巧:快速打开API文档

学习Cesium.js的人&#xff0c;肯定经常看官方示例&#xff0c; 网址如下&#xff1a; https://sandcastle.cesium.com/ 有个小技巧&#xff0c;可以快速打开具体类的API文档 在示例中&#xff0c;双击具体类名或方法名&#xff0c;会出现一个提示框 单击或右键菜单-在新标…