多分类中混淆矩阵的TP,TN,FN,FP计算

news2025/1/17 21:40:09

关于混淆矩阵,各位可以在这里了解:混淆矩阵细致理解_夏天是冰红茶的博客-CSDN博客

上一篇中我们了解了混淆矩阵,并且进行了类定义,那么在这一节中我们将要对其进行扩展,在多分类中,如何去计算TP,TN,FN,FP。

原理推导

这里以三分类为例,这里来看看TP,TN,FN,FP是怎么分布的。

类别1的标签:

类别2的标签:

类别3的标签:

这样我们就能知道了混淆矩阵的对角线就是TP

TP = torch.diag(h)

 假正例(FP)是模型错误地将负类别样本分类为正类别的数量

FP = torch.sum(h, dim=1) - TP

假负例(FN)是模型错误地将正类别样本分类为负类别的数量

FN = torch.sum(h, dim=0) - TP

最后用总数减去除了 TP 的其他三个元素之和得到 TN

TN = torch.sum(h) - (torch.sum(h, dim=0) + torch.sum(h, dim=1) - TP)

逻辑验证

这里借用上一篇的例子,假如我们这个混淆矩阵是这样的:

tensor([[2, 0, 0],
            [0, 1, 1],
            [0, 2, 0]])

为了方便讲解,这里我们对其进行一个简单的编号,即0—8:

012
345
678

torch.sum(h, dim=1) 可得 tensor([2., 2., 2.]) , torch.sum(h, dim=0) 可得 tensor([2., 3., 1.]) 。

  •  TP:   tensor([2., 1., 0.]) 
  •  FP:   tensor([0., 1., 2.]) 
  •  TN:   tensor([4., 2., 3.]) 
  •  FN:   tensor([0., 2., 1.])

我们先来看看TP的构成,对应着矩阵的对角线2,1,0;FP在类别1中占3,6号位,在类别2中占1,7号位,在类别3中占2,5号位,加起来即为0,1,2;TN在类别1中占4,5,7,8号位,在类别2中占边角位,在类别3中占0,1,3,4号位,加起来即为4,2,3;FN在类别1中占1,2号位,在类别2中占3,5号位,在类别3中占6,7号位,加起来即为0,2,1。

补充类定义

import torch
import numpy as np

class ConfusionMatrix(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, t, p):
        n = self.num_classes
        if self.mat is None:
            # 创建混淆矩阵
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=t.device)
        with torch.no_grad():
            # 寻找GT中为目标的像素索引
            k = (t >= 0) & (t < n)
            # 统计像素真实类别t[k]被预测成类别p[k]的个数
            inds = n * t[k].to(torch.int64) + p[k]
            self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)

    def reset(self):
        if self.mat is not None:
            self.mat.zero_()

    @property
    def ravel(self):
        """
        计算混淆矩阵的TN, FP, FN, TP
        """
        h = self.mat.float()
        n = self.num_classes
        if n == 2:
            TP, FN, FP, TN = h.flatten()
            return TP, FN, FP, TN
        if n > 2:
            TP = h.diag()
            FN = h.sum(dim=1) - TP
            FP = h.sum(dim=0) - TP
            TN = torch.sum(h) - (torch.sum(h, dim=0) + torch.sum(h, dim=1) - TP)

            return TP, FN, FP, TN

    def compute(self):
        """
        主要在eval的时候使用,你可以调用ravel获得TN, FP, FN, TP, 进行其他指标的计算
        计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)
        计算每个类别的准确率
        计算每个类别预测与真实目标的iou,IoU = TP / (TP + FP + FN)
        """
        h = self.mat.float()
        acc_global = torch.diag(h).sum() / h.sum()
        acc = torch.diag(h) / h.sum(1)
        iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        return acc_global, acc, iu

    def __str__(self):
        acc_global, acc, iu = self.compute()
        return (
            'global correct: {:.1f}\n'
            'average row correct: {}\n'
            'IoU: {}\n'
            'mean IoU: {:.1f}').format(
            acc_global.item() * 100,
            ['{:.1f}'.format(i) for i in (acc * 100).tolist()],
            ['{:.1f}'.format(i) for i in (iu * 100).tolist()],
            iu.mean().item() * 100)

我在代码中添加了属性修饰器,以便我们可以直接的进行调用,并且也考虑到了二分类与多分类不同的情况。

性能指标

关于这些指标在网上有很多介绍,这里就不细讲了

