混淆矩阵、准确率、查准率、查全率、DSC、IoU、敏感度的计算

news2025/1/8 5:59:21

1.背景介绍

在训练的模型的时候,需要评价模型的好坏,就涉及到混淆矩阵、准确率、查准率、查全率、DSC、IoU、敏感度的计算。

2、混淆矩阵的概念

所谓的混淆矩阵如下表所示:

TP:真正类,真的正例被预测为正例

FN:假负类,样本为正例,被预测为负类

FP:假正类 ,原本实际为负,但是被预测为正例

TN:真负类,真的负样本被预测为负类。

从混淆矩阵当中,可以得到更高级的分类指标:Accuracy(准确率),Precision(查准率),Recall(查全率),Specificity(特异性),Sensitivity(灵敏度)。

3. 常用的分类指标

3.1 Accuracy(准确率)

不管是哪个类别,只要预测正确,其数量都放在分子上,而分母是全部数据量。常用于表示模型的精度,当数据类别不平衡时,不能用于模型的评价。

Accuracy=\frac{TP+TN}{TP+FN+FN+TN}

3.2 Precision(查准率)

即所有预测为正的样本中,预测正确的样本的所占的比重。

Precision = \frac{TP}{TP+FP}

3.3  Recall(查全率)

真实的为正的样本,被正确检测出来的比重。

Recall=\frac{TP}{TP+FN}

3.4 Specificity(特异性)

特异性指标,也称 负正类率(False Positive Rate, FPR),计算的是模型错识别为正类的负类样本占所有负类样本的比例,一般越低越好。

FPR = \frac{FP}{TN+FP}

3.5 DSC(Dice coefficient)

Dice系数,是一种相似性度量,度量二进制图像分割的准确性。

如图所示红色的框的区域时Groudtruth,而蓝色的框为预测值Prediction。

DSC=\frac{2\left | G\sqcap P \right |}{\left | p \right |+\left | G \right |}

3.6 IoU(交并比)

IoU=\frac{p\sqcap G}{p\bigsqcup G}

3.7 Sensitivity(灵敏度)

反应的时预测正确的区域在Groundtruth中所占的比重。

SEN=\frac{\left | p \left | \sqcap \right |g\right | }{\left | G \right | }

4. 计算程序

ConfusionMatrix 这个类可以直接计算出混淆矩阵

from collections import defaultdict, deque
import datetime
import time
import torch
import torch.nn.functional as F
import torch.distributed as dist
import errno
import os


class SmoothedValue(object):
    """Track a series of values and provide access to smoothed values over a
    window or the global series average.
    """

    def __init__(self, window_size=20, fmt=None):
        if fmt is None:
            fmt = "{value:.4f} ({global_avg:.4f})"
        self.deque = deque(maxlen=window_size)
        self.total = 0.0
        self.count = 0
        self.fmt = fmt

    def update(self, value, n=1):
        self.deque.append(value)
        self.count += n
        self.total += value * n

    def synchronize_between_processes(self):
        """
        Warning: does not synchronize the deque!
        """
        if not is_dist_avail_and_initialized():
            return
        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
        dist.barrier()
        dist.all_reduce(t)
        t = t.tolist()
        self.count = int(t[0])
        self.total = t[1]

    @property
    def median(self):
        d = torch.tensor(list(self.deque))
        return d.median().item()

    @property
    def avg(self):
        d = torch.tensor(list(self.deque), dtype=torch.float32)
        return d.mean().item()

    @property
    def global_avg(self):
        return self.total / self.count

    @property
    def max(self):
        return max(self.deque)

    @property
    def value(self):
        return self.deque[-1]

    def __str__(self):
        return self.fmt.format(
            median=self.median,
            avg=self.avg,
            global_avg=self.global_avg,
            max=self.max,
            value=self.value)


