政安晨:【Keras机器学习实践要点】(三)—— 编写组件与训练数据

news2025/1/9 10:19:59

政安晨的个人主页政安晨

欢迎 👍点赞✍评论⭐收藏

收录专栏: TensorFlow与Keras实战演绎机器学习

希望政安晨的博客能够对您有所裨益,如有不足之处,欢迎在评论区提出指正!

介绍

通过 Keras,您可以编写自定义层、模型、度量指标、损失和优化器,并在同一代码库中跨 TensorFlow、JAX 和 PyTorch 运行

老规矩,咱们还是先准备环境(参考我本专栏目录中的文章,其中有搭建环境的部分):

政安晨:【TensorFlow与Keras实战演绎机器学习】专栏 —— 目录icon-default.png?t=N7T8https://blog.csdn.net/snowdenkeke/article/details/136985399

准备好环境后,咱们开始。

编写组件

让我们先来看看自定义层

{keras.ops 命名空间包含}
1. NumPy API 的实现,例如 keras.ops.stack 或 keras.ops.matmul
2. 一组 NumPy 中没有的神经网络特定操作,如 keras.ops.conv 或 keras.ops.binary_crossentropy

让我们创建一个可与所有后端配合使用的自定义密集层

class MyDense(keras.layers.Layer):
    def __init__(self, units, activation=None, name=None):
        super().__init__(name=name)
        self.units = units
        self.activation = keras.activations.get(activation)

    def build(self, input_shape):
        input_dim = input_shape[-1]
        self.w = self.add_weight(
            shape=(input_dim, self.units),
            initializer=keras.initializers.GlorotNormal(),
            name="kernel",
            trainable=True,
        )

        self.b = self.add_weight(
            shape=(self.units,),
            initializer=keras.initializers.Zeros(),
            name="bias",
            trainable=True,
        )

    def call(self, inputs):
        # Use Keras ops to create backend-agnostic layers/metrics/etc.
        x = keras.ops.matmul(inputs, self.w) + self.b
        return self.activation(x)

接下来,让我们制作一个依赖于keras.random命名空间的自定义Dropout层

class MyDropout(keras.layers.Layer):
    def __init__(self, rate, name=None):
        super().__init__(name=name)
        self.rate = rate
        # Use seed_generator for managing RNG state.
        # It is a state element and its seed variable is
        # tracked as part of `layer.variables`.
        self.seed_generator = keras.random.SeedGenerator(1337)

    def call(self, inputs):
        # Use `keras.random` for random ops.
        return keras.random.dropout(inputs, self.rate, seed=self.seed_generator)

接下来,让我们编写一个自定义子类模型,使用我们的两个自定义层:

class MyModel(keras.Model):
    def __init__(self, num_classes):
        super().__init__()
        self.conv_base = keras.Sequential(
            [
                keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
                keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
                keras.layers.MaxPooling2D(pool_size=(2, 2)),
                keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
                keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
                keras.layers.GlobalAveragePooling2D(),
            ]
        )
        self.dp = MyDropout(0.5)
        self.dense = MyDense(num_classes, activation="softmax")

    def call(self, x):
        x = self.conv_base(x)
        x = self.dp(x)
        return self.dense(x)

让我们编译并适配它:

model = MyModel(num_classes=10)
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
)

model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=1,  # For speed
    validation_split=0.15,
)

现在咱们演绎如下

在本地的TensorFlow虚拟环境中,首先导入keras:

from tensorflow import keras

(可以在Jupyter Notebook中运行)

如果在演绎执行中出错,可能是Keras版本问题,使用如下命令升级keras

sudo pip install --upgrade keras

执行结果:

训练模型

在任意数据源上训练模型

所有的Keras模型都可以在各种数据来源上进行训练和评估,与您使用的后端无关。这包括:

NumPy数组 Pandas数据框 TensorFlow tf.data.Dataset对象 PyTorch DataLoader对象 Keras PyDataset对象 无论您使用TensorFlow、JAX还是PyTorch作为Keras后端,它们都可以工作。

让我们尝试使用PyTorch DataLoader:

import torch

# Create a TensorDataset
train_torch_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_torch_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(x_test), torch.from_numpy(y_test)
)

# Create a DataLoader
train_dataloader = torch.utils.data.DataLoader(
    train_torch_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
    val_torch_dataset, batch_size=batch_size, shuffle=False
)

model = MyModel(num_classes=10)
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
)
model.fit(train_dataloader, epochs=1, validation_data=val_dataloader)

现在让我们尝试使用tf.data来完成这个任务

import tensorflow as tf

train_dataset = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)
test_dataset = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

model = MyModel(num_classes=10)
model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.SparseCategoricalAccuracy(name="acc"),
    ],
)
model.fit(train_dataset, epochs=1, validation_data=test_dataset)


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1544858.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据结构】顺序表的定义

🎈个人主页:豌豆射手^ 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:数据结构 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进…

【牛客】SQL142 对试卷得分做min-max归一化

描述 现有试卷信息表examination_info(exam_id试卷ID, tag试卷类别, difficulty试卷难度, duration考试时长, release_time发布时间): idexam_idtagdifficultydurationrelease_time19001SQLhard602020-01-01 10:00:0029002Chard802020-01-0…

SQLite使用的临时文件(二)

返回:SQLite—系列文章目录 上一篇:SQLite数据库文件损坏的可能几种情况 下一篇:SQLite数据库成为内存中数据库(三) ​ 1. 引言 SQLite的显着特点之一它是一个数据库由一个磁盘文件组成。 这简化了 SQLite 的使用…

【动态规划】Leetcode 62. 不同路径

