机器学习:集成学习之随机森林

news2024/9/21 18:43:57

目录

前言

一、集成学习

1.集成学习的含义

2.集成学习的代表

3.集成学习的应用

二、随机森林

1.随机森林的特点

2.随机森林生成步骤

3.随机森林优点

4.随机森林的缺点

三、代码实现

1.完整代码

2.数据预处理

3.创建并训练模型

4.测试模型

总结


前言

        随机森林是一种集成学习方法,主要用于分类和回归任务。它通过构建多个决策树并将其结果结合起来,提高模型的准确性。每棵树在训练时使用数据的随机子集和特征的随机子集,从而降低过拟合风险,并增强模型的泛化能力。最终预测是通过对所有树的预测结果进行投票(分类)或平均(回归)来实现的。

 

一、集成学习

1.集成学习的含义

        集成学习是将多个基础学习器进行组合,来实现比单一学习器显著优越的学习性能

 

2.集成学习的代表

  • bagging方法:典型的是随机森林
  • boosting方法:典型的是Xgboost
  • stacking方法:堆叠模型

 

3.集成学习的应用

  1. 分类问题集成
  2. 回归问题集成
  3. 特征选取集成

 

 

二、随机森林

1.随机森林的特点

  1. 数据采样随机:随机从训练集中选取自定百分比的数据
  2. 特征选取随机:随机从特征中选取自定百分比的特征
  3. 森林:很多树
  4. 基分类器为决策树

 

2.随机森林生成步骤

  1. 生成多个决策树
    1. 从原始数据集中通过Bootstrap抽样生成多个子集,每个子集用于训练一棵决策树。
    2. 在每棵树的训练过程中,随机选择特征子集进行节点分裂,增加树的多样性。
  2. 预测与投票
    1. 对于分类任务,通过对所有决策树的预测结果进行投票,选择票数最多的类别作为最终预测。
    2. 对于回归任务,通过对所有决策树的预测结果进行平均,得到最终的预测值。

 

3.随机森林优点

  1. 具有极高的准确率。
  2. 随机性的引入,使得随机森林的抗噪声能力很强。
  3. 随机性的引入,使得随机森林不容易过拟合。
  4. 能够处理很高维度的数据,不用做特征选择。
  5. 容易实现并行化计算。

 

4.随机森林的缺点

  1. 当随机森林中的决策树个数很多时,训练时需要的空间和时间会较大。
  2. 随机森林模型还有许多不好解释的地方,有点算个黑盒模型,

 

 

三、代码实现

  • 本次使用的是多特征二分类数据

1.完整代码

import pandas as pd
from sklearn.model_selection import train_test_split


# 可视化混淆矩阵
def cm_plot(y, yp):
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt

    cm = confusion_matrix(y, yp)
    plt.matshow(cm, cmap=plt.cm.Blues)
    plt.colorbar()
    for x in range(len(cm)):
        for y in range(len(cm)):
            plt.annotate(cm[x, y], xy=(y, x), horizontalalignment='center',
                         verticalalignment='center')
            plt.ylabel('True label')
            plt.xlabel('Predicted label')
    return plt


data = pd.read_csv('spambase.csv')

x = data.iloc[:, :-1]  # 取出特征数据
y = data.iloc[:, -1]   # 取出标签

x_train, x_test, y_train, y_test = \
    train_test_split(x, y, test_size=0.3, random_state=0)

"""
n_estimators:决策树的个数
max_feature:特征的个数
"""
from sklearn.ensemble import RandomForestClassifier

rf = RandomForestClassifier(
    n_estimators=100
    , max_features=0.8  # 80%的特征
    , random_state=0
)
rf.fit(x_train, y_train)

from sklearn import metrics

train_predict = rf.predict(x_train)
print(metrics.classification_report(y_train, train_predict))

test_predict = rf.predict(x_test)
print(metrics.classification_report(y_test, test_predict))

cm_plot(y_test, test_predict).show()

输出:

  • 可视化混淆矩阵——测试集

  • 混淆矩阵
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      1954
           1       1.00      1.00      1.00      1263

    accuracy                           1.00      3217
   macro avg       1.00      1.00      1.00      3217
weighted avg       1.00      1.00      1.00      3217

              precision    recall  f1-score   support

           0       0.94      0.97      0.95       831
           1       0.95      0.91      0.93       549

    accuracy                           0.94      1380
   macro avg       0.95      0.94      0.94      1380
weighted avg       0.94      0.94      0.94      1380

 

2.数据预处理

  • 取出训练集,测试集的特征数据和标签
import pandas as pd
from sklearn.model_selection import train_test_split

