第100+15步 ChatGPT学习:R实现Ababoost分类

news2025/1/16 5:57:13

基于R 4.2.2版本演示

一、写在前面

有不少大佬问做机器学习分类能不能用R语言,不想学Python咯。

答曰:可!用GPT或者Kimi转一下就得了呗。

加上最近也没啥内容写了,就帮各位搬运一下吧。

二、R代码实现Ababoost分类

(1)导入数据

我习惯用RStudio自带的导入功能:

(2)建立Ababoost模型(默认参数)

# Load necessary libraries
library(caret)
library(pROC)
library(ggplot2)

# Assume 'data' is your dataframe containing the data
# Set seed to ensure reproducibility
set.seed(123)

# Split data into training and validation sets (80% training, 20% validation)
trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
trainData <- data[trainIndex, ]
validData <- data[-trainIndex, ]

# Convert the target variable to a factor for classification
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)

# Define control method for training with cross-validation
trainControl <- trainControl(method = "cv", number = 10)

# Fit Random Forest model on the training set
model <- train(X ~ ., data = trainData, method = "ada", trControl = trainControl)

# Print the best parameters found by the model
best_params <- model$bestTune
cat("The best parameters found are:\n")
print(best_params)

# Predict on the training and validation sets
trainPredict <- predict(model, trainData, type = "prob")[,2]
validPredict <- predict(model, validData, type = "prob")[,2]

# Calculate ROC curves and AUC values
trainRoc <- roc(response = trainData$X, predictor = trainPredict)
validRoc <- roc(response = validData$X, predictor = validPredict)

# Plot ROC curves with AUC values
ggplot(data = data.frame(fpr = trainRoc$specificities, tpr = trainRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +
  geom_line(color = "blue") +
  geom_area(alpha = 0.2, fill = "blue") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  ggtitle("Training ROC Curve") +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  annotate("text", x = 0.5, y = 0.1, label = paste("Training AUC =", round(auc(trainRoc), 2)), hjust = 0.5, color = "blue")

ggplot(data = data.frame(fpr = validRoc$specificities, tpr = validRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +
  geom_line(color = "red") +
  geom_area(alpha = 0.2, fill = "red") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  ggtitle("Validation ROC Curve") +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  annotate("text", x = 0.5, y = 0.2, label = paste("Validation AUC =", round(auc(validRoc), 2)), hjust = 0.5, color = "red")

# Calculate confusion matrices based on 0.5 cutoff for probability
confMatTrain <- table(trainData$X, trainPredict >= 0.5)
confMatValid <- table(validData$X, validPredict >= 0.5)

# Function to plot confusion matrix using ggplot2
plot_confusion_matrix <- function(conf_mat, dataset_name) {
  conf_mat_df <- as.data.frame(as.table(conf_mat))
  colnames(conf_mat_df) <- c("Actual", "Predicted", "Freq")
  
  p <- ggplot(data = conf_mat_df, aes(x = Predicted, y = Actual, fill = Freq)) +
    geom_tile(color = "white") +
    geom_text(aes(label = Freq), vjust = 1.5, color = "black", size = 5) +
    scale_fill_gradient(low = "white", high = "steelblue") +
    labs(title = paste("Confusion Matrix -", dataset_name, "Set"), x = "Predicted Class", y = "Actual Class") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1), plot.title = element_text(hjust = 0.5))
  
  print(p)
}

# Now call the function to plot and display the confusion matrices
plot_confusion_matrix(confMatTrain, "Training")
plot_confusion_matrix(confMatValid, "Validation")

# Extract values for calculations
a_train <- confMatTrain[1, 1]
b_train <- confMatTrain[1, 2]
c_train <- confMatTrain[2, 1]
d_train <- confMatTrain[2, 2]

a_valid <- confMatValid[1, 1]
b_valid <- confMatValid[1, 2]
c_valid <- confMatValid[2, 1]
d_valid <- confMatValid[2, 2]

# Training Set Metrics
acc_train <- (a_train + d_train) / sum(confMatTrain)
error_rate_train <- 1 - acc_train
sen_train <- d_train / (d_train + c_train)
sep_train <- a_train / (a_train + b_train)
precision_train <- d_train / (b_train + d_train)
F1_train <- (2 * precision_train * sen_train) / (precision_train + sen_train)
MCC_train <- (d_train * a_train - b_train * c_train) / sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train))
auc_train <- roc(response = trainData$X, predictor = trainPredict)$auc

# Validation Set Metrics
acc_valid <- (a_valid + d_valid) / sum(confMatValid)
error_rate_valid <- 1 - acc_valid
sen_valid <- d_valid / (d_valid + c_valid)
sep_valid <- a_valid / (a_valid + b_valid)
precision_valid <- d_valid / (b_valid + d_valid)
F1_valid <- (2 * precision_valid * sen_valid) / (precision_valid + sen_valid)
MCC_valid <- (d_valid * a_valid - b_valid * c_valid) / sqrt((d_valid + b_valid) * (d_valid + c_valid) * (a_valid + b_valid) * (a_valid + c_valid))
auc_valid <- roc(response = validData$X, predictor = validPredict)$auc

# Print Metrics
cat("Training Metrics\n")
cat("Accuracy:", acc_train, "\n")
cat("Error Rate:", error_rate_train, "\n")
cat("Sensitivity:", sen_train, "\n")
cat("Specificity:", sep_train, "\n")
cat("Precision:", precision_train, "\n")
cat("F1 Score:", F1_train, "\n")
cat("MCC:", MCC_train, "\n")
cat("AUC:", auc_train, "\n\n")

cat("Validation Metrics\n")
cat("Accuracy:", acc_valid, "\n")
cat("Error Rate:", error_rate_valid, "\n")
cat("Sensitivity:", sen_valid, "\n")
cat("Specificity:", sep_valid, "\n")
cat("Precision:", precision_valid, "\n")
cat("F1 Score:", F1_valid, "\n")
cat("MCC:", MCC_valid, "\n")
cat("AUC:", auc_valid, "\n")

在R语言中,使用 caret 包训练Ababoost模型时,最关键的可调参数不多,下面是一些可以调整的关键参数:

①Iter: 这是最重要的参数之一,代表弱学习器的数量,即AdaBoost算法中的迭代次数。较大的nIter值通常可以提高模型的复杂度和拟合能力,但也可能导致过拟合。

②maxdepth: 这是决策树的最大深度。AdaBoost通常使用决策树作为其弱学习器。通过调整maxdepth可以控制单个决策树的复杂度,从而影响整个集成模型的复杂度。

③nu: 这个参数是学习率(也称为收缩参数或步长)。它用于更新每次迭代中模型权重。较小的nu值可以使模型学习得更加谨慎,通常可以减少过拟合的风险,但可能需要更多的迭代次数来收敛。

结果输出(默认参数):

在默认参数中,caret包已经默默帮我们吧上面三个参数进行测试和寻优。

从AUC来看,Ababoost随便一跑,就跑出个不错的结果。不过有些过拟合了,验证集的性能稍微差些。

三、Ababoost手动调参方法(3个值)

设置iter值取值50、100、200、400、600;maxdepth取值1、2、5、7和9;nu取值0.01、0.1、0.5:

# Load necessary libraries
library(caret)
library(pROC)
library(ggplot2)

# Assume 'data' is your dataframe containing the data
# Set seed to ensure reproducibility
set.seed(123)

# Split data into training and validation sets (80% training, 20% validation)
trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
trainData <- data[trainIndex, ]
validData <- data[-trainIndex, ]

# Convert the target variable to a factor for classification
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)

# Define control method for training with cross-validation
trainControl <- trainControl(method = "cv", number = 10)

# Define the tuning grid with correct parameter names
tuneGrid <- expand.grid(iter = c(50, 100, 200, 400, 600),
                        maxdepth = c(1, 2, 5, 7, 9),
                        nu = c(0.01, 0.1, 0.5))

# Train the model using the ada method and the corrected tuning grid
model <- train(X ~ ., data = trainData, method = "ada", trControl = trainControl, tuneGrid = tuneGrid)


# Print the best parameters found by the model
best_params <- model$bestTune
cat("The best parameters found are:\n")
print(best_params)

# Predict on the training and validation sets
trainPredict <- predict(model, trainData, type = "prob")[,2]
validPredict <- predict(model, validData, type = "prob")[,2]

# Calculate ROC curves and AUC values
trainRoc <- roc(response = trainData$X, predictor = trainPredict)
validRoc <- roc(response = validData$X, predictor = validPredict)

