【需求实现】输入多少就输出多少的拟合任务如何实现(二):进度条简化

news2024/11/18 18:20:08

文章目录

  • 导读
  • 普通的输出方式
  • 上下求索
    • TensorBoard是个不错的切入点
    • 与Callback参数对应的Callback方法
    • 官方的内置Callback
    • 官方进度条
    • 简单的猜测与简单的验证
    • 拼图凑齐了!

导读

在训练模型的过程中往往会有日志一堆一堆的困扰。我并不想知道,因为最后我会在变量里面查询,反正训练过程中也没心思看。于是,就想把进度条简化一下。下面给出解决方案。

普通的输出方式

对于一般的训练过程,我们可能在Tensorflow中的fit方法中,将verbose置为 1 1 1,或者不设置verbose而让Tensorflow默认verbose 1 1 1。这样的话就会有如下图一样长篇大论的输出。

在这里插入图片描述
虽然不至于很烦躁,但是实在不愿意去管这些事情。又不是做生物实验,完全不需要人在这边守着嘛。趁着这时间泡杯咖啡多好。

于是呢,就想着源码里面是如何将输出显示出来的。

上下求索

TensorBoard是个不错的切入点

Tensorflow海量的源码中寻找一个输出无疑是大海捞针,对于Windows用户来说找起来超级麻烦,除非Linux用户直接用grep命令作弊。

但是呢,突然就注意到,Tensorflow还有一个Tensorboard,是在fit方法里面的callback参数中出现。既然日志能够从callback参数中获得,那么这里是有什么玄机吗?

与Callback参数对应的Callback方法

于是找到了文件中tensorflow/tensorflow/python/keras/engine/training.py,也就是官方GitHub的这一页(点击直达callbacks注释那一行),他是这么说明的:

'''
callbacks: List of `keras.callbacks.Callback` instances.
    List of callbacks to apply during training.
    See `tf.keras.callbacks`. Note `tf.keras.callbacks.ProgbarLogger`
    and `tf.keras.callbacks.History` callbacks are created automatically
    and need not be passed into `model.fit`.
    `tf.keras.callbacks.ProgbarLogger` is created or not based on
    `verbose` argument to `model.fit`.
    Callbacks with batch-level calls are currently unsupported with
    `tf.distribute.experimental.ParameterServerStrategy`, and users are
    advised to implement epoch-level calls instead with an appropriate
    `steps_per_epoch` value.
'''

这也就是说,官方已经把进度条内置到ProgbarLogger这个类里面了,并通过callbacks调用。

其中,对于callbacks是这么调用的:

callbacks.on_train_begin()

而这个on_train_begin又是属于CallbackList类中:

# in CallbackList class
def on_train_begin(self, logs=None):
  """Calls the `on_train_begin` methods of its callbacks.

  Args:
      logs: Dict. Currently no data is passed to this argument for this method
        but that may change in the future.
  """
  logs = self._process_logs(logs)
  for callback in self.callbacks:
    callback.on_train_begin(logs)

也就是说,是遍历callbacks中的所有callback然后一一执行。

执行过程的源码在Callback类中(类名跟上一个不一样哦):

# in Callback class
@doc_controls.for_subclass_implementers
def on_train_begin(self, logs=None):
  """Called at the beginning of training.

  Subclasses should override for any actions to run.

  Args:
      logs: Dict. Currently no data is passed to this argument for this method
        but that may change in the future.
  """

但很明显,这个源码就是想让我们自定义。

刚刚好发现进度条也在这里。那接下来的事情就更明确了,去找这个类就行了。一方面是查看进度条的原理,另一方面则是按照官方的进度条仿写一个简单的进度条。

官方的内置Callback

于是就找到了tensorflow/tensorflow/python/keras/callbacks.py文件中,也就是官方GitHub的这一页(点击直达ProgbarLogger类定义的那一行)。但比较可惜的是,这个类定义的时候注释太少了,并不能很确定每个类中的各个方法都在做什么。怎么办呢?

