机器学习入门实例-MNIST手写数据集-二分分类效果评估

news2024/10/1 15:17:41

接上文的Binary Classifier,将数据分成“是2”和“非2”两类。

Performance Measures 分类效果评价方法

Accuracy(准确性)

y_train_2 = (y_train == 2)
...
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_2)

from sklearn.model_selection import cross_val_score
accuracy = cross_val_score(sgd_clf, X_train, y_train_2, cv=4, scoring="accuracy")
print(accuracy) 
# [0.97066667 0.9674     0.97653333 0.9748    ]

虽然这个准确性不错,但是如果数据有偏向性呢?假设构造一个分类器,对所有数据都分类成“非2”,再看准确性:

class Never2Classifier(BaseEstimator):
    def fit(self, X, y=None):
        return self
    def predict(self, X):
        return np.zeros((len(X), 1), dtype=bool)
...

from sklearn.model_selection import cross_val_score
never_2_clf = Never2Classifier()
print(cross_val_score(never_2_clf, X_train, y_train_2, cv=4, scoring="accuracy"))
# [0.90253333 0.90093333 0.90033333 0.899     ]

这个说明只有10%的图像是2(肯定的啊,只是0~9十个数字的手写库,一定会保证每个数字占1/10左右)。所以,对于有偏的数据集(skewed datasets,指某些类比其他类拥有更多的数据),准确性(accuracy)并不是一个很好的指标。

Confusion Matrix(混淆矩阵)

# 训练模型
y_train_2 = (y_train == 2)
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train_2)

from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix

# 计算混淆矩阵
# cross_val_predict进行k-fold cross-validation,但是不返回分数
# 返回的是每个fold上的预测
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_2, cv=4)
print(confusion_matrix(y_train_2, y_train_pred))

# 输出
[[53230   812]
 [  847  5111]]

解释:
每一行表示一个分类,每一列表示一个预测分类。比如第一行表示“非2”的分类(the negative class),所以53230个被正确判定为“非2”(true negatives),812个被错误地判定为“是2”(false positive);第二行表示“是2”的分类(the positive class),847被错误地判定为“非2”(false negatives),5111个被正确判定为“是2”(true positive)。一个好的分类器应该只有true positive和true negative,所以应该只在主对角线(左上到右下)有非零数值,其他位置都应该是0。

预测为正预测为负
实际为正TPFN
实际为负FPTN

精确率 precision = TP / (TP + FP)
召回率 recall = TP / (TP + FN) 也称为 sensitivity、true positive rate(TPR)
记忆方法:precision表示当预测为正时,正确的概率,所以÷竖向;recall表示只能预测出多少正,所以÷横向。

from sklearn.metrics import precision_score, recall_score
print(precision_score(y_train_2, y_train_pred)) # 其实就是 5111/(5111+812) 约0.863
print(recall_score(y_train_2, y_train_pred)) # 其实就是 5111/(5111+847) 约0.858

F1 score则结合了precision和recall。只有二者都大时,f1 score才会大。
在这里插入图片描述
使用方法是类似的:

from sklearn.metrics import precision_score, recall_score, f1_score

print(f1_score(y_train_2, y_train_pred)) # 0.8603652891170777

Precision/Recall Trade-off(精确率和召回率的权衡)

对于precision和recall双高的情况,f1 score挺好的,但对于某些特殊情况,比如一个监控系统,完全可以低precision,高recall(可以有很多次假警报,即高FP,但基本所有小偷都抓到了,即低FN)。但是这样的话,f1是比较低的。

从随机梯度下降(SGD)分类器的设计思路考虑:对于每个实例,代入decision function计算分数。如果这个分数高于阈值,就会分到positive类,否则分入negative类。这样的话,如果阈值降低,会有更多的FP,因为FP和FN的总量是一定的,那么FN会降低,所以recall会升高,precision会降低;相反,如果阈值升高,FP变少,FN变多,则recall降低,precision升高,如下图所示:
在这里插入图片描述
虽然scikit learn不允许直接获取这个阈值,但是却可以拿到用decision function计算出的分数。那么,可以先拿到所有实例的分数(cross_val_predict),然后绘制recall-threshold,precision-threshold曲线,然后选择threshold。

