sklearn中使用决策树

news2025/1/16 11:09:34

1.示例

criterion可以是信息熵,entropy,可以是基尼系数gini

# -*-coding:utf-8-*-
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
wine=load_wine()

# print ( wine.feature_names )
#(178, 13)
print(wine.data.shape)


Xtrain,Xtest,Ytrain,Ytest=train_test_split(wine.data,wine.target,test_size=0.3)

#random_state=30:输入任意整数,会一直长同一棵树,让模型稳定下来
clf=tree.DecisionTreeClassifier(criterion="entropy",random_state=30,splitter="best")
# clf=tree.DecisionTreeClassifier(criterion="entropy")
clf=clf.fit(Xtrain,Ytrain)
#返回预测准确度accuracy
score=clf.score(Xtest,Ytest)

print( score )

import graphviz
dot_data=tree.export_graphviz(clf,
                              feature_names=wine.feature_names,
                              class_names=["wine1","wine2","wine3"],
                              filled=True,
                              rounded=True)
graph=graphviz.Source(dot_data)
#生成pdf文件
graph.render(view=True, format="pdf", filename="tree_pdf")
print ( graph )
#feature_importances_:每个特征在决策树中的重要成都
print(clf.feature_importances_)
print ( [*zip(wine.feature_names,clf.feature_importances_)] )

决策树生成的pdf 

 2.示例

max_depth:这参数用来控制决策树的最大深度。以下示例,构建1~10深度的决策时,看哪个深度的决策树的精确率(score)高

# -*-coding:utf-8-*-
from sklearn import tree
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

plt.switch_backend("TkAgg")

wine=load_wine()

# print ( wine.feature_names )
#(178, 13)
print(wine.data.shape)


import pandas as pd
# print (pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1))
#所有的train,test必须是二维矩阵
Xtrain,Xtest,Ytrain,Ytest=train_test_split(wine.data,wine.target,test_size=0.3)

test=[]
bestScore=-1
bestClf=None
for i in range(10):
    clf=tree.DecisionTreeClassifier(max_depth=i+1,
                                    criterion="entropy",
                                    random_state=30,
                                    splitter="random")
    clf=clf.fit(Xtrain,Ytrain)
    score=clf.score(Xtest,Ytest)
    test.append(score)
    if score>bestScore:
        bestScore=score
        bestClf=clf
print(test)
print(test.index(bestScore))
#predict返回每个测试样本的分类/回归结果
predicted=bestClf.predict(Xtest)
print(predicted)

#返回每个测试样本的叶子节点的索引
leaf=bestClf.apply(Xtest)
print(leaf)

plt.plot(range(1,11),test,color="red",label="max_depth")
plt.legend()
plt.show()

结果:

(178, 13)
[0.5555555555555556, 0.8148148148148148, 0.9444444444444444, 0.9259259259259259, 0.8518518518518519, 0.8333333333333334, 0.8333333333333334, 0.8333333333333334, 0.8333333333333334, 0.8333333333333334]
2
[0 1 0 1 2 0 1 1 1 2 2 0 0 2 0 1 1 0 0 0 0 1 1 0 2 1 0 2 2 1 2 1 1 1 1 0 1
 2 2 0 1 1 2 0 2 1 1 0 1 1 2 1 2 2]
[12  7 12 11  3 12  7  7  4  3  3 12 12  3 12  9  7 12 12 12 12  7  9 12
  3  9 12  3  3  4  3  4  7  7  7 12  7  3  3 12  9  9  3 12  3  7  7 12
  7  7  3  7  3  3]

3.交叉熵验证的示例 

# -*-coding:utf-8-*-
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeRegressor
import sklearn
from sklearn.datasets import fetch_california_housing

housing=fetch_california_housing()
# print(housing)
# print(housing.data)
# print(housing.target)

regressor=DecisionTreeRegressor(random_state=0)

#cv=10,10次交叉验证,default:cv=5
#scoring="neg_mean_squared_error",评价指标是负的均方误差
cross_res=cross_val_score(regressor,
                housing.data,
                housing.target,
                scoring="neg_mean_squared_error",
                cv=10)
