ROC和AUC也不是评估机器学习性能的金标准

news2024/11/15 1:57:28

对于不平衡数据集,AUC值是分类器效果评估的常用标准。但如果在解释时不仔细,它也会有一些误导。以Davis and Goadrich (2006)中的模型为例。如图所示,左侧展示的是两个模型的ROC曲线,右侧展示的是precision-recall曲线 (PRC)。

图片

Precision值和Recall值是既矛盾又统一的两个指标,为了提高Precision值,分类器需要尽量在“更有把握”时才把样本预测为正样本,但此时往往会因为过于保守而漏掉很多“没有把握”的正样本,导致Recall值降低。

RPC的横轴是召回率,纵轴是精准率。对于一个分类模型来说,其PRC线上的一个点代表着,在某一阈值下,模型将大于该阈值的结果判定为正样本,小于该阈值的结果判定为负样本,此时返回结果对应的召回率和精准率。整条PRC曲线是通过将阈值从高到低移动而生成的。

上图是PRC曲线样例图,其中实线代表模型A的PRC曲线,虚线代表模型B的PRC曲线。原点附近代表当阈值最大时模型的精准率和召回率 (阈值越大,鉴定出的样品越真,能鉴定出的样品越少)。

模型1 (Curve 1)的AUC值为0.813, 模型2 (Curve 2)的AUC值为0.875, 从AUC值角度看模型2更优一点。但是右侧的precision-recall曲线却给出完全不同的结论。模型1 (precision-recall Curve 1)下的面积为0.513,模型2 (precision-recall Curve 2)下的面积为0.038。模型1在较低的假阳性率(FPR<0.2)时有较高的真阳性率。

图片

我们再看另一个关于ROC曲线误导性的例子 Fawcett (2005). 这里有两套数据集:一套为平衡数据集(两类分组为1:1关系),一套为非平衡数据集(两类分组为10:1关系)。每套数据集分别构建2个模型并绘制ROC曲线,从Fig \ref(fig:rocprbalanceimbalance) a,c 可以看出,数据集是否平衡对ROC曲线的影响很小。只是在两个模型之间有一些差别,实线代表的模型在假阳性率较低时 (FPR<0.1)真阳性率低于虚线代表的模型。但precision-recall curve (PRC)曲线却差别很大。对于平衡数据集,两个模型的召回率 (recall)和精准率precision都比较好。对于非平衡数据集,虚线代表的分类模型在较低的召回率时就有较高的精准率。

图片

因此,Saito and Rehmsmeier (2015)推荐在处理非平衡数据集时使用PRC曲线,它所反映的信息比ROC曲线更明确。

我们对前面5个模型计算下AUPRC,与AUC结果基本吻合,up效果最好,其次是weightedsmotedownoriginal。值得差距稍微拉大了一些。

library("PRROC")
calc_auprc <- function(model, data){

  index_class2 <- data$Class == minorityClass
  index_class1 <- data$Class == majorityClass

  predictions <- predict(model, data, type = "prob")

  pr.curve(predictions[[minorityClass]][index_class2],
           predictions[[minorityClass]][index_class1],
           curve = TRUE)

}

# Get results for all 5 models

model_list_pr <- model_list %>%
  map(calc_auprc, data = imbal_test)

model_list_pr %>%
  map(function(the_mod) the_mod$auc.integral)

计算的AUPRC值如下(越大越好)

## $original
## [1] 0.5155589
## 
## $weighted
## [1] 0.640687
## 
## $down
## [1] 0.5302778
## 
## $up
## [1] 0.6461067
## 
## $SMOTE
## [1] 0.6162899

我们绘制PRC曲线观察各个模型的分类效果。基于选定的分类阈值,up samplingweighting有着最好的精准率和召回率 (单个分组的准确率)。而原始分类器则效果最差。

图片

假如加权分类器在召回率 (recall)为75%时,精准率可以达到50% (下面曲线中略低于50%),则F1得分为0.6

原始分类器在召回率为75%时,精准率为25% (下面曲线略高于25%),则F1得分为0.38

也就是说,当构建好了这两个分类器,并设置一个分类阈值 (不同模型的阈值不同)后,都可以在样品少的分组中获得75%的召回率。但是对于加权模型,有50%的预测为属于样品少的分组的样品是预测对的。而对于原始模型,只有25%预测为属于样品少的分组的样品是预测对的。

# Plot the AUPRC curve for all 5 models

results_list_pr <- list(NA)
num_mod <- 1

