机器学习之数据分离与混淆矩阵

news2024/11/16 23:35:11

文章目录

    • @[TOC](文章目录)
  • 前言
  • 数据分离
  • 混淆矩阵
  • 實戰
  • 總結

前言

例如:随着人工智能的不断发展,机器学习这门技术也越来越重要,很多人都开启了学习机器学习,本文就介绍了机器学习的基础内容之数据分离与混淆矩阵。


数据分离

前面我们讲过的实例基本流程都是这样的:
数据载入→数据可视化与预处理→模型创建→全数据用于模型训练→模型评估

按照这个流程的前提是我们要有新数据来作为测试数据供我们使用,那么如果我们没有新数据用于模型评估表现怎么办呢?这时就需要用到数据分离了。

简单来说数据分离就是对全数据进行数据分离,部分数据用来训练,部分数据用于新数据的结果预测。一般就是二八分還是三七分吧

通常来说分为3步:
在这里插入图片描述
小例
在这里插入图片描述
在这里插入图片描述

这里就是将3个数据作为训练数据,两个数据作为测试数据。

混淆矩阵

在我们前面的分类任务中,都是计算测试数据集预测准确率(accuracy)以评估模型表现,但如果只用accuracy会有很大的局限性。
在这里插入图片描述
举个例子:
在这里插入图片描述

如果建立模型如模型1所示,预测准确率就为90%;但如果模型如模型2所示,预测所有的样本结果都是1,也就是说按这个模型无论测试数据是什么,结果都为1,这样就会导致无论数据怎样,我的准确率都是90%,这时得到的准确率就是空准确率。

在这里插入图片描述
这就是只用accuracy的局限性,即我们不知道各类型数据的预测情况,就拿上例来说,我们不知道0中预测准确率是多少,也不知道1中预测准确率是多少,我们只知道整体准确率是90%。同时,我们也无法看出模型错误预测的类型,我们并不知道是0预测不对还是1预测不对,这样会给我们优化模型带来很多不便。

所以,为了更好地评估模型,我们可以用混淆矩阵。

混淆矩阵:又称为误差矩阵,用于衡量分类算法的准确程度
具体的混淆矩阵如下图所示:
在这里插入图片描述
在这里插入图片描述

另外,通过混淆矩阵可以计算更丰富的模型评估指标:

在这里插入图片描述
举个例子:

在这里插入图片描述

在这里插入图片描述
这时我们可以发现负样本中预测正确的比例为0,我们就可以清楚地知道这个指标的预测是有问题的,进而可以修改优化我们的模型。

在这里插入图片描述

在这些指标中没有说谁更重要、谁不重要,其重要度取决于应用场景

以上課件來自慕課

實戰

好壞質檢分類實戰task:

  1. 基於data_class_raw.csv數據,根據高斯分佈概率密度函數,尋找異常點並剔除
  2. 基於data_class_processed.csv數據,進行PCA處理,確定重要數據維度及成分
  3. 完成數據分離,數據分離參數:random_state=4,test_size=0.4
  4. 建立KNN模型完成分類,n_beighbors取10,計算分類準確率,可視化分類邊界
  5. 計算測試數據集對應的混淆矩陣,計算準確率、召回率、特異度、精確率、F1分數
  6. 嘗試不同的n_neighbors(1-20),計算其在訓練數據集、測試數據集上的準確率並作圖
#load the data
import pandas as pd
import numpy as np
data = pd.read_csv('data_class_raw.csv')
data.head()

在这里插入图片描述

#define X and y
X = data.drop(['y'],axis=1)
y = data.loc[:,'y']

#visualize the data
from matplotlib import pyplot as plt
fig1 = plt.figure(figsize=(5,5))
bad = plt.scatter(X.loc[:,'x1'][y==0],X.loc[:,'x2'][y==0])
good = plt.scatter(X.loc[:,'x1'][y==1],X.loc[:,'x2'][y==1])
plt.legend((good,bad),('good','bad'))
plt.title('raw data')
plt.xlabel('x1')
plt.ylabel('x2')
plt.show()

在这里插入图片描述

#anomy detection
from sklearn.covariance import EllipticEnvelope
ad_model = EllipticEnvelope(contamination=0.02)
ad_model.fit(X[y==0])
y_predict_bad = ad_model.predict(X[y==0])
print(y_predict_bad)

-1為異常點

在这里插入图片描述

