【再学Tensorflow2】TensorFlow2的建模流程:Titanic生存预测

news2025/1/20 10:48:44

TensorFlow2的建模流程

  • 1. 使用Tensorflow实现神经网络模型的一般流程
  • 2. Titanic生存预测问题
    • 2.1 数据准备
    • 2.2 定义模型
    • 2.3 训练模型
    • 2.4 模型评估
    • 2.5 使用模型
    • 2.6 保存模型
  • 参考资料

在机器学习和深度学习领域,通常使用TensorFlow来实现机器学习模型,尤其常用于实现神经网络模型。从原理上说可以使用张量构建计算图来定义神经网络,并通过自动微分机制训练模型。自从Tensorflow2发布之后,大大降低了Tensorflow的使用门槛。这里为简洁起见,一般推荐使用TensorFlow的高层次keras接口来实现神经网络网模型。

1. 使用Tensorflow实现神经网络模型的一般流程

  1. 准备数据
  2. 定义模型
  3. 训练模型
  4. 评估模型
  5. 使用模型
  6. 保存模型

对新手来说,其中最困难的部分实际上是准备数据过程。在实践中通常会遇到的数据类型包括结构化数据图片数据文本数据时间序列数据。这里,我们将分别以Titanic生存预测问题,CIFAR2图片分类问题,IMBD电影评论分类问题,国内新冠疫情结束时间预测问题为例,演示应用Tensorflow对这四类数据的建模方法。

2. Titanic生存预测问题

2.1 数据准备

Titanic数据集的目标是根据乘客信息预测他们在Titanic号撞击冰山沉没后能否生存。结构化数据一般会使用Pandas中的DataFrame进行预处理。这些历史数据已经分为训练集和测试集。
导入依赖包

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow.keras import models, layers

导入数据

dftrain_raw = pd.read_csv('../Data/Titanic/train.csv')
dftest_raw = pd.read_csv('../Data/Titanic/test.csv')
dftrain_raw.head(10)

训练数据集
字段说明

  • Survived:0代表死亡,1代表存活【y标签】
  • Pclass:乘客所持票类,有三种值(1,2,3) 【转换成onehot编码】
  • Name:乘客姓名 【舍去】
  • Sex:乘客性别 【转换成bool特征】
  • Age:乘客年龄(有缺失) 【数值特征,添加“年龄是否缺失”作为辅助特征】
  • SibSp:乘客兄弟姐妹/配偶的个数(整数值) 【数值特征】
  • Parch:乘客父母/孩子的个数(整数值)【数值特征】
  • Ticket:票号(字符串)【舍去】
  • Fare:乘客所持票的价格(浮点数,0-500不等) 【数值特征】
  • Cabin:乘客所在船舱(有缺失) 【添加“所在船舱是否缺失”作为辅助特征】
  • Embarked:乘客登船港口:S、C、Q(有缺失)【转换成onehot编码,四维度 S,C,Q,nan】

探索数据

%matplotlib inline
%config InlineBackend.figure_format = 'png'
# label分布情况
ax = dftrain_raw['Survived'].value_counts().plot(kind='bar', figsize= (12,8), fontsize=15, rot=0)
ax.set_ylabel('Counts', fontsize=15)
ax.set_xlabel('Survived', fontsize=15)
plt.show()

label分布情况

%matplotlib inline
%config InlineBackend.figure_format = 'png'
# 年龄分布情况
ax = dftrain_raw['Age'].plot(kind='hist', bins=20, color='purple', figsize=(12, 8), fontsize=15)
ax.set_ylabel('Frequency', fontsize = 15)
plt.show()

年龄分布情况

%matplotlib inline
%config InlineBackend.figure_format = "png"
# 年龄和label的相关性
ax = dftrain_raw.query('Survived == 0')['Age'].plot(kind='density', figsize=(12,8),fontsize=15)
dftrain_raw.query('Survived == 1')['Age'].plot(kind='density', figsize=(12,8), fontsize=15)
ax.legend(['Survived==0','Survived==1'], fontsize=12)
ax.set_ylabel('Density', fontsize=15)
ax.set_xlabel('Age', fontsize=15)
plt.show()

年龄与label的相关性
数据预处理

