- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
前言
- 机器学习是深度学习和数据分析的基础,接下来将更新常见的机器学习算法
- 注意:在打数学建模比赛中,机器学习用的也很多,可以一起学习
- 决策树模型数学原理很复杂,强烈推荐看书,看书,看书,这里推荐《统计学习方法》和《机器学习西瓜书》。
- 这里只是介绍了决策树组成,但是原理没有详细介绍,后面会出详介绍篇章。
- 最近开学,更新不太及时,请大家见谅,欢迎收藏 + 点赞 + 关注
文章目录
- 决策树模型
- 简介
- 建立决策树的方法
- 分类案例
- 导入数据和数据分析
- 划分自变量和因变量
- 模型训练
- 模型预测结果
- 回归案例
- 导入数据
- 划分数据
- 创建模型
- 模型预测与训练
- 模型评估
- 树图绘制
决策树模型
简介
定义(统计学习方法):分类决策树模型是一种描述对实例进行分类的树形结构,决策树由节点、有向边组成,节点类型有两种,内部节点和叶子节点,内部节点表示一个特征或者属性,叶子节点表示一个类。
决策树与if-then:
学过任何语言的人都知道if-else
结构,决策树也是这样,如果if
满足某一种条件,则归到一类,不满足条件的归到另外一类,如此循环判断,一直到所有特征、属性和类都归类到某一类,最终形成一颗树,注意:一个原则互斥且完备。
决策树过程:
特征选择、建立决策树、决策树剪枝三个过程
决策树解决问题:
回归和分类,如果分类的叶子节点,就是回归
,否则就是分类
。
建立决策树的方法
决策树背后由很多的数学原理,这里只介绍信息增益、信息增益比、基尼系数,其他的概念推荐翻阅统计学习方法和西瓜书,想要电子版资料的可以私聊我。
建议:这一部分一定要看书,推荐统计学习方法和机器学习西瓜书,书中有很详细的案例帮助我们理解。67y
以下概念均来自于《统计学习方法》
信息增益: 特征A对训练数据集D的信息增益g(D.A),定义为集合D的经验熵H(D)
与特征A
给定条件下D
的经验条件H(DA)之差
,即:
g ( D , A ) = H ( D ) − H ( D ∣ A ) g\left(D,A\right)=H\left(D\right)-H\left(D|A\right) g(D,A)=H(D)−H(D∣A)
信息增益比:特征A对训练数据集D的信息增益比gR(D)定为其信息增益 g(D,A)与训练数据集 D 关于特征 A的值的熵 HA(D)之比。即:
g R ( D , A ) = g ( D , A ) H A ( D ) g_{R}(D,A)=\frac{g(D,A)}{H_{A}(D)} gR(D,A)=HA(D)g(D,A)
其中: H A ( D ) = − ∑ i = 1 n ∣ D i ∣ ∣ D ∣ log 2 ∣ D i ∣ ∣ D ∣ H_{A}(D)=-\sum_{i=1}^{n}\frac{\left|D_{i}\right|}{\left|D\right|}\log_{2}\frac{\left|D_{i}\right|}{\left|D\right|} HA(D)=−i=1∑n∣D∣∣Di∣log2∣D∣∣Di∣ ,n表示特征A的数量。
基尼指数:分类问题中,假设有区个类,样本点属于第k 类的概率为 pk,则概率分布的基尼指数定义为:
G i n i ( p ) = ∑ k = 1 K p k ( 1 − p k ) = 1 − ∑ k = 1 K p k 2 Gini\left(p\right)=\sum_{k=1}^{K}p_{k}\left(1-p_{k}\right)=1-\sum_{k=1}^{K}p_{k}^{2} Gini(p)=k=1∑Kpk(1−pk)=1−k=1∑Kpk2
建议:看书,通过案例和公式来理解。
分类案例
简介:通过鸢尾花的叶子特征,构建判别叶子类别的树。
导入数据和数据分析
import numpy as np
import pandas as pd
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
columns = ['花萼-length', '花萼-width', '花瓣-length', '花瓣-width', 'class']
data = pd.read_csv(url, names=columns)
data
花萼-length | 花萼-width | 花瓣-length | 花瓣-width | class | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa |
1 | 4.9 | 3.0 | 1.4 | 0.2 | Iris-setosa |
2 | 4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa |
3 | 4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa |
4 | 5.0 | 3.6 | 1.4 | 0.2 | Iris-setosa |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | Iris-virginica |
146 | 6.3 | 2.5 | 5.0 | 1.9 | Iris-virginica |
147 | 6.5 | 3.0 | 5.2 | 2.0 | Iris-virginica |
148 | 6.2 | 3.4 | 5.4 | 2.3 | Iris-virginica |
149 | 5.9 | 3.0 | 5.1 | 1.8 | Iris-virginica |
150 rows × 5 columns
# 查看值的类别和数量
data['class'].value_counts()
结果:
class
Iris-setosa 50
Iris-versicolor 50
Iris-virginica 50
Name: count, dtype: int64
# 查看变量信息
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150 entries, 0 to 149
Data columns (total 5 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 花萼-length 150 non-null float64
1 花萼-width 150 non-null float64
2 花瓣-length 150 non-null float64
3 花瓣-width 150 non-null float64
4 class 150 non-null object
dtypes: float64(4), object(1)
memory usage: 6.0+ KB
# 查看缺失值
data.isnull().sum()
结果:
花萼-length 0
花萼-width 0
花瓣-length 0
花瓣-width 0
class 0
dtype: int64
# 查看特征的统计变量
data.describe()
结果:
花萼-length | 花萼-width | 花瓣-length | 花瓣-width | |
---|---|---|---|---|
count | 150.000000 | 150.000000 | 150.000000 | 150.000000 |
mean | 5.843333 | 3.054000 | 3.758667 | 1.198667 |
std | 0.828066 | 0.433594 | 1.764420 | 0.763161 |
min | 4.300000 | 2.000000 | 1.000000 | 0.100000 |
25% | 5.100000 | 2.800000 | 1.600000 | 0.300000 |
50% | 5.800000 | 3.000000 | 4.350000 | 1.300000 |
75% | 6.400000 | 3.300000 | 5.100000 | 1.800000 |
max | 7.900000 | 4.400000 | 6.900000 | 2.500000 |
# 查看特征变量的相关性
name_corr = ['花萼-length', '花萼-width', '花瓣-length', '花瓣-width']
corr = data[name_corr].corr()
print(corr)
花萼-length 花萼-width 花瓣-length 花瓣-width
花萼-length 1.000000 -0.109369 0.871754 0.817954
花萼-width -0.109369 1.000000 -0.420516 -0.356544
花瓣-length 0.871754 -0.420516 1.000000 0.962757
花瓣-width 0.817954 -0.356544 0.962757 1.000000
说明:特征变量之间存在共线性问题
划分自变量和因变量
X = data.iloc[:, [0, 1, 2, 3]].values # .values转化成矩阵
y = data.iloc[:, 4].values
模型训练
from sklearn import tree
model = tree.DecisionTreeClassifier()
model.fit(X, y) # 模型训练
# 打印模型结构
r = tree.export_text(model)
模型预测结果
# 随机选取值
x_test = X[[0, 30, 60, 90, 120, 130], :]
y_pred_prob = model.predict_proba(x_test) # 预测概率
y_pred = model.predict(x_test) # 预测值
print("\n===模型===")
print(r)
===模型===
|--- feature_3 <= 0.80
| |--- class: Iris-setosa
|--- feature_3 > 0.80
| |--- feature_3 <= 1.75
| | |--- feature_2 <= 4.95
| | | |--- feature_3 <= 1.65
| | | | |--- class: Iris-versicolor
| | | |--- feature_3 > 1.65
| | | | |--- class: Iris-virginica
| | |--- feature_2 > 4.95
| | | |--- feature_3 <= 1.55
| | | | |--- class: Iris-virginica
| | | |--- feature_3 > 1.55
| | | | |--- feature_2 <= 5.45
| | | | | |--- class: Iris-versicolor
| | | | |--- feature_2 > 5.45
| | | | | |--- class: Iris-virginica
| |--- feature_3 > 1.75
| | |--- feature_2 <= 4.85
| | | |--- feature_0 <= 5.95
| | | | |--- class: Iris-versicolor
| | | |--- feature_0 > 5.95
| | | | |--- class: Iris-virginica
| | |--- feature_2 > 4.85
| | | |--- class: Iris-virginica
print("\n===测试数据===")
print(x_test)
===测试数据===
[[5.1 3.5 1.4 0.2]
[4.8 3.1 1.6 0.2]
[5. 2. 3.5 1. ]
[5.5 2.6 4.4 1.2]
[6.9 3.2 5.7 2.3]
[7.4 2.8 6.1 1.9]]
print("\n===预测所属类别概率===")
print(y_pred_prob)
===预测所属类别概率===
[[1. 0. 0.]
[1. 0. 0.]
[0. 1. 0.]
[0. 1. 0.]
[0. 0. 1.]
[0. 0. 1.]]
print("\n===测试所属类别==")
print(y_pred)
===测试所属类别==
['Iris-setosa' 'Iris-setosa' 'Iris-versicolor' 'Iris-versicolor'
'Iris-virginica' 'Iris-virginica']
回归案例
通过鸢尾花三个特征,预测花瓣长度。
导入数据
import pandas as pd
import numpy as np
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
names = ['花萼-width', '花萼-length', '花瓣-width', '花瓣-length', 'class']
data = pd.read_csv(url, names=names)
data
花萼-width | 花萼-length | 花瓣-width | 花瓣-length | class | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | Iris-setosa |
1 | 4.9 | 3.0 | 1.4 | 0.2 | Iris-setosa |
2 | 4.7 | 3.2 | 1.3 | 0.2 | Iris-setosa |
3 | 4.6 | 3.1 | 1.5 | 0.2 | Iris-setosa |
4 | 5.0 | 3.6 | 1.4 | 0.2 | Iris-setosa |
... | ... | ... | ... | ... | ... |
145 | 6.7 | 3.0 | 5.2 | 2.3 | Iris-virginica |
146 | 6.3 | 2.5 | 5.0 | 1.9 | Iris-virginica |
147 | 6.5 | 3.0 | 5.2 | 2.0 | Iris-virginica |
148 | 6.2 | 3.4 | 5.4 | 2.3 | Iris-virginica |
149 | 5.9 | 3.0 | 5.1 | 1.8 | Iris-virginica |
150 rows × 5 columns
划分数据
# 划分数据
X = data.iloc[:, [0, 1, 2]]
y = data.iloc[:, 3]
创建模型
from sklearn import tree
model = tree.DecisionTreeRegressor()
model.fit(X, y) # 模型训练
模型预测与训练
x_test = X.iloc[[0, 1, 50, 51, 100, 120], :]
y_test = y.iloc[[0, 1, 50, 51, 100, 120]] # 只有一列
y_pred = model.predict(x_test)
模型评估
# 输出原始值和真实值
df = pd.DataFrame()
df['原始值'] = y_test
df['预测值'] = y_pred
df
原始值 | 预测值 | |
---|---|---|
0 | 0.2 | 0.25 |
1 | 0.2 | 0.20 |
50 | 1.4 | 1.40 |
51 | 1.5 | 1.50 |
100 | 2.5 | 2.50 |
120 | 2.3 | 2.30 |
from sklearn.metrics import mean_absolute_error
# 误差计算
mse = mean_absolute_error(y_test, y_pred)
mse
结果:
0.008333333333333331
# 打印树结构
r = tree.export_text(model)
print(r)
# 树模型结构比较复杂,可以运行后面代码绘图展示。
|--- feature_2 <= 2.45
| |--- feature_1 <= 3.25
| | |--- feature_1 <= 2.60
| | | |--- value: [0.30]
| | |--- feature_1 > 2.60
| | | |--- feature_0 <= 4.85
| | | | |--- feature_0 <= 4.35
| | | | | |--- value: [0.10]
| | | | |--- feature_0 > 4.35
| | | | | |--- feature_2 <= 1.35
| | | | | | |--- value: [0.20]
| | | | | |--- feature_2 > 1.35
| | | | | | |--- feature_1 <= 2.95
| | | | | | | |--- value: [0.20]
| | | | | | |--- feature_1 > 2.95
| | | | | | | |--- feature_0 <= 4.65
| | | | | | | | |--- value: [0.20]
| | | | | | | |--- feature_0 > 4.65
| | | | | | | | |--- feature_1 <= 3.05
| | | | | | | | | |--- value: [0.20]
| | | | | | | | |--- feature_1 > 3.05
| | | | | | | | | |--- value: [0.20]
| | | |--- feature_0 > 4.85
| | | | |--- feature_0 <= 4.95
| | | | | |--- feature_1 <= 3.05
| | | | | | |--- value: [0.20]
| | | | | |--- feature_1 > 3.05
| | | | | | |--- value: [0.10]
| | | | |--- feature_0 > 4.95
| | | | | |--- value: [0.20]
| |--- feature_1 > 3.25
| | |--- feature_2 <= 1.55
| | | |--- feature_1 <= 4.30
| | | | |--- feature_1 <= 3.95
| | | | | |--- feature_1 <= 3.85
| | | | | | |--- feature_1 <= 3.65
| | | | | | | |--- feature_0 <= 5.30
| | | | | | | | |--- feature_2 <= 1.45
| | | | | | | | | |--- feature_1 <= 3.55
| | | | | | | | | | |--- feature_2 <= 1.35
| | | | | | | | | | | |--- value: [0.30]
| | | | | | | | | | |--- feature_2 > 1.35
| | | | | | | | | | | |--- truncated branch of depth 3
| | | | | | | | | |--- feature_1 > 3.55
| | | | | | | | | | |--- value: [0.20]
| | | | | | | | |--- feature_2 > 1.45
| | | | | | | | | |--- value: [0.20]
| | | | | | | |--- feature_0 > 5.30
| | | | | | | | |--- feature_1 <= 3.45
| | | | | | | | | |--- value: [0.40]
| | | | | | | | |--- feature_1 > 3.45
| | | | | | | | | |--- value: [0.20]
| | | | | | |--- feature_1 > 3.65
| | | | | | | |--- feature_0 <= 5.20
| | | | | | | | |--- feature_1 <= 3.75
| | | | | | | | | |--- value: [0.40]
| | | | | | | | |--- feature_1 > 3.75
| | | | | | | | | |--- value: [0.30]
| | | | | | | |--- feature_0 > 5.20
| | | | | | | | |--- value: [0.20]
| | | | | |--- feature_1 > 3.85
| | | | | | |--- value: [0.40]
| | | | |--- feature_1 > 3.95
| | | | | |--- feature_0 <= 5.35
| | | | | | |--- value: [0.10]
| | | | | |--- feature_0 > 5.35
| | | | | | |--- value: [0.20]
| | | |--- feature_1 > 4.30
| | | | |--- value: [0.40]
| | |--- feature_2 > 1.55
| | | |--- feature_0 <= 4.90
| | | | |--- value: [0.20]
| | | |--- feature_0 > 4.90
| | | | |--- feature_0 <= 5.05
| | | | | |--- feature_1 <= 3.45
| | | | | | |--- value: [0.40]
| | | | | |--- feature_1 > 3.45
| | | | | | |--- value: [0.60]
| | | | |--- feature_0 > 5.05
| | | | | |--- feature_1 <= 3.35
| | | | | | |--- value: [0.50]
| | | | | |--- feature_1 > 3.35
| | | | | | |--- feature_2 <= 1.65
| | | | | | | |--- value: [0.20]
| | | | | | |--- feature_2 > 1.65
| | | | | | | |--- feature_1 <= 3.60
| | | | | | | | |--- value: [0.20]
| | | | | | | |--- feature_1 > 3.60
| | | | | | | | |--- feature_0 <= 5.55
| | | | | | | | | |--- value: [0.40]
| | | | | | | | |--- feature_0 > 5.55
| | | | | | | | | |--- value: [0.30]
|--- feature_2 > 2.45
| |--- feature_2 <= 4.75
| | |--- feature_2 <= 4.15
| | | |--- feature_1 <= 2.65
| | | | |--- feature_2 <= 3.95
| | | | | |--- feature_2 <= 3.75
| | | | | | |--- feature_2 <= 3.15
| | | | | | | |--- value: [1.10]
| | | | | | |--- feature_2 > 3.15
| | | | | | | |--- value: [1.00]
| | | | | |--- feature_2 > 3.75
| | | | | | |--- feature_0 <= 5.55
| | | | | | | |--- value: [1.10]
| | | | | | |--- feature_0 > 5.55
| | | | | | | |--- value: [1.10]
| | | | |--- feature_2 > 3.95
| | | | | |--- feature_0 <= 5.90
| | | | | | |--- feature_0 <= 5.65
| | | | | | | |--- feature_1 <= 2.40
| | | | | | | | |--- value: [1.30]
| | | | | | | |--- feature_1 > 2.40
| | | | | | | | |--- value: [1.30]
| | | | | | |--- feature_0 > 5.65
| | | | | | | |--- value: [1.20]
| | | | | |--- feature_0 > 5.90
| | | | | | |--- value: [1.00]
| | | |--- feature_1 > 2.65
| | | | |--- feature_0 <= 5.75
| | | | | |--- feature_0 <= 5.40
| | | | | | |--- value: [1.40]
| | | | | |--- feature_0 > 5.40
| | | | | | |--- value: [1.30]
| | | | |--- feature_0 > 5.75
| | | | | |--- feature_2 <= 4.05
| | | | | | |--- feature_2 <= 3.95
| | | | | | | |--- value: [1.20]
| | | | | | |--- feature_2 > 3.95
| | | | | | | |--- value: [1.30]
| | | | | |--- feature_2 > 4.05
| | | | | | |--- value: [1.00]
| | |--- feature_2 > 4.15
| | | |--- feature_2 <= 4.45
| | | | |--- feature_0 <= 5.80
| | | | | |--- feature_1 <= 2.65
| | | | | | |--- value: [1.20]
| | | | | |--- feature_1 > 2.65
| | | | | | |--- feature_1 <= 2.95
| | | | | | | |--- feature_1 <= 2.80
| | | | | | | | |--- value: [1.30]
| | | | | | | |--- feature_1 > 2.80
| | | | | | | | |--- value: [1.30]
| | | | | | |--- feature_1 > 2.95
| | | | | | | |--- value: [1.20]
| | | | |--- feature_0 > 5.80
| | | | | |--- feature_1 <= 2.95
| | | | | | |--- value: [1.30]
| | | | | |--- feature_1 > 2.95
| | | | | | |--- feature_0 <= 6.25
| | | | | | | |--- value: [1.50]
| | | | | | |--- feature_0 > 6.25
| | | | | | | |--- value: [1.40]
| | | |--- feature_2 > 4.45
| | | | |--- feature_0 <= 5.15
| | | | | |--- value: [1.70]
| | | | |--- feature_0 > 5.15
| | | | | |--- feature_1 <= 3.25
| | | | | | |--- feature_1 <= 2.95
| | | | | | | |--- feature_2 <= 4.65
| | | | | | | | |--- feature_0 <= 5.85
| | | | | | | | | |--- value: [1.30]
| | | | | | | | |--- feature_0 > 5.85
| | | | | | | | | |--- feature_0 <= 6.55
| | | | | | | | | | |--- value: [1.50]
| | | | | | | | | |--- feature_0 > 6.55
| | | | | | | | | | |--- value: [1.30]
| | | | | | | |--- feature_2 > 4.65
| | | | | | | | |--- feature_1 <= 2.85
| | | | | | | | | |--- value: [1.20]
| | | | | | | | |--- feature_1 > 2.85
| | | | | | | | | |--- value: [1.40]
| | | | | | |--- feature_1 > 2.95
| | | | | | | |--- feature_2 <= 4.55
| | | | | | | | |--- value: [1.50]
| | | | | | | |--- feature_2 > 4.55
| | | | | | | | |--- feature_2 <= 4.65
| | | | | | | | | |--- value: [1.40]
| | | | | | | | |--- feature_2 > 4.65
| | | | | | | | | |--- feature_1 <= 3.15
| | | | | | | | | | |--- value: [1.50]
| | | | | | | | | |--- feature_1 > 3.15
| | | | | | | | | | |--- value: [1.40]
| | | | | |--- feature_1 > 3.25
| | | | | | |--- value: [1.60]
| |--- feature_2 > 4.75
| | |--- feature_2 <= 5.05
| | | |--- feature_0 <= 6.75
| | | | |--- feature_0 <= 5.80
| | | | | |--- value: [2.00]
| | | | |--- feature_0 > 5.80
| | | | | |--- feature_1 <= 2.35
| | | | | | |--- value: [1.50]
| | | | | |--- feature_1 > 2.35
| | | | | | |--- feature_0 <= 6.25
| | | | | | | |--- value: [1.80]
| | | | | | |--- feature_0 > 6.25
| | | | | | | |--- feature_2 <= 4.95
| | | | | | | | |--- feature_1 <= 2.60
| | | | | | | | | |--- value: [1.50]
| | | | | | | | |--- feature_1 > 2.60
| | | | | | | | | |--- value: [1.80]
| | | | | | | |--- feature_2 > 4.95
| | | | | | | | |--- feature_0 <= 6.50
| | | | | | | | | |--- value: [1.90]
| | | | | | | | |--- feature_0 > 6.50
| | | | | | | | | |--- value: [1.70]
| | | |--- feature_0 > 6.75
| | | | |--- feature_0 <= 6.85
| | | | | |--- value: [1.40]
| | | | |--- feature_0 > 6.85
| | | | | |--- value: [1.50]
| | |--- feature_2 > 5.05
| | | |--- feature_1 <= 3.05
| | | | |--- feature_0 <= 6.35
| | | | | |--- feature_0 <= 5.85
| | | | | | |--- feature_1 <= 2.75
| | | | | | | |--- value: [1.90]
| | | | | | |--- feature_1 > 2.75
| | | | | | | |--- value: [2.40]
| | | | | |--- feature_0 > 5.85
| | | | | | |--- feature_1 <= 2.85
| | | | | | | |--- feature_1 <= 2.65
| | | | | | | | |--- value: [1.40]
| | | | | | | |--- feature_1 > 2.65
| | | | | | | | |--- feature_1 <= 2.75
| | | | | | | | | |--- value: [1.60]
| | | | | | | | |--- feature_1 > 2.75
| | | | | | | | | |--- value: [1.50]
| | | | | | |--- feature_1 > 2.85
| | | | | | | |--- feature_1 <= 2.95
| | | | | | | | |--- value: [1.80]
| | | | | | | |--- feature_1 > 2.95
| | | | | | | | |--- value: [1.80]
| | | | |--- feature_0 > 6.35
| | | | | |--- feature_0 <= 7.50
| | | | | | |--- feature_0 <= 7.15
| | | | | | | |--- feature_1 <= 2.75
| | | | | | | | |--- feature_0 <= 6.55
| | | | | | | | | |--- value: [1.90]
| | | | | | | | |--- feature_0 > 6.55
| | | | | | | | | |--- value: [1.80]
| | | | | | | |--- feature_1 > 2.75
| | | | | | | | |--- feature_0 <= 6.60
| | | | | | | | | |--- feature_2 <= 5.55
| | | | | | | | | | |--- feature_2 <= 5.35
| | | | | | | | | | | |--- value: [2.00]
| | | | | | | | | | |--- feature_2 > 5.35
| | | | | | | | | | | |--- value: [1.80]
| | | | | | | | | |--- feature_2 > 5.55
| | | | | | | | | | |--- feature_2 <= 5.70
| | | | | | | | | | | |--- value: [2.15]
| | | | | | | | | | |--- feature_2 > 5.70
| | | | | | | | | | | |--- value: [2.20]
| | | | | | | | |--- feature_0 > 6.60
| | | | | | | | | |--- feature_0 <= 6.75
| | | | | | | | | | |--- value: [2.30]
| | | | | | | | | |--- feature_0 > 6.75
| | | | | | | | | | |--- value: [2.10]
| | | | | | |--- feature_0 > 7.15
| | | | | | | |--- feature_2 <= 5.95
| | | | | | | | |--- value: [1.60]
| | | | | | | |--- feature_2 > 5.95
| | | | | | | | |--- feature_1 <= 2.85
| | | | | | | | | |--- value: [1.90]
| | | | | | | | |--- feature_1 > 2.85
| | | | | | | | | |--- value: [1.80]
| | | | | |--- feature_0 > 7.50
| | | | | | |--- feature_2 <= 6.80
| | | | | | | |--- feature_2 <= 6.35
| | | | | | | | |--- value: [2.30]
| | | | | | | |--- feature_2 > 6.35
| | | | | | | | |--- feature_1 <= 2.90
| | | | | | | | | |--- value: [2.00]
| | | | | | | | |--- feature_1 > 2.90
| | | | | | | | | |--- value: [2.10]
| | | | | | |--- feature_2 > 6.80
| | | | | | | |--- value: [2.30]
| | | |--- feature_1 > 3.05
| | | | |--- feature_1 <= 3.25
| | | | | |--- feature_2 <= 5.95
| | | | | | |--- feature_0 <= 6.60
| | | | | | | |--- feature_2 <= 5.40
| | | | | | | | |--- feature_0 <= 6.45
| | | | | | | | | |--- value: [2.30]
| | | | | | | | |--- feature_0 > 6.45
| | | | | | | | | |--- value: [2.00]
| | | | | | | |--- feature_2 > 5.40
| | | | | | | | |--- value: [1.80]
| | | | | | |--- feature_0 > 6.60
| | | | | | | |--- feature_2 <= 5.50
| | | | | | | | |--- feature_2 <= 5.25
| | | | | | | | | |--- value: [2.30]
| | | | | | | | |--- feature_2 > 5.25
| | | | | | | | | |--- value: [2.10]
| | | | | | | |--- feature_2 > 5.50
| | | | | | | | |--- feature_2 <= 5.65
| | | | | | | | | |--- value: [2.40]
| | | | | | | | |--- feature_2 > 5.65
| | | | | | | | | |--- value: [2.30]
| | | | | |--- feature_2 > 5.95
| | | | | | |--- value: [1.80]
| | | | |--- feature_1 > 3.25
| | | | | |--- feature_0 <= 7.45
| | | | | | |--- feature_2 <= 5.85
| | | | | | | |--- feature_2 <= 5.65
| | | | | | | | |--- feature_2 <= 5.50
| | | | | | | | | |--- value: [2.30]
| | | | | | | | |--- feature_2 > 5.50
| | | | | | | | | |--- value: [2.40]
| | | | | | | |--- feature_2 > 5.65
| | | | | | | | |--- value: [2.30]
| | | | | | |--- feature_2 > 5.85
| | | | | | | |--- feature_0 <= 6.75
| | | | | | | | |--- value: [2.50]
| | | | | | | |--- feature_0 > 6.75
| | | | | | | | |--- value: [2.50]
| | | | | |--- feature_0 > 7.45
| | | | | | |--- feature_0 <= 7.80
| | | | | | | |--- value: [2.20]
| | | | | | |--- feature_0 > 7.80
| | | | | | | |--- value: [2.00]
树图绘制
from sklearn.tree import export_graphviz
import graphviz
#设置字体
from pylab import mpl
mpl.rcParams["font.sans-serif"] = ["SimHei"] # 显示中文
# 使用export_graphviz生成DOT文件
dot_data = export_graphviz(model, out_file=None,
feature_names=['花萼-width', '花萼-length', '花瓣-width'],
class_names=['花瓣-length'],
filled=True, rounded=True,
special_characters=True)
# 使用graphviz渲染DOT文件
graph = graphviz.Source(dot_data)
graph.render("decision_tree") # 将图形保存为PDF或其它格式
graph.view() # 在默认查看器中打开图形
图太长了,不方便展示,可以运行代码绘制。
ue: [2.50]
| | | | | | | |— feature_0 > 6.75
| | | | | | | | |— value: [2.50]
| | | | | |— feature_0 > 7.45
| | | | | | |— feature_0 <= 7.80
| | | | | | | |— value: [2.20]
| | | | | | |— feature_0 > 7.80
| | | | | | | |— value: [2.00]
## 树图绘制
```python
from sklearn.tree import export_graphviz
import graphviz
#设置字体
from pylab import mpl
mpl.rcParams["font.sans-serif"] = ["SimHei"] # 显示中文
# 使用export_graphviz生成DOT文件
dot_data = export_graphviz(model, out_file=None,
feature_names=['花萼-width', '花萼-length', '花瓣-width'],
class_names=['花瓣-length'],
filled=True, rounded=True,
special_characters=True)
# 使用graphviz渲染DOT文件
graph = graphviz.Source(dot_data)
graph.render("decision_tree") # 将图形保存为PDF或其它格式
graph.view() # 在默认查看器中打开图形
图太长了,不方便展示,可以运行代码绘制。