案例:DNN进行分类

news2025/1/11 12:42:09

7.2 案例:DNN进行分类

学习目标

  • 目标
    • 知道tf.data.Dataset的API使用
    • 知道tf.feature_columnAPI使用
    • 知道tf.estimatorAPI使用
  • 应用

7.2.1 数据集介绍

对鸢尾花进行分类:概览

本文档中的示例程序构建并测试了一个模型,此模型根据鸢尾花的花萼花瓣大小将其分为三种不同的品种。

 

 

从左到右:山鸢尾(提供者:Radomil,依据 CC BY-SA 3.0 使用)、变色鸢尾(提供者:Dlanglois,依据 CC BY-SA 3.0 使用)和维吉尼亚鸢尾(提供者:Frank Mayfield,依据 CC BY-SA 2.0 使用)。

数据集

鸢尾花数据集包含四个特征和一个目标值。这四个特征确定了单株鸢尾花的下列植物学特征:

  • 花萼长度
  • 花萼宽度
  • 花瓣长度
  • 花瓣宽度

该标签确定了鸢尾花品种,品种必须是下列任意一种:

  • 山鸢尾 (0)
  • 变色鸢尾 (1)
  • 维吉尼亚鸢尾 (2)

我们的模型会将该标签表示为 int32 分类数据。

下表显示了数据集中的三个样本:

花萼长度花萼宽度花瓣长度花瓣宽度品种(标签)
5.13.31.70.50(山鸢尾)
5.02.33.31.01(变色鸢尾)
6.42.85.62.22(维吉尼亚鸢尾)

7.2.2 构建模型

该程序会训练一个具有以下拓扑结构的深度神经网络分类器模型:

  • 2 个隐藏层。
  • 每个隐藏层包含 10 个节点。

下图展示了特征、隐藏层和预测(并未显示隐藏层中的所有节点):

 

 

Estimator 是从 tf.estimator.Estimator衍生而来的任何类。要根据预创建的 Estimator 编写 TensorFlow 程序,您必须执行下列任务:

  • 1、 创建一个或多个输入函数。
  • 2、定义模型的特征列。
  • 3、实例化 Estimator,指定特征列和各种超参数。
  • 5、在 Estimator 对象上调用一个或多个方法,传递适当的输入函数作为数据的来源。

7.2.3 数据获取与输入函数

在iris_data文件中通过

# 获取数据
def load_data(y_name='Species'):
    """获取训练测试数据"""
    train_path, test_path = maybe_download()

    train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
    train_x, train_y = train, train.pop(y_name)

    test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
    test_x, test_y = test, test.pop(y_name)

    return (train_x, train_y), (test_x, test_y)

(train_x, train_y), (test_x, test_y) = iris_data.load_data()

输入函数是返回 tf.data.Dataset 对象的函数,此对象会输出下列含有两个元素的元组:

  • features:Python 字典,其中:
    • 每个键都是特征的名称。
    • 每个值都是包含此特征所有值的数组。
  • label - 包含每个样本的标签值的数组。
def input_evaluation_set():
    features = {'SepalLength': np.array([6.4, 5.0]),
                'SepalWidth':  np.array([2.8, 2.3]),
                'PetalLength': np.array([5.6, 3.3]),
                'PetalWidth':  np.array([2.2, 1.0])}
    labels = np.array([2, 1])
    return features, labels

输入函数可以以您需要的任何方式生成 features 字典和 label 列表。不过,我们建议使用 TensorFlow 的 Dataset API,它可以解析各种数据。概括来讲,Dataset API 包含下列类:

 

 

  • Dataset - 包含创建和转换数据集的方法的基类。您还可以通过该类从内存中的数据或 Python 生成器初始化数据集。
  • TextLineDataset - 从文本文件中读取行。
  • TFRecordDataset - 从 TFRecord 文件中读取记录。
  • FixedLengthRecordDataset - 从二进制文件中读取具有固定大小的记录。
  • Iterator - 提供一次访问一个数据集元素的方法。
