六、决策树算法(DT,DecisionTreeClassifier)(有监督学习)

news2025/1/16 8:07:18

决策树(DT)是一种用于分类和回归的非参数监督学习方法。其目标是创建一个模型,通过学习从数据特征中推断出的简单决策规则来预测目标变量的值。一棵树可以看作是一个片断常数近似值。

一、算法思路

具体可参考博文:七、决策树算法和集成算法

基尼系数Gini:衡量选择标准的不确定程度;说白了,就是越不确定Gini系数越高
需要选择最小的Gini系数来决定决策树下一级别分类的标准
在这里插入图片描述
以基尼系数为核心的决策树称为CART决策树(Classification and Regression Tree)
一般看到的决策树都是二叉树,这只是一种选择,并不代表所有决策树都是二叉树
决策树的生成容易造成过拟合现象的产生,需要剪枝操作来放弃一些约束条件达到防止过拟合的效果

官网决策树算法介绍:1.10. Decision Trees

二、官网API

官网API

class sklearn.tree.DecisionTreeClassifier(*, criterion='gini', splitter='best', max_depth=None, min_samples_split=2, min_samples_leaf=1, min_weight_fraction_leaf=0.0, max_features=None, random_state=None, max_leaf_nodes=None, min_impurity_decrease=0.0, class_weight=None, ccp_alpha=0.0)

导包:from sklearn.tree import DecisionTreeClassifier

①特征标准选择criterion

criterion衡量分割质量的函数
gini’:用于衡量gini不纯度,默认值
log_loss’和’entropy’:均用于衡量Shannon信息增益

具体官网详情如下:
在这里插入图片描述

使用方法

DecisionTreeClassifier(criterion='entropy')

②splitter

splitter用于选择每个节点分割的策略
best’:最佳分割,默认值
random’:最佳随机分割

具体官网详情如下:
在这里插入图片描述

使用方法

DecisionTreeClassifier(splitter='random')

③max_features

max_features寻找最佳分割时需要考虑的特征数量
auto’:每次分割时,特征数量根据具体情况进行自动选择
sqrt’:每次分割时,特征数量采用max_features=sqrt(n_features)
log2’:每次分割时,特征数量采用max_features=log2(n_features)
int’:在每次分割时考虑max_features特征
float’:特征值是一个分数;每次分割时,特征数量采用max(1,int(max_features * n_features_in_))
None’:每次分割时,特征数量采用max_features=n_features,默认值

具体官网详情如下:
在这里插入图片描述

使用方法

DecisionTreeClassifier(max_features='auto')

④随机种子random_state

控制估计器的随机性,随即状态实例,如果要是为了对比,需要控制变量的话,这里的随机种子最好设置为同一个整型数

在每次分割时,即使splitter=‘best’,特征也总是随机排列的
当 max_features < n_features 时,算法会在每次分割时随机选择 max_features,然后再从中找出最佳分割
但是,即使 max_features=n_features 在不同的运行中找到的最佳分割也可能不同
如果多个分割点的改进标准相同,且必须随机选择一个分割点,就会出现这种情况;为了在拟合过程中获得确定的行为,必须将 random_state 设为整数

具体官网详情如下:
在这里插入图片描述

使用方法

DecisionTreeClassifier(random_state=42)

三、代码实现

①导包

这里需要评估、训练、保存和加载模型,以下是一些必要的包,若导入过程报错,pip安装即可

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import joblib
%matplotlib inline
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score

②加载数据集

数据集可以自己简单整个,csv格式即可,我这里使用的是6个自变量X和1个因变量Y
在这里插入图片描述

fiber = pd.read_csv("./fiber.csv")
fiber.head(5) #展示下头5条数据信息

在这里插入图片描述

③划分数据集

前六列是自变量X,最后一列是因变量Y

常用的划分数据集函数官网API:train_test_split
在这里插入图片描述
test_size:测试集数据所占比例
train_size:训练集数据所占比例
random_state:随机种子
shuffle:是否将数据进行打乱
因为我这里的数据集共48个,训练集0.75,测试集0.25,即训练集36个,测试集12个

X = fiber.drop(['Grade'], axis=1)
Y = fiber['Grade']

X_train, X_test, y_train, y_test = train_test_split(X,Y,train_size=0.75,test_size=0.25,random_state=42,shuffle=True)

print(X_train.shape) #(36,6)
print(y_train.shape) #(36,)
print(X_test.shape) #(12,6)
print(y_test.shape) #(12,)

④构建DT模型

参数可以自己去尝试设置调整

dtc = DecisionTreeClassifier()

⑤模型训练

就这么简单,一个fit函数就可以实现模型训练

dtc.fit(X_train,y_train)

⑥模型评估

把测试集扔进去,得到预测的测试结果

y_pred = dtc.predict(X_test)

看看预测结果和实际测试集结果是否一致,一致为1否则为0,取个平均值就是准确率

accuracy = np.mean(y_pred==y_test)
print(accuracy)

