混淆矩阵细致理解

news2024/11/25 6:50:21

1、什么是混淆矩阵

混淆矩阵(Confusion Matrix)是深度学习和机器学习领域中的一个重要工具,用于评估分类模型的性能。它提供了一个清晰的视觉方式来展示模型的预测结果与真实标签之间的关系,尤其在分类任务中,帮助我们了解模型的强项和弱点。

混淆矩阵通常是一个二维矩阵,其中包含四个关键的指标:

预测值=正例预测值=反例
真实值=正例TPFN
真实值=反例FPTN

  1. 真正例(True Positives,TP): 模型正确地预测了正类别的样本数量。

  2. 真负例(True Negatives,TN): 模型正确地预测了负类别的样本数量。

  3. 假正例(False Positives,FP): 模型错误地将负类别的样本预测为正类别的数量。

  4. 假负例(False Negatives,FN): 模型错误地将正类别的样本预测为负类别的数量。

你可以选择看我之前写过的这一篇博客。其实很好理解,比如TP,它就是正确的预测了正确的样本,FP就是错误的预测为了正确的样本。同理,TN就是正确的预测了错误的样本,FN就是错误的预测了错误的样本。

2、从混淆矩阵得到分类指标

然后我们来看上图,这里就能够得到一些指标运算的公式:

准确率(Accuracy):它是分类正确的样本数与总样本数的比率。准确率通常用来衡量模型在整个数据集上的性能,但在不平衡类别的情况下可能不太适用。

准确率 = (TP + TN) / (TP + TN + FP + FN)

精确率(Precision):精确率是指被模型正确预测为正类别的样本数占所有预测为正类别的样本数的比率。它用于衡量模型在正类别预测中的准确性。

精确率 = TP / (TP + FP)

召回率(Recall):召回率是指被模型正确预测为正类别的样本数占所有真实正类别的样本数的比率。它用于衡量模型在识别正类别样本中的能力。

召回率 = TP / (TP + FN)

F1 分数:F1 分数是精确率和召回率的调和平均值,用于综合评估模型的性能。它对精确率和召回率都进行了考虑,特别适用于不平衡类别的情况。

F1 分数 = 2 * (精确率 * 召回率) / (精确率 + 召回率)

IoU(Intersection over Union):IoU 用于语义分割等任务,它是真实正类别区域与模型预测正类别区域的交集与并集之比。

IoU = TP / (TP + FP + FN) 

3、使用pytorch构建混淆矩阵

最初要写的目的也是为了回顾一下之前所学的,并且想要在训练过程中能写一个类方便调用。先说一下思路。

首先,这个是针对标签的,我需要一个num_classes,也就是分类数,以便我先创建一个分类数大小的矩阵。

然后在不计算梯度的情况下,我们需要筛选出合适的像素点,这里简单来说就是一行代码:

k = (t >= 0) & (t < n)
  • t是真实类别的张量,其中包含了每个像素的真实类别标签。
  • n是类别总数,表示模型可以进行分类的类别数量。

然后,通过t[k]与p[k]就可以确定正确的像素范围。将每个选定像素的真实类别标签乘以总类别数n,以获得一个在混淆矩阵中的行索引,然后再加上p[k],就是我们混淆矩阵的索引。

inds = n * t[k].to(torch.int64) + p[k]

torch.bincount是用于统计 inds 中每个索引出现的次数。minlength 参数指定了输出张量的长度,这里设置为 n**2,以确保输出张量的长度足够容纳混淆矩阵的所有元素。

以上就是我的思路,这里大家可以自己打印出来看看每个步骤是怎么实现的:

import torch

num_classes = n = 3
mat = torch.zeros((n, n), dtype=torch.int64)

true_labels = t = torch.tensor([0, 1, 2, 0, 1, 2])  # 真实标签
predicted_labels = p = torch.tensor([0, 1, 1, 0, 2, 1])  # 预测结果

with torch.no_grad():
    k = (t >= 0) & (t < n)
    inds = n * t[k].to(torch.int64) + p[k]
    print(inds)
    mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)
