TensorFlow 模型中的回调函数与损失函数

news2024/9/24 23:25:54

回调函数

tf.keras 的回调函数实际上是一个类,一般是在 model.fit 时作为参数指定,用于控制在训练过程开始或者在训练过程结束,在每个 epoch 训练开始或者训练结束,在每个 batch 训练开始或者训练结束时执行一些操作,例如收集一些日志信息,改变学习率等超参数,提前终止训练过程等等。同样地,针对 model.evaluate 或者 model.predict 也可以指定 callbacks 参数,用于控制在评估或预测开始或者结束时,在每个 batch 开始或者结束时执行一些操作,但这种用法相对少见。

大部分时候,keras.callbacks 子模块中定义的回调函数类已经足够使用了,如果有特定的需要,我们也可以通过对 keras.callbacks.Callbacks 实施子类化构造自定义的回调函数。所有回调函数都继承至 keras.callbacks.Callbacks 基类,拥有 params 和 model 这两个属性。

其中 params 是一个 dict,记录了训练相关参数 (例如 verbosity, batch size, number of epochs 等等)。model 即当前关联的模型的引用。此外,对于回调类中的一些方法如 on_epoch_begin,on_batch_end,还会有一个输入参数 logs, 提供有关当前 epoch 或者 batch 的一些信息,并能够记录计算结果,如果 model.fit 指定了多个回调函数类,这些 logs 变量将在这些回调函数类的同名函数间依顺序传递。

内置回调函数

  • BaseLogger:收集每个 epoch 上 metrics 在各个 batch 上的平均值,对 stateful_metrics 参数中的带中间状态的指标直接拿最终值无需对各个 batch 平均,指标均值结果将添加到 logs 变量中。该回调函数被所有模型默认添加,且是第一个被添加的。
  • History:将 BaseLogger 计算的各个 epoch 的 metrics 结果记录到 history 这个 dict 变量中,并作为 model.fit 的返回值。该回调函数被所有模型默认添加,在 BaseLogger 之后被添加。
  • EarlyStopping:当被监控指标在设定的若干个 epoch 后没有提升,则提前终止训练。
  • TensorBoard:为 Tensorboard 可视化保存日志信息。支持评估指标,计算图,模型参数等的可视化。
  • ModelCheckpoint:在每个 epoch 后保存模型。
  • ReduceLROnPlateau:如果监控指标在设定的若干个 epoch 后没有提升,则以一定的因子减少学习率。
  • TerminateOnNaN:如果遇到 loss 为 NaN,提前终止训练。
  • LearningRateScheduler:学习率控制器。给定学习率 lr 和 epoch 的函数关系,根据该函数关系在每个 epoch 前调整学习率。
  • CSVLogger:将每个 epoch 后的 logs 结果记录到 CSV 文件中。
  • ProgbarLogger:将每个 epoch 后的 logs 结果打印到标准输出流中。

自定义回调函数

可以使用 callbacks.LambdaCallback 编写较为简单的回调函数,也可以通过对 callbacks.Callback 子类化编写更加复杂的回调函数逻辑。

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers,models,losses,metrics,callbacks
import tensorflow.keras.backend as K

# 示范使用LambdaCallback编写较为简单的回调函数
import json
json_log = open('./data/keras_log.json', mode='wt', buffering=1)
json_logging_callback = callbacks.LambdaCallback(
    on_epoch_end=lambda epoch, logs: json_log.write(
        json.dumps(dict(epoch = epoch,**logs)) + '\n'),
    on_train_end=lambda logs: json_log.close()
)

# 示范通过Callback子类化编写回调函数(LearningRateScheduler的源代码)
class LearningRateScheduler(callbacks.Callback):

    def __init__(self, schedule, verbose=0):
        super(LearningRateScheduler, self).__init__()
        self.schedule = schedule
        self.verbose = verbose

    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.model.optimizer, 'lr'):
            raise ValueError('Optimizer must have a "lr" attribute.')
        try:
            lr = float(K.get_value(self.model.optimizer.lr))
            lr = self.schedule(epoch, lr)
        except TypeError:  # Support for old API for backward compatibility
            lr = self.schedule(epoch)
        if not isinstance(lr, (tf.Tensor, float, np.float32, np.float64)):
            raise ValueError('The output of the "schedule" function '
                             'should be float.')
        if isinstance(lr, ops.Tensor) and not lr.dtype.is_floating:
            raise ValueError('The dtype of Tensor should be float')
        K.set_value(self.model.optimizer.lr, K.get_value(lr))
        if self.verbose > 0:
            print('\nEpoch %05d: LearningRateScheduler reducing learning '
                 'rate to %s.' % (epoch + 1, lr))

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        logs['lr'] = K.get_value(self.model.optimizer.lr)

