一文上手决策树:从理论到实战

news2024/9/22 5:35:34

一、基础概念

决策树是一类极为常用的机器学习方法,尤其是在分类场景。决策树通过树形结构来递归地将样本分割到不同的叶子结点中去,并根据每个叶子结点中的样本构成对该结点中的样本进行分类。

我们可以从两个视角来理解决策树模型。

  • 第一种视角是将决策树视为一组规则的集合。对一棵完整的决策树来说,从根节点到每一个叶子结点都对应了一条规则,不同的规则之间互斥且完备。
  • 第二种视角是从条件概率的角度来理解决策树。我们对每个叶子结点的分类,都是依据该结点包含的样本集合分属于不同分类的概率来决定的。从这个角度来看,决策树本质上也是一种概率模型。

与其他机器学习方法一样,使用决策树进行预测时,我们的目标是尽可能地在新样本上预测得更准确。那么,一方面我们要在训练集上得到尽可能高的预测精度,另一方面,我们要通过正则化参数来保证模型没有过拟合。

假设树 T T T的叶子结点个数为 ∣ T ∣ |T| T t t t为树 T T T的叶子结点,每个叶子结点有 N t N_t Nt个样本,假设 k k k类的样本有 N t k N_{tk} Ntk个,其中 k = 1 , 2 , ⋯   , K k=1,2,\cdots,K k=1,2,,K H t ( T ) H_t(T) Ht(T)为叶子结点上的经验熵(empirical entropy), α ≥ 0 \alpha \ge 0 α0为正则化参数,那么决策树学习的损失函数可以表示为:

L α ( T ) = ∑ t = 1 ∣ T ∣ N t H t ( T ) + α ∣ T ∣ L_{\alpha}(T) = \sum_{t=1}^{|T|}{N_t H_t(T) +\alpha |T|} Lα(T)=t=1TNtHt(T)+αT

决策树学习的目标就是最小化上述函数,该函数无法使用常规的梯度下降来直接求解,因此我们一般使用启发式的方法来寻找最优决策树,具体来说,就是递归地选择最优特征来分割数据集。如果某次分割后的子集可以完全正确划分到某一类,那么该子集可以归到一个叶子结点;否则,继续从这些子集中选择最优特征进行下一次划分,直到所有子集都能被正确分类。

以上思路会构建一棵完整的树,但是正如前文所述,我们还需要保证模型没有过拟合,因此我们需要对决策树进行剪枝。决策树剪枝通常有预剪枝和后剪枝两种方法。

总的来说,完整的决策树包含特征选择决策树构建决策树剪枝三大方面。

二、特征选择

为了构建一棵性能良好的决策树,我们需要从训练集中不断选取最具有区分度(分类能力)的特征。一般来说,我们通过三个指标来实现这一目标。

1. 信息增益

为了说明信息增益,我们需要引入信息熵的概念。在信息论和概率论中,熵是一种描述随机变量不确定性的度量方式,也可以用来描述样本集合的不纯度。熵越低,样本的不确定性就越低,纯度则越高。

假设当前样本数据集 D D D中的第 k k k个类所占比例为 p k ( k = 1 , 2 , ⋯   , Y ) p_k(k=1,2,\cdots,Y) pk(k=1,2,,Y),那么该样本数据集的熵可以定义为:

E ( D ) = − ∑ i = 1 Y p k l o g ( p k ) E(D) = - \sum_{i=1}^{Y}{p_k log(p_k)} E(D)=i=1Ypklog(pk)

假设离散随机变量 ( X , Y ) (X, Y) (X,Y)的联合概率分布为:

P ( X = x i , Y = y i ) = p i j ( i = 1 , 2 , ⋯   , m , j = 1 , 2 , ⋯   , n ) P(X=x_i, Y=y_i) = p_{ij}(i=1, 2, \cdots, m, \quad j=1, 2, \cdots, n) P(X=xi,Y=yi)=pij(i=1,2,,m,j=1,2,,n)

条件熵 E ( Y ∣ X ) E(Y|X) E(YX)表示在已知随机变量 X X X的条件下对 Y Y Y的不确定性的度量,它可以定义为在给定 X X X的条件下 Y Y Y的条件概率分布的熵对 X X X的数学期望。条件熵可以表示为:

E ( Y ∣ X ) = ∑ i = 1 m p i ⋅ E ( Y ∣ X = x i ) E(Y|X) = \sum_{i=1}^{m}{p_i \cdot E(Y|X=x_i)} E(YX)=i=1mpiE(YX=xi)

