TensorFlow 03(Keras)

news2025/1/22 19:07:17

一、tf.keras

tf.keras是TensorFlow 2.0的高阶API接口,为TensorFlow的代码提供了新的风格和设计模式,大大提升了TF代码的简洁性和复用性,官方也推荐使用tf.keras来进行模型设计和开发。

1.1 tf.keras中常用模块

如下表所示:

1.2 常用方法

深度学习实现的主要流程:

1.数据获取,

2 数据处理,

3 模型创建与训练,

4 模型测试与评估,

5.模型预测

导入tf.keras

使用 tf.keras,首先需要在代码开始时导入tf.keras

import tensorflow as tf
from tensorflow import keras

数据输入

 对于小的数据集,可以直接使用numpy格式的数据进行训练、评估模型,对于大型数据集或者要进行跨设备训练时使用tf.data.datasets来进行数据输入。

模型构建

  • 简单模型使用Sequential进行构建
  • 复杂模型使用函数式编程来构建
  • 自定义layers

训练与评估

  • 配置训练过程
# 配置优化方法,损失函数和评价指标
model.compile(optimizer=tf.train.AdamOptimizer(0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])
  • 模型训练
# 指明训练数据集,训练epoch,批次大小和验证集数据
model.fit/fit_generator(dataset, epochs=10, 
                        batch_size=3,
          validation_data=val_dataset,
          )
  • 模型评估
# 指明评估数据集和批次大小
model.evaluate(x, y, batch_size=32)
  • 模型预测
# 对新的样本进行预测
model.predict(x, batch_size=32)

回调函数(callbacks)

回调函数用在模型训练过程中,来控制模型训练行为,可以自定义回调函数,也可使用tf.keras.callbacks 内置的 callback :

ModelCheckpoint:定期保存 checkpoints。 LearningRateScheduler:动态改变学习速率。 EarlyStopping:当验证集上的性能不再提高时,终止训练。 TensorBoard:使用 TensorBoard 监测模型的状态。

模型的保存和恢复

  • 只保存参数
# 只保存模型的权重
model.save_weights('./my_model')
# 加载模型的权重
model.load_weights('my_model')
  • 保存整个模型
# 保存模型架构与权重在h5文件中
model.save('my_model.h5')
# 加载模型:包括架构和对应的权重
model = keras.models.load_model('my_model.h5')

二、keras构建模型

 

2.1 相关的库的导入

在这里使用sklearn和tf.keras完成鸢尾花分类,导入相关的工具包:

# 绘图
import seaborn as sns
# 数值计算
import numpy as np
# sklearn中的相关工具
# 划分训练集和测试集
from sklearn.model_selection import train_test_split
# 逻辑回归
from sklearn.linear_model import LogisticRegressionCV
# tf.keras中使用的相关工具
# 用于模型搭建
from tensorflow.keras.models import Sequential
# 构建模型的层和激活方法
from tensorflow.keras.layers import Dense, Activation
# 数据处理的辅助工具
from tensorflow.keras import utils

 

2.2 数据展示和划分

利用seborn导入相关的数据,iris数据以dataFrame的方式在seaborn进行存储,我们读取后并进行展示;

将数据划分为训练集和测试集:从iris dataframe中提取原始数据,将花瓣和萼片数据保存在数组X中,标签保存在相应的数组y中;

# 读取数据
iris = sns.load_dataset("iris")
# 展示数据的前五行
iris.head()

# 花瓣和花萼的数据
X = iris.values[:, :4]
# 标签值
y = iris.values[:, 4]


# 将数据集划分为训练集和测试集
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.5, test_size=0.5, random_state=0)

另外,利用seaborn中pairplot函数探索数据特征间的关系:

# 将数据之间的关系进行可视化
sns.pairplot(iris, hue='species')

2.3 sklearn实现

利用逻辑回归的分类器,并使用交叉验证的方法来选择最优的超参数,实例化LogisticRegressionCV分类器,并使用fit方法进行训练:

# 实例化分类器
lr = LogisticRegressionCV()
# 训练
lr.fit(train_X, train_y)



# 计算准确率并进行打印
print("Accuracy = {:.2f}".format(lr.score(test_X, test_y)))

