【使用 TensorFlow 2】03/3 创建自定义损失函数

news2024/11/12 13:44:30

一、说明

        TensorFlow 2发布已经接近5年时间,不仅继承了Keras快速上手和易于使用的特性,同时还扩展了原有Keras所不支持的分布式训练的特性。3大设计原则:简化概念,海纳百川,构建生态.这是本系列的第三部分,我们将创建代价函数并在 TensorFlow 2 中使用它们。
 

图 1:实际应用中的梯度下降算法

二、关于代价函数

        神经网络学习将训练数据中的一组输入映射到一组输出。它通过使用某种形式的优化算法来实现这一点,例如梯度下降、随机梯度下降、AdaGrad、AdaDelta 或一些最近的算法,例如 Adam、Nadam 或 RMSProp。梯度下降中的“梯度”指的是误差梯度。每次迭代后,网络都会将其预测输出与实际输出进行比较,然后计算“误差”。通常,对于神经网络,我们寻求最小化错误。因此,用于最小化误差的目标函数通常称为成本函数或损失函数,并且由“损失函数”计算的值简称为“损失”。各种问题中使用的典型损失函数 –

A。均方误差

b. 均方对数误差

C。二元交叉熵

d. 分类交叉熵

e. 稀疏分类交叉熵

        在Tensorflow中,这些损失函数已经包含在内,我们可以如下所示调用它们。

        1 损失函数作为字符串

model.compile(损失='binary_crossentropy',优化器='adam',指标=['准确性'])

        或者,

        2. 损失函数作为对象

从tensorflow.keras.losses导入mean_squared_error

model.compile(损失=mean_squared_error,优化器='sgd')

        将损失函数作为对象调用的优点是我们可以在损失函数旁边传递参数,例如阈值。

从tensorflow.keras.losses导入mean_squared_error

model.compile(损失=均方误差(参数=值),优化器='sgd')

三、使用函数创建自定义损失:

        为了使用函数创建损失,我们需要首先命名损失函数,它将接受两个参数,y_true(真实标签/输出)和y_pred(预测标签/输出)。

def loss_function(y_true, y_pred):

***一些计算***

回波损耗

四、创建均方根误差损失 (RMSE):

        损失函数名称 — my_rmse

        目的是返回目标 (y_true) 和预测 (y_pred) 之间的均方根误差。

        RMSE 公式:

  • 误差:真实标签和预测标签之间的差异。
  • sqr_error:误差的平方。
  • mean_sqr_error:误差平方的平均值
  • sqrt_mean_sqr_error:误差平方均值的平方根(均方根误差)。
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import backend as K
#defining the loss function
def my_rmse(y_true, y_pred):
    #difference between true label and predicted label
    error = y_true-y_pred    
    #square of the error
    sqr_error = K.square(error)
    #mean of the square of the error
    mean_sqr_error = K.mean(sqr_error)
    #square root of the mean of the square of the error
    sqrt_mean_sqr_error = K.sqrt(mean_sqr_error)
    #return the error
    return sqrt_mean_sqr_error
#applying the loss function
model.compile (optimizer = 'sgd', loss = my_rmse)

五、创建 Huber 损失 

图 2:Huber 损失(绿色)和平方误差损失(蓝色)作为 y — f(x) 的函数

         Huber损失的公式:

        这里,

        δ是阈值,

        a是误差(我们将计算 a ,标签和预测之间的差异)

        所以,当|a| ≤δ,损失= 1/2*(a)²

        当|a|>δ 时,损失 = δ(|a| — (1/2)*δ)

        代码:

# creating the Conv-Batch Norm block

def conv_bn(x, filters, kernel_size, strides=1):
    
    x = Conv2D(filters=filters, 
               kernel_size = kernel_size, 
               strides=strides, 
               padding = 'same', 
               use_bias = False)(x)
    x = BatchNormalization()(x)
