文章目录
- 什么是随机森林?
- 随机森林的优缺点
- 随机森林示例——鸢尾花分类
什么是随机森林?
随机森林算法是机器学习、计算机视觉等领域内应用极为广泛的一个算法,它不仅可以用来做分类,也可用来做回归即预测,随机森林机由多个决策树构成,相比于单个决策树算法,它分类、预测效果更好,不容易出现过度拟合的情况。
常应用于以下类型的场景:
- 预测用户贷款是否能够按时还款;
- 预测用户是否会购买某件商品等等
官网:分类和回归
随机森林的优缺点
优点:
-
可以处理高纬度的数据;
-
训练之前不需要特意的做特征选择;
-
建立很多树,预防了过拟合风险;
缺点:
-
计算量相对于决策树很大,性能开销很大。
-
可能会导致有些数据集没有训练到,但这种几率很小。
-
分裂的时候,偏向于选择取值较多的特征。
随机森林示例——鸢尾花分类
数据集下载:
链接:
https://pan.baidu.com/s/1AshgNxx1wOWhLgKxgjrZww?pwd=lz3l
提取码:
lz3l
数据集介绍:
iris.scale.txt
是 libsvm
格式的鸢尾花数据集,共有五个字段。第一个为标签字段,后四个为特征字段。
libsvm
格式参考:机器学习:libsvm数据格式
将数据集中的随机百分之70作为训练集,剩余的作为测试集。
使用 SparkSQL 的方式读取 libsvm
格式的文件会自动生成 label
和 features
结构的数据,如下所示:
val data: DataFrame = spark.read.format("libsvm").load("iris.scale.txt")
data.show()
需求实现:
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.{DataFrame, SparkSession}
object Iris {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().appName("Iris").master("local[*]").getOrCreate()
// 加载 libsvm 格式文件的数据
val data: DataFrame = spark.read.format("libsvm").load("C:\\Users\\Administrator\\Desktop\\iris.scale.txt")
data.show()
// 1.构建标签列转换对象
val labelIndexer: StringIndexerModel = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data)
// 2.构建特征列转换对象,设置特征列数量
val featureIndexer: VectorIndexerModel = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)
.fit(data)
// 3.将随机百分之70作为训练集,其余为测试集
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
// 4.创建随机森林对象,设置标签列与特征列以及决策树的个数
val rf: RandomForestClassifier = new RandomForestClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setNumTrees(10)
// 5.设置预测列标签
val labelConverter: IndexToString = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labelsArray(0))
// 6.管道组装
val pipeline: Pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
// 7.模型训练
val model: PipelineModel = pipeline.fit(trainingData)
// 8.模型预测
val predictions: DataFrame = model.transform(testData)
// 9.模型评估
predictions.select("predictedLabel", "label", "features").show()
// 10.创建错误率的计算对象
val evaluator: MulticlassClassificationEvaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
// 11.计算错误率
val accuracy: Double = evaluator.evaluate(predictions)
println(s"Test Error = ${(1.0 - accuracy)}")
// 12.打印随机森林模型
val rfModel: RandomForestClassificationModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
println(s"Learned classification forest model:\n ${rfModel.toDebugString}")
}
}