print(cross_res)
[-1.30551334 -0.78405711 -0.72809865 -0.50413232 -0.79683323 -0.83698199
 -0.56591889 -1.03621067 -1.02786488 -0.51371889]

4.Titanic生存者预测

数据来源:

Titanic - Machine Learning from Disaster | Kaggle

数据预处理

读取数据 

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
#---------设置pd,在pycharm中显示完全表格-------
pd.set_option('display.max_columns', 1000)
pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', 1000)
#----------------------------------------
data=pd.read_csv("./data.csv")
print (data.head(5))
print(data.info())

   PassengerId  Survived  Pclass                                                 Name     Sex   Age  SibSp  Parch            Ticket     Fare Cabin Embarked
0            1         0       3                              Braund, Mr. Owen Harris    male  22.0      1      0         A/5 21171   7.2500   NaN        S
1            2         1       1  Cumings, Mrs. John Bradley (Florence Briggs Thayer)  female  38.0      1      0          PC 17599  71.2833   C85        C
2            3         1       3                               Heikkinen, Miss. Laina  female  26.0      0      0  STON/O2. 3101282   7.9250   NaN        S
3            4         1       1         Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1      0            113803  53.1000  C123        S
4            5         0       3                             Allen, Mr. William Henry    male  35.0      0      0            373450   8.0500   NaN        S
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB
None

Process finished with exit code 0

筛选特征

data.drop(["Cabin","Name","Ticket"],inplace=True,axis=1)
print(data.head())
print(data.info())
   PassengerId  Survived  Pclass     Sex   Age  SibSp  Parch     Fare Embarked
0            1         0       3    male  22.0      1      0   7.2500        S
1            2         1       1  female  38.0      1      0  71.2833        C
2            3         1       3  female  26.0      0      0   7.9250        S
3            4         1       1  female  35.0      1      0  53.1000        S
4            5         0       3    male  35.0      0      0   8.0500        S
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 9 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Sex          891 non-null    object 
 4   Age          714 non-null    float64
 5   SibSp        891 non-null    int64  
 6   Parch        891 non-null    int64  
 7   Fare         891 non-null    float64
 8   Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(2)
memory usage: 62.8+ KB
None

处理缺失值