# Plot ROC curves with AUC values
ggplot(data = data.frame(fpr = trainRoc$specificities, tpr = trainRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +
  geom_line(color = "blue") +
  geom_area(alpha = 0.2, fill = "blue") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  ggtitle("Training ROC Curve") +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  annotate("text", x = 0.5, y = 0.1, label = paste("Training AUC =", round(auc(trainRoc), 2)), hjust = 0.5, color = "blue")

ggplot(data = data.frame(fpr = validRoc$specificities, tpr = validRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +
  geom_line(color = "red") +
  geom_area(alpha = 0.2, fill = "red") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  ggtitle("Validation ROC Curve") +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  annotate("text", x = 0.5, y = 0.2, label = paste("Validation AUC =", round(auc(validRoc), 2)), hjust = 0.5, color = "red")

# Calculate confusion matrices based on 0.5 cutoff for probability
confMatTrain <- table(trainData$X, trainPredict >= 0.5)
confMatValid <- table(validData$X, validPredict >= 0.5)

# Function to plot confusion matrix using ggplot2
plot_confusion_matrix <- function(conf_mat, dataset_name) {
  conf_mat_df <- as.data.frame(as.table(conf_mat))
  colnames(conf_mat_df) <- c("Actual", "Predicted", "Freq")
  
  p <- ggplot(data = conf_mat_df, aes(x = Predicted, y = Actual, fill = Freq)) +
    geom_tile(color = "white") +
    geom_text(aes(label = Freq), vjust = 1.5, color = "black", size = 5) +
    scale_fill_gradient(low = "white", high = "steelblue") +
    labs(title = paste("Confusion Matrix -", dataset_name, "Set"), x = "Predicted Class", y = "Actual Class") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1), plot.title = element_text(hjust = 0.5))
  
  print(p)
}

# Now call the function to plot and display the confusion matrices
plot_confusion_matrix(confMatTrain, "Training")
plot_confusion_matrix(confMatValid, "Validation")

# Extract values for calculations
a_train <- confMatTrain[1, 1]
b_train <- confMatTrain[1, 2]
c_train <- confMatTrain[2, 1]
d_train <- confMatTrain[2, 2]

a_valid <- confMatValid[1, 1]
b_valid <- confMatValid[1, 2]
c_valid <- confMatValid[2, 1]
d_valid <- confMatValid[2, 2]

# Training Set Metrics
acc_train <- (a_train + d_train) / sum(confMatTrain)
error_rate_train <- 1 - acc_train
sen_train <- d_train / (d_train + c_train)
sep_train <- a_train / (a_train + b_train)
precision_train <- d_train / (b_train + d_train)
F1_train <- (2 * precision_train * sen_train) / (precision_train + sen_train)
MCC_train <- (d_train * a_train - b_train * c_train) / sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train))
auc_train <- roc(response = trainData$X, predictor = trainPredict)$auc

# Validation Set Metrics
acc_valid <- (a_valid + d_valid) / sum(confMatValid)
error_rate_valid <- 1 - acc_valid
sen_valid <- d_valid / (d_valid + c_valid)
sep_valid <- a_valid / (a_valid + b_valid)
precision_valid <- d_valid / (b_valid + d_valid)
F1_valid <- (2 * precision_valid * sen_valid) / (precision_valid + sen_valid)
MCC_valid <- (d_valid * a_valid - b_valid * c_valid) / sqrt((d_valid + b_valid) * (d_valid + c_valid) * (a_valid + b_valid) * (a_valid + c_valid))
auc_valid <- roc(response = validData$X, predictor = validPredict)$auc

# Print Metrics
cat("Training Metrics\n")
cat("Accuracy:", acc_train, "\n")
cat("Error Rate:", error_rate_train, "\n")
cat("Sensitivity:", sen_train, "\n")
cat("Specificity:", sep_train, "\n")
cat("Precision:", precision_train, "\n")
cat("F1 Score:", F1_train, "\n")
cat("MCC:", MCC_train, "\n")
cat("AUC:", auc_train, "\n\n")

cat("Validation Metrics\n")
cat("Accuracy:", acc_valid, "\n")
cat("Error Rate:", error_rate_valid, "\n")
cat("Sensitivity:", sen_valid, "\n")
cat("Specificity:", sep_valid, "\n")
cat("Precision:", precision_valid, "\n")
cat("F1 Score:", F1_valid, "\n")
cat("MCC:", MCC_valid, "\n")
cat("AUC:", auc_valid, "\n")

结果输出:

以上是找到的相对最优参数组合,看看具体性能:

还不让入默认的性能好呢。

看看GPT给的参数的取值建议,祝各位调得开心:

