政安晨:【Keras机器学习实践要点】(二)—— 给首次接触Keras 3 的朋友

news2024/9/8 23:42:23

政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras实战演绎机器学习

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

介绍

Keras 3是一个深度学习框架,可以与TensorFlow、JAX和PyTorch互换使用。本文将介绍Keras 3的关键工作流程。

从一个常用实验开始

一个MNIST卷积神经网络

让我们再次从机器学习的Hello World开始:训练一个卷积神经网络来对MNIST数字进行分类。

如果您是第一次看到这篇文章,还没有准备好环境,参考我本专栏的目录文章中关于搭建环境的部分:

政安晨:【TensorFlow与Keras实战演绎机器学习】专栏 —— 目录icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/136985399

小伙伴们准备好环境,就跟我一起探索机器学习领域的奇妙吧。

当然,有时我为了能够随时开展机器学习实验,也会使用Kaggle等在线平台进行训练,我的文章以实用为主,大家跟着我的实战过程,应该会有所受益。

导入数据

先执行代码,下面代码获取数据

import keras
import numpy as np

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print("y_train shape:", y_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

构建模型

下面构建我们需要的模型

Keras 提供的不同模型构建选项包括:
- 顺序 API(我们在下文中使用的 API)
- 功能 API(最典型)
- 通过子类化编写自己的模型(用于高x

# Model parameters
num_classes = 10
input_shape = (28, 28, 1)

model = keras.Sequential(
    [
        keras.layers.Input(shape=input_shape),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
        keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
        keras.layers.GlobalAveragePooling2D(),
        keras.layers.Dropout(0.5),
        keras.layers.Dense(num_classes, activation="softmax"),
    ]
)

可以查看模型摘要:

model.summary()

执行如下:

编译与训练

我们使用编译()方法来指定优化器、损失函数和要监控的指标。请注意,对于 JAX 和 TensorFlow 后端,XLA 编译是默认开启的。

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
)

让我们来训练和评估模型。在训练过程中,我们将留出 15% 的验证数据,以监控未见数据的泛化情况。

batch_size = 128
epochs = 20

callbacks = [
    keras.callbacks.ModelCheckpoint(filepath="model_at_epoch_{epoch}.keras"),
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=2),
]

model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=epochs,
    validation_split=0.15,
    callbacks=callbacks,
)
score = model.evaluate(x_test, y_test, verbose=0)

实战演绎如下

在训练过程中,我们会在每个周期结束时保存模型

您也可以像这样保存模型的最新状态:

model.save("final_model.keras")

然后像这样重新加载:

model = keras.saving.load_model("final_model.keras")

接下来,您可以使用 predict() 查询类概率的预测结果:

predictions = model.predict(x_test)

演绎如下:

对于首次接触的小伙伴们而言,这就是全部的基本知识啦。


大家只要能跑起来,我们以后循序渐进,包括智能硬件系统的训练、传感器的感知交互训练等等。

大厦真不是一天建成的,小伙伴们一起加油!


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1544904.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

牛客网BC-33 统计成绩(数组排序思想)

题目如下 --------------------------------------------------------------------------------------------------------------------------------- 思路:以数组形式输入,并将数组顺序(或者逆序)排序,最后输出最大值最…

Mysql数据库:日志管理、备份与恢复

目录 前言 一、MySQL日志管理 1、存放日志和数据文件的目录 2、日志的分类 2.1 错误日志 2.2 通用查询日志 2.3 二进制日志 2.4 慢查询日志 2.5 中继日志 3、日志综合配置 4、查询日志是否开启 二、数据备份概述 1、数据备份的重要性 2、备份类型 2.1 从物理与…

2.6 IDE(集成开发环境)是什么

IDE(集成开发环境)是什么 IDE 是 Integrated Development Environment 的缩写,中文称为集成开发环境,用来表示辅助程序员开发的应用软件,是它们的一个总称。 通过前面章节的学习我们知道,运行 C 语言&…

蓝桥杯单片机快速开发笔记——利用定时器计数器设置定时器

一、基本原理 参考本栏http://t.csdnimg.cn/iPHN0 二、具体步骤 三、主要事项 如果使用中断功能记得打开总中断EA 四、示例代码 void Timer0_Isr(void) interrupt 1 { }void Timer0_Init(void) //10毫秒12.000MHz {AUXR & 0x7F; //定时器时钟12T模式TMOD & 0xF0;…

微服务day06 -- Elasticsearch的数据搜索功能。分别使用DSL和RestClient实现搜索

1.DSL查询文档 elasticsearch的查询依然是基于JSON风格的DSL来实现的。 1.1.DSL查询分类 Elasticsearch提供了基于JSON的DSL(Domain Specific Language)来定义查询。常见的查询类型包括: 查询所有:查询出所有数据,一…

iMazing2024功能强大的iPhone和iPad管理工具

iMazing是一款功能强大的iPhone和iPad管理工具,确实可以作为iTunes的替代品进行数据备份。以下是一些关于iMazing的主要特点和功能: 设备备份:iMazing可以备份iOS设备上的所有数据,包括照片、视频、音乐、应用程序等。与iTunes相比…

