【Python机器学习系列】一文教你绘制多分类任务的ROC曲线-宏平均ROC曲线(案例+源码)

news2024/12/24 2:21:05

这是我的第345篇原创文章。

一、引言

ROC曲线是用于评估二分类模型性能的工具,它展示了模型在不同阈值下的真阳性率与假阳性率之间的关系,但是标准的ROC并不能运用于多分类任务种,于是扩展出了宏平均ROC曲线。

宏平均ROC曲线是多分类问题中对ROC曲线的扩展,在多分类任务中,我们需要计算每一类别相对于其他所有类别的ROC曲线,然后对所有这些ROC曲线进行平均,从而得到宏平均ROC曲线,其主要步骤如下:

  • 逐类计算ROC曲线:对于每个类别,将其视为正类,其他所有类别视为负类,计算出相应的ROC曲线,也就是可以看作对每个类别进行独热编码

  • 计算AUC值:计算每个类别对应的AUC值

  • 平均化:对所有类别的AUC值进行平均,从而得到宏平均AUC值,同时,将各类别的ROC曲线取平均,得到宏平均ROC曲线

宏平均ROC曲线的优点在于它平等地考虑了每个类别的性能,适用于类别数量不平衡的情况,不过,由于它对所有类别进行了简单平均,如果某些类别比其他类别更加重要,宏平均ROC可能无法完全反映分类器的实际性能。

二、实现过程

2.1 准备数据

data = pd.read_csv(r'data.csv')
df = pd.DataFrame(data)
print(df.head())

该多分类数据存在3个类别:

图片

2.2 提取目标变量

target = 'Type'
features = df.columns.drop(target)
print(data["Type"].value_counts()) # 顺便查看一下样本是否平衡

图片

2.3 划分数据集

# df = shuffle(df)
X_train, X_test, y_train, y_test = train_test_split(df[features], df[target], test_size=0.2, random_state=0)

2.4 归一化

mm1 = MinMaxScaler()   # 特征进行归一化
X_train_m = mm1.fit_transform(X_train)

2.5 模型的构建

model = RandomForestClassifier()

2.6 模型的训练

model.fit(X_train_m, y_train)

2.7 模型的推理

X_test_m = mm1.transform(X_test)
y_pred = model.predict(X_test_m)
y_scores = model.predict_proba(X_test_m)
print(y_pred)
acc = accuracy_score(y_test, y_pred) # 准确率acc
print(f"acc: \n{acc}")
cm = confusion_matrix(y_test, y_pred) # 混淆矩阵
print(f"cm: \n{cm}")
cr = classification_report(y_test, y_pred) # 分类报告
print(f"cr:  \n{cr}")

图片

2.8 模型的评价

acc = accuracy_score(y_test, y_pred) # 准确率acc
print(f"acc: \n{acc}")
cm = confusion_matrix(y_test, y_pred) # 混淆矩阵
print(f"cm: \n{cm}")
cr = classification_report(y_test, y_pred) # 分类报告
print(f"cr:  \n{cr}")

结果:

图片

把混淆矩阵进行可视化:

图片

2.9 绘制ROC曲线

计算宏平均ROC:

# 将y标签转换成one-hot形式
ytest_one_rf = label_binarize(y_test, classes=[1, 2, 3])

# 宏平均法计算AUC
rf_AUC = {}
rf_FPR = {}
rf_TPR = {}

for i in range(ytest_one_rf.shape[1]):
    rf_FPR[i], rf_TPR[i], thresholds = roc_curve(ytest_one_rf[:, i], y_scores[:, i])
    rf_AUC[i] = auc(rf_FPR[i], rf_TPR[i])
print(rf_AUC)

# 合并所有的FPR并排序去重
pass

# 计算宏平均TPR
rf_TPR_all = np.zeros_like(rf_FPR_final)
for i in range(ytest_one_rf.shape[1]):
    rf_TPR_all += np.interp(rf_FPR_final, rf_FPR[i], rf_TPR[i])
rf_TPR_final = rf_TPR_all / ytest_one_rf.shape[1]

# 计算最终的宏平均AUC
rf_AUC_final = auc(rf_FPR_final, rf_TPR_final)
AUC_final_rf = rf_AUC_final  # 最终AUC

print(f"Macro Average AUC with Random Forest: {AUC_final_rf}")

利用随机森林模型对测试集进行预测,并计算每个类别的预测概率。然后,将实际标签 ytest 转换为 one-hot 编码形式,以便进行多分类的 ROC 曲线分析,接着,通过逐类别计算 ROC 曲线和 AUC 值,并保存到字典中,最后,通过合并所有类别的 FPR 值并计算宏平均 TPR,从而得到最终的宏平均 AUC 值,用于评估随机森林模型在多分类任务中的整体性能。