for(the_pr in model_list_pr){

  results_list_pr[[num_mod]] <- 
    data_frame(recall = the_pr$curve[, 1],
               precision = the_pr$curve[, 2],
               model = names(model_list_pr)[num_mod])

  num_mod <- num_mod + 1

}

results_df_pr <- bind_rows(results_list_pr)

results_df_pr$model <- factor(results_df_pr$model, 
                               levels=c("original", "down","SMOTE","up","weighted"))

# Plot ROC curve for all 5 models

custom_col <- c("#000000", "#009E73", "#0072B2", "#D55E00", "#CC79A7")

ggplot(aes(x = recall,  y = precision, group = model), data = results_df_pr) +
  geom_line(aes(color = model), size = 1) +
  scale_color_manual(values = custom_col) +
  geom_abline(intercept =
                sum(imbal_test$Class == minorityClass)/nrow(imbal_test),
              slope = 0, color = "gray", size = 1)  +
  theme_bw(base_size = 18) + coord_fixed(1)

图片

基于AUPRC进行调参,修改参数summaryFunction = prSummarymetric = "AUC"。参考https://topepo.github.io/caret/measuring-performance.html (或者看之前的推文)。

# Set up control function for training
ctrlprSummary <- trainControl(method = "repeatedcv",
                     number = 10,
                     repeats = 5,
                     summaryFunction = prSummary,
                     classProbs = TRUE)

# Build a standard classifier using a gradient boosted machine

set.seed(5627)
orig_fit2 <- train(Class ~ .,
                  data = imbal_train,
                  method = "gbm",
                  verbose = FALSE,
                  metric = "AUC",
                  trControl = ctrlprSummary)

# Use the same seed to ensure same cross-validation splits
ctrlprSummary$seeds <- orig_fit$control$seeds

# Build weighted model

weighted_fit2 <- train(Class ~ .,
                      data = imbal_train,
                      method = "gbm",
                      verbose = FALSE,
                      weights = model_weights,
                      metric = "AUC",
                      trControl = ctrlprSummary)

# Build down-sampled model

ctrlprSummary$sampling <- "down"

down_fit2 <- train(Class ~ .,
                  data = imbal_train,
                  method = "gbm",
                  verbose = FALSE,
                  metric = "AUC",
                  trControl = ctrlprSummary)

# Build up-sampled model

ctrlprSummary$sampling <- "up"

up_fit2 <- train(Class ~ .,
                data = imbal_train,
                method = "gbm",
                verbose = FALSE,
                metric = "AUC",
                trControl = ctrlprSummary)

# Build smote model
ctrlprSummary$sampling <- "smote"

smote_fit2 <- train(Class ~ .,
                   data = imbal_train,
                   method = "gbm",
                   verbose = FALSE,
                   metric = "AUC",
                   trControl = ctrlprSummary)

model_list2 <- list(original = orig_fit2,
                   weighted = weighted_fit2,
                   down = down_fit2,
                   up = up_fit2,
                   SMOTE = smote_fit2)

评估下基于prSummary调参后模型的性能,SMOTE处理后的模型效果有提升,其它模型相差不大。

model_list_pr2 <- model_list2 %>%
  map(calc_auprc, data = imbal_test)

model_list_pr2 %>%
  map(function(the_mod) the_mod$auc.integral)

计算的AUPRC值如下(越大越好)

## $original
## [1] 0.5155589
## 
## $weighted
## [1] 0.640687
## 
## $down
## [1] 0.5302778
## 
## $up
## [1] 0.6461067
## 
## $SMOTE
## [1] 0.6341753

绘制PRC曲线

# Plot the AUPRC curve for all 5 models

results_list_pr <- list(NA)
num_mod <- 1

for(the_pr in model_list_pr2){

  results_list_pr[[num_mod]] <- 
    data_frame(recall = the_pr$curve[, 1],
               precision = the_pr$curve[, 2],
               model = names(model_list_pr)[num_mod])

  num_mod <- num_mod + 1

}

results_df_pr <- bind_rows(results_list_pr)

results_df_pr$model <- factor(results_df_pr$model, 
                               levels=c("original", "down","SMOTE","up","weighted"))

# Plot ROC curve for all 5 models

custom_col <- c("#000000", "#009E73", "#0072B2", "#D55E00", "#CC79A7")

ggplot(aes(x = recall,  y = precision, group = model), data = results_df_pr) +
  geom_line(aes(color = model), size = 1) +
  scale_color_manual(values = custom_col) +
  geom_abline(intercept =
                sum(imbal_test$Class == minorityClass)/nrow(imbal_test),
              slope = 0, color = "gray", size = 1)  +
  theme_bw(base_size = 18) + coord_fixed(1)