class ModelIndex():
    def __init__(self,TP, FN, FP, TN, e=1e-5):
        self.TN = TN
        self.FP = FP
        self.FN = FN
        self.TP = TP
        self.e = e

    def Precision(self):
        """精确度衡量了正类别预测的准确性"""
        return self.TP / (self.TP + self.FP + self.e)

    def Recall(self):
        """召回率衡量了模型对正类别样本的识别能力"""
        return self.TP / (self.TP + self.FN + self.e)

    def IOU(self):
        """表示模型预测的区域与真实区域之间的重叠程度"""
        return self.TP / (self.TP + self.FP + self.FN + self.e)

    def F1Score(self):
        """F1分数是精确度和召回率的调和平均数"""
        p = self.Precision()
        r = self.Recall()
        return 2*p*r / (p + r + self.e)

    def Specificity(self):
        """特异性是指模型在负类别样本中的识别能力"""
        return self.TN / (self.TN + self.FP + self.e)

    def Accuracy(self):
        """准确度是模型正确分类的样本数量与总样本数量之比"""
        return self.TP + self.TN / (self.TP + self.TN + self.FP + self.FN + self.e)

    def FP_rate(self):
        """False Positive Rate,假阳率是模型将负类别样本错误分类为正类别的比例"""
        return self.FP / (self.FP + self.TN + self.e)

    def FN_rate(self):
        """False Negative Rate,假阴率是模型将正类别样本错误分类为负类别的比例"""
        return self.FN / (self.FN + self.TP + self.e)

    def Qualityfactor(self):
        """品质因子综合考虑了召回率和特异性"""
        r = self.Recall()
        s = self.Specificity()
        return r+s-1

参考文章:多分类中TP/TN/FP/FN的计算_Hello_Chan的博客-CSDN博客 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1016091.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

1_图神经网络GNN基础知识学习

文章目录 安装PyTorch Geometric安装工具包 在KarateClub数据集上使用图卷积网络 (GCN) 进行节点分类两个画图函数Graph Neural Networks数据集&#xff1a;Zacharys karate club network.PyTorch Geometric数据集介绍 edge_index使用networkx可视化展示 Graph Neural Networks…

cenos自动启动tomcat

首先创建一个脚本 关闭tomcat 等待2分钟 启动tomcat 并且把日志输出在 /usr/local/tomcat/tomcatchognqi.log #!/bin/bashexport JAVA_HOME/usr/local/jdk/jdk1.8.0_211 export JRE_HOME$JAVA_HOME/jre# 日志文件路径和文件名 LOG_FILE"/usr/local/tomcat/tomcatchognqi.…

Ubuntu20.4搭建基于iRedMail的邮件服务器

iRedMail 是一个基于 Linux/BSD 系统的零成本、功能完备、成熟的邮件服务器解决方案。基于ubuntu20.4搭建基于iRedMail的邮件服务器包括环境配置&#xff0c;iRedMail安装与配置&#xff0c;iRedMail调整邮件附件大小等3个小节进行描述。具体如下详细描述。 1 环境配置 1.设置…

企业架构LNMP学习笔记56

MongoDB数据类型操作&#xff1a;CURD 1、添加数据&#xff1a; mongodb里存储数据的格式文档形式&#xff0c;以bson格式的文档形式。 创建数据库&#xff1a; > use tp5shop switched to db tp5shop > db.getName() tp5shop使用切换库&#xff0c;不存在自动创建&am…

Jmeter 自动化性能测试常见问题汇总

一、request 请求超时设置 timeout 超时时间是可以手动设置的&#xff0c;新建一个 http 请求&#xff0c;在“高级”设置中找到“超时”设置&#xff0c;设置连接、响应时间为2000ms。 1. 请求连接超时&#xff0c;连不上服务器。 现象&#xff1a; Jmeter表现形式为&…

计算机竞赛 机器视觉人体跌倒检测系统 - opencv python

0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 机器视觉人体跌倒检测系统 该项目较为新颖&#xff0c;适合作为竞赛课题方向&#xff0c;学长非常推荐&#xff01; &#x1f947;学长这里给一个题目综合评分(每项满分5分) 难度系数&…

【AIGC】Stable Diffusion Prompt 每日一练0915

一、前言 1.1 写在前面 本文是一个系列&#xff0c;有点类似随笔&#xff0c;每天一次更新&#xff0c;重点就Stable Diffusion Prompt进行专项训练&#xff0c;本文是第一篇《Stable Diffusion Prompt 每日一练0915》。 1.2 项目背景 stable diffusion提示词(prompt)是用于…

[vue问题]开发中问题集合

“TypeError: Cannot read property ‘Request’ of undefined” 这是测试文件的报错&#xff0c;最后发现是因为项目启动的时候就报错了&#xff0c;是其它错误导致的&#xff0c;所以测试文件才会提示这种错误&#xff0c;当启动报错修复后&#xff0c;该问题没有了 热加载…

Acwing算法心得——猜测短跑队员的速度(重写比较器)

