Python绘制多分类ROC曲线

news2024/11/18 3:19:51

目录

1 数据集介绍

1.1 数据集简介

1.2 数据预处理

 2随机森林分类

2.1 数据加载

2.2 参数寻优

2.3 模型训练与评估

3 绘制十分类ROC曲线

第一步,计算每个分类的预测结果概率

第二步,画图数据准备

第三步,绘制十分类ROC曲线


1 数据集介绍

1.1 数据集简介

分类数据集为某公司手机上网满意度数据集,数据如图所示,共7020条样本,关于手机满意度分类的特征有网络覆盖与信号强度、手机上网速度、手机上网稳定性等75个特征。

1.2 数据预处理

常规数据处理流程,详细内容见上期随机森林处理流程:

xxx 链接

  • 缺失值处理

  • 异常值处理

  • 数据归一化

  • 分类特征编码

处理完后的数据保存为 手机上网满意度.csv文件(放置文末)

 2随机森林分类

2.1 数据加载

第一步,导入包

import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier# 随机森林回归
from sklearn.model_selection import train_test_split,GridSearchCV,cross_val_score
from sklearn.metrics import accuracy_score # 引入准确度评分函数
from sklearn.metrics import mean_squared_error
from sklearn import preprocessing
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rc("font", family='Microsoft YaHei')

第二步,加载数据

net_data = pd.read_csv('手机上网满意度.csv')
Y = net_data.iloc[:,3]   
X= net_data.iloc[:, 4:]
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=200)  # 随机数种子
print(net_data.shape)
net_data.describe()

2.2 参数寻优

第一步,对最重要的超参数n_estimators即决策树数量进行调试,通过不同数目的树情况下,在训练集和测试集上的均方根误差来判断

## 分析随着树数目的变化,在测试集和训练集上的预测效果
rfr1 = RandomForestClassifier(random_state=1)
n_estimators = np.arange(50,500,50) # 420,440,2
train_mse = []
test_mse = []
for n in n_estimators:
    rfr1.set_params(n_estimators = n) # 设置参数
    rfr1.fit(X_train,Y_train) # 训练模型
    rfr1_lab = rfr1.predict(X_train)
    rfr1_pre = rfr1.predict(X_test)
    train_mse.append(mean_squared_error(Y_train,rfr1_lab))
    test_mse.append(mean_squared_error(Y_test,rfr1_pre))

## 可视化不同数目的树情况下,在训练集和测试集上的均方根误差
plt.figure(figsize=(12,9))
plt.subplot(2,1,1)
plt.plot(n_estimators,train_mse,'r-o',label='trained MSE',color='darkgreen')
plt.xlabel('Number of trees')
plt.ylabel('MSE')
plt.grid()
plt.legend()

plt.subplot(2,1,2)
plt.plot(n_estimators,test_mse,'r-o',label='test MSE',color='darkgreen')
index = np.argmin(test_mse)
plt.annotate('MSE:'+str(round(test_mse[index],4)),
             xy=(n_estimators[index],test_mse[index]),
             xytext=(n_estimators[index]+2,test_mse[index]+0.000002),
             arrowprops=dict(facecolor='red',shrink=0.02))
plt.xlabel('Number of trees')
plt.ylabel('MSE')
plt.grid()
plt.legend()
plt.tight_layout()
plt.show()

以及最优参数和最高得分进行分析,如下所示

###调n_estimators参数
ScoreAll = []
for i in range(50,500,50):
    DT = RandomForestClassifier(n_estimators = i,random_state = 1) #,criterion = 'entropy'
    score = cross_val_score(DT,X_train,Y_train,cv=6).mean()
    ScoreAll.append([i,score])
ScoreAll = np.array(ScoreAll)

max_score = np.where(ScoreAll==np.max(ScoreAll[:,1]))[0][0] ##这句话看似很长的,其实就是找出最高得分对应的索引
print("最优参数以及最高得分:",ScoreAll[max_score])  
plt.figure(figsize=[20,5])
plt.plot(ScoreAll[:,0],ScoreAll[:,1],'r-o',label='最高得分',color='orange')
plt.xlabel('n_estimators参数')
plt.ylabel('分数')
plt.grid()
plt.legend()
plt.show()

很明显,决策树个数设置在400的时候回归森林预测模型的测试集均方根误差最小,得分最高,效果最显著。因此,我们通过网格搜索进行小范围搜索,构建随机森林预测模型时选取的决策树个数为400。

第二步,在确定决策树数量大概范围后,搜索决策树的最大深度的最高得分,如下所示

