随机森林算法(Random Forest)的二分类问题

news2025/1/11 1:50:58

二分类问题

  • 1. 数据导入
  • 2. RF模型构建
    • 2.1 调参:mtry和ntree
    • 2.2 运行模型
  • 3. 模型测试
  • 4.绘制混淆矩阵
  • 5.绘制ROC曲线
  • 6. 参考

1. 数据导入

library(dplyr) #数据处理使用
library(data.table) #数据读取使用
library(randomForest) #RF模型使用
library(caret) # 调参和计算模型评价参数使用
library(pROC) #绘图使用
library(ggplot2) #绘图使用
library(ggpubr) #绘图使用
library(ggprism) #绘图使用
library(skimr) #查看数据分布
library(caTools) #训练集和测试集划分
setwd("D:/BaiduNetdiskDownload")
# 读取数据
data <- fread("data.txt",data.table = F)  # 替换为你的数据文件名或路径
as.data.frame(data)
skim(data)#数据鸟瞰
colnames(data)
hist(data$feature7, breaks = 50) # 查看数据分布

data <- data[1:1000, ] # 选择前1000个样本进行计算

在这里插入图片描述

  数据一共包含了35723个样本,214个特征,选择其中前1000个样本进行模型构建(数据太大,这样更快一些)。

在这里插入图片描述
  查看一下数据分布情况,是不是符合一定的规律,如正态性之类的。

2. RF模型构建

  数据集分割为训练集和测试集

# 分割数据为训练集和测试集
set.seed(123)  # 设置随机种子,保证结果可复现
split <- sample.split(data$type, SplitRatio = 0.8)  # 将数据按照指定比例分割
train_data <- subset(data, split == TRUE)  # 训练集
test_data <- subset(data, split == FALSE)  # 测试集

# 定义训练集特征和目标变量
X_train <- train_data[, -1]
y_train <- as.factor(train_data[, 1]) #将第一列的标签转换为因子变量

2.1 调参:mtry和ntree

  mtry:随机选择特征数目

# 2.1 mtry的取值是平方根(对于分类问题)或总特征数的三分之一(对于回归问题)
# mtry: 表示每棵决策树在进行节点分裂时考虑的特征数量
# 创建训练控制对象
ctrl <- trainControl(method = "cv", number = 10) # 选择10折交叉验证。
# 定义参数网格
grid <- expand.grid(mtry = c(2: 6))  # 每棵树中用于分裂的特征数量,这里只是随便给的测试,主要为了介绍如何调参,并非最优选择。

# 使用caret包进行调参
rf_model <- train(x = X_train, y = y_train,
                  method = "rf",
                  trControl = ctrl,
                  tuneGrid = grid)

# 输出最佳模型和参数
print(rf_model)

结果:

Random Forest 

800 samples
213 predictors
  2 classes: 'malignant', 'normal' 

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 720, 720, 720, 721, 720, 719, ... 
Resampling results across tuning parameters:

  mtry  Accuracy   Kappa    
  2     0.9650768  0.9060232
  3     0.9663268  0.9098219
  4     0.9650768  0.9065433
  5     0.9675768  0.9138408
  6     0.9675768  0.9141520

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 5.

选择mtry=5即可

  ntree:

# 调整Caret没有提供的参数
# 如果我们想调整的参数Caret没有提供,可以用下面的方式自己手动调参。
# 用刚刚调参的最佳mtry值固定mtry
grid <- expand.grid(mtry = c(5))  # 每棵树中用于分裂的特征数量

# 定义模型列表,存储每一个模型评估结果
modellist <- list()

# 调整的参数是决策树的数量
for (ntree in c(50, 70, 90)) {
  set.seed(123)
  fit <- train(x = X_train, y = y_train, method="rf", 
               metric="Accuracy", tuneGrid=grid, 
               trControl=ctrl, ntree=ntree)
  key <- toString(ntree)
  modellist[[key]] <- fit
  print(ntree)
}

# compare results
results <- resamples(modellist)
# 输出最佳模型和参数
summary(results)

结果:

Call:
summary.resamples(object = results)

Models: 50, 70, 90 
Number of resamples: 10 

Accuracy 
     Min.  1st Qu.    Median      Mean 3rd Qu.      Max. NA's
50 0.9500 0.962500 0.9748418 0.9687492   0.975 0.9875000    0
70 0.9375 0.953125 0.9688233 0.9637647   0.975 0.9750000    0
90 0.9375 0.962500 0.9748418 0.9674838   0.975 0.9876543    0

Kappa 
        Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
50 0.8624248 0.9016393 0.9330971 0.9177543 0.9354318 0.9682035    0
70 0.8360656 0.8754931 0.9164091 0.9043374 0.9354318 0.9373532    0
90 0.8360656 0.8992677 0.9330971 0.9144723 0.9354318 0.9673519    0

选择ntree为90即可

2.2 运行模型

# 使用最佳参数训练最终模型
final_model <- randomForest(x = X_train, y = y_train,mtry = 5,ntree = 90)
# 输出最终模型
print(final_model)

结果:

Call:
 randomForest(x = X_train, y = y_train, ntree = 90, mtry = 5) 
               Type of random forest: classification
                     Number of trees: 90
No. of variables tried at each split: 5

        OOB estimate of  error rate: 2.38%
Confusion matrix:
          malignant normal class.error
malignant       586      5 0.008460237
normal           14    195 0.066985646

3. 模型测试

# 在测试集上进行预测
X_test <- test_data[, -1]
y_test <- as.factor(test_data[, 1])
test_predictions <- predict(final_model, newdata = test_data)

# 计算模型指标
confusion_matrix <- confusionMatrix(test_predictions, y_test)
accuracy <- confusion_matrix$overall["Accuracy"]
precision <- confusion_matrix$byClass["Pos Pred Value"]
recall <- confusion_matrix$byClass["Sensitivity"]
f1_score <- confusion_matrix$byClass["F1"]

# 输出模型指标
print(confusion_matrix)
print(paste("Accuracy:", accuracy))
print(paste("Precision:", precision))
print(paste("Recall:", recall)) # sensitivity
print(paste("F1 Score:", f1_score))

结果:

> print(confusion_matrix)
Confusion Matrix and Statistics

           Reference
Prediction  malignant normal
  malignant       146      3
  normal            2     49
                                          
               Accuracy : 0.975           
                 95% CI : (0.9426, 0.9918)
    No Information Rate : 0.74            
    P-Value [Acc > NIR] : <2e-16          
                                          
                  Kappa : 0.9346          
                                          
 Mcnemar's Test P-Value : 1               
                                          
            Sensitivity : 0.9865          
            Specificity : 0.9423          
         Pos Pred Value : 0.9799          
         Neg Pred Value : 0.9608          
             Prevalence : 0.7400          
         Detection Rate : 0.7300          
   Detection Prevalence : 0.7450          
      Balanced Accuracy : 0.9644          
                                          
       'Positive' Class : malignant       
                                          
> print(paste("Accuracy:", accuracy))
[1] "Accuracy: 0.975"
> print(paste("Precision:", precision))
[1] "Precision: 0.979865771812081"
> print(paste("Recall:", recall)) # sensitivity
[1] "Recall: 0.986486486486487"
> print(paste("F1 Score:", f1_score))
[1] "F1 Score: 0.983164983164983"

4.绘制混淆矩阵

# 绘制混淆矩阵热图
# 将混淆矩阵转换为数据框
confusion_matrix_df <- as.data.frame.matrix(confusion_matrix$table)
colnames(confusion_matrix_df) <- c("cluster1","cluster2")
rownames(confusion_matrix_df) <- c("cluster1","cluster2")
draw_data <- round(confusion_matrix_df / rowSums(confusion_matrix_df),2)
draw_data$real <- rownames(draw_data)
draw_data <- melt(draw_data)

ggplot(draw_data, aes(real,variable, fill = value)) +
  geom_tile() +
  geom_text(aes(label = scales::percent(value))) +
  scale_fill_gradient(low = "#F0F0F0", high = "#3575b5") +
  labs(x = "True", y = "Guess", title = "Confusion matrix") +
  theme_prism(border = T)+
  theme(panel.border = element_blank(),
        axis.ticks.y = element_blank(),
        axis.ticks.x = element_blank(),
        legend.position="none")

在这里插入图片描述

5.绘制ROC曲线

# 绘制ROC曲线需要将预测结果以概率的形式输出
test_predictions <- predict(final_model, newdata = test_data,type = "prob")

# 计算ROC曲线的参数
roc_obj <- roc(response = y_test, predictor = test_predictions[, 2])
roc_auc <- auc(roc_obj)

# 将ROC对象转换为数据框
roc_data <- data.frame(1 - roc_obj$specificities, roc_obj$sensitivities)

