TensorFlow高阶API和低阶API

news2024/10/6 20:25:51

TensorFlow提供了众多的API,简单地可以分类为高阶API和低阶API. API太多太乱也是TensorFlow被诟病的重点之一,可能因为Google的工程师太多了,社区太活跃了~当然后来Google也意识到这个问题,在TensorFlow 2.0中有了很大的改善。本文就简要介绍一下TensorFlow的高阶API和低阶API使用,提供推荐的使用方式。

高阶API(For beginners)

The best place to start is with the user-friendly Keras sequential API. Build models by plugging together building blocks.

TensorFlow推荐使用Keras的sequence函数作为高阶API的入口进行模型的构建,就像堆积木一样:

# 导入TensorFlow, 以及下面的常用Keras层
import tensorflow as tf  
from tensorflow.keras.layers import Flatten, Dense, Dropout
 

# 加载并准备好MNIST数据集
mnist = tf.keras.datasets.mnist
 (x_train, y_train), (x_test, y_test) = mnist.load_data()
 
 # 将样本从0~255的整数转换为0~1的浮点数
 x_train, x_test = x_train / 255.0, x_test / 255.0
 
 # 将模型的各层堆叠起来,以搭建 tf.keras.Sequential 模型
model = tf.keras.models.Sequential([
 Flatten(input_shape=(28, 28)),
 Dense(128, activation='relu'),
 Dropout(0.5),
 Dense(10, activation='softmax')
 ])
 
 
# 为训练选择优化器和损失函数
 model.compile(optimizer='adam',
               loss='sparse_categorical_crossentropy',
               metrics=['accuracy'])
# 训练并验证模型
model.fit(x_train, y_train, epochs=5)
 model.evaluate(x_test,  y_test, verbose=2)

输出的日志:

Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 4s 72us/sample - loss: 0.2919 - accuracy: 0.9156
Epoch 2/5
60000/60000 [==============================] - 4s 58us/sample - loss: 0.1439 - accuracy: 0.9568
Epoch 3/5
60000/60000 [==============================] - 4s 58us/sample - loss: 0.1080 - accuracy: 0.9671
Epoch 4/5
60000/60000 [==============================] - 4s 59us/sample - loss: 0.0875 - accuracy: 0.9731
Epoch 5/5
60000/60000 [==============================] - 3s 58us/sample - loss: 0.0744 - accuracy: 0.9766
10000/1 - 1s - loss: 0.0383 - accuracy: 0.9765
[0.07581, 0.9765]

日志的最后一行有两个数 [0.07581, 0.9765],0.07581是最终的loss值,也就是交叉熵;0.9765是测试集的accuracy结果,这个数字手写体模型的精度已经将近98%.

低阶API(For experts)

The Keras functional and subclassing APIs provide a define-by-run interface for customization and advanced research. Build your model, then write the forward and backward pass. Create custom layers, activations, and training loops.

说到TensorFlow低阶API,最先想到的肯定是tf.Session和著名的sess.run,但随着TensorFlow的发展,tf.Session最后出现在TensorFlow 1.15中,TensorFlow 2.0已经取消了这个API,如果非要使用的话只能使用兼容版本的tf.compat.v1.Session. 当然,还是推荐使用新版的API,这里也是用Keras,但是用的是subclass的相关API以及GradientTape. 下面会详细介绍。

# 导入TensorFlow, 以及下面的常用Keras层
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2Dfrom tensorflow.keras import Model
 

# 加载并准备好MNIST数据集
 mnist = tf.keras.datasets.mnist
 (x_train, y_train), (x_test, y_test) = mnist.load_data()
 
 # 将样本从0~255的整数转换为0~1的浮点数
 x_train, x_test = x_train / 255.0, x_test / 255.0
# 使用 tf.data 来将数据集切分为 batch 以及混淆数据集

batch_size = 32
train_ds = tf.data.Dataset.from_tensor_slices(
     (x_train, y_train)).shuffle(10000).batch(batch_size)
 test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)
# 使用 Keras 模型子类化(model subclassing) API 构建 tf.keras 模型
class MyModel(Model):
 def __init__(self):
 super(MyModel, self).__init__()
 self.flatten = Flatten()
 self.d1 = Dense(128, activation='relu')
 self.dropout = Dropout(0.5)
 self.d2 = Dense(10, activation='softmax')
 
 def call(self, x):
     x = self.flatten(x)
     x = self.d1(x)
     x = self.dropout(x)
 return self.d2(x)
 
 model = MyModel()
 
 # 为训练选择优化器和损失函数
 loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
 optimizer = tf.keras.optimizers.Adam()
 
 # 选择衡量指标来度量模型的损失值(loss)和准确率(accuracy)。这些指标在 epoch 上累积值,然后打印出整体结果
 train_loss = tf.keras.metrics.Mean(name='train_loss')
 train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
 
 test_loss = tf.keras.metrics.Mean(name='test_loss')
 test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

 # 使用 tf.GradientTape 来训练模型
 @tf.functiondef train_step(images, labels):
 with tf.GradientTape() as tape:
     predictions = model(images)
     loss = loss_object(labels, predictions)
   gradients = tape.gradient(loss, model.trainable_variables)
   optimizer.apply_gradients(zip(gradients, model.trainable_variables))
 
   train_loss(loss)
   train_accuracy(labels, predictions)

 # 使用 tf.GradientTape 来训练模型
 @tf.functiondef train_step(images, labels):
 with tf.GradientTape() as tape:
     predictions = model(images)
     loss = loss_object(labels, predictions)
   gradients = tape.gradient(loss, model.trainable_variables)
   optimizer.apply_gradients(zip(gradients, model.trainable_variables))
 
   train_loss(loss)
   train_accuracy(labels, predictions)

 # 测试模型
