【需求实现】Tensorflow2的曲线拟合(二):进度条简化

news2024/9/20 8:05:22

文章目录

  • 导读
  • 普通的输出方式
  • 上下求索
    • TensorBoard是个不错的切入点
    • 与Callback参数对应的Callback方法
    • 官方的内置Callback
    • 官方进度条
    • 简单的猜测与简单的验证
    • 拼图凑齐了!

导读

在训练模型的过程中往往会有日志一堆一堆的困扰。我并不想知道,因为最后我会在变量里面查询,反正训练过程中也没心思看。于是,就想把进度条简化一下。下面给出解决方案。

普通的输出方式

对于一般的训练过程,我们可能在Tensorflow中的fit方法中,将verbose置为 1 1 1,或者不设置verbose而让Tensorflow默认verbose 1 1 1。这样的话就会有如下图一样长篇大论的输出。

在这里插入图片描述
虽然不至于很烦躁,但是实在不愿意去管这些事情。又不是做生物实验,完全不需要人在这边守着嘛。趁着这时间泡杯咖啡多好。

于是呢,就想着源码里面是如何将输出显示出来的。

上下求索

TensorBoard是个不错的切入点

Tensorflow海量的源码中寻找一个输出无疑是大海捞针,对于Windows用户来说找起来超级麻烦,除非Linux用户直接用grep命令作弊。

但是呢,突然就注意到,Tensorflow还有一个Tensorboard,是在fit方法里面的callback参数中出现。既然日志能够从callback参数中获得,那么这里是有什么玄机吗?

与Callback参数对应的Callback方法

于是找到了文件中tensorflow/tensorflow/python/keras/engine/training.py,也就是官方GitHub的这一页(点击直达callbacks注释那一行),他是这么说明的:

'''
callbacks: List of `keras.callbacks.Callback` instances.
    List of callbacks to apply during training.
    See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger`
    and `tf.keras.callbacks.History` callbacks are created automatically
    and need not be passed into `model.fit`.
    `tf.keras.callbacks.ProgbarLogger` is created or not based on
    `verbose` argument to `model.fit`.
    Callbacks with batch-level calls are currently unsupported with
    `tf.distribute.experimental.ParameterServerStrategy`, and users are
    advised to implement epoch-level calls instead with an appropriate
    `steps_per_epoch` value.
'''

这也就是说,官方已经把进度条内置到ProgbarLogger这个类里面了,并通过callbacks调用。

其中,对于callbacks是这么调用的:

callbacks.on_train_begin()

而这个on_train_begin又是属于CallbackList类中:

# in CallbackList class
def on_train_begin(self, logs=None):
  """Calls the `on_train_begin` methods of its callbacks.

  Args:
      logs: Dict. Currently no data is passed to this argument for this method
        but that may change in the future.
  """
  logs = self._process_logs(logs)
  for callback in self.callbacks:
    callback.on_train_begin(logs)

也就是说,是遍历callbacks中的所有callback然后一一执行。

执行过程的源码在Callback类中(类名跟上一个不一样哦):

# in Callback class
@doc_controls.for_subclass_implementers
def on_train_begin(self, logs=None):
  """Called at the beginning of training.

  Subclasses should override for any actions to run.

  Args:
      logs: Dict. Currently no data is passed to this argument for this method
        but that may change in the future.
  """

但很明显,这个源码就是想让我们自定义。

刚刚好发现进度条也在这里。那接下来的事情就更明确了,去找这个类就行了。一方面是查看进度条的原理,另一方面则是按照官方的进度条仿写一个简单的进度条。

官方的内置Callback

于是就找到了tensorflow/tensorflow/python/keras/callbacks.py文件中,也就是官方GitHub的这一页(点击直达ProgbarLogger类定义的那一行)。但比较可惜的是,这个类定义的时候注释太少了,并不能很确定每个类中的各个方法都在做什么。怎么办呢?

别忘了官方给的提示:【内置了ProgbarLogger类与History类】。我们既然需要了解如何在Callback里面调用,那么就需要了解这两个东西是如何插入进去的。但是这个类注释实在是太少了,很多东西看得不明不白的,该怎么办呢?

当然是【贪心搜索】了呀,按照【命名】与【个人经验】去判断哪一个最可能是我们想要的方法。虽然很不靠谱,但是万一运气好撞上了呢?于是呢,就找到了_add_default_callbacks方法,也就是默认插入的一些东西。他是这么写的:

  def _add_default_callbacks(self, add_history, add_progbar):
    """Adds `Callback`s that are always present."""
    self._progbar = None
    self._history = None

    for cb in self.callbacks:
      if isinstance(cb, ProgbarLogger):
        self._progbar = cb
      elif isinstance(cb, History):
        self._history = cb

    if self._progbar is None and add_progbar:
      self._progbar = ProgbarLogger(count_mode='steps')
      self.callbacks.insert(0, self._progbar)

    if self._history is None and add_history:
      self._history = History()
      self.callbacks.append(self._history)

总之就是一些判断,如果空就创建。

看来插入就是insert方法与append方法了。不难猜测,也不用猜测,callbacks将是一个数组。