Accuracy = 0.93

2.4 tf.keras实现

数据准备

在sklearn中我们只要实例化分类器并利用fit方法进行训练,最后衡量它的性能就可以了,那在tf.keras中与在sklearn非常相似,不同的是:

  • 构建分类器时需要进行模型搭建
  • 数据采集时,sklearn可以接收字符串型的标签,如:“setosa”,但是在tf.keras中需要对标签值进行热编码,如下所示:

有很多方法可以实现热编码,比如pandas中的get_dummies(),在这里我们使用tf.keras中的方法进行热编码:

# 进行热编码
def one_hot_encode_object_array(arr):
    # 去重获取全部的类别
    uniques, ids = np.unique(arr, return_inverse=True)
    # 返回热编码的结果
    return utils.to_categorical(ids, len(uniques))


#对标签值进行热编码 
# 训练集热编码
train_y_ohe = one_hot_encode_object_array(train_y)
# 测试集热编码
test_y_ohe = one_hot_encode_object_array(test_y)

 

模型搭建

在sklearn中,模型都是现成的。tf.Keras是一个神经网络库,我们需要根据数据和标签值构建神经网络。

神经网络可以发现特征与标签之间的复杂关系。

神经网络是一个高度结构化的图,其中包含一个或多个隐藏层。

每个隐藏层都包含一个或多个神经元。

神经网络有多种类别,该程序使用的是密集型神经网络,也称为全连接神经网络:一个层中的神经元将从上一层中的每个神经元获取输入连接。例如,图 2 显示了一个密集型神经网络,其中包含 1 个输入层、2 个隐藏层以及 1 个输出层,如下图所示:

上图 中的模型经过训练并馈送未标记的样本时,它会产生 3 个预测结果:相应鸢尾花属于指定品种的可能性。对于该示例,输出预测结果的总和是 1.0。该预测结果分解如下:山鸢尾为 0.02,变色鸢尾为 0.95,维吉尼亚鸢尾为 0.03。这意味着该模型预测某个无标签鸢尾花样本是变色鸢尾的概率为 95%。

TensorFlow tf.keras API 是创建模型和层的首选方式。通过该 API,您可以轻松地构建模型并进行实验,而将所有部分连接在一起的复杂工作则由 Keras 处理。

tf.keras.Sequential 模型是层的线性堆叠。该模型的构造函数会采用一系列层实例;在本示例中,采用的是 2 个密集层(分别包含 10 个节点)以及 1 个输出层(包含 3 个代表标签预测的节点)。第一个层的 input_shape 参数对应该数据集中的特征数量:

# 利用sequential方式构建模型
model = Sequential([
  # 隐藏层1,激活函数是relu,输入大小有input_shape指定
  Dense(10, activation="relu", input_shape=(4,)),  
  # 隐藏层2,激活函数是relu
  Dense(10, activation="relu"),
  # 输出层
  Dense(3,activation="softmax")
])

通过model.summary可以查看模型的架构:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 10)                50        
_________________________________________________________________
dense_1 (Dense)              (None, 10)                110       
_________________________________________________________________
dense_2 (Dense)              (None, 3)                 33        
=================================================================
Total params: 193
Trainable params: 193
Non-trainable params: 0
_________________________________________________________________             

激活函数可决定层中每个节点的输出形状。这些非线性关系很重要,如果没有它们,模型将等同于单个层。激活函数有很多,但隐藏层通常使用 ReLU

隐藏层和神经元的理想数量取决于问题和数据集。与机器学习的多个方面一样,选择最佳的神经网络形状需要一定的知识水平和实验基础。一般来说,增加隐藏层和神经元的数量通常会产生更强大的模型,而这需要更多数据才能有效地进行训练。

模型训练和预测

在训练和评估阶段,我们都需要计算模型的损失。这样可以衡量模型的预测结果与预期标签有多大偏差,也就是说,模型的效果有多差。我们希望尽可能减小或优化这个值,所以我们设置优化策略和损失函数,以及模型精度的计算方法:

# 设置模型的相关参数:优化器,损失函数和评价指标
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=["accuracy"])

接下来与在sklearn中相同,分别调用fit和predict方法进行预测即可。