绘制随机森林分类器在多分类任务中的 ROC 曲线,并计算并展示了每个类别的 AUC 值以及宏平均 ROC 曲线的 AUC:

plt.figure(figsize=(10, 5), dpi=300)
# 使用不同的颜色和线型
plt.plot(rf_FPR[0], rf_TPR[0], color='#1f77b4', linestyle='-', label='Class 1 ROC  AUC={:.4f}'.format(rf_AUC[0]), lw=2)
plt.plot(rf_FPR[1], rf_TPR[1], color='#ff7f0e', linestyle='-', label='Class 2 ROC  AUC={:.4f}'.format(rf_AUC[1]), lw=2)
plt.plot(rf_FPR[2], rf_TPR[2], color='#2ca02c', linestyle='-', label='Class 3 ROC  AUC={:.4f}'.format(rf_AUC[2]), lw=2)

# 宏平均ROC曲线
plt.plot(rf_FPR_final, rf_TPR_final, color='#000000', linestyle='-', label='Macro Average ROC  AUC={:.4f}'.format(rf_AUC_final), lw=3)
# 45度参考线
plt.plot([0, 1], [0, 1], color='gray', linestyle='--', lw=2, label='45 Degree Reference Line')
plt.xlabel('False Positive Rate (FPR)', fontsize=15)
plt.ylabel('True Positive Rate (TPR)', fontsize=15)
plt.title('Random Forest Classification ROC Curves and AUC', fontsize=18)
plt.grid(linestyle='--', alpha=0.7)
plt.legend(loc='lower right', framealpha=0.9, fontsize=12)
plt.savefig('RF_optimized.pdf', format='pdf', bbox_inches='tight')
plt.show()

结果:

图片

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2062041.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

工业控制常用“对象“数据类型汇总(数据结构篇)

合理巧妙的数据结构会大大简化项目的编程工作量,所以任何项目前期第一步应该是设计巧妙的数据结构、封装对象属性。这样会使我们的编程快捷和高效。这篇博客作为数据类型汇总,会不间断更新。 1、普通电机轴对象 2、普通电机轴对象(详细结构变量) TYPE "udtMotorAxis&q…

机器学习的入门笔记(第十五周)

本周观看了B站up主霹雳吧啦Wz的图像处理的课程, 课程链接:霹雳吧啦Wz的个人空间-霹雳吧啦Wz个人主页-哔哩哔哩视频 下面是本周的所看的课程总结。 利用GoogLeNet进行图像分类 GoogLeNet是由 Google 提出的卷积神经网络架构,于 2014 年在 …

没有用的小技巧之---接入网线,有内网没有外网,但是可以登录微信

打开控制面板,找到网络和Internet 选择Internet选项 点击连接,选择局域网设置 取消勾选代理服务器

JetBrains CLion 2024.2 (macOS, Linux, Windows) - C 和 C++ 跨平台 IDE

JetBrains CLion 2024.2 (macOS, Linux, Windows) - C 和 C 跨平台 IDE JetBrains 跨平台开发者工具 请访问原文链接:https://sysin.org/blog/jetbrains-clion/,查看最新版。原创作品,转载请保留出处。 作者主页:sysin.org Jet…

实战勤务指挥系统解决方案

4. 总体设计方案 方案围绕业务需求、接口需求和安全需求进行设计,包括语音集成、视频图像集成和第三方系统集成,以实现多系统联动和资源共享。 5. 系统特色 系统特色包括高度融合的指挥应用模式、简化的指挥流程、高效的管理机制,以及基于…

《Windows PE》2.1 初识PE文件

Windows PE文件(Portable Executable file)是一种可执行文件格式,用于Windows操作系统中的可执行程序、动态链接库(DLL)和驱动程序等。它是一种规范化的文件格式,定义了文件的结构和组织方式,以…

go设计模式———抽象工厂模式

抽象工厂模式概念 抽象工厂模式是一种设计模式,它允许创建一系列相关的对象,而无需指定具体的类。具体来说,抽象工厂定义了用于创建不同产品的接口,但实际的创建工作则由具体的工厂类完成。每个具体工厂负责创建一组相关的产品&am…

谷歌账号停用后申诉了,也收到了谷歌的邮件,如何判断谷歌申诉是否成功,成功了怎么办?被拒绝谷歌账号就废了吗?

