tf.Keras (tf-1.15)使用记录4-model.fit方法及其callbacks参数

news2025/2/2 18:14:35

model.fit() 方法是 TensorFlow Keras 中用于训练模型的核心方法。
其中里面的callbacks参数是实现模型保存、监控、以及和tensorboard联动的重要API

1 model.fit() 方法的参数及使用

必需参数

  • x: 训练数据的输入。可以是 NumPy 数组、TensorFlow tf.data.Dataset、Python 生成器或 keras.utils.Sequence 实例。
  • y: 训练数据的目标(标签)。与输入 x 相对应,应该是 NumPy 数组或 TensorFlow tf.data.Dataset。当 xtf.data.Dataset、生成器或 Sequence 实例时,y 应该不被提供,因为 x 已经包含了输入和目标。

常用可选参数

  • batch_size: 整数,指定进行梯度更新时每个批次的样本数。默认值为 32。注意,当使用 tf.data.Dataset、生成器或 Sequence 作为输入时,不应指定 batch_size,因为这些数据结构已经定义了批次大小。
  • epochs: 整数,训练模型的轮数,即整个数据集的前向和反向传播次数。
  • verbose: 整数,日志显示模式。0 = 不在标准输出流中输出日志信息,1 = 进度条(默认),2 = 每轮一行。
  • callbacks: keras.callbacks.Callback 实例的列表。一系列在训练过程中会被调用的回调函数,用于查看训练过程中内部状态和统计信息。
  • validation_split: 浮点数,0 到 1 之间,用来指定一定比例的训练数据作为验证数据的比例。模型会在这些数据上评估损失和任何模型指标,但这些数据不会用于训练。
  • validation_data: 用作验证的数据。格式可以是 (X_val, y_val) 的元组,或者是 tf.data.Dataset。如果提供此参数,则不会根据 validation_split 从训练数据中分割验证数据。
  • shuffle: 布尔值或字符串,表示是否在每轮训练前打乱数据。默认为 True。当设置为 False 时,不会打乱数据。当输入为 tf.data.Dataset、生成器或 Sequence 实例时,此参数无效,因为这些数据结构可能已经定义了自己的打乱数据的方式。
  • initial_epoch: 用于恢复之前的训练。从该轮次开始训练,之前的轮次被视为已经训练过。

高级参数

  • steps_per_epoch: 整数,当使用生成器或 Sequence 实例作为输入时定义一个 epoch 完成并开始下一个 epoch 的总步数(批次数)。通常,应该等于数据集的样本数除以批次大小。
  • validation_steps: 当 validation_data 是生成器或 Sequence 实例时,此参数指定在停止前验证集的总步数(批次数)。
  • validation_batch_size: 整数,仅当 validation_data 是 NumPy 数组时有效。指定验证批次的大小。
  • validation_freq: 指定验证的频率。可以是整数,也可以是 'epoch' 或列表。如果是整数,则表示每多少个 epoch 验证一次。如果是列表,则列表中的元素指定了需要进行验证的 epoch。

使用示例

基本用法:

model.fit(x_train, y_train, batch_size=64, epochs=10, validation_split=0.2)

使用验证数据:

model.fit(x_train, y_train, epochs=10, validation_data=(x_val, y_val))

使用回调函数:

from tensorflow.keras.callbacks import EarlyStopping

early_stopping = EarlyStopping(monitor='val_loss', patience=3)
model.fit(x_train, y_train, epochs=10, validation_split=0.2, callbacks=[early_stopping])

使用 tf.data.Dataset

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(32)
model.fit(train_dataset, epochs=10, validation_data=val_dataset)

model.fit() 方法提供了灵活的方式来训练模型,通过合理设置参数,可以有效地控制训练过程和评估模型性能。

2 callbacks参数使用

callbacks 参数是 model.fit() 方法中一个重要参数,属于keras的高级用法,它允许在训练的不同阶段(如训练开始、训练结束、每个 epoch 开始/结束时等)执行特定的操作。

callbacks 是一个 tf.keras.callbacks.Callback 实例的列表,每个实例都能够访问到模型的内部状态和统计信息。TensorFlow Keras 提供了多种内置的回调函数,同时也支持自定义回调。
以下是callbacks类的全部方法类(https://keras.io/api/callbacks/):
在这里插入图片描述

  1. ModelCheckpoint: 在训练过程中保存模型或模型权重。

    • filepath: 保存模型的路径。
    • monitor: 被监视的数据。
    • verbose: 详细信息模式。
    • save_best_only: 若为 True,则只保存在验证集上性能最好的模型。
    • save_weights_only: 若为 True,则只保存模型的权重。
    • mode: {auto, min, max} 中的一个。决定监视的数据是应该最大化还是最小化。
    • save_freq: 保存模型的频率。
    checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath='model.h5', save_best_only=True, monitor='val_loss', mode='min')
    
  2. EarlyStopping: 当被监视的数据不再提升,则停止训练。

    • monitor: 被监视的数据。
    • min_delta: 改进的最小变化量,小于这个量的改进将被忽略。
    • patience: 没有进步的训练轮数,在这之后训练将被停止。
    • verbose: 详细信息模式。
    • mode: {auto, min, max} 中的一个。
    early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10)
    
  3. ReduceLROnPlateau: 当学习停滞时,减少学习率。

    • monitor: 被监视的数据。
    • factor: 学习率将以这个因子减少。新的学习率 = 学习率 * 因子。
    • patience: 没有进步的训练轮数,在这之后学习率将被减少。
    • verbose: 详细信息模式。
    • mode: {auto, min, max} 中的一个。
    • min_lr: 学习率的下限。
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001)
    
  4. TensorBoard: 为 TensorFlow 提供的可视化工具。

    • log_dir: 用来保存日志文件的路径,TensorBoard 将读取这个路径下的日志。
    • histogram_freq: 对于模型层的激活和权重直方图的计算频率(每个 epoch)。
    • write_graph: 是否在 TensorBoard 中可视化图形。如果 write_graph 被打开,日志文件会变得非常大。
    tensorboard = tf.keras.callbacks.TensorBoard(log_dir='./logs')
    

使用示例

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10),
    tf.keras.callbacks.ModelCheckpoint(filepath='model.h5', save_best_only=True, monitor='val_loss', mode='min'),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.001),
    tf.keras.callbacks.TensorBoard(log_dir='./logs')
]

model.fit(x_train, y_train, validation_split=0.2, epochs=50, callbacks=callbacks)

自定义回调

也可以通过继承 tf.keras.callbacks.Callback 类来创建自定义回调,允许在训练的不同阶段执行自定义的逻辑。

class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        # 每个 epoch 结束时执行
        keys = list(logs.keys())
        print(f"结束 epoch {epoch},损失 = {logs['loss']}, 验证损失 = {logs['val_loss']}")

model.fit(x_train, y_train, validation_split=0.2, epochs=50, callbacks=[CustomCallback()])

回调提供了一种灵活的方式来嵌入训练过程,使得你可以在不改变模型代码的情况下,监控模型的训练、保存模型、调整学习率等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2290883.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Easy系列PLC尺寸测量功能块ST代码(激光微距仪应用)

激光微距仪可以测量短距离内的产品尺寸,产品规格书的测量 精度可以到0.001mm。具体需要看不同的型号。 1、激光微距仪 2、尺寸测量应用 下面我们以测量高度为例子,设计一个高度测量功能块,同时给出测量数据和合格不合格指标。 3、高度测量功能块 4、复位完成信号 5、功能…

996引擎 -地图-添加安全区

996引擎 -地图-添加安全区 文件位置配置 cfg_startpoint.xls特效效果1345参考资料文件位置 文件位置服务端D:\996M2-lua\MirServer-lua\Mir200客户端D:\996M2-lua\996M2_debug\dev配置 cfg_startpoint.xls 服务端\Mir200\Envir\DATA\cfg_startpoint.xls 填歪了也有可能只画一…

[Collection与数据结构] B树与B+树

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏: 🧊 Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 🍕 Collection与…

redex快速体验

第一步: 2.回调函数在每次state发生变化时候自动执行

【VM】VirtualBox安装CentOS8虚拟机

阅读本文前,请先根据 VirtualBox软件安装教程 安装VirtualBox虚拟机软件。 1. 下载centos8系统iso镜像 可以去两个地方下载,推荐跟随本文的操作用阿里云的镜像 centos官网:https://www.centos.org/download/阿里云镜像:http://…

电子电气架构 --- 汽车电子拓扑架构的演进过程

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 简单,单纯,喜欢独处,独来独往,不易合同频过着接地气的生活…

自动驾驶---苏箐对智驾产品的思考

1 前言 对于更高级别的自动驾驶,很多人都有不同的思考,方案也好,产品也罢。最近在圈内一位知名的自动驾驶专家苏箐发表了他自己对于自动驾驶未来的思考。 苏箐是地平线的副总裁兼首席架构师,同时也是高阶智能驾驶解决方案SuperDri…

90,【6】攻防世界 WEB Web_php_unserialize

进入靶场 进入靶场 <?php // 定义一个名为 Demo 的类 class Demo { // 定义一个私有属性 $file&#xff0c;默认值为 index.phpprivate $file index.php;// 构造函数&#xff0c;当创建类的实例时会自动调用// 接收一个参数 $file&#xff0c;用于初始化对象的 $file 属…

【数据分析】案例04:豆瓣电影Top250的数据分析与Web网页可视化(numpy+pandas+matplotlib+flask)