也可以通过score得分进行评估,计算的结果和思路都是一样的,都是看所有的数据集中模型猜对的概率,只不过这个score函数已经封装好了,当然传入的参数也不一样,需要导入accuracy_score才行,from sklearn.metrics import accuracy_score

score = dtc.score(X_test,y_test)#得分
print(score)

⑦模型测试

拿到一条数据,使用训练好的模型进行评估
这里是六个自变量,我这里随机整个test = np.array([[16,18312.5,6614.5,2842.31,25.23,1147430.19]])
扔到模型里面得到预测结果,prediction = dtc.predict(test)
看下预测结果是多少,是否和正确结果相同,print(prediction)

test = np.array([[16,18312.5,6614.5,2842.31,25.23,1147430.19]])
prediction = dtc.predict(test)
print(prediction) #[2]

⑧保存模型

lsvc是模型名称,需要对应一致
后面的参数是保存模型的路径

joblib.dump(dtc, './dtc.model')#保存模型

⑨加载和使用模型

dtc_yy = joblib.load('./dtc.model')

test = np.array([[11,99498,5369,9045.27,28.47,3827588.56]])#随便找的一条数据
prediction = dtc_yy.predict(test)#带入数据,预测一下
print(prediction) #[4]

完整代码

模型训练和评估,不包含⑦⑧⑨。

from sklearn.tree import DecisionTreeClassifier
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

fiber = pd.read_csv("./fiber.csv")
# 划分自变量和因变量
X = fiber.drop(['Grade'], axis=1)
Y = fiber['Grade']
#划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, Y, random_state=0)

dtc = DecisionTreeClassifier(criterion='entropy',splitter='random',max_features='auto',random_state=42)
dtc.fit(X_train,y_train)#模型拟合
y_pred = dtc.predict(X_test)#模型预测结果
accuracy = np.mean(y_pred==y_test)#准确度
score = dtc.score(X_test,y_test)#得分
print(accuracy)
print(score)

test = np.array([[11,99498,5369,9045.27,28.47,3827588.56]])#随便找的一条数据
prediction = dtc.predict(test)#带入数据,预测一下
print(prediction)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1031668.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

linux升级glibc-2.28

1.准备工作 1.1升级gcc到gcc8 # 安装devtoolset-8-gcc yum install centos-release-scl yum install devtoolset-8 scl enable devtoolset-8 -- bash# 启用工具 source /opt/rh/devtoolset-8/enable # 安装GCC-8 yum install -y devtoolset-8-gcc devtoolset-8-gcc-c devtoolse…

【C语言】数组和指针刷题练习

指针和数组我们已经学习的差不多了&#xff0c;今天就为大家分享一些指针和数组的常见练习题&#xff0c;还包含许多经典面试题哦&#xff01; 一、求数组长度和大小 普通一维数组 int main() {//一维数组int a[] { 1,2,3,4 };printf("%d\n", sizeof(a));//整个数组…

[plugin:vite:css] [sass] Undefined mixin.

前言&#xff1a; vite vue3 TypeScript环境 scss报错&#xff1a; [plugin:vite:css] [sass] Undefined mixin. 解决方案&#xff1a; 在vite.config.ts文件添加配置 css: {preprocessorOptions: {// 导入scss预编译程序scss: {additionalData: use "/resources/_ha…

如何使用远程桌面软件进行远程工作

远程工作提供了更大的灵活性和自由度&#xff0c;使得可以在任何地点工作。而要实现高效的远程工作&#xff0c;一个关键的工具就是远程桌面软件。本文将详细介绍如何使用远程桌面软件进行远程工作&#xff0c;以帮助读者提高工作效率。 一、了解远程桌面软件的基本原理 远程桌…

带你一步实现《栈》(括号匹配问题)

栈的结构及概念 栈是一种特殊的线性表&#xff0c;只允许在固定的一端插入或删除数据&#xff0c;进行插入和删除的一端被称为栈顶&#xff0c;另一端称为栈底。栈中的数据遵循后进先出原则 LIFO&#xff08;LAST IN FIRST OUT) 俗称栈的插入过程叫做压栈&#xff0c;入栈&…

Batbot智慧能源管理云平台:拥抱数字化,提高能源效率!

我们拥抱数字化&#xff0c;以帮助提高能源效率。 政府已采取措施增强国家的环境信誉&#xff0c;旨在实现雄心勃勃的法定目标&#xff0c;即到2035年&#xff0c;将国家温室气体排放量减少78%&#xff08;与1990年相比&#xff09;。 拥抱数字化&#xff0c;提高能源效率&a…

HTTP 协商缓存 Last-Modified,If-Modified-Since

浏览器第一次跟服务器请求一个资源&#xff0c;服务器在返回这个资源的同时&#xff0c;在respone header加上Last-Modified属性&#xff08;表示这个资源在服务器上的最后修改时间&#xff09;&#xff1a; ----------------------------------------------------------------…

ThinkPHP5,使用unionAll取出两个毫无相关字段表的数据且分页

一&#xff1a;首先来了解一下 union 和 unionAll 1&#xff1a;取结果的并集&#xff0c;是否去重 union&#xff1a;对两个结果集进行并集操作&#xff0c;不包括重复行&#xff0c;相当于distinct&#xff0c;同时进行默认规则的排序&#xff1b; unionAll&#xff1a;对两…

