tensorflow深度神经网络实现鸢尾花分类

news2024/10/5 13:16:02

tensorflow深度神经网络实现鸢尾花分类

本文目录

    • tensorflow深度神经网络实现鸢尾花分类
      • 获取数据集
      • 相关库的导入
      • 数据展示和划分
      • 对标签值进行热编码
      • 模型搭建
      • 使用Sequential模型
        • 搭建模型
        • 模型训练
        • 对训练好的模型进行评估
      • 使用model模型
        • 搭建模型
        • 对训练好的模型进行评估
      • 损失函数
      • 优化方法
      • 正则化

获取数据集

下载鸢尾花数据集:https://gitcode.net/mirrors/mwaskom/seaborn-data

相关库的导入

import tensorflow as tf
# 绘图
import seaborn as sns
# 数值计算
import numpy as np
# sklearn中的相关工具
# 划分训练集和测试集
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, classification_report

数据展示和划分

利用seaborn导入相关的数据,iris数据以dataFrame的方式在seaborn进行存储,我们读取后并进行展示

# 读取数据
# iris = sns.load_dataset("iris")
iris = sns.load_dataset("iris", cache=True, data_home='./seaborn-data/')
# 展示数据的前五行
print(iris.head())
   sepal_length  sepal_width  petal_length  petal_width species
0           5.1          3.5           1.4          0.2  setosa
1           4.9          3.0           1.4          0.2  setosa
2           4.7          3.2           1.3          0.2  setosa
3           4.6          3.1           1.5          0.2  setosa
4           5.0          3.6           1.4          0.2  setosa

利用seabornpairplot函数探索数据特征间的关系

# 将数据之间的关系进行可视化
sns.pairplot(iris, hue='species')
plt.show()

iris dataframe中提取原始数据,将花瓣和萼片数据保存在数组X中,标签保存在相应的数组y中

# 花瓣和花萼的数据
X = iris.values[:, :4]
# 标签值
y = iris.values[:, 4]

利用train_test_split将数据划分为训练集测试集

# 将数据集划分为训练集和测试集
train_X, test_X, train_y, test_y = train_test_split(X, y, train_size=0.7, test_size=0.3, random_state=0)s

对标签值进行热编码

独热编码:将分类变量转换为一组二进制变量的过程,其中每个二进制变量对应一个分类变量的值。这些二进制变量中的一个包含“1”,其余变量都包含“0”。这些二进制变量可以被视为一组互斥的指示器,因此独热编码也称为“一位有效编码”(One-of-N Encoding)

# 进行独热编码
def one_hot_encode_object_array(arr):
    # 去重获取全部的类别
    uniques, ids = np.unique(arr, return_inverse=True)
    # 返回热编码的结果
    return tf.keras.utils.to_categorical(ids, len(uniques))


# 训练集热编码
train_y_ohe = one_hot_encode_object_array(train_y)
# 测试集热编码
test_y_ohe = one_hot_encode_object_array(test_y)

模型搭建

tf.Keras是一个神经网络库,我们需要根据数据和标签值构建神经网络。神经网络可以发现特征与标签之间的复杂关系。神经网络是一个高度结构化的图,其中包含一个或多个隐藏层。每个隐藏层都包含一个或多个神经元。神经网络有多种类别,该程序使用的是密集型神经网络,也称为全连接神经网络:一个层中的神经元将从上一层中的每个神经元获取输入连接。

一个密集型神经网络,其中包含 1 个输入层、2 个隐藏层以及 1 个输出层,如下图所示:

上图 中的模型经过训练并馈送未标记的样本时,它会产生 3 个预测结果:相应鸢尾花属于指定品种的可能性。对于该示例,输出预测结果的总和是 1.0。该预测结果分解如下:山鸢尾为 0.02,变色鸢尾为 0.95,维吉尼亚鸢尾为 0.03。这意味着该模型预测某个无标签鸢尾花样本是变色鸢尾的概率为 95%

使用Sequential模型

搭建模型

tf.keras.Sequential 模型是层的线性堆叠,采用的是 2 个密集层(分别包含 10 个节点)以及 1 个输出层(包含 3 个代表标签预测的节点)。第一个层的 input_shape 参数对应该数据集中的特征数量

