1. XGBoost 算法
1.1 XGBoost 的介绍
XGBoost是2016年由华盛顿大学陈天奇老师带领开发的一个可扩展机器学习系统。严格意义上讲XGBoost并不是一种模型,而是一个可供用户轻松解决分类、回归或排序问题的软件包。它内部实现了梯度提升树(GBDT)模型,并对模型中的算法进行了诸多优化,在取得高精度的同时又保持了极快的速度,在一段时间内成为了国内外数据挖掘、机器学习领域中的大规模杀伤性武器。
更重要的是,XGBoost在系统优化和机器学习原理方面都进行了深入的考虑。毫不夸张的讲,XGBoost提供的可扩展性,可移植性与准确性推动了机器学习计算限制的上限,该系统在单台机器上运行速度比当时流行解决方案快十倍以上,甚至在分布式系统中可以处理十亿级的数据。
XGBoost的主要优点:
- 简单易用。相对其他机器学习库,用户可以轻松使用XGBoost并获得相当不错的效果。
- 高效可扩展。在处理大规模数据集时速度快效果好,对内存等硬件资源要求不高。
- 鲁棒性强。相对于深度学习模型不需要精细调参便能取得接近的效果。
- XGBoost内部实现提升树模型,可以自动处理缺失值。
XGBoost的主要缺点:
- 相对于深度学习模型无法对时空位置建模,不能很好地捕获图像、语音、文本等高维数据。
- 在拥有海量训练数据,并能找到合适的深度学习模型时,深度学习的精度可以遥遥领先XGBoost。
1.2 XGboost的应用
XGBoost在机器学习与数据挖掘领域有着极为广泛的应用。据统计在2015年Kaggle平台上29个获奖方案中,17只队伍使用了XGBoost;在2015年KDD-Cup中,前十名的队伍均使用了XGBoost,且集成其他模型比不上调节XGBoost的参数所带来的提升。这些实实在在的例子都表明,XGBoost在各种问题上都可以取得非常好的效果。
同时,XGBoost还被成功应用在工业界与学术界的各种问题中。例如商店销售额预测、高能物理事件分类、web文本分类;用户行为预测、运动检测、广告点击率预测、恶意软件分类、灾害风险预测、在线课程退学率预测。虽然领域相关的数据分析和特性工程在这些解决方案中也发挥了重要作用,但学习者与实践者对XGBoost的一致选择表明了这一软件包的影响力与重要性。
2. 基于天气数据集的 XGBoost 分类实战
在实践的最开始,我们首先需要导入一些基础的函数库包括:numpy (Python进行科学计算的基础软件包),pandas(pandas是一种快速,强大,灵活且易于使用的开源数据分析和处理工具),matplotlib和seaborn绘图。
#导入需要用到的数据集
!wget https://tianchi-media.oss-cn-beijing.aliyuncs.com/DSW/7XGBoost/train.csv
'wget' 不是内部或外部命令,也不是可运行的程序
或批处理文件。
2.1 函数库导入
## 基础函数库
import numpy as np
import pandas as pd
## 绘图函数库
import matplotlib.pyplot as plt
import seaborn as sns
本次我们选择天气数据集进行方法的尝试训练,现在有一些由气象站提供的每日降雨数据,我们需要根据历史降雨数据来预测明天会下雨的概率。样例涉及到的测试集数据test.csv与train.csv的格式完全相同,但其RainTomorrow未给出,为预测变量。
数据的各个特征描述如下:
2.2 数据读取/载入
## 我们利用Pandas自带的read_csv函数读取并转化为DataFrame格式
data = pd.read_csv('train.csv')
2.3 数据信息简单查看
## 利用.info()查看数据的整体信息
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 106644 entries, 0 to 106643
Data columns (total 23 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Date 106644 non-null object
1 Location 106644 non-null object
2 MinTemp 106183 non-null float64
3 MaxTemp 106413 non-null float64
4 Rainfall 105610 non-null float64
5 Evaporation 60974 non-null float64
6 Sunshine 55718 non-null float64
7 WindGustDir 99660 non-null object
8 WindGustSpeed 99702 non-null float64
9 WindDir9am 99166 non-null object
10 WindDir3pm 103788 non-null object
11 WindSpeed9am 105643 non-null float64
12 WindSpeed3pm 104653 non-null float64
13 Humidity9am 105327 non-null float64
14 Humidity3pm 103932 non-null float64
15 Pressure9am 96107 non-null float64
16 Pressure3pm 96123 non-null float64
17 Cloud9am 66303 non-null float64
18 Cloud3pm 63691 non-null float64
19 Temp9am 105983 non-null float64
20 Temp3pm 104599 non-null float64
21 RainToday 105610 non-null object
22 RainTomorrow 106644 non-null object
dtypes: float64(16), object(7)
memory usage: 18.7+ MB
## 进行简单的数据查看,我们可以利用 .head() 头部.tail()尾部
data.head()
Date | Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | ... | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2012/1/19 | MountGinini | 12.1 | 23.1 | 0.0 | NaN | NaN | W | 30.0 | N | ... | 60.0 | 54.0 | NaN | NaN | NaN | NaN | 17.0 | 22.0 | No | No |
1 | 2015/4/13 | Nhil | 10.2 | 24.7 | 0.0 | NaN | NaN | E | 39.0 | E | ... | 63.0 | 33.0 | 1021.9 | 1017.9 | NaN | NaN | 12.5 | 23.7 | No | Yes |
2 | 2010/8/5 | Nuriootpa | -0.4 | 11.0 | 3.6 | 0.4 | 1.6 | W | 28.0 | N | ... | 97.0 | 78.0 | 1025.9 | 1025.3 | 7.0 | 8.0 | 3.9 | 9.0 | Yes | No |
3 | 2013/3/18 | Adelaide | 13.2 | 22.6 | 0.0 | 15.4 | 11.0 | SE | 44.0 | E | ... | 47.0 | 34.0 | 1025.0 | 1022.2 | NaN | NaN | 15.2 | 21.7 | No | No |
4 | 2011/2/16 | Sale | 14.1 | 28.6 | 0.0 | 6.6 | 6.7 | E | 28.0 | NE | ... | 92.0 | 42.0 | 1018.0 | 1014.1 | 4.0 | 7.0 | 19.1 | 28.2 | No | No |
5 rows × 23 columns
这里我们发现数据集中存在NaN,一般的我们认为NaN在数据集中代表了缺失值,可能是数据采集或处理时产生的一种错误。这里我们采用-1将缺失值进行填补,还有其他例如“中位数填补、平均数填补”的缺失值处理方法有兴趣的同学也可以尝试。
data = data.fillna(-1)
data.tail()
Date | Location | MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustDir | WindGustSpeed | WindDir9am | ... | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | RainToday | RainTomorrow | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
106639 | 2011/5/23 | Launceston | 10.1 | 16.1 | 15.8 | -1.0 | -1.0 | SE | 31.0 | NNW | ... | 99.0 | 86.0 | 999.2 | 995.2 | -1.0 | -1.0 | 13.0 | 15.6 | Yes | Yes |
106640 | 2014/12/9 | GoldCoast | 19.3 | 31.7 | 36.0 | -1.0 | -1.0 | SE | 80.0 | NNW | ... | 75.0 | 76.0 | 1013.8 | 1010.0 | -1.0 | -1.0 | 26.0 | 25.8 | Yes | Yes |
106641 | 2014/10/7 | Wollongong | 17.5 | 22.2 | 1.2 | -1.0 | -1.0 | WNW | 65.0 | WNW | ... | 61.0 | 56.0 | 1008.2 | 1008.2 | -1.0 | -1.0 | 17.8 | 21.4 | Yes | No |
106642 | 2012/1/16 | Newcastle | 17.6 | 27.0 | 3.0 | -1.0 | -1.0 | -1 | -1.0 | NE | ... | 68.0 | 88.0 | -1.0 | -1.0 | 6.0 | 5.0 | 22.6 | 26.4 | Yes | No |
106643 | 2014/10/21 | AliceSprings | 16.3 | 37.9 | 0.0 | 14.2 | 12.2 | ESE | 41.0 | NNE | ... | 8.0 | 6.0 | 1017.9 | 1014.0 | 0.0 | 1.0 | 32.2 | 35.7 | No | No |
5 rows × 23 columns
## 利用value_counts函数查看训练集标签的数量
pd.Series(data['RainTomorrow']).value_counts()
RainTomorrow
No 82786
Yes 23858
Name: count, dtype: int64
我们发现数据集中的负样本数量远大于正样本数量,这种常见的问题叫做“数据不平衡”问题,在某些情况下需要进行一些特殊处理。
## 对于特征进行一些统计描述
data.describe()
MinTemp | MaxTemp | Rainfall | Evaporation | Sunshine | WindGustSpeed | WindSpeed9am | WindSpeed3pm | Humidity9am | Humidity3pm | Pressure9am | Pressure3pm | Cloud9am | Cloud3pm | Temp9am | Temp3pm | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 | 106644.000000 |
mean | 12.129147 | 23.183398 | 2.313912 | 2.704798 | 3.509008 | 37.305137 | 13.852200 | 18.265378 | 67.940353 | 50.104657 | 917.003689 | 914.995385 | 2.381231 | 2.285670 | 16.877842 | 21.257600 |
std | 6.444358 | 7.208596 | 8.379145 | 4.519172 | 5.105696 | 16.585310 | 8.949659 | 9.118835 | 20.481579 | 22.136917 | 304.042528 | 303.120731 | 3.483751 | 3.419658 | 6.629811 | 7.549532 |
min | -8.500000 | -4.800000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -1.000000 | -7.200000 | -5.400000 |
25% | 7.500000 | 17.900000 | 0.000000 | -1.000000 | -1.000000 | 30.000000 | 7.000000 | 11.000000 | 56.000000 | 35.000000 | 1011.000000 | 1008.500000 | -1.000000 | -1.000000 | 12.200000 | 16.300000 |
50% | 12.000000 | 22.600000 | 0.000000 | 1.600000 | 0.200000 | 37.000000 | 13.000000 | 17.000000 | 70.000000 | 51.000000 | 1016.700000 | 1014.200000 | 1.000000 | 1.000000 | 16.700000 | 20.900000 |
75% | 16.800000 | 28.300000 | 0.600000 | 5.400000 | 8.700000 | 46.000000 | 19.000000 | 24.000000 | 83.000000 | 65.000000 | 1021.800000 | 1019.400000 | 6.000000 | 6.000000 | 21.500000 | 26.300000 |
max | 31.900000 | 48.100000 | 268.600000 | 145.000000 | 14.500000 | 135.000000 | 130.000000 | 87.000000 | 100.000000 | 100.000000 | 1041.000000 | 1039.600000 | 9.000000 | 9.000000 | 39.400000 | 46.200000 |
2.4 可视化描述
为了方便,我们先纪录数字特征与非数字特征:
numerical_features = [x for x in data.columns if data[x].dtype == np.float64]
category_features = [x for x in data.columns if data[x].dtype != np.float64 and x != 'RainTomorrow']
## 选取三个特征与标签组合的散点可视化
sns.pairplot(data=data[['Rainfall',
'Evaporation',
'Sunshine'] + ['RainTomorrow']], diag_kind='hist', hue= 'RainTomorrow')
plt.show()
从上图可以发现,在2D情况下不同的特征组合对于第二天下雨与不下雨的散点分布,以及大概的区分能力。相对的Sunshine与其他特征的组合更具有区分能力
for col in data[numerical_features].columns: