全代码 | 随机森林在回归分析中的经典应用

news2025/1/12 16:04:05

公众号后台记录了发表过文章的各项阅读指标包括:内容标题,总阅读人数,总阅读次数,总分享人数,总分享次数,阅读后关注人数,送达阅读率,分享产生阅读次数,首次分享率,每次分享带来阅读次数,阅读完成率。

我们尝试利用机器学习中的随机森林算法预测下,是否存在某些指标或指标组合可以预测阅读后关注人数。

数据格式和读入数据

数据集包括1588篇文章的9个统计指标。

  • 阅读统计矩阵: WeChatOfficialAccount.txt

  • 阅读后关注人数:

    WeChatOfficialAccountFollowers.txt

feature_file <- "data/WeChatOfficialAccount.txt"
metadata_file <- "data/WeChatOfficialAccountFollowers.txt"

feature_mat <- read.table(feature_file, row.names = 1, header = T, sep="\t", stringsAsFactors =T)

# 处理异常的特征名字
# rownames(feature_mat) <- make.names(rownames(feature_mat))

metadata <- read.table(metadata_file, row.names=1, header=T, sep="\t", stringsAsFactors =T)

dim(feature_mat)
## [1] 1588    9

阅读统计表示例如下:

feature_mat[1:4,1:5]
##   TotalReadingPeople TotalReadingCounts TotalSharingPeople TotalSharingCounts ReadingRate
## 1               8278              11732                937               1069      0.0847
## 2               8951              12043                828                929      0.0979
## 3              18682              22085                781                917      0.0608
## 4               4978               6166                525                628      0.0072

Metadata表示例如下

head(metadata)
##   FollowersAfterReading
## 1                   227
## 2                   188
## 3                   119
## 4                   116
## 5                   105
## 6                   100

样品筛选和排序

样本表和表达表中的样本顺序对齐一致也是需要确保的一个操作。

feature_mat_sampleL <- rownames(feature_mat)
metadata_sampleL <- rownames(metadata)

common_sampleL <- intersect(feature_mat_sampleL, metadata_sampleL)

# 保证表达表样品与METAdata样品顺序和数目完全一致
feature_mat <- feature_mat[common_sampleL,,drop=F]
metadata <- metadata[common_sampleL,,drop=F]

判断是分类还是回归 

前面读数据时已经给定了参数stringsAsFactors =T,这一步可以忽略了。

  • 如果group对应的列为数字,转换为数值型 - 做回归

  • 如果group对应的列为分组,转换为因子型 - 做分类

# R4.0之后默认读入的不是factor,需要做一个转换
# devtools::install_github("Tong-Chen/ImageGP")
library(ImageGP)

# 此处的FollowersAfterReading根据需要修改
group = "FollowersAfterReading"

# 如果group对应的列为数字,转换为数值型 - 做回归
# 如果group对应的列为分组,转换为因子型 - 做分类
if(numCheck(metadata[[group]])){
    if (!is.numeric(metadata[[group]])) {
      metadata[[group]] <- mixedToFloat(metadata[[group]])
    }
} else{
  metadata[[group]] <- as.factor(metadata[[group]])
}

随机森林初步分析 

library(randomForest)

# 查看参数是个好习惯
# 有了前面的基础概述,再看每个参数的含义就明确了很多
# 也知道该怎么调了
# 每个人要解决的问题不同,通常不是别人用什么参数,自己就跟着用什么参数
# 尤其是到下游分析时
# ?randomForest

# 查看源码
# randomForest:::randomForest.default

加载包之后,直接分析一下,看到结果再调参。

# 设置随机数种子,具体含义见 https://mp.weixin.qq.com/s/6plxo-E8qCdlzCgN8E90zg
set.seed(304)

# 直接使用默认参数
rf <- randomForest(feature_mat, metadata[[group]])

查看下初步结果, 随机森林类型判断为分类,构建了500棵树,每次决策时从随机选择的3个指标中做最优决策 (mtry),平均平方残基 Mean of squared residuals: 39.82736,解释的变异度 % Var explained: 74.91。结果看上去一般。

rf
## 
## Call:
##  randomForest(x = feature_mat, y = metadata[[group]]) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 3
## 
##           Mean of squared residuals: 39.82736
##                     % Var explained: 74.91

观察下模型对训练集的预测效果,看上去一致性还可以。