其中, p i = P ( X = x i ) , i = 1 , 2 , ⋯   , m p_i = P(X=x_i), i=1, 2, \cdots, m pi=P(X=xi),i=1,2,,m。在利用实际数据进行计算时,熵和条件熵中的概率计算都是基于极大似然估计得到,对应的熵和条件熵也叫经验熵和经验条件熵。

信息增益是指在得到了某个特征X的信息之后,使得类Y的信息不确定性减少的程度。或者说,信息增益代表了某特征带来的分类确定性的增量,特征的信息增益越大,目标分类的确定性也就越大。

假设训练集 D D D的经验熵为 E ( D ) E(D) E(D),给定特征 A A A的条件下 D D D的经验条件熵为 E ( D ∣ A ) E(D|A) E(DA),那么信息增益可定义为经验熵 E ( D ) E(D) E(D)与经验条件熵之差:

g ( D , A ) = E ( D ) − E ( D ∣ A ) g(D, A) = E(D) - E(D|A) g(D,A)=E(D)E(DA)

构建决策树时可以使用信息增益进行特征选择,特征的信息增益越大,代表了其分类能力越强,ID3算法就是基于信息增益做特征选择的。

我们举一个例子来演示信息增益的计算。

例1:假设有20位同学,其中有10位喜欢篮球,10位不喜欢篮球。在20位同学中有12位男同学,其中9位喜欢篮球,3位不喜欢篮球;有8位女同学,其中1位喜欢篮球,7位不喜欢篮球。那么性别(男/女)的信息增益是多少?

import numpy as np

def entropy(freq: list) -> float:
    """计算信息熵
    """
    freq = np.array([i for i in freq if i > 0])
    proba = freq / freq.sum()
    entropy = - (proba * np.log2(proba)).sum()
  
    return entropy

if __name__ == '__main__':
    # 原始数据
    like_basketball = [10, 10]
    male_like_basketball = [9, 3]
    female_like_basketball = [1, 7]
  
    # 经验熵
    entropy_init = entropy(like_basketball)
  
    # 条件熵 
    entropy_cond = 10 / 20 * entropy(male_like_basketball) + \
        10 / 20 * entropy(female_like_basketball)
  
    # 信息增益
    info_gain = entropy_init - entropy_cond
  
    print('经验熵:{0}\n条件熵:{1}\n信息增益:{2}'.format(
        entropy_init, entropy_cond, info_gain))

结果为:

经验熵:1.0
条件熵:0.6774212838293646
信息增益:0.3225787161706354

2. 信息增益率

信息增益存在一个问题:当某个特征分类取值较多时,该特征的信息增益计算结果会放大。取极端情况,如有一个特征为编号,每个样本对应了唯一的一个编号,这种情况下的信息纯度很高,那么基于这个特征得到的信息增益就很大。

信息增益率可以解决信息增益的上述问题。特征 A A A对数据集 D D D的信息增益率可以定义为其信息增益 g ( D , A ) g(D,A) g(D,A)与数据集 D D D关于特征 A A A取值的熵 E A ( D ) E_A(D) EA(D)的比值:

g R ( D , A ) = g ( D , A ) E A ( D ) g_R(D, A) = \frac{g(D, A)}{E_A(D)} gR(D,A)=EA(D)g(D,A)

其中,

E A ( D ) = − ∑ i = 1 n ∣ D i ∣ ∣ D ∣ l o g 2 ∣ D i ∣ ∣ D ∣ E_A(D) = -\sum_{i=1}^{n}{\frac{|D_i|}{|D|} log_2 \frac{|D_i|}{|D|}} EA(D)=i=1nDDilog2DDi

n n n表示特征 A A A的取值个数。C4.5算法是基于信息增益率进行特征选择的。

我们仍以例1来演示,通过前文可知,一共有10名男同学和10名女同学,那么我们可以据此计算出 E A ( D ) E_A(D) EA(D)

gender_cnt = [12, 8]
entropy_gender = entropy(gender_cnt)
gain_rate = info_gain / entropy_gender

print('信息增益率:{0}'.format(gain_rate))

结果为:

信息增益率:0.33222979419649123

3. 基尼系数

基尼系数也是一种较好的特征选择方法。假设样本有 K K K个类,样本属于第 k k k类的概率为 p k p_k pk,则该样本类别概率分布的基尼系数为:

G i n i ( p ) = ∑ k = 1 k p k ( 1 − p k ) = 1 − ∑ k = 1 K p k 2 Gini(p) = \sum_{k=1}^{k}{p_k(1-p_k)} = 1 - \sum_{k=1}^{K}p_k^2 Gini(p)=k=1kpk(1pk)=1k=1Kpk2