# 模型训练:epochs,训练样本送入到网络中的次数,batch_size:每次训练的送入到网络中的样本个数
model.fit(train_X, train_y_ohe, epochs=10, batch_size=1, verbose=1);
  1. 迭代每个epoch。通过一次数据集即为一个epoch。
  2. 在一个epoch中,遍历训练 Dataset 中的每个样本,并获取样本的特征 (x) 和标签 (y)。
  3. 根据样本的特征进行预测,并比较预测结果和标签。衡量预测结果的不准确性,并使用所得的值计算模型的损失和梯度。
  4. 使用 optimizer 更新模型的变量。
  5. 对每个epoch重复执行以上步骤,直到模型训练完成。

与sklearn中不同,对训练好的模型进行评估时,与sklearn.score方法对应的是tf.keras.evaluate()方法,返回的是损失函数和在compile模型时要求的指标: 

# 计算模型的损失和准确率
loss, accuracy = model.evaluate(test_X, test_y_ohe, verbose=1)
print("Accuracy = {:.2f}".format(accuracy))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1003550.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

机器学习——协同过滤算法(CF)

机器学习——协同过滤算法(CF) 文章目录 前言一、基于用户的协同过滤1.1. 原理1.2. 算法步骤1.3. 代码实现 二、基于物品的协同过滤2.1. 原理2.2. 算法步骤2.3. 代码实现 三、比较与总结四、实例解析总结 前言 协同过滤算法是一种常用的推荐系统算法&am…

清理 Ubuntu 系统的 4 个简单步骤

清理 Ubuntu 系统的 4 个简单步骤 现在,试试看这 4 个简单的步骤,来清理你的 Ubuntu 系统吧。 这份精简指南将告诉你如何清理 Ubuntu 系统以及如何释放一些磁盘空间。 如果你的 Ubuntu 系统已经运行了至少一年,尽管系统是最新的,…

2003-2022年黄河流域TCI、VCI、VHI、TVDI逐年1km分辨率数据集

摘要 黄河流域大部分属于干旱、半干旱气候,先天水资源条件不足,是中国各大流域中受干旱影响最为严重的流域。随着全球环境和气候变化,黄河流域的干旱愈加频繁,对黄河流域的干旱监测研究已经成为当下的热点。本数据集基于MODIS植被和地表温度产品,通过对逐年数据进行去云、…

Mendix使用Upload image新增修改账户头像

学习Mendix中级文档,其中有个管理我的账号功能,确保账号主任可以修改其头像,接下来记录如何实现账户头像的上传和修改。根据文档的步骤实现功能~~ 新建GeneralExtentions模块,给GeneralExtentions添加两个模…

MapTR v2文章研读

MapTR v2论文来了,本文仅介绍v2相较于v1有什么改进之处,如果想了解v1版本的论文细节,可见链接。 相较于maptr,maptr v2改进之处: 在分层query机制中引进解耦自注意力机制,有效降低了内存消耗;…

Spring中如何解决循环依赖问题

一、什么是循环依赖 循环依赖也叫循环引用,是指bean之间形成相互依赖的关系,由此,bean对象在属性注入时便会产生循环。这种循环依赖会导致编译器无法编译代码,从而无法运行程序。为了避免循环依赖,我们在开发过程中需…

视频号视频下载工具有那些?我们怎么下载视频号里面的视频

本篇文章给大家谈谈视频号视频下载工具,以及视频号视频如何下载?对应的知识点,希望对各位有所帮助。 视频号里面的视频可以下载吗? 视频号官方首先是不提供下载功能的,但是很多第三方可以提供视频号的视频下载功能。 早期版本视…

【力扣每日一题】2023.9.12 课程表Ⅳ

目录 题目: 示例: 分析: 代码: 题目: 示例: 分析: 今天是课程表系列题目的最后一题,因为我在题库里找不到课程表5了,所以今天的每日一题就是最后一个课程表了。 题…

小节5:Python列表list常用操作

1、对列表的基本认知: 列表list,是可变类型。比如,append()函数会直接改变列表本身,往列表里卖弄添加元素。所以,list_a list_a.append(123)就是错误的。如果想删除列表中的元素,可以用remove()函数&…