似乎是谷歌分工机制的更新,最近谷歌账号“被停用”的情况貌似多了起来,许多朋友在谷歌账号提示活动异常,要输入手机号码恢复账号的时候,无论是否立刻恢复,很快好像就迎来了“您的账号已停用”的结果。或者有一些朋友许…

多元统计分析——基于R语言的单车使用情况可视化分析

注:基于R语言的单车使用情况可视化分析为实验记录,存在不足,自行改进。 一、提出问题(要解决或分析的问题) 1 、用户对共享单车的使用习惯,环境对共享单车运营带来的影响? 2 、共享单车的租赁…

【北京仁爱堂】痉挛性斜颈的健康指导

痉挛性斜颈是一种肌肉紧张异常症,仅限于颈部肌肉的肌张力障碍。当患者患有痉挛性斜颈,会表现为颈部肌肉间歇性或持续不规则的收缩,因此患者的头颈部会出现扭曲、歪斜、姿势异常等症状,多发于30-40岁左右中年人 一、 痉挛性斜颈的5…

mac和windows上安装nvm管理node版本

NVM 是 node version manager 的缩写,它是一个用来管理电脑上 node 版本的命令行工具,在日常前端开发中是一个跟 node 一样会经常用到的工具,可以很方便的让我们快速切换不同的node版本。 mac 上安装 nvm 1、下载安装 nvm 下载安装可以直…

【机器学习】逻辑回归原理(极大似然估计,逻辑函数Sigmod函数模型详解!!!)

目录 🍔 逻辑回归应用场景 🍔 极大似然估计 2.1 为什么要有极大似然估计? 2.2 极大似然估计步骤 2.3 极大似然估计的例子 🍔 Sigmod函数模型 3.1 逻辑斯特函数的由来 3.2 Sigmod函数绘图 3.3 进一步探究-加入线性回归 3…

【爬虫】 使用AI编写B站爬虫代码

记录一次,自己不写一行代码,所有的代码全由AI编写的过程。 本次使用的AI工具为:Claude 其他AI工具同理。 首先,观察哔哩哔哩网页的结构,定位到了包含视频信息的关键元素。右键检查或打开F12,找到最左侧的这…

2024前端面试题-js篇

1.js有哪些数据类型 基础数据类型:string,number,boolean,null,undefined,bigInt,symbol 引用数据类型:Object 2.js检测数据类型的方式 typeof:其中数组、对象、null都会被判断为object&…

基于WebSocket打造的一款SSH客户端

引用:Java打造一款SSH客户端,而且已开源_java ssh客户端-CSDN博客 由于原作者是放在Github上,不方便下载,所以下载下来,转存到码云上,地址:https://gitee.com/lfw1024/web-ssh 为了满足一些小白…

计算机毕业设计选题推荐-股票数据可视化分析与预测-Python爬虫

✨作者主页:IT毕设梦工厂✨ 个人简介:曾从事计算机专业培训教学,擅长Java、Python、微信小程序、Golang、安卓Android等项目实战。接项目定制开发、代码讲解、答辩教学、文档编写、降重等。 ☑文末获取源码☑ 精彩专栏推荐⬇⬇⬇ Java项目 Py…

Windows10拿到shell后远程登录

一、准备工作 kali机:192.168.19.130 win10:192.168.19.133 主机:192.168.1.73(自己操作可以在kali上,我这是因为反弹的shell在主机上) 二、开启远程登录 1.win10上关闭实时保护,并且运行了…

[C++] C++11详解 (一)

标题:[C] C11详解 (一) 水墨不写bug 目录 前言 一、列表初始化 二、STL的初始化列表(initializer_list —— Cplusplus.com) 三、声明方式(auto、decltype、nullptr) 1.auto ​编辑 2.decltype 正文开始&#x…

Spark环境搭建-Local

目录 Local下的角色分布: Anaconda On Linux 安装 (单台服务器) 1.下载安装 2.国内源 下载Spark安装包 1.下载 2.解压 3.环境变量 测试 监控 Local下的角色分布: 资源管理: Master:Local进程本身 Worker:L…

UE5.4 - 编辑器页面和概念术语

目录 一. 打开新项目 二. 主页面 1.菜单栏 2.工具栏 3.视口 4.内容侧滑菜单/内容浏览器 5.底部工具栏 6.大纲 7.细节面板 三. 虚幻引擎术语 四. 进一步的术语 五. 总结 一. 打开新项目 选择 虚幻引擎 -> 库 -> 启动 选择类型,选择示例的项目,可以把这些都选选…