对于给定训练集 D D D C k C_k Ck是属于第 k k k类样本的集合,则该训练集的基尼指数为:

G i n i ( D ) = 1 − ∑ k = 1 K ( ∣ C k ∣ ∣ D ∣ ) 2 Gini(D) = 1 - \sum_{k=1}^{K} \left( \frac{|C_k|}{|D|}\right)^2 Gini(D)=1k=1K(DCk)2

如果训练集 D D D根据特征 A A A的某一取值 a a a划分为 D 1 D_1 D1 D 2 D_2 D2两个部分,那么在特征 A A A这个条件下,训练集 D D D的基尼系数为:

G i n i ( D , A ) = ∣ D 1 ∣ ∣ D ∣ G i n i ( D 1 ) + ∣ D 2 ∣ ∣ D ∣ G i n i ( D 2 ) Gini(D, A) = \frac{|D_1|}{|D|}Gini(D_1) +\frac{|D_2|}{|D|}Gini(D_2) Gini(D,A)=DD1Gini(D1)+DD2Gini(D2)

与信息熵的定义相似,训练集 D D D的基尼指数 G i n i ( D ) Gini(D) Gini(D)表示该集合的不确定性(不纯度), G i n i ( D , A ) Gini(D, A) Gini(D,A)表示训练集经过 A A A的划分后的不确定性(不纯度)。在分类过程中,我们总是希望不确定性越低越好,即 G i n i ( D , A ) Gini(D, A) Gini(D,A)越小越好。CART算法使用基尼系数来进行特征选择。

仍以例1来演示基尼系数的计算。

def gini(freq: list) -> float:
    """计算基尼系数
    """
    freq = np.array([i for i in freq if i > 0])
    proba = freq / freq.sum()
    g = 1 - (proba ** 2).sum()
  
    return g

gini_male = 12 / 20 * gini(male_like_basketball) + 8 / 20 * gini(female_like_basketball)

print('基尼系数:{0}'.format(gini_male))

结果为:

基尼系数:0.3125

三、决策树模型

三大经典决策树模型分别为ID3、C4.5、CART,它们都是通过递归地选择最优特征来构建决策树。如前文所述,在评估最优特征时,它们分别使用了信息增益、信息增益率和基尼系数三个指标。

ID3和C4.5算法仅有决策树的生成,不包含决策树剪枝的部分,因此容易过拟合。CART算法除了用于分类外,还可用于回归,也包含决策树剪枝,因此现在应用更为广泛。

1. ID3

ID3算法的全称为Iterative Dichotomiser 3,即迭代二叉树。其核心是基于信息增益递归地选择最优特征构造决策树。

简单来阐述,ID3算法的思路为:

  1. 首先预设一个决策树根节点,然后对所有特征计算信息增益;
  2. 选择一个信息增益最大的特征作为最优特征,根据该特征的不同取值建立子结点;
  3. 接着对每个子结点递归地调用上述方法,直到信息增益很小或者没有特征可选时,将这些子结点作为叶子结点,并以该叶子结点上的多数类作为预测类。

给定训练集 D D D、特征集合 A A A以及信息增益阈值 ϵ \epsilon ϵ,ID3算法的流程如下:

  1. 如果 D D D中所有实例属于同一类别 C k C_k Ck,那么所构建的决策树 T T T为单结点树,并且类 C k C_k Ck即为该结点的预测类。
  2. 如果 T T T不是单结点树,则计算特征集合 A A A中各特征对 D D D的信息增益,选择信息增益最大的特征 A g A_g Ag
  3. 如果 A g A_g Ag的信息增益小于阈值 ϵ \epsilon ϵ,则将 T T T视为单结点树,并将 D D D中所属数量最多的类 C k C_k Ck作为该结点的预测类,并返回 T T T
  4. 否则,对 A g A_g Ag的每一取值 a i a_i ai,按照 A g = a i A_g = a_i Ag=ai D D D划分为若干非空子集 D i D_i Di,以 D i D_i Di中的多数类作为预测类并构建子结点,由结点和子结点构成树 T T T并返回。
  5. 对第 i i i个子结点,以 D i D_i Di为训练集,以 A − A g A-A_g AAg为特征集,递归地调用上述步骤,即可得到决策树子树 T i T_i Ti并返回。

2. C4.5

C4.5算法实际上是对ID3算法的改进。

  1. ID3算法使用信息增益做特征选择,倾向于选择取值水平较多的特征。针对这一问题,C4.5算法改为使用信息增益率。
  2. ID3算法不可以处理缺失值,C4.5算法可以。
  3. ID3算法不支持连续值特征,C4.5算法支持。
  4. ID3算法不支持后剪枝,C4.5算法支持后剪枝。

