Keras 3.0发布:全面拥抱 PyTorch!

news2025/1/15 6:27:59

 

 

Keras 3.0 介绍

https://keras.io/keras_3/

Keras 3.0 升级是对 Keras 的全面重写,引入了一系列令人振奋的新特性,为深度学习领域带来了全新的可能性。

多框架支持

Keras 3.0 的最大亮点之一是支持多框架。Keras 3 实现了完整的 Keras API,并使其可用于 TensorFlow、JAX 和 PyTorch —— 包括一百多个层、数十种度量标准、损失函数、优化器和回调函数,以及 Keras 的训练和评估循环,以及 Keras 的保存和序列化基础设施。所有您熟悉和喜爱的 API 都在这里。

大规模模型训练和部署

新版本的 Keras 为大规模模型训练和部署提供了全新的能力。借助优化的算法和性能改进,现在您可以处理更大规模、更复杂的深度学习模型,而无需担心性能问题。

使用任何来源的数据管道。

Keras 3 的 fit()/evaluate()/predict()例程兼容 tf.data.Dataset 对象、PyTorch 的 DataLoader 对象、NumPy 数组和 Pandas 数据框,无论您使用的是哪个后端。您可以在 PyTorch 的 DataLoader 上训练 Keras 3 + TensorFlow 模型,或者在 tf.data.Dataset 上训练 Keras 3 + PyTorch 模型。

 

案例1:搭配Pytorch训练

https://keras.io/guides/custom_train_step_in_torch/

  • 导入环境

import os

# This guide can only be run with the torch backend.
os.environ["KERAS_BACKEND"] = "torch"

import torch
import keras
from keras import layers
import numpy as np
  • 定义模型

在 train_step() 方法的主体中,实现了一个常规的训练更新,类似于您已经熟悉的内容。重要的是,我们通过 self.compute_loss() 计算损失,它包装了传递给 compile() 的损失函数。

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        # Call torch.nn.Module.zero_grad() to clear the leftover gradients
        # for the weights from the previous train step.
        self.zero_grad()

        # Compute loss
        y_pred = self(x, training=True)  # Forward pass
        loss = self.compute_loss(y=y, y_pred=y_pred)

        # Call torch.Tensor.backward() on the loss to compute gradients
        # for the weights.
        loss.backward()

        trainable_weights = [v for v in self.trainable_weights]
        gradients = [v.value.grad for v in trainable_weights]

        # Update weights
        with torch.no_grad():
            self.optimizer.apply(gradients, trainable_weights)

        # Update metrics (includes the metric that tracks the loss)
        for metric in self.metrics:
            if metric.name == "loss":
                metric.update_state(loss)
            else:
                metric.update_state(y, y_pred)

        # Return a dict mapping metric names to current value
        # Note that it will include the loss (tracked in self.metrics).
        return {m.name: m.result() for m in self.metrics}
  • 训练模型

# Construct and compile an instance of CustomModel
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

# Just use `fit` as usual
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)

案例2:自定义Pytorch流程

https://keras.io/guides/writing_a_custom_training_loop_in_torch/

  • 导入环境

import os

# This guide can only be run with the torch backend.
os.environ["KERAS_BACKEND"] = "torch"

import torch
import keras
from keras import layers
import numpy as np
  • 定义模型、加载数据集

# Let's consider a simple MNIST model
def get_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu")(inputs)
    x2 = keras.layers.Dense(64, activation="relu")(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model


# Create load up the MNIST dataset and put it in a torch DataLoader
# Prepare the training dataset.
batch_size = 32
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784)).astype("float32")
x_test = np.reshape(x_test, (-1, 784)).astype("float32")
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Create torch Datasets
train_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(x_train), torch.from_numpy(y_train)
)
val_dataset = torch.utils.data.TensorDataset(
    torch.from_numpy(x_val), torch.from_numpy(y_val)
)

# Create DataLoaders for the Datasets
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False
)
  • 定义优化器

# Instantiate a torch optimizer
model = get_model()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Instantiate a torch loss function
loss_fn = torch.nn.CrossEntropyLoss()
  • 训练模型

