1 决策树算法介绍
决策树(Decision Tree,又称为判定树)算法是机器学习中常见的一类算法,是一种以树结构形式表达的预测分析模型。决策树属于监督学习(Supervised learning),根据处理数据类型的不同,决策树又为分类决策树与回归决策树。最早的的决策树算法是由Hunt等人于1966年提出,Hunt算法是许多决策树算法的基础,包括ID3、C4.5和CART等。
1.1 决策树原理
决策树的主要思想是根据已知数据构建一棵树,通过对待分类或回归的样本进行逐步的特征判断,最终将其分类或回归至叶子节点。
决策树学习的关键在于如何选择最优的划分属性,所谓的最优划分属性,对于二元分类而言,就是尽量使划分的样本属于同一类别,即“纯度”最高的属性。那么如何来度量特征(features)的纯度,这时候就要用到“信息熵(information entropy)”。先来看看信息熵的定义:假如当前样本集D中第k类样本所占的比例为Pk(k=1,2,3,......|y|),|y|为类别的总数(对于二元分类来说,|y|=2)。则样本集的信息熵为:
Ent(D)的值越小,则D的纯度越高。再来看一个概念信息增益(information gain),假定离散属性有V个可能的取值{a1,a2,......av},如果使用特征a来对数据集D进行划分,则会产生V个分支结点, 其中第v(小v)个结点包含了数据集D中所有在特征a上取值为av的样本总数,记为Dv。因此可以根据上面信息熵的公式计算出信息熵,再考虑到不同的分支结点所包含的样本数量不同,给分支节点赋予权重|Dv|/|D|,从这个公式能够发现包含的样本数越多的分支结点的影响越大(这样是信息增益的一个缺点,即信息增益对可取值数目多的特征有偏好(即该属性能取得值越多,信息增益,越偏向这个),待会后面会举例子说明,这个缺点可以用信息增益率来解决)。因此,能够计算出特征a对样本集D进行划分所获得的“信息增益”:
一般而言,信息增益越大,则表示使用特征a对数据集划分所获得的“纯度提升”越大。所以信息增益可以用于决策树划分属性的选择,其实就是选择信息增益最大的属性,ID3算法就是采用的信息增益来划分属性。
1.2 决策树中的关键概念
- 节点:决策树由许多节点组成,其中分为两种类型:内部节点和叶子节点。内部节点表示某个特征,而叶子节点表示某个类别或回归值。
- 分支:每个内部节点连接着一些分支,每个分支代表着某个特征取值,表示样本在该特征上的取值。
- 根节点:决策树的根节点是整棵树的起点,即第一个判断特征的节点。
- 决策规则:决策树的每个节点表示一条决策规则,即对某个特征的判断,其决策规则是由已知数据集计算而得的。
- 剪枝:决策树为了避免过度拟合(Overfitting),通常会在构建好树之后进行剪枝操作,即去掉一些决策规则,以避免模型过于复杂。
1.3 决策树的实现流程
- 数据预处理:将原始数据转换成决策树可处理的数据格式。对于分类问题,需要将类别变量编码为数字;对于连续变量,需要进行离散化处理。
- 特征选择:从所有可用的特征中选择一个最佳的特征,用于划分数据集。常用的特征选择算法有信息增益、信息增益比和基尼指数等。
- 决策树的构建:将数据集递归地划分成越来越小的子集,直到数据集不能再被划分,或者达到预定的停止条件。在每个节点上选择最佳的特征,将数据集划分成两个子集,并在子集上递归地执行此过程,直到子集不能再被划分。
- 剪枝:为了避免过拟合,可以在构建完整棵树之后,对决策树进行剪枝。剪枝分为预剪枝和后剪枝两种方法。预剪枝是在决策树构建过程中,通过限制树的深度、节点数或其他条件,避免过拟合。后剪枝是在决策树构建完成之后,通过对子树进行剪枝,来减少决策树的复杂度。
- 决策树的评估:通过测试集或交叉验证等方法,对决策树的性能进行评估。评估指标包括准确率、精确率、召回率、F1分数等。
- 预测:将待预测的样本依次从根节点开始进行判断,按照决策规则向下移动,直到到达某个叶子节点,将样本归类于该叶子节点所代表的类别。
1.4 决策树的划分选择
熵:物理意义是体系混乱程度的度量。
信息熵:表示事物不确定性的度量标准,可以根据数学中的概率计算,出现的概率就大,出现的机会就多,不确定性就小(信息熵小)。
(1)信息增益(ID3使用的划分方式)
假设训练数据集D和特征A,根据如下步骤计算信息增益:
第一步:计算数据集D的经验熵:
其中,|Ck|为第k类样本的数目,|D|为数据集D的数目。
第二步:计算特征A对数据集D的经验条件熵H(D|A):
第三步:计算信息增益:
一般而言,信息增益越大,则意味着使用属性A来进行划分所获得的“纯度提升” 越大。因此,我们可使用信息增益来进行决策树的划分属性选择。ID3决策树学习算法就是以信息增益为准则来选择划分属性的。
(2)信息增益率(C4.5所用划分准则)
特征A对于数据集D的信息增益比定义为:
其中,
称为数据集D关于A的取值熵。
增益率准则就可取值数目较少的属性有所偏好,因此,C4.5算法并不是直接选择增益率最大的候选划分属性,而是使用了一个启发式:先从候选划分属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的。
(3)基尼指数
分类问题中,假设有K个类,样本点属于k的概率Pk,则概率分布的基尼指数:
二分类问题:
对给定的样本集合D,基尼指数:
CART决策树使用“基尼指数”来选择划分属性。数据集D的纯度可用基尼值来度量,Gini(D)越小,则数据集的纯度越高。CART生成的是二叉树,计算量相对来说不是很大,可以处理连续和离散变量,能够对缺失值进行处理。
1.5 典型的决策树算法
- ID3算法:ID3(Iterative Dichotomiser 3)算法是决策树算法中最早的一种,使用信息增益来选择最优特征。ID3算法基于贪心思想,一直选择当前最优的特征进行分割,直到数据集分割完成或没有特征可分割为止。
- C4.5算法:C4.5算法是ID3算法的改进版,使用信息增益比来选择最优特征。C4.5算法对ID3算法中存在的问题进行了优化,包括处理缺失值、处理连续值等。
- CART算法:CART(Classification and Regression Trees)算法是一种基于基尼不纯度的二叉树结构分类算法,用于解决二分类和回归问题。CART算法可以处理连续值和离散值的特征,能够生成二叉树结构,具有较好的可解释性。
- CHAID算法:CHAID(Chi-square Automatic Interaction Detection)算法是一种基于卡方检验的决策树算法,用于处理分类问题。CHAID算法能够处理多分类问题,不需要预先对特征进行处理。
- MARS算法:MARS(Multivariate Adaptive Regression Splines)算法是一种基于样条插值的决策树算法,用于回归问题。MARS算法能够处理连续值和离散值的特征,能够生成非二叉树结构,具有较好的拟合能力。
分类决策树可用于处理离散型数据,回归决策树可用于处理连续型数据。最常见的决策树算法包含ID3算法、C4.5算法以及CART算法。
1.6 决策树的剪枝
剪枝:顾名思义就是给决策树 "去掉" 一些判断分支,同时在剩下的树结构下仍然能得到不错的结果。之所以进行剪枝,是为了防止或减少 "过拟合现象" 的发生,是决策树具有更好的泛化能力。
具体做法:去掉过于细分的叶节点,使其回退到父节点,甚至更高的节点,然后将父节点或更高的叶节点改为新的叶节点。
剪枝的两种方法:
- 预剪枝:在决策树构造时就进行剪枝。在决策树构造过程中,对节点进行评估,如果对其划分并不能再验证集中提高准确性,那么该节点就不要继续王下划分。这时就会把当前节点作为叶节点。
- 后剪枝:在生成决策树之后再剪枝。通常会从决策树的叶节点开始,逐层向上对每个节点进行评估。如果剪掉该节点,带来的验证集中准确性差别不大或有明显提升,则可以对它进行剪枝,用叶子节点来代填该节点。
注意:决策树的生成只考虑局部最优,相对地,决策树的剪枝则考虑全局最优。
2 决策树算法的优缺点
2.1 决策树算法的优点:
-
简单易于理解和解释:决策树算法的模型可以直观地表示为树形结构,易于理解和解释,可以帮助人们做出可靠的决策。
-
适用于多类型的数据:决策树算法可以处理包含分类和数值型特征的数据,也可以用于处理多分类问题。
-
可以处理缺失值和异常值:决策树算法可以自动处理缺失值和异常值,而无需进行额外的数据预处理。
-
能够处理高维数据:决策树算法可以有效地处理具有大量特征的数据集,而不需要进行特征选择或降维。
-
计算复杂度较低:决策树算法的训练和预测的计算复杂度相对较低,尤其适用于大规模数据集。
2.2 决策树算法的缺点:
-
容易过拟合:决策树算法容易在训练数据上过拟合,导致在新数据上的泛化能力较差。可以通过剪枝等技术来减少过拟合的风险。
-
对输入数据的微小变化敏感:决策树算法对输入数据的微小变化非常敏感,这可能导致不稳定的预测结果。
-
忽略属性之间的相关性:决策树算法假设每个特征在判断类别时都是独立的,忽略了特征之间的相关性。这在某些情况下可能导致模型的性能下降。
-
可能产生不一致的分类结果:决策树算法对于某些数据集,可能会生成不一致的分类结果,这是由于训练数据的微小变化可能导致完全不同的树结构。
-
难以处理连续性特征:决策树算法对于连续性特征的处理相对困难,需要进行离散化或采用其他处理方法。
决策树算法的优缺点在实际应用中可能会受到具体问题和数据集的影响,因此在选择算法时需要根据具体情况进行综合考虑。
3 决策树算法的应用场景
决策树算法在机器学习和数据挖掘领域有广泛的应用场景,包括但不限于以下几个方面:
-
分类问题:决策树算法可以用于解决分类问题,例如根据一组特征将实例分为不同的类别。它在垃圾邮件过滤、疾病诊断、客户分类等领域都有应用。
-
回归问题:决策树算法也可以用于解决回归问题,例如根据一组特征预测一个连续的数值。它在房价预测、销量预测等领域可以发挥作用。
-
特征选择:决策树算法可以通过特征选择的方式确定哪些特征对于解决问题是最重要的。通过分析决策树的结构和特征重要性,可以帮助我们理解数据集并选择最相关的特征。
-
缺失值处理:决策树算法可以有效地处理带有缺失值的数据集。在决策树的构建过程中,可以根据可用的特征进行分割,从而处理缺失值问题。
-
异常检测:决策树算法可以用于检测异常值。通过构建决策树,可以将异常值划分为与正常值不同的分支,从而帮助我们发现异常情况。
-
推荐系统:决策树算法可以用于构建个性化的推荐系统。通过分析用户的特征和历史行为,决策树可以推断出用户的兴趣和偏好,进而提供个性化的推荐内容。
需要注意的是,决策树算法适用于特征具有离散值或连续值的数据集。同时,决策树算法也有其局限性,对于特征关联性较强的数据集可能表现不佳。因此,在选择算法时需要综合考虑数据集的特点和问题的需求。
4 决策树的代码实现
4.1 使用决策树实现iris数据集分类
(1)数据集介绍
Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例。数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)三种中的哪一品种。
数据内容:
sepal_len sepal_wid petal_len petal_wid label
0 5.1 3.5 1.4 0.2 0
1 4.9 3.0 1.4 0.2 0
2 4.7 3.2 1.3 0.2 0
3 4.6 3.1 1.5 0.2 0
4 5.0 3.6 1.4 0.2 0
.. ... ... ... ... ...
145 6.7 3.0 5.2 2.3 2
146 6.3 2.5 5.0 1.9 2
147 6.5 3.0 5.2 2.0 2
148 6.2 3.4 5.4 2.3 2
149 5.9 3.0 5.1 1.8 2
(2)代码实现
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 构建决策树
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)
# 预测
y_pred = clf.predict(X_test)
# 评估性能
accuracy = accuracy_score(y_test, y_pred)
print("Accuracy: ", accuracy)
(3)分类结果
Accuracy: 1.0
Process finished with exit code 0
达到了100%的准确率
4.2 使用决策树对带噪声的正选曲线进行回归
(1)代码实现
import numpy as np
from sklearn.tree import DecisionTreeRegressor
import matplotlib.pyplot as plt
rng = np.random.RandomState(1)
X = np.sort(5*rng.rand(80,1),axis = 0)
y = np.sin(X).ravel()
y[::5] += 3 * (0.5 - rng.rand(16))
# Fit regression model
regr_1 = DecisionTreeRegressor(max_depth=2)
regr_2 = DecisionTreeRegressor(max_depth=5)
regr_1.fit(X, y)
regr_2.fit(X, y)
# Predict
X_test = np.arange(0.0, 5.0, 0.01)[:, np.newaxis]
y_1 = regr_1.predict(X_test)
y_2 = regr_2.predict(X_test)
# Plot the results
plt.figure()
plt.scatter(X, y, s=20, edgecolor="black",
c="darkorange", label="data")
plt.plot(X_test, y_1, color="cornflowerblue",
label="max_depth=2", linewidth=2)
plt.plot(X_test, y_2, color="yellowgreen", label="max_depth=5", linewidth=2)
plt.xlabel("data")
plt.ylabel("target")
plt.title("Decision Tree Regression")
plt.legend()
plt.show()
(2)结果展示