给定训练集 D D D、特征集合 A A A以及信息增益阈值 ϵ \epsilon ϵ,C4.5算法的流程如下:

  1. 如果 D D D中所有实例属于同一类别 C k C_k Ck,那么所构建的决策树 T T T为单结点树,并且类 C k C_k Ck即为该结点的预测类。
  2. 如果 T T T不是单结点树,则计算特征集合 A A A中各特征对 D D D的信息增益,选择信息增益最大的特征 A g A_g Ag
  3. 如果 A g A_g Ag的信息增益率小于阈值 ϵ \epsilon ϵ,则将 T T T视为单结点树,并将 D D D中所属数量最多的类 C k C_k Ck作为该结点的预测类,并返回 T T T
  4. 否则,对 A g A_g Ag的每一取值 a i a_i ai,按照 A g = a i A_g = a_i Ag=ai D D D划分为若干非空子集 D i D_i Di,以 D i D_i Di中的多数类作为预测类并构建子结点,由结点和子结点构成树 T T T并返回。
  5. 对第 i i i个子结点,以 D i D_i Di为训练集,以 A − A g A-A_g AAg为特征集,递归地调用上述步骤,即可得到决策树子树 T i T_i Ti并返回。

3. CART

CART算法的全称为分类与回归树(classification and regression tree),它既可用于分类,又可用于回归,这是它与ID3/C4.5之间的主要区别之一,此处我们仅讨论CART算法用于分类的场景。此外,CART算法中的特征选择使用的是基尼系数。最后,CART算法不仅包含了决策树的生成算法,还包括了决策树的剪枝算法。

CART生成的决策树为二叉树,内部结点取值为“是”和“否”,这种方法等价于递归地二分每个特征,将特征空间划分为有限个子空间,并在这些子空间上确定预测的概率分布,即前述的预测条件概率分布。

其算法流程为:

  1. 给定训练集 D D D和特征集 A A A,对于每个特征 a a a及其所有取值 a i a_i ai,根据 a = a i a=a_i a=ai将数据集划分为 D 1 D_1 D1 D 2 D_2 D2两个部分,并计算 a = a i a=a_i a=ai时的基尼系数。
  2. 取基尼系数最小的特征及其相应的划分点作为最优特征和最优化分点,据此将当前结点划分为两个子结点,将训练集根据特征取值分配到两个子结点中。
  3. 对两个子结点递归地调用上述步骤,直至满足停止条件,最终生成CART分类决策树。

4. 对比

决策树模型分类树结构特征选择连续值处理缺失值处理剪枝处理
ID3分类多叉树信息增益不可以不可以不可以
C4.5分类多叉树信息增益率可以可以可以
CART分类二叉树基尼系数可以可以可以

四、决策树剪枝

决策树剪枝一般包含两种方法:预剪枝(pre-pruning)和后剪枝(post-pruning)。

1. 预剪枝

预剪枝,是指在决策树生成过程中提前停止树的增长的一种剪枝算法。其主要思路有:

  • 提前设定决策树的深度,当达到这一深度时,停止生长。
  • 当某结点的所有样本属于同一类别,停止生长。
  • 提前设定某个阈值,当某结点的样本数小于该阈值时,停止生长。
  • 提前设定某个阈值,当分裂带来的性能提升小于该阈值时,停止生长。

预剪枝方法直接、简单高效,适用于大规模求解问题。目前在主流的集成学习模型中,很多算法用到了预剪枝的思想。但因为决策树的构建使用的是启发式方法,具有局部最优的问题,预剪枝提前停止树的生长,存在一定的欠拟合风险。

2. 后剪枝

主流的后剪枝方法有四种:悲观错误剪枝(Pessimistic Error Pruning,PEP),最小错误剪枝(Minimum Error Pruning,MEP),代价复杂度剪枝(Cost-Complexity Pruning,CCP)和基于错误的剪枝(Error-Based Pruning,EBP)。C4.5采用悲观错误剪枝,CART采用代价复杂度剪枝。

后剪枝主要通过极小化决策树整体损失函数来实现。前文我们提到,决策树学习的目标是最小化如下损失函数:

L α ( T ) = ∑ t = 1 ∣ T ∣ N t H t ( T ) + α ∣ T ∣ L_{\alpha}(T) = \sum_{t=1}^{|T|}{N_t H_t(T) +\alpha |T|} Lα(T)=t=1TNtHt(T)+αT