print(mat)

打印出来的值:

tensor([0, 4, 7, 0, 5, 7])
tensor([[2, 0, 0],
            [0, 1, 1],
            [0, 2, 0]])

这个混淆矩阵解释如下:

  • 第一行表示真实标签为类别0的样本,模型将其分为类别0的次数为2次。
  • 第二行表示真实标签为类别1的样本,模型将其中一个分为类别1,另一个分为类别2。
  • 第三行表示真实标签为类别2的样本,模型将其都分为类别1。

然后对照我们原本设定的数据也是完全符合的。

4、使用pytorch构建分类指标

将混淆矩阵 mat 转换为浮点数张量 h ,以便进行后续计算

h = mat.float()

全局预测准确率,混淆矩阵的对角线表示的是真实和预测相对应的个数

acc_global = torch.diag(h).sum()/h.sum()

计算每个类别的准确率

acc = torch.diag(h)/h.sum(1)

 计算每个类别预测与真实目标的iou,IoU = TP / (TP + FP + FN)

iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))

我们打印出来,大家可以对照着公式进行对比即可:

tensor(0.5000)   tensor([1.0000, 0.5000, 0.0000])   tensor([1.0000, 0.2500, 0.0000])

其他的指标以后在补充,我想的是在评估的时候使用,所以这里的指标最好还是写一个类定义。

我也将其放进了pyzjr当中,欢迎大家pip安装使用。

class ConfusionMatrix(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, t, p):
        n = self.num_classes
        if self.mat is None:
            # 创建混淆矩阵
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=t.device)
        with torch.no_grad():
            # 寻找GT中为目标的像素索引
            k = (t >= 0) & (t < n)
            # 统计像素真实类别t[k]被预测成类别p[k]的个数
            inds = n * t[k].to(torch.int64) + p[k]
            self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)

    def reset(self):
        if self.mat is not None:
            self.mat.zero_()

    def compute(self):
        """
        计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)
        计算每个类别的准确率
        计算每个类别预测与真实目标的iou,IoU = TP / (TP + FP + FN)
        """
        h = self.mat.float()
        acc_global = torch.diag(h).sum() / h.sum()
        acc = torch.diag(h) / h.sum(1)
        iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        return acc_global, acc, iu

    def __str__(self):
        acc_global, acc, iu = self.compute()
        return (
            'global correct: {:.1f}\n'
            'average row correct: {}\n'
            'IoU: {}\n'
            'mean IoU: {:.1f}').format(
            acc_global.item() * 100,
            ['{:.1f}'.format(i) for i in (acc * 100).tolist()],
            ['{:.1f}'.format(i) for i in (iu * 100).tolist()],
            iu.mean().item() * 100)


if __name__=="__main__":
    num_classes = 3
    confusion_matrixs = ConfusionMatrix(num_classes)

    # 模拟一些真实标签和预测结果
    true_labels = torch.tensor([0, 1, 2, 0, 1, 2])  # 真实标签
    predicted_labels = torch.tensor([0, 1, 1, 0, 2, 1])  # 预测结果

    # 更新混淆矩阵
    confusion_matrixs.update(true_labels, predicted_labels)

    # 打印混淆矩阵及评估指标报告
    print("Confusion Matrix:")
    print(confusion_matrixs.mat)
    print("\nEvaluation Report:")
    print(confusion_matrixs)

打印的信息:

Confusion Matrix:
tensor([[2, 0, 0],
        [0, 1, 1],
        [0, 2, 0]])

Evaluation Report:
global correct: 50.0
average row correct: ['100.0', '50.0', '0.0']
IoU: ['100.0', '25.0', '0.0']
mean IoU: 41.7

5、与sklearn的官方实现进行比较 

同样的数据,我们放入到sklearn中看看是否正确呢?

from sklearn.metrics import confusion_matrix

confusion_matrix_result = confusion_matrix(true_labels, predicted_labels)
print(confusion_matrix_result)

