机器学习 | 计算分类算法的ROC和AUC曲线以随机森林为例

news2025/4/28 2:56:13

受试者工作特征(ROC)曲线和曲线下面积(AUC)是常用的分类算法评价指标,本文将讨论如何计算随机森林分类器的ROC 和 AUC。

ROC 和 AUC是量化二分类区分阳性和阴性类别能力的度量。ROC曲线是针对不同分类阈值的真阳性率(TPR)对假阳性率(FPR)的图。TPR是真阳性与阳性示例总数的比率,而FPR是假阳性与阴性示例总数的比率。AUC是ROC曲线下面积,范围为0.0至1.0,值越高表示分类器性能越好。

具体步骤

1.导入所需模块

from sklearn.ensemble import RandomForestClassifier 
from sklearn.metrics import roc_curve, roc_auc_score 
from sklearn.datasets import load_breast_cancer 
from sklearn.model_selection import train_test_split 
import matplotlib.pyplot as plt

这里我们导入所需的模块,包括分别来自sklearn.ensemble和sklearn.metrics模块的RandomForestClassifier和roc_curve函数。我们还从sklearn.datasets模块导入load_breast_cancer函数来加载乳腺癌数据集,并从sklearn.model_selection模块导入train_test_split函数来将数据集拆分为训练集和测试集。最后,我们从matplotlib库中导入pyplot模块来绘制ROC曲线。

2.加载并拆分数据集

加载数据集并分离特征和目标值,然后拆分训练和测试数据集。

df = load_breast_cancer(as_frame=True) 
df = df.frame 

x = df.drop('target',axis=1) 
y = df[['target']] 
# Split the dataset into training and test sets 
X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=0.3)

3.训练随机森林分类器

# Train a Random Forest classifier 
rf = RandomForestClassifier(n_estimators=5, max_depth=2) 
rf.fit(X_train, y_train)

在这里,我们使用RandomForestClassifier函数训练一个随机森林分类器,其中包含5个估计量和最大深度2。我们使用拟合方法将分类器拟合到训练数据。

4.获取测试集的预测类概率

# Get predicted class probabilities for the test set 
y_pred_prob = rf.predict_proba(X_test)[:, 1] 

在这里,我们使用随机森林分类器的predict_proba方法来获得测试集的预测类概率。该方法返回一个形状数组(n_samples,n_classes),其中n_samples是测试集中的样本数,n_classes是问题中的类数。因为我们使用的是二元分类器,所以n_classes等于2,我们感兴趣的是正类的概率。 这是数组的第二列。因此,我们使用 [:,1] 索引来获得正类概率的一维数组。

5.计算不同分类阈值的假阳性率(FPR)和真阳性率(TPR)

# Compute the false positive rate (FPR) 
# and true positive rate (TPR) for different classification thresholds 
fpr, tpr, thresholds = roc_curve(y_test, y_pred_prob, pos_label=1)

在这里,我们使用sklearn.metrics模块中的roc_curve函数来计算不同分类阈值的假阳性率(FPR)和真阳性率(TPR)。该函数将测试集的真标签(y_test)和阳性类的预测类概率(y_pred_prob)作为输入。它返回三个数组:fpr,其包含不同阈值的FPR值; tpr,其包含不同阈值的TPR值;以及thresholds,其包含阈值。

6.计算ROC AUC评分

# Compute the ROC AUC score 
roc_auc = roc_auc_score(y_test, y_pred_prob) 
roc_auc

输出

0.9787264420331239

这里我们使用sklearn.metrics模块中的roc_auc_score函数来计算ROC AUC分数。该函数将测试集的真标签(y_test)和阳性类的预测类概率(y_pred_prob)作为输入。它返回表示ROC曲线下面积的标量值。

7.绘制ROC曲线

# Plot the ROC curve 
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc) 
# roc curve for tpr = fpr 
plt.plot([0, 1], [0, 1], 'k--', label='Random classifier') 
plt.xlabel('False Positive Rate') 
plt.ylabel('True Positive Rate') 
plt.title('ROC Curve') 
plt.legend(loc="lower right") 
plt.show()

在这里插入图片描述
这里我们使用pyplot模块的plot函数来绘制ROC曲线。我们在x轴上传递FPR值,在y轴上传递TPR值。我们还将ROC AUC评分作为面积添加到图中。我们绘制虚线来表示随机分类器,其具有从(0,0)到(1,1)的直线的ROC曲线。我们为图添加轴标签和标题,以及显示ROC AUC得分和随机分类器线的图例。

说明:
ROC曲线是对于不同分类阈值,y轴上的真阳性率(TPR)对x轴上的假阳性率(FPR)的图。ROC曲线显示了分类器在不同阈值下区分阳性和阴性类别的能力。一个完美的分类器的TPR为1,FPR为0,对应于图的左上角。另一方面,随机分类器将具有从(0,0)到(1,1)的直线的ROC曲线,这是图中的虚线。ROC曲线越接近左上角,分类器的性能越好。

ROC曲线可用于选择分类器的最佳阈值,这取决于TPR和FPR之间的权衡。接近1的阈值将具有较低的FPR但较高的TPR,而接近0的阈值将具有较高的FPR但较低的TPR。

8.绘制预测类概率

# Plot the predicted class probabilities 
plt.hist(y_pred_prob, bins=10) 
plt.xlim(0, 1) 
plt.title('Histogram of predicted probabilities') 
plt.xlabel('Predicted probability of Setosa') 
plt.ylabel('Frequency') 
plt.show() 

在这里插入图片描述

多分类的ROC曲线示例

这里使用sklearn.datasets的iris数据集,它有3个类。ROC曲线可用于二分类,因此,这里我们将使用来自sklearn.multiclass的OneVsRestClassifier和Random forest作为分类器,绘制ROC曲线。

from sklearn.ensemble import RandomForestClassifier 
from sklearn.metrics import roc_curve, roc_auc_score 
from sklearn.datasets import load_iris 
from sklearn.multiclass import OneVsRestClassifier 
from sklearn.model_selection import train_test_split 
import matplotlib.pyplot as plt 


# Load the iris dataset 
iris = load_iris() 

# Split the dataset into training and test sets 
X_train, X_test, y_train, y_test = train_test_split(iris.data, 
													iris.target, 
													test_size=0.5, 
													random_state=23) 

# Train a Random Forest classifier 
clf = OneVsRestClassifier(RandomForestClassifier()) 

# fit model 
clf.fit(X_train, y_train) 

# Get predicted class probabilities for the test set 
y_pred_prob = clf.predict_proba(X_test) 

# Compute the ROC AUC score 
roc_auc = roc_auc_score(y_test, y_pred_prob, multi_class='ovr') 
print('ROC AUC Score :',roc_auc) 

# roc curve for Multi classes 
colors = ['orange','red','green'] 
for i in range(len(iris.target_names)):	 
	fpr, tpr, thresh = roc_curve(y_test, y_pred_prob[:,i], pos_label=i) 
	plt.plot(fpr, tpr, linestyle='--',color=colors[i], label=iris.target_names[i]+' vs Rest') 
# roc curve for tpr = fpr 
plt.plot([0, 1], [0, 1], 'k--', label='Random classifier') 
plt.title('Multiclass (Iris) ROC curve') 
plt.xlabel('False Positive Rate') 
plt.ylabel('True Positive rate') 
plt.legend() 
plt.show()

输出

ROC AUC Score : 0.9795855072463767

在这里插入图片描述

总结

总之,计算随机森林分类器的ROC AUC分数在Python中是一个简单的过程。sklearn.metrics模块提供了计算ROC曲线、ROC AUC评分和PR曲线的函数。ROC曲线和PR曲线是评估二值分类器性能的有用工具,它们可以帮助基于不同评估指标之间的权衡来选择分类器的最佳阈值。

