keras学习之回调函数的使用

news2025/1/16 5:56:31

回调函数

  • 回调函数是一个对象(实现了特定方法的类实例),它在调用fit()时被传入模型,并在训练过程中的不同时间点被模型调用
  • 可以访问关于模型状态与模型性能的所有可用数据
  • 模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前状态。
  • 提前终止(early stopping):如果验证损失不再改善,则中断训练(当然,同时保存在训练过程中的最佳模型)。
  • 在训练过程中动态调节某些参数值:比如调节优化器的学习率。
  • 在训练过程中记录训练指标和验证指标,或者将模型学到的表示可视化(这些表示在不断更新):fit()进度条实际上就是一个回调函数。

fit()方法中使用callbacks参数


# 这里有两个callback函数:早停和模型检查点
callbacks_list=[
    keras.callbacks.EarlyStopping(
        monitor="val_accuracy",#监控指标
        patience=2 #两轮内不再改善中断训练
    ),
    keras.callbacks.ModelCheckpoint(
        filepath="checkpoint_path",
        monitor="val_loss",
        save_best_only=True
    )
]
#模型获取
model=get_minist_model()
model.compile(optimizer="rmsprop",
             loss="sparse_categorical_crossentropy",
             metrics=["accuracy"])

model.fit(train_images,train_labels,
         epochs=10,callbacks=callbacks_list, #该参数使用回调函数
         validation_data=(val_images,val_labels))

test_metrics=model.evaluate(test_images,test_labels)#计算模型在新数据上的损失和指标
predictions=model.predict(test_images)#计算模型在新数据上的分类概率

训练结果

模型的保存和加载

#也可以在训练完成后手动保存模型,只需调用model.save('my_checkpoint_path')。
#重新加载模型
model_new=keras.models.load_model("checkpoint_path.keras")

通过对Callback类子类化来创建自定义回调函数

on_epoch_begin(epoch, logs) ←----在每轮开始时被调用
on_epoch_end(epoch, logs) ←----在每轮结束时被调用
on_batch_begin(batch, logs) ←----在处理每个批量之前被调用
on_batch_end(batch, logs) ←----在处理每个批量之后被调用
on_train_begin(logs) ←----在训练开始时被调用
on_train_end(logs ←----在训练结束时被调用

from matplotlib import pyplot as plt
# 实现记录每一轮中每个batch训练后的损失,并为每个epoch绘制一个图
class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs):
        self.per_batch_losses = []

    def on_batch_end(self, batch, logs):
        self.per_batch_losses.append(logs.get("loss"))

    def on_epoch_end(self, epoch, logs):
        plt.clf()
        plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,
                 label="Training loss for each batch")
        plt.xlabel(f"Batch (epoch {epoch})")
        plt.ylabel("Loss")
        plt.legend()
        plt.savefig(f"plot_at_epoch_{epoch}")
        self.per_batch_losses = [] #清空,方便下一轮的技术
model = get_mnist_model()
model.compile(optimizer="rmsprop",
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])
model.fit(train_images, train_labels,
          epochs=10,
          callbacks=[LossHistory()],
          validation_data=(val_images, val_labels))

在这里插入图片描述

【其他】模型的定义 和 数据加载

def get_minist_model():
    inputs=keras.Input(shape=(28*28,))
    features=layers.Dense(512,activation="relu")(inputs)
    features=layers.Dropout(0.5)(features)
    outputs=layers.Dense(10,activation="softmax")(features)
    model=keras.Model(inputs,outputs)
    return model
    
#datset
from tensorflow.keras.datasets import mnist
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
train_images=train_images.reshape((60000,28*28)).astype("float32")/255
test_images=test_images.reshape((10000,28*28)).astype("float32")/255
train_images,val_images=train_images[10000:],train_images[:10000]
train_labels,val_labels=train_labels[10000:],train_labels[:10000]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/402258.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SAP PO】X-DOC:SAP PO 接口配置 REST 服务对接填坑记