def train_input_fn(features, labels, batch_size):
    """
    """
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    return dataset


def eval_input_fn(features, labels, batch_size):
    """
    """
    features = dict(features)
    if labels is None:
        inputs = features
    else:
        inputs = (features, labels)

    dataset = tf.data.Dataset.from_tensor_slices(inputs)

    dataset = dataset.batch(batch_size)

    return dataset


dataset.make_one_shot_iterator().get_next()

7.2.3 特征处理tf.feature_colum

特征处理tf.feature_column

Estimator 的 feature_columns 参数来指定模型的输入。特征列在输入数据(由input_fn返回)与模型之间架起了桥梁。要创建特征列,请调用 tf.feature_column 模块的函数。本文档介绍了该模块中的 9 个函数。如下图所示,除了 bucketized_column 外的函数要么返回一个 Categorical Column 对象,要么返回一个 Dense Column 对象。

 

 

要创建特征列,请调用 tf.feature_column 模块的函数。本文档介绍了该模块中的 9 个函数。如下图所示,除了 bucketized_column 外的函数要么返回一个 Categorical Column 对象,要么返回一个 Dense Column 对象。

 

 

  • Numeric column(数值列)

Iris 分类器对所有输入特徵调用 tf.feature_column.numeric_column 函数:SepalLength、SepalWidth、PetalLength、PetalWidth

tf.feature_column 有许多可选参数。如果不指定可选参数,将默认指定该特征列的数值类型为 tf.float32

numeric_feature_column = tf.feature_column.numeric_column(key="SepalLength")
  • Bucketized column(分桶列)

通常,我们不直接将一个数值直接传给模型,而是根据数值范围将其值分为不同的 categories。上述功能可以通过 tf.feature_column.bucketized_column 实现。以表示房屋建造年份的原始数据为例。我们并非以标量数值列表示年份,而是将年份分成下列四个分桶:

# 首先,将原始输入转换为一个numeric column
numeric_feature_column = tf.feature_column.numeric_column("Year")

# 然后,按照边界[1960,1980,2000]将numeric column进行bucket
bucketized_feature_column = tf.feature_column.bucketized_column(
    source_column = numeric_feature_column,
    boundaries = [1960, 1980, 2000])
  • Categorical identity column(类别标识列)

输入的列数据就是为固定的离散值,假设您想要表示整数范围 [0, 4)。在这种情况下,分类标识映射如下所示:

 

 

identity_feature_column = tf.feature_column.categorical_column_with_identity(
    key='my_feature_b',
    num_buckets=4) # Values [0, 4)
  • Categorical vocabulary column(类别词汇表)

我们不能直接向模型中输入字符串。我们必须首先将字符串映射为数值或类别值。Categorical vocabulary column 可以将字符串表示为one_hot格式的向量。

 

 

vocabulary_feature_column =
    tf.feature_column.categorical_column_with_vocabulary_list(
        key=feature_name_from_input_fn,
        vocabulary_list=["kitchenware", "electronics", "sports"])
  • Hashed Column(哈希列)

处理的示例都包含很少的类别。但当类别的数量特别大时,我们不可能为每个词汇或整数设置单独的类别,因为这将会消耗非常大的内存。对于此类情况,我们可以反问自己:“我愿意为我的输入设置多少类别?

hashed_feature_column =
    tf.feature_column.categorical_column_with_hash_bucket(
        key = "some_feature",
        hash_bucket_size = 100) # The number of categories

 

 

  • 其它列处理
    • Crossed column(组合列)

7.2.4 实例化 Estimator

鸢尾花分类本身问题不复杂,一般算法也能够解决。这里我么你选用DNN做测试tf.estimator.DNNClassifier 我们将如下所示地实例化此 Estimator:

我们已经有一个 Estimator 对象,现在可以调用方法来执行下列操作:

  • 训练模型。
  • 评估经过训练的模型。
  • 使用经过训练的模型进行预测。