# 利用sequential方式构建模型
model = tf.keras.models.Sequential([
    # 隐藏层1,激活函数是relu,输入大小有input_shape指定
    tf.keras.layers.Dense(10, activation="relu", input_shape=(4,)),
    # 隐藏层2,激活函数是relu
    tf.keras.layers.Dense(10, activation="relu"),
    # 输出层
    tf.keras.layers.Dense(3, activation="softmax")
])

查看模型的架构

model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 10)                50        
                                                                 
 dense_1 (Dense)             (None, 10)                110       
                                                                 
 dense_2 (Dense)             (None, 3)                 33        
                                                                 
=================================================================
Total params: 193
Trainable params: 193
Non-trainable params: 0
_________________________________________________________________

模型训练

设置优化策略和损失函数,以及模型精度的计算方法

# 设置模型的相关参数:优化器,损失函数和评价指标
model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'],
)
  1. 迭代每个epoch。通过一次数据集即为一个epoch。
  2. 在一个epoch中,遍历训练 Dataset 中的每个样本,并获取样本的特征 (x) 和标签 (y)。
  3. 根据样本的特征进行预测,并比较预测结果和标签。衡量预测结果的不准确性,并使用所得的值计算模型的损失和梯度。
  4. 使用 optimizer 更新模型的变量。
  5. 对每个epoch重复执行以上步骤,直到模型训练完成。
# 模型训练:epochs,训练样本送入到网络中的次数,batch_size:每次训练的送入到网络中的样本个数
history = model.fit(train_X, train_y_ohe, epochs=100, batch_size=1, verbose=1, validation_data=(test_X, test_y_ohe))

训练过程

...
105/105 [==============================] - 1s 7ms/sample - loss: 0.0561 - acc: 0.9619 - val_loss: 0.1246 - val_acc: 0.9778
Epoch 97/100
105/105 [==============================] - 1s 6ms/sample - loss: 0.0732 - acc: 0.9524 - val_loss: 0.0941 - val_acc: 0.9778
Epoch 98/100
105/105 [==============================] - 1s 6ms/sample - loss: 0.0566 - acc: 0.9714 - val_loss: 0.1209 - val_acc: 0.9778
Epoch 99/100
105/105 [==============================] - 1s 6ms/sample - loss: 0.0575 - acc: 0.9810 - val_loss: 0.1297 - val_acc: 0.9556
Epoch 100/100
105/105 [==============================] - 1s 6ms/sample - loss: 0.0609 - acc: 0.9714 - val_loss: 0.0802 - val_acc: 0.9778

对训练好的模型进行评估

计算损失和准确率

# 输出模型评估报告
y_pred = model.predict(test_X)
print('Accuracy score:', accuracy_score(test_y_ohe.argmax(axis=1), y_pred.argmax(axis=1)))
print(classification_report(test_y_ohe.argmax(axis=1), y_pred.argmax(axis=1)))
Accuracy score: 0.9777777777777777
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        16
           1       1.00      0.94      0.97        18
           2       0.92      1.00      0.96        11

    accuracy                           0.98        45
   macro avg       0.97      0.98      0.98        45
