22- estimater使用 (TensorFlow系列) (深度学习)

news2024/9/29 13:19:18

知识要点

  • estimater 有点没理解透

  • 数据集是泰坦尼克号人员幸存数据.

  • 读取数据:train_df = pd.read_csv('./data/titanic/train.csv')

  • 显示数据特征:train_df.info()

  • 显示开头部分数据:train_df.head()

  • 提取目标特征:y_train = train_df.pop('survived')

  • 显示数据分布:train_df.describe()

  • 柱状图显示:train_df.age.hist(bins = 20)

  • 横向柱状图: train_df.sex.value_counts().plot(kind = 'barh')

  • pd.concat([train_df, y_train], axis = 1).groupby('sex').survived.mean().plot(kind = 'barh')  # 根据幸存率查看各类型的均值

  • 提取不同特征的统计: train_df.embark_town.value_counts()

  • 提取特征: vocab = train_df[categorical_column].unique()

  • tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list(categorical_column, vocab))   # one_hot 编码

  • dataset批次设置: dataset = dataset.repeat(epochs).batch(batch_size) 


1 导包

from tensorflow import keras
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

2 数据导入

train_df = pd.read_csv('./data/titanic/train.csv')
eval_df = pd.read_csv('./data/titanic/eval.csv')  # eval 评估   # 数据
print(train_df.info())
print(eval_df.info())

train_df.head()

 3 目标值获取

y_train = train_df.pop('survived')
y_eval = eval_df.pop('survived')

print(train_df.head())
print(eval_df.head())
print(y_train.head())
print(y_eval.head())

4 特征处理

train_df.describe()

# 观察年龄的数据分布
train_df.age.hist(bins = 20)

# 观察男女比例, 性别数量对比
train_df.sex.value_counts().plot(kind = 'barh')

# 仓位对比, 船舱类型
train_df['class'].value_counts().plot(kind = 'barh')

# 看港口人数
train_df['embark_town'].value_counts().plot(kind = 'barh')

pd.concat([train_df, y_train], axis = 1).groupby('sex').survived.mean().plot(kind = 'barh')

train_df.embark_town.value_counts()
'''Southampton    450
Cherbourg      123
Queenstown      53
unknown          1
Name: embark_town, dtype: int64'''
# 区分离散特征和连续特征
categorical_columns = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck', 'embark_town', 'alone']  # 离散特征
numeric_columns = ['age', 'fare']

# 接受特征
feature_columns = []
for categorical_column in categorical_columns:
    vocab = train_df[categorical_column].unique()  # 取出特征值
    print(vocab)
    # print(tf.feature_column.categorical_column_with_vocabulary_list(categorical_column, vocab))  # 创建vocabulary 的API
    # 将离散特征转换为one_hot形式的编码
    num = tf.feature_column.indicator_column(tf.feature_column.categorical_column_with_vocabulary_list(categorical_column, vocab))
    feature_columns.append(num)

# 数据类型转换
for numeric_column in numeric_columns:
    feature_columns.append(tf.feature_column.numeric_column(numeric_column, dtype = tf.float32))

5 dataset

# 创建生成dataset的方法
def make_dataset(data_df, label_df, epochs = 10, shuffle = True, batch_size = 32):
    dataset = tf.data.Dataset.from_tensor_slices((dict(data_df), label_df))
    if shuffle:
        dataset = dataset.shuffle(10000)  # 打乱, 洗牌
    dataset = dataset.repeat(epochs).batch(batch_size)
    return dataset
train_dataset = make_dataset(train_df, y_train, batch_size = 5)
# baseline_model
import os
output_dir = 'baseline_model'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
    
baseline_estimator = tf.compat.v1.estimator.BaselineClassifier(model_dir = output_dir, n_classes= 2)
# input_fn要求没有输入参数, 要求返回元组(x, y)或者可以返回(x, y)的dataset
baseline_estimator.train(input_fn = lambda : make_dataset(train_df, y_train, epochs = 100))
# baseline 是随机参数, 所以结果很差
baseline_estimator.evaluate(input_fn = lambda : make_dataset(eval_df, y_eval, epochs = 1,
                                                             shuffle = False, batch_size = 20))
# linear_model
linear_output_dir = 'linear_model'
if not os.path.exists(linear_output_dir):
    os.mkdir(linear_output_dir)
    