docker部署gitlab 报错的问题!!!

1、什么是gitlab? Gitlab是一个用于仓库管理系统的开源项目,使用git作为代码管理工具,并在此基础上搭建起来的web服务。Gitlab有乌克兰程序员DmitriyZaporozhets和ValerySizov开发,它由Ruby写成。后来,一些部分用Go语…

leetcode 225.用队列实现栈 JAVA

题目 思路 1.一种是用双端队列(Deque),直接就可以调用很多现成的方法,非常方便。 2.另一种是用普通的队列(Queue),要实现栈的先入后出,可以将最后一个元素的前面所有元素出队,然后…

tcp和udp分别是什么?udp和tcp的区别

TCP和UDP是计算机网络中常见的两种传输层协议,它们在实际应用中具有不同的特点和用途。本文将对TCP和UDP进行介绍,并分析它们之间的区别。 TCP和UDP分别是什么? TCP(Transmission Control Protocol) TCP是一种面向连…

基于TensorFlow的花卉识别(算能杯)%%%

Anaconda Prompt 激活 TensorFlow CPU版本 conda activate tensorflow_cpu //配合PyCharm环境 直接使用TensorFlow1.数据分析 此次设计的主题为花卉识别,数据为TensorFlow的官方数据集flower_photos,包括5种花卉(雏菊、蒲公英、玫瑰、向日葵…

论文汇总:A Closer Look at Few-shot Classification Again

文章汇总 文章是在总体上再一次地观察如何小样本领域存在的问题,并且发现了较为有趣的规律 1.测试误差随训练类别的数量而下降,而不是随每个类别的训练样本数量而下降。 2.训练算法(me:预训练模型)和自适应算法(me:预训练之后的…

适用于 Mac 电脑的 10 个最佳数据恢复工具

对于任何依对于依赖计算机获取重要文件(无论是个人照片还是重要商业文档)的人来说,数据丢失可能是一场噩梦。值得庆幸的是,有多种专门为 Mac 用户提供的数据恢复工具,可以帮助检索丢失或意外删除的文件。在本文中&…

SpringBoot 文件上传(三)

之前讲解了如何接收文件以及如何保存到服务端的本地磁盘中: SpringBoot 文件上传(一)-CSDN博客 SpringBoot 文件上传(二)-CSDN博客 这节讲解如何利用阿里云提供的OSS(Object Storage Service)对象存储服务保存文件。…

SORT/deepSORT_3.8

目标: SORT算法的原理DeepSORT算法的原理 SORT(Simple Online and Realtime Tracking)和DeepSORT是两种广泛应用在计算机视觉领域的多目标跟踪算法。 SORT: SORT算法由Martin Danelljan等人于2016年提出,是一种基于关联的在线实…

Visual Studio项目编译和运行依赖第三方库的项目

1.创建项目,这里创建的项目是依赖于.sln的项目,非CMake项目 2.添加第三方库依赖的头文件和库文件路劲 3.添加第三方依赖库文件 4.项目配置有2个,一个是Debug,一个是Release,如果你只配置了Debug,编译和运行…

TCP重传机制详解——04FACK

文章目录 TCP重传机制详解——04FACK什么是FACKFACK的发展为什么要引入FACK实战抓包讲解开启FACK场景,且达到dup ACK门限值开启FACK场景,未达到dup ACK门限值 为什么要淘汰FACK总结REF TCP重传机制详解——04FACK 什么是FACK FACK的全称是forward ackn…

Spring Aop 源码解析(下)

ProxyFactory选择cglib或jdk动态代理原理 ProxyFactory在生成代理对象之前需要决定到底是使用JDK动态代理还是CGLIB技术: config就是ProxyFactory对象,把自己传进来了,因为ProxyFactory继承了很多类,其中一个父类就是ProxyConfig // config就是ProxyFactory对象// 是不是…

蓝桥杯刷题(十四)

1.小平方 代码 n int(input()) count0 def f(x)->bool: # 判断条件return True if x**2%n<n/2 else False for i in range(1,n): # 遍历[1,n-1]&#xff0c;符合题意计数加一if f(i):count1 print(count)2.3的倍数 代码 a int(input()) b int(input()) c int(input…

HashMap---数据结构

目录 一、基本数据结构 二、树化与退化 三、索引计算 四、put方法和扩容 五、并发问题 六、key的设计 一、基本数据结构 在jdk1.7版本的时候&#xff0c;hashmap结构主要是使用数组 链表的格式&#xff0c;而在jdk1.8版本中&#xff0c;hashmap的数据结构增加了一种“红黑…

小红书矩阵批量发布工具,一键发布笔记软件

昨日&#xff0c;我收到了一条充满渴望与期待的私信&#xff0c;来自一位小红书的矩阵账号博主。他手握多个账号&#xff0c;渴望寻找一款能够助力他批量发布笔记的神器&#xff0c;每日能够轻松达到百篇的发布量。这份迫切的需求&#xff0c;我深感体会&#xff0c;因为这正是…