损失函数

一般来说,监督学习的目标函数由损失函数和正则化项组成。(Objective = Loss + Regularization),对于 keras 模型,目标函数中的正则化项一般在各层中指定,例如使用 Dense 的 kernel_regularizer 和 bias_regularizer 等参数指定权重使用 l1 或者 l2 正则化项,此外还可以用 kernel_constraint 和 bias_constraint 等参数约束权重的取值范围,这也是一种正则化手段。

损失函数在模型编译时候指定。对于回归模型,通常使用的损失函数是均方损失函数 mean_squared_error。对于二分类模型,通常使用的是二元交叉熵损失函数 binary_crossentropy。

对于多分类模型,如果 label 是 one-hot 编码的,则使用类别交叉熵损失函数 categorical_crossentropy。如果 label 是类别序号编码的,则需要使用稀疏类别交叉熵损失函数 sparse_categorical_crossentropy。如果有需要,也可以自定义损失函数,自定义损失函数需要接收两个张量 y_true,y_pred 作为输入参数,并输出一个标量作为损失函数值。

损失函数和正则化项

tf.keras.backend.clear_session()

model = models.Sequential()
model.add(layers.Dense(64, input_dim=64,
                kernel_regularizer=regularizers.l2(0.01),
                activity_regularizer=regularizers.l1(0.01),
                kernel_constraint = constraints.MaxNorm(max_value=2, axis=0)))
model.add(layers.Dense(10,
        kernel_regularizer=regularizers.l1_l2(0.01,0.01),activation = "sigmoid"))
model.compile(optimizer = "rmsprop",
        loss = "binary_crossentropy",metrics = ["AUC"])
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
dense (Dense)                (None, 64)                4160
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650
=================================================================
Total params: 4,810
Trainable params: 4,810
Non-trainable params: 0
_________________________________________________________________

内置损失函数

内置的损失函数一般有类的实现和函数的实现两种形式。如:CategoricalCrossentropy 和 categorical_crossentropy 都是类别交叉熵损失函数,前者是类的实现形式,后者是函数的实现形式。常用的一些内置损失函数说明如下。

  • mean_squared_error(均方误差损失,用于回归,简写为 mse, 类与函数实现形式分别为 MeanSquaredError 和 MSE)
  • mean_absolute_error (平均绝对值误差损失,用于回归,简写为 mae, 类与函数实现形式分别为 MeanAbsoluteError 和 MAE)
  • mean_absolute_percentage_error (平均百分比误差损失,用于回归,简写为 mape, 类与函数实现形式分别为 MeanAbsolutePercentageError 和 MAPE)
  • Huber(Huber 损失,只有类实现形式,用于回归,介于 mse 和 mae 之间,对异常值比较鲁棒,相对 mse 有一定的优势)
  • binary_crossentropy(二元交叉熵,用于二分类,类实现形式为 BinaryCrossentropy)
  • categorical_crossentropy(类别交叉熵,用于多分类,要求 label 为 onehot 编码,类实现形式为 CategoricalCrossentropy)
  • sparse_categorical_crossentropy(稀疏类别交叉熵,用于多分类,要求 label 为序号编码形式,类实现形式为 SparseCategoricalCrossentropy)
  • hinge(合页损失函数,用于二分类,最著名的应用是作为支持向量机 SVM 的损失函数,类实现形式为 Hinge)
  • kld(相对熵损失,也叫 KL 散度,常用于最大期望算法 EM 的损失函数,两个概率分布差异的一种信息度量。类与函数实现形式分别为 KLDivergence 或 KLD)
  • cosine_similarity(余弦相似度,可用于多分类,类实现形式为 CosineSimilarity)

自定义损失函数

自定义损失函数接收两个张量 y_true, y_pred 作为输入参数,并输出一个标量作为损失函数值。也可以对 tf.keras.loss.Loss 进行子类化,重写 call 方法实现损失的计算逻辑,从而得到损失函数的类的实现。

下面是一个 focal Loss 的自定义实现示范。focal Loss 是一种对 binary_crossentropy 的改进损失函数形式。它在样本不均衡和存在较多易分类的样本时相比 binary_crossentropy 具有明显的优势。

它有两个可调参数,alpha 参数和 gamma 参数。其中 alpha 参数主要用于衰减负样本的权重,gamma 参数主要用于衰减容易训练样本的权重。从而让模型更加聚焦在正样本和困难样本上。这就是为什么这个损失函数叫做 focal Loss。

def focal_loss(gamma=2., alpha=0.75):

    def focal_loss_fixed(y_true, y_pred):
        bce = tf.losses.binary_crossentropy(y_true, y_pred)
        p_t = (y_true * y_pred) + ((1 - y_true) * (1 - y_pred))
        alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
        modulating_factor = tf.pow(1.0 - p_t, gamma)
        loss = tf.reduce_sum(alpha_factor * modulating_factor * bce,axis = -1 )
        return loss
    return focal_loss_fixed

class FocalLoss(tf.keras.losses.Loss):

    def __init__(self,gamma=2.0,alpha=0.75,name = "focal_loss"):
        self.gamma = gamma
        self.alpha = alpha

    def call(self,y_true,y_pred):
        bce = tf.losses.binary_crossentropy(y_true, y_pred)
        p_t = (y_true * y_pred) + ((1 - y_true) * (1 - y_pred))
        alpha_factor = y_true * self.alpha + (1 - y_true) * (1 - self.alpha)
        modulating_factor = tf.pow(1.0 - p_t, self.gamma)
        loss = tf.reduce_sum(alpha_factor * modulating_factor * bce,axis = -1 )
        return loss

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1334892.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何使用PatchaPalooza对微软每月的安全更新进行全面深入的分析

关于PatchaPalooza PatchaPalooza是一款针对微软每月安全更新的强大分析工具,广大研究人员可以直接使用该工具来对微软每月定期推送的安全更新代码进行详细、全面且深入的安全分析。 PatchaPalooza使用了微软MSRC CVRF API的强大功能来获取、存储和分析安全更新数…

大语言模型说明书

在浩瀚的信息宇宙中,大语言模型如同一颗璀璨的星星正在熠熠生辉。21世纪以来,人工智能可谓是飞速发展,从简单的神经网络到大语言模型、生成式AI,这并非仅仅是一种技术的进步,更是人类智慧的飞跃。大语言模型不仅仅是语…

Wireshark网络工具来了

Wireshark是网络包分析工具。网络包分析工具的主要作用是尝试捕获网络包,并尝试显示包的尽可能详细的情况。 Wireshark是一个免费开源软件,不需要付费,免费使用,可以直接登陆到Wireshark的官网下载安装。 在windows环境中&#x…

【强化学习】PPO:近端策略优化算法

近端策略优化算法 《Proximal Policy Optimization Algorithms》 论文地址:https://arxiv.org/pdf/1707.06347.pdf 一、 置信域方法(Trust Region Methods) ​ 设 π θ o l d \pi_{\theta_{old}} πθold​​是先前参数为 θ o l d \theta_{old} θold​的策略网…

JavaScript:DOM-事件

JavaScript:DOM - 事件 事件监听什么是事件监听事件监听的方式事件类型点击事件鼠标事件键盘事件焦点事件文本框输入事件 事件对象什么是事件对象获取事件对象事件对象常用属性事件解绑 环境对象 this事件流事件捕获事件冒泡事件捕获与事件冒泡的影响阻止冒泡事件委…

CentOS7安装Java11

文章目录 Java11下载地址卸载OpenJDK查询原系统安装的 JDK根据原系统安装的 JDK 进行卸载命令修改 安装JDK生成JRE Java11下载地址 https://www.oracle.com/java/technologies/javase/jdk11-archive-downloads.html 卸载OpenJDK 查询原系统安装的 JDK java -version yum l…

如何将本地websocket发布至公网并实现远程访问服务端

文章目录 1. Java 服务端demo环境2. 在pom文件引入第三包封装的netty框架maven坐标3. 创建服务端,以接口模式调用,方便外部调用4. 启动服务,出现以下信息表示启动成功,暴露端口默认99995. 创建隧道映射内网端口6. 查看状态->在线隧道,复制所创建隧道的公网地址加端口号7. 以…

深入探讨多模态模型和计算机视觉

近年来,机器学习领域在从图像识别到自然语言处理的不同问题类型上取得了显着进展。然而,这些模型中的大多数都对来自单一模态的数据进行操作,例如图像、文本或语音。相比之下,现实世界的数据通常来自多种模态,例如图像…

前端---html 的介绍

1. 网页效果图 --CSDN 2. html的定义 HTML 的全称为&#xff1a;HyperText Mark-up Language, 指的是超文本标记语言。 标记&#xff1a;就是标签, <标签名称> </标签名称>, 比如: <html></html>、<h1></h1> 等&#xff0c;标签大多数都是…

Echarts随机生成颜色

Echarts生成随机颜色&#xff0c;并且不要黑色、灰色、棕色等难看的颜色&#xff0c;暖色系并且颜色亮丽&#xff0c; 可以通过修改saturation 和lightness 的随机数值&#xff0c;提高颜色饱和度和亮度 function generateWarmColor() {let hue Math.floor(Math.random() * 3…

【ctf】whireshark流量分析之tcp_杂篇

目录 简介 常考 图片类 提取png.pcap&#xff08;常规&#xff09; 异常的流量分析&#xff08;*&#xff0c;特殊&#xff09; john-in-the-middle&#xff08;特殊&#xff09; ​编辑 zip类 1.pcap&#xff08;常规&#xff09; 方法1&#xff08;常规提取压缩包&…

IAP在编程升级

以STM32F103ZET6为例讲解&#xff0c; FLASH 512KB,SRAM64KB. 让APP程序加载在FLASH里运行&#xff0c;在SRAM运行的先不讲解。 IAP执行流程 当加入 IAP 程序之后&#xff0c;程序运行流程如图。 APP程序的生成步骤 1.APP 程序起始地址设置方法 我们设置起始地址&#xff…

浮点数的转换--IEEE 754

IEEE754标准是一种浮点数表示标准&#xff0c;一般分为 单精度&#xff08;32位的二进制数&#xff09;&#xff1b;双精度&#xff08;64位的二进制数&#xff09; 根据国际标准IEEE754&#xff0c;任意一个二进制浮点数V可以表示为下面形式&#xff1a; V (-1)^s *&#…

JavaWeb后门(webshell)基础

0x00 基础 JSP JSP全称为JavaServer Pages&#xff0c;是一种用于开发支持动态内容的Web页面的技术。它有助于开发人员通过使用特殊的JSP标记在HTML页面中插入Java代码&#xff0c;其中大多数以<&#xff05;开头&#xff0c;以&#xff05;>结尾。Java是一种通用的计算…

互联网+建筑工地:技术革新引领建筑行业的未来

随着科技的飞速发展&#xff0c;互联网正日益渗透到建筑工地的方方面面。从设计、施工到管理&#xff0c;互联网建筑工地的深度融合不仅推动了建筑行业的数字化转型&#xff0c;还为工地管理、信息交流、安全监控等带来了全新的解决方案。本文将介绍互联网建筑工地的几个关键技…

【c++】入门2

函数重载 函数重载&#xff1a;是函数的一种特殊情况&#xff0c;C允许在同一作用域中声明几个功能类似的同名函数&#xff0c;这 些同名函数的形参列表(参数个数 或 类型 或 类型顺序)不同&#xff0c;常用来处理实现功能类似数据类型 不同的问题。 c区分重载函数是根据参数…

golang的jwt学习笔记

文章目录 初始化项目加密一步一步编写程序另一个参数--加密方式关于StandardClaims 解密解析出来的怎么用关于`MapClaims`上面使用结构体的全代码实战项目关于验证这个项目的前端初始化项目 自然第一步是暗转jwt-go的依赖啦 #go get github.com/golang-jwt/jwt/v5 go get githu…

AAAI 2024录用论文合集,包含图神经网络、时间序列、多模态、异常检测等热门研究方向

AAAI是国际顶级人工智能学术会议&#xff0c;属于CCF A类&#xff0c;在人工智能领域享有盛誉。今年的AAAI 会议投稿量突破了历史记录&#xff0c;共有12100篇投稿&#xff08;主赛道&#xff09;&#xff0c;最终录用2342篇&#xff0c;录用率为23.75%。对比前几年有了很大的提…

人工智能_机器学习073_SVM支持向量机_人脸识别模型建模_预测可视化_网格搜索交叉验证最优化参数对比---人工智能工作笔记0113

接着上一节来说,可以看到我们已经找到了合适的参数,然后 我们可以看一下这里 gc.best_params_ 就可以打印出最合适的参数 然后我们把最合适串按说填入到代码中,然后进行计算,看看得分 可以看到得分,训练数据是1.0 然后测试数据得分是0.7857...对吧

unity HoloLens2开发,使用Vuforia识别实体 触发交互(二)(有dome)

提示&#xff1a;文章有错误的地方&#xff0c;还望诸位大神不吝指教&#xff01; 文章目录 前言一、打包到HoloLens二、Vuforia相关1.配置识别框2.制作一个半透明识别框&#xff1a;3.设置如下4.问题 四 HoloLens2 问题总结 前言 我使用的utniy 版本&#xff1a;Unity 2021.3…