fig2 = plt.figure(figsize=(5,5))
bad = plt.scatter(X.loc[:,'x1'][y==0],X.loc[:,'x2'][y==0])
good = plt.scatter(X.loc[:,'x1'][y==1],X.loc[:,'x2'][y==1])
plt.scatter(X.loc[:,'x1'][y==0][y_predict_bad==-1],X.loc[:,'x2'][y==0][y_predict_bad==-1],marker='x',s=150)
plt.legend((good,bad),('good','bad'))
plt.title('raw data')
plt.xlabel('x1')
plt.ylabel('x2')
plt.show()

在这里插入图片描述

任務二

data = pd.read_csv('data_class_processed.csv')
data.head()
#define X and y
X = data.drop(['y'],axis=1)
y = data.loc[:,'y']
#pca
from sklearn.preprocessing import StandardScaler  #標準化處理
from sklearn.decomposition import PCA
X_norm = StandardScaler().fit_transform(X) #標準化處理
pca = PCA(n_components=2) #二維
X_reduced = pca.fit_transform(X_norm)
var_ratio = pca.explained_variance_ratio_ #標準差比例
print(var_ratio)
fig3 = plt.figure(figsize=(5,5))
plt.bar([1,2],var_ratio)
plt.show()

可以看出這兩個占比都很大也就沒必要降維了
在这里插入图片描述

# train and test split: random_state=4,test_size=0.4
from sklearn.model_selection import train_test_split #數據分離
X_train,X_test,y_train,y_test = train_test_split(X,y,random_state=4,test_size=0.4)
print(X_train.shape,X_test.shape,X.shape) #看一下訓練數據、測試數據占原始數據的多少

數據分離

在这里插入图片描述

#knn model
from sklearn.neighbors import KNeighborsClassifier
knn_10 = KNeighborsClassifier(n_neighbors=10)
knn_10.fit(X_train,y_train)
y_train_predict = knn_10.predict(X_train)
y_test_predict = knn_10.predict(X_test)

#calculate the accuracy
from sklearn.metrics import accuracy_score
accuracy_train = accuracy_score(y_train,y_train_predict)
accuracy_test = accuracy_score(y_test,y_test_predict)
print("training accuracy:",accuracy_train)
print("testing accuracy:",accuracy_test)

可以看出訓練數據準確率還可以,測試數據的不太行
在这里插入图片描述

#visualize the knn result and boundary
xx, yy = np.meshgrid(np.arange(0,10,0.05),np.arange(0,10,0.05)) #生成0-10間隔0.05的數據
print(xx)
print(xx.shape)

生成佈滿整個圖的點,主要是爲了畫出分類邊界
在这里插入图片描述

x_range = np.c_[xx.ravel(),yy.ravel()]
print(x_range.shape)
y_range_predict = knn_10.predict(x_range)

在这里插入图片描述

fig4 = plt.figure()
knn_bad = plt.scatter(x_range[:,0][y_range_predict==0],x_range[:,1][y_range_predict==0])
knn_good = plt.scatter(x_range[:,0][y_range_predict==1],x_range[:,1][y_range_predict==1])

bad = plt.scatter(X.loc[:,'x1'][y==0],X.loc[:,'x2'][y==0])
good = plt.scatter(X.loc[:,'x1'][y==1],X.loc[:,'x2'][y==1])
plt.legend((good,bad,knn_good,knn_bad),('good','bad','knn_good','knn_bad'))
plt.title('predoction data')
plt.xlabel('x1')
plt.ylabel('x2')
plt.show()

在这里插入图片描述

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(y_test,y_test_predict) #混淆矩陣
print(cm)

直接用sklearn裏面包生成混淆矩陣

在这里插入图片描述

在这里插入图片描述

TN = cm[0,0]
FP = cm[0,1]
FN = cm[1,0]
TP = cm[1,1]
print(TN,FP,FN,TP)

在这里插入图片描述

算一下各種指標

accuaracy = (TP + TN)/(TP + TN + FP + FN)
print(accuaracy)

可以看出這個跟上面算的是一樣的
在这里插入图片描述

recall = TP/(TP + FN)
print(recall)

在这里插入图片描述

specificity = TN/(TN + FP)
print(specificity)

在这里插入图片描述

precision = TP/(TP + FP)
print(precision)

在这里插入图片描述

f1 = 2 * precision * recall / (precision + recall)
print(f1)

在这里插入图片描述

#try different k and calcualte the accuracy for each
n = [i for i in range(1,21)]
accuaracy_train = []
accuaracy_test = []
for i in n:
    knn = KNeighborsClassifier(n_neighbors=i)
    knn.fit(X_train,y_train)
    y_train_predict = knn.predict(X_train)
    y_test_predict = knn.predict(X_test)
    accuaracy_train_i = accuracy_score(y_train,y_train_predict)
    accuaracy_test_i = accuracy_score(y_test,y_test_predict)
    accuaracy_train.append(accuaracy_train_i)
    accuaracy_test.append(accuaracy_test_i)