class ConfusionMatrix(object):
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, a, b):
        n = self.num_classes
        if self.mat is None:
            # 创建混淆矩阵
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
        with torch.no_grad():
            # 寻找GT中为目标的像素索引
            k = (a >= 0) & (a < n)
            # 统计像素真实类别a[k]被预测成类别b[k]的个数(这里的做法很巧妙)
            inds = n * a[k].to(torch.int64) + b[k]
            self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)

    def reset(self):
        if self.mat is not None:
            self.mat.zero_()

    def compute(self):
        h = self.mat.float()
        # 计算全局预测准确率(混淆矩阵的对角线为预测正确的个数)
        acc_global = torch.diag(h).sum() / h.sum()
        # 计算每个类别的准确率
        acc = torch.diag(h) / h.sum(1)
        # 计算每个类别预测与真实目标的iou
        iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        return acc_global, acc, iu

    def reduce_from_all_processes(self):
        if not torch.distributed.is_available():
            return
        if not torch.distributed.is_initialized():
            return
        torch.distributed.barrier()
        torch.distributed.all_reduce(self.mat)

    def __str__(self):
        acc_global, acc, iu = self.compute()
        return (
            'global correct: {:.1f}\n'
            'average row correct: {}\n'
            'IoU: {}\n'
            'mean IoU: {:.1f}').format(
                acc_global.item() * 100,
                ['{:.1f}'.format(i) for i in (acc * 100).tolist()],
                ['{:.1f}'.format(i) for i in (iu * 100).tolist()],
                iu.mean().item() * 100)


class DiceCoefficient(object):
    def __init__(self, num_classes: int = 2, ignore_index: int = -100):
        self.cumulative_dice = None
        self.num_classes = num_classes
        self.ignore_index = ignore_index
        self.count = None

    def update(self, pred, target):
        if self.cumulative_dice is None:
            self.cumulative_dice = torch.zeros(1, dtype=pred.dtype, device=pred.device)
        if self.count is None:
            self.count = torch.zeros(1, dtype=pred.dtype, device=pred.device)
        # compute the Dice score, ignoring background
        pred = F.one_hot(pred.argmax(dim=1), self.num_classes).permute(0, 3, 1, 2).float()
        dice_target = build_target(target, self.num_classes, self.ignore_index)
        self.cumulative_dice += multiclass_dice_coeff(pred[:, 1:], dice_target[:, 1:], ignore_index=self.ignore_index)
        self.count += 1

    @property
    def value(self):
        if self.count == 0:
            return 0
        else:
            return self.cumulative_dice / self.count

    def reset(self):
        if self.cumulative_dice is not None:
            self.cumulative_dice.zero_()

        if self.count is not None:
            self.count.zeros_()

    def reduce_from_all_processes(self):
        if not torch.distributed.is_available():
            return
        if not torch.distributed.is_initialized():
            return
        torch.distributed.barrier()
        torch.distributed.all_reduce(self.cumulative_dice)
        torch.distributed.all_reduce(self.count)

分类指标的计算

import torch

# SR : Segmentation Result
# GT : Ground Truth

def get_accuracy(SR,GT,threshold=0.5):
    SR = SR > threshold
    GT = GT == torch.max(GT)
    corr = torch.sum(SR==GT)
    tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3)
    acc = float(corr)/float(tensor_size)

    return acc

def get_sensitivity(SR,GT,threshold=0.5):
    # Sensitivity == Recall
    SR = SR > threshold
    GT = GT == torch.max(GT)

    # TP : True Positive
    # FN : False Negative
    TP = ((SR==1)+(GT==1))==2
    FN = ((SR==0)+(GT==1))==2

    SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)) + 1e-6)     
    
    return SE

def get_specificity(SR,GT,threshold=0.5):
    SR = SR > threshold
    GT = GT == torch.max(GT)

    # TN : True Negative
    # FP : False Positive
    TN = ((SR==0)+(GT==0))==2
    FP = ((SR==1)+(GT==0))==2

    SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)) + 1e-6)
    
    return SP

def get_precision(SR,GT,threshold=0.5):
    SR = SR > threshold
    GT = GT == torch.max(GT)

    # TP : True Positive
    # FP : False Positive
    TP = ((SR==1)+(GT==1))==2
    FP = ((SR==1)+(GT==0))==2

    PC = float(torch.sum(TP))/(float(torch.sum(TP+FP)) + 1e-6)

    return PC

def get_F1(SR,GT,threshold=0.5):
    # Sensitivity == Recall
    SE = get_sensitivity(SR,GT,threshold=threshold)
    PC = get_precision(SR,GT,threshold=threshold)

    F1 = 2*SE*PC/(SE+PC + 1e-6)

    return F1

def get_JS(SR,GT,threshold=0.5):
    # JS : Jaccard similarity
    SR = SR > threshold
    GT = GT == torch.max(GT)
    
    Inter = torch.sum((SR+GT)==2)
    Union = torch.sum((SR+GT)>=1)
    
    JS = float(Inter)/(float(Union) + 1e-6)
    
    return JS