官方进度条

既然知道了进度条是如何被调用的,那么接下来就是得了解官方的进度条是怎么添加的。

当然,还是在ProgbarLogger类中,为了方便点击这里就能传送。这里面明显的给出了输出Epoch,正好就是我们需要找到的输出。源码是这样的:

def on_epoch_begin(self, epoch, logs=None):
  self._reset_progbar()
  self._maybe_init_progbar()
  if self.verbose and self.epochs > 1:
    print('Epoch %d/%d' % (epoch + 1, self.epochs))

而且条件是需要verbose 0 0 0。怪不得我们把verbose置为 0 0 0就什么都没有了。

简单的猜测与简单的验证

既然官方这么设计能够输出,那么我们也就简单的提出一个想法:

我们首先需要定义一个类A,然后像这个ProgbarLogger类一样继承自Callback,然后再自定义on_epoch_begin或者on_epoch_end方法。这个方法需要具有三个参数:

  • 首先是on_epoch_begin方法作为A类一个成员的self
  • 其次是与ProgbarLoggeron_epoch_end方法一样传入epoch变量,从而获取到当前学习过程进行到哪一个epoch中了
  • 最后就是一个暂时是Nonelogs变量

那么,如何去验证呢?如果我们有一定的Java基础的话,那么我们其实大概可以猜出来,所有的东西都是有一个接口可以实现,或者一个抽象类可以继承。那么Tensorflow这种超大体量的框架也大概需要借鉴这种设计思想,否则很多东西都会乱糟糟的,没有一个统一的规范。所以,我们寻找一下有没有这类东西。

当然,我最终也是找到了:其实就是Callback类,他的注释是:

Abstract base class used to build new callbacks.

其中的on_epoch_beginon_epoch_endon_train_beginon_train_end等方法都是可以让子类实现的。

当然,在这里官方也很贴心的给了一个例子:

'''
Example:

  >>> training_finished = False
  >>> class MyCallback(tf.keras.callbacks.Callback):
  ...   def on_train_end(self, logs=None):
  ...     global training_finished
  ...     training_finished = True
  >>> model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
  >>> model.compile(loss='mean_squared_error')
  >>> model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),
  ...           callbacks=[MyCallback()])
  >>> assert training_finished == True
'''

看来我们的猜想是正确的。

这么一想我这找了这么久都是没用的吗

拼图凑齐了!

找了这么久,我们所需要了解的一切就都明白了。

那就自定义一个进度条:

import tensorflow as tf
class TensorflowProgressBar(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs = None):
    print('\r', f'Now Processing: {epoch}, Progress: {round(epoch / EPOCHS * 100, 2)}%',
                end = '', flush = True)

这样的话,每当一个epoch结束的时候,就会显示当前是第几个epoch,并计算出当前进度的百分比。

这样的话就能在等结果的过程中做点别的事情,顺便时不时抬头看一眼进度。当老板问到的时候,随便看一下百分比就能报告,非常方便。

最后,在fit方法里调用一下:

record = model.fit(X_train, y_train,
                   batch_size = BATCH_SIZE, epochs = EPOCHS,
                   callbacks=[TensorflowProgressBar()], verbose = 0)

其中,verbose置为 1 1 1的话,Tensorflow还会继续输出大量的进度,这是我们并不想看到的。所以为了让他只输出我们想要看到的进度,就必须将verbose置为 0 0 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/705164.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C# Excel表列名称

168 Excel表列名称 给你一个整数 columnNumber ,返回它在 Excel 表中相对应的列名称。 例如: A -> 1 B -> 2 C -> 3 … Z -> 26 AA -> 27 AB -> 28 … 示例 1: 输入:columnNumber 1 输出:“A”…

Unity与Android交互(4)——接入SDK

【前言】 unity接入Android SDK有两种方式,一种是把Unity的工程导出google project的形式进行接入,另一种是通过把Android的工程做成Plugins的形式进行接入。我们接入SDK基本都是将SDK作为插件的形式接入的。 对我们接入SDK的人来说,SDK也是…

一文了解PoseiSwap的质押系统

PoseiSwap 正在向订单簿 DEX 领域深度的布局,并有望成为订单簿 DEX 领域的早期开创者。

jmeter发送请求的request body乱码问题解决

JMeter的Put请求,响应结果中文出现乱码的解决方法 原文地址: http://www.taodudu.cc/news/show-808374.html?actiononClick

【云原生丶Kubernetes】从应用部署的发展看Kubernetes的前世今生

在了解Kubernetes之前,我们十分有必要先了解一下应用程序部署的发展历程,下面让我们一起来看看! 应用部署的发展历程 我们先来看看应用程序部署的3个阶段:从物理机部署到虚拟机部署,再到容器化部署,他们之…

Jenkins服务器连接JMeter分布式中的test-master

Jenkins想要连接test-master就要通过代理 将下载好的agent.jar传输到test-master机器上的/usr/local(实际上任何目录都可以)下 然后我们在/usr/local目录下输入: (这个是在Jenkins页面自己生成的命令) java -jar ag…