其中,经验熵 H t ( T ) H_t(T) Ht(T)可以表示为:

H t ( T ) = − ∑ k N t k N t l o g N t k N t H_t(T) = - \sum_k \frac{N_{tk}}{N_{t}} log \frac{N_{tk}}{N_{t}} Ht(T)=kNtNtklogNtNtk

两式合并有:

L α ( T ) = ∑ t = 1 ∣ T ∣ N t H t ( T ) + α ∣ T ∣ = − ∑ t = 1 ∣ T ∣ ∑ k = 1 K N t k l o g N t k N t + α ∣ T ∣ = L ( T ) + α ∣ T ∣ \begin{aligned} L_{\alpha}(T) &= \sum_{t=1}^{|T|}{N_t H_t(T) +\alpha |T|} \\ &=-\sum_{t=1}^{|T|} \sum_{k=1}^K N_{tk} log \frac{N_{tk}}{N_{t}} + \alpha |T| \\ &=L(T) + \alpha |T| \end{aligned} Lα(T)=t=1TNtHt(T)+αT=t=1Tk=1KNtklogNtNtk+αT=L(T)+αT

其中, L ( T ) L(T) L(T)为模型的经验误差项, ∣ T ∣ |T| T表示决策树的复杂度(结点数), α ≥ 0 \alpha \ge 0 α0为正则化参数,用于调控经验误差项和正则化项之间的权重关系。

决策树后剪枝就是在正则化参数 α \alpha α确定的情况下,选择损失函数 L α ( T ) L_{\alpha}(T) Lα(T)最小的决策树模型。给定算法生成的决策树 T T T和正则化参数 α \alpha α,后剪枝算法的流程如下:

  1. 计算每个树节点的经验熵 H t ( T ) H_t(T) Ht(T)
  2. 递归地自底向上回缩,假设一组叶子结点回缩到父节点前后的数分别为 T b e f o r e T_{before} Tbefore T a f t e r T_{after} Tafter,其对应的损失函数分别为 L α ( T b e f o r e ) L_{\alpha}(T_{before}) Lα(Tbefore) L α ( T a f t e r ) L_{\alpha}(T_{after}) Lα(Tafter),如果 L α ( T b e f o r e ) ≥ L α ( T a f t e r ) L_{\alpha}(T_{before}) \ge L_{\alpha}(T_{after}) Lα(Tbefore)Lα(Tafter),则进行剪枝,将父节点变为新的叶子结点。
  3. 重复上一步,直至得到损失函数最小的子树 T α T_{\alpha} Tα

CART算法使用的正是后剪枝方法。CART后剪枝首先通过计算子树的损失函数来实现剪枝并得到一个子树序列,然后通过交叉验证的方法从子树序列中选取最优子树。

  1. 初始化 α b e s t = ∞ \alpha_{best} = \infty αbest=,最优子树集合 ω = T \omega = {T} ω=T
  2. 从叶子结点开始自下而上计算内部节点 t t t的损失函数 L α ( T t ) L_{\alpha}(T_t) Lα(Tt)、叶子结点数 ∣ T t ∣ |T_t| Tt,以及正则化阈值 α = m i n { L ( T ) − L ( T t ) ∣ T t ∣ − 1 , α b e s t } \alpha = min\{\frac{L(T)-L(T_t)}{|T_t| - 1}, \alpha_{best}\} α=min{Tt1L(T)L(Tt),αbest},更新 α b e s t = α \alpha_{best} = \alpha αbest=α
  3. 得到所有节点的 α \alpha α值集合 M M M
  4. 从M中选择最大的值 α k \alpha_k αk,自上而下地访问子树 t t t的内部节点,如果 L ( T ) − L ( T t ) ∣ T t ∣ − 1 ≤ α k \frac{L(T)-L(T_t)}{|T_t|-1} \le \alpha_k Tt1L(T)L(Tt)αk,则进行剪枝,并决定叶子结点 t t t的预测值。
  5. ω = ω ∪ T k \omega = \omega \cup T_k ω=ωTk M = M − α k M=M-{\alpha_k} M=Mαk
  6. 如果 M M M不为空,则回到步骤4。否则已得到了所有的可选最优子树集合 w w w
  7. 采用交叉验证在 w w w选择最优子树 T α T_{\alpha} Tα

五、优缺点