from sklearn.model_selection import cross_val_predict
from sklearn.metrics import precision_recall_curve
y_scores = cross_val_predict(sgd_clf, X_train, y_train_2, cv=4, method="decision_function")
precisions, recalls, thresholds = precision_recall_curve(y_train_2, y_scores)
# np.argmax指寻找第一个满足precisions>=0.9的索引
recall_90_precision = recalls[np.argmax(precisions >= 0.9)] # 打印出来是 0.8217522658610272
threshold_90_precision = thresholds[np.argmax(precisions >= 0.9)] # 868.8893539759117
# 绘图
plt.figure(figsize=(8, 4))
# precision最后一个值是1,recall最后一个值是0,可以不必显示。而且threshold也比它俩少一个数值
plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
plt.legend(loc="center right", fontsize=16)
plt.xlabel("Threshold", fontsize=16)
plt.grid(True)
# x轴范围
plt.axis([-50000, 50000, 0, 1])
# r:表示红色虚线
# 绘制线段
plt.plot([threshold_90_precision, threshold_90_precision], [0., 0.9], "r:")
plt.plot([-50000, threshold_90_precision], [0.9, 0.9], "r:")
plt.plot([-50000, threshold_90_precision], [recall_90_precision, recall_90_precision], "r:")
# 绘制两个点
plt.plot([threshold_90_precision], [0.9], "ro")
plt.plot([threshold_90_precision], [recall_90_precision], "ro")
plt.show()

在这里插入图片描述

y_train_pred_90 = (y_scores >= threshold_90_precision)

from sklearn.metrics import precision_score, recall_score

print(precision_score(y_train_2, y_train_pred_90))
print(recall_score(y_train_2, y_train_pred_90))

0.9
0.8217522658610272

例如,以下代码演示了如何将阈值设置为 0.3:
在下面的代码中,predict_proba() 函数返回的是一个二维数组,第一列是预测为负例的概率,第二列是预测为正例的概率。我们通过 [:, 1] 来获取预测为正例的概率,并与阈值比较,将结果转换为 0 或 1。

from sklearn.linear_model import SGDClassifier

clf = SGDClassifier(loss='log')
clf.fit(X_train, y_train)

threshold = 0.3
y_pred = (clf.predict_proba(X_test)[:, 1] > threshold).astype(int)

ROC Curve

ROC曲线的横轴是FPR(False positive rate),纵轴是TPR(True positive rate,也就是recall,sensitivity)。其中,TNR(True negative rate,也叫specifitivity)。ROC也是sensitivity和1-specificity的曲线。
FPR = 1 - TNR
TNR = TN / (TN+FP)

根据模型的预测结果,将样本按照从高到低的概率值排序,然后在不同的阈值下计算 TPR 和 FPR,就可以得到 ROC 曲线。ROC 曲线越靠近左上角,说明模型的性能越好。

from sklearn.metrics import roc_curve
fpr, tpr, thresholds = roc_curve(y_train_2, y_scores)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, linewidth=2)
plt.plot([0, 1], [0, 1], 'b--') # dashed diagonal
plt.axis([0, 1, 0, 1])
plt.xlabel('FPR')
plt.ylabel('TPR(Recall)')
plt.grid(True)
fpr_90 = fpr[np.argmax(tpr >= recall_90_precision)]
plt.plot([fpr_90, fpr_90], [0., recall_90_precision], "r:")
plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], "r:")
plt.plot([fpr_90], [recall_90_precision], "ro")
plt.show()

在这里插入图片描述
比较两个分类器的好坏可以用AUC。AUC就是area under the curve,指ROC曲线下面的面积。最好的分类器auc=1,随便一个分类器auc=0.5,scikit learn提供了计算函数。

from sklearn.metrics import roc_auc_score
print(roc_auc_score(y_train_2, y_scores))
# 0.9735367947441673

另一种曲线PR:
在这里插入图片描述
选PR还是ROC?
如果positive类比较少或者相比false negative,更关心false positive,那就应该使用PR,否则使用ROC。

比较随机森林和SGD:
在这里插入图片描述

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_curve

forest_clf = RandomForestClassifier(random_state=42)
# 返回值是一个array,每行代表一个实例,包含其在各类的概率
y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_2, cv=4,
                                    method="predict_proba")