训练模型

通过调用 Estimator 的 train 方法训练模型,如下所示:

# Train the Model.
classifier.train(
    input_fn=lambda:iris_data.train_input_fn(train_x, train_y, args.batch_size),
    steps=args.train_steps)

我们将 input_fn 调用封装在 lambda 中以获取参数,同时提供一个不采用任何参数的输入函数,正如 Estimator 预计的那样。steps 参数告知方法在训练多步后停止训练。

评估经过训练的模型

模型已经过训练,现在我们可以获取一些关于其效果的统计信息。以下代码块会评估经过训练的模型对测试数据进行预测的准确率:

# Evaluate the model.
eval_result = classifier.evaluate(
    input_fn=lambda:iris_data.eval_input_fn(test_x, test_y, args.batch_size))

print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))

与我们对 train 方法的调用不同,我们没有传递 steps 参数来进行评估。我们的 eval_input_fn 只生成一个周期的数据。

运行此代码会生成以下输出(或类似输出):

Test set accuracy: 0.967

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/196376.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《MFC编程》:MFC程序执行流程

《MFC编程》:MFC程序启动《MFC编程》:MFC程序启动入口函数执行流程CWinApp的成员视频链接《MFC编程》:MFC程序启动 入口函数 MFC程序的入口函数与win32程序一样,都是从WinMain入口。 但是MFC库已经实现了WinMain函数&#xff0…

ORA-65096: invalid common user or role 解决方法

问题描述 oracle 12C 创建数据库时报错: ORA-65096: invalid common user or role name例如: SQL> create user rui identified by oracle;create user rui identified by oracle* ERROR at line 1: ORA-65096: invalid common user or role name原…

记录1-两数之和

给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。你可以假设每种输入只会对应一个答案。但是,数组中同一个元素在答案里不能重复出现。你可以按任意顺序返回答案…

分享111个JS图片切换特效,总有一款适合您

分享111个图片切换,总有一款适合您 下面是文件的名字,我放了一些图片,文章里不是所有的图主要是放不下..., 111个图片切换下载链接:https://pan.baidu.com/s/1iGzOzU3WZbjBF21dZzoH9w?pwdqi5u 提取码:qi…

Linux常见指令大全(二)

🌹作者:云小逸 📝个人主页:云小逸的主页 📝Github:云小逸的Github 🤟motto:要敢于一个人默默的面对自己,强大自己才是核心。不要等到什么都没有了,才下定决心去做。种一颗树,最好的时间是十年前…

2023年第一篇来谈谈效率

晚上临走的时候和同事聊了聊关于效率的问题,暂且称呼为A同学。借着和A同学的这次畅谈记录下这段时间的所负责的数据迁移过程。 数据迁移的整体内容并不复杂。主要内容如下 我们在做事情的时候总会遇到这件事情所关联的其他问题。 不要带着情绪去工作 书写脚本的时候…

【HBase高级】6. HBase数据结构(下)——LSM树数据结构、布隆过滤器、StoreFiles(HFile)结构

5.3 LSM树数据结构 1、简介 传统关系型数据库,一般都选择使用B树作为索引结构,而在大数据场景下,HBase、Kudu这些存储引擎选择的是LSM树。LSM树,即日志结构合并树(Log-Structured Merge-Tree)。 LSM树主要目标是快速建立索引B树…

redis加锁的几种方法

1. redis加锁分类 redis能用的的加锁命令分表是INCR、SETNX、SET 2. 第一种锁命令INCR 这种加锁的思路是, key 不存在,那么 key 的值会先被初始化为 0 ,然后再执行 INCR 操作进行加一。 然后其它用户在执行 INCR 操作进行加一时,…

3.4 内部类

文章目录1.概述2.特点3.内部类入门案例4.成员内部类4.1 被private修饰4.2 被static修饰5.局部内部类6.匿名内部类1.概述 如果一个类存在的意义就是为指定的另一个类,可以把这个类放入另一个类的内部。 就是把类定义在类的内部的情况就可以形成内部类的形式。 A类中…

