TensorFlow和Keras应如何选择?

news2024/11/15 13:38:23

前些年,深度学习领域的研究人员、开发人员和工程师必须经常做出一些选择:

  1. 我应该选择易于使用但自定义困难的 Keras 库?
  2. 还是应该使用难度更大的 TensorFlow API,编写大量代码?(更不用说一个不那么容易使用的 API 了。)

如果你陷于“应该使用 Keras 还是 TensorFlow”这样的问题,可以退一步再看,其实这是一个错误的问题,因为完成可以选择同时使用两个。

如果你是 TensorFlow 用户,那也应该开始考虑 Keras API 了,因为:

  1. 它是基于 TensorFlow 创建的
  2. 它更易于使用
  3. 当你需要用纯 TensorFlow 实现特定性能或功能时,它可以直接用于你的 Keras。

关于tf.keras基础知识点

  • tf.keras的基本流程
  • tf.keras实现模型构建的方法
  • tf.keras中模型训练验证的相关方法

今天通过鸢尾花分类案例,来给大家介绍tf.keras的基本使用流程。tf.keras使用tensorflow中的高级接口,我们调用它即可完成:

  1. 导入和解析数据集
  2. 构建模型
  3. 使用样本数据训练该模型
  4. 评估模型的效果。

由于与scikit -learn的相似性,接下来我们将通过将Keras与scikit -learn进行比较,介绍tf.Keras的相关使用方法。

1.相关的库的导入

在这里使用sklearn和tf.keras完成鸢尾花分类,导入相关的工具包:

# 绘图
import seaborn as sns
# 数值计算
import numpy as np
# sklearn中的相关工具
# 划分训练集和测试集
from sklearn.model_selection import train_test_split
# 逻辑回归
from sklearn.linear_model import LogisticRegressionCV
# tf.keras中使用的相关工具
# 用于模型搭建
from tensorflow.keras.models import Sequential
# 构建模型的层和激活方法
from tensorflow.keras.layers import Dense, Activation
# 数据处理的辅助工具
from tensorflow.keras import utils

2.数据展示和划分

利用seborn导入相关的数据,iris数据以dataFrame的方式在seaborn进行存储,我们读取后并进行展示:

# 读取数据
iris = sns.load_dataset("iris")
# 展示数据的前五行
iris.head()

另外,利用seaborn中pairplot函数探索数据特征间的关系:

# 将数据之间的关系进行可视化
sns.pairplot(iris, hue='species')

​将数据划分为训练集和测试集:从iris dataframe中提取原始数据,将花瓣和萼片数据保存在数组X中,标签保存在相应的数组y中:

# 花瓣和花萼的数据
X = iris.values[:, :4]
# 标签值
y = iris.values[:, 4]

利用train_test_split完成数据集划分:

# 将数据集划分为训练集和测试集
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.5, test_size=0.5, random_state=0)

接下来,我们就可以使用sklearn和tf.keras来完成预测

3.sklearn实现

利用逻辑回归的分类器,并使用交叉验证的方法来选择最优的超参数,实例化LogisticRegressionCV分类器,并使用fit方法进行训练:

# 实例化分类器
lr = LogisticRegressionCV()
# 训练
lr.fit(train_X, train_y)

利用训练好的分类器进行预测,并计算准确率:

# 计算准确率并进行打印
print("Accuracy = {:.2f}".format(lr.score(test_X, test_y)))

逻辑回归的准确率为:

Accuracy = 0.93

4.tf.keras实现

在sklearn中我们只要实例化分类器并利用fit方法进行训练,最后衡量它的性能就可以了,那在tf.keras中与在sklearn非常相似,不同的是:

  • 构建分类器时需要进行模型搭建
  • 数据采集时,sklearn可以接收字符串型的标签,如:“setosa”,但是在tf.keras中需要对标签值进行热编码,如下所示:

​有很多方法可以实现热编码,比如pandas中的get_dummies(),在这里我们使用tf.keras中的方法进行热编码:

# 进行热编码
def one_hot_encode_object_array(arr):
    # 去重获取全部的类别
    uniques, ids = np.unique(arr, return_inverse=True)
    # 返回热编码的结果
    return utils.to_categorical(ids, len(uniques))

4.1 数据处理

接下来对标签值进行热编码:

# 训练集热编码
train_y_ohe = one_hot_encode_object_array(train_y)
# 测试集热编码
test_y_ohe = one_hot_encode_object_array(test_y)

4.2 模型搭建