data = pd.read_csv('spambase.csv')

x = data.iloc[:, :-1]
y = data.iloc[:, -1]

x_train, x_test, y_train, y_test = \
    train_test_split(x, y, test_size=0.3, random_state=0)

 

3.创建并训练模型

  • 创建一个100个决策树的随机森林,每棵树选取80%的特征进行训练
"""
n_estimators:决策树的个数
max_feature:特征的个数
"""
from sklearn.ensemble import RandomForestClassifier

rf = RandomForestClassifier(
    n_estimators=100
    , max_features=0.8  # 80%的特征
    , random_state=0
)
rf.fit(x_train, y_train)

 

4.测试模型

  • 使用训练集数据和测试集数据进行测试,得到结果
from sklearn import metrics

train_predict = rf.predict(x_train)
print(metrics.classification_report(y_train, train_predict))

test_predict = rf.predict(x_test)
print(metrics.classification_report(y_test, test_predict))

cm_plot(y_test, test_predict).show()

输出:

  • 虽然训练集数据进行测试时正确率非常高,看起来像过拟合
  • 但是不用担心,测试集正确率并没有下降多少
  • 说明该模型并没有过拟合
  • 可以看出随机森林不仅正确率高,还不容易过拟合
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      1954
           1       1.00      1.00      1.00      1263

    accuracy                           1.00      3217
   macro avg       1.00      1.00      1.00      3217
weighted avg       1.00      1.00      1.00      3217

              precision    recall  f1-score   support

           0       0.94      0.97      0.95       831
           1       0.95      0.91      0.93       549

    accuracy                           0.94      1380
   macro avg       0.95      0.94      0.94      1380
weighted avg       0.94      0.94      0.94      1380

 

总结

        本篇讲述了集成学习的概念,随机森林的概念,特点,步骤和优缺点,最后使用代码实例演示了随机森林的使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2074914.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

集合及数据结构第十二节(下)————哈希表、字符串常量池和练习题

系列文章目录 集合及数据结构第十二节(下)————哈希表、字符串常量池和练习题 哈希表、字符串常量池和练习题 哈希表的概念冲突-概念冲突-避免冲突-解决冲突严重时的解决办法冲突严重时的解决办法的实现性能分析和 java 类集的关系Hashmap的使用案…

R8RS标准之重要特性及用法实例(四十)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 新书发布:《Android系统多媒体进阶实战》🚀 优质专栏: Audio工程师进阶系列…

LDR6500Type-C pd OTGi协议芯片讲解

LDR6500是一款由乐得瑞科技推出的USB-C DRP(Dual Role Port,双角色端口)接口USB PD(Power Delivery,功率传输)通信芯片。这款芯片具备一系列先进的功能和特点,特别适合于手机音频转接器、USB Ty…

QT中引入SQLITE3数据库

1、把sqlite3.dll、.h、.lib这三个文件拷贝到工程目录下 2、在pro文件中配置一下即可 LIBS $$PWD/sqlite3.lib 3、保存一下pro文件 4、引入sqlite3.h头文件 5、验证 先新建一个文件夹data,若没有user.db,则会自动新建;有就直接使用 运行成…

UTONMOS:探索未来游戏的元宇宙纪元新篇章

元宇宙游戏,作为融合了虚拟现实(VR)、增强现实(AR)、区块链、人工智能(AI)等前沿技术的综合性数字世界,元宇宙游戏不仅重新定义了游戏的边界,更预示着一个沉浸式、交互性…

YOlOV5入门教程

前言 因项目需求,所以要使用yolo进行操作,现在对yolov5进行教程,代码可以在这下载:https://github.com/ultralytics/yolov5 项目结构 下载完成后可以看到资源如图所示。 1.1.github文件夹 ISSUE_TEMPLATE 目录 这个目录下的文件…

Cesium 展示——绘制水面动态升高