@tf.functiondef test_step(images, labels):
   predictions = model(images)
   t_loss = loss_object(labels, predictions)
 
   test_loss(t_loss)
   test_accuracy(labels, predictions)
 EPOCHS = 5
 
 for epoch in range(EPOCHS):
   for images, labels in train_ds:
     train_step(images, labels)
 
   for test_images, test_labels in test_ds:
     test_step(test_images, test_labels)
 
   template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
   print (template.format(epoch+1,
                          train_loss.result(),
                          train_accuracy.result()*100,
                          test_loss.result(),
                          test_accuracy.result()*100))

输出:

Epoch 1, Loss: 0.13822732865810394, Accuracy: 95.84833526611328, Test Loss: 0.07067110389471054, Test Accuracy: 97.75
Epoch 2, Loss: 0.09080979228019714, Accuracy: 97.25, Test Loss: 0.06446609646081924, Test Accuracy: 97.95999908447266
Epoch 3, Loss: 0.06777264922857285, Accuracy: 97.93944549560547, Test Loss: 0.06325332075357437, Test Accuracy: 98.04000091552734
Epoch 4, Loss: 0.054447807371616364, Accuracy: 98.33999633789062, Test Loss: 0.06611879169940948, Test Accuracy: 98.00749969482422
Epoch 5, Loss: 0.04556874558329582, Accuracy: 98.60433197021484, Test Loss: 0.06510476022958755, Test Accuracy: 98.10400390625

可以看出,低阶API把整个训练的过程都暴露出来了,包括数据的shuffle(每个epoch重新排序数据使得训练数据随机化,避免周期性重复带来的影响)及组成训练batch,组建模型的数据通路,具体定义各种评估指标(loss, accuracy),计算梯度,更新梯度(这两步尤为重要)。如果用户需要对梯度或者中间过程做处理,甚至打印等,使用低阶API可以完全进行完全的控制。

如何选择

从上面的标题也可以看出,对于初学者来说,建议使用高阶API,简单清晰,可以迅速入门。对于专家学者们,建议使用低阶API,可以随心所欲地对具体细节进行改造和加工。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/463883.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Markdown常用数学公式

1 行内公式 在两个美元符号中输入公式即可。 如$Em\times c^2$ 效果: E m c 2 Em\times c^2 Emc2(注:\times是乘的意思) 2 整行公式 在四个美元符号中输入公式,如果想要给公式后面添加编号,那么在公式…

windows安装mongodb6.x并设置用户名密码

安装教程 下载安装设置账号密码利用连接工具设置配置文件重新连接 下载 官网下载地址:点击去下载 安装 这工具很好用的,页面美观,设置账号密码也必不可少,推荐勾选。 设置账号密码 利用连接工具设置 必须选择一个库 use adm…

史上最全Maven教程(三)

文章目录 🔥Maven工程测试_Junit使用步骤🔥Maven工程测试_Junit结果判定🔥Maven工程测试_Before、After🔥依赖冲突调解_最短路径优先原则🔥依赖冲突调解_最先声明原则🔥依赖冲突调解_排除依赖、锁定版本 &a…

onnx手动操作001:onnx.helper

使用onnx.helper可以进行onnx的制造组装操作: 对象描述ValueInfoProto 对象张量名、张量的基本数据类型、张量形状算子节点信息 NodeProto算子名称(可选)、算子类型、输入和输出列表(列表元素为数值元素)GraphProto对象用张量节点和算子节点组成的计算图对象ModelP…

2023年测试岗,自动化测试我该如何进阶?卷出方向...

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 先编程语言打好基…

BPMN2.0 网关