1. 优点

  1. 简单直观,生成的决策树很直观。
  2. 基本不需要预处理,不需要提前归一化和处理缺失值。
  3. 使用决策树预测的代价是 O ( l o g 2 N ) O(log2N) O(log2N) N N N为样本数。
  4. 既可以处理离散值也可以处理连续值。很多算法只是专注于离散值或者连续值。
  5. 可以处理多维度输出的分类问题。
  6. 相比于神经网络之类的黑盒分类模型,决策树在逻辑上可以很好解释。
  7. 可以交叉验证的剪枝来选择模型,从而提高泛化能力。
  8. 对于异常点的容错能力好,健壮性高。

2. 缺点

  1. 决策树算法非常容易过拟合,导致泛化能力不强。可以通过设置节点最少样本数量和限制决策树深度来改进。
  2. 决策树会因为样本发生一点的改动,导致树结构的剧烈改变。这个可以通过集成学习之类的方法解决。
  3. 寻找最优的决策树是一个NP难题,我们一般是通过启发式方法,容易陷入局部最优。可以通过集成学习的方法来改善。
  4. 有些比较复杂的关系,决策树很难学习,比如异或。这个就没有办法了,一般这种关系可以换神经网络分类方法来解决。
  5. 如果某些特征的样本比例过大,生成决策树容易偏向于这些特征。这个可以通过调节样本权重来改善。

六、代码实战

1. sklearn

在sklearn中,使用决策树进行分类预测非常简单,下面是一个来自官方文档的例子。

from sklearn.tree import DecisionTreeClassifier

X = [[0, 0], [1, 1]]
Y = [0, 1]
clf = DecisionTreeClassifier()
clf = clf.fit(X, Y)