文章目录 需求分析需求 如图,绘制水面动态升高,作为洪水淹没的效果 分析 我们首先需要绘制一个面然后给这个面一个高度,在回调函数中进行动态设置值【这里有两种,一种是到达水面一定高度停止升高,一种是水面重新升高】/*** @description :洪水淹没* @author : Hukang*…

关闭IDEA启动画面

新版IDEA启动时启动画面居中且无法最小化,所以想把它给隐藏掉。(此操作不会加快启动速度) 在快捷方式后加入参数 nosplash,记得有个空格。

Java | Leetcode Java题解之第374题猜数字大小

题目&#xff1a; 题解&#xff1a; public class Solution extends GuessGame {public int guessNumber(int n) {int left 1, right n;while (left < right) { // 循环直至区间左右端点相同int mid left (right - left) / 2; // 防止计算时溢出if (guess(mid) < 0)…

CSV文件的高级处理:从大型文件处理到特殊字符管理

目录 一、处理大型CSV文件 1.1 面临的挑战 1.2 使用Pandas库 1.3 注意事项 二、跳过无效行 2.1 无效行的原因 2.2 使用异常处理机制 2.3 注意事项 三、处理特殊字符 3.1 特殊字符的问题 3.2 使用引号包围字段 3.3 使用库函数处理特殊字符 结论 CSV&#xff08;Com…

Web大学生网页作业成品——节日端午节介绍网页设计与实现(HTML+CSS)(5个页面)

&#x1f389;&#x1f389;&#x1f389; 常见网页设计作业题材有**汽车、环保、明星、文化、国家、抗疫、景点、人物、体育、植物、公益、图书、节日、游戏、商城、旅游、家乡、学校、电影、动漫、非遗、动物、个人、企业、美食、婚纱、其他**等网页设计题目, 可满足大学生网…

计算机网络面试真题总结(三)

文章收录在网站&#xff1a;http://hardyfish.top/ 文章收录在网站&#xff1a;http://hardyfish.top/ 文章收录在网站&#xff1a;http://hardyfish.top/ 文章收录在网站&#xff1a;http://hardyfish.top/ TCP 和 UDP 分别对应的常见应用层协议有哪些&#xff1f; TCP 对应…

帮助我们从曲线图中获取数据的软件分享——GetData Graph Digitizer

在科技论文写作和数据分析过程中&#xff0c;我们常常需要将自己的数据与前人的研究成果进行对比。然而&#xff0c;有时我们只能从别人的论文中获得一张包含坐标轴的曲线图&#xff0c;而无法直接获取原始数据。在这种情况下&#xff0c;GetData Graph Digitizer 软件就显得尤…

(24)(24.4) MultiWii/DJI/HDZero OSD (version 4.2 and later)(三)

文章目录 前言 3 显示端口OSD 前言 经过 WTF-OSD 修改的 HDZero、Walksnail 和 DJI 能够进行 DisplayPort 操作。 3 显示端口OSD DisplayPort 是一种 MSP 协议扩展&#xff0c;允许自动驾驶仪在兼容的外部操作系统上远程绘制文本。DisplayPort 是一种 MSP 协议扩展&#xf…

架构师篇-21、工作坊实战DDD分解业务

课程内容&#xff1a; 采用工作坊的教学模式共创主题一&#xff1a;DDD业务分析步骤共创主题二&#xff1a;DDD领域模型输出共创主题三&#xff1a;业务架构蓝图输出 收益&#xff1a; 如何采用DDD进行业务分解&#xff1f;【循序渐进不断实践】共创输出项目业务架构图及业务…

xtrabackup 用户权限

xtrabackup 用户权限 1.1、建用户及授权 The database user needs the following privileges on the tables/databases to be backed up: RELOAD and LOCK TABLES (unless the --no-lock option is specified) in order to FLUSH TABLES WITH READ LOCK and FLUSH ENGINE LO…

【C++】vector(上)

个人主页~ vector类 一、vector的介绍和使用1、vector的介绍2、vector的使用&#xff08;1&#xff09;vector的定义&#xff08;2&#xff09;vector iterator的使用&#xff08;3&#xff09;vector 空间增长&#xff08;4&#xff09;vector的增删查改&#xff08;5&#xf…

linux qt编写串口软件

1.界面布局 界面的简单设置&#xff0c;用到了 1.输入显示栏 2.数据发送栏 3.选择栏 4.16进制显示栏 和若干pushbottom label&#xff0c;布局就是横竖横竖这样布局下去 对界面进行基础的对齐美化 1.右侧布局的对齐 添加设置代码 右上选项已对齐 有个校验位一开始忘记添加…

PostgreSQL:后端开发者的瑞士军刀

PostgreSQL&#xff1a;后端开发者的瑞士军刀 在后端开发的世界中&#xff0c;PostgreSQL不仅是一个数据库&#xff0c;它更像是一个多功能的瑞士军刀&#xff0c;为开发者提供了强大的工具来构建和维护复杂的数据系统。作为一名资深后端开发者&#xff0c;我想分享一些关于Po…

Nginx四层负载均衡

1、Nginx四层负载均衡 1.1 负载均衡概述 负载均衡是一种分布式计算技术&#xff0c;用于将网络流量和用户请求分散到多台服务器上&#xff0c;以此来提高网络服务的可用性和可靠性。它通过优化资源使用、最大化吞吐量以及最小化响应时间&#xff0c;增强了网络、服务器和数据…