网关(gateway)用于控制执行的流向(或者按BPMN 2.0的用词:执行的“标志(token)”)。网关可以消费(consuming)与生成(generating)标志。 网关用其中带有图标的菱形表示。 排他网关 排他网关(exclusive gateway)(也叫异或网关 XOR gateway,或者更专业的,基于数据…

怎么把pdf压缩的小一点 这3种方式都很简单

在日常工作中,我们常常会遇到PDF文件太大无法上传的情况。这种情况在通过聊天工具传输工作PDF文件资料或在发送附件邮件时尤为常见。如果你也遇到过类似的问题,那么怎么把pdf压缩的小一点?网上的方法虽然很多但是查找起来也是非常费心费力的。…

【正点原子Linux连载】第三章 RKMedia编译和使用 摘自【正点原子】ATK-DLRV1126系统开发手册

第三章 RKMedia编译和使用 5.1 RKMedia编译 Rkmedia是RK官方封装一层简易的API,把RGA、MPP、RKNN等等这些接口封装成高级的接口。在SDK官方的源码目录下,运行以下命令进行跳转: cd external/rkmedia/examples/ ls 运行命令结果如下所示&…

激光雷达“进阶战”:谁在引领新风向?

激光雷达正进入新的发展阶段。 高工智能汽车注意到,伴随激光雷达在2022年第一波小规模前装导入,市场正尝试向中端车型渗透,以逐步迈向快速增长期。在这一阶段,谁能解决成本可控、性能提升的难题,同时帮车企用好激光雷…

鸟哥的Linux私房菜——基础学习篇(第三版) (6-10章)

基础学习篇 第六章 :Linux 的档案权限与目录配置第七章 :档案与目录管理第八章 :Linux 磁盘与文件系统管理第九章 :文件与文件系统的压缩与打包第十章 :Vim程序编辑器 第六章 :Linux 的档案权限与目录配置 …

浅述 国产仪器仪表 6121A 音频分析仪

6121A是具有音频信号产生和音频信号分析功能的测试仪器,适用于语音性能测试和音频功放测试等领域,满足电台、移动通信、音响设备和水声通信设备对频响、谐波失真和信噪比等指标的测试需求,是音频信号性能测试的常备仪器。 6121A音频分析仪具…

Hadoop2.x集群搭建(centos7、VMware、finalshell)

第一章 Hadoop集群安装 1.1 集群规划 集群规划规划操作系统Mac、Windows虚拟软件Parallels Desktop(Mac)、VMWare(Windows)虚拟机主机名: c1, IP地址: 192.168.10.101主机名: c2, IP地址: 192.168.10.102主机名: c3, IP地址: 192.168.10.103软件包上传路径/root/softwares软件…

持续集成下接口自动化测试实践

目录:导读 引言 接口自动化测试工具介绍 接口自动化测试在持续集成中的运用 小结 引言 目前很多持续集成项目都需要执行接口层的测试,当你了解其基本概念,理解了接口协议、如何传参、测试原理后,无需 掌握程序语言&#xff0…

【MySQL高级】——目录结构数据库和文件系统的关系

一、目录结构 <1> 主要目录结构 find / -name mysql<2> 数据库文件目录 目录&#xff1a;/var/lib/mysql/ 配置方式&#xff1a;show variables like ‘datadir’; <3> 相关命令目录 目录&#xff1a;/usr/bin&#xff08;mysqladmin、mysqlbinlog、my…

软件著作权申请流程待发放多久就能到已发放拿到纸质证书?

软件著作权申请一般有两种途径 1、代理 代理机构有加急通道&#xff0c;软件著作权交件后最快20-30工作日内出&#xff0c;待发放到已发放只要3工作日拿到就可以邮寄纸质证书给你了。 2、版权中心官网自己登记 流程比较缓慢&#xff0c;而且最要命的是&#xff0c;证书是用邮…

数据划分方法简述:数据离散化和均值标准差分级法(含python代码)

文章目录 1 问题缘起2. 数据离散化等距离散等频离散聚类离散其他 3. 均值标准差分级 1 问题缘起 在数学建模中&#xff0c;我经常遇到这样一个问题&#xff1a; 在某一步中&#xff0c;需要把数据分成好几个类别或者是按照数据大小分级划分。 放到一维数据中形象一点解释就是…

InstructGPT原理讲解及ChatGPT类开源项目

InstructGPT原理讲解及ChatGPT类开源项目 Generative Pre-Trained Transformer&#xff08;GPT&#xff09; 是OpenAI的提出的生成式预训练语言模型&#xff0c;目前已经发布了GPT-1、GPT-2、GPT-3和GPT-4&#xff0c;未来也将发布GPT-5。 最近非常火的ChatGPT是基于Instruct…

【ChatGPT】稳定性好响应速度快可部署到国内服务器的ChatGPT 强力推荐!

朋友们&#xff0c;大家好&#xff0c;我是 jonssonyan。今天分享一个免费开源的 ChatGPT 项目&#xff0c;它的表现无论是响应速度还是稳定性都比 ChatGPT Plus 还要优秀&#xff0c;只需要有个 Access Token 或者使用热心网友提供的共享账号 就可以免费在线体验&#xff0c;也…

Vicuna-13B量化模型单GPU可跑

链接在这&#xff08;需要科学上网&#xff09; Vicuna-13B: Best Free ChatGPT Alternative According to GPT-4 &#x1f92f; | Tutorial (GPU) 有人在B站转了人家的视频 ChatGPT&#xff1a;在你的本地电脑上运行Vicuna-13B &#x1f92f;|教程 (GPU) 下面就是部署的步骤…

023 - C++ 继承

本期我们学习 C 面向对象编程中的继承。 面向对象编程是一个巨大的编程范式&#xff0c;类之间的继承是它的一个基本面&#xff0c;它是我们可以实际利用的最强大的特性之一。 先了解这些 继承允许我们有一个相互关联的类的层次结构。展开来说&#xff0c;它允许我们有一个包…