linear_estimator = tf.estimator.LinearClassifier(feature_columns = feature_columns,
                                                 model_dir = linear_output_dir)
linear_estimator.train(input_fn = lambda :make_dataset(train_df, y_train, epochs = 100))
# baseline 是随机参数, 所以结果很差
linear_estimator.evaluate(input_fn = lambda : make_dataset(eval_df, y_eval, epochs = 1, shuffle = False,
                                                           batch_size = 20))
dnn_output_dir = './dnn_model'
if not os.path.exists(dnn_output_dir):
    os.mkdir(dnn_output_dir)
    
dnn_estimator = tf.estimator.DNNClassifier(model_dir = dnn_output_dir,  # 存储地址
                                           n_classes= 2,  # 二分类
                                           feature_columns = feature_columns, 
                                           hidden_units = [128, 128],   # 隐藏层
                                           activation_fn = tf.nn.relu,  # 算法
                                           optimizer = 'Adam')  # 损失函数, 优化:optimizer
# dnn_estimator.train(input_fn = lambda : make_dataset(train_df, y_train, epochs = 100))

dnn_estimator.train(input_fn = lambda :make_dataset(train_df, y_train, epochs = 100))
dnn_estimator.evaluate(input_fn = lambda : make_dataset(eval_df, y_eval, epochs = 1,
                                                        shuffle = False, batch_size = 20))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/381320.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Web前端:四大Web应用开发趋势和技术

就像其他行业一样,web应用程序开发每年都会经历巨大的变化。就像人们说的,变化是技术中唯一不变的东西。因此,我们这里有一些你可以期待的市场变化。Web应用开发趋势和技术1.市场对聊天机器人和人工智能寄予厚望已经说过很多次,也…

java 面试

面试目录概述需求:设计思路实现思路分析1.面试概要参考资料和推荐阅读Survive by day and develop by night. talk for import biz , show your perfect code,full busy,skip hardness,make a better result,wait for change,challenge Survive. happy f…

JSTL核心库的简单使用

JSTL核心库的简单使用 7.1考试重点 7.1.1c:out输出数据 考试重点就是c的相关的 jar包下载地址:Apache Tomcat - Apache Taglibs Downloads 看会典型应用就可以<% page contentType"text/html;charsetUTF-8" language"java" %> <% taglib uri"…

DolphinDB 通过 Telegraf + Grafana 实现设备指标的采集监控和展示

基于原始数据采集的可视化监控是企业确保设备正常运行和安全生产的重要措施。本文详细介绍了如何从DolphinDB 出发&#xff0c;借助 Telegraf 对设备进行原始数据采集&#xff0c;并通过 Grafana 实现数据的可视化&#xff0c;从而实现设备指标的实时监控。1. 概览Telegraf 是 …

Mybatis-plus逻辑删除更新字段

MybatisPlus版本 <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-boot-starter</artifactId><version>3.4.2</version> </dependency> <dependency><groupId>com.baomidou</groupId&g…

优思学院|DFMEA是全球制造业的必修课!

DFMEA&#xff08;Design Failure Mode and Effects Analysis&#xff09;是一种分析技术&#xff0c;在产品设计的早期阶段识别和解决潜在的失效问题。它通过分析设计的各个方面&#xff0c;识别潜在的失效模式和影响&#xff0c;并提出相应的改进措施&#xff0c;以减少失效的…

服装企业 采购系统

技术&#xff1a;Java、JSP等摘要&#xff1a;随着我国市场经济的不断发展,企业之间的竞争越来越激烈,只有对企业库存物资资源全面掌握,充分发挥闲置资源的利用,对资源进行优化配置,才能使企业效益达到最大化。只有通过规范科学的物资管理手段,才能节省物资采购成本,提高工作效…

Java——面向对象

目录 前言 一、什么是面向对象&#xff1f; 面向过程 & 面向对象 面向对象 二、回顾方法的定义和调用 方法的定义 方法的调用 三、类与对象的创建 类和对象的关系 创建与初始化对象 四、构造器详解 五、创建对象内存分析 六、封装详解 七、什么是继承&#x…

Unity TextMeshPro

Unity TextMeshPro 简介 TextMeshPro(也简称为TMP)号称是Unity的终极文本解决方案&#xff0c;它是Unity 的 UI 文本和旧版文本网格体的完美替代品。 功能强大且易于使用&#xff0c;使用高级文本渲染技术以及一组自定义着色器;提供实质性的视觉质量改进&#xff0c;同时在文…