在sklearn中,模型都是现成的。tf.Keras是一个神经网络库,我们需要根据数据和标签值构建神经网络。神经网络可以发现特征与标签之间的复杂关系。神经网络是一个高度结构化的图,其中包含一个或多个隐藏层。每个隐藏层都包含一个或多个神经元。神经网络有多种类别,该程序使用的是密集型神经网络,也称为全连接神经网络:一个层中的神经元将从上一层中的每个神经元获取输入连接。例如,图 2 显示了一个密集型神经网络,其中包含 1 个输入层、2 个隐藏层以及 1 个输出层,如下图所示:

​上图 中的模型经过训练并馈送未标记的样本时,它会产生 3 个预测结果:相应鸢尾花属于指定品种的可能性。对于该示例,输出预测结果的总和是 1.0。该预测结果分解如下:山鸢尾为 0.02,变色鸢尾为 0.95,维吉尼亚鸢尾为 0.03。这意味着该模型预测某个无标签鸢尾花样本是变色鸢尾的概率为 95%。

TensorFlow tf.keras API 是创建模型和层的首选方式。通过该 API,您可以轻松地构建模型并进行实验,而将所有部分连接在一起的复杂工作则由 Keras 处理。

tf.keras.Sequential 模型是层的线性堆叠。该模型的构造函数会采用一系列层实例;在本示例中,采用的是 2 个密集层(分别包含 10 个节点)以及 1 个输出层(包含 3 个代表标签预测的节点)。第一个层的 input_shape 参数对应该数据集中的特征数量:

# 利用sequential方式构建模型
model = Sequential([
  # 隐藏层1,激活函数是relu,输入大小有input_shape指定
  Dense(10, activation="relu", input_shape=(4,)),  
  # 隐藏层2,激活函数是relu
  Dense(10, activation="relu"),
  # 输出层
  Dense(3,activation="softmax")
])

通过model.summary可以查看模型的架构:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 10)                50        
_________________________________________________________________
dense_1 (Dense)              (None, 10)                110       
_________________________________________________________________
dense_2 (Dense)              (None, 3)                 33        
=================================================================
Total params: 193
Trainable params: 193
Non-trainable params: 0
_________________________________________________________________

激活函数可决定层中每个节点的输出形状。这些非线性关系很重要,如果没有它们,模型将等同于单个层。激活函数有很多,但隐藏层通常使用 ReLU。

隐藏层和神经元的理想数量取决于问题和数据集。与机器学习的多个方面一样,选择最佳的神经网络形状需要一定的知识水平和实验基础。一般来说,增加隐藏层和神经元的数量通常会产生更强大的模型,而这需要更多数据才能有效地进行训练。

4.3 模型训练和预测

在训练和评估阶段,我们都需要计算模型的损失。这样可以衡量模型的预测结果与预期标签有多大偏差,也就是说,模型的效果有多差。我们希望尽可能减小或优化这个值,所以我们设置优化策略和损失函数,以及模型精度的计算方法:

# 设置模型的相关参数:优化器,损失函数和评价指标
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=["accuracy"])

接下来与在sklearn中相同,分别调用fit和predict方法进行预测即可。

# 模型训练:epochs,训练样本送入到网络中的次数,batch_size:每次训练的送入到网络中的样本个数
model.fit(train_X, train_y_ohe, epochs=10, batch_size=1, verbose=1);

上述代码完成的是:

  1. 迭代每个epoch。通过一次数据集即为一个epoch。
  2. 在一个epoch中,遍历训练 Dataset 中的每个样本,并获取样本的特征 (x) 和标签 (y)。
  3. 根据样本的特征进行预测,并比较预测结果和标签。衡量预测结果的不准确性,并使用所得的值计算模型的损失和梯度。
  4. 使用 optimizer 更新模型的变量。
  5. 对每个epoch重复执行以上步骤,直到模型训练完成。

训练过程展示如下:

Epoch 1/10
75/75 [==============================] - 0s 616us/step - loss: 0.0585 - accuracy: 0.9733
Epoch 2/10
75/75 [==============================] - 0s 535us/step - loss: 0.0541 - accuracy: 0.9867
Epoch 3/10
75/75 [==============================] - 0s 545us/step - loss: 0.0650 - accuracy: 0.9733
Epoch 4/10
75/75 [==============================] - 0s 542us/step - loss: 0.0865 - accuracy: 0.9733
Epoch 5/10
75/75 [==============================] - 0s 510us/step - loss: 0.0607 - accuracy: 0.9733
Epoch 6/10
75/75 [==============================] - 0s 659us/step - loss: 0.0735 - accuracy: 0.9733
Epoch 7/10
75/75 [==============================] - 0s 497us/step - loss: 0.0691 - accuracy: 0.9600
Epoch 8/10
75/75 [==============================] - 0s 497us/step - loss: 0.0724 - accuracy: 0.9733
Epoch 9/10
75/75 [==============================] - 0s 493us/step - loss: 0.0645 - accuracy: 0.9600
Epoch 10/10
75/75 [==============================] - 0s 482us/step - loss: 0.0660 - accuracy: 0.9867

与sklearn中不同,对训练好的模型进行评估时,与sklearn.score方法对应的是tf.keras.evaluate()方法,返回的是损失函数和在compile模型时要求的指标:

# 计算模型的损失和准确率
loss, accuracy = model.evaluate(test_X, test_y_ohe, verbose=1)
print("Accuracy = {:.2f}".format(accuracy))

分类器的准确率为:

3/3 [==============================] - 0s 591us/step - loss: 0.1031 - accuracy: 0.9733
Accuracy = 0.97

到此我们对tf.kears的使用有了一个基本的认知,在接下来的课程中会给大家解释神经网络以及在计算机视觉中的常用的CNN的使用。

总结

1.使用tf.keras进行分类时的主要流程:

数据处理-构建模型-模型训练-模型验证

2.tf.keras中构建模型可通过squential()来实现并利用.fit()方法进行训练

3.使用evaluate()方法计算损失函数和准确率
 


人工智能自学站:

  • 人工智能开发_人工智能工程师_AI人工智能
  • 人工智能之python编程零基础入门
  • 4天快速入门Python数据挖掘
  • 简单快速入门Python机器学习
  • AI深度学习自然语言处理NLP零基础入门
  • AI-OpenCV图像处理10小时零基础入门
  • 3天带你玩转Python深度学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/108344.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

pyspark之sparksql数据交互

在pyspark中,使用sparksql进行mysql数据的读写处理,将程序保存为test.py #-*- coding: UTF-8 -*- # 设置python的默认编码 import sys reload(sys) sys.setdefaultencoding(utf-8) # Spark 初始化 from pyspark.sql import SQLContext, SparkSession, …

【推荐】DDD领域驱动设计和中台实践资料合集