return x

        解释:

        首先我们定义一个函数 - my huber loss,它接受 y_true 和 y_pred

        接下来我们设置阈值 = 1

        接下来我们计算误差 a = y_true-y_pred

        接下来我们检查误差的绝对值是否小于或等于阈值。is_small_error返回一个布尔值(True 或 False)。

        我们知道,当|a| ≤δ,loss = 1/2*(a)²,因此我们将small_error_loss计算为误差的平方除以2 

        否则,当|a| >δ,则损失等于 δ(|a| — (1/2)*δ)。我们在big_error_loss中计算这一点。

        最后,在return语句中,我们首先检查is_small_error是true还是false,如果是true,函数返回small_error_loss,否则返回big_error_loss。这是使用 tf.where 完成的。

        然后我们可以使用下面的代码编译模型,

model.compile(optimizer='sgd', loss=my_huber_loss)

在前面的代码中,我们始终使用阈值1。

但是,如果我们想要调整超参数(阈值)并在编译期间添加新的阈值,该怎么办?然后我们必须使用函数包装,即将损失函数包装在另一个外部函数周围。我们需要一个包装函数,因为默认情况下任何损失函数只能接受 y_true 和 y_pred 值,并且我们不能向原始损失函数添加任何其他参数。

5.1 使用包装函数的 Huber 损失

        包装函数代码如下所示:

import tensorflow as tf
#wrapper function which accepts the threshold parameter
def my_huber_loss_with_threshold(threshold):
   def my_huber_loss(y_true, y_pred):   
  
       error = y_true - y_pred     
       is_small_error = tf.abs(error) <= threshold     
       small_error_loss = tf.square(error) / 2     
       big_error_loss = threshold * (tf.abs(error) - (0.5 * threshold))
       return tf.where(is_small_error, small_error_loss, big_error_loss)
    return my_huber_loss

在这种情况下,阈值不是硬编码的。相反,我们可以在模型编译期间通过阈值。

model.compile(optimizer='sgd', loss=my_huber_loss_with_threshold(threshold=1.5))

5.2 使用类的 Huber 损失 (OOP)

import tensorflow as tf
from tensorflow.keras.losses import Loss

class MyHuberLoss(Loss): #inherit parent class
  
    #class attribute
    threshold = 1
  
    #initialize instance attributes
    def __init__(self, threshold):
        super().__init__()
        self.threshold = threshold

    #compute loss
    def call(self, y_true, y_pred):
        error = y_true - y_pred
        is_small_error = tf.abs(error) <= self.threshold
        small_error_loss = tf.square(error) / 2
        big_error_loss = self.threshold * (tf.abs(error) - (0.5 * self.threshold))
        return tf.where(is_small_error, small_error_loss, big_error_loss)

        MyHuberLoss是类名。在类名之后,我们从tensorflow.keras.losses继承父类'Loss'。所以MyHuberLoss继承为Loss。这允许我们使用 MyHuberLoss 作为损失函数。

        __init__从类中初始化对象。

        从类实例化对象时执行的调用函数

        init 函数获取阈值,call 函数获取我们之前出售的 y_true 和 y_pred 参数。因此,我们将阈值声明为类变量,这允许我们给它一个初始值。

        在 __init__ 函数中,我们将阈值设置为 self.threshold。

        在调用函数中,所有阈值类变量将由 self.threshold 引用。

        以下是我们如何在 model.compile 中使用这个损失函数。

model.compile(optimizer='sgd', loss=MyHuberLoss(threshold=1.9))

六、创建对比损失(用于暹罗网络):

        连体网络比较两个图像是否相似。对比损失是暹罗网络中使用的损失函数。

        在上面的公式中,

        Y_true 是图像相似度细节的张量。如果图像相似,它们就是 1,如果不相似,它们就是 0。

        D 是图像对之间的欧几里德距离的张量。

        边距是一个常量,我们可以用它来强制它们之间的最小距离,以便将它们视为相似或不同。

        如果Y_true =1,则方程的第一部分变为 D²,第二部分变为零。因此,当 Y_true 接近 1 时,D² 项具有更大的权重。

        如果Y_true = 0,则方程的第一部分变为零,第二部分产生一些结果。这为最大项赋予了更大的权重,而为 D 平方项赋予了更少的权重,因此最大项在损失的计算中占主导地位。

使用包装函数的对比损失