def preprocessing(dfdata):
    dfresult = pd.DataFrame()
    # Pclass 乘客所持票类,有三种值(1,2,3),转换成onehot编码
    dfPclass = pd.get_dummies(dfdata['Pclass'])
    dfPclass.columns = ['Pclass_' + str(x) for x in dfPclass.columns]
    dfresult = pd.concat([dfresult, dfPclass], axis=1)

    # Sex 乘客性别,转换成bool特征
    dfSex = pd.get_dummies(dfdata['Sex'])
    dfresult = pd.concat([dfresult, dfSex], axis=1)

    # Age 乘客年龄(有缺失)[数值特征,添加“年龄是否缺失”作为辅助特征]
    dfresult['Age'] = dfdata['Age'].fillna(0)
    dfresult['Age_null'] = pd.isna(dfdata['Age']).astype('int32')

    # SibSp, Parch, Fare: 乘客的兄弟姐妹/配偶的个数(整数值)[数值特征];乘客父母/孩子的个数;乘客所持票的价格[浮点数,0-500不等]
    dfresult['SibSp'] = dfdata['SibSp']
    dfresult['Parch'] = dfdata['Parch']
    dfresult['Fare'] = dfdata['Fare']

    # Carbin
    dfresult['Cabin_null'] = pd.isna(dfdata['Cabin']).astype('int32')

    # Embarked 乘客登船港口:S、C、Q(有缺失),转换成onehot编码,四维度 S,C,Q,nan
    dfEmbarked = pd.get_dummies(dfdata['Embarked'], dummy_na=True)
    dfEmbarked.columns = ['Embarked_' + str(x) for x in dfEmbarked.columns]
    dfresult = pd.concat([dfresult, dfEmbarked], axis=1)

    return dfresult

调用预处理函数

x_train = preprocessing(dftrain_raw)
y_train = dftrain_raw['Survived'].values

x_test = preprocessing(dftest_raw)
print('x_train.shape = ', x_train.shape)
print('x_test.shape = ' , x_test.shape)

输出结果:
x_train.shape = (891, 15)
x_test.shape = (418, 15)

2.2 定义模型

使用keras接口有一下三种方式构建模型:

  1. 使用Sequential按层顺序构建模型
  2. 使用函数式API构建任意结构模型
  3. 集成Model基类构建自定义模型

这里,我们使用最简单的Sequential,按层顺序模型。

tf.keras.backend.clear_session()

model = models.Sequential()
model.add(layers.Dense(20, activation='relu', input_shape=(15, )))
model.add(layers.Dense(10, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 20)                320       
_________________________________________________________________
dense_1 (Dense)              (None, 10)                210       
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 11        
=================================================================
Total params: 541
Trainable params: 541
Non-trainable params: 0
_________________________________________________________________

2.3 训练模型

训练模型通常也有3种方法:内置fit方法,内置train_on_batch方法,以及自定义训练循环。此处我们选择最常用也是最简单的内置fit方法。

# 二分类问题选择二元交叉熵损失函数
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['AUC'])
history = model.fit(x_train, y_train, batch_size=64, epochs=50, validation_split=0.2)

运行过程:

Epoch 46/50
12/12 [==============================] - 0s 2ms/step - loss: 0.4421 - auc: 0.8540 - val_loss: 0.3939 - val_auc: 0.8757
Epoch 47/50
12/12 [==============================] - 0s 2ms/step - loss: 0.4427 - auc: 0.8567 - val_loss: 0.4168 - val_auc: 0.8657
Epoch 48/50
12/12 [==============================] - 0s 2ms/step - loss: 0.4631 - auc: 0.8470 - val_loss: 0.4044 - val_auc: 0.8693
Epoch 49/50
12/12 [==============================] - 0s 2ms/step - loss: 0.4602 - auc: 0.8451 - val_loss: 0.3830 - val_auc: 0.8832
Epoch 50/50
12/12 [==============================] - 0s 2ms/step - loss: 0.4450 - auc: 0.8574 - val_loss: 0.3892 - val_auc: 0.8776

2.4 模型评估

首先评估一下模型在训练集和验证集上的效果。

%matplotlib inline
%config InlineBackend.figure_format = 'png'

import matplotlib.pyplot as plt

def plot_metric(history, metric):
    train_metrics = history.history[metric]
    val_metrics = history.history['val_'+metric]
    epochs = range(1, len(train_metrics) + 1)
    plt.plot(epochs, train_metrics, 'bo--')
    plt.plot(epochs, val_metrics, 'ro-')
    plt.title('Training and validation '+ metric)
  png  plt.xlabel("Epochs")
    plt.ylabel(metric)
    plt.legend(["train_"+metric, 'val_'+metric])
    plt.show()

查看loss变化

plot_metric(history,"loss")