PR(precision-recall)曲线是二元分类问题的另一个评估指标。PR曲线是针对不同分类阈值的精确度(y轴)对召回率(x轴)的图。精确度被定义为真阳性的数量除以真阳性加假阳性的数量,而召回率被定义为真阳性的数量除以真阳性加假阴性的数量。PR曲线显示了分类器在最小化误报的同时预测阳性类别的能力。

与ROC曲线相比,PR曲线更适合不平衡数据集,其中阳性类别中的样本数量远小于阴性类别中的样本数量。当假阳性和假阴性的成本不同时,PR曲线也很有用,因为它可以帮助基于精确度-召回率权衡为分类器选择最佳阈值。

重要的是要注意,ROC AUC不应该是用于评估分类器性能的唯一度量。其他指标,如精确度、召回率和F1分数,也可能有用,具体取决于应用程序的具体要求。此外,重要的是要考虑数据中正面和负面示例的总体分布以及不平衡类对评估指标的潜在影响。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1956366.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Mac电脑 系统监测工具 System Dashboard Pro【简单操作,小白轻松上手】

Mac分享吧 文章目录 效果一、下载软件二、开始安装1、双击运行软件,将其从左侧拖入右侧文件夹中,等待安装完毕2、应用程序显示软件图标,表示安装成功 三、运行测试安装完成!!! 效果 一、下载软件 下载软件…

opencascade AIS_PlaneTrihedron 源码学习

AIS_PlaneTrihedron 前言 构建一个可选择的2D轴系在3D绘图中。 这个轴系可以放置在3D系统中的任何位置,提供一个用于在平面中绘制曲线和形状的坐标系。 有三种选择模式: 模式0 选择整个平面“trihedron” 模式1 选择平面“trihedron”的原点 模式2 选择…

Nuxt.js 路由管理:useRouter 方法与路由中间件应用

title: Nuxt.js 路由管理:useRouter 方法与路由中间件应用 date: 2024/7/28 updated: 2024/7/28 author: cmdragon excerpt: 摘要:本文介绍了Nuxt 3中useRouter方法及其在路由管理和中间件应用中的功能。内容包括使用useRouter添加、移除路由&#xf…

Cesium高性能渲染海量矢量建筑

0、数据输入为类似Geojson的压缩文件和纹理图片,基于DrawCommand命令绘制; 1、自定义建筑几何,包括顶点、法线、纹理等; 2、自定义纹理贴图,包括按建筑高度贴图、mipmap多级纹理; 3、自定义批处理表&…

我的新书《Android系统多媒体进阶实战》正式发售了!!!

我的新书要正式发售了,把链接贴在下面,感兴趣的朋友可以支持下。 ❶发售平台:当当,京东,抖音北航社平台,小红书,b站 ❷目前当当和京东已开启预售 ❸当当网 https://u.dangdang.com/KIDHJ ❹…

22 B端产品经理与MySQL基本查询、排序(2)

MySQL基本常识 MySQL:一种关系型数据库管理系统。是按照数据结构来组织、存储和管理数据的仓库。 数据库:是一些关联数据表的集合。 数据表:表是数据的矩阵,看起来像电子表格,如下图:user表和admin表。 …

⌈ 传知代码 ⌋ 红外小目标检测

💛前情提要💛 本文是传知代码平台中的相关前沿知识与技术的分享~ 接下来我们即将进入一个全新的空间,对技术有一个全新的视角~ 本文所涉及所有资源均在传知代码平台可获取 以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦&#x…

keil5导入程序到stm32的开发板

如图, 1,安装mdk_514.exe 2,安装Keil.STM32F1xx_DFP.1.0.5.pack 3,注册方法(仅限学生使用):http://www.openedv.com/thread-69384-1-1.html 点击keil程序的上面魔法棒, 在device中…

类中的function无法正确被matlab所识别,该怎么操作呢?

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

【Linux】CentOS更换国内阿里云yum源(超详细)