def contrastive_loss_with_margin(margin):
    def contrastive_loss(y_true, y_pred):
        
        square_pred = K.square(y_pred)
        margin_square = K.square(K.maximum(margin - y_pred, 0))
        return K.mean(y_true * square_pred + (1 - y_true) * margin_square)
    return contrastive_loss

七、结论 

        Tensorflow 中不可用的任何损失函数都可以使用函数、包装函数或以类似的方式使用类来创建。阿琼·萨卡

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1079385.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux之open/close/read/write/lseek记录

一、文件权限 这里不做过多描述&#xff0c;只是简单的记录&#xff0c;因为下面的命令会涉及到。linux下一切皆是文件包括文本、硬件设备、管道、数据库、socket等。通过ls -l 命令可以查看到以下信息 drwxrwxrwx 1 root root 0 Oct 10 17:06 open -rwxrwxrwx 1 root roo…

js 滚动鼠标滑轮放大缩小图片

<div><h1>原图</h1><imgsrc"https://so.360tres.com/dmsmty/160_160_100/t01b29734b9604fb7aa.webp"/><h1>可放大缩小的图</h1><div class"imgView"><imgsrc"https://so.360tres.com/dmsmty/160_160_10…

建设一个完整的企业经营管理体系是什么样的

建设一个完整的企业经营管理体系是确保企业高效、可持续发展的基础。该体系包括组织架构、战略规划、运营管理、人力资源管理、财务管理等多个要素&#xff0c;下面将逐一进行详细介绍。 一、组织架构&#xff1a; 组织架构是企业内部各个部门、职能和层级之间的关系和分工。…

代码随想录算法训练营第五十三天 |1143.最长公共子序列、1035.不相交的线、53. 最大子序和动态规划

一、1143.最长公共子序列 题目链接/文章讲解&#xff1a;代码随想录 视频讲解&#xff1a;动态规划子序列问题经典题目 | LeetCode&#xff1a;1143.最长公共子序列_哔哩哔哩_bilibili 思考&#xff1a; 1.确定dp数组&#xff08;dp table&#xff09;以及下标的含义 dp[i][j]…

关键词搜索天猫商品列表数据,天猫商品列表数据接口

在网页抓取方面&#xff0c;可以使用 Python、Java 等编程语言编写程序&#xff0c;通过模拟 HTTP 请求&#xff0c;获取天猫网站上的商品页面。在数据提取方面&#xff0c;可以使用正则表达式、XPath 等方式从 HTML 代码中提取出有用的信息。值得注意的是&#xff0c;天猫网站…

pytest + yaml 框架 -56. 输出日志优化+allure报告优化

前言 v1.4.8 版本优化接口请求和响应输出日志&#xff0c;生成的allure报告也按步骤优化request和response详情日志 优化日志 用例 test_log1: -name: log1request:url: http://127.0.0.1:8000/api/test/demomethod: GETvalidate:- eq: [status_code, 200]- eq: ["cod…

从 0 到 1 打造企业数字化运营闭环

打造企业数字化运营闭环是现代企业发展的必然趋势。它涉及到信息技术、数据分析、流程优化等多个方面&#xff0c;通过有效整合和运用这些资源&#xff0c;可以实现从0到1的全面数字化转型。 下面是一个详细的介绍&#xff0c;包括步骤、关键要素和实施策略。 一、了解需求和…

吉客云对接打通金蝶云星空销售单查询接口与销售出库新增接口

吉客云对接打通金蝶云星空销售单查询接口与销售出库新增接口 接通系统&#xff1a;吉客云 吉客云是基于“网店管家”十五年电商ERP行业和技术积累基础上顺应产业发展需求&#xff0c;重新定位、全新设计推出的换代产品&#xff0c;从业务数字化和组织数字化两个方向出发&#x…

图扑 HT for Web 风格属性手册教程

图扑软件明星产品 HT for Web 是一套纯国产化独立自主研发的 2D 和 3D 图形界面可视化引擎。HT for Web&#xff08;以下简称 HT&#xff09;图元的样式由其 Style 属性控制&#xff0c;并且不同类型图元的 Style 属性各不相同。为了方便查询和理解图元的 Style 属性&#xff0…