# 取第二列
y_scores_forest = y_probas_forest[:, 1]
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_2, y_scores_forest)

fpr, tpr, thresholds = roc_curve(y_train_2, y_scores)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr,"b:", label="SGD")
plt.plot(fpr_forest, tpr_forest, label="Random Forest")
plt.plot([0, 1], [0, 1], 'b--') # dashed diagonal
plt.legend(loc="lower right")
plt.axis([0, 1, 0, 1])
plt.xlabel('FPR')
plt.ylabel('TPR(Recall)')
plt.grid(True)
fpr_90 = fpr[np.argmax(tpr >= recall_90_precision)]
plt.plot([fpr_90, fpr_90], [0., recall_90_precision], "r:")
plt.plot([0.0, fpr_90], [recall_90_precision, recall_90_precision], "r:")
plt.plot([fpr_90], [recall_90_precision], "ro")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/437610.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

连锁店销售管理系统有哪些功能?应该如何选购?

不管是直营还是加盟,想要实现门店的精细化管理,把不同门店的业绩做好,离不开连锁店销售管理系统的支持。 一款真正能够为连锁店经营带来帮助的连锁店销售管理系统应该具备哪些基本功能,以及选择连锁店销售管理系统时有哪些常见的问…

【科研工具】Zotero实现自动翻译

科研党基本都用过Zotero吧,方便文件管理和做笔记。我常使用的一款插件,可以实现paper英文内容的自动翻译为中文,非常简单、好用,现推荐给大家。 目录 一、下载zotero-pdf-translate插件 1.1 登录GitHub 1.2 找到.xpi文件并下载…

java 拼接字符串的方法

1.拼接字符串的方法,先要将字符串转化为数字类型,再根据需要拼接。这样可以避免直接拼接导致的错误。 2.将字符串转化为数字类型,这个就是一个循环。可以使用循环的方法,但是循环次数不宜太多,否则容易出错。 3.可以使…

微信小程序登陆(全流程-前后端)

环境要求 1.注册一个小程序 2.微信开发者工具 3.idea(springboot) 目录 项目准备 用户登陆 前端开发,传递code index.wxss index.js 后端编写,调用微信接口,获取openId 现在用户的所有信息都拿不到,只能用户自己填写 其…

MySQL的停止与启动、与客户端的连接(参见黑马程序员)