基于微信小程序的宠物寄养平台,附源码、数据库

1. 简介 本文正是基于微信小程序开发平台,针对宠物寄养的需求,本文设计出一个包含寄养家庭分类、寄养服务管理、宠物档案、交流论坛的微信小程序,以此帮助宠物寄养的实现,促进宠物寄养工作的进展。 2 开发技术 微信小程序的运行环境分为渲染层和逻辑层&#xff0…

仿照Everything实现的文件搜索工具--SearchEverything

一、项目介绍 项目名称:SearchEverything 项目简介:SearchEverything是仿照Everything实现的一款桌面级的文件搜索软件,它是Everything的增强版,支持跨平台的使用。 项目功能: 1.选择文件夹后,多线程扫描文件夹下的…

学会这个技能,写字楼立马高级起来!

在当今现代化社会中,写字楼已成为商业和行政活动的中心。成千上万的人们每天涌入这些高楼大厦,从事各种各样的工作,以实现公司和组织的目标。然而,与这种繁忙的办公环境一样,也带来了一系列的安全挑战和管理难题。 随着…

【大数据之Kafka】十一、Kafka消费者及消费者组案例

1 独立消费者案例(订阅主题) (1)需求:创建一个独立消费者,消费 first 主题中数据。 (2)分析: 注意:在消费者 API 代码中必须配置消费者组 id。命令行启动消…

算法通关村第13关【青铜】| 数字与数学基础问题

数字统计专题 1.数组元素积的符号 思路&#xff1a;每回碰到负数就取反 class Solution {public int arraySign(int[] nums) {int res nums[0];if(nums[0]>0){res 1;}else if(nums[0]<0){res -1;}else{return res;}for(int i 1;i<nums.length;i){if(nums[i]<…

Linux基本认识

一、Linux基本概念 Linux 内核最初只是由芬兰人林纳斯托瓦兹&#xff08;Linus Torvalds&#xff09;在赫尔辛基大学上学时出于个人爱好而编写的。 Linux 是一套免费使用和自由传播的类 Unix 操作系统&#xff0c;是一个基于 POSIX 和 UNIX 的多用户、多任务、支持多线程和多…

地下管网实时水位监测用什么设备好?

地下排水管网是城市重要基础设施生命线之一&#xff0c;主要用于排放雨水、地表水和废水&#xff0c;以维护城市的安全运行。然而&#xff0c;在极端天气事件发生时&#xff0c;排水系统可能会面临压力巨大&#xff0c;导致排水不畅引发城市内涝。通过对管网水位实时监测&#…

Java集合大总结——Collection集合

Collection集合的整理 1、List&#xff0c;Set&#xff0c;Queue&#xff0c;Map四者的区别集合底层数据结构梳理2、关于集合的的选用2.1 为什么使用集合3、List接口3.1 ArrayList 和 Array&#xff08;数组&#xff09;的区别&#xff1f;3.1 LinkedList 为什么不能实现Random…

基于python+txt的学生成绩管理系统

基于pythontxt的学生成绩管理系统 一、系统介绍二、效果展示三、其他系统实现四、获取源码 一、系统介绍 录入学生信息查找学生信息删除学生信息修改学生信息排序统计学生信息显示所有学生信息 基于python的学生成绩管理系统&#xff0c;具备基本的增删改查功能&#xff0c;包…

2023-9-12 完全背包问题

题目链接&#xff1a;完全背包问题 初版(时间复杂度拉满) #include <iostream> #include <algorithm>using namespace std;const int N 1010;int n, m; int v[N], w[N]; int f[N][N];int main() {cin >> n >> m;for(int i 1; i < n; i ) cin >…

AntDB数据库参加ACDU中国行杭州站,分享数据库运维实践与经验

关于ACDU 和中国行: ACDU是由墨天轮社区举办的中国数据库联盟的品牌活动之一&#xff0c;在线下汇集数据库领域的行业知名人士&#xff0c;共同探讨数据库前沿技术及其应用&#xff0c;促进行业发展和创新的平台&#xff0c;也为开发者们提供友好交流的机会。 AntDB作为具有技术…