《机器学习公式推导与代码实现》学习笔记,记录一下自己的学习过程,详细的内容请大家购买作者的书籍查阅。
CatBoost
CatBoost是俄罗斯搜索引擎巨头Yandex于2017年开源的一款GBDT计算框架,因能够高效处理数据中的类别特征而取名为CatBoost
(Categorical Boosting
)。
1 机器学习中类别特征的处理方法
CatBoost通过对常规的目标变量统计方法添加先验项来改进它们。除此之外CatBoost还考虑使用类别特征的不同组合来增加数据集特征维度。
对于特征取值数目较多的类别特征,一种折中的方法就是将类别数目重新归类,使其降到较少数目再进行one-hot编码。另一种常用的方法是目标变量统计
(target statistics, TS
),TS计算每个类别对于目标变量的期望值并将类别特征转换为新的数值特征。CatBoost在常规TS方法上做了改进。
2 CatBoost理论基础
CatBoost
算法框架的自身理论特色,包括用于处理类别变量的目标变量统计
、特征组合
和排序提升算法
。
2.1 目标变量统计
CatBoost
算法的设计初衷是为了更好的处理GBDT特征中的categorical features
。在处理 GBDT特征中的categorical features的时候,最简单的方法是用 categorical feature 对应的标签的平均值来替换。在决策树中,标签平均值将作为节点分裂的标准。这种方法被称为 Greedy Target-based Statistics
, 简称 Greedy TS
,用公式来表达就是:
x
^
k
i
=
∑
j
=
1
n
[
x
j
,
k
=
x
i
,
k
]
Y
i
∑
j
=
1
n
[
x
j
,
k
=
x
i
,
k
]
\hat{x}_{k}^{i} =\frac{\sum_{j=1}^{n}\left [ x_{j,k} =x_{i,k} \right ]Y_{i}}{\sum_{j=1}^{n} \left [ x_{j,k} =x_{i,k} \right ]}
x^ki=∑j=1n[xj,k=xi,k]∑j=1n[xj,k=xi,k]Yi
这种方法有一个显而易见的缺陷,就是通常特征比标签包含更多的信息,如果强行用标签的平均值来表示特征的话,当训练数据集和测试数据集数据结构和分布不一样的时候会出条件偏移问题。
一个标准的改进 Greedy TS的方式是添加先验分布项,这样可以减少噪声和低频率类别型数据对于数据分布的影响:
x
^
k
i
=
∑
j
=
1
p
−
1
[
x
σ
j
,
k
=
x
σ
p
,
k
]
Y
σ
j
+
α
p
∑
j
=
1
p
−
1
[
x
σ
j
,
k
=
x
σ
p
,
k
]
+
α
\hat{x}_{k}^{i} =\frac{\sum_{j=1}^{p-1}\left [ x_{\sigma _{j,k} } =x_{\sigma _{p,k} } \right ]Y_{\sigma _{j}} + \alpha p}{\sum_{j=1}^{p-1} \left [ x_{\sigma _{j,k} } =x_{\sigma _{p,k} } \right ]+\alpha }
x^ki=∑j=1p−1[xσj,k=xσp,k]+α∑j=1p−1[xσj,k=xσp,k]Yσj+αp
其中p是添加的先验项,α通常是大于0的权重系数。添加先验项是一个普遍做法,针对类别数较少的特征,它可以减少噪声数据。对于回归问题,一般情况下,先验项可取数据集label的均值。对于二分类,先验项是正例的先验概率。利用多个数据集排列也是有效的,但是,如果直接计算可能导致过拟合。
CatBoost利用了一个比较新颖的计算叶子节点值的方法,这种方式(oblivious trees
,对称树)可以避免多个数据集排列中直接计算会出现过拟合的问题。
2.2 特征组合
值得注意的是几个类别型特征的任意组合都可视为新的特征。例如,在音乐推荐应用中,我们有两个类别型特征:用户ID和音乐流派。如果有些用户更喜欢摇滚乐,将用户ID和音乐流派转换为数字特征时,根据上述这些信息就会丢失。
结合这两个特征就可以解决这个问题,并且可以得到一个新的强大的特征。然而,组合的数量会随着数据集中类别型特征的数量成指数增长,因此不可能在算法中考虑所有组合。
为当前树构造新的分割点时,CatBoost会采用贪婪的策略考虑组合。对于树的第一次分割,不考虑任何组合。对于下一个分割,CatBoost将当前树的所有组合、类别型特征与数据集中的所有类别型特征相结合,并将新的组合类别型特征动态地转换为数值型特征。
2.3 排序提升算法
对于学习预测偏移的内容,我提出了两个问题:
- 什么是预测偏移?
- 用什么办法解决预测偏移问题?
预测偏移(Prediction shift
)是由梯度偏差造成的。在GDBT的每一步迭代中, 损失函数使用相同的数据集求得当前模型的梯度, 然后训练得到基学习器, 但这会导致梯度估计偏差, 进而导致模型产生过拟合的问题。
CatBoost通过采用排序提升 (Ordered boosting
) 的方式替换传统算法中梯度估计方法,进而减轻梯度估计的偏差,提高模型的泛化能力。
CatBoost
采用对称树作为基分类器,对称意味着在树的同一层,分裂标准相同。对称树具有平衡、不易过拟合、能够大大缩短测试时间的特点。
3 CatBoost算法实现
作为与XGBoost和LightGBM齐名的Boosting算法,CatBoost有足够优秀的性能指标,尤其是对类别特征的处理。
import pandas as pd
data = pd.read_csv('./adult.data', header=None)
data
data.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race',
'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'income'] # 变量重命名
data['income']
0 <=50K
1 <=50K
2 <=50K
3 <=50K
4 <=50K
...
32556 <=50K
32557 >50K
32558 <=50K
32559 <=50K
32560 >50K
Name: income, Length: 32561, dtype: object
data['income'] = data['income'].astype('category').cat.codes
data['income'].unique()
array([0, 1], dtype=int8)
from sklearn.model_selection import train_test_split
import catboost as cb
from sklearn.metrics import accuracy_score
X_train, X_test, y_train, y_test = train_test_split(data.drop(['income'], axis=1), data['income'], random_state=10, test_size=0.3)
clf = cb.CatBoostClassifier(eval_metric='AUC', depth=4, iterations=500, l2_leaf_reg=1, learning_rate=0.1)
cat_features_index = [1, 3, 5, 6, 7, 8, 9, 13] # 设置分类特征的索引,以便 CatBoost 能够正确地识别这些特征
clf.fit(X_train, y_train, cat_features=cat_features_index)
y_pred = clf.predict(X_test)
print(accuracy_score(y_pred, y_test))
0: total: 274ms remaining: 2m 16s
1: total: 337ms remaining: 1m 23s
2: total: 384ms remaining: 1m 3s
3: total: 434ms remaining: 53.8s
4: total: 485ms remaining: 48s
5: total: 558ms remaining: 45.9s
6: total: 596ms remaining: 41.9s
7: total: 642ms remaining: 39.5s
8: total: 676ms remaining: 36.9s
9: total: 712ms remaining: 34.9s
10: total: 748ms remaining: 33.3s
11: total: 782ms remaining: 31.8s
12: total: 816ms remaining: 30.6s
13: total: 854ms remaining: 29.6s
14: total: 896ms remaining: 29s
15: total: 941ms remaining: 28.4s
16: total: 981ms remaining: 27.9s
17: total: 1.02s remaining: 27.3s
18: total: 1.06s remaining: 26.8s
19: total: 1.1s remaining: 26.4s
20: total: 1.14s remaining: 26s
21: total: 1.18s remaining: 25.6s
22: total: 1.22s remaining: 25.2s
23: total: 1.25s remaining: 24.8s
24: total: 1.28s remaining: 24.4s
...
497: total: 18s remaining: 72.4ms
498: total: 18.1s remaining: 36.2ms
499: total: 18.1s remaining: 0us
0.8721465861398301
笔记本_Github地址