Spark 3.0 - 11.ML 随机森林实现二分类实战

news2025/1/18 17:01:57

目录

一.引言

二.随机森林实战

1.数据预处理

2.随机森林 Pipeline

3.模型预测与验证

三.总结


一.引言

之前介绍了 决策树 ,而随机森林则可以看作是多颗决策树的集合。在 Spark ML 中,随机森林中的每一颗树都被分配到不同的节点上进行并行计算,或者在一些特定的条件下,单独的一颗决策树也可以并行化运算,其中每一棵决策树之间没有相关性。

随机森林在运行的时候,每当有一个新的数据传输到系统中,都会由随机森林的每一颗决策树同时进行处理,如果处理一个连续常数,就会取所有树的平均值作为结果,这里可以看做是等权重;如果是非连续结果,就选择所有决策树结果中最多的一项,类似于投票法。

Tips:

训练过程中不同决策树的随机性来源于每次迭代中对原始数据进行的二次采样 boostrap,该采样方法可以使得决策树获得不同的训练集,从而在树节点上拆分不同的随机特征子集,这也是随机森林随机性的由来。

二.随机森林实战

随机森林支持回归与分类,回归可以看做是等权重的平均,分类可以看做是少数服从多数的投票。

1.数据预处理

数据存储为 libsvm 格式,最开始为标签 label,后面为不同特征的不同取值:

    val spark = SparkSession
      .builder //创建spark会话
      .master("local")
      .appName("RandomForestClassifierExample") //设置名称
      .getOrCreate() //创建会话变量

    // 读取文件,装载数据到spark dataframe 格式中
    val data = spark.read.format("libsvm").load("/Users/xudong11/sparkV3/src/main/scala/org/example/RandomForest/sample_libsvm_data.txt")

    // 搜索标签,添加元数据到标签列
    // 对整个数据集包括索引的全部标签都要适应拟合
    val labelIndexer = new StringIndexer()
      .setInputCol("label")
      .setOutputCol("indexedLabel")
      .fit(data)

    // 自动识别分类特征,并对其进行索引
    // 设置maxCategories以便大于4个不同值的特性被视为连续的。
    val featureIndexer = new VectorIndexer()
      .setInputCol("features")
      .setOutputCol("indexedFeatures")
      .setMaxCategories(4)
      .fit(data)

    // 按照7:3的比例进行拆分数据,70%作为训练集,30%作为测试集
    val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

数据的预处理分为四步:

A - spark.read.format('libsvm') 负责解析 libsvm 格式数据

B - StringIndexer 负责将标签重新匹配,出现次数最多的标签索引为 0,以此类推

C - VectorIndexer 负责将特征重新映射,根据 MaxCategories 的取值决定特征是连续还是离散

D - randomSplit 负责将原始数据按照给定 RatioArray 进行划分,返回 Array[DataSet[Row]]

其中 B 和 C 提到的两个 indexer 可能解释不够清晰,决策树实战 一文中有两个函数方法的详细解释与示例,大家可以参考。

2.随机森林 Pipeline

这一步主要结合前面数据预处理部分定义的 Transformer 并添加 RandomForest 构建 pipeline fit 数据,从而获取最终的模型。RF 模型除了设定输入输出列外,还定义了 numTrees 代表随机森林中决策树的数量,该参数至少为1,默认为20。labelConverter 负责将 labelIndexer 转换后的标签再重新映射回去。通过组装4个部件我们得到了最终的 pipeline 并应用于到数据预处理中得到的 TrainData。

    // 建立一个决策树分类器,并设置森林中含有10颗树
    val rf = new RandomForestClassifier()
      .setLabelCol("indexedLabel")
      .setFeaturesCol("indexedFeatures")
      .setNumTrees(10)

    // 将索引标签转换回原始标签
    val labelConverter = new IndexToString()
      .setInputCol("prediction")
      .setOutputCol("predictedLabel")
      .setLabels(labelIndexer.labelsArray(0))

    // 把索引和决策树链接(组合)到一个管道(工作流)之中
    val pipeline = new Pipeline()
      .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))

    // 载入训练集数据正式训练模型
    val model = pipeline.fit(trainingData)

Tips 关于树的数量与深度:

通常情况下,随机森林的准确性与树的数量成正比,但是随之而来的是更多地训练与预测成本,除此之外也与自己的数据规模、特征多少有关,国外的同学在 29 个常规数据集上测试发现在 128 棵树之后随机森林的准确性不再有显著的改进。

其次关于树的深度,这个其实和单棵决策树是相同的,在特征较多的情况下,如果存在过拟合的情况需要通过剪枝解决。