Python基础教程(入门教程),初学者学Python编程如何快速入门?

【导语】Python是一种跨平台的计算机程序设计语言&#xff0c;通过Python编程&#xff0c;我们能够解决现实生活中的很多困难&#xff0c;现如今&#xff0c;我们工作中的许多工作都需要通过编写计算机软件来完成&#xff0c;那么初学者学Python编程如何快速入门呢?下面就来给…

【用Group整理目录结构 Objective-C语言】

一、接下来,我们看另外一个知识点,怎么用Group把这一堆乱七八糟的文件给它整理一下,也算是封装一下吧, 1.这一堆杂乱无章的文件: 那么,哪些类是属于模型呢,哪些类是属于视图呢,哪些类是属于控制器呢, 我们接下来通过Group的方式,来给它们分一下类, 这样看起来就好…

虚拟机安装ubuntu窗口自适应问题以及软件窗口显示不全解决方法

这部分查了很多博客&#xff0c;首先感谢前人栽树。 直接上我在安装过程中的有效解决步骤&#xff0c; 文后会描述遇到的非有效解决步骤&#xff0c;以供遇到相同问题的同学参考。 打开终端窗口 (ctrlaltt),当然肯定是一条一条的执行。 sudo apt-get update sudo apt-get upg…

第三章-OpenCV基础-7-形态学

前置 形态学主要是从图像中提取分量信息&#xff0c;该分量信息通常是图像理解时所使用的最本质的形状特征,对于表达和描绘图像的形状有重要意义。 大体就是通过一系列操作让图像信息中的关键信息更加凸出。同时&#xff0c;形态学的操作都是基于灰度图进行。 相关操作最主要…

Filebeat处理多行换行的问题

问题&#xff1a;在使用filebeatelabscience或者filebeatelk 又或者其他桥接器的时候&#xff0c;因为filbeat默认使用单行显示的原因&#xff0c;但日志出现堆栈错误或其他多行日志时会出现如下错误处理办法&#xff1a;1.固定日志格式 这里不展开说明2.匹配日志 找到你的file…

【Flutter入门到进阶】Flutter基础篇---布局

1 GridView网格布局组件 1.1 说明 1.1.1 图例 1.1.2 说明 GridView网格布局在实际项目中用的也是非常多的&#xff0c;当我们想让可以滚动的元素使用矩阵方式排列的时 候。此时我们可以用网格列表组件GridView实现布局 GridView创建网格列表主要有下面三种方式 1、可以通过Gr…

纳睿雷达在科创板上市:总市值达93亿元,2022年营收约2亿元

3月1日&#xff0c;广东纳睿雷达科技股份有限公司&#xff08;下称“纳睿雷达”&#xff0c;SH:688522&#xff09;在科创板上市。本次上市&#xff0c;纳睿雷达的发行价为46.68元/股&#xff0c;发行数量为3866.68万股&#xff0c;募资总额约为18.05亿元。 上市首日&#xff…

关于“腺样体面容”的两大认知误区,你需要了解一下

仅供医学专业人士阅读参考看完不要再中招了&#xff01;随着父母越来越重视孩子的外表和健康成长&#xff0c;“腺样脸”几乎成为聚会上不可避免的热门话题。在各种交流和讨论中&#xff0c;你经常听到朋友焦虑有点高兴地说&#xff1a;“虽然我的孩子总是张嘴睡觉&#xff0c;…

pandas: 三种算法实现递归分析Excel中各列相关性

目录 前言 目的 思路 代码实现 1. 循环遍历整个SDGs列&#xff0c;两两拿到数据 2. 调用pandas库函数直接进行分析 完整源码 运行效果 总结 前言 博主之前刚刚被学弟邀请参与了2023美赛&#xff0c;这也是第一次正式接触数学建模竞赛&#xff0c;现在已经提交等待结果…

【自动化测试】一位自动化测试工程师居然不会封装框架?神秘自动化测试框架......

目录&#xff1a;导读前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09;前言 自动化测试框架 自…

02 Android基础--service

02 Android基础--service什么是service&#xff1f;service的demo使用Service的种类前台service的使用背景什么是service&#xff1f; Service(服务)是一个一种可以在后台执行长时间运行操作而没有用户界面的应用组件。 服务分为两种形式&#xff1a;非绑定状态与绑定状态。 非…