print(accuaracy_train,accuaracy_test)

嘗試不同的n_neighbors(1-20),計算其在訓練數據集、測試數據集上的準確率並作圖

在这里插入图片描述

fig5 = plt.figure(figsize=(15,5))
plt.subplot(121)
plt.plot(n,accuaracy_train,marker='o')
plt.title('training accuracy vs n_neighbors')
plt.xlabel('n_neighbors')
plt.ylabel('accuaracy')
plt.subplot(122)
plt.plot(n,accuaracy_test,marker='o')
plt.title('testing accuracy vs n_neighbors')
plt.xlabel('n_neighbors')
plt.ylabel('accuaracy')

plt.show()

可視化看一下

在这里插入图片描述

可以看出n_neighbors值在20時效果最差,n_neighbors小時訓練數據效果好測試數據效果差,總體來説n_neighbors值和準確率並沒有很有規律,還是得多嘗試才知道n_neighbors取什麽好

總結

好壞質檢分類實戰summary:

  1. 通過進行異常監測,幫助找到了潛在的異常數據點
  2. 通過PCA分析,發現需要保留2維數據集
  3. 實現了訓練數據與測試數據的分離,並計算模型對於測試數據的預測準確率
  4. 計算得到混淆矩陣,實現模型更全面的評估
  5. 通過新的方法,可視化分類的決策邊界
  6. 通過調整核心參數n_neighbors值(K值),在計算對應的準確率,可以幫助我們更好的確定使用哪個模型
  7. 核心算法参考链接:
  • https://scikit-learn.org.cn/view/695.html
  • https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html#sklearn.neighbors.KNeighborsClassifier

這就是本次學習数据分离与混淆矩阵的筆記
附上本次實戰的數據集和源碼:
鏈接:https://github.com/fbozhang/python/tree/master/jupyter

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/70779.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

多线程间的同步控制和通信

用多线程并发处理,目的是为了让程序更充分地利用CPU ,好能加快程序的处理速度和用户体验。如果每个线程各自处理的部分互不相干,那真是极好的,我们在程序主线程要做的同步控制最多也就是等待几个工作线程的执行完毕,如…

weston input 概述

weston input 概述 零、前言 本文描述了有关于 weston (基于 wayland 协议一种显示服务器的实现) 中有关于输入设备管理的部分;为了聚焦于此,本文不会对 weston 整体或 wayland 协议进行过多的阐述. 考虑到读者可能存在不同的需求,采用分层次的描述方式,主要面向以下两类人群…

Android Studio 导入opencv异常报错紧急救援

Download OpenCV from SourceForge.net 1、下载Android demo之后导入Android Studio 如下图所示 报错信息如下 A problem occurred configuring root project opencv_samples. > Could not resolve all artifacts for configuration :classpath.> Could not find org.j…

校园失物招领毕业设计,学生失物招领系统设计与实现,毕业设计怎么写论文源码开题报告需求分析怎么做

项目背景和意义 目的:本课题主要目标是设计并能够实现一个基于web网页的失物招领网站系统,整个网站项目使用了B/S架构,基于java的springboot框架下开发;管理员通过后台录入信息、管理信息,设置网站信息,管理…

8_4、Java基本语法之线程的通信

一、问题的引入 使用两个线程打印 1-100。线程1, 线程2交替打印? 二、解决问题涉及的方法 涉及到的三个方法: 1.wait():一旦执行此方法,那么调用此方法的线程就会进入阻塞状态,并释放同步监视器。 2.notify():一个线程…

如何使用htmlq提取html文件内容

htmlq能够对 HTML 数据进行 sed 或 grep 操作。我们可以使用 htmlq 搜索、切片和过滤 HTML 数据。让我们看看如何在 Linux 或 Unix 上安装和使用这个方便的工具并处理 HTML 数据。 什么是htmlq? htmlq类似于 jq,但用于 HTML。使用 CSS 选择器从 HTML 文…

[安装] HIVE搭建环境

一、生产环境hive集群架构 参考: hive2.3.7安装记录 hive基础入门与环境的搭建 基础篇七 Hive-2.3.9安装与配置 大数据之Hive 集群搭建 完整使用 数仓(十)hive的Metastore机制 二、前言快读 Hive安装分类 主要是metastore的服务搭建方…