豆瓣电影Top250的数据分析与Web网页可视化(numpy+pandas+matplotlib+flask) 豆瓣电影Top250官网:https://movie.douban.com/top250写在前面 实验目的:实现豆瓣电影Top250详情的数据分析与Web网页可视化。电脑系统:Windows使用软件:PyCharm、NavicatPython版本:Python 3.…

Banana JS,一个严格子集 JavaScript 的解释器

项目地址&#xff1a;https://github.com/shajunxing/banana-js 特色 我的目标是剔除我在实践中总结的JavaScript语言的没用的和模棱两可的部分&#xff0c;只保留我喜欢和需要的&#xff0c;创建一个最小的语法解释器。只支持 JSON 兼容的数据类型和函数&#xff0c;函数是第…

2025.2.1——四、php_rce RCE漏洞|PHP框架

题目来源&#xff1a;攻防世界 php_rce 目录 一、打开靶机&#xff0c;整理信息 二、解题思路 step 1&#xff1a;PHP框架漏洞以及RCE漏洞信息 1.PHP常用框架 2.RCE远程命令执行 step 2&#xff1a;根据靶机提示&#xff0c;寻找版本漏洞 step 3&#xff1a;进行攻击…

对比DeepSeek、ChatGPT和Kimi的学术写作撰写引言能力

引言 引言部分引入研究主题&#xff0c;明确研究背景、问题陈述&#xff0c;并提出研究的目的和重要性&#xff0c;最后&#xff0c;概述研究方法和论文结构。 下面我们使用DeepSeek、ChatGPT4以及Kimi辅助引言撰写。 提示词&#xff1a; 你现在是一名[计算机理论专家]&#…

【C++篇】哈希表

目录 一&#xff0c;哈希概念 1.1&#xff0c;直接定址法 1.2&#xff0c;哈希冲突 1.3&#xff0c;负载因子 二&#xff0c;哈希函数 2.1&#xff0c;除法散列法 /除留余数法 2.2&#xff0c;乘法散列法 2.3&#xff0c;全域散列法 三&#xff0c;处理哈希冲突 3.1&…

数据密码解锁之DeepSeek 和其他 AI 大模型对比的神秘面纱

本篇将揭露DeepSeek 和其他 AI 大模型差异所在。 目录 ​编辑 一本篇背景&#xff1a; 二性能对比&#xff1a; 2.1训练效率&#xff1a; 2.2推理速度&#xff1a; 三语言理解与生成能力对比&#xff1a; 3.1语言理解&#xff1a; 3.2语言生成&#xff1a; 四本篇小结…

知识管理系统推动企业知识创新与人才培养的有效途径分析

内容概要 本文旨在深入探讨知识管理系统在现代企业中的应用及其对于知识创新与人才培养的重要性。通过分析知识管理系统的概念&#xff0c;企业可以认识到它不仅仅是信息管理的一种工具&#xff0c;更是提升整体创新能力的战略性资产。知识管理系统通过集成企业内部信息资源&a…

nth_element函数——C++快速选择函数

目录 1. 函数原型 2. 功能描述 3. 算法原理 4. 时间复杂度 5. 空间复杂度 6. 使用示例 8. 注意事项 9. 自定义比较函数 11. 总结 nth_element 是 C 标准库中提供的一个算法&#xff0c;位于 <algorithm> 头文件中&#xff0c;用于部分排序序列。它的主要功能是将…

Hot100之双指针

283移动零 题目 思路解析 那我们就把不为0的数字都放在数组前面&#xff0c;然后数组后面的数字都为0就行了 代码 class Solution {public void moveZeroes(int[] nums) {int left 0;for (int num : nums) {if (num ! 0) {nums[left] num;// left最后会变成数组中不为0的数…

DeepSeek-R1论文研读:通过强化学习激励LLM中的推理能力

DeepSeek在朋友圈&#xff0c;媒体&#xff0c;霸屏了好长时间&#xff0c;春节期间&#xff0c;研读一下论文算是时下的回应。论文原址&#xff1a;[2501.12948] DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning 摘要&#xff1a; 我们…

群晖Alist套件无法挂载到群晖webdav,报错【连接被服务器拒绝】

声明&#xff1a;我不是用docker安装的 在套件中心安装矿神的Alist套件后&#xff0c;想把夸克挂载到群晖上&#xff0c;方便复制文件的&#xff0c;哪知道一直报错&#xff0c;最后发现问题出在两个地方&#xff1a; 1&#xff09;挂载的路径中&#xff0c;直接填 dav &…

three.js+WebGL踩坑经验合集(6.2):负缩放,负定矩阵和行列式的关系(3D版本)

本篇将紧接上篇的2D版本对3D版的负缩放矩阵进行解读。 (6.1):负缩放&#xff0c;负定矩阵和行列式的关系&#xff08;2D版本&#xff09; 既然three.js对3D版的负缩放也使用行列式进行判断&#xff0c;那么&#xff0c;2D版的结论用到3D上其实是没毛病的&#xff0c;THREE.Li…