一、CatBoost模型简介
1.1适用范围
CatBoost(Categorical Boosting)是一种基于梯度提升的机器学习算法,特别适用于处理具有类别特征的数据集。它可以用于分类、回归和排序任务,并且在处理具有大量类别特征的数据时表现优异。典型应用包括但不限于:
- 电子商务中的推荐系统
- 客户行为分析
- 财务风险评估
- 医疗数据分析
1.2原理
CatBoost使用梯度提升决策树(GBDT)作为其核心算法。其主要特点包括:
- 处理类别特征:CatBoost原生支持类别特征,并在内部使用目标编码(target encoding)来处理它们,从而减少了类别变量处理的复杂性。
- 顺序增强(Ordered Boosting):在构建每棵树时,CatBoost通过引入一种新的顺序提升方法来避免传统梯度提升中的预测偏差问题。
- 随机分片:为了进一步减少过拟合,CatBoost在每次树构建时随机分割数据集。
1.3优点
- 高效处理类别特征:无需复杂的预处理步骤。
- 减少过拟合:通过顺序增强和随机分片技术。
- 易于使用:内置了许多默认的优化参数,适合初学者和快速原型开发。
- 高性能:在许多实际应用中表现优于其他GBDT算法(如XGBoost和LightGBM)。
1.4缺点
- 模型训练时间较长:尽管有许多优化,训练时间可能比其他简单模型更长。
- 内存占用较高:在处理大规模数据时,内存需求较大。
二、实现CatBoost模型的Python代码
下面是一个使用CatBoost进行分类任务的完整Python代码示例,包含详细注释。
2.1导入必要的包和测试数据
import pandas as pd
from catboost import CatBoostClassifier, Pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns
# 加载Titanic数据集
url = 'https://web.stanford.edu/class/archive/cs/cs109/cs109.1166/stuff/titanic.csv'
data = pd.read_csv(url)
# 查看数据集的列名
print("Columns in the dataset:", data.columns)
2.2简单的数据预处理
# 简单的数据预处理
# 填充缺失值
# data['Age'].fillna(data['Age'].median(), inplace=True)
# data['Embarked'].fillna(data['Embarked'].mode()[0], inplace=True)
# 将Sex和Embarked转换为类别型特征
data['Sex'] = data['Sex'].astype('category')
# data['Pclass'] = data['Pclass'].astype('Pclass')
# 选择特征和目标
features = ['Pclass', 'Sex', 'Age', 'Siblings/Spouses Aboard', 'Parents/Children Aboard', 'Fare']
target = 'Survived'
X = data[features]
y = data[target]
2.3构建CatBoost模型
# 分割数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 创建CatBoost数据池
categorical_features = ['Sex', 'Pclass']
train_pool = Pool(X_train, y_train, cat_features=categorical_features)
test_pool = Pool(X_test, y_test, cat_features=categorical_features)
# 初始化并训练CatBoost分类器
model = CatBoostClassifier(
iterations=1000,
learning_rate=0.1,
depth=6,
loss_function='Logloss', # 二分类任务使用'Logloss'
verbose=100 # 每100次迭代打印一次信息
)
# 训练模型
model.fit(train_pool)
# 在测试集上进行预测
y_pred = model.predict(test_pool)
y_pred_proba = model.predict_proba(test_pool)[:, 1]
2.4模型评估
# 评估模型
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
print(classification_report(y_test, y_pred))
模型评估输出结果如下 :
0: learn: 0.6538633 total: 159ms remaining: 2m 39s
100: learn: 0.2814504 total: 891ms remaining: 7.93s
200: learn: 0.2007734 total: 1.68s remaining: 6.68s
300: learn: 0.1536222 total: 2.45s remaining: 5.69s
400: learn: 0.1220845 total: 3.19s remaining: 4.77s
500: learn: 0.0961718 total: 3.95s remaining: 3.93s
600: learn: 0.0810769 total: 4.7s remaining: 3.12s
700: learn: 0.0694396 total: 5.45s remaining: 2.33s
800: learn: 0.0598153 total: 6.2s remaining: 1.54s
900: learn: 0.0527771 total: 6.93s remaining: 761ms
999: learn: 0.0474017 total: 7.67s remaining: 0us
Accuracy: 0.8033707865168539
precision recall f1-score support
0 0.84 0.85 0.84 111
1 0.74 0.73 0.74 67
accuracy 0.80 178
macro avg 0.79 0.79 0.79 178
weighted avg 0.80 0.80 0.80 178
Feature: Pclass, Importance: 16.480181005946406
Feature: Sex, Importance: 24.322199798316337
Feature: Age, Importance: 27.28642174968946
Feature: Siblings/Spouses Aboard, Importance: 5.125530737270014
Feature: Parents/Children Aboard, Importance: 3.006729091175773
Feature: Fare, Importance: 23.77893761760206
2.5可视化特征重要性(可选)
# 可视化特征重要性(可选)
plt.figure(figsize=(10, 6))
plt.barh(X.columns, feature_importances)
plt.xlabel('Feature Importance')
plt.title('CatBoost Feature Importances')
plt.show()
特征重要性输出结果如下:
2.6绘制混淆矩阵
# 绘制混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()
绘制混淆矩阵输出结果如下:
2.7绘制ROC曲线
# 绘制ROC曲线
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='blue', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='gray', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc='lower right')
plt.show()
绘制ROC曲线输出结果如下: