CART(Classification and Regression Tree)
CART(分类与回归树)是一种用于分类和回归任务的决策树算法,提出者为 Breiman 等人。它的核心思想是通过二分法递归地将数据集划分为子集,从而构建一棵树。CART 算法既可以生成分类树,也可以生成回归树。
1. CART 的特点
- 二叉树结构:CART 始终生成二叉树,每个节点只有两个分支(左子树和右子树)。
- 分裂标准不同:
- 对于分类任务,CART 使用**基尼指数(Gini Index)**作为分裂标准。
- 对于回归任务,CART 使用**最小均方误差(MSE)**作为分裂标准。
- 支持剪枝:通过后剪枝减少过拟合。
- 处理连续和离散数据:支持连续特征的划分点选择。
2. CART 的基本流程
- 输入:训练数据集 D,目标变量类型(分类或回归)。
- 递归分裂:
- 按照基尼指数(分类)或均方误差(回归)选择最佳划分点。
- 对数据集划分为两个子集,递归构造子树。
- 停止条件:
- 节点样本数量小于阈值。
- 划分后不再能显著降低误差。
- 剪枝:
- 通过校验集性能优化,剪去不显著的分支。
- 输出:最终的二叉决策树。
3. 分类树
(1) 基尼指数
基尼指数(Gini Index)用于衡量一个节点的“纯度”,越小表示越纯:
其中:
- :类别 k 的样本数量。
- K:类别的总数。
节点分裂的基尼指数计算为:
最佳划分点是使 最小的特征和对应的划分点。
(2) 示例:分类树
数据集:
天气 | 温度 | 湿度 | 风力 | 是否运动 |
---|---|---|---|---|
晴天 | 30 | 高 | 弱 | 否 |
晴天 | 32 | 高 | 强 | 否 |
阴天 | 28 | 高 | 弱 | 是 |
雨天 | 24 | 正常 | 弱 | 是 |
雨天 | 20 | 正常 | 强 | 否 |
-
计算每个特征的基尼指数:
- 对离散特征(如天气),分别计算不同类别划分后的基尼指数。
- 对连续特征(如温度),尝试所有划分点,计算每个划分点的基尼指数。
-
选择最优特征和划分点:
- 选择基尼指数最小的划分点。
-
生成子树:
- 对每个子集递归分裂,直到满足停止条件。
4. 回归树
(1) 分裂标准
对于回归任务,CART 使用**均方误差(MSE)**作为分裂标准:
其中:
- :第 i 个样本的目标值。
- :节点中所有样本目标值的均值。
节点分裂的误差计算为:
最佳划分点是使 最小的特征和对应的划分点。
(2) 示例:回归树
假设我们有如下数据集(目标值为房价):
面积(平方米) | 房价(万元) |
---|---|
50 | 150 |
60 | 180 |
70 | 210 |
80 | 240 |
90 | 270 |
-
尝试划分点:
- 例如,划分点为 656565。
- 左子集:{50,60},右子集:{70, 80, 90}。
-
计算误差:
- 左子集的均值:。
- 右子集的均值:。
- 计算分裂后的总均方误差。
-
选择最佳划分点:
- 选择误差最小的划分点,继续构造子树。
5. 剪枝
CART 使用后剪枝来防止过拟合:
-
生成完全生长的决策树。
-
计算子树的损失函数:
其中:
- :第 i 个叶子节点。
- :叶子节点的数量。
- α:正则化参数,控制树的复杂度。
-
剪去对验证集性能提升不大的分支。
6. CART 的优缺点
优点
- 生成二叉树,逻辑清晰,易于实现。
- 支持分类和回归任务。
- 支持连续特征和缺失值处理。
- 剪枝机制增强了泛化能力。
缺点
- 易受数据噪声影响,可能生成复杂的树。
- 对高维数据表现一般,无法处理稀疏特征。
- 生成的边界是轴对齐的,可能不适用于复杂分布。
7. 与其他决策树算法的比较
特点 | ID3 | C4.5 | CART |
---|---|---|---|
划分标准 | 信息增益 | 信息增益比 | 基尼指数 / MSE |
支持连续特征 | 否 | 是 | 是 |
树结构 | 多叉树 | 多叉树 | 二叉树 |
剪枝 | 无 | 后剪枝 | 后剪枝 |
应用 | 分类 | 分类 | 分类与回归 |
8. 代码实现
以下是一个简单的 CART 分类树实现:
import numpy as np
# 计算基尼指数
def gini_index(groups, classes):
total_samples = sum(len(group) for group in groups)
gini = 0.0
for group in groups:
size = len(group)
if size == 0:
continue
score = 0.0
for class_val in classes:
proportion = [row[-1] for row in group].count(class_val) / size
score += proportion ** 2
gini += (1 - score) * (size / total_samples)
return gini
# 划分数据集
def split_data(data, index, value):
left, right = [], []
for row in data:
if row[index] < value:
left.append(row)
else:
right.append(row)
return left, right
# 示例数据
dataset = [
[2.771244718, 1.784783929, 0],
[1.728571309, 1.169761413, 0],
[3.678319846, 2.81281357, 0],
[3.961043357, 2.61995032, 0],
[2.999208922, 2.209014212, 1],
]
# 计算基尼指数
split = split_data(dataset, 0, 2.5)
gini = gini_index(split, [0, 1])
print("基尼指数:", gini)
输出结果
基尼指数: 0.30000000000000004
CART 是机器学习中非常经典的算法,同时也是随机森林、梯度提升决策树等模型的基础。