library(ggplot2)

followerDF <- data.frame(Real_Follower=metadata[[group]], Predicted_Follower=predict(rf, newdata=feature_mat))

sp_scatterplot(followerDF, xvariable = "Real_Follower", yvariable = "Predicted_Follower",
               smooth_method = "auto") + coord_fixed(1)

图片

随机森林标准操作流程

拆分训练集和测试集

library(caret)
seed <- 1
set.seed(seed)
train_index <- createDataPartition(metadata[[group]], p=0.75, list=F)
train_data <- feature_mat[train_index,]
train_data_group <- metadata[[group]][train_index]

test_data <- feature_mat[-train_index,]
test_data_group <- metadata[[group]][-train_index]
dim(train_data)
## [1] 1192    9
dim(test_data)
## [1] 396   9

Boruta特征选择鉴定关键分类变量

# install.packages("Boruta")
library(Boruta)
set.seed(1)

boruta <- Boruta(x=train_data, y=train_data_group, pValue=0.01, mcAdj=T, 
       maxRuns=300)

boruta
## Boruta performed 14 iterations in 5.917085 secs.
##  8 attributes confirmed important: AverageReadingCountsForEachSharing, FirstSharingRate,
## ReadingRate, TotalReadingCounts, TotalReadingCountsOfSharing and 3 more;
##  1 attributes confirmed unimportant: ReadingFinishRate;

查看下变量重要性鉴定结果(实际上面的输出中也已经有体现了),8个重要的变量,0个可能重要的变量 (tentative variable, 重要性得分与最好的影子变量得分无统计差异),1个不重要的变量。

table(boruta$finalDecision)
## 
## Tentative Confirmed  Rejected 
##         0         8         1

绘制鉴定出的变量的重要性。变量少了可以用默认绘图,变量多时绘制的图看不清,需要自己整理数据绘图。

定义一个函数提取每个变量对应的重要性值。

library(dplyr)
boruta.imp <- function(x){
  imp <- reshape2::melt(x$ImpHistory, na.rm=T)[,-1]
  colnames(imp) <- c("Variable","Importance")
  imp <- imp[is.finite(imp$Importance),]

  variableGrp <- data.frame(Variable=names(x$finalDecision), 
                            finalDecision=x$finalDecision)

  showGrp <- data.frame(Variable=c("shadowMax", "shadowMean", "shadowMin"),
                        finalDecision=c("shadowMax", "shadowMean", "shadowMin"))

  variableGrp <- rbind(variableGrp, showGrp)

  boruta.variable.imp <- merge(imp, variableGrp, all.x=T)

  sortedVariable <- boruta.variable.imp %>% group_by(Variable) %>% 
    summarise(median=median(Importance)) %>% arrange(median)
  sortedVariable <- as.vector(sortedVariable$Variable)


  boruta.variable.imp$Variable <- factor(boruta.variable.imp$Variable, levels=sortedVariable)

  invisible(boruta.variable.imp)
}
boruta.variable.imp <- boruta.imp(boruta)

head(boruta.variable.imp)
##                             Variable Importance finalDecision
## 1 AverageReadingCountsForEachSharing   4.861474     Confirmed
## 2 AverageReadingCountsForEachSharing   4.648540     Confirmed
## 3 AverageReadingCountsForEachSharing   6.098471     Confirmed
## 4 AverageReadingCountsForEachSharing   4.701201     Confirmed
## 5 AverageReadingCountsForEachSharing   3.852440     Confirmed
## 6 AverageReadingCountsForEachSharing   3.992969     Confirmed

只绘制Confirmed变量。从图中可以看出重要性排名前4的变量都与“分享”相关 (分享产生阅读次数, 总分享人数, 总分享次数,首 次分享率),文章被分享对于增加关注是很重要的。

library(ImageGP)

sp_boxplot(boruta.variable.imp, melted=T, xvariable = "Variable", yvariable = "Importance",
           legend_variable = "finalDecision", legend_variable_order = c("shadowMax", "shadowMean", "shadowMin", "Confirmed"),
           xtics_angle = 90, coordinate_flip =T)

图片

提取重要的变量和可能重要的变量

boruta.finalVarsWithTentative <- data.frame(Item=getSelectedAttributes(boruta, withTentative = T), Type="Boruta_with_tentative")
data <- cbind(feature_mat, metadata)