def get_DC(SR,GT,threshold=0.5):
    # DC : Dice Coefficient
    SR = SR > threshold
    GT = GT == torch.max(GT)

    Inter = torch.sum((SR+GT)==2)
    DC = float(2*Inter)/(float(torch.sum(SR)+torch.sum(GT)) + 1e-6)

    return DC

参考文献:

混淆矩阵的概念-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1415111.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OSI七层模型 | TCP/IP模型 | 网络和操作系统的联系 | 网络通信的宏观流程

文章目录 1.OSI七层模型2.TCP/IP五层(或四层)模型3.网络通信的宏观流程3.1.同网段通信3.2.跨网段通信 1.OSI七层模型 在计算机通信诞生之初&#xff0c;不同的厂商都生产自己的设备&#xff0c;都有自己的网络通讯标准&#xff0c;导致了不同厂家之间各种协议不兼容&#xff0…

Linux文本三剑客---grep

grep&#xff08;从文本或字符串种过滤特定内容。&#xff09; 格式&#xff1a;Usage: grep [OPTION]... PATTERNS [FILE]... 常用选项&#xff1a; -E 等价于 egrep 扩展正则 -i 忽略大小写 -w 匹配单词 -o 仅显示匹配内容 -r 递归匹配 -c 统计匹配的行数 -v 取反 -n 行号 -A…

ABAP 状态栏排除某些按钮

ABAP 状态栏排除某些按钮 GUI State状态栏 在状态栏这里有这些按钮&#xff0c;现在在导出界面要排除掉这些按钮&#xff1a; 将要排除的按钮追加到gt_code内表&#xff1a; gt_fcode功能码内表的定义 DATA:gt_fcode TYPE TABLE OF sy-ucomm,完整程序 *&---------…

SpringBoot不同的@Mapping使用

文章目录 一、介绍二、使用 一、介绍 一般Mapping类注解在Spring框架中用于将HTTP请求映射到对应的处理器方法。它们各自对应于不同类型的HTTP方法&#xff0c;主要用于RESTful Web服务中。以下是每个注解的作用&#xff1a; GetMapping: 用于映射HTTP GET请求到处理器方法。通…

Android读写文件,适配Q以上

Android Q升级了文件系统&#xff0c;访问文件不仅仅是说动态权限了&#xff0c;有各种限制。权限什么的就不赘述了&#xff0c;下面介绍一下在10以上的系统中访问文件。 首先是打开文件管理器 /*** 打开文件管理器 存储卡和外接U盘都可以访问*/public void openFileManager()…

【揭秘】ForkJoinTask全面解析

内容摘要 ForkJoinTask的显著优点在于其高效的并行处理能力&#xff0c;它能够将复杂任务拆分成多个子任务&#xff0c;并利用多核处理器同时执行&#xff0c;从而显著提升计算性能&#xff0c;此外&#xff0c;ForkJoinTask还提供了简洁的API和强大的任务管理机制&#xff0c…

Blender教程-编辑模式点线面的选择-06

一、新建立方体 ShiftA新建立方体用于演示 二、模式切换 按TAB键切换为编辑模式 点模式 在点模式下可以选择中物体的顶点。 线模式&#xff08;边模式&#xff09; 面模式 在熟悉编辑模式下的点线面基础操作以后&#xff0c;我们后续建模会以此为基础教程。

D. Epic Transformation(堆+贪心)

思路&#xff1a;我们删的策略是从次数多的数开始删&#xff0c;每次取两种不同的数&#xff0c;每种删去一个&#xff0c;然后放回堆中。 代码&#xff1a; void solve(){int n;cin >> n;map<int,int>mp;for(int i 1;i < n;i ){int x;cin >> x;mp[x] …

CAD-autolisp(二)——选择集、命令行设置对话框、符号表

目录 一、选择集1.1 选择集的创建1.2 选择集的编辑1.3 操作选择集 二、命令行设置对话框2.1 设置图层2.2 加载线型2.3 设置字体样式2.4 设置标注样式&#xff08;了解即可&#xff09; 三、符号表3.1 简介3.2 符号表查找3.2 符号表删改增 一、选择集 定义&#xff1a;批量选择…

ubuntu 20.04 更新 autoconf 版本