[[2 0 0]
 [0 1 1]
 [0 2 0]]

不看数据类型,元素都是一一对应的,说明我们的实现是正确的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1021555.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

浅谈应急照明系统在民用建筑的设计应用与产品选型

贾丽丽 安科瑞电气股份有限公司 上海嘉定 201801 【摘要】应急照明分为备用照明、安全照明及疏散照明。文章介绍了应急照明系统的设计、灯具选择、灯具布置、配电等要求。并结合实例进行疏散照明的计算&#xff0c;以指导应急照明系统的设计与应用。 【关键词】照度&#xf…

大数据学习1.4-xShell配置Hadoop

1.创建hadoop目录 mkdir /usr/local/hadoop 2.切换到hadoop中 cd /usr/local/hadoop/ 3.将hadoop直接拖到xShell中 4.解压hadoop tar -zxvf hadoop-2.7.1.tar.gz 5.配置环境变量 vi /etc/profile export PATH$PATH:/usr/local/hadoop/hadoop-2.7.1/bin 6.加载配置文件(不能…

【刷题】蓝桥杯

蓝桥杯2023年第十四届省赛真题-平方差 - C语言网 (dotcpp.com) 初步想法&#xff0c;x y2 − z2&#xff08;yz)(y-z) 即xa*b&#xff0c;ayz&#xff0c;by-z 2yab 即ab是2的倍数就好了。 即x存在两个因数之和为偶数就能满足条件。 但时间是&#xff08;r-l&#xff09;*x&am…

Mybatis学习笔记8 查询返回专题

1.返回实体类 2.返回List<实体类> 3.返回Map 4.返回List<Map> 5.返回Map<String,Map> 6.resultMap结果集映射 7.返回总记录条数 新建模块 依赖 目录结构 1.返回实体类 如果返回多条,用单个实体接收会出异常 2.返回List<实体类> 即使返回一条记…

【软考】系统集成项目管理工程师(四)项目管理一般知识

一、 项目 1、 项目的定义 为大到特定的目的、使用一定的资源、在确定的期间内、为特定发起人而提供独特的产品、服务或成果而进行的一次性努力。 2、项目目标 分类描述成果性目标项目目标&#xff1a;满足客观要求的产品、系统、服务或者成果&#xff1b;例&#xff1a;①…

删除链表中所有含有val的节点

给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链表中所有满足 Node.val val 的节点&#xff0c;并返回 新的头节点 。 示例 1&#xff1a; 输入&#xff1a;head [1,2,6,3,4,5,6], val 6 输出&#xff1a;[1,2,3,4,5] 思路1&#xff1a;遍历查找&#xf…

大数据学习1.1-Centos8网络配置

1.查看虚拟网卡 2.配置网络信息 打勾处取消 记住箭头的数字 3.修改 网络连接 4.进入虚拟网络 5.进入属性 6.修改IPv4 5.将iIP和DNS进行修改 6.配置网络信息-进入修改网络配置文件 # 进入root用户 su root # 进入网络配置文件 cd /etc/sysconfig/network-scripts/ # 修改网络配…

C. Beautiful Sets of Points(找规律杂题)

解析&#xff1b; 由于坐标必须为整数&#xff0c;并且距离不能为整数&#xff0c;则同行同列不能存在多个“好点”。 则每行每列只能放一个点&#xff0c;所以最多的点数量即为 min&#xff08;n&#xff0c;m&#xff09;1 #include<bits/stdc.h> using namespace std…

以数据为中心的安全市场快速增长

根据Adroit Market Research的数据&#xff0c;2021年全球以数据为中心的安全市场规模估计为27.6亿美元&#xff0c;预计到2030年将增长至393.48亿美元&#xff0c;2021年至2030年的复合年增长率为30.9%。 研究人员表示&#xff0c;以数据为中心的安全强调保护数据本身&#x…

arcgis栅格按某列属性导出

栅格属性如下 可以看出栅格对应很多属性值&#xff0c;我们要按其中一个属性值作为栅格值 操作如下 只需将栅格值赋值为所需属性数据

后端配置(宝塔):SSH终端设置

一、打开SSH开关 在“安全”中找到SSH管理&#xff0c;按图打开对应按钮 二、复制秘钥 点击“查看密钥”&#xff0c;对密钥进行复制 三、添加服务器 在终端页面添加新的服务器 四、进行密钥连接 输入IP地址&#xff0c;进行root登录&#xff0c;私钥即在“安全”界面复制的…

微服务保护-隔离和降级

隔离和降级 限流是一种预防措施&#xff0c;虽然限流可以尽量避免因高并发而引起的服务故障&#xff0c;但服务还会因为其它原因而故障。 而要将这些故障控制在一定范围&#xff0c;避免雪崩&#xff0c;就要靠线程隔离&#xff08;舱壁模式&#xff09;和熔断降级手段了。 线…

MFC串口通信控件MSCOMM32.OCX的安装注册

MSCOMM32.OCX是一个与Microsoft Corporation开发的MSComm控件相关联的文件。MSComm控件是软件应用程序用来与调制解调器、条形码读取器和其他串行设备等设备建立串行通信的通信控件。 下载地址1 https://download.csdn.net/download/m0_60352504/88345092 下载地址2 https://ww…

【无标题】mysql 普通用户连接报错: MySql server has gone away

1、mysql 普通用户连接报错&#xff1a; MySql server has gone away 2、进入mysql错误日志位置查看输出日志显示错误为&#xff1a; [Warning] [MY-013130] [Server] Aborted connection 47 to db: unconnected user: tjcx host: 10.195.11.4 (init_connect command failed; …

437. 路径总和 III

给定一个二叉树的根节点 root &#xff0c;和一个整数 targetSum &#xff0c;求该二叉树里节点值之和等于 targetSum 的 路径 的数目。 路径 不需要从根节点开始&#xff0c;也不需要在叶子节点结束&#xff0c;但是路径方向必须是向下的&#xff08;只能从父节点到子节点&am…

vue柱状图+折线图组合

<template><div id"main" style"width: 100%;height: 500px; padding-top: .6rem"></div> </template>data() {return {weekData: ["1周","2周","3周","4周","5周","6周&…

九芯电子丨语音智能风扇,助您畅享智慧生活

回忆童年时期的传统机械风扇&#xff0c;那“古老”的扇叶连摆动看起来是那么吃力。在一个闷热的夏夜&#xff0c;风扇的噪音往往令人印象深刻。但在今天&#xff0c;静音家用风扇已取代了传统的机械风扇。与此同时&#xff0c;随着智能化的发展&#xff0c;智能家居已逐渐成为…

springcloud3 指定nacos的服务名称和配置文件的group,名称空间

一 指定读取微服务的配置文件 1.1 工程结构 1.2 nacos的配置 1.配置文件 2.内容 1.3 微服务的配置文件 1.bootstrap.yml内容 2.application.yml文件内容 1.4 验证访问 控制台&#xff1a; 1.5 nacos服务空间名称和groupid配置 1.配置文件配置 2.nacos的查看

C语言生成随机数、C++11按分布生成随机数学习

C语言生成随机数 如果只要产生随机数而不需要设定范围的话&#xff0c;只要用rand()就可以&#xff1b;rand()会返回一随机数值, 范围在0至RAND_MAX 间&#xff1b;RAND_MAX定义在stdlib.h, 其值为2147483647&#xff1b; 如果想要获取在一定范围内的数的话&#xff0c;直接做…

新版发布 | Cloudpods v3.10.5 和 v3.9.13 正式发布

Cloudpods v3.10.5 本期发布中&#xff0c;ocboot 部署脚本有较多变化&#xff0c;首先支持以非 root 用户执行安装流程&#xff0c;其次响应社区的呼吁&#xff0c;增加了–stack 参数&#xff0c;允许 Allinone 一键安装仅包含私有云&#xff08;参数为 edge&#xff09;或云…