weighted avg       0.98      0.98      0.98        45
# 获取模型训练过程的准确率以及损失率的变化
accuracy = history.history['acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
val_accuracy = history.history['val_acc']
epochs = range(len(accuracy))
plt.plot(epochs, accuracy, 'b', label='Training accuracy')
plt.plot(epochs, val_accuracy, 'orange', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()
plt.show()

plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'orange', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()

使用model模型

搭建模型

tf.keras 提供了 Functional API,建立更为复杂的模型,使用方法是将层作为可调用的对象并返回张量,并将输入向量和输出向量提供给 tf.keras.Modelinputsoutputs 参数

inputs = tf.keras.Input(shape=(4,))
x = tf.keras.layers.Dense(10, activation="relu")(inputs)
x = tf.keras.layers.Dense(10, activation="relu")(x)
outputs = tf.keras.layers.Dense(3, activation="softmax")(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

可以优化,正则化通过对算法的修改来减少泛化误差

x = tf.keras.layers.Dense(10, activation="relu", kernel_regularizer=tf.keras.regularizers.l2(0.01))(inputs)
x = tf.keras.layers.Dense(10, activation="relu", kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 4)]               0         
                                                                 
 dense (Dense)               (None, 10)                50        
                                                                 
 dense_1 (Dense)             (None, 10)                110       
                                                                 
 dense_2 (Dense)             (None, 3)                 33        
                                                                 
=================================================================
Total params: 193
Trainable params: 193
Non-trainable params: 0
_________________________________________________________________

对训练好的模型进行评估

Accuracy score: 0.9777777777777777
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        14
           1       1.00      0.94      0.97        17
           2       0.93      1.00      0.97        14

    accuracy                           0.98        45
   macro avg       0.98      0.98      0.98        45
weighted avg       0.98      0.98      0.98        45

损失函数

  • 分类任务

    • 多分类任务:softmax
    • 二分类:sigmoid
  • 回归任务

    • MAE:Mean absolute loss(MAE)也被称为L1 Loss,以绝对误差作为距离

    • MSE:Mean Squared Loss/ Quadratic Loss(MSE loss)也被称为L2 loss,或欧氏距离,以误差的平方和作为距离

    • smooth L1

优化方法

  • 梯度下降

  • 反向传播算法(BP算法)

  • 梯度下降优化方法

    • 动量算法(Momentum)
    • AdaGrad
    • RMSprop
    • Adam
  • 学习率退火

正则化

  • L1正则化
  • L2正则化
  • L1L2
  • Dropout正则化
  • 提前停止

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/419269.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用golang连接kafka

1 下载,配置,启动 kafka 下载链接 配置修改 在config目录下的server文件和zookeeper文件,其中分别修改kafka的日志保存路径和zookeeper的数据保存路径。 启动kafka 先启动kafka自带的zookeeper,在kafka的根目录下打开终端&a…

百模大战,谁是下一个ChatGPT?

“不敢下手,现在中国还没跑出来一家绝对有优势的大模型,上层应用没法投,担心押错宝。”投资人Jucy(化名)向光锥智能表示,AI项目看得多、投的少是这段时间的VC常态。 ChatGPT点燃AI大爆炸2个月中&#xff0…

为什么工控行业生意越来越难做了?

前段时间跟几个做工业品销售的朋友聚了一下,大家都说去年一年挺难的,有些甚至想把小店关了。为什么现在工业品领域越来越难做了呢?今天也想给大家说一说我的一些看法。 以前的工控生意相对现在来说较为有限和封闭,技术上也没有现今…

Android 大图检测插件的落地

作者:layz4android 在实际的项目开发中,引入图片的方式基本可以分为两种:本地图片和云端图片,对于云端图片来说,可以动态地配置图片的大小,如果服务端的伙伴下发的图片很大导致程序异常,那么可以…

前端视角-https总结

1.http存在的问题 1.1可能被窃听 HTTP 本身不具备加密的功能,HTTP 报文使用明文方式发送互联网是由联通世界各个地方的网络设施组成,所有发送和接收经过某些设备的数据都可能被截获或窥视。(例如TCP/IP抓包工具:Wireshark),即使经过加密处理,也会被窥视是通信内容,只是可能很…

在 Flutter 多人视频通话中实现虚拟背景、美颜与空间音效

前言 在之前的「基于声网 Flutter SDK 实现多人视频通话」里,我们通过 Flutter 声网 SDK 完美实现了跨平台和多人视频通话的效果,那么本篇我们将在之前例子的基础上进阶介绍一些常用的特效功能,包括虚拟背景、色彩增强、空间音频、基础变声…

HBase高手之路4-Shell操作

文章目录HBase高手之路3—HBase的shell操作一、hbase的shell命令汇总二、需求三、表的操作1.进入shell命令行2.创建表3.查看表的定义4.列出所有的表5.删除表1)禁用表2)启用表3)删除表四、数据的操作1.添加数…

TensorFlow 深度学习实战指南:1~5 全

原文:Hands-on Deep Learning with TensorFlow 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自【ApacheCN 深度学习 译文集】,采用译后编辑(MTPE)流程来尽可能提升效率。 不要担心自己的形象,只关心如…

【通义千问】继ChatGPT爆火后,阿里云的大模型“通义千问”它终于来了

通义千问一、通义千问名字的由来二、通义千问和ChatGPT有什么区别呢?三、如何申请体验通义千问呢?四、未来通义千问能称为中国版的ChatGPT吗?五、通义千问什么时候正式发布呢?一、通义千问名字的由来 通义千问顾名思义&#xff0…

作物杂交——蓝桥杯20年省赛(JAVA)

