机器学习入门实例-MNIST手写数据集-多分分类错误分析多标签分类多输出分类

news2025/1/16 1:56:25

多分类

随机梯度下降、随机森林和朴素贝叶斯都可以处理多分类问题,而logistic回归、支持向量机是严格的二分类分类器,但是可以用一些方法将多个二分类分类器组合在一起完成多分类任务。

1. OvR(one-versus-the-rest、one-versus-all)
比如识别手写数字时,直接训练10个二分分类器,遇到新图片时,分别跑这10个模型,然后选得分最高的那个作为识别结果。

2. OvO(one-versus-one)
比如识别手写数字时,为每一对数字(0-1,0-2…8-9)训练一个二分分类器。那么N个类需要训练 N x (N - 1) / 2 个分类器。遇到新图片时,分别跑45个模型,看哪个类赢得多。OvO的优点是只在当前这对数字的训练集上训练。相比于在大的训练集上训练较少的分类器,在小的数据集上训练较多分类器会更快些。一些算法很难在大数据集上扩展,比如SVM,就比较适合OvO。

当使用二分分类器处理多分类任务时,scikit learn会根据算法特性,自动选择使用OvR策略还是OvO策略。

from sklearn.svm import SVC
some_digit = X[5]
svm_clf = SVC()
svm_clf.fit(X_train, y_train)
some_digit_scores = svm_clf.decision_function([some_digit])
print(svm_clf.predict([some_digit]))
print(some_digit_scores)
print(np.argmax(some_digit_scores))
print(svm_clf.classes_)

[2]
[[ 2.7442516  -0.30125634  9.31399134  7.26965016  3.75753908  3.74771902
   0.71155138  2.76650678  8.2765048   6.21840456]]
2
[0 1 2 3 4 5 6 7 8 9]

由上面可以看到,decision_function的分数有10个,说明是OvR,分数最高的确实是index=2,也就是“2”类。
对于天然能多分类的分类器,比如SGD,它是不需要经过OvR或者OvO策略的,其decision_function()返回的值也会大很多(之前有过负几万到正几千)。

分类器训练时,classes_属性存储了分类列表(已根据值排序),但是index不一定跟类本身一致。这里只是刚好一致。

如果要强制使用OvO或OvR,那么只要创建一个 实例,然后把分类器传递进去即可。

import time
from sklearn.svm import SVC
some_digit = X[5]
from sklearn.multiclass import OneVsOneClassifier, OneVsRestClassifier

start = time.time()
ovo_clf = OneVsOneClassifier(SVC())
ovo_clf.fit(X_train, y_train)
print(ovo_clf.predict([some_digit]))
end = time.time()
print(end - start)
print(len(ovo_clf.estimators_))

start = time.time()
ovr_clf = OneVsRestClassifier(SVC())
ovr_clf.fit(X_train, y_train)
print(ovr_clf.predict([some_digit]))
end = time.time()
print(end - start)
print(len(ovr_clf.estimators_))

[2]
156.24309158325195
45
[2]
1002.182160139084
10

可以看到OvR确实耗时更多。

错误分析

假设已经选好了模型,现在想要提升它的效果,那么方法之一就是分析模型中存在什么错误,予以解决。

假设我们使用一个SGD模型,先做一个简单的缩放,可以看到准确率提升了一些:

from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score