# 绘制ROC曲线
ggplot(roc_data, aes(x = 1 - roc_obj$specificities, y = roc_obj$sensitivities)) +
  geom_line(color = "#0073C2FF", size = 1.5) +
  geom_segment(aes(x = 0, y = 0, xend = 1, yend = 1), linetype = "dashed", color = "gray") +
  geom_text(aes(x = 0.8, y = 0.2, label = paste("AUC =", round(roc_auc, 2))), size = 4, color = "black") +
  coord_cartesian(xlim = c(0, 1), ylim = c(0, 1)) +
  theme_pubr() +
  labs(x = "1 - Specificity", y = "Sensitivity") +
  ggtitle("ROC Curve") +
  theme(plot.title = element_text(size = 14, face = "bold"))+
  theme_prism(border = T)

在这里插入图片描述

6. 参考

  • 机器学习之分类器性能指标之ROC曲线、AUC值

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1128407.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件测试之 测试用例 如何设计

在软件开发过程中&#xff0c;测试是一个至关重要的环节&#xff0c;它有助于确保软件的质量和稳定性。而测试用例设计则是测试过程中的一个关键步骤&#xff0c;它帮助测试团队确定如何测试软件以发现潜在的问题和缺陷。本文将介绍测试用例设计的基本概念和步骤&#xff0c;以…

产品研发团队协作神器!10款提效工具大盘点!

在如今科技驱动的时代&#xff0c;产品研发团队面临着前所未有的竞争压力和不断变化的市场需求。为了在这个激烈的环境中脱颖而出&#xff0c;团队需要高效协作并充分利用先进的工具来提高生产力和创新能力。 本文将为你盘点产品研发团队协作必备的10个提效工具&#xff0c;这…

TiDB、MySQL与Oracle的char字段

文章目录 MySQLTiDBOracle结论 我们组在团队内维护了一套TiDB&#xff0c;有时候会有其他同事来请教一些问题&#xff0c;当然遇到比较复杂的问题&#xff0c;我也会直接抛给DBA。今天有个同事来问了一下TiDB的char字段查询是否需要补空格。在我的印象中&#xff0c;TiDB是高度…

【Docker】Docker学习之一:离线安装Docker步骤

前言&#xff1a;基于Ubuntu Jammy 22.04 (LTS)版本安装和测试 1、Docker安装 1.1、离线安装 步骤一&#xff1a;官网下载 docker 安装包 wget https://download.docker.com/linux/static/stable/x86_64/docker-24.0.6.tgz步骤二&#xff1a;解压安装包; tar -zxvf docker…

安防监控视频汇聚平台EasyCVR增加AI算法列表接口的实现方法

安防监控视频汇聚平台EasyCVR基于云边端一体化架构&#xff0c;具有强大的数据接入、处理及分发能力&#xff0c;可提供视频监控直播、云端录像、云存储、录像检索与回看、智能告警、平台级联、云台控制、语音对讲、智能分析等功能。平台既具备传统安防监控的能力&#xff0c;也…

C++进阶语法之函数和指针【学习笔记(三)】

文章目录 1、C 函数1.1 函数的定义1.2 函数原型&#xff08;function prototypes&#xff09;1.3 参数&#xff08;parameter&#xff09;——值传递&#xff08;pass by value&#xff09;1.4 重载&#xff08;overloading&#xff09;1.5 函数传参——传递数组&#xff08;ar…

Linux下控制GPIO的三种方法

https://blog.csdn.net/qq_41076734/article/details/124669908 1. 应用空间控制gpio 1.1简介 在/sys/class/gpio/下有个export文件&#xff0c;向export文件写入要操作的GPIO号&#xff0c;使得该GPIO的操作接口从内核空间暴露到用户空间&#xff0c;GPIO的操作接口包括dir…

Flink学习笔记(四):Flink 四大基石之 Window 和 Time

文章目录 1、 概述2、 Flink 的 Window 和 Time2.1、Window API2.1.1、WindowAssigner2.1.2、Trigger2.1.3、Evictor 2.2、窗口类型2.2.1、Tumbling Windows2.2.2、Sliding Windows2.2.3、Session Windows2.2.4、Global Windows 2.3、Time 时间语义2.4、乱序和延迟数据处理2.5、…

linux系统安装Googletest单元测试框架

环境信息 系统&#xff1a;ubuntn cmake版本&#xff1a;3.5.1 gcc版本&#xff1a;5.4.0 1、下载googletest git clone https://github.com/google/googletest.git注意&#xff01;不选branch的话默认下载最新版本&#xff08;需要编译器能够支持C14&#xff09;&#xff0c;…

如何生成osg的动画路径文件

目录 1. 前言 2. 生成动画路径文件 2.1. 粗糙方式 2.2. 精确方式 1. 前言 在进行osg的开发中&#xff0c;有时需要对模型按某个路径或规则进行动画&#xff0c;如下&#xff1a; 奶牛在10秒时间段从起始的osg::Vec3d(0.0, 18, 1.0)位置 匀速直线运动到osg::Vec3d(0.0, -8, …