X-DOC:SAP PO 接口配置 REST 服务对接填坑记1、背景2、PO SLD配置3、PO https证书导入1、背景 (1)需求背景: SAP中BOM频繁变更,技术人员在对BOM进行变更后,希望及时通知到相关使用人员 (2&…

配天智造自主原创数字工厂:百余名员工人均创收122万

配天智造(832223)2022年度报告显示,报告期内公司实现营业收入1.3亿元,同比增长52%,归属于挂牌公司股东的净利润3867万元,同比增长28.11%。而这家公司全部在职员工仅有107人,人均创收约为122万。…

计算机科学导论笔记(七)

目录 九、程序设计语言 9.1 演化 9.1.1 机器语言 9.1.2 汇编语言 9.1.3 高级语言 9.2 翻译 9.2.1 编译 9.2.2 解释 9.2.3 翻译过程 9.3 编程模式 9.3.1 面向过程模式 9.3.2 面向对象模式 9.3.3 函数式模式 9.3.4 声明式模式 9.4 共同概念 九、程序设计语言 9.1 …

Spring Cloud Alibaba全家桶(六)——微服务组件Sentinel介绍与使用

前言 本文小新为大家带来 微服务组件Sentinel介绍与使用 相关知识,具体内容包括分布式系统存在的问题,分布式系统问题的解决方案,Sentinel介绍,Sentinel快速开始(包括:API实现Sentinel资源保护,…

ABAQUS免费培训 Abaqus成型 焊接 疲劳多工况课程

一、详解Abaqus多工况分析在工程中,多工况的情况是普遍存在的情况,而单工况孤立存在是十分理想状态下的假设。例如我们在进行强度分析时,都是假设其本身是不存在应力的,然后基于这种无初始应力下的计算,使得我们不得不…

aop实现接口访问频率限制

引言 项目开发中我们有时会用到一些第三方付费的接口,这些接口的每次调用都会产生一些费用,有时会有别有用心之人恶意调用我们的接口,造成经济损失;或者有时需要对一些执行时间比较长的的接口进行频率限制,这里我就简…

OpenGL超级宝典学习笔记:纹理

前言 本篇在讲什么 本篇章记录对OpenGL中纹理使用的学习 本篇适合什么 适合初学OpenGL的小白 本篇需要什么 对C语法有简单认知 对OpenGL有简单认知 最好是有OpenGL超级宝典蓝宝书 依赖Visual Studio编辑器 本篇的特色 具有全流程的图文教学 重实践,轻理…

MP4文件播放不了是什么原因?原因及解决办法分享!

为什么mp4文件播放不了?常见的有三种原因,可能是由于视频流或音频流不兼容导致,可能是由于视频文件损坏,也可能是因为电脑上缺乏编解码器。下面小编根据mp4文件无法播放的三种可能进行针对性解答。 原因一:视频流或音频…

基于SSM的学生竞赛模拟系统

基于SSM的学生竞赛模拟系统 ✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取项目下载方式🍅 一、项目背景介绍&#x…

DPU54国产全速USB1.1HUB控制器芯片替代AU9254

目录DPU54简介结构框图DPU54主要特性性能特点典型应用领域DPU54简介 DPU54是高性能、低功耗4口全速 USB1.1 HUB 控制器芯片,上行端口兼容全速 12MHz 模式,4 个下行端口兼容全速 12MHz、低速 1.5MHz 两种模式。 DPU54采用状态机单事务处理架构&#xff0…

windows 11系统,通过ip地址远程连接连接ubuntu 22.04系统(共同局域网下,另一台主机不需要联网)

windows 11系统,通过ip地址远程连接连接ubuntu 22.04系统(不需要联网)问题来源问题分析解决方案问题来源 自己搭建了一台ubuntu系统作为深度学习的机器,但是学校的网络问题,一个账号只能同时登录3台设备。通过远程连接…

C#完全掌握控件之-combbox

无论是QT还是VC,这些可视化编程的工具,掌握好控件的用法是第一步,C#的控件也不例外,尤其这些常用的控件。常见控件中较难的往往是这些与数据源打交道的,比如CombBox、ListBox、ListView、TreeView、DataGridView. 文章…

JUC并发编程之HashMap(jdk1.7版本)-底层源码探究

目录 JUC并发编程之HashMap(jdk1.7版本)-底层源码探究 HashMap底层源码 - jdk1.7 基本概念 -采取层层递进,问答式 存储Key-Value的结构 常量和成员变量 构造方法 put方法 inflateTable方法 hash方法 indexFor方法 addEntry方法 resize方法 createEntry…

JVM 运行时数据区(数据区组成表述,程序计数器,java虚拟机栈,本地方法栈)

JVM 运行时数据区JVM 运行时数据区3.1运行时的数据区组成概述3.1.1程度计数器3.1.2java虚拟机栈3.1.3本地方法栈3.1.4java堆3.1.5方法区3.2程序计数器3.3java虚拟机栈3.4本地方法栈JVM 运行时数据区 堆,方法区(元空间) 主要用来存放数据 是线程共享的. 程序计数器,本地方法栈…

Leetcode.1590 使数组和能被 P 整除

题目链接 Leetcode.1590 使数组和能被 P 整除 Rating : 2039 题目描述 给你一个正整数数组 nums,请你移除 最短 子数组(可以为 空),使得剩余元素的 和 能被 p整除。 不允许 将整个数组都移除。 请你返回你需要移除的…

Java中IO流中字节流(FileInputStream(read、close)、FileOutputStream(write、close、换行写、续写))

IO流:存储和读取数据的解决方案 纯文本文件:Windows自带的记事本打开能读懂 IO流体系: FileInputStream:操作本地文件的字节输入流,可以把本地文件中的数据读取到程序中来 书写步骤:①创建字节输入流对象 …

cento7安装docker

1.环境说明 root用户,centos7内核版本:3.10.0-1160.88.1.el7.x86_64 可通过一下命令查看当前内核版本 [rootlocalhost ~]# uname -r 3.10.0-1160.88.1.el7.x86_64 这里内核版本为3.10,Linux版本为centos7。 2.使用root命令更新yum包 注意​ …

Redis高频面试题汇总(中)

目录 1.什么是redis事务? 2.如何使用 Redis 事务? 3.Redis 事务为什么不支持原子性 4.Redis 事务支持持久性吗 5.Redis事务基于lua脚本的实现 6.Redis集群的主从复制模型是怎样的? 7.Redis集群中,主从复制的数据同步的步骤 …

有没有好用的设备管理系统推荐?不妨看看这6款

有没有好用的设备管理系统推荐?不妨看看这6款! 在现代社会中,软件已经成为了企业信息化、设备管理等方面必不可少的工具。而设备管理系统是将信息化了设备技术信息与现代化管理相结合,是实现研究级管理信息化的先导。 对于设备管…

p79 Python 开发-sqlmapapiTamperPocsuite

数据来源​​​​​​本文仅用于信息安全学习,请遵守相关法律法规,严禁用于非法途径。若观众因此作出任何危害网络安全的行为,后果自负,与本人无关。 # 知识点: Request 爬虫技术,Sqlmap 深入分析&#x…