图片

PRCAUPRC是处理非平衡数据集的有效衡量方式。基于AUC指标来看,权重和重采样技术只带来了微弱的性能提升。但是这个改善更多体现在可以在较低假阳性率基础上获得较高真阳性率,模型的性能更均匀提升。在处理非平衡样本学习问题时,除了尝试调整权重和重采样之外,也不能完全依赖AUC值,而是依靠PRC曲线联合判断,以期获得更好的效果。

References

http://pages.cs.wisc.edu/~jdavis/davisgoadrichcamera2.pdf

http://people.inf.elte.hu/kiss/11dwhdm/roc.pdf

http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0118432

https://dpmartin42.github.io/posts/r/imbalanced-classes-part-2

https://zhuanlan.zhihu.com/p/64963796

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2151495.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

前端univer创建、编辑excel

前端univer创建、编辑excel 源码在线demo&#xff1a;https://codesandbox.io/p/sandbox/univer-q87kqg?file/src/Demo.jsx univer官网地址&#xff1a;https://univer.ai/zh-CN/guides/sheet/introduction 安装univer npm install univerjs/core univerjs/design univerjs…

大模型爬虫—ScrapeGraphAI

大模型爬虫—ScrapeGraphAI 一、介绍 ScrapeGraphAI是一个网络爬虫 Python 库,使用大型语言模型和直接图逻辑为网站和本地文档(XML,HTML,JSON 等)创建爬取管道。 只需告诉库您想提取哪些信息,它将为您完成! scrapegraphai有三种主要的爬取管道可用于从网站(或本地文…

dockerfile 添加arthas 监控插件。容器添加arthas监控

1. arthas官网&#xff1a; 简介 | arthas 2. arthas下载地址&#xff1a; Releases alibaba/arthas GitHub 3. 下载版本&#xff1a; 4. 下载压缩包后&#xff0c;解压缩&#xff0c;放入Dockerfile 同级目录 5. dockerfile 命令&#xff1a; RUN mkdir -p /opt/arthas…

HarmonyOS鸿蒙开发实战(5.0)自定义全局弹窗实践

鸿蒙HarmonyOS开发实战往期文章必看&#xff1a; HarmonyOS NEXT应用开发性能实践总结 最新版&#xff01;“非常详细的” 鸿蒙HarmonyOS Next应用开发学习路线&#xff01;&#xff08;从零基础入门到精通&#xff09; 非常详细的” 鸿蒙HarmonyOS Next应用开发学习路线&am…

【RPA私教课:UIPath】RPA 赋能科技企业,登录时验证码自动截取

在某科技型企业里&#xff0c;专门设置了一个验证码接收系统。每当用户进行登录操作时&#xff0c;都必须从这个系统中抓取最新的登录验证码&#xff0c;以确保登录的安全性。 具体需求如下&#xff1a; 客户会预先在表格中妥善保存众多的账户和密码。当 RPA 机器人在业务系统…

weblogic CVE-2017-3506 靶场攻略

漏洞描述 Weblogic的WLS Security组件对外提供了webserver服务&#xff0c;其中使⽤了XMLDecoder来解析⽤户输⼊的XML数据&#xff0c;在解析过程中出现反序列化漏洞&#xff0c;可导致任意命令执⾏。 影响版本 受影响版本&#xff1a;WebLogic 10.3.6.0, 12.1.3.0, 12.2.1.…

idea启动oom了解决

解决 Error:java: java.lang.OutOfMemoryError: WrappedJavaFileObject[org.jetbrains.jps.javac.InputFileObject[file:///D:/mingan/pb/backend/src/main/java/com/cy/backend/service/impl/StorageServiceImpl.java]]pos36199: WrappedJavaFileObject[org.jetbrains.jps.j…

提升效率的AI工具集 - 轻松实现自动化

在这个快节奏、高效率的社会中&#xff0c;我们每个人都渴望能够找到提升工作效率的捷径。幸运的是&#xff0c;随着人工智能&#xff08;AI&#xff09;技术的迅猛发展&#xff0c;越来越多的AI工具涌现出来&#xff0c;为我们提供了强大的支持。这些工具不仅能够帮助我们提高…

JavaScript可视化

JavaScript 可视化通常涉及利用各种库和工具将数据转化为图形的形式&#xff0c;从而更直观地呈现信息。以下是一些流行的 JavaScript 可视化工具和库&#xff0c;以及一些关键知识点&#xff1a; 流行的 JavaScript 可视化库&#xff1a; 1. D3.js (Data-Driven Documents)&…

xtop:pt dmsa环境下如何写出timing data file

我正在「拾陆楼」和朋友们讨论有趣的话题,你⼀起来吧? 拾陆楼知识星球入口 xtop的输入数据之一就是sta timing data,它是由各scenario的报告组成的,

MedPrompt:基于提示工程的医学诊断准确率优化方法

Medprompt&#xff1a;基于提示工程的医学诊断准确率优化方法 秒懂大纲解法拆解MedPrompt 提示词全流程分析总结创意视角 论文&#xff1a;Can Generalist Foundation Models Outcompete Special-Purpose Tuning? Case Study in Medicine 秒懂大纲 ├── 1 研究背景【描述背…

VIVADO IP核之FIR插值器多相滤波仿真

VIVADO IP核之FIR插值器多相滤波仿真&#xff08;含有与MATLAB仿真数据的对比&#xff09; 目录 前言 一、滤波器系数生成 二、用MATLAB生成仿真数据 三、VIVADO FIR插值多相滤波器使用 四、VIVADO FIR插值多相滤波器仿真 五、VIVADO工程下载 总结 前言 网络上有许多文章…

生信初学者教程(五):R语言基础

文章目录 数据类型整型逻辑型字符型日期型数值型复杂数数据结构向量矩阵数组列表因子数据框ts特殊值缺失值 (NA)无穷大 (Inf)非数字 (NaN)安装R包学习材料R语言是一种用于统计计算和图形展示的编程语言和软件环境,广泛应用于数据分析、统计建模和数据可视化。1991年:R语言的最…

webpack4 target:“electron-renderer“ 打包加速配置

背景 昨天写得一篇Electron-vue asar 局部打包优化处理方案——绕开每次npm run build 超级慢的打包问题-CSDN博客文章浏览阅读754次&#xff0c;点赞19次&#xff0c;收藏11次。因为组员对于 Electron 打包过程存在比较迷糊的状态&#xff0c;且自己也没主动探索 Electron-vu…

CX8903:电动车手机充电器降压芯片,搭配协议实现快充

CX8903&#xff1a;一款专用于电动车手机充电器的降压芯片&#xff0c;搭配协议实现快充。 在城市的车水马龙中&#xff0c;电动自行车如灵动的精灵&#xff0c;便捷着我们的出行生活。在骑行的路上&#xff0c;随时保持连接&#xff0c;电动自行车手机充电器让手机电量满满。…

汽车应用生态系统的飞跃

在过去的几年里&#xff0c;汽车系统经历了前所未有的变革&#xff0c;驾驶员和乘客对于车内体验的期待已远远超越了传统的驾驶范畴。随着技术的不断进步&#xff0c;基于Android Automotive OS&#xff08;AAOS&#xff09;和Google Automotive Services&#xff08;GAS&#…

在 Python 中使用 JSON

了解如何在 Python 中使用 JSON&#xff0c;从基础到高级技术。本指南涵盖解析、序列化、API 集成和最佳实践。 1. JSON 简介 1.1. 什么是 JSON&#xff1f; JSON&#xff08;JavaScript 对象表示法&#xff09;是一种轻量级数据交换格式&#xff0c;人类可以轻松读取和写入…

线性规划中可行域为什么一定是凸的--证明

线性规划中的凸性证明 线性规划中可行域是凸的&#xff0c;这是自然能够想到和容易理解的道理。直观上&#xff0c;线性约束定义的可行域是由半平面的交集构成的&#xff0c;这些半平面的交集总是形成凸区域。 这么一个自然想到、容易理解的道理&#xff0c;怎么从数学上完备…

机器翻译与数据集_by《李沐:动手学深度学习v2》pytorch版

系列文章目录 文章目录 系列文章目录介绍机器翻译下载和预处理数据集词元化词表加载数据集训练模型对上述代码中出现的Vocab进行总体解释和逐行解释使用场景 小结练习答案1. num_examples 参数对词表大小的影响2. 对于没有单词边界的语言&#xff0c;单词级词元化的有效性 介绍…

低代码平台后端搭建-阶段完结

前言 最近又要开始为跳槽做准备了&#xff0c;发现还是写博客学的效率高点&#xff0c;在总结其他技术栈之前准备先把这个专题小完结一波。在这一篇中我又试着添加了一些实际项目中可能会用到的功能点&#xff0c;用来验证这个平台的扩展性&#xff0c;以及总结一些学过的知识。…