深度学习:tf.keras实现模型搭建、模型训练和预测

news2024/11/24 15:35:35

在sklearn中,模型都是现成的。tf.Keras是一个神经网络库,我们需要根据数据和标签值构建神经网络。神经网络可以发现特征与标签之间的复杂关系。神经网络是一个高度结构化的图,其中包含一个或多个隐藏层。每个隐藏层都包含一个或多个神经元。神经网络有多种类别,该程序使用的是密集型神经网络,也称为全连接神经网络:一个层中的神经元将从上一层中的每个神经元获取输入连接。例如,图 2 显示了一个密集型神经网络,其中包含 1 个输入层、2 个隐藏层以及 1 个输出层,如下图所示:

神经网络

上图 中的模型经过训练并馈送未标记的样本时,它会产生 3 个预测结果:相应鸢尾花属于指定品种的可能性。对于该示例,输出预测结果的总和是 1.0。该预测结果分解如下:山鸢尾为 0.02,变色鸢尾为 0.95,维吉尼亚鸢尾为 0.03。这意味着该模型预测某个无标签鸢尾花样本是变色鸢尾的概率为 95%。

TensorFlow tf.keras API 是创建模型和层的首选方式。通过该 API,您可以轻松地构建模型并进行实验,而将所有部分连接在一起的复杂工作则由 Keras 处理。

tf.keras.Sequential 模型是层的线性堆叠。该模型的构造函数会采用一系列层实例;在本示例中,采用的是 2 个密集层(分别包含 10 个节点)以及 1 个输出层(包含 3 个代表标签预测的节点)。第一个层的 input_shape 参数对应该数据集中的特征数量:

# 利用sequential方式构建模型model = Sequential([
  # 隐藏层1,激活函数是relu,输入大小有input_shape指定
  Dense(10, activation="relu", input_shape=(4,)),  
  # 隐藏层2,激活函数是relu
  Dense(10, activation="relu"),
  # 输出层
  Dense(3,activation="softmax")])

通过model.summary可以查看模型的架构:

激活函数可决定层中每个节点的输出形状。这些非线性关系很重要,如果没有它们,模型将等同于单个层。激活函数有很多,但隐藏层通常使用 ReLU。

隐藏层和神经元的理想数量取决于问题和数据集。与机器学习的多个方面一样,选择最佳的神经网络形状需要一定的知识水平和实验基础。一般来说,增加隐藏层和神经元的数量通常会产生更强大的模型,而这需要更多数据才能有效地进行训练。

模型训练和预测

在训练和评估阶段,我们都需要计算模型的损失。这样可以衡量模型的预测结果与预期标签有多大偏差,也就是说,模型的效果有多差。我们希望尽可能减小或优化这个值,所以我们设置优化策略和损失函数,以及模型精度的计算方法:

# 设置模型的相关参数:优化器,损失函数和评价指标mode
l.compile(optimizer='adam', loss='categorical_crossentropy', metrics=["accuracy"])

接下来与在sklearn中相同,分别调用fit和predict方法进行预测即可。

# 模型训练:epochs,训练样本送入到网络中的次数,batch_size:每次训练的送入到网络中的样本个数
model.fit(train_X, train_y_ohe, epochs=10, batch_size=1, verbose=1);

上述代码完成的是:

  1. 迭代每个epoch。通过一次数据集即为一个epoch。

  2. 在一个epoch中,遍历训练 Dataset 中的每个样本,并获取样本的特征 (x) 和标签 (y)。

  3. 根据样本的特征进行预测,并比较预测结果和标签。衡量预测结果的不准确性,并使用所得的值计算模型的损失和梯度。

  4. 使用 optimizer 更新模型的变量。

  5. 对每个epoch重复执行以上步骤,直到模型训练完成。

训练过程展示如下:

Epoch 1/10
75/75 [==============================] - 0s 616us/step - loss: 0.0585 - accuracy: 0.9733
Epoch 2/10
75/75 [==============================] - 0s 535us/step - loss: 0.0541 - accuracy: 0.9867
Epoch 3/10
75/75 [==============================] - 0s 545us/step - loss: 0.0650 - accuracy: 0.9733
Epoch 4/10
75/75 [==============================] - 0s 542us/step - loss: 0.0865 - accuracy: 0.9733
Epoch 5/10
75/75 [==============================] - 0s 510us/step - loss: 0.0607 - accuracy: 0.9733
Epoch 6/10
75/75 [==============================] - 0s 659us/step - loss: 0.0735 - accuracy: 0.9733
Epoch 7/10
75/75 [==============================] - 0s 497us/step - loss: 0.0691 - accuracy: 0.9600
Epoch 8/10
75/75 [==============================] - 0s 497us/step - loss: 0.0724 - accuracy: 0.9733
Epoch 9/10
75/75 [==============================] - 0s 493us/step - loss: 0.0645 - accuracy: 0.9600
Epoch 10/10
75/75 [==============================] - 0s 482us/step - loss: 0.0660 - accuracy: 0.9867

与sklearn中不同,对训练好的模型进行评估时,与sklearn.score方法对应的是tf.keras.evaluate()方法,返回的是损失函数和在compile模型时要求的指标:

# 计算模型的损失和准确率
loss, accuracy = model.evaluate(test_X, test_y_ohe, verbose=1)
print("Accuracy = {:.2f}".format(accuracy))

分类器的准确率为:

3/3 [==============================] - 0s 591us/step - loss: 0.1031 - accuracy: 0.9733
Accuracy = 0.97

到此我们对tf.kears的使用有了一个基本的认知,在接下来的课程中会给大家解释神经网络以及在计算机视觉中的常用的CNN的使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/781234.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

echarts3d饼图实现