sgd_clf = SGDClassifier()
sgd_clf.fit(X_train, y_train)
print(cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy"))
# [0.88735 0.83815 0.8769 ]

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))
print(cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy"))
# [0.89565 0.89795 0.9003 ]

查看混淆矩阵

from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import cross_val_score, cross_val_predict
from sklearn.metrics import confusion_matrix

sgd_clf = SGDClassifier()
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
print(conf_mx)
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()

[[5597    0   21    7    9   50   37    6  195    1]
 [   0 6405   38   27    4   46    4    8  195   15]
 [  27   27 5262   98   70   28   68   38  330   10]
 [  24   17  118 5265    2  220   25   42  347   71]
 [  12   14   47   12 5225   11   39   27  281  174]
 [  29   16   31  181   54 4514   74   17  440   65]
 [  30   15   46    3   44   95 5561    5  118    1]
 [  20   10   53   34   52   13    3 5710  149  221]
 [  19   63   45   92    3  130   28   11 5412   48]
 [  24   18   30   69  121   38    1  180  288 5180]]

在这里插入图片描述
从上图可以看出,主对角线是浅色的,说明很多图像都是分对了。但是主对角线上一些区域是浅灰色的,应该考虑是图像较少或者分类效果不佳,下面通过一些方法区分这两种情况:
首先将混淆矩阵中的每个数据除以对应的类实例个数

# 按行求和,并且保持原有维度不变
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_max = conf_mx / row_sums
# 把主对角线设为0,所以主对角线变暗,只关注错误的情况
np.fill_diagonal(norm_conf_max, 0)
plt.matshow(norm_conf_max, cmap=plt.cm.gray)
plt.show()

在这里插入图片描述
在上图中,行表示真实的分类,列表示预测的分类。由8号列颜色浅可以看出,许多图都被错误地分类到8了。所以应该努力减少分类错误的8,比如增加一些看起来像8但实际不是8的实例,或者增加一些特性(比如计算图中的封闭环数目,因为8有2个,6有1个,5没有),或者对图像做一些预处理,让某些特征更明显(比如前面说的封闭环)。

另外3和5也有较明显的浅色块,那是因为SGD算法给了每个像素每类一个权重,有新图像进来的时候SGD会计算权重和。因为某些3和5的图只在少量像素处有区别的话,模型就很容易搞混。而且如果某个3的图像顶部有连笔,并且连笔向左偏的话,就更容易搞错了。因此,还可以在图像预处理阶段将图像居中、摆正。

多标签分类

multilabel classification:比如相册的人物标注问题。一张图片里有多个人像,所以要输出一组标签。

some_digit = X[5]
print(y_train[5])

from sklearn.neighbors import KNeighborsClassifier

# 形成一行[True,False..] 长度与y_train相等
y_train_large = (y_train >= 7)
# 同理,判断y_train中每个元素是否是奇数
y_train_odd = (y_train % 2 == 1)
# np.c_将两列合在一起,组成一个二维数组,其中每个元素的第一位表示是否>=7,第二位表示是否为奇数
y_multi_label = np.c_[y_train_large, y_train_odd]
# 支持多标签
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_multi_label)
print(knn_clf.predict([some_digit]))
y_train_knn_pred = cross_val_predict(knn_clf, X_train, y_multi_label, cv=3)
print(f1_score(y_multi_label, y_train_knn_pred, average="macro"))

2
[[False False]]
0.976410265560605

此时,每个标签的权重是一样大的,如果希望每个标签有不一样的权重:

print(f1_score(y_multi_label, y_train_knn_pred, average="weighted"))
0.9778357403921755

多输出分类

multioutput classification:输出的标签可能有多个,每个标签也可能有多个类。比如从图片中去掉噪点的任务:输入一张有噪点的图片,输出一张“干净”的图像,其中包含很多个像素点,每个点的取值范围是0~255。

from sklearn.neighbors import KNeighborsClassifier
knn_clf = KNeighborsClassifier()

# 添加随机噪声
noise = np.random.randint(0, 100, (len(X_train), 784))
X_train_mod = X_train + noise
noise = np.random.randint(0, 100, (len(X_test), 784))
X_test_mod = X_test + noise
y_train_mod = X_train
y_test_mod = X_test

# 绘制有噪声的图像,即输入
some_digit_image = X_test_mod[5].reshape(28, 28)
plt.imshow(some_digit_image, cmap="binary")
plt.axis("off")
plt.show()

# 模型拟合、预测
knn_clf.fit(X_train_mod, y_train_mod)
clean_dit = knn_clf.predict([X_test_mod[5]])