[rsync] 基于rsync的同步

环境 Linux:CentOs7.5 rsync: 3.1.2 rsync安装 一般安装系统时会自带rsync,可通过如下命令查看已经安装的版本信息 rsync --version如果系统未安装,可通过如下方式安装 yum安装【建议】 使用root用户执行yum安装 yum install -y rsync安…

代码随想录刷题记录day37 0-1背包+分割等和子集

代码随想录刷题记录day37 0-1背包分割等和子集 0-1背包 问题:有n件物品和一个最多能背重量为w 的背包。第i件物品的重量是weight[i],得到的价值是value[i] 。每件物品只能用一次,求解将哪些物品装入背包里物品价值总和最大。 例题&#xf…

操作系统实验五 进程间通信-管道通信

实验目的 1.掌握利用管道机制实现进程间的通信的方法 2.了解利用消息缓冲队列机制实现进程间的通信的方法 3..了解利用共享存储区机制实现进程间的通信的方法 五个题目如下 1. 函数int pipe(int fd[2])创建一个管道,管道两端可分别用描述字fd[0]以及fd[1]来描述。需…

多元宇宙算法求解电力系统多目标优化问题(Matlab实现)【电气期刊论文复现与创新】

💥💥💥💞💞💞欢迎来到本博客❤️❤️❤️💥💥💥 🎉作者研究:🏅🏅🏅本科计算机专业,研究生电气学硕…

SpringCloud项目使用Nacos进行服务的注册

本篇介绍Spring cloud项目使用Nacos作为注册中心来进行服务注册及服务发现,并进行简单的测试来验证。 一、简介 nacos是一个集服务发现、服务配置、服务元数据以及流量管理于一体的管理中心,能帮助我们更好的发现、配置和管理微服务。 注意&#xff1…

家政公司网站毕业设计,家政服务系统设计与实现,毕业设计论文源码开题报告需求分析

项目背景和意义 目的:本课题主要目标是设计并能够实现一个基于web网页的家政服务预约系统,整个网站项目使用了B/S架构,基于java的springboot框架下开发;管理员通过后台录入信息、管理信息,设置网站信息,管理…

JS项目打包之ROLLUP.JS入门

一、目的 Rollup是一个用于JavaScript的模块打包器,它将小块代码编译成更大、更复杂的东西,例如库或应用程序。它为JavaScript ES6版本中包含的代码模块使用了新的标准化格式,而不是以前的特殊解决方案,如CommonJS和AMD。ES模块可…

Win10安装Nacos

Win10安装Nacos 文章目录Win10安装Nacos前言下载Nacos安装Nacos验证前言 最近在学微服务的东西,使用的是 Spring Cloud Alibaba 生态,Nacos就是其中关键的一环。 这是 Nacos 的官网地址:https://nacos.io/zh-cn/index.html 官网的文档对于…

Python中用PyTorch机器学习神经网络分类预测银行客户流失模型

分类问题属于机器学习问题的类别,其中给定一组特征,任务是预测离散值。分类问题的一些常见示例是,预测肿瘤是否为癌症,或者学生是否可能通过考试。 最近我们被客户要求撰写关于银行客户流失的研究报告,包括一些图形和…

@Scheduled定时任务搭配Redis防止多实例定时重复调用

有个Redis安装使用教程&#xff0c;可视化界面&#xff0c;有需要的话&#xff0c;可以打开这个链接去看一下 https://blog.csdn.net/weixin_45507672/article/details/105973279 创建个maven项目&#xff0c;在pom.xml文件加上以下依赖 <dependency><groupId>or…

4EVERLAND专用网关公告,免费体验

我们很高兴地宣布发布 4EVERLAND 专用 IPFS 网关&#xff01;与 4EVERLAND 公共网关一起&#xff0c;4EVERLAND 专用网关将为全世界的开发者和用户提供更快、更稳定地访问更能体现其品牌形象的 IPFS 内容。 专用网关的好处&#xff1a; 全球分布的边缘节点提供全球加速无速率…

[附源码]计算机毕业设计快转二手品牌包在线交易系统Springboot程序

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

Spring Boot 入门到精通(二)

文章目录五、SpringBoot整合MyBatis5.1 mapper 配置5.2 mapper映射配置&#xff1a;配置文件方式5.3 注解配置方式六. 自定义部分SpringMvc配置。6.1 SpringBoot整合日期转换器6.1.1 配置原理6.1.2 日期转换器整合6.2 SpringBoot整合拦截器七. Spring Boot 自定义日志配置&…