JVM面试题-JVM对象的创建过程、内存分配、内存布局、访问定位等问题详解

对象 内存分配的两种方式 指针碰撞 适用场合&#xff1a;堆内存规整&#xff08;即没有内存碎片&#xff09;的情况下。 原理&#xff1a;用过的内存全部整合到一边&#xff0c;没有用过的内存放在另一边&#xff0c;中间有一个分界指针&#xff0c;只需要向着没用过的内存…

【QT】QRadioButton的使用(17)

QRadioButton这个控件在实际项目中多用于多个QRadioButton控件选择其中一个这样的方式去执行&#xff0c;那么&#xff0c;今天这节就通过几个简单的例子来好好了解下QRadioButton的一个使用。 一.环境配置 1.python 3.7.8 可直接进入官网下载安装&#xff1a;Download Pyt…

PIL或Pillow学习2

接着学习下Pillow常用方法&#xff1a; PIL_test1.py : 9, Pillow图像降噪处理由于成像设备、传输媒介等因素的影响&#xff0c;图像总会或多或少的存在一些不必要的干扰信息&#xff0c;我们将这些干扰信息统称为“噪声”&#xff0c; 比如数字图像中常见的“椒盐噪声”&…

聊一聊Twitter的雪花算法

什么是Twitter的雪花算法方法&#xff1f; 这是一种在分布式系统中生成唯一ID的解决方案。Twitter在推文、私信、列表等方面使用这种方法。 •ID是唯一且可排序的•ID包含时间信息&#xff08;按日期排序&#xff09;•ID适用于64位无符号整数•仅包含数字值 符号位&#xff08…

芋道商城,基于 Vue + Uniapp 实现,支持分销、拼团、砍价、秒杀、优惠券、积分、会员等级、小程序直播、页面 DIY 等功能

商城简介 芋道商城&#xff0c;基于 芋道开发平台 构建&#xff0c;以开发者为中心&#xff0c;打造中国第一流的 Java 开源商城系统&#xff0c;全部开源&#xff0c;个人与企业可 100% 免费使用。 有任何问题&#xff0c;或者想要的功能&#xff0c;可以在 Issues 中提给艿艿…

【从0学习Solidity】 10. 控制流,用solidity实现插入排序

【从0学习Solidity】10. 控制流&#xff0c;用solidity实现插入排序 博主简介&#xff1a;不写代码没饭吃&#xff0c;一名全栈领域的创作者&#xff0c;专注于研究互联网产品的解决方案和技术。熟悉云原生、微服务架构&#xff0c;分享一些项目实战经验以及前沿技术的见解。关…

python中的NaN在质量控制中怎么处理?

一、数据中的缺省值 气象数据中经常存在缺省值&#xff0c;比如未入库的站点数据、比如海温格点实况数据中的陆地区域。这些缺省值往往被赋予NaN&#xff08;Not a Number&#xff0c;非数&#xff09;。NaN是计算机科学中数值数据类型的一类值&#xff0c;表示未定义或不可表…

远程端点管理和安全性

当今的企业网络环境是一个分布式动态环境&#xff0c;其中有许多需要管理、验证和保护的移动部件&#xff0c;而不会对最终用户的生产力产生任何威慑力。提供有效的端点管理安全性&#xff0c;同时仍提供无缝最终用户体验的解决方案至关重要。 Endpoint Central 执行的活动可确…

Linux高性能服务器编程 学习笔记 第六章 高级IO函数

pipe函数用于创建一个管道&#xff0c;以实现进程间通信&#xff1a; fd参数是一个包含两个int的数组。该函数成功时返回0&#xff0c;并将一对打开的文件描述符填入其参数指向的数组&#xff0c;如果失败&#xff0c;则返回-1并设置errno。 pipe函数创建的这两个文件描述符f…

用Python爬取短视频列表

短视频是一款备受欢迎的短视频分享平台&#xff0c;每天都有大量精彩的视频内容等待我们去探索。在本文中&#xff0c;我们将分享如何使用Python爬取短视频的视频列表&#xff0c;让您能够发现更多有趣的视频。 一、安装必要的库 在开始之前&#xff0c;确保已安装以下库&…

Unity——对象池

对象池是一种朴素的优化思想。在遇到需要大量创建和销毁同类物体的情景时&#xff0c;可以考虑使用对象池技术优化游戏性能。 一、为什么要使用对象池 在很多类型的游戏中都会创建和销毁大量同样类型的物体。例如&#xff0c;飞行射击游戏中有大量子弹&#xff0c;某些动作游戏…

java用easyexcel按模版导出

首先在项目的resources下面建一个template包&#xff0c;之后在下面创建一个模版&#xff0c;模版格式如下&#xff1a; 名称为 financeReportBillStandardTemplateExcel.xlsx&#xff1a; {.fee}类型的属性值&#xff0c;是下面实体类的属性&#xff0c;要注意这里面的格式&a…