variableFactor <- rev(levels(boruta.variable.imp$Variable))

sp_scatterplot(data, xvariable = group, yvariable = variableFactor[1], smooth_method = "auto")

图片

因为变量不多,也可以用ggpairs看下所有变量之间,以及它们与响应变量的相关性怎样?

library(GGally)

ggpairs(data, progress = F)

图片

交叉验证选择参数并拟合模型

定义一个函数生成一些列用来测试的mtry (一系列不大于总变量数的数值)。

generateTestVariableSet <- function(num_toal_variable){
  max_power <- ceiling(log10(num_toal_variable))
  tmp_subset <- c(unlist(sapply(1:max_power, function(x) (1:10)^x, simplify = F)), ceiling(max_power/3))
  #return(tmp_subset)
  base::unique(sort(tmp_subset[tmp_subset<num_toal_variable]))
}
# generateTestVariableSet(78)

选择关键特征变量相关的数据

# 提取训练集的特征变量子集
boruta_train_data <- train_data[, boruta.finalVarsWithTentative$Item]
boruta_mtry <- generateTestVariableSet(ncol(boruta_train_data))

使用 Caret 进行调参和建模

library(caret)

if(file.exists('rda/wechatRegression.rda')){
  borutaConfirmed_rf_default <- readRDS("rda/wechatRegression.rda")
} else {

# Create model with default parameters
trControl <- trainControl(method="repeatedcv", number=10, repeats=5)

seed <- 1
set.seed(seed)
# 根据经验或感觉设置一些待查询的参数和参数值
tuneGrid <- expand.grid(mtry=boruta_mtry)

borutaConfirmed_rf_default <- train(x=boruta_train_data, y=train_data_group, method="rf", 
                                    tuneGrid = tuneGrid, # 
                                    metric="RMSE", #metric='Kappa'
                                    trControl=trControl)
saveRDS(borutaConfirmed_rf_default, "rda/wechatRegression.rda")
}

borutaConfirmed_rf_default
## Random Forest 
## 
## 1192 samples
##    8 predictor
## 
## No pre-processing
## Resampling: Cross-Validated (10 fold, repeated 5 times) 
## Summary of sample sizes: 1073, 1073, 1073, 1072, 1073, 1073, ... 
## Resampling results across tuning parameters:
## 
##   mtry  RMSE      Rsquared   MAE     
##   1     6.441881  0.7020911  2.704873
##   2     6.422848  0.7050505  2.720557
##   3     6.418449  0.7052825  2.736505
##   4     6.431665  0.7039496  2.742612
##   5     6.453067  0.7013595  2.754239
##   6     6.470716  0.6998307  2.758901
##   7     6.445304  0.7020575  2.756523
## 
## RMSE was used to select the optimal model using the smallest value.
## The final value used for the model was mtry = 3.

绘制准确性随超参的变化曲线

plot(borutaConfirmed_rf_default)

图片

绘制贡献最高的 20 个变量 (Boruta评估的变量重要性跟模型自身评估的重要性略有不同)

dotPlot(varImp(borutaConfirmed_rf_default))

图片

提取最终选择的模型,评估其效果。

borutaConfirmed_rf_default_finalmodel <- borutaConfirmed_rf_default$finalModel

首先采用训练数据集评估构建的模型的训练效果,RMSE=3.1Rsquared=0.944,还是挺不错的。

# 获得模型结果评估参数
predictions_train <- predict(borutaConfirmed_rf_default_finalmodel, newdata=train_data)
postResample(pred = predictions_train, obs = train_data_group)
##      RMSE  Rsquared       MAE 
## 3.1028533 0.9440182 1.1891391

采用测试数据评估模型的预测效果,RMSE=6.2Rsquared=0.825,还可以。后续用下其它方法看看能否提高。

predictions_train <- predict(borutaConfirmed_rf_default_finalmodel, newdata=test_data)
postResample(pred = predictions_train, obs = test_data_group)
##      RMSE  Rsquared       MAE 
## 6.2219834 0.8251457 2.7212806
library(ggplot2)

testfollowerDF <- data.frame(Real_Follower=test_data_group, Predicted_Follower=predictions_train)

sp_scatterplot(testfollowerDF, xvariable = "Real_Follower", yvariable = "Predicted_Follower",
               smooth_method = "auto") + coord_fixed(1)

