ROC 曲线详解

news2025/4/9 10:43:20

前言

ROC 曲线是一种坐标图式的分析工具,是由二战中的电子和雷达工程师发明的,发明之初是用来侦测敌军飞机、船舰,后来被应用于医学、生物学、犯罪心理学。

如今,ROC 曲线已经被广泛应用于机器学习领域的模型评估,说到这里就不得不提到 Tom Fawcett 大佬,他一直在致力于推广 ROC 在机器学习领域的应用,他发布的论文《An introduction to ROC analysis》[1]更是被奉为 ROC 的经典之作(引用 2.2w 次),知名机器学习库 scikit-learn 中的 ROC 算法就是参考此论文实现,可见其影响力!

不知道大多数人是否和我一样,对于 ROC 曲线的理解只停留在调用 scikit-learn 库的函数,对于它的背后原理和公式所知甚少。

前几天我重读了《An introduction to ROC analysis》终于将 ROC 曲线彻底搞清楚了,独乐乐不如众乐乐!如果你也对 ROC 的算法及实现感兴趣,不妨花些时间看完全文,相信你一定会有所收获!

图片

一、什么是 ROC 曲线

下图中的蓝色曲线就是 ROC 曲线,它常被用来评价二值分类器的优劣,即评估模型预测的准确度。

二值分类器,就是字面意思它会将数据分成两个类别(正/负样本)。例如:预测银行用户是否会违约、内容分为违规和不违规,以及广告过滤、图片分类等场景。篇幅关系这里不做多分类 ROC 的讲解。

图片

坐标系中纵轴为 TPR(真阳率/命中率/召回率)最大值为 1,横轴为 FPR(假阳率/误判率)最大值为 1,虚线为基准线(最低标准),蓝色的曲线就是 ROC 曲线。其中 ROC 曲线距离基准线越远,则说明该模型的预测效果越好。(TPR: True positive rate; FPR: False positive rate)

  • ROC 曲线接近左上角:模型预测准确率很高

  • ROC 曲线略高于基准线:模型预测准确率一般

  • ROC 低于基准线:模型未达到最低标准,无法使用

二、背景知识

考虑一个二分类模型, 负样本(Negative) 为 0,正样本(Positive) 为 1。即:

  • 标签 y 的取值为 0 或 1。

  • 模型预测的标签为 \hat{y},取值也是 0 或 1。

因此,将 y\hat{y} 两两组合就会得到 4 种可能性,分别称为:

图片

2.1 公式

ROC 曲线的横坐标为 FPR(False Positive Rate),纵坐标为 TPR(True Positive Rate)。FPR 统计了所有负样本中 预测错误(FP) 的比例,TPR 统计了所有正样本中 预测正确(TP) 的比例,其计算公式如下,其中 # 表示统计个数,例如 #N 表示负样本的个数,#P 表示正样本的个数