损失值
查看AUC曲线

plot_metric(history,'auc')

AUC曲线
模型在训练数据集上的效果值:

model.evaluate(x=x_train, y = y_train) # [0.48629626631736755, 0.8456763625144958]

2.5 使用模型

预测概率

model.predict(x_test)

预测概率
预测类别

model.predict_classes(x_test)

预测类别

2.6 保存模型

可以使用Keras方式保存模型,也可以使用TensorFlow原生方式保存。前者仅仅适合使用Python环境恢复模型,后者则可以跨平台进行模型部署。推荐使用后一种方式进行保存。

  1. 使用Keras方式保存模型
# 保存模型结构及权重
model.save('../Output/keras_model_titanic.h5')

删除现有模型,加载保存的模型

del model # 删除现有模型
model = models.load_model('../Output/keras_model_titanic.h5')
model.evaluate(x_train, y_train)

加载保存的mo'x够
其他:
其他保存方式

  1. 使用Tensorflow原生方式保存
    仅仅保存权重张量:
# 保存权重,该方式仅仅保存权重张量
model.save_weights('../Output/tf_model_titanic_weights.ckpt', save_format='tf')

保存结构参数

# 保存模型结构与模型参数到文件,该方式保存的模型具有跨平台性便于部署
model.save('../Output/tf_model_titanic_saved', save_format='tf')
print('export saved model.')

加载模型:

model_loaded = tf.keras.models.load_model('../Output/tf_model_titanic_saved')
model_loaded.evaluate(x_train, y_train)

模型使用

参考资料

  1. 30天吃掉那只 TensorFlow2:[结构化数据建模流程范例]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/105998.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

03Python算数运算符及变量基本使用

算数运算符 算数运算符 是完成基本的算术运算使用的符号,用来处理四则运算 运算符描述实例加10 20 30-减10 - 20 -10*乘10 * 20 200/除10 / 20 0.5//取整除返回除法的整数部分(商) 9 // 2 输出结果 4%取余数返回除法的余数 9 % 2 1*…

干货 | 云原生时代的灰度发布有几种“姿势”?

随着企业数字化转型进程不断发展,云原生时代的来临,企业应用越来越多,不得不面对应用程序升级的巨大挑战。传统的停机发布方式,新旧版本应用切换少则停机30分钟,多则停机10小时以上,愈发无法满足业务端的需…

java入门及环境配置

java三大版本 JavaSE: 标准版(桌面程序,控制台开发........) JavaEE: 嵌入式开发(手机,小家电.....) JavaEE: E企业级开发(web端,服务器开发...) JDK、JRE、JVM: Java安装开发环境&a…

信息化时代企业数据防泄露工作该怎么做

场景描述 信息化时代发展迅速,数据防泄露一词也频繁的出现在我们身边。无论企业或政府单位,无纸化办公场景越来越多,数据泄露的时间也层出不穷。例如:世界最大职业中介网站Monster遭到黑客大规模攻击,黑客窃取在网站注…

计算机毕设Python+Vue药品销售平台(程序+LW+部署)

项目运行 环境配置: Jdk1.8 Tomcat7.0 Mysql HBuilderX(Webstorm也行) Eclispe(IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持)。 项目技术: SSM mybatis Maven Vue 等等组成,B/S模式 M…

模板二(基础算法)

目录 快速排序 归并排序 二分 整数二分 浮点数二分 前缀和 一维前缀和 二维前缀和 差分 一维差分 二维差分 双指针 位运算 离散化 区间合并 快速排序 方法一:定义两个新数组,a[ ],b[ ],每次将大于x的放到a中,小于x的放到b中&…

【矩阵论】6.范数理论——范数估计——许尔估计谱估计

6.3 许尔估计 任意方阵 A(aij)nnA(a_{ij})_{n\times n}A(aij​)nn​ ,全体根 λ(A){λ1,⋯,λn}\lambda(A)\{\lambda_1,\cdots,\lambda_n\}λ(A){λ1​,⋯,λn​} ,满足 ∣λ1∣2⋯∣λn∣2≤∑∣aij∣2\vert \lambda_1\vert^2\cdots\vert \lambda_n\ve…

数据存储格式

文章目录数据存储格式1 行列存储比较2 ORC文件格式2.1 文件级2.1.1 Post scripts2.1.2 File Footer2.1.3 File MetaData2.2 Stripe级2.2.1 Stripe Footer2.2.2 Row Data2.2.3 Index Data3 Parquet文件格式3.1 Header3.2 Data3.2.1 Row Group3.2.2 Column Chunk3.2.3 Page3.3 Fo…