# 绘制输出图像
some_digit_image = clean_dit.reshape(28, 28)
plt.imshow(some_digit_image, cmap="binary")
plt.axis("off")
plt.show()

加了noise的图像
“预测结果”

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/443842.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

webp格式怎么转换成jpg,3种常用工具方法

在日常办公中,我们经常需要在网上找到一些图片进行编辑。但有时候我们会遇到Webp格式的图片,而有些软件无法直接编辑它们。Webp是一种由谷歌开发的图像文件格式,它提供了有损压缩和无损压缩(可逆压缩)的功能,近年来,它…

Kafka基础篇学习笔记整理

Kafka基础篇学习笔记整理 生产者数据发送流程批量与定时发送缓冲区大小send发送消息消息可靠性发布确认机制重试机制 消息顺序性问题如何避免重试导致消息顺序错乱自定义拦截器自定义序列化器自定义分区器幂等与事务kafka实现幂等kafka实现事务事务的隔离级别使用演示 消费者重…

【状态估计】用于描述符 LTI 和 LPV 系统的分析、状态估计和故障检测的算法(Matlab代码实现)

💥 💥 💞 💞 欢迎来到本博客 ❤️ ❤️ 💥 💥 🏆 博主优势: 🌞 🌞 🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 …

免费文案生成器-免费文案改写神器

推荐一款高效免费自动写作软件,让你的写作效率飞升! 写作,对于众多的从业者或学生来说都是必不可少的工作内容。然而,许多人在写作时遇到了各种各样的困难,例如缺乏灵感、引用不足、缺乏逻辑性等等。为了解决这些问题…

Linux DHCP服务

DHCP 作用 DHCP动态主机配置协议作为服务端负责集中给客户端分配各种网络地址参数(主要包括IP地址、子网掩码、广播地址、默认网关地址、DNS服务器地址) 传输协议端口 服务端 UDP 67端口 客户端 UDP 68端口 工作原理 1) 客户端广播发送DISCOVER报文寻找服务端 2) 服务端广播发…

Unity - 带耗时 begin ... end 的耗时统计的Log - TSLog

CSharp Code // jave.lin 2023/04/21 带 timespan 的日志 (不帶 log hierarchy 结构要求,即: 不带 stack 要求)using System; using System.Collections.Generic; using System.IO; using UnityEditor; using UnityEngine;public…

Qt 学生信息数据库管理

1 添加样式表 我们采用了样式表 通过添加Qt resources文件 添加前缀 添加文件,将我们的图标进行添加 2 拖动部件 用到的部件 Label 标签Pushbutton 按钮table view 视图LineEdit 输入框 3 程序编写 1 配置sql环境 在 pro文件中 添加 连接数据库跟访问数据…

Qt模型视图结构

一.模型视图介绍 1.Model/View(模型/视图结构) 视图(View)是显示和编辑数据的界面组件, 模型(Model)是视图和原始数据之间的接口 2.视图组件有:QListView QTreeView QTableView,QColumnView,QHeaderView 模型组件有:QStringListM…

MyBatis详解(2)

8、自定义映射resultMap 8.1、resultMap处理字段和属性的映射关系 若字段名和实体类中的属性名不一致&#xff0c;则可以通过resultMap设置自定义映射 <!--resultMap&#xff1a;设置自定义映射属性&#xff1a;id&#xff1a;表示自定义映射的唯一标识type&#xff1a;查询…

PCIE_DMA实例二:xapp1052的EDK仿真

目录 一&#xff1a;前言 二&#xff1a;前期准备 三&#xff1a;操作步骤 四&#xff1a;仿真结果 五&#xff1a;总结 一&#xff1a;前言 对于有的同学&#xff0c;想要学习基于FPGA的PCIe DMA控制器设计&#xff0c;但是手上没有合适的Xilinx开发板&#xff0c;而且xap…

ETCD(五)写请求执行过程

