文章目录
- 什么是决策树?
- 决策树的优缺点
- 决策树示例——鸢尾花分类
什么是决策树?
决策树及其集成是分类和回归机器学习任务的流行方法。决策树被广泛使用,因为它们易于解释,处理分类特征,扩展到多类分类设置,不需要特征缩放,并且能够捕获非线性和特征相互作用。随机森林和增强算法等树集成算法在分类和回归任务中表现最佳。
常应用于以下类型的场景:
- 预测用户贷款是否能够按时还款;
- 预测邮件是否是垃圾邮件;
- 预测用户是否会购买某件商品等等
官网:分类和回归
决策树的优缺点
优点:
-
决策树算法易理解,机理解释起来简单。
-
决策树算法可以用于小数据集。
-
决策树算法的时间复杂度较小,为用于训练决策树的数据点的对数。
-
相比于其他算法智能分析一种类型变量,决策树算法可处理数字和数据的类别。
-
能够处理多输出的问题。
-
对缺失值不敏感。
-
可以处理不相关特征数据。
-
效率高,决策树只需要一次构建,反复使用,每一次预测的最大计算次数不超过决策树的深度。
缺点:
-
对连续性的字段比较难预测。
-
容易出现过拟合。
-
当类别太多时,错误可能就会增加的比较快。
-
在处理特征关联性比较强的数据时表现得不是太好。
-
对于各类别样本数量不一致的数据,在决策树当中,信息增益的结果偏向于那些具有更多数值的特征。
参考博客:决策树算法优缺点
决策树示例——鸢尾花分类
数据集下载:
链接:
https://pan.baidu.com/s/1AshgNxx1wOWhLgKxgjrZww?pwd=lz3l
提取码:
lz3l
数据集介绍:
iris.data
数据集中共有五个字段,逗号分隔,前四个为特征字段,最后一个为标签字段。
标签字段列一共有三种值,分别是:Iris-setosa
、Iris-versicolor
、Iris-virginica
。
将数据集中的随机百分之70作为训练集,剩余的作为测试集。
需求实现:
import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
object Iris {
// TODO 鸢尾花种类判断
def main(args: Array[String]): Unit = {
val sc: SparkSession = SparkSession
.builder()
.appName("Iris")
.master("local[*]")
.getOrCreate()
// 1.加载鸢尾花数据
val train_data: RDD[String] = sc
.read
.textFile("iris.data")
.rdd
// 2.将随机百分之70的数据设置为训练集,其余为测试集
val data: Array[RDD[String]] = train_data.randomSplit(Array(0.7, 0.3))
// 3.向量转换
import sc.implicits._
val trainDF: DataFrame = data(0).map(lines => {
val arr: Array[String] = lines.split(",")
LabeledPoint(
if (arr(4).equals("Iris-setosa")) {
1D
} else if (arr(4).equals("Iris-versicolor")) {
2D
} else {
3D
},
Vectors.dense(arr.take(4).map(_.toDouble))
)
}).toDF("label", "features")
// 4.创建决策树对象
val classifier = new DecisionTreeClassifier()
// 设置最大深度、分支、质量、特征列
classifier.setMaxDepth(5).setMaxBins(32).setImpurity("gini").setFeaturesCol("features")
// 5.训练模型
val model: DecisionTreeClassificationModel = classifier.fit(trainDF)
// 打印模型
println(model.toDebugString)
// 6.将测试集转换成向量
val testDF: DataFrame = data(1).map(lines => {
val arr: Array[String] = lines.split(",")
LabeledPoint(
if (arr(4).equals("Iris-setosa")) {
1D
} else if (arr(4).equals("Iris-versicolor")) {
2D
} else {
3D
},
Vectors.dense(arr.take(4).map(_.toDouble))
)
}).toDF("label", "features")
// 7.模型预测
val result: DataFrame = model.transform(testDF.select("label", "features"))
// 8.模型预测评估
result.select("label", "features","prediction").show(100)
// 9.计算错误率
val error: Double = result.where("label = prediction").count.toDouble/result.count
println("错误率为:"+(1-error))
}
}