iter (迭代次数): 这个参数通常设置在10到1000之间。较小的数据集可能需要较少的迭代,而较大或较复杂的数据集可能需要更多的迭代。通常开始可以尝试50, 100, 200等值,然后根据模型的性能来调整。

maxdepth (树的最大深度): 这个参数一般设置在1到10之间。深度为1意味着使用决策树桩(仅一个决策点),这有助于防止过拟合,是AdaBoost中常用的设置。但对于更复杂的数据模式,可能需要更深的树。可以尝试的值包括1, 2, 3, 5等。

nu (学习率): 学习率的典型取值范围是0.01到1。较小的学习率(如0.01, 0.1)可以使模型学习得更稳健,但收敛速度可能较慢,需要更多的迭代次数。较高的学习率可以加快学习速度,但可能导致模型在训练过程中不稳定。

四、最后

数据嘛:

链接:https://pan.baidu.com/s/1rEf6JZyzA1ia5exoq5OF7g?pwd=x8xm

提取码:x8xm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1920725.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

打造热销爆款:LazadaShopee店铺测评与关键词策略

面对Lazada和Shopee平台上店铺销量难以突破的困境&#xff0c;卖家们往往寻求各种解决方案。其中&#xff0c;店铺测评作为提升店铺信誉、优化产品排名及增加曝光度的有效手段&#xff0c;正逐渐成为卖家关注的焦点。以下将深入探讨店铺测评的好处、实施技巧及自养号的关键要素…

RK3588部署YOLOV8-seg的问题

在使用YOLOV8-seg训练出来的pt模型转为onnx的时候&#xff0c;利用以下仓库地址转。 git clone https://github.com/airockchip/ultralytics_yolov8.git 在修改ultralytics/cfg/default.yaml中的task&#xff0c;mode为model为自己需要的内容后&#xff0c; 执行以下语句 cd …

2024最新修复微信公众号无限回调系统源码下载 免授权开心版

2024最新修复微信公众号无限回调系统源码下载 免授权开心版 微信公众平台回调比较麻烦&#xff0c;还不能多次回调&#xff0c;于是搭建一个多域名回调的源码很有必要。 测试环境&#xff1a;Nginx1.24PHP7.2MySQL5.6 图片&#xff1a;

5G与未来通信技术

随着科技的迅猛发展&#xff0c;通信技术也在不断演进。5G技术作为第五代移动通信技术&#xff0c;已成为现代通信技术的一个重要里程碑。本文将详细介绍5G及其对未来通信技术的影响&#xff0c;重点探讨超高速互联网和边缘网络的应用。 一、超高速互联网 1. 低延迟 5G技术最显…

一个vue页面复用方案

前言 问大家一个问题&#xff0c;曾经的你是否也遇到过&#xff0c;一个项目中有好几个页面长得基本相同&#xff0c;但又差那么一点&#xff0c;想用 vue extends 继承它又不能按需继承html模板部分&#xff0c;恰好 B 页面需要用的 A 页面 80% 的模板&#xff0c;剩下的 20%…

在Anaconda环境中安装TensorFlow+启动jupyter notebook

1.打开cmd&#xff0c;输入C:\Users\xy>conda create -n tensorflow python3.7 这是在环境中创建了一个名为tensorflow的环境&#xff0c;具体会显示以下信息&#xff1a; C:\Users\xy>conda create -n tensorflow python3.7 Retrieving notices: ...working... done Co…

等保从哪些方面进行测评

等保&#xff0c;全名叫做信息安全等级保护&#xff0c;顾名思义就是指根据信息系统在国家安全、社会稳定、经济秩序和公共利益方便的中重要程度以及风险威胁、安全需求、安全成本等因素&#xff0c;将其划分不同的安全保护等级并采取相应等级的安全保护技术、管理措施、以保障…

Python面试宝典第11题:最长连续序列

题目 给定一个未排序的整数数组 nums &#xff0c;找出数字连续的最长序列&#xff08;不要求序列元素在原数组中连续&#xff09;的长度。请你设计并实现时间复杂度为 O(n) 的算法解决此问题。 示例 1&#xff1a; 输入&#xff1a;nums [100,4,200,1,3,2] 输出&#xff1a;…

前端JS特效第30集:jQuery焦点图插件edslider

jQuery焦点图插件edslider&#xff0c;先来看看效果&#xff1a; 部分核心的代码如下(全部代码在文章末尾)&#xff1a; <!DOCTYPE html> <html lang"zh"> <head> <meta charset"UTF-8"> <meta http-equiv"X-UA-Compatib…

