CART 决策树
CART决策树(Classification And Regression Tree),可以做为分类树也可以作为回归树。
什么是回归树?
在分类树中我们可以处理离散的数据(数据种类有限的数据)它输出的数据样本是数据的类别,而回归树可以对于连续的数值进行预测,也就是预测数据在那些区间内进行一个取值,他输出的是一个数值。
CART决策树原理
首先我们知道ID3算法是基于信息增益进行判断,而C4.5算法是在ID3基础上进行了改进,提出了信息增益率概念,其实CART与C4.5类似,只不过属性上采用了基尼系数。
基尼系数:经济学中的一个指标,是衡量国家收入的差异的常用指标,当基尼系数大于0.4的时候,说明财富差异悬殊。基尼系数在0.2-0.4之间说明财富差距不大。
基尼系数
t为节点,基尼系数计算公式则为:
p(Ck|t)表示的是节点t属于Ck的概率,节点t的基尼系数为1减Ck的概率平方和。
例如:
定义一个集合A: 6个人都吃饭
定义一个集合B: 3个人吃饭,3个人不吃饭.
集合A的基尼系数为:p(ck|t) = 1 , GINI(t)=1-1=0
集合B的基尼系数为:p(c1|t) = 0.5 p(c2|t)=0.5 , GINI(t)=1-(0.50.5+0.50.5)=0.5
这个两个基尼系数可以看出,集合一的比较稳定。
CART 算法来创建分类树
# encoding=utf-8
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_iris
# 准备数据集
iris=load_iris()
# 获取特征集和分类标识
features = iris.data
labels = iris.target
# 随机抽取33%的数据作为测试集,其余为训练集
train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.33, random_state=0)
# 创建CART分类树
clf = DecisionTreeClassifier(criterion='gini')
# 拟合构造CART分类树
clf = clf.fit(train_features, train_labels)
# 用CART分类树做预测
test_predict = clf.predict(test_features)
# 预测结果与测试集结果作比对
score = accuracy_score(test_labels, test_predict)
print("CART分类树准确率 %.4lf" % score)
如何使用 CART 回归树做预测
# encoding=utf-8
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_boston
from sklearn.metrics import r2_score,mean_absolute_error,mean_squared_error
from sklearn.tree import DecisionTreeRegressor
# 准备数据集
boston=load_boston()
# 探索数据
print(boston.feature_names)
# 获取特征集和房价
features = boston.data
prices = boston.target
# 随机抽取33%的数据作为测试集,其余为训练集
train_features, test_features, train_price, test_price = train_test_split(features, prices, test_size=0.33)
# 创建CART回归树
dtr=DecisionTreeRegressor()
# 拟合构造CART回归树
dtr.fit(train_features, train_price)
# 预测测试集中的房价
predict_price = dtr.predict(test_features)
# 测试集的结果评价
print('回归树二乘偏差均值:', mean_squared_error(test_price, predict_price))
print('回归树绝对值偏差均值:', mean_absolute_error(test_price, predict_price))