#年龄用均值填补
data["Age"]=data["Age"].fillna(data["Age"].mean())
print(data.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 9 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Sex          891 non-null    object 
 4   Age          891 non-null    float64
 5   SibSp        891 non-null    int64  
 6   Parch        891 non-null    int64  
 7   Fare         891 non-null    float64
 8   Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(2)
memory usage: 62.8+ KB
None
#删除有缺失值的行,Embarked缺了两行
data=data.dropna()
print(data.info())
<class 'pandas.core.frame.DataFrame'>
Int64Index: 889 entries, 0 to 890
Data columns (total 9 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  889 non-null    int64  
 1   Survived     889 non-null    int64  
 2   Pclass       889 non-null    int64  
 3   Sex          889 non-null    object 
 4   Age          889 non-null    float64
 5   SibSp        889 non-null    int64  
 6   Parch        889 non-null    int64  
 7   Fare         889 non-null    float64
 8   Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(2)
memory usage: 69.5+ KB
None

处理非数值的列

查看非数值列的所有值

print(data["Embarked"].unique())
print(data["Sex"].unique())

#------------结果如下----------
['S' 'C' 'Q']
['male' 'female']
labels=data["Embarked"].unique().tolist()
#x代表data[Embarked]的每一行的值,S-->0,C-->1,Q-->2
data["Embarked"]=data["Embarked"].apply(lambda x:labels.index(x))

#把条件为True的转为int行
#也可以这样写:data.loc[:,"Sex"]=(data["Sex"]=="male").astype("int")
#male-->0,female-->1
data["Sex"]=(data["Sex"]=="male").astype("int")

提取数据

x=data.iloc[:, data.columns!="Survived"]
y=data.iloc[:,data.columns=="Survived"]

#Xtrain:(622, 8)
#划分数据集和测试集
from sklearn.model_selection import train_test_split
Xtrain,Xtest,Ytrain,Ytest=train_test_split(x,y,test_size=0.3)

#把索引变为从0~622
for i in [Xtrain,Xtest,Ytrain,Ytest]:
    i.index=range(i.shape[0])

第一种方法构建决策树

# clf=DecisionTreeClassifier(random_state=25)
# clf=clf.fit(Xtrain,Ytrain)
# score=clf.score(Xtest,Ytest)
# print(score)
from sklearn.model_selection import cross_val_score
# clf=DecisionTreeClassifier(random_state=25)
# score=cross_val_score(clf,x,y,cv=10).mean()
# print(score)



tr=[]
te=[]
for i in range(10):
    clf=DecisionTreeClassifier(random_state=25,
                               max_depth=i+1,
                               criterion="entropy"
                               )
    clf=clf.fit(Xtrain,Ytrain)
    score_tr=clf.score(Xtrain,Ytrain)
    score_te=cross_val_score(clf,x,y,cv=10).mean()

    tr.append(score_tr)
    te.append(score_te)
print(max(te))
plt.plot(range(1,11),tr,color="red",label="train")
plt.plot(range(1,11),te,color="blue",label="test")
#1~10全部显示
plt.xticks(range(1,11))
plt.legend()
plt.show()

不同深度的决策树的测试集和训练集的表现 

 第二种方法构建决策树

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
import matplotlib.pyplot as plt
plt.switch_backend("TkAgg")
from sklearn.model_selection import GridSearchCV
import numpy as np

#---------设置pd,在pycharm中显示完全表格-------
pd.set_option('display.max_columns', 1000)
pd.set_option('display.width', 1000)
pd.set_option('display.max_colwidth', 1000)
#----------------------------------------
data=pd.read_csv("./data.csv")
# print (data.head(5))
# print(data.info())

#去掉姓名、Cabin、票号的特征
data.drop(["Cabin","Name","Ticket"],inplace=True,axis=1)
# print(data.head())
# print(data.info())

#处理缺失值
#年龄用均值填补
data["Age"]=data["Age"].fillna(data["Age"].mean())
# print(data.info())

#删除有缺失值的行,Embarked缺了两行,所有的数据去掉不完整的行
data=data.dropna()
# print(data.info())

# print(data["Embarked"].unique())
# print(data["Sex"].unique())

labels=data["Embarked"].unique().tolist()
#x代表data[Embarked]的每一行的值,S-->0,C-->1,Q-->2
data["Embarked"]=data["Embarked"].apply(lambda x:labels.index(x))

#把条件为True的转为int行
#也可以这样写:data.loc[:,"Sex"]=(data["Sex"]=="male").astype("int")
#male-->0,female-->1
data["Sex"]=(data["Sex"]=="male").astype("int")

x=data.iloc[:, data.columns!="Survived"]
y=data.iloc[:,data.columns=="Survived"]


#Xtrain:(622, 8)
#划分数据集和测试集
from sklearn.model_selection import train_test_split
Xtrain,Xtest,Ytrain,Ytest=train_test_split(x,y,test_size=0.3)

#把索引变为从0~622
for i in [Xtrain,Xtest,Ytrain,Ytest]:
    i.index=range(i.shape[0])


from sklearn.model_selection import cross_val_score


clf=DecisionTreeClassifier(random_state=25)
#GridSearchCV:满足fit,score,交叉验证三个功能
#parameters:一串参数和这些参数对应的,我们希望网格搜索来搜索对应的参数的取值范围
parameters={
    "criterion":("gini","entropy"),
    "splitter":("best","random"),
    "max_depth":[*range(1,10)],
    "min_samples_leaf":[*range(1,50,5)],
    "min_impurity_decrease":[*np.linspace(0,0.5,20)]
}
GS=GridSearchCV(clf,parameters,cv=10)
gs=GS.fit(Xtrain,Ytrain)

#从输入的参数和参数取值中,返回最佳组合
print(gs.best_params_)

#网格搜索后的模型的评判标准
print(gs.best_score_)
{'criterion': 'entropy', 'max_depth': 3, 'min_impurity_decrease': 0.0, 'min_samples_leaf': 1, 'splitter': 'best'}
0.8297235023041475

这种方法构建的决策树的准确率比第一种的还低

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/843813.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue3 第四节 自定义hook函数以及组合式API

1.自定义hook函数 2.toRef和toRefs 3.shallowRef和shallowReactive 4.readonly和shallowReadonly 5.toRaw和markRaw 6.customref 一.自定义hook函数 ① 本质是一个函数&#xff0c;把setup函数中使用的Composition API 进行了封装,类似于vue2.x中的mixin 自定义hook函数…

Maven介绍-下载-安装-使用-基础知识

Maven介绍-下载-安装-使用-基础知识 Maven的进阶高级用法可查看这篇文章&#xff1a; Maven分模块-继承-聚合-私服的高级用法 文章目录 Maven介绍-下载-安装-使用-基础知识01. Maven1.1 初识Maven1.1.1 什么是Maven1.1.2 Maven的作用 02. Maven概述2.1 Maven介绍2.2 Maven模型…

F5 LTM 知识点和实验 12-使用规则和本地流量策略定制应用程序交付

第十一章:iapp(忽略) 第十二章:使用规则和本地流量策略定制应用程序交付 用最简单的术语来说,iRule是在网络流量通过BIGIP系统时对其执行的脚本。其思想非常简单:规则使您能够编写简单的网络感知代码片段,这些代码以各种方式影响您的网络流量。无论您是希望以BIG-IP内置…

SpringBoot 自动配置--常用配置

&#x1f600;前言 本篇博文是关于SpringBoot 自动配置的一些分享&#xff0c;希望能够帮助到您&#x1f60a; &#x1f3e0;个人主页&#xff1a;晨犀主页 &#x1f9d1;个人简介&#xff1a;大家好&#xff0c;我是晨犀&#xff0c;希望我的文章可以帮助到大家&#xff0c;您…

分布式任务调度平台——XXL-JOB

1、为什么需要任务调度平台 1.1、传统的定时任务实现方案不足 在Java中&#xff0c;传统的定时任务实现方案&#xff0c;比如Timer&#xff0c;Quartz等都或多或少存在一些问题&#xff1a; 不支持集群、不支持统计、没有管理平台、没有失败报警、没有监控等。在现在分布式的…

如何快速掌握水土保持方案编制

1、熟悉水土保持常用的主要法律法规、部委规章、规范性文件及技术规范与标准&#xff1b; 2、了解水土保持方案、监测及验收工作开展的流程&#xff1b; 3、熟悉水土保持方案、监测及验收工作需要收集的资料、现场踏勘注意事项&#xff1b; 4、熟悉常见水土保持工程施工工艺…

JavaSE_2.1——数组【概念、创建、内存分配】

今天是练习数组的第一天&#xff0c;后续继续 1、数组的定义以及声明 1.数组的定义&#xff1a;一组能够存储相同数据类型值的变量的集合 2.数组的赋值方式&#xff1a; New关键字:表示创建一个数组&#xff1b; &#xff08;1&#xff09;使用默认的初始值来初始化数组中…

我的Python教程:Tkinter组件布局管理的3种方式

**Tkinter组件布局管理可以使用pack()方法、grid()方法和place()方法。**pack()方法将组件放置在窗口中&#xff0c;grid()方法将组件放置在网格布局中&#xff0c;place()方法将组件放置在指定位置。 01使用pack()方法布局&#xff1a; 在Tkinter中&#xff0c;pack方法用于将…

【双指针_快乐数_C++】

题目解析 快乐数 算法原理 快慢双指针1、定义快慢指针 2、慢指针每次向后移动一步&#xff0c;快指针每次向后移动两步。 3、判断相遇的时候的值 编写代码 class Solution { public:int num_sum(int n){int sum 0;while(n!0){int t n%10;sumt*t;n n/10;}return sum;}bool…

Visio Studio Code 搭建Vue开发环境

一、安装Visual Studio Code 使用 Visual Studio Code&#xff08;VS Code&#xff09;开发 Vue.js 应用是一种常见的做法&#xff0c;以下是简要的步骤&#xff1a; 安装 VS Code&#xff1a; 如果您尚未安装 Visual Studio Code&#xff0c;您可以从官方网站&#xff08;htt…

作为一名软件测试工程师,需要具备哪些能力?

软件测试工程师是个神奇的职业&#xff0c;他是开发人员与老板之间的传话筒&#xff0c;也是开发人员与老板的好帮手。他不仅需要有销售的沟通能力&#xff0c;也需要具备编辑人员的文档撰写技巧。如此一个面面俱到的岗位&#xff0c;他需要具备的技能到底有哪些呢&#xff1f;…

互联网智能3D导诊系统源码支持微信小程序、APP

通过智能导诊&#xff0c;进行自助问询及挂号服务&#xff0c;减轻导诊台护士压力&#xff0c;挂号更加方便快捷。 系统技术架构&#xff1a;springbootredismybatis plusmysqlRocketMQ 一、智慧导诊系统开发原理 导诊系统从原理上大致可分为基于规则模板和基于数据模型两类…

人工智能进入到制造业后,可以带来哪些方面的新改变?

随着人工智能&#xff08;AI&#xff09;进入制造业&#xff0c;它有可能带来重大变化和改进。以下是人工智能可以给制造业带来的一些关键变化&#xff1a; 1.提高效率和生产力&#xff1a;人工智能可以通过分析大量数据并识别低效率来优化生产流程。它可以帮助简化制造运营、…

【100天精通python】Day28:文件与IO操作_JSON文件处理

目录 专栏导读 1. JSON数据格式简介 1.1 示例JSON数据 1.2 JSON文件的特点 2 json模块的常用操作 2.1 读写JSON文件的示例 2.2 解析JSON字符串 2.3 修改JSON数据 2.4 查询和操作嵌套数据 2.5 处理包含特殊字符的JSON文件 2.6 处理日期和时间 2.7 处理大型JSON文…

【web逆向】全报文加密及其登录流程的分析案例

aHR0cHM6Ly9oZWFsdGguZWxkZXIuY2NiLmNvbS9zaWduX2luLw 涉及加密库jsencrypt 定位加密点 先看加密的请求和响应&#xff1a; 全局搜索加密字段jsondata&#xff0c;这种非特定参数的一般一搜一个准&#xff0c;搜到就是断点。起初下的断点没停住&#xff0c;转而从调用栈单步…

MySQL—— 基础语法大全

MySQL—— 基础 一、MySQL概述1.1 、数据库相关概念1.2 、MySQL 客户端连接1.3 、数据模型 二、SQL2.1、SQL通用语法2.2、SQL分类2.3、DDL2.4、DML2.5、DQL2.6、DCL 三、函数四、约束五、多表查询六、事务 一、MySQL概述 1.1 、数据库相关概念 数据库、数据库管理系统、SQL&a…

谈谈DNS是什么?它的作用以及工作流程

作者&#xff1a;Insist-- 个人主页&#xff1a;insist--个人主页 作者会持续更新网络知识和python基础知识&#xff0c;期待你的关注 目录 一、DNS是什么&#xff1f; 二、DNS的作用 三、DNS查询流程 1、查看浏览器缓存 2、查看系统缓存 3、查看路由器缓存 4、查看ISP …

arcgis栅格数据之最佳路径分析

1、打开arcmap&#xff0c;加载数据&#xff0c;需要对影像进行监督分类&#xff0c;如下&#xff1a; 这里任选一种监督分类的方法&#xff08;最大似然法&#xff09;&#xff0c;如下&#xff1a; 这里会先生成一个.ecd文件&#xff0c;然后再利用.ecd文件对影像进行分类。如…

论 SoC上的Linux如何拉动外部I/O

在MCU中&#xff08;如classic autosr或其他RTOS&#xff09;&#xff0c;一般可以直接通过往对应的寄存器&#xff08;地址转为指针&#xff09;写值&#xff0c; 或者调用一些硬件抽象层或者驱动接口来拉动芯片提供的GPIO。 但是在Linux中&#xff0c;可能不会让应用层直接去…

[CKA]考试之查看pod的cpu

由于最新的CKA考试改版&#xff0c;不允许存储书签&#xff0c;本博客致力怎么一步步从官网把答案找到&#xff0c;如何修改把题做对&#xff0c;下面开始我们的 CKA之旅 题目为&#xff1a; Task 找出标签是namecpu-loader的Pod&#xff0c;并过滤出使用CPU最高的Pod&#…