一、vue中使用3d饼图 效果图: 二、使用步骤 1.引入库 安装echarts 在package.json文件中添加 "dependencies": {"echarts": "^5.1.2""echarts-gl": "^1.1.0",// "echarts-gl": "^2.0.8&quo…

基于AutoEncoder自编码器的MNIST手写数字数据库识别matlab仿真

目录 1.算法理论概述 2.部分核心程序 3.算法运行软件版本 4.算法运行效果图预览 5.算法完整程序工程 1.算法理论概述 MNIST手写数字数据库是机器学习中常用的数据集,包含了0到9这10个数字的手写图片。本文介绍一种基于AutoEncoder自编码器的MNIST手写数字识别算…

高校大数据教材推荐-Hadoop大数据开发基础(第2版)(微课版)

Hadoop大数据开发基础(第2版)(微课版)是“十四五”职业教育国家规划教材,是大数据应用开发“1X”职业技能等级证书配套系列教材,也是“以纸质教材为核心、以互联网为载体”的新形态教材,配套39个微课视频(二维码随扫随学&#xff…

DASCTF 2023 0X401七月暑期挑战赛 Reverse部分题解

文章目录 controlflow1. 异或0x4012. 加i*i3. 异或i*(i1)4. 减i5. 乘36. swap7. judge解题脚本 webserver1.关键函数2. 求约束条件3.Z3求解 controlflow 动态调试观察执行情况 1. 异或0x401 2. 加i*i 3. 异或i*(i1) 注意这里是从data[10i]开始 4. 减i 5. 乘3 6. swap 注意…

redis的简单入门

文章目录 一、前言1.1、什么是Redis? 二、简介三、Redis下载与安装四、Redis服务启动与停止五、Redis设置密码进行远程连接5.1、设置密码5.2、远程连接 六、Redis数据类型七、Redis常用命令7.1、字符串String命令7.2、哈希hash操作命令7.3、列表list操作命令7.4、集合set操作命…

机器人SLAM导航学习-All in one

参考引用 张虎,机器人SLAM导航核心技术与实战[M]. 机械工业出版社,2022.本博客未详尽之处可自行查阅上述书籍 一、编程基础篇 1. ROS 入门必备知识 ROS学习笔记(文章链接汇总) 2. C 编程范式 《21天学通C》读书笔记&#xff0…

leetcode743. 网络延迟时间 DJ

https://leetcode.cn/problems/network-delay-time/ 有 n 个网络节点,标记为 1 到 n。 给你一个列表 times,表示信号经过 有向 边的传递时间。 times[i] (ui, vi, wi),其中 ui 是源节点,vi 是目标节点, wi 是一个信…

python_day13

reduceByKey算子,聚合 列表中存放二元元组,元组中第一个为key,此算子按key聚合,传入计算逻辑 from pyspark import SparkConf, SparkContext import osos.environ["PYSPARK_PYTHON"] "D:/dev/python/python3.10…

Geriit使用出错记录

拉取服务器代码(clone ) 1、执行:git clone ssh:xxxxxx && scp -p -P 29418 xxxxxxxxx 1、报错:Unable to negotiate with XX.XX.XX.XX port XX:: no matching key exchange method found. Their offer: diffie-hellman…

第五章 编程之免交互

免交互:不需要人为控制就可以完成的自动化操作(自动化运维) shell脚本和面交互是一个概念,但是两种写法 shell:默认解释器是bash 使用i/o(输入/输出)重定向的方式,将命令的列表提供…

pytest钩子函数(二):初始化钩子

前言 pytest这个框架提供了非常多的钩子。通过这些钩子我们可以对pytest 用例收集、用例执行、报告输出等各个阶段进行干预,根据需求去开发对应的插件,以满足自己的使用场景。 01 什么是钩子函数? 钩子函数在pytest称之为Hook函数,它pytes…

【头歌】二叉树的二叉链表存储及基本操作

第1关:先序遍历创建二叉链表存储的二叉树及遍历操作 任务描述 本关任务:以二叉链表作存储结构存储二叉树,利用先序递归遍历创建二叉树,并依次进行二叉树的前序、中序、后序递归遍历。 相关知识 在顺序存储结构中,利用数组下标表示元素的位置及元素之间孩子或双亲的关系…

失去中国市场的三星继续称霸全球,中国手机的份额反而进一步下降了

市调机构canalys公布了二季度全球手机市场的数据,数据显示三星、苹果的市场份额保持稳定并位居全球前二,三星的表现显然让人称奇,一直被唱衰,却一直都稳稳占据全球手机市场第一名。 从Canalys公布的数据可以看到,三星以…

【Go】 map 精髓理解

map go map 的底层结构 hmap,的四个元素 然后再讲一下 buckets 的元素,讲一下 hash 冲突,和解决方法 再讲一下,增量扩容和等量扩容 再讲一下增删改查的过程,就查询过程 map 基础 向值为 nil 的 map 添加元素会发生 pa…

青枫壁纸小程序V1.4.0(后端SpringBoot)

引言 那么距离上次的更新已经过去了5个多月,期间因为忙着毕业设计的原因,更新的速度变缓了许多。所以,这次的更新无论是界面UI、用户功能、后台功能都有了非常大的区别。希望这次更新可以给用户带来更加好的使用体验 因为热爱,更…

【湍流介质的三维传播模拟器】全衍射3-D传播模拟器,用于在具有随机和背景结构的介质中传播无线电和光传播(Matlab代码实现)

目录 💥1 概述 📚2 运行结果 🎉3 参考文献 🌈4 Matlab代码实现 💥1 概述 全衍射3-D传播模拟器是一种用于模拟在具有随机和背景结构的介质中传播无线电和光的工具。它可以帮助研究人员和工程师理解和预测无线电波和光波…

多重感知机MLP:Mnist

文章目录 网络结构代码common_utils.pynetwork.pyprovider.pytrain.pytest.pyvisual.py 实验训练结果测试结果可视化 网络结构 输入过程输出28*28Flatten784784Linear300300Linear100100Linear10 代码 文件结构: common_utils.py 用来输出日志文件 # common_…

基于扩展(EKF)和无迹卡尔曼滤波(UKF)的电力系统动态状态估计

1 主要内容 该程序对应文章《Power System Dynamic State Estimation Using Extended and Unscented Kalman Filters》,电力系统状态的准确估计对于提高电力系统的可靠性、弹性、安全性和稳定性具有重要意义,虽然近年来测量设备和传输技术的发展大大降低…

Linux常用嗅探工具(1):fping命令

fping的优点: 可以一次ping多个主机可以从主机列表文件ping结果清晰 便于脚本处理速度快 fping的安装: 前置安装cgg编译器 : yum -y install gcc 下载fping: wget http://fping.org/dist/fping-4.0.tar.gz 解压: …

力扣 -- 918. 环形子数组的最大和

一、题目: 题目链接:918. 环形子数组的最大和 - 力扣(LeetCode) 二、解题步骤: 下面是用动态规划的思想解决这道题的过程,相信各位小伙伴都能看懂并且掌握这道经典的动规题目滴。 三、参考代码&#xff1…