决策树
1.概述
1.1 决策树是如何工作
决策树能够从一系列有特征和标签的数据中总结出决策规则,并且使用树状图的结构来表现,从而达到解决回归和分类问题。通俗的说,就是我们只需要问一系列问题就可以对数据进行分类。
核心要解决的问题:
1. 如何从数据中找出最佳节点和最佳分支。
2. 如何让决策树停止生长,防止过拟合。
1.2 sklearn中的决策树(分类树)
sklearn官网:http://scikit-learn.org/stable/index.html
1.2.1 sklearn中决策树的类都在”tree“这个模块之下。
简明介绍:
- 高随机版本分类树
- 高随机版本回归树
- 将生成的决策树导出为DOT格式,画图专用
1.2.2 sklearn的基本建模流程
from sklearn import tree #导入需要的模块
clf = tree.DecisionTreeClassifier() #实例化
clf = clf.fit(X_train,y_train) #用训练集数据训练模型
result = clf.score(X_test,y_test) #导入测试集,从接口中调用需要的信息
1.2.3 sklearn中重要参数的介绍
criterion
对于分类树来说,衡量一棵树的好坏的指标叫做‘不纯度’。通常来说,不纯度越低,这颗树就越好。
不纯度的计算方式有两种:
(1)信息熵(Entropy)
(2)基尼系数(Gini Impurity)
其中t代表节点,i代表标签的任意分类,p(i|t)代表标签i在节点t上的比例。当使用信息熵时,sklearn实际计算的是基于信息熵的信息增益(Information Gain),即父节点的信息熵和子节点的信息熵之差。
决策树的基本流程概括为:
(1)计算全部特征的不纯度指标
(2)选取不纯度指标最优的特征来进行分支
(3)除被选之外的特征中继续计算不纯度指标
(4)选取不纯度最低的特征来进行分支
splitter
用来控制决策树中的随机选项的。
(1)best:决策树在分枝时虽然随机,但是还是会优先选择更重要的特征进行分枝。
(2)random:决策树在分枝时会更加随机。
random_state
random_state用来设置分枝中的随机模式的参数。输入任意整数,使得长出相同的树,使得模型更加稳定。
1.2.4 剪枝参数的介绍
max_depth
限制树的最大深度,超过设定深度的树枝全部剪掉。
min_samples_leaf
一个节点在分支后的每一个字节点都必须要包含至少min_samples_leaf训练样本,否则分支不会发生。
min_samples_split
min_samples_split限定,一个节点必须要包含至少min_samples_split个训练样本,这个节点才允许被分枝,否则分枝就不会发生。
max_features & min_impurity_decrease
max_features限制分枝时考虑的特征个数,超过限制个数的特征都会被舍弃。
min_impurity_decrease限制信息增益的大小,信息增益小于设定数值的分枝不会发生。
class_weight
完成样本标签平衡的参数。
min_weight_fraction_leaf
有了权重之后,样本量就不再是单纯地记录数目,而是受输入的权重影响了,因此这时候剪枝,就需要搭配min_weight_fraction_leaf这个基于权重的剪枝参数来使用。
1.2.4 重要属性和接口
feature_importances_
能够查看各个特征对模型的重要性。
1.3 sklearn中的决策树(回归树)
1.3.1 重要参数,属性以及接口
回归树衡量分支的指标有三种:
(1)输入"mse"使用均方误差mean squared error(MSE。L2损失),父节点和叶子节点之间的均方误差的差额将被用来作为特征选择的标准。这种指标使用叶子节点的平均值来最小化L1损失。
(2)输入“friedman_mse”使用费尔德曼均方误差。
(3)输入"mae"使用绝对平均误差MAE(mean absolute error),这种指标使用叶子节点的中值来最小化L1损失。
在回归树中,MSE不只是我们的分枝质量衡量指标,也是我们最常用的衡量回归树回归质量的指标,同时交叉验证的指标也是MSE。但是回归树的借口score是R平方。
虽然均方误差永远为正,但是sklearn当中使用均方误差作为评判标准时,却是计算”负均方误差“(neg_mean_squared_error)。这是因为sklearn在计算模型评估指标的时候,会考虑指标本身的性质,均方误差本身是一种误差,所以被sklearn划分为模型的一种损失(loss),因此在sklearn当中,都以负数表示。真正的均方误差MSE的数值,其实就是neg_mean_squared_error
去掉负号的数字。