前言 由于最近打算交叉编译 python&#xff0c;依赖 libffi 库&#xff0c;而交叉编译 libffi 库&#xff0c;由于使用的是 github 上的 libffi&#xff0c;又提示 autoconf 版本太低了&#xff0c;所以&#xff0c;先更新 autoconf 的版本 当前 ubuntu 20.04 上安装的 autuco…

【数据分享】2015年泛第三极65国1km分辨率土壤侵蚀强度数据集(免费获取)

土壤数据是在环境、农业、生态等相关研究中都非常常用的数据&#xff01;我们之前发表过一篇介绍土壤数据来源的文章&#xff08;可查看之前推送的文章获悉详情&#xff09;&#xff01; 土壤侵蚀强度是土壤的重要属性&#xff01;本次我们给大家带来的是2015年泛第三极65国1k…

《合成孔径雷达成像算法与实现》Figure5.19

clc clear close all距离向参数 R_eta_c 20e3; % 景中心斜距 Tr 25e-6; % 发射脉冲时宽 Kr 0.25e12; % 距离向调频率 Fr 7.5e6; % 距离向采样率 Nrg 256; % 距离线采样点数 Bw abs(Kr*Tr); …

8.15合并区间(LC56)

算法&#xff1a; 和452. 用最少数量的箭引爆气球 (opens new window)和 435. 无重叠区间 (opens new window)都是一个套路。 这几道题都是判断区间重叠&#xff0c;区别就是判断区间重叠后的逻辑&#xff0c;本题是判断区间重贴后要进行区间合并。 步骤&#xff1a; 先排序…

自动化脚本不稳定,原来是软件弹窗惹的祸,2个方法解决!

很多同学在学习 App 自动化或者在项目中落地实践 App 自动化时&#xff0c;会发现编写的自动化脚本无缘无故的执行失败、不稳定。 而导致其问题很大原因是因为应用的各种弹窗&#xff08;升级弹窗、使用过程提示弹窗、评价弹窗等等&#xff09;&#xff0c;比如这样的&#xff…

nodejs下载 安装 配置环境

目录 1.下载 2、配置环境 1.下载 下载地址&#xff1a;https://nodejs.org/en/download/ 下载完成后&#xff0c;双击安装包&#xff0c;开始安装&#xff0c;一直点next即可。我把安装路径设置为 D:\Program Files\nodejs\ 安装完之后打开终端 windowR cmd 回车&#xff1…

C++:组合、继承与多态

面向对象设计的重要目的之一就是代码重用&#xff0c;这也是C的重要特性之一。代码重用鼓励人们使用已有的&#xff0c;得到认可并经过测试的高质量代码。多态允许以常规方式书写代码来访问多种现有的且已专门化了的相关类。继承和多态是面向对象程序设计方法的两个最主要的特性…

fix bug: FileNotFoundError: [Errno 2] No such file or directory: ‘nvcc‘

1.问题描述 运行的代码设计pycuda,会调用nvcc&#xff0c;确定已经安装cuda toolkit&#xff0c;在terminal中云运行 nvcc -V 能得到想到的结果&#xff1a; 但是在 pycharm中运行代码时提示&#xff1a; FileNotFoundError: [Errno 2] No such file or directory: nvcc 2.…

细数语音识别中的几个former

随着Transformer在人工智能领域掀起了一轮技术革命&#xff0c;越来越多的领域开始使用基于Transformer的网络结构。目前在语音识别领域中&#xff0c;Tranformer已经取代了传统ASR建模方式。近几年关于ASR的研究工作很多都是基于Transformer的改进&#xff0c;本文将介绍其中应…

拦截器的简单使用

拦截器的简单使用 拦截器的使用创建拦截器preHandle 目标方法执行前执行postHandle 目标方法执行后执行afterCompletion 视图渲染后执行 拦截器使用场景返回值注册拦截器运用拦截器 拦截器的使用 创建拦截器 首先,我们需要创建一个拦截器器的类,并且需要继承自HandlerIntercep…

java分布式锁(详解)

本地锁 浏览器把100w请求由网关随机往下传&#xff0c;在集群情况下&#xff0c;每台服务都放行10w请求过来&#xff0c;这时候每台服务都用的是本地锁是跨JVM的&#xff0c; 列如这些服务都没有49企业&#xff0c;此时有几个服务进行回原了打击在DB上面&#xff0c;那后期把这…