【内网安全】——CS操作指南(二)

作者名:白昼安全主页面链接: 主页传送门创作初心: 一切为了她座右铭: 不要让时代的悲哀成为你的悲哀专研方向: web安全,后渗透技术每日emo:关心和细节吗?注意:我这里的cs…

Android MVI框架的使用

AndroidMviFrame AndroidMviFrame 是一个Android简单易用的项目框架 文档下面会对框架中所使用的一些核心技术进行阐述。该框架作为技术积累的产物,会一直更新维护,如果有技术方面的谈论或者框架中的错误点,可以在 GitHub 上提 Issues&…

DAMA认证(CDGA/CDGP)证书好考吗

随着数字化经济的不断发展,企业对数据重视程度越来越高,致使越来越多得数字人关注到DAMA认证。很多小伙伴都会有这样的疑问,DAMA认证(CDGA/CDGP认证)好考吗?通过率怎么样?今天小编就在这里做一下简单的说明…

UniRx之操作符详解-Linq语法

前言 UniRx中由很多操作符,注意要分为三类 Linq操作符,和Linq语法风格一致Rx操作符,从Rx.Net库继承下来的操作符。UniRx操作符,UniRx针对Unity的独有操作符。 Rx和Linq Linq是微软的一项技术,新增一种自然查询的SQ…

时间序列预测

问题简介 简单来说,时间序列是按照时间顺序,按照一定的时间间隔取得的一系列观测值,比如我们上边提到的国内生产总值,消费者物价指数,利率,汇率,股票价格等等。时间间隔可以是日,周…

数字IC设计 Synopsys EDA Tools的安装补充

数字IC Synopsys 七件套的Ubuntu安装步骤 推荐大佬的安装教程,本人亲测可用,在这里表示十分感谢! 数字IC设计的第一步——Synopsys EDA Tools的安装 跟着大佬的教程仔细点可以一步到位的! 在这里备忘本人遇到的几个粗心导致的问…

浅谈Spring IoC容器

目录 1.IoC容器 2.依赖注入 1.IoC容器 IOC: Inversion of Control,是一种设计思想。 在spring框架中,Spring 通过IoC容器进行管理所有Java对象的实例化和初始化,控制对象与对象之间的依赖关系。 IoC管理的对象称为Bean,它与使…

“华为杯”研究生数学建模竞赛2005年-【华为杯】A题:行车时间估计和最优路线选择(附获奖论文)

赛题描述 A: Highway Traveling time Estimate and Optimal Routing Ⅰ Highway traveling time estimate is crucial to travelers. Hence, detectors are mounted on some of the US highways. For instance, detectors are mounted on every two-way six-lane highways o…

MySQL 百万级数据,如何做分页查询?

随着业务的增长,数据库的数据也呈指数级增长,拿订单表为例,之前的订单表每天只有几千个,一个月下来不超过十万。而现在每天的订单大概就是2w,目前订单表的数据已经达到了700w。这带来了各种各样的问题,今天…

国产ETL工具/ETL 产品 (BeeDI ) 集团财务 双向同步 审核平台

项目需求核心 实时同步、双向同步、部分同步、日志解析同步、断点续传 项目需求概要 35分公司财务数据实时同步汇总中心平台 🔛 中心平台财务数据实时同步分发35分公司 项目需求内容 35分公司数据中部分表数据同步到中心库对应表,10张表分公司表年数…

【MyBatis】mybatis缓存机制

1. 缓存基础知识:缓存: cache缓存的作用: 通过减少IO的方式, 来提高程序的执行效率mybatis缓存包括:一级缓存: 讲话查询的数据存储到SqlSession中二级缓存: 将查询的数据存储到SqlSessionFactory中或者集成第三方的缓存: 比如EhCache...mybatis缓存只针对DQL语句, 也就是说缓存…