写请求过程 客户端执行写请求指令 etcdctl put hello world —endpoints 192.168.1.1:12379执行流程&#xff1a; 首先客户端通过负载均衡选择一个etcd节点发起gRPC put方法调用&#xff1b;服务器收到请求后经过gRPC拦截器、Quota模块校验&#xff0c;进入KV Server模块&am…

「物联网时代的新选择」漫途科技推出装配式物联网服务,轻松实现项目落地

随着物联网技术的不断发展&#xff0c;越来越多的企业开始重视物联网系统的应用。然而&#xff0c;在物联网时代&#xff0c;鱼龙混杂&#xff0c;小品牌厂商层出不穷&#xff0c;质量参差不齐&#xff0c;这为系统集成商寻找靠谱的供应商伙伴带来了极大的挑战。 一、如何找靠谱…

Nacos配置中心的配置是怎么加载到spring容器的?

首先看到 org.springframework.boot.SpringApplication#applyInitializers 这个方法。 protected void applyInitializers(ConfigurableApplicationContext context) {for (ApplicationContextInitializer initializer : getInitializers()) {Class<?> requiredType G…

单链表C语言实现

链表就是许多节点在逻辑上串起来的数据存储方式 是通过结构体中的指针将后续的节点串联起来 typedef int SLTDataType;//数据类型 typedef struct SListNode//节点 {SLTDataType data;//存储的数据struct SListNode* next;//指向下一个节点地址的指针 }SLTNode;//结构体类型的…

设计模式(GOF)之我见(0)——UML

这里直接梳理画类图时的几个类关系。 类图的语法和功能 关系说明举例依赖&#xff08;Dependency) 偶然的&#xff0c;陌生的。 对类B进行修改会影响到A。 例如&#xff1a;问路时&#xff0c;路人甲给路人乙带路&#xff0c;路人甲的指引必然会影响到路人乙&#xff0c;但是…

在ROS2中使用奥比中光(ORBBEC)的AstraPro深度相机

0.效果演示 1.下载SDK 到官网下载OpenNI2_SDK 记得是下载这个OpenNI2_SDK,而不是下载那个Orbbec_SDK. 2.拷贝至自定义目录 拷贝到你的ubuntu的一个文件夹中&#xff0c;并解压得到 ros2_astra_camera 文件夹 然后新建一个ros2_ws文件夹&#xff0c;再在ros2_ws文件夹中新建…

矩阵链相乘的乘法次数(动态规划)

Description 设 A1, A2, …, An 为矩阵序列&#xff0c;Ai 是阶为 Pi − 1 * Pi 的矩阵 i  1, 2, …, n.试确定矩阵的乘法顺序&#xff0c;使得计算 A1A2…An 过程中元素相乘的总次数最少.Input 多组数据第一行一个整数 n(1≤n≤300) &#xff0c;表示一共有 n 个矩…

真题详解(计算机知识)-软件设计(五十四)

真题详解&#xff08;归并&#xff09;-软件设计&#xff08;五十三)https://blog.csdn.net/ke1ying/article/details/130254861 若无条件转移汇编指令采用直接寻址&#xff0c;则该指令功能是将指令中的地址码送入_____? PC&#xff08;程序计数器&#xff09; 程序计数器&…

10种黑客类型,你知道几种?

黑客一般有 10 种类型 1、白帽黑客 白帽黑客是指通过实施渗透测试&#xff0c;识别网络安全漏洞&#xff0c;为政府及组织工作并获得授权或认证的黑客。他们也确保保护免受恶意网络犯罪。他们在政府提供的规章制度下工作&#xff0c;这就是为什么他们被称为道德黑客或网络安全…

Kyligence Zen 产品体验 --- 初识庐山真面目

简介 Kyligence Zen 是一款数据分析工具&#xff0c;其市场定位是一站式云端指标平台。它基于 Kyligence 核心 OLAP&#xff08;On-Line Analytical Processing&#xff09; 能力打造&#xff0c;提供集业务模型、指标管理、指标加工、数据服务于一体的一站式服务。 其基本的…