# 探索max_depth的最佳参数
ScoreAll = []  
for i in range(4,14,2):  
    DT = RandomForestClassifier(n_estimators = 400,random_state = 1,max_depth =i ) #,criterion = 'entropy'  
    score = cross_val_score(DT,X_train,Y_train,cv=6).mean()  
    ScoreAll.append([i,score])  
ScoreAll = np.array(ScoreAll)  
    
max_score = np.where(ScoreAll==np.max(ScoreAll[:,1]))[0][0] 
print("最优参数以及最高得分:",ScoreAll[max_score])    
plt.figure(figsize=[20,5])  
plt.plot(ScoreAll[:,0],ScoreAll[:,1]) 
plt.xlabel('max_depth最佳参数')
plt.ylabel('分数')
plt.grid()
plt.legend() 
plt.show()  

决策树的深度最终设置为10。

2.3 模型训练与评估

# 随机森林 分类模型  
model = RandomForestClassifier(n_estimators=400,max_depth=10,random_state=1) # min_samples_leaf=11
# 模型训练
model.fit(X_train, Y_train)
# 模型预测
y_pred = model.predict(X_test)
print('训练集模型分数:', model.score(X_train,Y_train))
print('测试集模型分数:', model.score(X_test,Y_test))
print("训练集准确率: %.3f" % accuracy_score(Y_train, model.predict(X_train)))
print("测试集准确率: %.3f" % accuracy_score(Y_test, y_pred))

绘制混淆矩阵:

# 混淆矩阵
from sklearn.metrics import confusion_matrix
import matplotlib.ticker as ticker

cm = confusion_matrix(Y_test, y_pred,labels=[1,2,3,4,5,6,7,8,9,10]) # ,
print('混淆矩阵:\n', cm)
labels=['1','2','3','4','5','6','7','8','9','10']
from sklearn.metrics import ConfusionMatrixDisplay
cm_display = ConfusionMatrixDisplay(cm,display_labels=labels).plot()

3 绘制十分类ROC曲线

第一步,计算每个分类的预测结果概率

from sklearn.metrics import roc_curve,auc
df = pd.DataFrame()
pre_score = model.predict_proba(X_test)
df['y_test'] = Y_test.to_list()

df['pre_score1'] = pre_score[:,0]
df['pre_score2'] = pre_score[:,1]
df['pre_score3'] = pre_score[:,2]
df['pre_score4'] = pre_score[:,3]
df['pre_score5'] = pre_score[:,4]
df['pre_score6'] = pre_score[:,5]
df['pre_score7'] = pre_score[:,6]
df['pre_score8'] = pre_score[:,7]
df['pre_score9'] = pre_score[:,8]
df['pre_score10'] = pre_score[:,9]

pre1 = df['pre_score1']
pre1 = np.array(pre1)

pre2 = df['pre_score2']
pre2 = np.array(pre2)

pre3 = df['pre_score3']
pre3 = np.array(pre3)

pre4 = df['pre_score4']
pre4 = np.array(pre4)

pre5 = df['pre_score5']
pre5 = np.array(pre5)

pre6 = df['pre_score6']
pre6 = np.array(pre6)

pre7 = df['pre_score7']
pre7 = np.array(pre7)

pre8 = df['pre_score8']
pre8 = np.array(pre8)

pre9 = df['pre_score9']
pre9 = np.array(pre9)

pre10 = df['pre_score10']
pre10 = np.array(pre10)

第二步,画图数据准备

y_list = df['y_test'].to_list()
pre_list=[pre1,pre2,pre3,pre4,pre5,pre6,pre7,pre8,pre9,pre10]

lable_names=['1','2','3','4','5','6','7','8','9','10']
colors1 = ["r","b","g",'gold','pink','y','c','m','orange','chocolate']
colors2 = "skyblue"# "mistyrose","skyblue","palegreen"
my_list = []
linestyles =["-", "--", ":","-", "--", ":","-", "--", ":","-"]

第三步,绘制十分类ROC曲线

