多分类的ROC曲线绘制思路

news2025/1/16 18:58:09

目录

一、什么是ROC曲线

二、AUC面积

三、代码示例

1、二分类问题

2、多分类问题


一、什么是ROC曲线

我们通常说的ROC曲线的中文全称叫做接收者操作特征曲线(receiver operating characteristic curve),也被称为感受性曲线。

该曲线有两个维度,横轴为fpr(假正率),纵轴为tpr(真正率)

  • 准确率(accuracy):(TP+TN)/ ALL =(3+4)/ 10 准确率是所有预测为正确的样本除以总样本数,用以衡量模型对正负样本的识别能力。
  • 错误率(error rate):(FP+FN)/ ALL =(1+2)/ 10 错误率就是识别错误的样本除以总样本数。
  • 假正率(fpr):FP / (FP+TN) = 2 / (2+4)假正率就是真实负类中被预测为正类的样本数除以所有真实负类样本数。
  • 真正率(tpr):TP / (TP+FN)= 3 / (3+1)真正率就是真实正类中被预测为正类的样本数除以所有真实正类样本数。

  • 这个曲线是怎么画出来的呢?

这幅曲线的每个点都对应一个(fpr,tpr),看过之前混淆矩阵的话,感觉一堆数据好像最终只能算出一个fpr和tpr,那是如何获得这么多的点的呢?

我们y会有一个预测为正类的概率,我们会将这个概率从大到小进行排序,然后再每个概率阈值的情况下进行将数据分类,能够计算出一组新的fpr和tpr,这样就会再不同的概率阈值下计算出多组点。

二、AUC面积

上面说了什么是ROC曲线,横轴为fpr,纵轴为tpr,fpr就是假正率,就是真实的负类被预测为正类的样本数除以所有真实的样本数,而tpr是真正率,是正式的正类被预测为正类的数目除以所有正类的样本数,所以我们的目标就是让假正率越小,真正率越大(负样本被误判的样本数越少,正样本被预测正确的样本数越大),那么对应图中ROC曲线来说,我们希望的是fpr越小,tpr越大,即曲线越靠近左上方越好,那么下面的面积也就成为了我们的衡量指标,俗称AUC(Area Under Curve)。

解释下图中四个点的含义:

  • (0,0):该点代表fpr和tpr都为0,也就是说此时负类样本全部预测正确,而正类样本全部预测错误。
  • (0,1):该点代表fpr为0,tpr为1,就是说此时负类样本全部预测正确,正类样本也全部预测正确,这也验证了我们上面说曲线越靠近左上方越好。
  • (1,0):该点代表fpr为1,tpr为0,即负类样本全部被预测为正类样本,而正类样本全部被预测为负类样本,这是最坏的情况,也就对应了图中的右下角。
  • (1,1):该点代表fpr为1,tpr为1,即负类样本全部预测错误,但是正类样本全部预测正确

y=x这条直线对应fpr和tpr是相等的,也就是我们模型评估正类和负类的能力是一样的,我们一般认为曲线要在该直线上方才有意义。

三、代码示例

下面是官方给出的示例,y代表着真实分类,scores代表着y对应的每个样本被预测为正类的概率,也就是说对于第一个样本来说,预测为2的概率为0.1,预测为1的概率为0.9,以此类推。

roc_curve(y_true,scores,pos_label): 对应的参数分别为y的真实标签,预测为正类的概率,pos_label 是指明哪个标签为正类,因为默认都是-1和1,1被当作正类,如果y对应的不是这个,就会报错,所以需要特别指明一下。

返回值为对应的fpr,tprthresholds

1、二分类问题

二分类示例:

from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
import numpy as np
 
y = np.array([1, 1, 2, 2])
scores = np.array([0.1, 0.4, 0.35, 0.8])
fpr, tpr, thresholds = roc_curve(y, scores, pos_label=2)

# 基于roc_curve函数返回fpr、tpr序列计算AUC的值(和roc_auc_score等价)
auc(fpr, tpr), roc_auc_score(y, pred)

注意:roc_curve只针对二分类情况,多分类情况有点特殊。下面给出多分类示例。

2、分类问题

对于多分类问题,ROC曲线的获取主要有两种方法:

假设测试样本个数为m,类别个数为n。在训练完成后,计算出每个测试样本在各类别下的概率或置信度,得到一个[m, n]形状的矩阵P,每一行表示一个测试样本在各类别下概率值(按类别标签排序)。相应地,将每个测试样本的标签转换为类似二进制的形式,每个位置用来标记是否属于对应的类别(也按标签排序,这样才和前面对应),由此也可以获得一个[m, n]的标签矩阵L。

  • 方法一:

每种类别下,都可以得到m个测试样本为该类别的概率(矩阵P中的列)。所以,根据概率矩阵P和标签矩阵L中对应的每一列,可以计算出各个阈值下的假正例率(FPR)和真正例率(TPR),从而绘制出一条ROC曲线。这样总共可以绘制出n条ROC曲线。最后对n条ROC曲线取平均,即可得到最终的ROC曲线。

  • 方法二:

首先,对于一个测试样本:1)标签只由0和1组成,1的位置表明了它的类别(可对应二分类问题中的‘’正’’),0就表示其他类别(‘’负‘’);2)要是分类器对该测试样本分类正确,则该样本标签中1对应的位置在概率矩阵P中的值是大于0对应的位置的概率值的。基于这两点,将标签矩阵L和概率矩阵P分别按行展开,转置后形成两列,这就得到了一个二分类的结果。所以,此方法经过计算后可以直接得到最终的ROC曲线。

上面的两个方法得到的ROC曲线是不同的,当然曲线下的面积AUC也是不一样的。 在python中,方法1和方法2分别对应sklearn.metrics.roc_auc_score函数中参数average值为’macro’和’micro’的情况。下面参考sklearn官网提供的例子,对两种方法进行实现。

# 引入必要的库
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from scipy import interp

# 加载数据
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 将标签二值化
y = label_binarize(y, classes=[0, 1, 2]) # 三个类别

# 设置种类
n_classes = y.shape[1]

# 训练模型并预测
random_state = np.random.RandomState(0)
n_samples, n_features = X.shape

# shuffle and split training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5,random_state=0)

# Learn to predict each class against the other
classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True,
                                 random_state=random_state))
y_score = classifier.fit(X_train, y_train).predict_proba(X_test) # 获得预测概率

# 计算每一类的ROC
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and ROC area(方法二)
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Compute macro-average ROC curve and ROC area(方法一)
# First aggregate all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

# Then interpolate all ROC curves at this points
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += interp(all_fpr, fpr[i], tpr[i])

# Finally average it and compute AUC
mean_tpr /= n_classes
fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

# Plot all ROC curves
lw=2
plt.figure()
plt.plot(fpr["micro"], tpr["micro"],
         label='micro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["micro"]),
         color='deeppink', linestyle=':', linewidth=4)

plt.plot(fpr["macro"], tpr["macro"],
         label='macro-average ROC curve (area = {0:0.2f})'
               ''.format(roc_auc["macro"]),
         color='navy', linestyle=':', linewidth=4)

colors = cycle(['aqua', 'darkorange', 'cornflowerblue'])
for i, color in zip(range(n_classes), colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=lw,
             label='ROC curve of class {0} (area = {1:0.2f})'
             ''.format(i, roc_auc[i]))

plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([-0.02, 1.0])
plt.ylim([0.0, 1.02])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/512635.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

第四节 Linux 特殊权限SUID、SGID、SBIT

目录 1.Set UID 简称 SUID 2.Set GID 简称 SGID 3.Sticky Bit 简称 SBIT 1.Set UID 简称 SUID 简称 SUID 限制与功能: SUID权限仅对二进制程序有效; 执行者对于该程序需要具有x的执行权限; 本权限仅在执行该程序的过程中有效&#xff1…

Softmax简介

Softmax是一种数学函数,通常用于将一组任意实数转换为表示概率分布的实数。其本质上是一种归一化函数,可以将一组任意的实数值转化为在[0, 1]之间的概率值,因为softmax将它们转换为0到1之间的值,所以它们可以被解释为概率。如果其…

VSCode +gdb+gdbserver远程调试arm开发板

一、下载编译器 从ARM官网下载gcc-arm编译器,编译器中自带gdb和gdbserver,可以省去自己编译。 注:gdb是电脑端程序,gdbserver是arm开发板程序 arm官网链接:https://developer.arm.com/downloads/-/arm-gnu-toolchain-d…

速卖通、Lazada、美客多、亚马逊新品流量如何利用测评快速提升?

熟悉亚马逊的卖家应该清楚,亚马逊对于新发布的产品会有一定的流量倾向,特别是产品刚上架的2-4周,你的产品将在搜索结果中显示更多,排名比通常情况下更快。 第一步:优化好自己的产品listing1.新品上架标题要点标题权重…

SLM27524一款能够有效驱动MOSFET和IGBT电源开关双通道低侧栅极驱动器

深力科电子为“数据中心服务器电源”推荐一款双通道大非反相低侧栅极驱动器 SLM27524,该产品能够有效驱动MOSFET和IGBT电源开关。SLM27524采用一种能够从内部极大地降低击穿电流的设计,将高峰值的源电流和灌电流脉冲提供给电容负载,从而实现了…

NDK OpenGL离屏渲染与工程代码整合

NDK​系列之OpenGL离屏渲染与工程代码整合,本节主要是对上一节OpenGL渲染画面效果代码进行封装设计,将各种特效代码进行分离解耦,便于后期增加其他特效。 实现效果: 实现逻辑: 1.封装BaseFilter过滤器基类&#xff0c…

C++ 多线程编程(四) 原子类型atomic

C 11增加了原子类型atomic类&#xff0c;在一定条件下可以实现无锁编程。 1. 简介 atomic是一个模板类&#xff0c;定义如下&#xff1a; template< class T > struct atomic; atomic可以实现无锁编程&#xff0c;在效率上要比mutex高很多&#xff0c;直接看个直观的…

有道云笔记常用快捷键

