第100+12步 ChatGPT学习:R实现KNN分类

news2024/10/6 16:31:00

基于R 4.2.2版本演示

一、写在前面

有不少大佬问做机器学习分类能不能用R语言,不想学Python咯。

答曰:可!用GPT或者Kimi转一下就得了呗。

加上最近也没啥内容写了,就帮各位搬运一下吧。

二、R代码实现KNN分类

(1)导入数据

我习惯用RStudio自带的导入功能:

(2)建立KNN模型

# Load necessary libraries
library(caret)
library(pROC)
library(ggplot2)

# Assume 'data' is your dataframe containing the data
# Set seed to ensure reproducibility
set.seed(123)

# Split data into training and validation sets (80% training, 20% validation)
trainIndex <- createDataPartition(data$X, p = 0.8, list = FALSE)
trainData <- data[trainIndex, ]
validData <- data[-trainIndex, ]

# Convert the target variable to a factor for classification
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)

# Define control method for training with cross-validation
trainControl <- trainControl(method = "cv", number = 10)

# Fit KNN model on the training set
model <- train(X ~ ., data = trainData, method = "knn", trControl = trainControl, preProcess = "scale")

# Predict on the training and validation sets
trainPredict <- predict(model, trainData, type = "prob")[,2]
validPredict <- predict(model, validData, type = "prob")[,2]

# Convert true values to factor for ROC analysis
trainData$X <- as.factor(trainData$X)
validData$X <- as.factor(validData$X)

# Calculate ROC curves and AUC values
trainRoc <- roc(response = trainData$X, predictor = trainPredict)
validRoc <- roc(response = validData$X, predictor = validPredict)