目录 1. 前言2. 打开终端3. 确保虚拟机已经联网4. 备份现有yum配置文件5. 下载阿里云yum源6. 清理缓存7. 重新生成缓存8. 测试安装gcc 1. 前言 有些同学在安装完CentOS操作系统后,在系统内安装比如:gcc等软件的时候出现这种情况:&#xff08…

SpringBoot3如何整合Redis?

SpringBoot应该不用介绍!它是Spring当前最火的一个框架,整合Spring Boot 3和Redis可以显著提升应用程序的性能,特别是在处理大量数据和需要快速访问的场景下。 在Spring Boot中,从1.x版本到2.x版本的Redis连接方式发生了变化&…

点脂成金携手北京新颜兴医疗美容医院,共启战略合作新篇章

2024年7月24日上午,点脂成金品牌方与北京新颜兴医疗美容医院在京举行了隆重的签约仪式,宣布达成战略合作关系,共同开启医疗美容领域的设备共享新篇章。 签约仪式在北京纯脂医疗美容门诊部有限公司举行,现场氛围热烈而庄重。点脂成…

使用 WebSocket 实现实时聊天

个人名片 🎓作者简介:java领域优质创作者 🌐个人主页:码农阿豪 📞工作室:新空间代码工作室(提供各种软件服务) 💌个人邮箱:[2435024119qq.com] &#x1f4f1…

基于opencv的人脸识别(实战)

前言 经过这几天的学习,我已经跃跃欲试了,相信大家也是,所以我决定自己做一个人脸识别程序。我会把自己的思路和想法都在这篇博客内讲清楚,大家可以当个参考,🌟仅供学习使用🌟。 &#x1f31f…

黑马程序员2024最新SpringCloud微服务开发与实战 个人学习心得、踩坑、与bug记录Day5 全网最快最全

你好,我是Qiuner. 为帮助别人少走弯路和记录自己编程学习过程而写博客 这是我的 github https://github.com/Qiuner ⭐️ gitee https://gitee.com/Qiuner 🌹 如果本篇文章帮到了你 不妨点个赞吧~ 我会很高兴的 😄 (^ ~ ^) 想看更多 那就点个关注吧 我会…

树莓派_Opencv学习笔记23:模版样本匹配

今日继续学习树莓派4B 4G:(Raspberry Pi,简称RPi或RasPi) 本人所用树莓派4B 装载的系统与版本如下: 版本可用命令 (lsb_release -a) 查询: ​ Opencv 版本是4.5.1: ​ Python 版本3.7.3: 今日学习Opencv样本…

香烟商品销售网站

1 香烟商品销售网站概述 1.1 课题简介 1.2 设计目的 1.3 系统开发所采用的技术 1.4 系统功能模块 2 数据库设计 2.1 建立的数据库名称 2.2 所使用的表 3 香烟商品销售网站设计与实现 1. 注册登录: 2. 分页查询: 3. 分页条件(精确、…

速卖通卖家如何利用自养号测评,让店铺曝光量飙升?

在速卖通这个竞争激烈的跨境电商平台上,店铺曝光率是决定销售成败的关键因素之一。为了在众多商家中脱颖而出,增加速卖通店铺曝光显得尤为重要。速卖通怎么增加店铺曝光? 速卖通怎么增加店铺曝光? 1、优化产品列表 速卖通的产品列表是买家…

数据库实验:连接查询

一、实验目的: 1、掌握使用两种写法完成连接查询:叉积连接语法和内连接语法。 2、掌握使用外连接语法完成查询。 3、掌握使用派生表完成下列查询。 二、实验内容: 1. 使用连接实现查询,查询订单号为‘000005’的订单订购的玩具…

windows 安装docker桌面版

下载 下载两个: git桌面版 docker desktop 启动docker 执行安装文件,启动 更新wsl2 假如报错,会提示失败原因。 win10会提示跳转到: https://learn.microsoft.com/zh-cn/windows/wsl/install-manual#step-4—download-the-l…