【动态规划】Leetcode 62. 不同路径 解法 ---------------🎈🎈62. 不同路径 题目链接🎈🎈------------------- 解法 😒: 我的代码实现> 动规五部曲 ✒️确定dp数组以及下标的含义 dp[i][j] 走到i, j这个格子的…

Open WebUI大模型对话平台-适配Ollama

什么是Open WebUI Open WebUI是一种可扩展、功能丰富、用户友好的大模型对话平台,旨在完全离线运行。它支持各种LLM运行程序,包括与Ollama和Openai兼容的API。 功能 直观的界面:我们的聊天界面灵感来自ChatGPT,确保了用户友好的体验。响应…

(四)图像的%2线性拉伸

环境:Windows10专业版 IDEA2021.2.3 jdk11.0.1 OpenCV-460.jar 系列文章: (一)PythonGDAL实现BSQ,BIP,BIL格式的相互转换 (二)BSQ,BIL,BIP存储格式的相互转换算法 (三…

Netty剖析 - 掌握Netty 整体架构脉络

文章目录 Netty 整体结构Core 核心层Protocol Support 协议支持层Transport Service 传输服务层 Netty 逻辑架构网络通信层事件调度层服务编排层组件关系梳理 Netty 源码结构Core 核心层模块Protocol Support 协议支持层模块Transport Service 传输服务层模块 思维导图 Netty 整…

机器学习OpenNLP

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl OpenNLP概述 OpenNLP是一个基于机器学习的自然语言处理开发工具包,它是Apache软件基金会的一个开源项目。OpenNLP支持多种自然语言处理任务,如分词、…

云数据库认识

云数据库概述 说明云数据库厂商概述Amazon 云数据库产品Google 的云数据库产品Microsoft 的云数据库产品 云数据库系统架构UMP 系统概述UMP 系统架构MnesiaRabbitMQZooKeeperLVSController 服务器Proxy 服务器Agent 服务器日志分析服务器 UMP 系统功能容灾 读写分离分库分表资源…

如何利用python 把一个表格某列数据和另外一个表格某列匹配 类似Excel VLOOKUP功能

环境: python3.8.10 Excel2016 Win10专业版 问题描述: 如何利用python 把一个表格某列数据和另外一个表格某列匹配 类似Excel VLOOKUP功能 先排除两表A列空白单元格,然后匹配x1表格和x2表格他们的A列,把x1表格中A列A1-A810范围对应的B列B1-B810数据,匹配填充到x2范围…

android studio忽略文件

右键文件,然后忽略,就不会出现在commit里面了 然后提交忽略文件即可

纯前端调用本机原生Office实现Web在线编辑Word/Excel/PPT,支持私有化部署

在日常协同办公过程中,一份文件可能需要多次重复修改才能确定,如果你发送给多个人修改后再汇总,这样既效率低又容易出错,这就用到网页版协同办公软件了,不仅方便文件流转还保证不会出错。 但是目前一些在线协同Office…

6.1 LIBBPF简介

写在前面 eBPF 是一项革命性的技术! 内核对于很多开发者来说,就是一个像黑洞一样的存在。它是操作系统最核心的存在,管理者我们的整个计算机和外设。基于稳定性,性能和安全性,我们对内核的任何修改往往是慎之又慎。 但是。 eBPF的出现了,它几乎无需通过更改任何内核源…

【数据存储】TIDB和MySQL的区别

1.TIDB和MySQL对比 对比内容MySQLTiDB架构设计一个传统的单机数据库系统,采用主从复制和分区表等方式来实现水平扩展一个分布式的 NewSQL 数据库,采用分布式存储和分布式事务等技术,支持水平扩展和高可用性事务支持 InnoDB 存储引擎来支持事…

开发者的瑞士军刀:DevToys

DevToys: 一站式开发者工具箱,打造高效创意编程体验,让代码生活更加得心应手!—— 精选真开源,释放新价值。 概览 不知道大家是否在windows系统中使用过PowerToys?这是微软研发的一项免费实用的系统工具套…

Selenium 自动化 —— 定位页面元素

更多内容请关注我的 Selenium 自动化 专栏: 入门和 Hello World 实例使用WebDriverManager自动下载驱动Selenium IDE录制、回放、导出Java源码浏览器窗口操作切换浏览器窗口 使用 Selenium 做自动化,我们不仅仅是打开一个网页,这只是万里长…

高防服务器、高防IP、高防CDN的工作原理是什么

高防IP高防CDN我们先科普一下是什么是高防。“高防”,顾名思义,就犹如网络上加了类似像盾牌一样很高的防御,主要是指IDC领域的IDC机房或者线路有防御DDOS能力。 高防服务器主要是比普通服务器多了防御服务,一般都是在机房出口架设…

Oracle 控制文件详解

1、控制文件存储的数据信息 1)数据库名称和数据库唯一标识符(DBID) 2)创建数据库的时间戳 3)有关数据文件、联机重做日志文件、归档重做日志文件的信息 4)表空间信息 5)检查点信息 6)日志序列号…

SAP BTP云上一个JVM与DB Connection纠缠的案例

前言 最近在CF (Cloud Foundry) 云平台上遇到一个比较经典的案例。因为牵扯到JVM (app进程)与数据库连接两大块,稍有不慎,很容易引起不快。 在云环境下,有时候相互扯皮的事蛮多。如果是DB的问题,就会找DB…

Reactor设计模式和Reactor模型

Reactor设计模式 翻译过来就是反应堆,所以Reactor设计模式本质是基于事件驱动。 角色 Handle(事件)EventHandler(事件处理器)ConcreteEventHandler(具体事件处理器)Synchronous Event Demult…