plt.figure(figsize=(12,5),facecolor='w')
for i in range(10):
    roc_auc = 0
     #添加文本信息
    if i==0:
        fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=1)
        # 计算AUC的值
        roc_auc = auc(fpr, tpr)
        plt.text(0.3, 0.01, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)
    elif i==1:
        fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=2)
        # 计算AUC的值
        roc_auc = auc(fpr, tpr)
        plt.text(0.3, 0.11, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)
    elif i==2:
        fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=3)
        # 计算AUC的值
        roc_auc = auc(fpr, tpr)
        plt.text(0.3, 0.21, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)
    elif i==3:
        fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=4)
        # 计算AUC的值
        roc_auc = auc(fpr, tpr)
        plt.text(0.3, 0.31, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)
    elif i==4:
        fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=5)
        # 计算AUC的值
        roc_auc = auc(fpr, tpr)
        plt.text(0.3, 0.41, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)
    elif i==5:
        fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=6)
        # 计算AUC的值
        roc_auc = auc(fpr, tpr)
        plt.text(0.6, 0.01, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)
    elif i==6:
        fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=7)
        # 计算AUC的值
        roc_auc = auc(fpr, tpr)
        plt.text(0.6, 0.11, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)
    elif i==7:
        fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=8)
        # 计算AUC的值
        roc_auc = auc(fpr, tpr)
        plt.text(0.6, 0.21, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)
    elif i==8:
        fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=9)
        # 计算AUC的值
        roc_auc = auc(fpr, tpr)
        plt.text(0.6, 0.31, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)
    elif i==9:
        fpr, tpr, threshold = roc_curve(y_list,pre_list[i],pos_label=10)
        # 计算AUC的值
        roc_auc = auc(fpr, tpr)
        plt.text(0.6, 0.41, "class "+lable_names[i]+' :ROC curve (area = %0.2f)' % roc_auc)
    my_list.append(roc_auc)
    # 添加ROC曲线的轮廓
    plt.plot(fpr, tpr, color = colors1[i],linestyle = linestyles[i],linewidth = 3,
             label = "class:"+lable_names[i])  #  lw = 1,
    #绘制面积图
    plt.stackplot(fpr, tpr, colors=colors2, alpha = 0.5,edgecolor = colors1[i]) #  alpha = 0.5,
   
# 添加对角线
plt.plot([0, 1], [0, 1], color = 'black', linestyle = '--',linewidth = 3)
plt.xlabel('1-Specificity')
plt.ylabel('Sensitivity')
plt.grid()
plt.legend()
plt.title("手机上网稳定性ROC曲线和AUC数值")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1298772.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

TailwindCSS 如何设置 placeholder 的样式

前言 placeholder 在前端多用于 input、textarea 等任何输入或者文本区域的标签,它用户在用户输入内容之前显示一些提示。浏览器自带的 placeholder 样式可能不符合设计规范,此时就需要通过 css 进行样式美化。 当项目中使用 TailwindCSS 处理样式时&a…

手把手教你使用axure9画出图书出借的功能界面(原型模型)从0实现图书借阅界面

问题 设计图书出借的功能界面,并使用axure画出界面原型(pc端或移动端都可以)。就你的设计,你觉得有哪些方面需要跟用户沟通确认? 一、登录界面 1.先将图片背景改成灰色 2.插入文本框 3.插入文字,输入图书…

物联网第十四周总结

本周任务 消息转换器 PostgreSQL学习,JetLinks配置PostgreSQL 问题与总结 JetLinks配置PostgreSQL的时候,启动报错 2023-12-08 09:34:30.478 ERROR 19028 --- [actor-tcp-nio-1] o.h.e.r.e.r.r.R2dbcReactiveSqlExecutor : > Error: c…

Elasticsearch 8.9 refresh刷Es缓冲区的数据到Lucene,更新segemnt,使数据可见

一、相关API的handler1、接受HTTP请求的hander(RestRefreshAction)2、往数据节点发送刷新请求的action(TransportRefreshAction)3、数据节点接收主节点refresh传输的action(TransportShardRefreshAction) 二、在IndexShard执行refresh操作1、根据入参决定是使用lucene提供的阻塞…

Http请求(bug)——路径变量传参遇到特殊符号的问题 URL中的#,?,符号作用

前言 本篇博客分析路径变量传参遇到特殊符号的问题,阐述了URL中的#,?,&符号作用。 目录 前言引出路径变量传参遇到特殊符号的问题问题描述问题分析 URL中的 #,?,&符号的作用URL中# 的作…

【探索Linux】—— 强大的命令行工具 P.21(多线程 | 线程同步 | 条件变量 | 线程安全)

阅读导航 引言一、线程同步1. 竞态条件的概念2. 线程同步的概念 二、条件变量1. 条件变量函数⭕使用前提(1)初始化条件变量(2)等待条件满足(3)唤醒等待pthread_cond_broadcast()pthread_cond_signal() &…

Qexo博客后台管理部署