epochs = 3
for epoch in range(epochs):
    for step, (inputs, targets) in enumerate(train_dataloader):
        # Forward pass
        logits = model(inputs)
        loss = loss_fn(logits, targets)

        # Backward pass
        model.zero_grad()
        loss.backward()

        # Optimizer variable updates
        optimizer.step()

        # Log every 100 batches.
        if step % 100 == 0:
            print(
                f"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}"
            )
            print(f"Seen so far: {(step + 1) * batch_size} samples")

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1476421.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

面试笔记系列八之JVM基础知识点整理及常见面试题

类实例化加载顺序 加载:当程序访问某个类时,JVM会首先检查该类是否已经加载到内存中。如果尚未加载,则会进行加载操作。加载操作将类的字节码文件加载到内存,并为其创建一个Class对象。 连接(验证、准备、解析&#x…

Qt程序设计-指南针自定义控件实例

本文讲解Qt指南针自定义控件实例。 效果演示 创建指南针类 #ifndef COMPASS_H #define COMPASS_H#include <QWidget> #include <QWidget> #include <QTimer> #include <QPainter> #include <QPen> #include <QDebug> #include <QtMat…

【MATLAB】REMD_ MFE_SVM_LSTM 神经网络时序预测算法

有意向获取代码&#xff0c;请转文末观看代码获取方式~也可转原文链接获取~ 1 基本定义 REMD_MFE_SVM_LSTM神经网络时序预测算法是一种结合了REMD&#xff08;Reservoir Enhanced Multi-scale Deep Learning&#xff09;算法、多尺度特征提取&#xff08;MFE&#xff09;、支持…

许多主要新闻媒体正屏蔽 OpenAI 爬虫

自OpenAI的内容生成式人工智能模型面世以来&#xff0c;大量互联网数据成为了不断训练和优化模型的“饵料”&#xff0c;但据路透社研究所的一项调查&#xff0c;有越来越多的新闻媒体已对OpenAI的数据爬取说“不”&#xff0c;在传统媒体领域&#xff0c;这一比例甚至超过了50…

数仓模型设计方法论

在当今大数据时代&#xff0c;数据已经成为企业最重要的资产之一。而数据仓库作为企业数据管理和分析的核心基础设施&#xff0c;其设计方法论对于企业的数据治理和决策分析至关重要。本文将探索数仓模型设计的方法论&#xff0c;帮助读者更好地理解和应用数仓模型设计。 一、…

仿牛客网项目---社区首页的开发实现

从今天开始我们来写一个新项目&#xff0c;这个项目是一个完整的校园论坛的项目。主要功能模块&#xff1a;用户登录注册&#xff0c;帖子发布和热帖排行&#xff0c;点赞关注&#xff0c;发送私信&#xff0c;消息通知&#xff0c;社区搜索等。这篇文章我们先试着写一下用户的…

EAP-TLS实验之Ubuntu20.04环境搭建配置(FreeRADIUS3.0)(四)

该篇主要介绍了利用配置ca.cnf、server.cnf、client.cnf在certs路径下生成证书文件&#xff08;非执行bootstrap脚本&#xff0c;网上也有很多直接通过openssl命令方式生成的文章&#xff09;&#xff0c;主要参考&#xff08;概括中心思想&#xff09;官方手册&#xff0c;以及…

2024年阿里云2核4G配置服务器测评_ECS和轻量性能测评

阿里云2核4G服务器多少钱一年&#xff1f;2核4G服务器1个月费用多少&#xff1f;2核4G服务器30元3个月、85元一年&#xff0c;轻量应用服务器2核4G4M带宽165元一年&#xff0c;企业用户2核4G5M带宽199元一年。本文阿里云服务器网整理的2核4G参加活动的主机是ECS经济型e实例和u1…

安卓之ContentProvider的应用场景以及优劣分析

摘要 本文旨在对Android开发中的ContentProvider进行深入探讨。ContentProvider是Android系统中四大组件之一&#xff0c;主要用于在不同的应用程序之间共享数据。本文首先对ContentProvider进行概述&#xff0c;然后分析其应用场景&#xff0c;接着对其优势和劣势进行分析&…

简单模板2(HTML)

紧接上回&#xff0c;简单模板2又来了&#xff0c;喜欢赶紧点个赞吧&#xff0c;希望大家喜欢&#xff01; 效果图&#xff1a; CODE&#xff1a; <!DOCTYPE html> <html> <head><title>我的第一个网页</title> </head> <body><…

微信小程序订阅消息前后端示例

微信小程序的订阅消息&#xff0c; 必须是由弹框&#xff0c;弹框&#xff0c;弹框来调起了&#xff0c;单纯的在页面上调用 wx.requestSubscribeMessage是没有效果的 小程序端的代码 <view class"sub" bindtap"dinyuxiaoxi">订阅消息</view>…

【深度学习】SDXL-Lightning 体验,gradio教程,SDXL-Lightning 论文

文章目录 资源SDXL-Lightning 论文 资源 SDXL-Lightning论文&#xff1a;https://arxiv.org/abs/2402.13929 gradio教程&#xff1a;https://blog.csdn.net/qq_21201267/article/details/131989242 SDXL-Lightning &#xff1a;https://huggingface.co/ByteDance/SDXL-Light…

SpringCloud Eureka(注册中心)

一、spring cloud简介 spring cloud 为开发人员提供了快速构建分布式系统的一些工具&#xff0c;包括配置管理、服务发现、断路器、路由、微代理、事件总线、全局锁、决策竞选、分布式会话等等。它运行环境简单&#xff0c;可以在开发人员的电脑上跑。另外说明spring cloud是基…

sql基本语法+实验实践

sql语法 注释&#xff1a; 单行 --注释内容# 注释内容多行 /* 注释内容 */数据定义语言DDL 查询所有数据库 show databases;注意是databases而不是database。 查询当前数据库 select database();创建数据库 create database [if not exists] 数据库名 [default charset 字符…

Qt中关于信号与槽函数的思考

信号与槽函数的思考 以pushbutton控件为例&#xff0c;在主界面上放置一个pushbutton控件&#xff0c;点击右键选择关联槽函数&#xff0c;关联一个click函数&#xff0c;如下图所示&#xff1a; 在该函数中&#xff0c;实现了一个点击pushbutton按钮后&#xff0c;弹出一个窗…

复制策略深入探讨

在之前的博客中&#xff0c;我们讨论了复制最佳实践和不同类型的复制&#xff0c;例如批量、站点和存储桶。但是&#xff0c;随着所有这些不同类型的复制类型的出现&#xff0c;人们不得不想知道在哪里使用哪种复制策略&#xff1f;从现有 S3 兼容数据存储迁移数据时&#xff0…

部署PhotoMaker通过堆叠 ID 嵌入自定义逼真的人物照片

PhotoMaker只需要一张人脸照片就可以生成不同风格的人物照片&#xff0c;可以快速出图&#xff0c;无需额外的LoRA培训。 安装环境 python 3.10gitVisual Studio 2022 安装依赖库 git clone https://github.com/bmaltais/PhotoMaker.git cd PhotoMaker python -m venv venv…

借助 Aspose.Words,使用 C#、Java、Python 和 C++ 创建 Word 文档

以编程方式创建和操作 Word 文档是许多应用程序的常见要求。幸运的是&#xff0c;有各种编程语言的强大库可以简化此任务。Aspose.Words 就是此类多功能解决方案之一&#xff0c;它是强大的 API&#xff0c;使开发人员能够无缝生成、修改和转换 Word 文件。在这篇博文中&#x…

tomcat下载安装配置教程

tomcat下载安装配置教程 我是使用tomcat下载安装及配置教程_tomcat安装-CSDN博客 此贴来进行安装配置&#xff0c;原文21年已经有些许不同。 下载tomcat 官网&#xff1a;http://tomcat.apache.org/ 我们老师让安装8.5以上&#xff0c;所以我直接选择版本9 点击9页面之后…

定制开发一款家政小程序,应知应会

引言 在这个快节奏的现代生活中&#xff0c;人们对高效、便捷的家政服务的需求日益增加。随着社会结构的变化和职业生活的繁忙&#xff0c;许多家庭面临着时间不足、精力不济的挑战。在这种情况下&#xff0c;家政服务成为解决问题的有效途径。然而&#xff0c;传统的家政服务…