图片

随机森林回归的不足

随机森林回归模型预测出的值不会超出训练集中响应变量的取值范围,不能用于外推。

可以使用Regression-Enhanced Random Forests (RERFs)作为一个解决方案。

References

  1. https://medium.com/swlh/random-forest-and-its-implementation-71824ced454f

  2. https://neptune.ai/blog/random-forest-regression-when-does-it-fail-and-why

  3. https://levelup.gitconnected.com/random-forest-regression-209c0f354c84

  4. https://rpubs.com/Isaac/caret_reg

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1941275.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《低代码指南》——Oracle APEX : AI在低代码开发中的创新应用

在低代码开发领域,我们正在目睹人工智能(AI)集成所带来的显著进展。Oracle公司最新推出的APEX 24.1版本,便是这一趋势的明显体现,其集成的AI功能旨在极大提高开发者的生产力,同时简化应用程序的创建过程。 Contents 将变革性的AI整合到低代码平台 将AI技术引入低代码平台…

UniMERNet - 数学公式识别转LaTeX

文章目录 一、关于 UniMERNet演示视频 二、快速入门1、克隆repo并下载模型2、安装3、运行UniMERNet 四、评估1、下载UniMER-Test数据集2、运行评估代码3、与SOTA方法的性能比较&#xff08;BLEU&#xff09;4、不同方法的可视化结果 五、UniMER数据集1、导言2、数据集下载 六、…

【Linux 13】文件系统

文章目录 &#x1f308; 一、前言&#x1f308; 二、文件操作的系统接口⭐ 1. 打开文件 open⭐ 2. 关闭文件 close⭐ 3. 写入文件 write⭐ 4. 读取文件 read &#x1f308; 三、文件描述符⭐ 1. 文件描述符介绍⭐ 2. 提前被分配的文件描述符 0 1 2⭐ 3. 文件描述符的分配规则 &…

【Qt】Qt的坐标转换(mapToGlobal)

1、QPoint QWidget::mapToGlobal(const QPoint &pos) const 将小部件坐标转换为全局坐标。mapToGlobal(QPoint(0,0))可以得到小部件左上角像素的全局坐标。2、QPoint QWidget::mapToParent(const QPoint &pos) const 将小部件坐标转换为父部件坐标。如果小部件没有父部…

Linux内存的概念及管理

1、内存概念 内存是指计算机中所安装的随机存取内存的容量&#xff0c;储存是指计算机内硬盘的容量。硬盘应当是计算机的“外存”。内存应当是在主板上的一些存储器&#xff0c;用来保存CPU运算使用过程中的中间数据和计算结果&#xff0c;当不用这些数据时&#xff0c;它们被保…

鸿蒙仓颉语言【Redis仓颉语言客户端】

特性 支持RESP2和RESP3协议接口设计兼容jedis接口语义丰富的管理命令支持支持单连接多线程模式支持发布订阅模式支持哨兵模式和集群模式完备的单元测试覆盖架构简洁&#xff0c;易于扩展 开发计划 2024.3.22 完成支持单机模式的RESP2和RESP3协议的客户端&#xff0c;提供Bet…

实际生活中网段不通的典型分析及处理方案

关于端口&#xff1a; 应用层&#xff1a; FTP TELNET SMTP DNS TFTP SNMP 端口号&#xff1a; 21 23 25 53 69 161 传输层&#xff1a; TCP UDP&#xff08;DNS两个都占…

7月21日,贪心练习

大家好呀&#xff0c;今天带来一些贪心算法的应用解题、 一&#xff0c;柠檬水找零 . - 力扣&#xff08;LeetCode&#xff09; 解析&#xff1a; 本题的贪心体现在对于20美元的处理上&#xff0c;我们总是优先把功能较少的10元作为找零&#xff0c;这样可以让5元用处更大 …

野兔在线工具箱系统全新升级改版,基于TP8和yetuadmin后台实现

野兔在线工具箱系统全新升级改版&#xff0c;基于TP8和yetuadmin后台实现 系统名称&#xff1a;野兔在线工具系统 系统语言&#xff1a;支持多语言&#xff0c;大概有20种 系统源码&#xff1a;不加密&#xff0c;开源 系统开发&#xff1a;PHPMySQL (基于thinkphp8&#x…