Qexo博客后台管理部署 个人主页 个人博客 参考文档 https://www.oplog.cn/qexo/本地部署 采用本地Docker部署管理本地Hexo 下载代码包 若无法下载使用科学工具下载到本地在上传到服务器 wget https://github.com/Qexo/Qexo/archive/refs/tags/3.0.1.zip# 解压 unzip Qexo…

SQL命令---修改字段的排列位置

介绍 使用sql语句表字段的排列顺序。 命令 alter table 表名 modify 字段名1 数据类型 first|after 字段名2;例子 将a表中的age字段改为表的第一个字段。 alter table a modify age int(12) first;下面是执行命令后的表结构: 将a表中的age字段放到name字段之…

【linux】查看CPU和内存信息

之前咱们一起学习了查看内存的和CPU的命令。 ​mpstat : 【linux】 mpstat 使用 uptime:【Linux】 uptime命令使用 CPU的使用率:【linux】查看CPU的使用率 nmon :【linux】nmon 工具使用 htop :【linux】htop 命令…

学习Linux(2)-学习Linux命令

Linux目录结构 Linux目录结构-菜鸟教程 /bin:bin 是 Binaries (二进制文件) 的缩写, 这个目录存放着最经常使用的命令。 /boot:这里存放的是启动 Linux 时使用的一些核心文件,包括一些连接文件以及镜像文件。 /dev :dev 是 De…

Cocos Creator:创建棋盘

Cocos Creator:创建棋盘 创建地图三部曲:1. 创建layout组件2. 创建预制体Prefab,做好精灵贴图:3. 创建脚本LayoutSprite.ts收尾工作: 创建地图三部曲: 1. 创建layout组件 使用layout进行布局,…

数据表记录的操作

一、数据添加 1、打开SSMS,附加数据库(数据库文件在自己的文件夹下面),并进行下面的设置: (1)设置“部门信息”表中的“编号”为主键(SSMS) 首先建立好所需的数据库库…

HNU计算机视觉作业三

前言 选修的是蔡mj老师的计算机视觉,上课还是不错的,但是OpenCV可能需要自己学才能完整把作业写出来。由于没有认真学,这门课最后混了80多分,所以下面作业解题过程均为自己写的,并不是标准答案,仅供参考 …

单臂路由与三层交换机

单臂路由 划分VLAN后同一VLAN的计算机属于同一个广播域,同一VLAN的计算机之间的通信是不成问题的。然而,处于不同VLAN的计算机即使是在同一交换机上,它们之间的通信也必须使用路由器。 图(a)是一种实现VLAN间路由的方…

ubuntu上搭建bazel编译环境,构建Android APP

背景是github上下载的工程,说明仅支持bazel编译,折腾了一天Android studio,失败。 不得不尝试单价bazel编译环境,并不复杂,过程记录如下 说明:ubuntu环境是20.04,pve虚拟机安装 1.安装jdk sudo…

docker-compose安装教程

1.确认docker-compose是否安装 docker-compose -v如上图所示表示未安装,需要安装。 如上图所示表示已经安装,不需要再安装,如果觉得版本低想升级,也可以继续安装。 2.离线安装 下载docker-compose安装包,上传到服务…

如何将html网页免费转为excel?

一、直接复制。 直接复制是最简单有效、快捷的解决方案,操作方法如下: 1、用鼠标像平常复制文本一样,将整个网页表格选中。 2、点击右键,点击“复制”。 3、打开excel软件,鼠标点击任意单元格。 4、点击右键&#…

leetcode7 移除列表中特定元素

给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素,并返回移除后数组的新长度。 不要使用额外的数组空间,你必须仅使用 O(1) 额外空间并 原地 修改输入数组。 元素的顺序可以改变。你不需要考虑数组中超出新长度后面…

【Cisco Packet Tracer】路由器 NAT实验

NAT的实现方式有三种,即静态转换Static Nat、动态转换Dynamic Nat和端口多路复用OverLoad。 静态转换是指内部本地地址一对一转换成内部全局地址,相当内部本地的每一台PC都绑定了一个全局地址。一般用于在内网中对外提供服务的服务器。 [3] 动态转换是指…

springboot+java医保付费绩效管理平台ssm

随着社会的飞速发展,特别是信息技术的迅猛发展,各行各业都在努力与现代先进技术接轨,通过科技手段提高自身的优势。当然,也不能排除医保付费及绩效管理行业。随着网络技术的不断成熟,医保付费及绩效管理的发展得到了促…