台湾精锐APEX伺服行星减速机发热原因及解决方案

在实际运行过程中台湾精锐APEX伺服行星减速机常常会遇到发热的问题&#xff0c;这不仅影响减速机的正常运转&#xff0c;还可能缩短其使用寿命&#xff0c;甚至引发安全事故。因此&#xff0c;了解APEX伺服行星减速机发热的原因及相应的解决方案&#xff0c;对于保障生产线的稳…

C语言——基础框架、变量、运算符

基础框架&#xff1a; #include<stdio.h> //编译预处理指令int main() //程序的入口主函数main { //程序&#xff08;函数、功能&#xff09;结束标志return 0; //程序退出前返回给调用者&#xff08;操作系统&#xff09;的值…

MySQL实战45讲学习笔记(持续更新ing……)

文章目录 一、基础架构&#xff1a;一条SQL查询语句是如何执行的&#xff1f;概览连接器查询缓存分析器优化器执行器 二、日志系统&#xff1a;一条SQL更新语句是如何执行的&#xff1f;redo logbinlog两阶段提交 一、基础架构&#xff1a;一条SQL查询语句是如何执行的&#xf…

深度学习DeepLearning多元线性回归 学习笔记

文章目录 多维特征变量与术语公式多元线性回归正规方程法Mean normalizationZ-score normalization设置合适的学习率Feature engineering 多维特征 变量与术语 列属性xj属性数n x ⃗ \vec{x} x (i)行向量某个值 x ⃗ j i \vec{x}_j^i x ji​上行下列均值μ标准化标准差σsigm…

无线速度传感器

对高中物理实验中的速度测量方法进行改进&#xff0c;利用安装在小车上的无线光电门来测量小车运动过程中的速度&#xff0c;即满足了精度的要求&#xff0c;又可以研究物体的运动过程。无线光电门和数据接收器间采用蓝牙无线传输的方式&#xff0c;电脑端的软件使用Flash来制作…

vant-app中加的custom-class为啥审查元素时看不到自定义类名

如下图&#xff1a; 我们发现在左侧审查元素时确实看不到&#xff0c;但是在右侧是可以看到&#xff0c;而且样式是生效的。 是不是微信开发者工具的bug?

SQL基础-DQL 小结

SQL基础-DQL 小结 学习目标&#xff1a;学习内容&#xff1a;SELECTFROMWHEREGROUP BYHAVINGORDER BY运算符ASC 和 DESC 总结 学习目标&#xff1a; 1.理解DQL&#xff08;Data Query Language&#xff09;的基本概念和作用。 2.掌握SQL查询的基本语法结构&#xff0c;包括SEL…

微软子公司Xandr遭隐私诉讼,或面临巨额罚款

近日&#xff0c;欧洲隐私权倡导组织noyb对微软子公司Xandr提起了诉讼&#xff0c;指控其透明度不足&#xff0c;侵犯了欧盟公民的数据访问权。据指控&#xff0c;Xandr的行为涉嫌违反《通用数据保护条例》&#xff08;GFPR&#xff09;&#xff0c;因其处理信息并创建用于微目…

C#开发:VS2022中配置TFS(Team Foundation Server)和使用

第一步&#xff0c;点出团队资源管理器 第二步&#xff0c;输入服务器地址 第三步&#xff0c;输入配置地址和账密&#xff08;问管理员&#xff09; 输入配置地址&#xff1a;$/xxxx 输入工作区地址&#xff1a;本地随便一个路径 第四步&#xff0c;获取最新代码 第五步&#…

空调元件的介绍

保险丝管 1、保险丝管在电脑板上用FC1.2&#xff08;FUSE&#xff09;表示&#xff0c;主要用于起过电流保护。 2、故障现象&#xff1a;整机无电不工作 3、检测方法&#xff1a; 目测观察保险丝是否熔断&#xff0c;如是应更换&#xff1b; 4、注意事项&#xff1a; 如果电…

Python酷库之旅-第三方库Pandas(018)

目录 一、用法精讲 44、pandas.crosstab函数 44-1、语法 44-2、参数 44-3、功能 44-4、返回值 44-5、说明 44-6、用法 44-6-1、数据准备 44-6-2、代码示例 44-6-3、结果输出 45、pandas.cut函数 45-1、语法 45-2、参数 45-3、功能 45-4、返回值 45-5、说明 4…