F5 同步/刷新 Shift AltD 插入当前时间&#xff1a; CTRL B 加粗 CTRL I 斜体字 CTRL U 下划线 CTRL E 删除线 CTRL D 任务框 CTRL 1 变成标题1 CTRL 2 变成标题2 CTRL 3 变成标题3 CTRL 4 变成标题4 CTRL G 高亮块 CTRL H 加水平线 当前行成无序列表&a…

npm安装依赖实践总结

node下载地址&#xff1a;https://nodejs.org/en/download/releases 。可以看到node版本、npm版本、node_module版本。 【1】npm的全局安装路径 查看默认值&#xff1a; npm get prefix默认是C:\Users\你的用户名\AppData\Roaming\npm 可以通过 npm config prefix 更改全局…

为什么PCB设计完成后需要放置mark点

PCB设计中的Mark点是指一些标记点&#xff0c;通常用于促进PCB制造和组装过程中的准确性和一致性。这些标记点在制造过程中可以帮助操作员进行自动化定位&#xff0c;从而确保所有部件都被正确组装到其正确位置&#xff0c;这对于确保产品的质量和可靠性至关重要。 下面&#…

springboot抵御即跨站脚本(XSS)攻击

抵御即跨站脚本(XSS)攻击 XSS攻击通常指的是通过利用网站系统保存系统的漏洞&#xff0c;通过巧妙的方法把恶意指令注入到网页&#xff0c;用户加载网页的时候会自动执行恶意脚本 比如&#xff1a; <script>alert(xss); </script> 如果客户能在你的浏览器执行j…

C# Setting.settings . 配置用法

1、定义 在Settings.settings文件中定义配置字段。把作用范围定义为&#xff1a;User则运行时可更改(用户范围的字段数据更改存储在用户信息中&#xff0c;不在该程序文件中)&#xff0c;Applicatiion则运行时不可更改。可以使用数据网格视图(VS软件的Properties 下面的Settin…

几何深度学习 - 利用几何先验知识的深度学习

深度学习很难。 虽然通用逼近定理表明足够复杂的神经网络原则上可以逼近“任何东西”&#xff0c;但不能保证我们可以找到好的模型。 尽管如此&#xff0c;通过明智地选择模型架构&#xff0c;深度学习取得了巨大进步。 这些模型架构对归纳偏差进行编码&#xff0c;为模型提供…

makefile 条件判断语句

文章目录 前言一、条件判断语句的语法说明二、ifeq / ifneq三、ifdef / ifndef代码讲解&#xff1a; 四、经典示例总结 前言 一、条件判断语句的语法说明 makefile 中支持条件判断语句。 可以根据条件的值决定 make 的执行。可以 比较 两个不同变量或者变量和常量值。 条件判…

物理环境与网络通信安全

物理环境与网络通信安全 物理和环境安全环境安全设施安全-安全区域与边界防护传输安全 OSI七层模型ISO/OSI七层模型结构OSI模型特点OSI安全体系结构 TCP/IP协议安全TCP/IP协议结构网络接口层网络互联层核心协议-IP协议传输层协议-TCP&#xff08;传输控制协议&#xff09;传输层…

AMB300系列母线槽红外测温解决方案南沙XX养殖项目案例分享

安科瑞 耿敏花 一、 行业背景 随着当今社会的发展和用电量的急剧上升&#xff0c;现代化工程设施和装备的涌现&#xff0c;封闭式母线即母线槽因方便、节能、载流量大、机械强度高 、安装灵活、寿命长等特点&#xff0c;逐渐取代传统电缆&#xff0c;广泛应用于室内变压…

Window下载Android源码

Android 10源码下载 想要研究Android 源码的同学可以用此方法进行下载。源码从清华大学开源软件镜像站&#xff08;https://mirrors.tuna.tsinghua.edu.cn/help/AOSP/&#xff09;下载。 使用Linux的同学直接参照清华镜像站提供的使用帮助(https://mirrors.tuna.tsinghua.edu…

面试题:react、 vue中的key有什么作用? (key的内部原理)

面试题:react、 vue中的key有什么作用? &#xff08;key的内部原理) 1.虚拟DOM中key的作用: key是虚拟DOM对象的标识&#xff0c;当状态中的数据发生变化时&#xff0c;Vue会根据【新数据】生成【新的虚拟DON】,随后Vue进行【新虚拟DOM】与【旧虚拟DOM】的差异比较&#xff0…

如何用ChatGPT做内容营销方案和选题计划,同时生产和优化内容?

该场景对应的关键词库&#xff08;31个&#xff09;&#xff1a; 内容营销、目标、主题、类型、选题计划、素材、推广策略、优化方案、渠道、目标受众、竞争对手、行业背景、转化率、品牌知名度、客户参与度、销售、发布频率、选题阶段、生产阶段、推广阶段、预算分配、人群特征…

Python+Yolov5舰船侦测识别

程序示例精选 PythonYolov5舰船侦测识别 如需安装运行环境或远程调试&#xff0c;见文章底部个人QQ名片&#xff0c;由专业技术人员远程协助&#xff01; 前言 这篇博客针对<<PythonYolov5舰船侦测识别>>编写代码&#xff0c;代码整洁&#xff0c;规则&#xff0c…