大家好&#xff0c;我是晴天学长&#xff0c;今天的算法题用到了比较器的知识&#xff0c;是经常会用到的一个知识点&#xff0c;常见与同种数据的排序&#xff0c;需要的小伙伴请自取哦&#xff01;如果觉得写的不错的话&#xff0c;可以点个关注哦&#xff0c;后续会继续更新…

什么是函数式编程(functional programming)?在JavaScript中如何实现函数式编程的概念?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐ 函数式编程&#xff08;Functional Programming&#xff09;⭐ 纯函数&#xff08;Pure Functions&#xff09;⭐ 不可变性&#xff08;Immutability&#xff09;⭐ 高阶函数&#xff08;Higher-Order Functions&#xff09;⭐ 函数组合&…

jq命令安装与使用

目录 一、简介二、下载及安装1.Linux 安装2.Windows 安装3.测试安装结果 三、jq用法1.基本语法2.常见用法1&#xff09;格式化 JSON2&#xff09;获取属性3&#xff09;属性不存在情况处理4&#xff09;数组遍历、截取、展开5&#xff09;管道、逗号、加号6&#xff09;数据构造…

长话短说 CopyOnWrite 思想及其应用场景

长话短说 CopyOnWrite 思想及其应用场景。 CopyOnWrite(写入时复制)思想 CopyOnWrite(简称COW,中文意思是:写入时复制)就是在进行写操作时,先复制要改变的对象,对副本进行写操作,完成对副本的操作后,把原有对象的引用指向副本对象。 CopyOnWrite采用了读写分离的思想…

Day45:Element-Plus

Vue生态中包含了大量优秀的组件库&#xff0c;经过快速的学习&#xff0c;我们就能把这些现成的组件应用到自己的项目中了 1.常见组件库 UI组件库需要基于JS框架来实现&#xff0c;也就是说我们现在学习的是Vue3&#xff0c;也需要选择适配的组件库&#xff0c;在Vue生态中&a…

C++QT day6

1> 将之前定义的栈类和队列类都实现成模板类 栈&#xff1a; #include <iostream> #define MAX 128 using namespace std; template<typename T> class Stack_s { private:T *pnew T[MAX];//栈的数组int top;//记录栈顶的变量 public://构造函数Stack_s(int t…

数据结构--平衡二叉树

目录 平衡二叉树定义 平衡二叉树的插入 调整最小不平衡子树 LL型 RR型 LR型​编辑 RL型​编辑 练习 查找效率分析​编辑 回顾知识点 平衡二叉树的删除 例1 ​编辑 例2​编辑 例3 例4​编辑 ​编辑 平衡二叉树的删除回顾​编辑 定义插入操作插入新结点后如何调…

SoundSource 5 for Mac:专业音频控制器,让您的Mac听觉体验更出色

对于使用 Mac 的用户而言&#xff0c;拥有一个强大的音频控制工具是非常重要的。SoundSource 5 for Mac 是 Rogue Amoeba 公司开发的一款专业的音频控制器&#xff0c;它提供了丰富的功能和优秀的性能&#xff0c;让您能够轻松管理并优化Mac上的音频设置。 首先&#xff0c;So…

腾讯mini项目-【指标监控服务重构】2023-08-12

今日已办 Watermill Handler 将 4 个阶段的逻辑处理定义为 Handler 测试发现&#xff0c;添加的 handler 会被覆盖掉&#xff0c;故考虑添加为 middleware 且 4 个阶段的处理逻辑针对不同 topic 是相同的。 参考https://watermill.io/docs/messages-router/实现不同topic&am…

DS相关题目

DS相关题目 题目一&#xff1a;消失的数字 拿到这道题目之后&#xff0c;首先可以想到的一个解题方法就是&#xff0c;我们可以先排序&#xff0c;排完序之后&#xff0c;这个数组其实就是一个有序的数组了&#xff0c;那只用比较数组中的每一个元素和他对应的下标是不是相等的…

Day46-50:统计图表项目总结

建项 项目需求写法——可视化报表 可视化报表项目效果 开源表格样式库阿帕奇 绘制echarte图标的流程 在视图中放置一个容器&#xff0c;这个容器需要有一个固定的宽高获取容器&#xff0c;调用init方法&#xff0c;初始化echarts实例 let container document.querySelec…

stm32----ADC模数转换

一、ADC介绍 ADC&#xff0c;即模数转换器&#xff0c;它可以将模拟信号转化为数字信号。在stm32种一般有3个ADC&#xff0c;每个ADC有18个通道。 12位ADC是一种逐次逼近型模拟数字转换器&#xff0c;它有多达18个通道&#xff0c;可测量16个外部和两个内部信号源。各个通道的A…