别忘了官方给的提示:【内置了ProgbarLogger类与History类】。我们既然需要了解如何在Callback里面调用,那么就需要了解这两个东西是如何插入进去的。但是这个类注释实在是太少了,很多东西看得不明不白的,该怎么办呢?

当然是【贪心搜索】了呀,按照【命名】与【个人经验】去判断哪一个最可能是我们想要的方法。虽然很不靠谱,但是万一运气好撞上了呢?于是呢,就找到了_add_default_callbacks方法,也就是默认插入的一些东西。他是这么写的:

  def _add_default_callbacks(self, add_history, add_progbar):
    """Adds `Callback`s that are always present."""
    self._progbar = None
    self._history = None

    for cb in self.callbacks:
      if isinstance(cb, ProgbarLogger):
        self._progbar = cb
      elif isinstance(cb, History):
        self._history = cb

    if self._progbar is None and add_progbar:
      self._progbar = ProgbarLogger(count_mode='steps')
      self.callbacks.insert(0, self._progbar)

    if self._history is None and add_history:
      self._history = History()
      self.callbacks.append(self._history)

总之就是一些判断,如果空就创建。

看来插入就是insert方法与append方法了。不难猜测,也不用猜测,callbacks将是一个数组。

官方进度条

既然知道了进度条是如何被调用的,那么接下来就是得了解官方的进度条是怎么添加的。

当然,还是在ProgbarLogger类中,为了方便点击这里就能传送。这里面明显的给出了输出Epoch,正好就是我们需要找到的输出。源码是这样的:

def on_epoch_begin(self, epoch, logs=None):
  self._reset_progbar()
  self._maybe_init_progbar()
  if self.verbose and self.epochs > 1:
    print('Epoch %d/%d' % (epoch + 1, self.epochs))

而且条件是需要verbose 0 0 0。怪不得我们把verbose置为 0 0 0就什么都没有了。

简单的猜测与简单的验证

既然官方这么设计能够输出,那么我们也就简单的提出一个想法:

我们首先需要定义一个类A,然后像这个ProgbarLogger类一样继承自Callback,然后再自定义on_epoch_begin或者on_epoch_end方法。这个方法需要具有三个参数:

  • 首先是on_epoch_begin方法作为A类一个成员的self
  • 其次是与ProgbarLoggeron_epoch_end方法一样传入epoch变量,从而获取到当前学习过程进行到哪一个epoch中了
  • 最后就是一个暂时是Nonelogs变量

那么,如何去验证呢?如果我们有一定的Java基础的话,那么我们其实大概可以猜出来,所有的东西都是有一个接口可以实现,或者一个抽象类可以继承。那么Tensorflow这种超大体量的框架也大概需要借鉴这种设计思想,否则很多东西都会乱糟糟的,没有一个统一的规范。所以,我们寻找一下有没有这类东西。

当然,我最终也是找到了:其实就是Callback类,他的注释是:

Abstract base class used to build new callbacks.

其中的on_epoch_beginon_epoch_endon_train_beginon_train_end等方法都是可以让子类实现的。

当然,在这里官方也很贴心的给了一个例子:

'''
Example:

  >>> training_finished = False
  >>> class MyCallback(tf.keras.callbacks.Callback):
  ...   def on_train_end(self, logs=None):
  ...     global training_finished
  ...     training_finished = True
  >>> model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
  >>> model.compile(loss='mean_squared_error')
  >>> model.fit(tf.constant([[1.0]]), tf.constant([[1.0]]),
  ...           callbacks=[MyCallback()])
  >>> assert training_finished == True
'''

看来我们的猜想是正确的。

这么一想我这找了这么久都是没用的吗

拼图凑齐了!

找了这么久,我们所需要了解的一切就都明白了。

那就自定义一个进度条:

import tensorflow as tf
class TensorflowProgressBar(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs = None):
    print('\r', f'Now Processing: {epoch}, Progress: {round(epoch / EPOCHS * 100, 2)}%',
                end = '', flush = True)