除此之外,可以通过随机选择特征结合 out-of-bag [OOB] 袋外误差率评估模型效果,如果将某个特征换为随机值,OOB-Error 没有明显增加则代表当前特征不显著,反之特征显著。

3.模型预测与验证

这一步利用上一步 Pipeline fit 得到的 Transformer Model 进行 transform 对测试数据转换,随后使用 Evaluator 进行评估。

    // 使用测试集作预测
    val predictions = model.transform(testData)

    // 选择一些样例进行显示
    predictions.select("predictedLabel", "label", "features").show(5)

    // 计算测试误差
    val evaluator = new MulticlassClassificationEvaluator()
      .setLabelCol("indexedLabel")
      .setPredictionCol("prediction")
      .setMetricName("accuracy")
    val accuracy = evaluator.evaluate(predictions)
    println(s"Test Error = ${1.0 - accuracy}")

    val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
    println(s"Learned classification forest model:\n ${rfModel.toDebugString}")

    spark.stop()

最后调用 asInstanceOf 将 pipeline 中的 RF Model 转换出来,打印 toDebugString 获取当前的随机森林情况。

...... 此处忽略 Tree2 - Tree7 .....

numTrees = 10,分别为 Tree0 -> Tree9,可以看到大家的 weight 均为 1.0,如果是 GBDT 的情况下,每棵树的权重也不同。每棵决策树的 If else 可以看做是分界点、其中分界条件即为特征划分选择。最终将 10 棵树的 Predict 进行投票法,选取大部分决策树都认同的结果作为随机森林的预测结果。

三.总结

随机森林的本质就是建立多颗决策树,然后取得所有决策树的平均值或者以投票的方式分类。随机森林是用于分类和回归最成功的机器学习模型之一,其结合多颗决策树,以降低过拟合的风险。与决策树相同,随机森林可以实现特征的自动选择,上述 IF ELSE 的每一个决策节点都可以看做是区分度较大的特征。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/84970.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Crack:Aspose.3D for .NET 22.11.X