Domain Driven Design(简称 DDD),又称为领域驱动设计,起源于杰出软件建模专家Eric Evans在2003年发表的书籍《DOMAIN-DRINEN DESIGN —TACKLING COMPLEXITY IN THE HEART OF SOFTWARE》(中文译名《领域驱动设计—软件核…

卓海科技冲刺创业板:拟募资5.47亿 相宇阳控制52.9%股权

雷递网 雷建平 12月20日无锡卓海科技股份有限公司(简称:“卓海科技”)日前递交招股书,准备在深交所创业板上市。卓海科技计划募资5.47亿元,其中,1.04亿元用于半导体前道量检测设备扩产项目,1.84…

ev_api_server:大事件node接口项目开发

Headline 大事件后台 API 项目,API 接口文档请参考 https://www.showdoc.cc/escook?page_id3707158761215217 1. 初始化 1.1 创建项目 新建 api_server 文件夹作为项目根目录,并在项目根目录中运行如下的命令,初始化包管理配置文件&#x…

尚医通-前端Vue学习(十)

目录: (1)node.js介绍 (2)npm包管理工具 (3)es6模块化 (4)babel转码器 (5)webpack打包工具 (1)node.js介绍 浏览器的…

Python使用pandas导入xlsx格式的excel文件内容

Python使用pandas导入xlsx格式的excel文件内容1. 基本导入2. 列标题与数据对齐3. 指定导入某个sheet4. 指定行索引5. 指定列索引6. 指定导入列7. 指定导入的行数8. 更多的参数1. 基本导入 在 Python中使用pandas导入.xlsx文件的方法是read_excel()。 # codingutf-8 import pa…

Windows环境下在VScode中运行开源运动规划库(zhm-real / PathPlanning)的方法

本文主要介绍Windows环境下,在Vscode中运行zhm-real发布的开源运动规划库PathPlanning的实现方法,包括环境配置及运行开源包时常见错误解决方法。    一、环境配置 (1)VScode 下载及安装,官网如下: http…

Flowable工作流进阶使用教程(监听器+流程变量+网关)

一、任务分配和流程变量 1.任务分配 1.1 固定分配 固定分配就是我们前面介绍的,在绘制流程图或者直接在流程文件中通过Assignee来指定的方式 1.2 表达式分配 Flowable使用UEL进行表达式解析。UEL代表Unified Expression Language,是EE6规范的一部分…

【Python】用python将html转化为pdf

其实早在去年就有做过,一直没有写,先简单记录下 1、主要用到的工具【wkhtmltopdf】 【下载地址】wkhtmltopdf 根据系统选择安装包,速度有点慢,先挂着 2、下载Python库 pip install pdfkit pip install wkhtmltopdf 3、简单代码…

CAD教程:CAD自定义之基础设置的操作技巧

在使用国产CAD软件绘制CAD图纸的过程中,有些时候会需要CAD自定义设置,那么你知道浩辰CAD建筑软件中CAD自定义之基础设置怎么使用吗?不知道也没关系,接下来的CAD教程就让小编来给大家介绍一下国产CAD软件——浩辰CAD建筑软件中CAD自…

【1799. N 次操作后的最大分数和】

来源:力扣(LeetCode) 描述: 给你 nums ,它是一个大小为 2 * n 的正整数数组。你必须对这个数组执行 n 次操作。 在第 i 次操作时(操作编号从 1 开始),你需要: 选择两个…

实验一 逻辑回归

一、实验目的 (1)学习并掌握常见的机器学习方法; (2)能够结合所学的python知识实现机器学习算法; (3)能够用所学的机器学习算法解决实际问题。 二、实验内容与要求 &#xff08…

设计模式之备忘录模式

Memento design pattern 备忘录模式的概念、备忘录模式的结构、备忘录模式的优缺点、备忘录模式的使用场景、备忘录模式的实现示例、备忘录模式的源码分析 1、备忘录模式的概念 备忘录模式,又称快照模式,即在不破坏封装的前提下,获取并保存一…

【数电】Simulation Test 模拟测试

一、 选择题:(共20分,每小题2分) 1、逻辑函数的所有最小项之和等于多少? A. 0 B. 1 C. 0或1 D. 任意值 2、与非门的多余输入端应如何处理?…

MySQL面试常问问题(基础) —— 赶快收藏

目录 1. 什么是内连接、外连接、交叉连接、笛卡尔积呢? 2. 那MySQL 的内连接、左连接、右连接有有什么区别? 3.说一下数据库的三大范式? 4.varchar与char的区别? 5.blob和text有什么区别? 6.DATETIME和TIMESTAMP…

SCSS学习笔记

文章目录1.安装scss2.选择器嵌套3.属性嵌套4.父选择器&5.变量5.1变量的规范5.2变量的作用域5.3给变量设置默认值(!default)6数据类型7.运算符8.插值语法9.流程控制语句9.1 条件语句9.2循环语句9.2.1for9.2.2each9.2.3while10import10.1引入scss不编译10.2嵌套引入scss11.mi…

【软件测试】概念篇

目录 一、需求 1.1用户需求 1.2软件需求 1.3需求的重要性 二、测试用例 三、BUG 3.1什么是BUG 3.2如何描述一个BUG 4.3BUG优先级 四、软件开发模型 4.1软件生命周期 4.2开发模型 定义:软件测试就是一系列活动,这些活动是为了评估一个程序或者…

新店速递 | IU酒店带您领略“东方古罗马”

淄博,位处鲁中,是黄河三角洲生态经济和蓝色经济区的交汇处。四季分明的气候造就了这座齐国故都的生态多样性,南高北低的地理位置使其峻岭平原兼具,鲁中的位置又赋予他交通枢纽的重要性。这里历史气息浓厚,社会文化自由…

中文语法纠错全国大赛获奖分享:基于多轮机制的中文语法纠错

中文语法纠错任务旨在对文本中存在的拼写、语法等错误进行自动检测和纠正,是自然语言处理领域一项重要的任务。同时该任务在公文、新闻和教育等领域都有着落地的应用价值。但由于中文具有的文法和句法规则比较复杂,基于深度学习的中文文本纠错在实际落地…

新能源汽车市场渗透率不断提高,锂电设备需求空间较大

根据观研报告网发布的《中国新能源汽车行业发展深度研究与投资趋势调研报告(2022-2029年)》显示,近年来,随着各国开启能源转型,在汽车领域,由于电动汽车具有高效节能、零排放等优点,已逐渐成为汽…