重发布路由策略实验

实验要求 1.搭建拓扑 路由策略分析&#xff1a; 拓扑左边是rip协议&#xff0c;右边是ospf协议&#xff0c;想要实现全网可达可以采用多点双向重发布的方式。 对于rip协议使用偏移列表来干涉选路&#xff0c;对于ospf协议采用路由策略来干涉选路 2.配置ip r1 [AR1]interfac…

【初阶数据结构】深度解析七大常见排序|掌握底层逻辑与原理

初阶数据结构相关知识点可以通过点击以下链接进行学习一起加油&#xff01;时间与空间复杂度的深度剖析深入解析顺序表:探索底层逻辑深入解析单链表:探索底层逻辑深入解析带头双向循环链表:探索底层逻辑深入解析栈:探索底层逻辑深入解析队列:探索底层逻辑深入解析循环队列:探索…

什么是机器学习以及机器学习如今的社会现状!!

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;开发者-曼亿点 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 曼亿点 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a…

长效代理IP如何选用及代理服务分析

在这个数据为王、信息瞬息万变的时代&#xff0c;长效代理IP成为了众多开发者、数据科学家乃至普通网民手中的一把利器。它不仅能帮助我们解决地域管理&#xff0c;还能在保护隐私的同时&#xff0c;确保数据传输的稳定与安全。但面对市面上琳琅满目的代理服务&#xff0c;如何…

飞腾腾云S2500 Nginx单机环回测试性能调优方法

【写在前面】 飞腾开发者平台是基于飞腾自身强大的技术基础和开放能力&#xff0c;聚合行业内优秀资源而打造的。该平台覆盖了操作系统、算法、数据库、安全、平台工具、虚拟化、存储、网络、固件等多个前沿技术领域&#xff0c;包含了应用使能套件、软件仓库、软件支持、软件适…

Vulnhub靶场DC-7练习

目录 0x00 准备0x01 主机信息收集0x02 站点信息收集1. 获取用户名/密码2. ssh连接目标主机3. drush命令修改Drupal密码 0x03 漏洞查找与利用1. Drupal写入php木马2. 连接shell3. 反弹shell并提权 0x04 总结 0x00 准备 下载链接&#xff1a;https://download.vulnhub.com/dc/DC-…

深度学习每周学习总结N4:中文文本分类-Pytorch实现(基本分类(熟悉流程)、textCNN分类(通用模型)、Bert分类(模型进阶))

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 | 接辅导、项目定制 目录 0. 总结&#xff1a;1. 基础模型a. 数据加载b. 数据预处理c. 模型搭建与初始化d. 训练函数e. 评估函数f.拆分数据集运行模型g. 结果可…

git命令学习分享

分布式版本控制系统&#xff0c;本地仓库和远程仓库相互独立。 使用repository仓库进行控制&#xff0c;可以对里面的文件进行跟踪&#xff0c;复原。 git config --global --list&#xff1a;查看git配置列表 cd ** &#xff1a;进入** cd .. &#xff1a;退回上一级 echo…

【人工智能】Transformers之Pipeline(四):零样本音频分类(zero-shot-audio-classification)

​​​​​​​ 目录 一、引言 二、零样本音频分类&#xff08;zero-shot-audio-classification&#xff09; 2.1 概述 2.2 意义 2.3 应用场景 2.4 pipeline参数 2.4.1 pipeline对象实例化参数​​​​​​​ 2.4.2 pipeline对象使用参数 2.4 pipeline实战 2.5 模…

TinyVue:与 Vue 交往八年的组件库

本文由体验技术团队莫春辉老师原创~ 去年因故停办的 VueConf&#xff0c;今年如约在深圳举行。作为东道主 & 上届 VueConf 讲师的我&#xff0c;没有理由不来凑个热闹。大会结束后&#xff0c;我见裕波在朋友圈转发 Jinjiang 的文章《我和 Vue.js 的十年》&#xff0c;我就…

版本控制工具

版本控制工具是用于记录代码文件变化历史、方便查阅特定版本修改情况的系统&#xff0c;一般分为集中式和分布式两种。以下是一些常见的版本控制工具&#xff1a; 集中式版本控制工具 Subversion&#xff08;SVN&#xff09; 简介&#xff1a;Subversion是一种集中式版本控制…