正则表达式判断数字

判断 正负整数,正负小数 表达式: ^[-]?([0]{1,1}|[1-9]{1,1}[0-9]*?)[.]?[\\d]{1,}$ import java.util.Scanner; import java.util.regex.Pattern; public static void main(String[] args) { Pattern pattern Pattern.compile("^[-]?([0]{1,1}|[1-9]{1,1}…

3dtiles数据解析

1.解析json文件 2.解析b3dm模型 (1)b3dm模型文件时二进制文件,其中包含glTF文件: 当使用tiny_gltf库解析glTF时,需要减去(28byte featuretable的byte batchTable的byte ): bool TinyGLTF::ExtractGltfFromMemory(Model *model,std::string…

JVM - 内存区域划分 类加载机制 垃圾回收机制

目录 1. 内存区域划分 2. 类加载 2.1 双亲委派模型 3. 垃圾回收机制 (GC) 3.1 如何判断一个对象是否为 "垃圾" 3.1 可达性分析 3.2 垃圾回收算法 1. 内存区域划分 JVM 作本质上是一个 Java 进程, 它启动的时候, 就会从操作系统申请一大块内存, 并且把这一大块…

CSS学习(七):盒子模型,圆角边框,盒子阴影和文字阴影

原文链接:CSS学习(七):盒子模型,圆角边框,盒子阴影和文字阴影 1. 盒子模型 页面布局要学习三大核心:盒子模型,浮动和定位。学习好盒子模型能非常好的帮助我们页面布局。 1.1 看透…

肽基脯氨酰异构酶底物:1926163-51-0,WFY-pSer-PR-AMC

WFYpSPR-AMC, Pin1底物类似显色底物H- trp - phi - tir - ser (PO₃H₂)-Pro-Arg-pNA。 磷酸肽在生命过程中发挥重要作用,磷酸化的位置在多肽上的Tyr、Ser,Thr,。目前磷酸肽合成一般都采用磷酸化氨基酸,目前使用的都是单苄基磷酸化…

Kafka Producer - 分区机制实战

Kafka Producer - 分区机制实战 上一篇介绍了kafka Producer 生产者发送数据的程序代码,以及对生产者分区机制的相关介绍,今天继续深入的了解下分区机制的原理、测试验证、自定义分区。 在学习之前先在本地机器搭建一个单机版的双节点集群环境&#xf…

80.【Spring5】

Spring《解耦》(一)、Spring 简介1.历史:2.Spring 目的3.Spring 引入4.优点5.Spring 七大模块组成:6.扩展(约定大于配置)(二)、IOC理论推导(Inversion of Contro)1.以前的三层分级2.现在对三层架构的更新3.什么是IOC(三)、HelloSpring1.怎么使用Spring?…

技术分享 | 缓存穿透 - Redis Module 之布隆过滤器

作者:贲绍华 爱可生研发中心工程师,负责项目的需求与维护工作。其他身份:柯基铲屎官。 本文来源:原创投稿 *爱可生开源社区出品,原创内容未经授权不得随意使用,转载请联系小编并注明来源。 一、场景案例 假…

设计模式-抽象工厂模式

1、什么是抽象工厂模式 抽象工厂(AbstractFactory)模式的定义:是一种为访问类提供一个创建一组相关或相互依赖对象的接口,且访问类无须指定所要产品的具体类就能得到同族的不同等级的产品的模式结构。抽象工厂模式是工厂方法模式的…

Tiny ImageNet 数据集分享

ImageNet官网上的数据集,动辄就100G,真的是太大了。 有需要Tiny Image Net 数据集的小伙伴可以点击这个下载链接: http://cs231n.stanford.edu/tiny-imagenet-200.zip数据集简介: Tiny ImageNet Challenge 来源于斯坦福 CS231N …

uwb无线定位系统的原理和介绍

uwb无线定位系统是在 uwb平台上部署的定位基站,通过发射无线信号,将 uwb定位系统部署在需要安装的位置,同时结合定位基站所支持工作环境条件(如:温度、湿度、光照等)和定位算法,实现在不同的地理…

使用elesticsearch-7.10.0版本连接elasticsearch-head

背景: 由于esasticsearch-5.5.1中没有登录,登出的安全校验,在安全测评时,经常被检查到高危漏洞,因此项目经常要升级到es7版本。 问题一:jdk版本不满足要求,提示如下 future versions of Elasti…