# Plot ROC curves with AUC values
ggplot(data = data.frame(fpr = trainRoc$specificities, tpr = trainRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +
  geom_line(color = "blue") +
  geom_area(alpha = 0.2, fill = "blue") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  ggtitle("Training ROC Curve") +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  annotate("text", x = 0.5, y = 0.1, label = paste("Training AUC =", round(auc(trainRoc), 2)), hjust = 0.5, color = "blue")

ggplot(data = data.frame(fpr = validRoc$specificities, tpr = validRoc$sensitivities), aes(x = 1 - fpr, y = tpr)) +
  geom_line(color = "red") +
  geom_area(alpha = 0.2, fill = "red") +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed", color = "black") +
  ggtitle("Validation ROC Curve") +
  xlab("False Positive Rate") +
  ylab("True Positive Rate") +
  annotate("text", x = 0.5, y = 0.2, label = paste("Validation AUC =", round(auc(validRoc), 2)), hjust = 0.5, color = "red")

# Calculate confusion matrices based on 0.5 cutoff for probability
confMatTrain <- table(trainData$X, trainPredict >= 0.5)
confMatValid <- table(validData$X, validPredict >= 0.5)

# Function to plot confusion matrix using ggplot2
plot_confusion_matrix <- function(conf_mat, dataset_name) {
  conf_mat_df <- as.data.frame(as.table(conf_mat))
  colnames(conf_mat_df) <- c("Actual", "Predicted", "Freq")
  
  p <- ggplot(data = conf_mat_df, aes(x = Predicted, y = Actual, fill = Freq)) +
    geom_tile(color = "white") +
    geom_text(aes(label = Freq), vjust = 1.5, color = "black", size = 5) +
    scale_fill_gradient(low = "white", high = "steelblue") +
    labs(title = paste("Confusion Matrix -", dataset_name, "Set"), x = "Predicted Class", y = "Actual Class") +
    theme_minimal() +
    theme(axis.text.x = element_text(angle = 45, hjust = 1), plot.title = element_text(hjust = 0.5))
  
  print(p)
}
# Now call the function to plot and display the confusion matrices
plot_confusion_matrix(confMatTrain, "Training")
plot_confusion_matrix(confMatValid, "Validation")

# Extract values for calculations
a_train <- confMatTrain[1, 1]
b_train <- confMatTrain[1, 2]
c_train <- confMatTrain[2, 1]
d_train <- confMatTrain[2, 2]

a_valid <- confMatValid[1, 1]
b_valid <- confMatValid[1, 2]
c_valid <- confMatValid[2, 1]
d_valid <- confMatValid[2, 2]

# Training Set Metrics
acc_train <- (a_train + d_train) / sum(confMatTrain)
error_rate_train <- 1 - acc_train
sen_train <- d_train / (d_train + c_train)
sep_train <- a_train / (a_train + b_train)
precision_train <- d_train / (b_train + d_train)
F1_train <- (2 * precision_train * sen_train) / (precision_train + sen_train)
MCC_train <- (d_train * a_train - b_train * c_train) / sqrt((d_train + b_train) * (d_train + c_train) * (a_train + b_train) * (a_train + c_train))
auc_train <- roc(response = trainData$X, predictor = trainPredict)$auc

# Validation Set Metrics
acc_valid <- (a_valid + d_valid) / sum(confMatValid)
error_rate_valid <- 1 - acc_valid
sen_valid <- d_valid / (d_valid + c_valid)
sep_valid <- a_valid / (a_valid + b_valid)
precision_valid <- d_valid / (b_valid + d_valid)
F1_valid <- (2 * precision_valid * sen_valid) / (precision_valid + sen_valid)
MCC_valid <- (d_valid * a_valid - b_valid * c_valid) / sqrt((d_valid + b_valid) * (d_valid + c_valid) * (a_valid + b_valid) * (a_valid + c_valid))
auc_valid <- roc(response = validData$X, predictor = validPredict)$auc

# Print Metrics
cat("Training Metrics\n")
cat("Accuracy:", acc_train, "\n")
cat("Error Rate:", error_rate_train, "\n")
cat("Sensitivity:", sen_train, "\n")
cat("Specificity:", sep_train, "\n")
cat("Precision:", precision_train, "\n")
cat("F1 Score:", F1_train, "\n")
cat("MCC:", MCC_train, "\n")
cat("AUC:", auc_train, "\n\n")

cat("Validation Metrics\n")
cat("Accuracy:", acc_valid, "\n")
cat("Error Rate:", error_rate_valid, "\n")
cat("Sensitivity:", sen_valid, "\n")
cat("Specificity:", sep_valid, "\n")
cat("Precision:", precision_valid, "\n")
cat("F1 Score:", F1_valid, "\n")
cat("MCC:", MCC_valid, "\n")
cat("AUC:", auc_valid, "\n")

在R语言中,caret包提供了一个通用的接口来训练KNN模型。使用caret的train函数来训练KNN模型时,可以调整多种参数来优化模型的性能:

基本参数:

①formula: 指定模型的公式,如Y ~ .,表示使用数据框中的所有其他变量来预测Y。

②data: 提供包含训练数据的数据框。

③method: 对于KNN模型,这个参数应设置为"knn"。

④preProcess: 预处理步骤,常用的包括标准化("scale")和中心化("center"),对于KNN这一步非常重要因为KNN依赖于变量的距离度量。

⑤trControl: 一个trainControl对象,定义了模型训练的各种控制策略,如交叉验证的类型和重复次数。

trainControl 函数的参数:

①method: 训练的方法,如交叉验证("cv"),重复交叉验证("repeatedcv"),留一交叉验证("LOOCV")等。

②number: 对于"cv"和"repeatedcv",这个参数定义了折数。

③repeats: 当使用"repeatedcv"时,定义重复的次数。

④search: 参数搜索方法,默认为"grid"。也可以设置为"random"进行随机搜索。

⑤savePredictions: 是否保存预测结果,通常用于后续分析。

模型性能调整参数:

使用KNN时,最关键的参数之一是邻居的数量(K值)。这可以通过train函数的以下参数来调整:

①tuneLength: 这个参数决定了在参数搜索中考虑多少个不同的K值。

②tuneGrid: 这是一个数据框,可以自定义K值的具体范围,例如expand.grid(k = c(1, 5, 10))

结果输出(默认参数):

三、KNN调参方法

如前所述,KNN的关键参数就是K值,所以可以对其进行一个暴力测试,比如取值1到10:

# 定义交叉验证的控制方法,启用网格搜索
trainControl <- trainControl(method = "cv", number = 10)
# 定义K值的网格搜索范围
tuneGrid <- expand.grid(k = 1:10)
# 在训练集上拟合KNN模型,指定网格搜索的K值
model <- train(X ~ ., data = trainData, method = "knn", trControl = trainControl,
               tuneGrid = tuneGrid, preProcess = "scale")
# 查看模型结果,找出最优的K值
print(model)

解读:

定义交叉验证的控制方法:使用trainControl函数设定交叉验证的详细参数。

定义K值的网格:使用tuneGrid参数在train函数中指定K值的范围。

拟合模型:使用train函数训练模型,同时应用预处理步骤(比如标准化数据),以确保每个特征在距离计算中具有等同的权重。

结果输出:

注意:用了caret包的train函数,并且通过网格搜索指定了一系列的参数(如K值的范围),那么这个函数会自动选择表现最好的参数配置来训练最终的模型。train函数的输出即是基于你提供的训练数据和参数搜索范围内表现最优的模型。因此,当你调用predict函数进行预测时,使用的就是这个最优化的模型。所以,下面的代码不变。

结果吧,跟之前的完全一样:

因为caret包对于KNN模型默认进行一系列的K值尝试,通常这个范围是1到最多的邻居数,但具体的最大K值依赖于caret的内部设置。在大多数情况下,它会尝试如1, 5, 7, 9等常用的K值。所以,我们默认参数的时候,其实软件自动给我们寻找最优K值了。可以用这个代码输出最有K值:

# Print the best K value used by the model
best_k <- model$bestTune$k
cat("The best K value found is:", best_k, "\n")

K值就是9,跟我们自行调参的一致。

那我们猛点,把K的范围设置的宽一些:

# 定义交叉验证的控制方法,启用网格搜索
trainControl <- trainControl(method = "cv", number = 10)
# 定义K值的网格搜索范围
tuneGrid <- expand.grid(k = 1:20)
# 在训练集上拟合KNN模型,指定网格搜索的K值
model <- train(X ~ ., data = trainData, method = "knn", trControl = trainControl,
               tuneGrid = tuneGrid, preProcess = "scale")
# 查看模型结果,找出最优的K值
print(model)

结果:

K=19,性能指标如下,似乎大同小异:

四、最后

数据嘛:

链接:https://pan.baidu.com/s/1rEf6JZyzA1ia5exoq5OF7g?pwd=x8xm

提取码:x8xm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1849814.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Java实训中心管理系统设计和实现(源码+LW+调试文档+讲解等)

&#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN作者、博客专家、全栈领域优质创作者&#xff0c;博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌&#x1f497; &#x1f31f;文末获取源码数据库&#x1f31f; 感兴趣的可以先收藏起来&#xff0c;…

基于Java学生干部管理系统设计和实现(源码+LW+调试文档+讲解等)

&#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN作者、博客专家、全栈领域优质创作者&#xff0c;博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌&#x1f497; &#x1f31f;文末获取源码数据库&#x1f31f; 感兴趣的可以先收藏起来&#xff0c;…

【机器学习 复习】 第1章 概述

一、概念 1.机器学习是一种通过先验信息来提升模型能力的方式。 即从数据中产生“模型”( model )的算法&#xff0c;然后对新的数据集进行预测。 2.数据集&#xff08;Dataset&#xff09;&#xff1a;所有数据的集合称为数据集。 训练集&#xff1a;用来训练出一个适合模…

聊聊Vue中的Router(路由)

Vue构造的是一个单页面应用 在 Vue 中&#xff0c;router&#xff08;路由&#xff09;用于定义应用的不同页面路径和组件之间的映射关系&#xff0c;通过路由从而实现页面的切换和导航功能 vue中所有的xxx.vue文件&#xff0c;都是路由组件&#xff0c;这些组件都会被vue读取…

MySQL 死锁查询和解决死锁

来了来了来了&#xff01;客户现场又要骂街了&#xff0c;你们这是什么破系统怎么这么慢啊&#xff1f;&#xff01;&#xff1f;&#xff01; 今天遇到了mysql死锁&#xff0c;直接导致服务器CPU被PUA直接GUA了&#xff01; 别的先别管&#xff0c;先看哪里死锁&#xff0c;或…

【Springcloud微服务】Docker下篇

&#x1f525; 本文由 程序喵正在路上 原创&#xff0c;CSDN首发&#xff01; &#x1f496; 系列专栏&#xff1a;Springcloud微服务 &#x1f320; 首发时间&#xff1a;2024年6月22日 &#x1f98b; 欢迎关注&#x1f5b1;点赞&#x1f44d;收藏&#x1f31f;留言&#x1f4…

【MySQL数据库】:MySQL视图特性

视图的概念 视图是一个虚拟表&#xff0c;其内容由查询定义&#xff0c;同真实的表一样&#xff0c;视图包含一系列带有名称的列和行数据。视图中的数据并不会单独存储在数据库中&#xff0c;其数据来自定义视图时查询所引用的表&#xff08;基表&#xff09;&#xff0c;在每…

1931java Web披萨店订餐系统idea开发mysql数据库web结构java编程计算机网页源码servlet项目

一、源码特点 java Web 披萨店订餐系统是一套完善的信息管理系统&#xff0c;结合java 开发技术和bootstrap完成本系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用 B/S模式开发。 视频地址&#xff1a;…

从零开始的Ollama指南:部署私域大模型

大模型相关目录 大模型&#xff0c;包括部署微调prompt/Agent应用开发、知识库增强、数据库增强、知识图谱增强、自然语言处理、多模态等大模型应用开发内容 从0起步&#xff0c;扬帆起航。 大模型应用向开发路径&#xff1a;AI代理工作流大模型应用开发实用开源项目汇总大模…

c语言回顾-结构体(2)

前言 前面讲了结构体的概念&#xff0c;定义&#xff0c;赋值&#xff0c;访问等知识&#xff0c;本节内容小编将讲解结构体的内存大小的计算以及通过结构体实现位段&#xff0c;话不多说&#xff0c;直接上干货&#xff01;&#xff01;&#xff01; 1.结构体内存对齐 说到计…

Ubuntu系统使用快速入门实践(八)—— git 命令使用

Ubuntu系统使用快速入门实践系列文章 下面是Ubuntu系统使用系列文章的总链接&#xff0c;本人发表这个系列的文章链接均收录于此 Ubuntu系统使用快速入门实践系列文章总链接 下面是专栏地址&#xff1a; Ubuntu系统使用快速入门实践系列文章专栏 文章目录 Ubuntu系统使用快速…

UDS服务——RequestDownload(0x34)

诊断协议那些事儿 诊断协议那些事儿专栏系列文章,本文介绍RequestDownload(0x34)—— 请求下载,用于给ECU下载数据的,最常见的应用就是在bootloader中,程序下载工具会发起请求下载,以完成ECU程序的升级。通过阅读本文,希望能对你有所帮助。 文章目录 诊断协议那些事儿…

【C++】一个极简但完整的C++程序

一、一个极简但完整的C程序 我们编写程序是为了解决问题和任务的。 1、任务&#xff1a; 某个书店将每本售出的图书的书名和出版社&#xff0c;输入到一个文件中&#xff0c;这些信息以书售出的时间顺序输入&#xff0c;每两周店主会手工计算每本书的销售量、以及每个出版社的…

【单片机毕业设计选题24020】-全自动鱼缸的设计与应用

系统功能: &#xff08;1&#xff09;检测并控制鱼缸水温&#xff0c;水温低于22℃后开启加热&#xff0c;高于28℃后关闭加热。 &#xff08;2&#xff09;定时喂食&#xff0c;每天12点和0点喂食一次&#xff0c;步进电机开启后再关闭模拟喂食。 &#xff08;3&#xff09…

初学者应该掌握的MySQL数据库的基本组成部分及概念

MySQL数据库作为一种开源的关系型数据库管理系统&#xff0c;被广泛应用于Web应用开发和数据存储。它具有高性能、易用性和可靠性等特点&#xff0c;是开发者们的首选之一。在本篇文章中&#xff0c;我们将详细介绍MySQL数据库的核心组成部分&#xff0c;帮助你深入理解这个强大…

杀疯了!PerfXCloud-AI大模型夏日狂欢来袭,向基石用户赠送 ∞ 亿Token!

【澎峰科技重磅消息】 在全球范围内大模型正逐渐成为强大的创新驱动力。在这个充满激情的夏日&#xff0c;PerfXCloud为开发者和企业带来了前所未有的福利&#xff1a; 1. 零成本亲密、深度体验大模型&#xff0c;提供大量示范案例。 2. 向基石用户赠送∞亿Token的激励计划。…

展讯-QMI8658和气压传感器驱动调试

1.调试QMI8658 参考demo&#xff0c;添加QMI8610相关内容 当前驱动路径位于&#xff1a;bsp/modules/input/misc/qmi8610/qmi8610.c 编译使用make sockoimage 用fastboot烧录 1.确定驱动被正常加载 代码添加之后&#xff0c;首先确定有没有进入当前驱动文件 dmesg |grep …

Gradle 自动化项目构建-Gradle 核心之 Project

一、前言 从明面上看&#xff0c;Gradle 是一款强大的构建工具&#xff0c;但 Gradle 不仅仅是一款强大的构建工具&#xff0c;它更像是一个编程框架。Gradle 的组成可以细分为如下三个方面&#xff1a; groovy 核心语法&#xff1a;包括 groovy 基本语法、闭包、数据结构、面…

C#调用OpenCvSharp实现图像的直方图均衡化

本文学习基于OpenCvSharp的直方图均衡化处理方式&#xff0c;并使用SkiaSharp绘制相关图形。直方图均衡化是一种图像处理方法&#xff0c;针对偏亮或偏暗的图像&#xff0c;通过调整图像的像素值来增强图像对比度&#xff0c;详细原理及介绍见参考文献1-4。   直方图均衡化第…

计算机组成原理笔记-第1章 计算机系统概论

第一章 计算机系统概论 笔记PDF版本已上传至Github个人仓库&#xff1a;CourseNotes&#xff0c;欢迎fork和star&#xff0c;拥抱开源&#xff0c;一起完善。 该笔记是最初是没打算发网上的&#xff0c;所以很多地方都为了自我阅读方便&#xff0c;我理解了的地方就少有解释&a…