这样的话,每当一个epoch结束的时候,就会显示当前是第几个epoch,并计算出当前进度的百分比。

这样的话就能在等结果的过程中做点别的事情,顺便时不时抬头看一眼进度。当老板问到的时候,随便看一下百分比就能报告,非常方便。

最后,在fit方法里调用一下:

record = model.fit(X_train, y_train,
                   batch_size = BATCH_SIZE, epochs = EPOCHS,
                   callbacks=[TensorflowProgressBar()], verbose = 0)

其中,verbose置为 1 1 1的话,Tensorflow还会继续输出大量的进度,这是我们并不想看到的。所以为了让他只输出我们想要看到的进度,就必须将verbose置为 0 0 0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/699934.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

今天给大家分享几款好用的卸载神器

在日常使用电脑的过程中,我们经常需要安装和卸载各种软件。然而,有时候使用操作系统自带的卸载程序可能无法完全清除程序及其相关文件和注册表项,导致系统出现垃圾文件和残留问题。为了解决这个困扰,今天我将向大家分享几款好用的…

【Spring】基于注解方式存取JavaBean:Spring有几种注入方式?有什么区别?

前言 Hello,我是小黄。众所周知,Spring是一个开源的Java应用程序框架,其中包括许多通过注解实现依赖注入的功能。Spring提供了多种注入方式,可以满足不同的需求和场景。常见的注入方式包括构造函数注入、Setter方法注入和属性注入…

【Redis】Redis的高可用与持久化

文章目录 一、Redis 高可用1. 概念2. 高可用技术以及作用2.1 持久化2.2 主从复制2.3 哨兵2.4 集群 二、Redis 持久化1. 持久化的功能2. Redis 持久化方式 三、RDB 持久化1. 概述2. 触发条件2.1 手动触发2.2 自动触发2.3 其他自动发机制 3. 执行流程4. 启动时加载 四、AOF 持久化…

Modin 入门学习

Modin 是一个 Python 第三方库,用于加速 Pandas 的 API 执行速度。原始的 Pandas 是单线程执行的,而 Modin 则重新打包了 Pandas 里面的 API,使其同时在多个内核中运行,提高硬件性能的利用率。 使用方法很简单,安装 M…

2.9C++多态

C 继承扩展 C继承在实际开发中它可以帮助我们实现代码重用,减少代码冗余,提高代码的可维护性和可扩展性。 通过继承,我们可以从已有的类中派生出新的类,新的类可以继承父类的属性和方法,并且可以添加自己的属性和方法…

selenium元素定位---ElementClickInterceptedException(元素点击交互异常)解决方法

目录 前言: 1、异常原因 2、解决方法: 前言: 当使用Selenium进行元素定位和交互时,可能会遇到ElementClickInterceptedException(元素点击交互异常)的异常。这通常是由于页面上存在其他元素或弹出窗口遮…

ROS学习之基础包创建的详细流程:包括rosnode, rostopic, rosrun,roslaunch等使用

0 引言 本文旨在学习ROS基础包的从零开始创建,包括如何创建一个发布消息节点,一个接收消息节点,还有如何使用roslaunch同时启动多个节点,如何编译ROS工程包等操作。 默认已在Ubuntu系统中安装ROS机器人系统,比如Ubun…

AOP--拦截器

AOP应用--拦截器Spring拦截器拦截器执行流程前缀的添加统一异常处理统一数据返回格式返回String类型 AOP应用–拦截器 AOP的作用:统一功能处理;我们将以三个内容作为学习的掌握点;而这三点也是我们非常迫切需要的 1:用户登录权限…

Windows系统分区大小

Microsoft Reserved(MSR)——保留分区——16MB左右 EFI System Partition(ESP)——系统分区——100MB左右 Recovery Partition(自起名字REP)——恢复分区——450MB左右 其他分区——剩余

对rabbitmq进行压测

添加rabbitmq依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-amqp</artifactId></dependency><dependency><groupId>org.springframework.boot</groupId><artifactI…