ubuntu 中使用Qt连接MMSQl,报错libqsqlodbc.so: undefined symbol: SQLAllocHandle

Qt4.8.7的源码编译出来的libqsqlodbc.so&#xff0c;在使用时报错libqsqlodbc.so: undefined symbol: SQLAllocHandle&#xff0c;需要在编译libqsqlodbc.so 的项目pro文件加上LIBS -L/usr/local/lib -lodbc。 这里的路径根据自己的实际情况填写。 编辑&#xff1a; 使用uni…

Python数据结构(队列)

Python数据结构&#xff08;队列&#xff09; 队列(queue)是只允许在一端进行插入操作&#xff0c;而在另一端进行删除操作的线性表。 队列是一种先进先出的 (First n First ut)的线性表&#xff0c;简称FIFO。允许插入的一端为队尾&#xff0c;允许删除的一端为队头&#xff…

深度学习——图像分类(CIFAR-10)

深度学习——图像分类&#xff08;CIFAR-10&#xff09; 文章目录 前言一、实现图像分类1.1. 获取并组织数据集1.2. 划分训练集、验证集1.3. 图像增广1.4. 引入数据集1.5. 定义模型1.6. 定义训练函数1.7. 训练模型并保存模型参数 二、生成一个桌面小程序2.1. 使用QT设计师设计界…

建筑工程模板分类以及特点?

建筑工程模板可以根据材料、结构和用途等方面进行分类。以下是一些常见的建筑工程模板分类及其特点&#xff1a; 1. 木质模板&#xff1a; - 特点&#xff1a;木质模板是最常见和传统的模板类型。它们由木材制成&#xff0c;具有良好的可塑性和可加工性。木质模板适用于各种形状…

京东数据分析:2023年9月京东白酒行业品牌销售排行榜

鲸参谋监测的京东平台9月份白酒市场销售数据已出炉&#xff01; 9月白酒市场的整体热度较高&#xff0c;贵州茅台先是与瑞幸联名推出酱香拿铁&#xff0c;后又宣布与德芙推出联名产品酒心巧克力&#xff0c;引起了诸多消费者的关注。在这一热度的加持下&#xff0c;从销售上看&…

Kali Linux 安装搭建 hadoop 平台 调用 wordcount 示例程序 详细教程

步骤一&#xff1a; 目标&#xff1a;*安装虚拟机&#xff0c;在自己虚拟机上完成hadoop的伪分布式安装。&#xff08;安装完成后要检查&#xff09;* 1&#xff09;前期环境准备&#xff1a;&#xff08;虚拟机、jdk、ssh&#xff09; 2&#xff09;SSH相关配置 安装SSH Se…

【MyBatis篇】MyBatis动态代理总结

本人正在浅学mybatis&#xff0c;正学到mybatis动态代理&#xff0c;在查询多方资料之后做出以下总结&#xff0c;以便于系统学习时回顾&#xff1b; 目录 MyBatis为什么引入动态代理 mybatis的动态代理 Dao代理技术 MyBatis为什么引入动态代理 因为程序员的 懒&#xff0c;…

访问控制2

文章目录 主要内容一.Role和ClusterRole1.ClusterRole示例&#xff0c;创建一个名为test-clusterrole且仅有创建Pod和deployment的集群角色代码如下&#xff08;示例&#xff09;: 2.YAML文件创建代码如下&#xff08;示例&#xff09;: 3.将udbs用户和Clusterrole进行绑定&…

0基础学习VR全景平台篇第112篇:控制点和遮罩工具 - PTGui Pro教程

上课&#xff01;全体起立~ 大家好&#xff0c;欢迎观看蛙色官方系列全景摄影课程&#xff01; 前情回顾&#xff1a;上节&#xff0c;我们用PTGui拼接了一张全景图&#xff0c;全景编辑器里的各项功能帮助我们进行了初步的检查和编辑。 之后我们需要使用【控制点】和【遮罩…

智慧垃圾站:AI视频智能识别技术助力智慧环保项目,以“智”替人强监管

一、背景分析 建设“技术先进、架构合理、开放智能、安全可靠”的智慧环保平台&#xff0c;整合环境相关的数据&#xff0c;对接已建业务系统&#xff0c;将环境相关数据进行统一管理&#xff0c;结合GIS技术进行监测、监控信息的展现和挖掘分析&#xff0c;实现业务数据的快速…