\text{FPR}=\frac{\#\text{FP}}{\#\text{N}} $$ $$\text{TPR}=\frac{\#\text{TP}}{\#\text{P}}

2.2 计算方法

下面举一个实际例子作为讲解,以下表 5 个样本为例,讲解如何计算 FPR 和 TPR

id真实标签  y预测标签 \hat{y}
111
210
300
411
501

正样本数 \#P=3,负样本数\#N=2

其中 y=0\hat{y}=1的样本有 1 个,即 \#FP=1,所以 FPR=1/2=0.5

其中 y=1\hat{y}=1 的样本有 2 个,即 \#TP=2,所以 FPR=2/3

FPR 和 TPR 的取值范围均是 0 到 1 之间。对于 FPR,我们希望其越小越好。而对于 TPR,我们希望其越大越好。

至此,我们已经介绍完如何计算 FPR 和 TPR 的值,下面将会讲解如何绘制 ROC 曲线。

三、绘制 ROC 曲线

讲到这里,可能有的同学会问:ROC 不是一条曲线吗?讲了这么多它到底应该怎么画呢?下面将分为两部分讲解如何绘制 ROC 曲线,直接打通你的“任督二脉”彻底拿下 ROC 曲线:

  • 第一部分:通过手绘的方式讲解原理

  • 第二部分:Python 代码实现,代码清爽易读

3.1 手绘 ROC 曲线

一般在二分类模型里(标签取值为 0 或 1),会默认设定一个阈值 (threshold)。当预测分数大于这个阈值时,输出 1,反之输出 0。我们可以通过调节这个阈值,改变模型预测的输出,进而画出 ROC 曲线。

以下面表格中的 20 个点为例,介绍如何人工画出 ROC 曲线,其中正样本和负样本都是 10 个,即 \#P = \#N = 10

id真实标签预测分数id真实标签预测分数
11.9111.4
21.8120.39
30.7131.38
41.6140.37
51.55150.36
61.54160.35
70.53171.34
80.52180.33
91.51191.30
100.505200.1

当设定阈值为 0.9 时,只有第一个点预测为 1,其余都为 0,故 \#FP=0\#TP=1,计算出 FPR=0/10=0TPR=1/10=0.1,画出点 (0,0.1)

当设定阈值为 0.8 时,只有前两个点预测为 1,其余都为 0,故 $\#FP=0、\#TP=2$,计算出 FPR=0/10=0 、TPR=2/10=0.2,画出点 (0,0.2)

当设定阈值为 0.7 时,只有前三个点预测为 1,其余都为 0,故 \#FP=1\#TP=2,计算出 FPR=1/10=0.1TPR=2/10=0.2,画出点 (0.1,0.2)。

以此类推,画出的 ROC 曲线如下:

图片

因此,在画 ROC 曲线前,需要将预测分数从大到小排序,然后将预测分数依次设定为阈值,分别计算 FPRTPR。而对于基准线,假设随机预测为正样本的概率为 x,即 \Pr(\hat{y}=1)=x 由于 FPR 计算的是负样本中,预测为正样本的概率,因此 FPR= x(同理,TPR= x)。所以,基准线为从点 (0, 0) 到 (1, 1) 的斜线

3.2 Python 代码

接下来,我们将结合代码讲解如何在 Python 中绘制 ROC 曲线。

下面的代码参考了《An Introduction to ROC Analysis》[2]中的算法 1(伪代码)。值得一提的是,知名机器学习库 scikit-learn 的 roc_curve 函数[3] 也参考了这个算法。

图片

下面我自己实现的 roc 函数可以理解为是简化版的 roc_curve,这里的代码逻辑更加简洁易懂,算法的时间复杂度 O ( n log ⁡ n ) O(n\log n) O(nlogn)。

完整的代码如下:

# import numpy as np
def roc(y_true, y_score, pos_label):
    """
    y_true:真实标签
    y_score:模型预测分数
    pos_label:正样本标签,如“1”
    """
    # 统计正样本和负样本的个数
    num_positive_examples = (y_true == pos_label).sum()
    num_negtive_examples = len(y_true) - num_positive_examples

    tp, fp = 0, 0
    tpr, fpr, thresholds = [], [], []
    score = max(y_score) + 1
    
    # 根据排序后的预测分数分别计算fpr和tpr
    for i in np.flip(np.argsort(y_score)):
        # 处理样本预测分数相同的情况
        if y_score[i] != score:
            fpr.append(fp / num_negtive_examples)
            tpr.append(tp / num_positive_examples)
            thresholds.append(score)
            score = y_score[i]
            
        if y_true[i] == pos_label:
            tp += 1
        else:
            fp += 1

    fpr.append(fp / num_negtive_examples)
    tpr.append(tp / num_positive_examples)
    thresholds.append(score)

    return fpr, tpr, thresholds

导入上面 3.1 表格中的数据,通过上面实现的 roc 方法,计算 ROC 曲线的坐标值。

import numpy as np

y_true = np.array(
    [1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0]
)
y_score = np.array([
    .9, .8, .7, .6, .55, .54, .53, .52, .51, .505,
    .4, .39, .38, .37, .36, .35, .34, .33, .3, .1
])

fpr, tpr, thresholds = roc(y_true, y_score, pos_label=1)

最后,通过 Matplotlib 将计算出的 ROC 曲线坐标绘制成图。

import matplotlib.pyplot as plt

plt.plot(fpr, tpr)
plt.axis("square")
plt.xlabel("False positive rate")
plt.ylabel("True positive rate")
plt.title("ROC curve")
plt.show()

图片

至此,ROC 的基础知识部分就全部讲完了,如果还想深入了解的同学可以继续往下看。

四、联邦学习中的 ROC 平均

图片

顾名思义,ROC 平均就是将多条 ROC 曲线“平均化”。那么,什么场景需要做 ROC 平均呢?例如:横向联邦学习中,由于样本都在用户本地,服务器可以采用 ROC 平均的方式,计算近似的全局 ROC 曲线

ROC 的平均有两种方法:垂直平均、阈值平均,下面将逐一进行讲解,并给出 Python 代码实现。

4.1 垂直平均

图片

垂直平均(Vertical averaging)的思想是,选取一些 FPR 的点,计算其平均的 TPR 值。下面是论文中的算法描述的伪代码,看不懂可直接略过看 Python 代码实现部分。

图片

下面是 Python 的代码实现:

# import numpy as np
def roc_vertical_avg(samples, FPR, TPR):
    """
    samples:选取FPR点的个数
    FPR:包含所有FPR的列表
    TPR:包含所有TPR的列表
    """
    nrocs = len(FPR)
    tpravg = []
    fpr = [i / samples for i in range(samples + 1)]

    for fpr_sample in fpr:
        tprsum = 0
        # 将所有计算的tpr累加
        for i in range(nrocs):
            tprsum += tpr_for_fpr(fpr_sample, FPR[i], TPR[i])
        # 计算平均的tpr
        tpravg.append(tprsum / nrocs)

    return fpr, tpravg

# 计算对应fpr的tpr
def tpr_for_fpr(fpr_sample, fpr, tpr):
    i = 0
    while i < len(fpr) - 1 and fpr[i + 1] <= fpr_sample:
        i += 1

    if fpr[i] == fpr_sample:
        return tpr[i]
    else:
        return interpolate(fpr[i], tpr[i], fpr[i + 1], tpr[i + 1], fpr_sample)

# 插值
def interpolate(fprp1, tprp1, fprp2, tprp2, x):
    slope = (tprp2 - tprp1) / (fprp2 - fprp1)
    return tprp1 + slope * (x - fprp1)

4.2 阈值平均

图片

阈值平均(Threshold averaging)的思想是,选取一些阈值的点,计算其平均的 FPR 和 TPR。

图片

下面是 Python 的代码实现:

# import numpy as np
def roc_threshold_avg(samples, FPR, TPR, THRESHOLDS):
    """
    samples:选取FPR点的个数
    FPR:包含所有FPR的列表
    TPR:包含所有TPR的列表
    THRESHOLDS:包含所有THRESHOLDS的列表
    """
    nrocs = len(FPR)
    T = []
    fpravg = []
    tpravg = []

    for thresholds in THRESHOLDS:
        for t in thresholds:
            T.append(t)
    T.sort(reverse=True)

    for tidx in range(0, len(T), int(len(T) / samples)):
        fprsum = 0
        tprsum = 0
        # 将所有计算的fpr和tpr累加
        for i in range(nrocs):
            fprp, tprp = roc_point_at_threshold(FPR[i], TPR[i], THRESHOLDS[i], T[tidx])
            fprsum += fprp
            tprsum += tprp
        # 计算平均的fpr和tpr
        fpravg.append(fprsum / nrocs)
        tpravg.append(tprsum / nrocs)

    return fpravg, tpravg

# 计算对应threshold的fpr和tpr
def roc_point_at_threshold(fpr, tpr, thresholds, thresh):
    i = 0
    while i < len(fpr) - 1 and thresholds[i] > thresh:
        i += 1
    return fpr[i], tpr[i]

五、最后

本文由浅入深地详细介绍了 ROC 曲线算法,包含算法原理、公式、计算、源码实现和讲解,希望能够帮助读者一口气搞懂 ROC。

虽然 ROC 是个不起眼的知识点,但能网上能彻底讲清楚 ROC 的文章并不多。所以我又花时间重温了一遍 Tom Fawcett 的经典论文《An introduction to ROC analysis》[4],并将论文的内容抽丝剥茧、配上通俗易懂的 Python 代码,最终写出了这篇文章。再次致敬🫡 Tom Fawcett,感谢他在机器学习领域的贡献!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1198998.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

「题解」反转链表 返回中间节点

文章目录 &#x1f349;题目1&#xff1a;反转链表&#x1f349;解析&#x1f34c;解法一&#xff1a;创建一个新链表&#x1f34c;解法二&#xff1a;直接操作原链表 &#x1f349;题目2&#xff1a;返回中间节点&#x1f34c;解法一&#xff1a;快慢指针&#x1f34c;解法二&…

2023年【汽车驾驶员(高级)】找解析及汽车驾驶员(高级)复审考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 汽车驾驶员&#xff08;高级&#xff09;找解析是安全生产模拟考试一点通总题库中生成的一套汽车驾驶员&#xff08;高级&#xff09;复审考试&#xff0c;安全生产模拟考试一点通上汽车驾驶员&#xff08;高级&#…

【Linux】WSL安装Kali及基本操作

&#x1f60f;★,:.☆(&#xffe3;▽&#xffe3;)/$:.★ &#x1f60f; 这篇文章主要介绍WSL安装Kali及基本操作。 学其所用&#xff0c;用其所学。——梁启超 欢迎来到我的博客&#xff0c;一起学习&#xff0c;共同进步。 喜欢的朋友可以关注一下&#xff0c;下次更新不迷路…

采用示波器显示扭矩传感器模拟信号

扭矩传感器输出的信号波形通常是模拟电压信号&#xff0c;可以通过示波器等仪器进行分析。扭矩传感器的输出信号波形通常有两种类型&#xff1a;正弦波和方波。 应变片传感器扭矩测量采用应变电测技术。在弹性轴上粘贴应变计组成测量电桥&#xff0c;当弹性轴受扭矩产生微小变…

IPV4过渡IPV6的关键技术NAT(Network AddressTranslation,网络地址转换)

文章目录 NAT的由来NAT基本工作机制NAT技术的分类推荐阅读 NAT的由来 随着物联网、工业互联网、5G的快速发展&#xff0c;网络应用对IP地址的需求呈现出爆炸式的增长。 然而&#xff0c;早在2011年&#xff0c;ICANN就发布公告称最后五组IP地址已分配完毕&#xff0c;已无IPv4…

华为ensp搭建小型园区网络规划

文章目录 前言一、拓扑图二、数据规划三、设备配置四.配置命令1.配置接入层交换机ACC11.1 设备命名&#xff0c;创建VLAN1.2 配置eth-trunk 11.3 配置用户端 2.配置核心层交换机CORE2.1设备命名2.2配置Eth-Trunk2.3 vlan配置ip2.4 上行接口配置 3.DHCP配置3.1 CORE: 4.配置路由…

【CASS精品教程】cass3d 11.0加载超大影像、三维模型、点云数据

CAD2016+CASS11.0(内置3d)下载与安装: 【CASS精品教程】CAD2016+CASS11.0安装教程(附CASS11.0安装包下载)https://geostorm.blog.csdn.net/article/details/132392530 一、cass11.0 3d支持的数据 cass11.0中的3d模块增加了多种数据的支持,主要有: 1. 三维模型 点击…

Python文件、文件夹操作汇总

目录 一、概览 二、文件操作 2.1 文件的打开、关闭 2.2 文件级操作 2.3 文件内容的操作 三、文件夹操作 四、常用技巧 五、常见使用场景 5.1 查找指定类型文件 5.2 查找指定名称的文件 5.3 查找指定名称的文件夹 5.4 指定路径查找包含指定内容的文件 一、概览 ​在…

Least Square Method 最小二乘法(图文详解,必懂)

最小二乘法是一种求解线性回归模型的优化方法&#xff0c;其目标是最小化数据点和拟合直线之间的残差平方和。这意味着最小二乘法关注的是找到一个直线&#xff0c;使得所有数据点与该直线的偏差的平方和最小。在数学公式中&#xff0c;如果y是实际值&#xff0c;y是函数估计值…

头歌答案Python——JSON基础

目录 ​编辑 Python——JSON基础 第1关&#xff1a;JSON篇&#xff1a;JSON基础知识 任务描述 第2关&#xff1a;JSON篇&#xff1a;使用json库 任务描述 Python——XPath基础 第1关&#xff1a;XPath 路径表达式 任务描述 第2关&#xff1a;XPath 轴定位 任务描述…

计算机毕业设计:疲劳驾驶检测识别系统 python深度学习 YOLOv5 (包含文档+源码+部署教程)

[毕业设计]2023-2024年最新最全计算机专业毕设选题推荐汇总 1、项目介绍 基于YOLOv5的疲劳驾驶检测系统使用深度学习技术检测常见驾驶图片、视频和实时视频中的疲劳行为&#xff0c;识别其闭眼、打哈欠等结果并记录和保存&#xff0c;以防止交通事故发生。本文详细介绍疲劳驾…

2023-11-12 LeetCode每日一题(Range 模块)

2023-03-29每日一题 一、题目编号 715. Range 模块二、题目链接 点击跳转到题目位置 三、题目描述 Range模块是跟踪数字范围的模块。设计一个数据结构来跟踪表示为 半开区间 的范围并查询它们。 半开区间 [left, right) 表示所有 left < x < right 的实数 x 。 实…

服务号如何升级订阅号

服务号和订阅号有什么区别&#xff1f;服务号转为订阅号有哪些作用&#xff1f;首先我们要知道服务号和订阅号有什么区别。服务号侧重于对用户进行服务&#xff0c;每月可推送4次&#xff0c;每次最多8篇文章&#xff0c;发送的消息直接显示在好友列表中。订阅号更侧重于信息传…

利用移动性的比例公平蜂窝调度测量和算法

&#xff08;一支笔一包烟&#xff0c;一节论文看一天 &#xff09;&#xff08;一张纸一瓶酒&#xff0c;一道公式推一宿&#xff09; 摘要1. 引言2. 相关工作3. 模型和问题公式4. 预测FPF调度 &#xff08; P F &#xff09; 2 S &#xff08;PF&#xff09;^2S &#xff08;…

在线制作仿真病历证明软件,易语言实现病例报告生成器,取画板快照+标签+编辑框

闲着无聊用易语言开发了一个病例生成器&#xff0c;当然我加了水印的&#xff0c;这个图片你就算截图你也用不了&#xff0c;模板是从百度图库搜的&#xff0c;很多&#xff0c;我就随便找了一个&#xff0c;然后实现逻辑就是加了一个画板&#xff0c;然后载入了素材图&#xf…

常见面试题-Redis底层的SDS、ZipList、ListPack

Redis 的 SDS 了解吗&#xff1f; 答&#xff1a; Redis 创建了 SDS&#xff08;simple dynamic string&#xff09; 的抽象类型作为 String 的默认实现 SDS 的结构如下&#xff1a; struct sdshdr {// 字节数组&#xff0c;用于保存字符串char buf[];// buf[]中已使用字节…

Xilinx FPGA平台DDR3设计详解(一):DDR SDRAM系统框架

DDR SDRAM&#xff08;双倍速率同步动态随机存储器&#xff09;是一种内存技术&#xff0c;它可以在时钟信号的上升沿和下降沿都传输数据&#xff0c;从而提高数据传输的速率。DDR SDRAM已经发展了多代&#xff0c;包括DDR、DDR2、DDR3、DDR4和DDR5&#xff0c;每一代都有不同的…

中国国内机场信息集成系统厂家现状情况

机场信息集成系统在本世纪初进入中国市场&#xff0c;早期的信息集成系统提供商以外企为主&#xff0c;后来国内企业迅速发展。但在2008年前&#xff0c;民航总局设立了机场信息系统的入门门槛&#xff0c;也就是需要民航空管工程及机场弱电系统建设资质要求&#xff0c;该要求…

Linux学习教程(第二章 Linux系统安装)3

第二章 Linux系统安装 十一、Linux远程管理协议&#xff08;RFB、RDP、Telnet和SSH&#xff09; 提到远程管理&#xff0c;通常指的是远程管理服务器&#xff0c;而非个人计算机。个人计算机可以随时拿来用&#xff0c;服务器通常放置在机房中&#xff0c;用户无法直接接触到…

【云备份|| 日志 day6】文件业务处理模块

云备份day6 业务处理 业务处理 云备份项目中 &#xff0c;业务处理模块是针对客户端的业务请求进行处理&#xff0c;并最终给与响应。而整个过程中包含以下要实现的功能&#xff1a; 借助网络通信模块httplib库搭建http服务器与客户端进行网络通信针对收到的请求进行对应的业…