1、启动与停止 (1)Windowsr 输入 services.msc 在其中找MySQL并点鼠标右键,即可设定是停止还是启动 (2)以管理员身份打开cmd命令 (具体步骤:左下角点搜索输入cmd,在出现的选项里…

数字温湿度传感器DHT11

今天我们来说说一个新的模块DHT11——温湿度传感器 顾名思义,通过开发DHT11能够进行温湿度检测,是一个非常实用且有趣的模块,下面我们先对DHT11基本信息做一个了解,然后进行开发。 DHT11的优点: ►相对湿度和温度测…

算法篇——N个数之和大集合(js版)

1.两数之和 给定一个整数数组 nums 和一个整数目标值 target,请你在该数组中找出 和为目标值 target 的那 两个 整数,并返回它们的数组下标。 你可以假设每种输入只会对应一个答案。但是,数组中同一个元素在答案里不能重复出现。 你可以按…

如何招生?一文教你高职院校有效的招生技巧

生源,是每一所高校的生存之本和生命线。 近几年招生宣传工作作为高职院校招生工作中的重要环节之一,具有政策性强,涉及面广,工作量大等特点,直接关系到学校可持续发展问题。 随着新媒体时代的发展,高职院…

炫酷的3DCSS卡片样式

先效果图展示&#xff1a; 再上代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title><style>*,*::after,*::before {margin: 0;padding: 0;box-sizing: bord…

【LeetCode】94.二叉树的中序遍历

1.问题 给定一个二叉树的根节点 root &#xff0c;返回 它的 中序 遍历 。 示例 1&#xff1a; 输入&#xff1a;root [1,null,2,3] 输出&#xff1a;[1,3,2] 示例 2&#xff1a; 输入&#xff1a;root [] 输出&#xff1a;[] 示例 3&#xff1a; 输入&#xff1a;root […

Vue3中使用 EventBus 实现兄弟组件传参

前言&#xff1a;EventBus vue3中没有了&#xff0c;EventBus&#xff0c;所以我们要自己写&#xff0c;但是非常简单。 步骤一&#xff1a;创建&#xff08;EventBus 容器&#xff09; 在src目录&#xff0c;创建个bus文件夹&#xff0c;存放 自己建立的 bus.js class Bus…

Springboot 中快速完成文件上传,整合多平台神器

哈喽&#xff0c;大家好~ 又是做好人好事的一天&#xff0c;有个小可爱私下问我有没有好用的springboot文件上传工具&#xff0c;这不巧了嘛&#xff0c;正好我私藏了一个好东西&#xff0c;顺便给小伙伴们也分享一下&#xff0c;demo地址放在文末了。 文件上传在平常不过的一…

1.黑马Springboot基础篇笔记

Springboot基础篇 1.快速上手Springboot 1.基础配置 1.parent 作用&#xff1a;指定jar包版本信息信息&#xff0c;避免依赖版本冲突 2.starter 作用:SpringBoot中常见项目名称&#xff0c;定义了当前项目使用的所有依赖坐标&#xff0c;以达到减少依赖配置的目的使用任意…

扬帆优配|逼近历史最高点!刚刚,A股这一板块沸腾!

今天早盘&#xff0c;A股整体小幅走强&#xff0c;上证指数创阶段性新高&#xff0c;并逼近年内最高点&#xff0c;科创50指数则大涨超2%领涨两市。 盘面上&#xff0c;新能源车、黄金、锂矿、建筑等板块涨幅居前&#xff0c;互联网、传媒娱乐、知识产权、博彩概念等板块跌幅居…

开源项目创始人的营销建议:让开源项目脱颖而出

来自开源创始人的营销建议 面对现实吧&#xff0c;如果你想让你的开源项目变成主业&#xff0c;就得投入一定的精力对它进行营销。 这并不意味着几篇空洞的文章加上夺人眼球的标题&#xff0c;而是要向用户清晰地传达产品的功能&#xff0c;并帮助他们轻松发现产品的优势。 本文…

什么是数智化招采?如何实现数智化招采(系统)?

数智化&#xff0c;是当今信息技术领域的一个热门话题。它的应用范围非常广泛&#xff0c;包括商业、医疗、科学、政府、城市、企业、社会等各个领域。随着现代信息技术的不断发展&#xff0c;数智化已经成为各行各业中的一个重要趋势。 什么是数智化招采 信息化是数据形成的…

ChatGPT实战100例 - (06) 10倍速可视化组织架构与人员协作流程

文章目录 ChatGPT实战100例 - (06) 10倍速可视化组织架构与人员协作流程一、需求与思路二、 组织架构二、 人员协作四、 总结 ChatGPT实战100例 - (06) 10倍速可视化组织架构与人员协作流程 一、需求与思路 管理研发团队的过程中&#xff0c;组织架构与人员协作流程的可视化是…

《商用密码应用与安全性评估》第一章密码基础知识1.7密码功能实现示例

保密性实现 访问控制&#xff1a;防止敌手访问敏感信息 信息隐藏&#xff1a;避免敌手发现敏感信息 信息加密&#xff1a;允许观测&#xff0c;但无法提炼信息 几种分组密码工作模式的区别&#xff1a; 名称全称优点缺点ECB电子密码本模式简单、快速、并行不抗重放CBC密码分组…

计算机:理解操作系统:内存篇(上)

内存篇 1. 什么是内存2. C/C内存模型2.1 代码段和数据段2.2 堆和栈 本节是操作系统系列教程的第三篇文章&#xff0c;属于操作系统第一章即基础篇&#xff0c;在真正开始操作系统相关章节前在这一部分回顾一些重要的主题&#xff0c;算是温故知新吧&#xff0c;以下是目录&…

瑞吉外卖项目——瑞吉外卖

软件开发整体介绍 软件开发流程 需求分析&#xff1a;产品原型、需求规格说明书 设计&#xff1a;产品文档、UI界面设计、概要设计、详细设计、数据库设计 编码&#xff1a;项目代码、单元测试 测试&#xff1a;测试用例、测试报告 上线运维&#xff1a;软件环境安装、配置…