题目链接: 用户登录https://www.lanqiao.cn/problems/506/learning/?page2&first_category_id1&sortstudents_count 题目描述 作物杂交是作物栽培中重要的一步。已知有 N 种作物 (编号 1 至 N ),第 i 种作物从播种到成熟的时间为 Ti​。作物…

少儿编程 电子学会图形化 scratch编程等级考试四级真题答案解析(判断题)2022年12月

2022年12月scratch编程等级考试四级真题 判断题(共10题,每题2分,共20分) 16、点击绿旗,反复按下空格键,可以使变量a的值在0和1之间反复变化 答案:对 考点分析:考查积木综合使用,重点考查变量积木的使用,按一下空格键,a变量值会改变5次,0-1-0-1-0-1,按第二下…

budibase <2.4.3 存在 ssrf 漏洞(CVE-2023-29010)

漏洞描述 budibase 是一个开源的低代码平台,元数据端点(metadata endpoint)是Budibase提供的一个REST API端点,用于访问应用程序的元数据信息。 budibase 2.4.3之前版本中存在 ssrf 漏洞,该漏洞可能影响 Budibase 自主托管的用户&#xff0…

安利安利-向大家推荐一个超级牛的etcd管理工具-EtcdKeeperFyne

etcd介绍 关于etcd的介绍大家可以看下这篇文章 etcd 开源仓库地址:EtcdKeeperFyne EtcdKeeperFyne 今天主要是向大家推荐一款使用起来特别方便的Etcd管理工具 EtcdKeeperFyne,具体运行起来的界面如下: 推荐原因 使用简单安装简单&…

卷积层输出尺寸计算 / 感受野尺寸计算

卷积层输出尺寸计算 输入图像a*a, 卷积核大小b*b, stride c, padding d 输出图像的尺寸:[(a - b 2d) // c] 1 (a - b 2d) 表示在输入图像两侧填充 d 个像素后,窗口在输入图像上最多能移动的距离,再加上 1 表示最后一个窗口的右侧边界…

博客文章效果

学习风宇blog md文档转html&#xff08;markdown-it的使用&#xff09;语法高亮、行号、一键复制toc生成目录sticky粘性定位 <style lang"scss"> import url(//at.alicdn.com/t/c/font_4004562_9v94jccafmc.css); import url(https://fonts.font.im/css?fam…

DFIG控制8: 不平衡电网下的网侧变换器控制

DFIG控制8&#xff1a; 不平衡电网下的网侧变换器控制。主要是添加网侧变换器的负序分量控制器。 本文基于教程的第8部分&#xff1a;DFIM Tutorial 8 - Asymmetrical Voltage Dips Analysis in DFIG based WT: Grid Side Converter Control 控制策略简介 来自&#xff1a;G…

过滤器(Filter)与拦截器(Interceptor)区别

1 过滤器&#xff08;Filter&#xff09; Servlet 中的过滤器 Filter 实现了 javax.servlet.Filter 接口的服务器端程序&#xff0c;主要用途是设置字符集&#xff08;CharacterEncodingFilter&#xff09;、控制权限、控制转向、用户是否已经登陆、有没有权限访问该页面等。 …

springboot配置跨域问题

近期自己搭建项目时&#xff0c;遇到一个跨域问题。我们以前项目解决跨域是在controller上加一个跨域注解CrossOrigin(allowCredentials "true")&#xff0c;很方便。但是在我自己搭建的项目中&#xff0c;启动时竟然报错了&#xff0c;错误如下&#xff1a; When …

图的传递闭包

给定一个有向图,对于给定图中的所有顶点对(i, j),找出一个顶点j是否可从另一个顶点i到达。这里的可达性是指从顶点i到j有一条路径。可达性矩阵称为图的传递闭包。 例如,考虑下面的图表 上述图的传递闭包为 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 1 该图以邻接矩阵的形式给出,…

抛弃 TCP 和 QUIC 的 HTTP

下班路上发了一则朋友圈&#xff1a; 周四听了斯坦福老教授 John Ousterhout 关于 Homa 的分享&#xff0c;基本重复了此前那篇 It’s Time To Rep… 的格调&#xff0c;花了一多半时间喷 TCP… Ousterhout 关于 Homa 和 TCP 之间的论争和论证&#xff0c;诸多反复回执&…