7. 回归决策树
学习目标:
- 知道回归决策树的实现原理
前面已经讲到,关于数据类型,我们主要可以把其分为两类,①连续型数据和②离散型数据。
在面对不同数据时,决策树也可以分为两大类型:
- 分类决策树:主要用于处理离散型数据
- 回归决策树:主要用于处理连续型数据
连续性数据主要用于回归;离散型数据主要用于分类
7.1 原理概述
不管是回归决策树还是分类决策树,都会存在两个核心问题:
- 如何选择划分结点?
- 如何决定叶节点的输出值?
一个回归树对应着输入空间(即特征空间)的一个划分结点以及在划分单元上的输出值。在分类树中,我们采用信息论中的方法,通过计算选择最佳划分点。而在回归树中,采用的是启发式的方法。
假如我们有 n n n 个特征,每个特征有 s i ( i ∈ ( 1 , n ) ) s_i(i \in (1, n)) si(i∈(1,n)) 个取值,那我们遍历所有特征,尝试该特征所有取值,对空间进行划分,直到取到特征 j j j 的取值 s s s,使得损失函数最小,这样就得到了一个划分点。描述该过程的公式如下:
min j s [ min c 1 L ( y i , c i ) + min c 2 L ( y i , c 2 ) ] \underset{js}{\min}[\underset{c_1}{\min}\ \mathcal{L}(y_i, c_i) + \underset{c_2} {\min} \ \mathcal{L}(y_i, c_2)] jsmin[c1min L(yi,ci)+c2min L(yi,c2)]
其中:
- n n n 表示特征的数量
- s i s_i si 表示第 i i i 个特征的取值数量
- j j j 和 s s s 分别表示最佳划分点的特征和取值
- c 1 c_1 c1 和 c 2 c_2 c2 分别表示划分后两个区域内固定的输出值。
- L \mathcal{L} L 表示损失函数
假设将输入空间划分为 M M M 个单元: R 1 , R 2 , . . . , R m R_1, R_2, ..., R_m R1,R2,...,Rm,那么每个区域的输出值就是 c m = a v g ( y i ∣ x i ∈ R m ) c_m = \mathrm{avg}(y_i|x_i \in R_m) cm=avg(yi∣xi∈Rm),也就是该区域内所有点 y y y 值的平均数。
其中:
- M M M 表示输入空间被划分成的单元数
- R 1 , R 2 , . . . , R m R_1, R_2, ..., R_m R1,R2,...,Rm 表示每个单元
- c m c_m cm 表示每个区域的输出值,它等于该区域内所有点 y y y 值的平均数
- a v g ( y i ∣ x i ∈ R m ) \mathrm{avg}(y_i|x_i \in R_m) avg(yi∣xi∈Rm) 表示在 x i x_i xi 属于区域 R m R_m Rm 的条件下,所有 y i y_i yi 值的平均数。
Q1:“单元”是什么?
A1:在决策树中,单元(也称为区域)是指输入空间被划分成的子区域。决策树通过不断地选择最佳划分点来将输入空间划分成若干个单元,每个单元内的数据点具有相似的特征。每个单元都有一个固定的输出值,用来预测该区域内数据点的目标值。
Q2:划分节点就是结点吗?划分节点可以是叶子结点吗?
A2:划分节点是指决策树中的非叶子节点,它用来将输入空间划分成若干个子区域。每个划分节点都有一个划分条件,用来决定数据点属于哪个子区域。划分节点不是叶子节点,叶子节点是指决策树中没有子节点的节点,它表示一个单元,用来预测该区域内数据点的目标值。
Q3:单元=叶子结点,划分点非叶子结点的结点,对吗?
A3:是的。在决策树中,每个单元都对应着一个叶子节点,每个叶子节点都表示一个单元。划分节点是非叶子节点,它用来将输入空间划分成若干个子区域。
举例:如下图,假如我们想要对楼内居民的年龄进行回归,将楼划分为 3 个区域 R 1 , R 2 , R 3 R_1, R_2, R_3 R1,R2,R3(红线)。那么 R 1 R_1 R1 的输出就是第一列四个居民年龄的平均值, R 2 R_2 R2 的输出就是第二列四个居民年龄的平均值, R 3 R_3 R3 的输出就是第三、四列八个居民年龄的平均值。
7.2 算法描述
输入:训练数据集
D
D
D
输出:回归树
f
(
x
)
f(x)
f(x)
在训练数据集所在的输入空间中,递归的将每个区域划分为两个子区域并决定每个子区域上的输出值,构建二叉决策树:
一、选择最优切分特征 j j j 与切分点 s s s,求解
min j , s [ min c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + min c 2 ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] \underset{j, s}{\min}\left[ \underset{c_1}{\min} \sum_{x_i \in R_1(j, s)} (y_i - c_1)^2 + \underset{c_2}{\min} \sum_{x_i \in R_2(j, s)}(y_i - c_2)^2 \right] j,smin c1minxi∈R1(j,s)∑(yi−c1)2+c2minxi∈R2(j,s)∑(yi−c2)2
遍历特征 j j j,对固定的切分特征 j j j 扫描切分点 s s s,选择使得上式达到最小值的对 ( j , s ) (j, s) (j,s)
二、用选定的对 ( j , s ) (j, s) (j,s) 划分区域并决定相应的输出值:
R 1 ( j , s ) = x ∣ x ( j ) ≤ s R_1(j, s) = x|x^{(j)} \le s R1(j,s)=x∣x(j)≤s
R 2 ( j , s ) = x ∣ x ( j ) > s R_2(j, s) = x|x^{(j)} >s R2(j,s)=x∣x(j)>s
c ^ m = 1 N ∑ x 1 ∈ R m ( j , s ) y i 其中 x ∈ R m , m = 1 , 2 \hat{c}_m = \frac{1}{N}\sum_{x_1 \in R_m(j, s)} y_i \ \ 其中x \in R_m, m = 1, 2 c^m=N1x1∈Rm(j,s)∑yi 其中x∈Rm,m=1,2
三、继续对两个子区域调用步骤一和二,直至满足停止条件。
四、将输入空间划分为 M M M 个区域 R 1 , R 2 , . . . , R M R_1, R_2, ..., R_M R1,R2,...,RM,生成决策树:
f ( x ) = ∑ m = 1 M c ^ m I ( x ∈ R m ) f(x) = \sum_{m = 1}^M \hat{c}_mI(x\in R_m) f(x)=m=1∑Mc^mI(x∈Rm)
其中:
- D D D 表示训练数据集
- f ( x ) f(x) f(x) 表示回归树
- j j j 和 s s s 分别表示最优切分特征和切分点
- R 1 ( j , s ) R_1(j, s) R1(j,s) 和 R 2 ( j , s ) R_2(j, s) R2(j,s) 分别表示根据最优切分特征和切分点划分出的两个子区域
- c 1 c_1 c1 和 c 2 c_2 c2 分别表示两个子区域内的输出值
- c ^ m \hat{c}_m c^m 表示第 m m m 个区域内的输出值,它等于该区域内所有点 y y y 值的平均数
- M M M 表示输入空间被划分成的区域数
- R 1 , R 2 , . . . , R M R_1, R_2, ..., R_M R1,R2,...,RM 表示每个区域。
7.3 简单实例
为了易于理解,接下来通过一个简单实例加深对回归决策树的理解。训练数据见下表,我们的目标是得到一棵最小二乘回归树。
x x x(特征值) | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 |
---|---|---|---|---|---|---|---|---|---|---|
y y y(目标值) | 5.56 | 5.7 | 5.91 | 6.4 | 6.8 | 7.05 | 8.9 | 8.7 | 9 | 9.05 |
7.3.1 实例计算过程
一、选择最优的切分特征 j j j 与最优切分点 s s s:
- 确定第一个问题:选择最优切分特征
- 在本数据集中,只有一个特征,因此最优切分特征自然是 x x x
- 确定第二个问题:我们考虑 9 个切分点
[
1.5
,
2.5
,
3.5
,
4.5
,
5.5
,
6.5
,
7.5
,
8.5
,
9.5
]
[1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5]
[1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5]。
- 损失函数定义为平方损失函数 L ( y , f ( x ) ) = [ f ( x ) − y ] 2 \mathcal{L}(y, f(x)) = [f(x) - y]^2 L(y,f(x))=[f(x)−y]2,其中 f ( x ) f(x) f(x) 为预测值, y y y 为真实值(目标值)
- 将上述 9 个切分点依此代入下面的公式,其中 c m = a v g ( y i ∣ x i ∈ R m ) c_m = \mathrm{avg}(y_i | x_i \in R_m) cm=avg(yi∣xi∈Rm)
a. 计算子区域输出值:
当切分点 s = 1.5 s=1.5 s=1.5 时,数据被分为两个子区域: R 1 R_1 R1 和 R 2 R_2 R2。 R 1 R_1 R1 包括特征值为 1 1 1 的数据点,而 R 2 R_2 R2 包括特征值为 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 2,3,4,5,6,7,8,9,10 2,3,4,5,6,7,8,9,10 的数据点。
c 1 c_1 c1 和 c 2 c_2 c2 分别是这两个子区域的输出值。它们的计算方法是将各自子区域内的目标值相加,然后除以数据点的数量。因此,这两个区域的输出值分别为:
- c 1 = 5.56 c_1 = 5.56 c1=5.56
- c 2 = 5.7 + 5.91 + 6.4 + 6.8 + 7.05 + 8.9 + 8.7 + 9 + 9.05 9 = 7.50 c_2= \frac{5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05}{9} = 7.50 c2=95.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05=7.50
当切分点 s = 2.5 s=2.5 s=2.5 时,数据被分为两个子区域: R 1 R_1 R1 和 R 2 R_2 R2。 R 1 R_1 R1 包括特征值为 1 , 2 1,2 1,2 的数据点,而 R 2 R_2 R2 包括特征值为 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 3,4,5,6,7,8,9,10 3,4,5,6,7,8,9,10 的数据点。
c 1 c_1 c1 和 c 2 c_2 c2 分别是这两个子区域的输出值。它们的计算方法是将各自子区域内的目标值相加,然后除以数据点的数量。因此,这两个区域的输出值分别为:
-
c 1 = 5.56 + 5.7 2 = 5.63 c_1 = \frac{5.56 + 5.7}{2} = 5.63 c1=25.56+5.7=5.63
-
c 2 = 5.91 + 6.4 + 6.8 + 7.05 + 8.9 + 8.7 + 9 + 9.05 8 = 7.73 c_2 = \frac{5.91+6.4+6.8+7.05+8.9+8.7+9+9.05}{8} = 7.73 c2=85.91+6.4+6.8+7.05+8.9+8.7+9+9.05=7.73
同理,我们可以得到其他各切分点的子区域输出值,如下表所示:
s s s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 | 6.5 | 7.5 | 8.5 | 9.5 |
---|---|---|---|---|---|---|---|---|---|
c 1 c1 c1 | 5.56 | 5.63 | 5.72 | 5.89 | 6.07 | 6.24 | 6.62 | 6.88 | 7.11 |
c 2 c2 c2 | 7.5 | 7.73 | 7.99 | 8.25 | 8.54 | 8.91 | 8.92 | 9.03 | 9.05 |
b. 计算损失函数值,找到最优切分点:
把 c 1 c_1 c1, c 2 c_2 c2 的值代入到平方损失函数 L ( y , f ( x ) ) = [ f ( x ) − y ] 2 \mathcal{L}(y, f(x)) = [f(x) - y]^2 L(y,f(x))=[f(x)−y]2,其中 f ( x ) f(x) f(x) 为预测值, y y y 为真实值(目标值)
当s=1.5时:总损失为:
L = ∑ x i ∈ R 1 [ f ( x i ) − y i ] 2 + ∑ x i ∈ R 2 [ f ( x i ) − y i ] 2 = [ 5.56 − 5.56 ] 2 + [ 7.50 − 5.7 ] 2 + [ 7.50 − 5.91 ] 2 + . . . + [ 7.50 − 9.05 ] 2 = 0 + ( 1.8 ) 2 + ( 1.59 ) 2 + . . . + ( − 1.55 ) 2 = 15.72 \begin{aligned} \mathcal{L} &= \sum_{x_i \in R_1} [f(x_i) - y_i]^2 + \sum_{x_i \in R_2} [f(x_i) - y_i]^2 \\ &= [5.56 - 5.56]^2 + [7.50 - 5.7]^2 + [7.50 - 5.91]^2 + ... + [7.50 - 9.05]^2 \\ &= 0 + (1.8)^2 + (1.59)^2 + ... + (-1.55)^2 & = 15.72 \end{aligned} L=xi∈R1∑[f(xi)−yi]2+xi∈R2∑[f(xi)−yi]2=[5.56−5.56]2+[7.50−5.7]2+[7.50−5.91]2+...+[7.50−9.05]2=0+(1.8)2+(1.59)2+...+(−1.55)2=15.72
当切分点 s = 2.5 s=2.5 s=2.5 时,总损失为:
L = ∑ x i ∈ R 1 [ f ( x i ) − y i ] 2 + ∑ x i ∈ R 2 [ f ( x i ) − y i ] 2 = [ 5.63 − 5.56 ] 2 + [ 5.63 − 5.7 ] 2 + [ 7.73 − 5.91 ] 2 + . . . + [ 7.73 − 9.05 ] 2 = ( 0.07 ) 2 + ( − 0.07 ) 2 + ( 1.82 ) 2 + . . . + ( − 1.32 ) 2 \begin{aligned} \mathcal{L} &= \sum_{x_i \in R_1} [f(x_i) - y_i]^2 + \sum_{x_i \in R_2} [f(x_i) - y_i]^2 \\ &= [5.63 - 5.56]^2 + [5.63 - 5.7]^2 + [7.73 - 5.91]^2 + ... + [7.73 - 9.05]^2 \\ &= (0.07)^2 + (-0.07)^2 + (1.82)^2 + ... + (-1.32)^2 \end{aligned} L=xi∈R1∑[f(xi)−yi]2+xi∈R2∑[f(xi)−yi]2=[5.63−5.56]2+[5.63−5.7]2+[7.73−5.91]2+...+[7.73−9.05]2=(0.07)2+(−0.07)2+(1.82)2+...+(−1.32)2
同理,计算得到其他各切分点的损失函数值,可获得下表:
s s s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 | 6.5 | 7.5 | 8.5 | 9.5 |
---|---|---|---|---|---|---|---|---|---|
m ( s ) m(s) m(s) | 15.72 | 12.07 | 8.36 | 5.78 | 3.91 | 1.93 | 8.01 | 11.73 | 15.74 |
显然取 s = 6.5 s=6.5 s=6.5 时, m ( s ) m(s) m(s) 最小。因此第一个划分变量 [ j = x , s = 6.5 ] [j=x,s=6.5] [j=x,s=6.5]
Q:为什么要用
m
(
s
)
m(s)
m(s),不应该是
L
(
y
,
f
(
x
)
)
\mathcal{L}(y, f(x))
L(y,f(x))吗?
A:
m
(
s
)
m(s)
m(s) 和
L
\mathcal{L}
L 都表示损失函数。在回归决策树中,损失函数用于衡量划分后的子区域内预测值与真实值之间的差异。不同的文献或资料可能会使用不同的符号来表示损失函数,但它们的意义是相同的。
m ( s ) m(s) m(s) 用于表示在切分点 s s s 处的损失函数值。因此,当计算不同切分点处的损失函数值时,使用 m ( s ) m(s) m(s) 或 L \mathcal{L} L 都是可以的。
二、用选定的 ( j , s ) (j, s) (j,s) 划分区域,并决定输出值:
- 两个区域分别是: R 1 = 1 , 2 , 3 , 4 , 5 , 6 R_1={1,2,3,4,5,6} R1=1,2,3,4,5,6, R 2 = 7 , 8 , 9 , 10 R_2={7,8,9,10} R2=7,8,9,10
- 输出值 c m = a v g ( y i ∣ x i ∈ R m ) c_m = \mathrm{avg}(y_i|x_i\in R_m) cm=avg(yi∣xi∈Rm), c 1 = 6.24 c_1 =6.24 c1=6.24, c 2 = 8.91 c_2 = 8.91 c2=8.91
三、调用步骤一、二,继续划分:
对 R 1 R_1 R1 继续进行划分:
x x x(特征值) | 1 | 2 | 3 | 4 | 5 | 6 |
---|---|---|---|---|---|---|
y y y(目标值) | 5.56 | 5.7 | 5.91 | 6.4 | 6.8 | 7.05 |
取切分点 [ 1.5 , 2.5 , 3.5 , 4.5 , 5.5 ] [1.5,2.5,3.5,4.5,5.5] [1.5,2.5,3.5,4.5,5.5],则各区域的输出值 c c c 如下表:
s s s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 |
---|---|---|---|---|---|
c 1 c1 c1 | 5.56 | 5.63 | 5.72 | 5.89 | 6.07 |
c 2 c2 c2 | 6.37 | 6.54 | 6.75 | 6.93 | 7.02 |
计算损失函数值 m ( s ) m(s) m(s):
s s s | 1.5 | 2.5 | 3.5 | 4.5 | 5.5 |
---|---|---|---|---|---|
m ( s ) m(s) m(s) | 1.3087 | 0.754 | 0.2771 | 0.4368 | 1.0644 |
s = 3.5 s=3.5 s=3.5 时, m ( s ) m(s) m(s) 最小。
循环…
回归决策树的划分终止条件通常有以下几种:
- 子区域中的数据点数量小于预先设定的阈值。
- 子区域中的数据点目标值的方差小于预先设定的阈值。
- 树的深度达到预先设定的最大深度。
当满足上述任意一个条件时,划分过程将终止。这些条件可以根据具体问题进行调整,以获得最佳的模型性能。
四、生成回归树
假设在生成 3 个区域之后停止划分,那么最终生成的回归树形式如下:
T = { 5.72 x ≤ 3.5 6.75 3.5 ≤ x ≤ 6.5 8.91 6.5 < x T = \begin{cases} 5.72 & x \le 3.5 \\ 6.75 & 3.5 \le x \le 6.5 \\ 8.91 & 6.5 < x \end{cases} T=⎩ ⎨ ⎧5.726.758.91x≤3.53.5≤x≤6.56.5<x
这棵回归树的结构如下:
[j=x,s=6.5]
/ \
[j=x,s=3.5] R_2
/ \
R_{11} R_{12}
其中, R 11 R_{11} R11、 R 12 R_{12} R12 和 R 2 R_2 R2 都是叶子节点。
这棵回归树有三个叶子节点,分别对应三个子区域 R 11 R_{11} R11、 R 12 R_{12} R12 和 R 2 R_2 R2。根节点的划分变量为 [ j = x , s = 6.5 ] [j=x,s=6.5] [j=x,s=6.5],它将数据分为两个子区域: R 1 R_1 R1 和 R 2 R_2 R2。根节点的左子节点对应子区域 R 1 R_1 R1,它的划分变量为 [ j = x , s = 3.5 ] [j=x,s=3.5] [j=x,s=3.5],将子区域 R 1 R_1 R1 再次分为两个子区域: R 11 R_{11} R11 和 R 12 R_{12} R12。根节点的左子节点的左右子节点分别对应子区域 R 11 R_{11} R11 和 R 12 R_{12} R12,它们都是叶子节点。根节点的右子节点对应子区域 R 2 R_2 R2,它也是一个叶子节点。
其中:
- j j j 和 s s s 分别表示切分特征和切分点
-
j
=
x
j=x
j=x 表示切分特征为
x
x
x,而
s
=
6.5
s=6.5
s=6.5 表示切分点为
6.5
6.5
6.5
- 当切分变量为 [ j = x , s = 6.5 ] [j=x,s=6.5] [j=x,s=6.5] 时,数据将根据特征 x x x 的值被分为两个子区域: R 1 R_1 R1 和 R 2 R_2 R2。子区域 R 1 R_1 R1 包括特征值小于等于 6.5 6.5 6.5 的数据点,而子区域 R 2 R_2 R2 包括特征值大于 6.5 6.5 6.5 的数据点。
- 因此,当切分变量为 [ j = x , s = 6.5 ] [j=x,s=6.5] [j=x,s=6.5] 时,数据将根据特征 x x x 的值被分为两个子区域。
小结:
- 输入:训练数据集 D D D
- 输出:回归树 f ( x ) f(x) f(x)
- 流程:在训练数据集所在的输入空间中,递归的将每个区域划分为两个子区域并决定每个子区域上的输出值,构建二叉决策树:
- 选择最优切分特征 j j j 与切分点 s s s,求解 min j , s [ min c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + min c 2 ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] \underset{j, s}{\min}\left[ \underset{c_1}{\min} \sum_{x_i \in R_1(j, s)} (y_i - c_1)^2 + \underset{c_2}{\min} \sum_{x_i \in R_2(j, s)}(y_i - c_2)^2 \right] j,smin[c1min∑xi∈R1(j,s)(yi−c1)2+c2min∑xi∈R2(j,s)(yi−c2)2] —— 遍历特征 j j j,对固定的切分特征 j j j 扫描切分点 s s s,选择使得上式达到最小值的对 ( j , s ) (j, s) (j,s)
- 用选定的对
(
j
,
s
)
(j, s)
(j,s) 划分区域并决定相应的输出值:
R 1 ( j , s ) = x ∣ x ( j ) ≤ s R_1(j, s) = x|x^{(j)} \le s R1(j,s)=x∣x(j)≤s
R 2 ( j , s ) = x ∣ x ( j ) > s R_2(j, s) = x|x^{(j)} >s R2(j,s)=x∣x(j)>s
c
^
m
=
1
N
∑
x
1
∈
R
m
(
j
,
s
)
y
i
其中
x
∈
R
m
,
m
=
1
,
2
\hat{c}_m = \frac{1}{N}\sum_{x_1 \in R_m(j, s)} y_i \ \ 其中x \in R_m, m = 1, 2
c^m=N1x1∈Rm(j,s)∑yi 其中x∈Rm,m=1,2
3. 继续对两个子区域调用步骤一和二,直至满足停止条件。
4. 将输入空间划分为
M
M
M 个区域
R
1
,
R
2
,
…
,
R
M
R_1, R_2 , …, R_M
R1,R2,…,RM,生成决策树
f
(
x
)
=
∑
m
=
1
M
c
^
m
I
(
x
∈
R
m
)
f(x) = \sum_{m = 1}^M \hat{c}_mI(x\in R_m)
f(x)=∑m=1Mc^mI(x∈Rm):
7.4 回归决策树和线性回归对比
import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeRegressor
from sklearn.linear_model import LinearRegression
from pylab import mpl
# 设置中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]
# 设置正常显示符号
mpl.rcParams["axes.unicode_minus"] = False
# 1. ⽣成数据
x = np.array(list(range(1, 11))).reshape(-1, 1) # 使其变为列向量
y = np.array([5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05])
# 2. 训练模型
model_1 = DecisionTreeRegressor(max_depth=1) # 决策树模型
model_2 = DecisionTreeRegressor(max_depth=3) # 决策树模型
model_3 = LinearRegression() # 线性回归模型
model_1.fit(x, y)
model_2.fit(x, y)
model_3.fit(x, y)
# 3. 模型预测
X_test = np.arange(0.0, 10.0, 0.01).reshape(-1, 1) # ⽣成1000个数,⽤于预测模型
predict_1 = model_1.predict(X_test)
predict_2 = model_2.predict(X_test)
predict_3 = model_3.predict(X_test)
# 4. 结果可视化
plt.figure(dpi=300)
plt.scatter(x, y, label="原始数据(目标值)")
plt.plot(X_test, predict_1, label="回归决策树: max_depth=1")
plt.plot(X_test, predict_2, label="回归决策树: max_depth=3")
plt.plot(X_test, predict_3, label="线性回归")
plt.xlabel("数据")
plt.ylabel("预测值")
plt.title("线性回归与回归决策树效果对比")
plt.grid(alpha=0.5)
plt.legend()
plt.show()
结果:
8. 决策树总结
8.1 优点
- 易于理解和解释。
- 决策树的结构可以可视化,非专家也能很容易理解。
- 数据准备简单。
- 决策树不需要对数据进行复杂的预处理,例如归一化或去除缺失值。
- 能够同时处理数值型和分类数据。
- 不受数据缩放的影响。
- 计算成本相对较低。
这些优点使得决策树在许多领域都得到了广泛应用。
8.2 缺点
- 容易过拟合。决策树模型可能会产生过于复杂的模型,导致泛化能力较差。
- 可以通过剪枝、设置叶节点所需的最小样本数或设置树的最大深度来避免过拟合。
- 不稳定性。微小的数据变化可能会导致生成完全不同的树。
- 这个问题可以通过决策树集成来缓解。
- 对连续性字段预测困难。
- 当类别太多时,错误率可能会增加较快。
这些缺点需要在使用决策树时予以注意。
8.3 改进的方法
针对决策树的缺点,有一些改进方法可以使用。例如:
- 避免过拟合。
- 可以通过剪枝、设置叶节点所需的最小样本数或设置树的最大深度来避免过拟合。
- 剪枝包括预剪枝和后剪枝。
- 前者通过对连续型变量设置阈值,来控制树的深度,或者控制节点的个数,在节点开始划分之前就进行操作,进而防止过拟合现象。
- 后者是自底向上对非叶节点进行考察,如果这个内部节点换成叶节点能提升决策树的泛化能力,那就把它换掉。
- 使用决策树集成。
- 可以通过集成多个决策树来提高模型的稳定性和准确性。
- 例如,随机森林算法就是基于决策树的集成学习算法,它通过构建多棵决策树并结合它们的预测结果来提高模型的准确性和稳定性。
- 可以通过集成多个决策树来提高模型的稳定性和准确性。
- 对连续性字段进行离散化处理。
- 可以将连续性字段离散化为分类变量,以便决策树能够更好地处理。
- 对类别不平衡的数据进行重采样。
- 可以对类别不平衡的数据进行重采样,以减少错误率。
这些方法可以帮助改进决策树模型,提高其准确性和稳定性。