第二证券:5.5G时代将至 算力基建迎政策助力

昨日&#xff0c;A股全线低开&#xff0c;三大股指盘中均跌超1%&#xff0c;盘中冲高回落&#xff0c;午后逐渐止跌。到收盘&#xff0c;沪指跌0.44%报3096.92点&#xff0c;深成指微跌0.03%报10106.96点&#xff0c;创业板指跌0.26%报1998.61点&#xff0c;两市算计成交7700元…

腾讯系数藏停摆一年 玩家被甩在维权路上

暴雷、维权、清退是过去一年数藏行业的常态。小平台跑了&#xff0c;腾讯这样的大厂以关停、退款终止运营数藏业务时&#xff0c;吃相也不好看。 在黑猫投诉平台上&#xff0c;幻核被投诉退款缓慢&#xff0c;曾经发行过数字藏品的QQ音乐被投诉违背发行时承诺的“持有356天后可…

嵌入式学习笔记(52)ADC的引入

11.1.1什么是ADC (1)ADC:analog digital converter,AD转换&#xff0c;模数转换&#xff08;也就是模拟转数字&#xff09; (2)CPU本身是数字的&#xff0c;而外部世界变量&#xff08;如电压、温度、高度、压力&#xff09;都是模拟的&#xff0c;所以需要用CPU来处理这些外…

【办公自动化】在Excel中按条件筛选数据并存入新的表2.0(文末送书)

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

多测师肖sir_高级金牌讲师_python之 字符、索引、切片、列表、集合004

python之 字符、索引、切片、列表、集合 一、索引 索引在公司中一般叫下标或角标 定义&#xff1a;我们可以直接使用索引来访问序列中的元素&#xff0c;同时索引可分为正向和负向两种&#xff0c;而切片也会用到索引&#xff0c;下面放上一个图&#xff0c;有助于大 家理解正…

Docker 网络访问原理解密

How Container Networking Works: Practical Explanation 这篇文章讲得非常棒&#xff0c;把docker network讲得非常清晰。 分为三个部分&#xff1a; 1&#xff09;docker 内部容器互联。 2&#xff09;docker 容器 访问 外部root 网络空间。 3&#xff09;外部网络空间…

MySQL数据库技术笔记(6)

新建两张表&#xff0c;班级表和学生表&#xff0c;因为班级表与学生表之间是 1 对多的关系&#xff0c;需要将少的表的主键放置多的 表中称为外键。 添加班级信息 添加学生信息并关联对应的班级 连接查询&#xff0c;表示查询的时候关联多张表进行查询 最终两张表的交叉连接…

面试算法24:反转链表

题目 定义一个函数&#xff0c;输入一个链表的头节点&#xff0c;反转该链表并输出反转后链表的头节点。例如&#xff0c;把图4.8&#xff08;a&#xff09;中的链表反转之后得到的链表如图4.8&#xff08;b&#xff09;所示。 分析 由于节点j的next指针指向了它的前一个节…

别用==比较包装类

前两天工作把代码合并到主分支时&#xff0c;被公司的安全监测机制拦截了&#xff0c;一看是因为用了来比较Integer类型。 在阿里开发手册中&#xff0c;有这样一条&#xff1a;在包装类进行比较的时候&#xff0c;要用equals方法&#xff0c;而不是。 具体的原因下面也讲解的…

提升爬虫IP时效:解决被封IP的难题

在进行数据采集时&#xff0c;经常会遇到被目标网站封禁IP的情况&#xff0c;这给爬虫系统带来了困扰。本文将介绍如何提升爬虫IP的时效&#xff0c;解决被封IP的难题&#xff0c;帮助您顺利进行数据采集&#xff0c;不再受限于IP封禁。 第一步&#xff1a;使用爬虫IP 使用爬虫…

飞书应用机器人文件上传

背景&#xff1a; 接上一篇 flask_apscheduler实现定时推送飞书消息&#xff0c;当检查出的异常结果比较多的时候&#xff0c;群里会有很多推送消息&#xff0c;一条条检查工作量会比较大&#xff0c;且容易出现遗漏。   现在需要将定时任务执行的结果记录到文件&#xff0c;…