C# PaddleOCR标注工具

基于以下开源项目改造的 https://gitee.com/BaoJianQiang/FastOCRLabel 效果 Demo下载

CMake使用gRPC(Protobuf) 的c++ demo

gRPC的命令参数里&#xff0c; 1. 如果要用pacakge&#xff0c;需要--proto_path的参数&#xff0c; 例如helloworld.proto的绝对路径是 /home/user/grpc_demo_ws/grpc_demo/hello_world/proto/helloworld.proto 在helloworld.proto里面的pacakge是 package grpc_demo.hello_w…

同步(通信原理)

同步原理&#xff1a; 在通信系统中&#xff0c;同步是指发送端和接收端在时间上保持一致&#xff0c;使得接收端能够正确地解析和还原发送端发送的信号。同步的原理可以根据具体的通信系统和协议来区分&#xff0c;下面是几种常见的同步原理&#xff1a; 1. 时钟同步&#x…

uniapp 配置chooseLocation微信小程序腾讯地图选点

uniapp 配置chooseLocation微信小程序腾讯地图选点 场景 在uniapp中使用地图选点 搜索功能&#xff0c;回显功能&#xff0c;移动选点功能 使用到的API是uni.chooseLocation 详细看一下都有哪些属性 latitude &#xff1a;目标地纬度 Number longitude &#xff1a;目标地经度…

论文阅读: (CVPR2023 SDT )基于书写者风格和字符风格解耦的手写文字生成及源码对应

目录 引言SDT整体结构介绍代码与论文对应搭建模型部分数据集部分 总结 引言 许久不认真看论文了&#xff0c;这不赶紧捡起来。这也是自己看的第一篇用到Transformer结构的CV论文。之所以选择这篇文章来看&#xff0c;是考虑到之前做过手写字体生成的项目。这个工作可以用来合成…

浅析基于物联网技术的校园能耗智慧监控平台的设计及应用

摘 要&#xff1a;为打造低碳绿色校园&#xff0c;营造良好的学习环境&#xff0c;针对目前校园建筑能耗大&#xff0c;特别是空调节能困难等问题&#xff0c;特采用物联网技术构建校园建筑能耗智慧监控平台。通过设计空调监控子系统&#xff0c;搭建空调监控模型实现了空调等智…

在 Jetpack Compose 中使用 Snackbar

Jetpack Compose 是 Android 的现代 UI 工具库&#xff0c;提供了丰富的组件和功能来构建漂亮、交互丰富的用户界面。在本文中&#xff0c;我们将学习如何在 Jetpack Compose 中使用 Snackbar 组件来显示临时消息或操作反馈。 什么是 Snackbar&#xff1f; Snackbar 是一种用于…

基于Layui实现管理页面

基于Layui实现的后台管理页面&#xff08;仅前端&#xff09; 注&#xff1a;这是博主在帮朋友实现的一个简单的系统前端框架&#xff08;无后端&#xff09;&#xff0c;跟大家分享出来&#xff0c;可以直接将对应菜单跟html文件链接起来&#xff0c;页面使用标签页方式存在&…

面试了一个前阿里P7,Java八股文与架构核心知识简直背得炉火纯青

前几天&#xff0c;跟个老朋友吃饭&#xff0c;他最近想跳槽去大厂&#xff0c;觉得压力很大&#xff0c;问我能不能分享些所谓的经验套路。 每次有这类请求&#xff0c;都觉得有些有趣&#xff0c;不知道你发现没有大家身边真的有很多人不知道怎么面试&#xff0c;也不知道怎…

赛效:如何将PDF文件免费转换成Word文档

1&#xff1a;在网页上打开wdashi&#xff0c;默认进入PDF转Word页面&#xff0c;点击中间的上传文件图标。 2&#xff1a;将PDF文件添加上去之后&#xff0c;点击右下角的“开始转换”。 3&#xff1a;稍等片刻转换成功后&#xff0c;点击绿色的“立即下载”按钮&#xff0c;将…