SQL频率低但笔试会遇到: 触发器、索引、外键约束

一. 前言 在SQL面笔试中,对于表的连接方式,过滤条件,窗口函数等肯定是考察的重中之重,但是有一些偶尔会出现,频率比较低但是至少几乎会遇见一两次的题目,就比如触发器,索引和外键约束&#xff0…

C++ 教程

C 教程 C 是一种高级语言,它是由 Bjarne Stroustrup 于 1979 年在贝尔实验室开始设计开发的。C 进一步扩充和完善了 C 语言,是一种面向对象的程序设计语言。C 可运行于多种平台上,如 Windows、MAC 操作系统以及 UNIX 的各种版本。 本教程通过…

Stanford点云公开数据集:S3DIS

S3DIS (Stanford Large-Scale 3D Indoor Spaces Dataset) 是斯坦福大学提供的大场景室内3D点云数据集,包含6个教学和办公Area,总共有695,878,620个带有色彩信息以及语义标签的3D点。 该数据集目前已经被包含在一个更大的Full 2D-3D-S Dataset当中&#x…

深入探讨Seata RPC模块的设计与实现

在Seata中,TM,RM与TC都需要进行跨进程的网络调用,通常来说就会需要RPC来支持远程调用,而Seata内部就有自身基于Netty的RPC实现,这里我们就来看下Seata是如何进行RPC设计与实现的 RPC整体设计 抽象基类AbstractNettyRemoting 该类是…

Dinky:问题总结

一、启动时指定flink版本,因为dinky本身也集成了部分flink ./auto.sh start 1.12 二、数据源管理新增mysql时的url jdbc:mysql://ip:3306/dinky?useUnicodetrue&characterEncodingutf8&useSSLfalse&autoReconnecttrue&failOverReadOnlyfalse 不要…

2、JAVA 分支结构 switch结构 for循环

1 分支结构 1.1 概述 顺序结构的程序虽然能解决计算、输出等问题 但不能做判断再选择。对于要先做判断再选择的问题就要使用分支结构 1.2 形式 1.3.1 练习:商品打折案例 创建包: cn.tedu.basic 创建类: TestDiscount.java 需求: 接收用户输入的原价。满1000打9折…

网工内推 | 运营商招网工,3年以上网安经验,CISP/CCIE认证优先

01 微算互联 招聘岗位:网络工程师 职责描述: 1、负责生产系统的网络管理、网络监控告警、网络设备数据资料备份/恢复容灾管理; 2、海外、tob 、私有云网络搭建及运维7 X 24小时值班; 3、负责业务系统工程网络层面规划、建设、升级…

Java之Javac、JIT、AOT之间的关系

Javac:javac 是java语言编程编译器。全称java compiler。但这时候还是不能直接执行的,因为机器只能读懂汇编,也就是二进制,因此还需要进一步把.class文件编译成二进制文件。 Java的执行过程 详细流程 结论:javac编译后…

使用Python调用aapt命令查看APK文件信息

以下演示内容在window的操作系统 1、下载 aapt 下载完成后注意配置环境变量!!! 地址:https://www.mediafire.com/file/e8ww8wbgcowbti4/aapt 2、代码实现 import os import re import subprocess#获取当前操作系统 current_os o…

S型平滑函数功能块(CODESYS ST完整源代码)

S型平滑函数在多段曲线控温上的应用。完整算法介绍请参看下面文章博客: 带平滑功能的斜坡函数(多段曲线控温纯S型曲线SCL源代码+完整算法分析)_RXXW_Dor的博客-CSDN博客PLC运动控制基础系列之梯形速度曲线,可以参看下面这篇博客:PLC运动控制基础系列之梯形速度曲线_RXXW_…

RTL8720CF烧录工具

一、环境准备 1、 解压烧录工具包 AmebaZII_PGTool_v1.2.39.zip 2、 日志串口接串口调试助手 3、 模组A0脚接高电平 二、烧录 1、模组重新上电,串口有Download Image over …… 输出,表示模组已进入烧录模式,如图: 2、打开上…

springboot+vue校园一卡通管理系统_q7e7o

近些年来,随着科技的飞速发展,互联网的普及逐渐延伸到各行各业中,给人们生活带来了十分的便利,校园一卡通利用计算机网络实现信息化管理,使整个校园一卡通管理的发展和服务水平有显著提升。 本文拟采用java技术和Sprin…

【算法题】动态规划中级阶段之买卖股票的最佳时机、三角形最小路径和

动态规划中级阶段 前言一、三角形最小路径和1.1、思路1.2、代码实现 二、买卖股票的最佳时机 II2.1、思路2.2、代码实现 总结 前言 动态规划(Dynamic Programming,简称 DP)是一种解决多阶段决策过程最优化问题的方法。它是一种将复杂问题分解…

计算机网络————运输层

文章目录 概述UDPTCP首部格式 连接管理连接建立连接释放 概述 从IP层看,通信双方是两个主机。 但真正进行通信的实体是在主机中的进程,是这个主机中的一个进程和另一个主机中的一个进程在交换数据。 所以严格的讲,两个主机进行通信就是两个…