Aspose.3D for .NETAspose.3D for .NET 是一个功能丰富的游戏软件和计算机辅助设计 (CAD) API,无需任何 3D 建模和渲染软件依赖即可操作文档。API 支持 Discreet3DS、WavefrontOBJ、FBX(ASCII、二进制)、STL(ASCII、二进制&#x…

[附源码]Nodejs计算机毕业设计基于web的家教管理系统Express(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程。欢迎交流 项目运行 环境配置: Node.js Vscode Mysql5.7 HBuilderXNavicat11VueExpress。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分…

linux信号:SIGINT、SIGKILL、SIGSTOP、SIGCONT

目录 1. SIGINT 2. SIGKILL 3. SIGSTOP与SIGCONT 简介 SignalStandardActionCommentSIGINTP1990TermInterrupt form keybordSIGKILLP1990TermKill signalSIGSTOPP1990TermInterrupt form keybord1. SIGINT 我们在shell交互式进程中常用的ctrl c 就是对当前运行的程序进行…

[附源码]Node.js计算机毕业设计电影票网上订票系统Express

项目运行 环境配置: Node.js最新版 Vscode Mysql5.7 HBuilderXNavicat11Vue。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分离等等。 环境需要 1.运行环境:最好是Nodejs最新版,我…

2022年山东食品安全管理员模拟试题及答案

百分百题库提供食品安全管理员考试试题、食品安全管理员考试预测题、食品安全管理员考试真题、食品安全管理员证考试题库等,提供在线做题刷题,在线模拟考试,助你考试轻松过关。 一、单选题 1.下列哪项措施与保证食品安全无关? A…

2022gwb_web3

可以通过反序列化出一个 Webclome 类从而任意构造原生类,但只能调用 getSize 方法获取文件或目录 的大小,试了试直接拿根目录的 /flag 就别想了,先看看网站目录有没有藏什么东西(扫目录什么也扫 不出来),EX…

再学C语言2:概览

重新把C语言梳理一遍,学习在VSCode中进行C语言编程 一、C语言起源 1972年,贝尔实验室的Dennis Ritchie在C语言的基础上设计出一种新的语言,即C语言 C是作为从事实际编程工作的程序员的一种工具儿出现,是为编程人员开发的语言 二…

【脚本项目源码】Python制作桌面宠物,这么可爱的萌宠你不想拥有吗?

前言 本文给大家分享的是如何通过利用Python制作桌面宠物,废话不多直接开整~ 开发工具 Python版本: 3.6 相关模块: random模块 os模块 cfg模块 sys模块 PyQt5模块 环境搭建 安装Python并添加到环境变量,pip安装需要的相…

skynet设计原理和使用

skynet设计原理一、多核并发编程方式二、skynet2.1、skynet简介2.2、环境准备2.3、编译安装2.4、Actor 模型2.5、消息队列2.6、actor公平调度三、skynet的使用3.1、第一个skynet程序3.2、skynet网络消息3.3、skynet定时消息3.4、skynet actor间消息四、vscode调试skynet总结后言…

Python比较难的知识点: 迭代器与生成器

迭代器与生成器是Python比较难的知识点, 在学Python之前, 我已经有了多年的C语言与MATLAB的使用经验了, 但是学这些知识点, 还是有一定的困难, 总觉得是一知半解的. 现在, 经过一段时间的学习和梳理, 感觉是搞懂了, 写下这篇文章与大家分享. 学习具体概念技术之前, 得知道这些…

c++ - 第17节 - AVL树和红黑树

1.AVL树 1.1.AVL树的概念 二叉搜索树虽可以缩短查找的效率,但如果数据有序或接近有序二叉搜索树将退化为单支树,查找元素相当于在顺序表中搜索元素,效率低下。因此,两位俄罗斯的数学家G.M.Adelson-Velskii和E.M.Landis在1962年发…

Mybatis:MyBatis的逆向工程(10)

Mybaits笔记框架:https://blog.csdn.net/qq_43751200/article/details/128154837 Mybatis中文官方文档: https://mybatis.org/mybatis-3/zh/index.html Mybati的逆向工程1. 正向工程 VS 逆向工程2. 创建逆向工程的步骤(MyBatis3Simple清新简洁…

Go 1.20要来了,看看都有哪些变化-第1篇

前言 Go官方团队在2022.12.08发布了Go 1.20 rc1(release candidate)版本,Go 1.20的正式release版本预计会在2023年2月份发布。 让我们先睹为快,看看Go 1.20给我们带来了哪些变化。(文末有彩蛋!) 安装方法: $ go install golan…

[附源码]Nodejs计算机毕业设计基于web的火车订票管理系统Express(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程。欢迎交流 项目运行 环境配置: Node.js Vscode Mysql5.7 HBuilderXNavicat11VueExpress。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分…

阅读器消退之际:文石造产品,掌阅塑生态

配图来自Canva可画 阅读器正在变得小众,似乎自Kindle以来营造的那种“阅读的生活方式”,已经被证明是一个伪命题:会阅读的人,无论如何都会去阅读;不会阅读的人,阅读器的归宿终究是一个“泡面盖”。于是&am…

Android原生项目接入flutter_boost4.0

折腾了好几天,经验思维导致的,记录一下踩坑。 官方接入步骤官方文档 接入原生,就3个步骤,我都能折腾好久,浪费时间。 flutter部分很简单,按文档配置就行,在pubspec.yaml依赖就好了。 &#…

Ceph性能瓶颈分析与优化(混合盘篇)

原文链接: Ceph性能瓶颈分析与优化(混合盘篇) - 知乎背景ssdhdd的混合盘场景在各个存储厂商中算是一种典型应用场景。 但是经过测试(4k随机写)发现,加了nvme ssd做ceph的wal和db后,性能提升仅一倍以内且nvme盘性能余量较大。所以希望通过对问题瓶颈进行…

目标检测数据标注案例-高清地图中障碍物(汽车)标注

计算机视觉在无人机中领域中有何作用? 无人机能够在空中识别、分类和追踪目标。无人机的摄像头和感应器可以捕获数据并进行分析,以提取重要信息。 AI可以自动提取视觉数据信息,准确识别、说明和追踪图像和视频中的目标。例如高空检测工作,…

Nacos 配置中心之长轮询--客户端

先来看下长轮询调用的链路 客户端 入口 在 NacosConfigService 初始化的时候,会初始化两个组件 一是网络组件,也就是http数据处理的 (起作用的是 ServerHttpAgent)二是客户端的长轮询ClientWorker public NacosConfigService(Properties properties) throws NacosException…

本地连接docker mysql

1.拉取镜像 docker pull mysql 2.启动mysql实例容器 docker run --name mysql -p 3307:3306 -e MYSQL_ROOT_PASSWORDmysql_pw -d mysql --name 为mysql的实例设置别名。 -p 3307为对外暴露的端口。3306是内部端口 -e MYSQL_ROOT_PASSWORD 设置mysql登录密码 -d 以守…