# 预测
print(clf.predict([[2, 2]])

# 预测概率
print(clf.predict_proba([[2, 2]])

我们还可以将决策树通过可视化的方式呈现出来。

from sklearn.datasets import load_iris
from sklearn import tree
import matplotlib.pyplot as plt

# 以iris数据为例
iris = load_iris()
clf = tree.DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

# 可视化
plt.figure(figsize=(36, 24))
tree.plot_tree(clf, feature_names=iris.feature_names, 
               filled=True, proportion=True, fontsize=14)

2. PySpark

在PySpark中使用决策树模型稍显复杂。

import numpy as np
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier, RandomForestClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, TrainValidationSplit
from pyspark.ml.evaluation import BinaryClassificationEvaluator


spark = SparkSession.builder.appName('test').getOrCreate()


# 准备数据
# 数据可以从Hive中读取,或从pandas.DataFrame格式创建等。
# 此处假设一份用于二分类预测模型训练的数据已准备好,
data = YOUR_PYSPARK_DATAFRAME
features = YOUR_FEATURE_COLUMN_NAMES
label_col = YOUR_LABEL_COLUMNS

# 数据集分割
traindf, testdf = data.randomSplit([0.8, 0.2], seed=1)

# 特征向量化
vec_assembler = VectorAssembler(inputCols=features, outputCol='features')

# 决策树
dtree = DecisionTreeClassifier(
    seed=1,
    labelCol=label_col,
    featuresCol='features',
    predictionCol='pred',
    probabilityCol='proba',
    maxDepth=5,
    minInstancesPerNode=3,
    impurity='gini',
    maxBins=10
)

# 训练模型
pipeline = Pipeline(stages=[vec_assembler, dtree])
model = pipeline.fit(traindf)

# 特征重要性
feat_importances = list(zip(features, model.stages[1].featureImportances))
df_importances = pd.DataFrame(sorted(feat_importances, key=lambda x: x[1], reverse=True), 
                              columns=['feature', 'importances'])
df_importances.head()

# 预测
df_pred = model.transform(testdf) 
to_array = F.udf(lambda x: x.toArray().tolist(), ArrayType(DoubleType()))
df_pred = df_pred.withColumn('proba_score', to_array('proba')[1])

我们还可以在PySpark中使用网格搜索来确定最佳参数。

# 特征向量化
vec_assembler = VectorAssembler(inputCols=features, outputCol='features')

# 随机森林
dtree = DecisionTreeClassifier(
    seed=1,
    labelCol=label_col,
    featuresCol='features',
    predictionCol='pred',
    probabilityCol='proba',
    impurity='gini',
    # maxDepth=5,
    # minInstancesPerNode=3,
    # maxBins=10
)

# 流水线
pipeline = Pipeline(stages=[vec_assembler, dtree])

# 设置网格参数
param_grid = ParamGridBuilder() \
    .baseOn({dtree.labelCol:'label'}) \
    .baseOn({dtree.featuresCol: 'features'}) \
    .baseOn({dtree.predictionCol: 'pred'}) \
    .baseOn({dtree.probabilityCol: 'proba'}) \
    .addGrid(dtree.minInstancesPerNode, [3, 5, 7]) \
    .addGrid(dtree.maxDepth, [10, 12, 15, 20]) \
    .addGrid(dtree.maxBins, [5, 10, 15]) \
    .build()


# 模型评估
evaluator = BinaryClassificationEvaluator()

# 交叉验证
cv = CrossValidator(
    estimator=pipeline,
    estimatorParamMaps=param_grid,
    evaluator=evaluator,
    numFolds=5,
    seed=1024
)

# 开始执行
a = time.time()
cvModel = cv.fit(traindf)
b = time.time()
print(b - a)

# 打印最佳参数
params = cvModel.getEstimatorParamMaps()
avg_metrics = cvModel.avgMetrics
all_params = list(zip(params, avg_metrics))
best_param = sorted(all_params, key=lambda x: x[1], reverse=True)[0]
for p, v in best_param[0].items():
    print("{}: {}".format(p.name, v))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/177285.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python副业技术总结,手把手教你用宝塔面板部署Django程序

前言 最近写了几个Django项目,写完以后怎么让对方测试成了问题,因为之前都是自己在本地写的练习项目,对于部署这一块很陌生,不知道怎么操作,内心很忐忑。没办法,只能硬着头皮上,一边百度&#…

17种编程语言实现排序算法-插入排序

开源地址 https://gitee.com/lblbc/simple-works/tree/master/sort/ 覆盖语言:C、C、C#、Java、Kotlin、Dart、Go、JavaScript(JS)、TypeScript(TS)、ArkTS、swift、PHP。 覆盖平台:安卓(Java、Kotlin)、iOS(SwiftUI)、Flutter(Dart)、Window桌面(C#)、…

ARP渗透与攻防(四)之WireShark截获用户数据

ARP-WireShark截获用户数据 系列文章 ARP渗透与攻防(一)之ARP原理 ARP渗透与攻防(二)之断网攻击 ARP渗透与攻防(三)之流量分析 1.WireShark工具介绍 wireshark的官方下载网站: WireShark wireshark是非常流行的网络封包分析软件,功能十分强大。可以…

PowerShell 学习笔记:压缩、解压缩文件

在自动构建的时候&#xff0c;最常用的就是压缩备份项目的源文件&#xff0c;PowerShell提供了相关命令。Compress-Archive&#xff08;压缩文件&#xff09;Compress-Archive[-Path] <String[]>[-DestinationPath] <String>[-CompressionLevel <String>][-P…

Crossplane - 比 Terraform 更先进的云基础架构管理平台?

&#x1f449;️URL: https://crossplane.io/ &#x1f4dd;Description: 将云基础架构和服务组成自定义平台 API 简介 在 11 月的 KCD 上海现场&#xff0c;听了一场阿里云的工程师关于他们自己的多云基础架构管理工具的介绍&#xff0c;前边的引言部分有介绍到 Terraform&am…

连续系统的数字PID控制仿真-3

在连续系统的数字PID控制仿真-2的基础上&#xff0c;利用S函数实现PID离散控制器的Simulink仿真。在S函数中&#xff0c;采用初始化函数、更新函数和输出函数&#xff0c;即 mdlInitializeSizes函数、mdIUpdates函数和mdlOutputs函数。在初始化中采用sizes 结构&#xff0c;选择…

什么是GRACE CPU --- Grace CPU架构详解

深入详解GRACE CPU架构 NVIDIA Grace CPU 是 NVIDIA 开发的第一款数据中心 CPU。 通过将 NVIDIA 专业知识与 Arm 处理器、片上结构、片上系统 (SoC) 设计和弹性高带宽低功耗内存技术相结合&#xff0c;NVIDIA Grace CPU 从头开始构建&#xff0c;以创建世界上第一个超级芯片 用…

Spring控制反转(IoC)和依赖注入(DI)

Spring官网&#xff1a;spring.io1.spring 2.SprinMVC 3.Maven高级 4.SpringBoot 5.MyBatisPlus为什么要学Spring?简化开发&#xff0c;降低企业级开发的复杂度框架整合&#xff0c;高效整合其他技术&#xff0c;提高企业级应用开发与运行效率Spring 系统架构IOC&#xff08;I…

分享132个ASP源码,总有一款适合您

ASP源码 分享132个ASP源码&#xff0c;总有一款适合您 下面是文件的名字&#xff0c;我放了一些图片&#xff0c;文章里不是所有的图主要是放不下...&#xff0c; 132个ASP源码下载链接&#xff1a;https://pan.baidu.com/s/1bk2hftqR5NTdUIT2zvmbiw?pwdke5x 提取码&#x…

离散数学与组合数学-04图论

文章目录离散数学与组合数学-04图论4.1 图的引入4.1.1 图的示例4.1.2 无序对和无序积4.1.3 图的定义4.2 图的表示4.2.1 集合表示和图形表示4.2.2 矩阵表示法4.2.3 邻接点与邻接边4.3 图的分类4.3.1 按边的方向分类4.3.2 按平行边分类4.3.3 按权值分类4.3.4 综合分类方法4.4 图论…

干货 | 移动互联网应用程序(APP)个人信息安全自我评测工具

以下内容整理自清华大学《数智安全与标准化》课程大作业期末报告同学的汇报内容。第一部分&#xff1a;研究背景概述截止今年6月&#xff0c;我国已经有APP 232万款&#xff0c;手机网民达到10.47亿&#xff0c;在APP中大规模的个人信息收集和使用成为常态&#xff0c;个人信息…

【算法题】1828. 统计一个圆中点的数目

插&#xff1a; 前些天发现了一个巨牛的人工智能学习网站&#xff0c;通俗易懂&#xff0c;风趣幽默&#xff0c;忍不住分享一下给大家。点击跳转到网站。 坚持不懈&#xff0c;越努力越幸运&#xff0c;大家一起学习鸭~~~ 题目&#xff1a; 给你一个数组 points &#xff0c;…

Kubernetes:基于命令行终端UI的管理工具 K9s

写在前面 K9s 是一个基于终端UI的 K8S 管理工具博文内容为 k9s 在 windows、Linux 以及docker 安装Demo简单的 热键使用。理解不足小伙伴帮忙指正 我所渴求的&#xff0c;無非是將心中脫穎語出的本性付諸生活&#xff0c;為何竟如此艱難呢 ------赫尔曼黑塞《德米安》 K9s 是一…

客快物流大数据项目(一百零八):Spring Cloud 技术栈

文章目录 Spring Cloud 技术栈 ​​​​前言 一、微服务技术栈

如果物理学真的不存在?

最近过年&#xff0c;看「三体」电视剧。开始看剧情&#xff0c;觉得代入感挺不好的&#xff0c;特别林子健演的那个作战中心的长官&#xff0c;镜头从远处拉过去&#xff0c;看着他昂首挺胸慢慢走过去的样子。之后&#xff0c;讲到了火鸡和农场主的故事&#xff0c;这个时候再…

C++STL剖析(二)—— vector的概念和使用

文章目录1. vector的介绍2. vector的常见构造3. vector的遍历方式&#x1f351; [ ] 下标&#x1f351; 迭代器&#x1f351; 范围for4. vector 迭代器使用&#x1f351; begin 和 end&#x1f351; rbegin 和 rend5. vector 空间增长问题&#x1f351; size&#x1f351; cap…

37.Isaac教程--自由空间分割(道路分割)

自由空间分割 ISAAC教程合集地址: https://blog.csdn.net/kunhe0512/category_12163211.html 文章目录自由空间分割快速开始推理训练数据模拟数据设置与模拟器的通信来自公共数据集的真实数据具有自主数据收集的自由空间分割的真实数据自主数据收集通过地图规划路径监测机器人位…

JavaEE 突击 5 - Spring 更简单的读取和存储对象(2)

Spring 更简单的读取和存储对象 - 2三 . 获取 Bean 对象3.1 属性注入3.1.1 原理3.1.2 相关问题能在启动类里面调用 [Autowired ](/Autowired ) 注解吗[Autowired ](/Autowired ) 能使用多次吗Autowired 修饰的私有方法名字可以是其他的吗3.1.3 属性注入的优点和缺点3.2 Setter …

关于Kubernetes 桌面客户端 Aptakube 的一些笔记整理

写在前面 分享一个 k8s 桌面客户端 AptakubeAptakube 不是一个开源的产品&#xff0c;现在需要付费&#xff0c;最初是开源的这里简单了解下理解不足小伙伴帮忙指正 我所渴求的&#xff0c;無非是將心中脫穎語出的本性付諸生活&#xff0c;為何竟如此艱難呢 ------赫尔曼黑塞《…

redis学习看这一篇文章就够了

第一章 redis简介 第1节 NoSQL 1.1 NoSQL简介 NoSQL&#xff0c;泛指非关系型的数据库&#xff0c;NoSQL即Not-Only SQL&#xff0c;它可以作为关系型数据库的良好补充。随着互联网web2.0网站的兴起&#xff0c;非关系型的数据库现在成了一个极其热门的新领域&#xff0c;非…