深度学习PyTorch之13种模型精度评估公式及调用方法

news2025/3/12 2:40:07

深度学习pytorch之22种损失函数数学公式和代码定义
深度学习pytorch之19种优化算法(optimizer)解析
深度学习pytorch之4种归一化方法(Normalization)原理公式解析和参数使用
深度学习pytorch之简单方法自定义9类卷积即插即用
实时语义分割之BiSeNetv2(2020)结构原理解析及建筑物提取实践

文章目录

  • 摘要
    • 1. Accuracy Score
    • 2. Balanced Accuracy
    • 3. Brier Score Loss
    • 4. Cohen's Kappa
    • 5. F1/F-beta Score
    • 6. Hamming Loss
    • 7. Hinge Loss
    • 8. Jaccard Score
    • 9. Log Loss
    • 10. Matthews Correlation
    • 11. Precision
    • 12. Recall
    • 13. Zero-One Loss
  • 关键参数说明
  • 可执行代码示例

摘要

模型训练后需要评估模型性能,因此需要了解各种评估指标的具体用法和背后的数学原理,本博客以清晰的格式呈现分类任务评估指标的名称、调用示例、公式说明。

1. Accuracy Score

调用方式:

from sklearn.metrics import accuracy_score
acc = accuracy_score(y_true, y_pred, normalize=True, sample_weight=None)

公式:

Accuracy = (TP + TN) / (TP + TN + FP + FN)

2. Balanced Accuracy

调用方式:

from sklearn.metrics import balanced_accuracy_score
bal_acc = balanced_accuracy_score(y_true, y_pred, sample_weight=None, adjusted=False)

公式:

Balanced Accuracy = (Recall_Class1 + Recall_Class2 + … +Recall_ClassN) / N
调整后版本:BalancedAcc_adj = (BalancedAcc - 1/N) / (1 -1/N)

3. Brier Score Loss

调用方式:

from sklearn.metrics import brier_score_loss
brier = brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=1)

公式:

Brier Score = 1/N * Σ(y_true_i - y_prob_i)^2

(适用于概率预测的校准度评估)

4. Cohen’s Kappa

调用方式:

from sklearn.metrics import cohen_kappa_score
kappa = cohen_kappa_score(y1, y2, labels=None, weights=None, sample_weight=None)

公式:

κ = (p_o - p_e) / (1 - p_e) 其中 p_o 为观察一致率,p_e 为期望一致率

5. F1/F-beta Score

调用方式:

from sklearn.metrics import f1_score, fbeta_score
f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
fbeta = fbeta_score(y_true, y_pred, beta=0.5, average='macro')

公式:

Fβ = (1 + β²) * (precision * recall) / (β² * precision + recall) 当 β=1
时为 F1 Score

6. Hamming Loss

调用方式:

from sklearn.metrics import hamming_loss
hamming = hamming_loss(y_true, y_pred, sample_weight=None)

公式:

Hamming Loss = 1/N * Σ(预测错误的标签数 / 总标签数) (多标签任务专用)

7. Hinge Loss

调用方式:

from sklearn.metrics import hinge_loss
hinge = hinge_loss(y_true, pred_decision, labels=None, sample_weight=None)

公式:

Hinge Loss = max(0, 1 - y_true * pred_decision) 的平均值 (SVM模型常用)

8. Jaccard Score

调用方式:

from sklearn.metrics import jaccard_score
jaccard = jaccard_score(y_true, y_pred, average='samples')

公式:

Jaccard = TP / (TP + FP + FN)

即IOU,多用于图像分割评估

9. Log Loss

调用方式:

from sklearn.metrics import log_loss
logloss = log_loss(y_true, y_pred, eps=1e-15, normalize=True, labels=None)

公式:

Log Loss = -1/N * Σ[y_true_i * log(y_pred_i) + (1-y_true_i) *log(1-y_pred_i)]

交叉熵损失,需概率预测输入

10. Matthews Correlation

调用方式:

from sklearn.metrics import matthews_corrcoef
mcc = matthews_corrcoef(y_true, y_pred, sample_weight=None)

公式:

MCC = (TPTN - FPFN) / √((TP+FP)(TP+FN)(TN+FP)(TN+FN))

适用于类别不平衡的二分类

11. Precision

调用方式:

from sklearn.metrics import precision_score
precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)

公式:

Precision = TP / (TP + FP)

12. Recall

调用方式:

from sklearn.metrics import recall_score
recall = recall_score(y_true, y_pred, average='macro', zero_division=0)

公式:

Recall = TP / (TP + FN)

13. Zero-One Loss

调用方式:

from sklearn.metrics import zero_one_loss
zero_one = zero_one_loss(y_true, y_pred, normalize=True)

公式:

Zero-One Loss = 1 - Accuracy

直接统计错误预测比例

关键参数说明

参数说明
average计算方式:None(各类单独计算)、‘micro’(全局统计)、‘macro’(各类平均)、‘weighted’(按支持数加权)
zero_division处理除零情况:0(返回0)、1(返回1)或’warn’(返回0并警告)
sample_weight样本权重数组
pos_label指定正类标签(仅二分类有效)
labels指定要评估的类别列表
betaF-beta中召回率的权重(>1侧重召回率,<1侧重精确率)

可执行代码示例

以下程序采用常用的accuracy, precision, recall, f1对分类结果进行评估,注意替换下列文件夹,两个文件夹内均为8位单波段影像,采用相同命名。

  • label_dir = ‘label’ # 替换为实际路径
  • pred_dir = ‘pred’ # 替换为实际路径
import os
import numpy as np
from PIL import Image
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import matplotlib.pyplot as plt


def load_images_and_labels(label_dir, pred_dir):
    """
    读取标签图像和预测图像,假设它们的像素值代表类别标签。
    :param label_dir: 实际标签图像的文件夹路径
    :param pred_dir: 预测标签图像的文件夹路径
    :return: 实际标签和预测标签的列表
    """
    labels = []
    preds = []

    # 获取文件列表
    label_files = sorted(os.listdir(label_dir))
    pred_files = sorted(os.listdir(pred_dir))

    # 遍历每个图像文件加载标签和预测
    for label_file, pred_file in zip(label_files, pred_files):
        label_path = os.path.join(label_dir, label_file)
        pred_path = os.path.join(pred_dir, pred_file)

        # 加载图像并转换为灰度
        label_img = Image.open(label_path).convert('L')  # 灰度图
        pred_img = Image.open(pred_path).convert('L')  # 灰度图

        # 假设灰度值代表类标签
        label = np.array(label_img)
        pred = np.array(pred_img)

        # 扁平化数组,以便计算评估指标
        labels.extend(label.flatten())
        preds.extend(pred.flatten())

    return np.array(labels), np.array(preds)


def evaluate_model(labels, preds):
    """
    计算模型的评估指标
    :param labels: 实际标签
    :param preds: 预测标签
    """
    # 计算评估指标
    accuracy = accuracy_score(labels, preds)
    precision = precision_score(labels, preds, average='weighted', zero_division=0)
    recall = recall_score(labels, preds, average='weighted', zero_division=0)
    f1 = f1_score(labels, preds, average='weighted', zero_division=0)


    # 打印评估指标
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")


    # 可选:绘制混淆矩阵
    from sklearn.metrics import confusion_matrix
    import seaborn as sns
    cm = confusion_matrix(labels, preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=np.unique(labels), yticklabels=np.unique(labels))
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()


if __name__ == "__main__":
    # 设置实际标签和预测标签的文件夹路径
    label_dir = 'label'  # 替换为实际路径
    pred_dir = 'pred'  # 替换为实际路径

    # 加载标签和预测数据
    labels, preds = load_images_and_labels(label_dir, pred_dir)

    # 评估模型
    evaluate_model(labels, preds)

输出结果:
Accuracy: 0.9681
Precision: 0.9686
Recall: 0.9681
F1 Score: 0.9683

绘制混淆矩阵:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2313509.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

tomcat单机多实例部署

一、部署方法 多实例可以运行多个不同的应用&#xff0c;也可以运行相同的应用&#xff0c;类似于虚拟主机&#xff0c;但是他可以做负载均衡。 方式一&#xff1a; 把tomcat的主目录挨个复制&#xff0c;然后把每台主机的端口给改掉就行了。 优点是最简单最直接&#xff0c;…

Java开发者如何接入并使用DeepSeek

目录 一、准备工作 二、添加DeepSeek SDK依赖 三、初始化DeepSeek客户端 四、数据上传与查询 五、数据处理与分析 六、实际应用案例 七、总结 【博主推荐】&#xff1a;最近发现了一个超棒的人工智能学习网站&#xff0c;内容通俗易懂&#xff0c;风格风趣幽默&#xff…

win10电脑鼠标速度突然变的很慢?

电脑鼠标突然变很慢&#xff0c;杀毒检测后没问题&#xff0c;鼠标设置也没变&#xff0c;最后发现可能是误触鼠标的“DPI”调节键。 DPI调节键在鼠标滚轮下方&#xff0c;再次点击即可恢复正常鼠标速度。 如果有和-的按键&#xff0c;速度变快&#xff0c;-速度变慢。 图源&…

第四次CCF-CSP认证(含C++源码)

第四次CCF-CSP认证 第一道&#xff08;easy&#xff09;思路及AC代码 第二道&#xff08;easy&#xff09;思路及AC代码遇到的问题 第三道&#xff08;mid&#xff09;思路及AC代码 第一道&#xff08;easy&#xff09; 题目链接 思路及AC代码 这题就是将这个矩阵旋转之后输出…

Netty基础—1.网络编程基础一

大纲 1.什么是OSI开放系统互连 2.OSI七层模型各层的作用 3.TCP/IP协议的简介 4.TCP和UDP的简介 5.TCP连接的三次握手 6.TCP连接的四次挥手 7.TCP/IP中的数据包 8.TCP通过确认应答与序列号提高可靠性 9.HTTP请求的传输过程 10.HTTP协议报文结构 11.Socket、短连接、长…

98.在 Vue3 中使用 OpenLayers 根据 Resolution 的不同显示不同的地图

在 Vue3 中使用 OpenLayers 根据 Resolution 的不同显示不同的地图 前言 在 Web GIS&#xff08;地理信息系统&#xff09;应用开发中&#xff0c;地图的 Resolution&#xff08;分辨率&#xff09;是一个重要的概念。不同的 Resolution 适用于不同的地图层级&#xff0c;有时…

unity学习64,第3个小游戏:一个2D跑酷游戏

目录 学习参考 素材资源导入 1 创建项目 1.1 创建1个2D项目 1.2 导入素材 2 背景图bg 2.0 bg素材 2.1 创建背景 2.2 修改素材&#xff0c;且修改摄像机等 2.2.1 修改导入的原始prefab素材 2.2.2 对应调整摄像机 2.2.3 弄好背景 2.3 背景相关脚本实现 2.3.1 错误…

在本地部署DeepSeek等大模型时,需警惕的潜在安全风险

在本地部署DeepSeek等大模型时&#xff0c;尽管数据存储在本地环境&#xff08;而非云端&#xff09;&#xff0c;但仍需警惕以下潜在安全风险&#xff1a; 1. 模型与数据存储风险 未加密的存储介质&#xff1a;若训练数据、模型权重或日志以明文形式存储&#xff0c;可能被物…

【redis】string类型相关操作:SET、GET、MSET、MGET、SETNX、SETEX、PSETEX

文章目录 二进制存储编码转换SET 和 GETSETGET MSET 和 MGETSETNX、SETEX 和 PSETEX Redis 所有的 key 都是字符串&#xff0c;value 的类型是存在差异的 二进制存储 Redis 中的字符串&#xff0c;直接就是按照二进制数据的方式存储的 不仅仅可以存储文本数据&#xff0c;还可…

GaussDB安全配置指南:从认证到防御的全方面防护

一、引言 随着企业数据规模的扩大和云端化进程加速&#xff0c;数据库安全性成为运维的核心挑战之一。GaussDB作为一款高性能分布式数据库&#xff0c;提供了丰富的安全功能。本文将从 ​认证机制、权限控制、数据加密、审计日志​ 等维度&#xff0c;系统性地讲解如何加固 Ga…

Ubuntu20.04搭建gerrit code review

一、环境准备 1. 安装 Java 环境‌ Gerrit 依赖 Java 运行环境&#xff08;推荐 JDK 8&#xff09;&#xff1a; sudo apt install openjdk-11-jdk 验证安装&#xff1a; java -version ‌2. 安装 Git sudo apt install git ‌3. 可选依赖 数据库‌&#xff1a;Gerrit …

MacOS安装FFmpeg和FFprobe

按照网上很多教程安装&#xff0c;结果都失败了&#xff0c;后来才发现是路径问题&#xff0c;其实安装过程很简单&#xff08;无奈&#xff09; 第一步&#xff1a; 在官网下载 打开页面后&#xff0c;可以看到FFmpeg、FFprobe、FFplay和FFserver的下载图标 第二步&#xff1…

Redis7系列:设置开机自启

前面的文章讲了Redis和Redis Stack的安装&#xff0c;随着服务器的重启&#xff0c;导致Redis 客户端无法连接。原来的是Redis没有配置开机自启。此文记录一下如何配置开机自启。 1、修改配置文件 前面的Redis和Redis Stack的安装的文章中已经讲了redis.config的配置&#xf…

SpringAI介绍及本地模型使用方法

博客原文地址 前言 Spring在Java语言中一直稳居高位&#xff0c;与AI的洪流碰撞后也产生了一些有趣的”化学反应“&#xff0c;当然你要非要说碰撞属于物理反应也可以&#xff0c; 在经历了一系列复杂的反应方程后&#xff0c;Spring家族的新成员——SpringAI&#xff0c;就…

Unity 基础知识总结(持续更新中...)

引擎基础 Unity有哪几个主要窗口&#xff1f; Scene窗口 用于场景搭建和UI界面拼接 Game窗口 游戏运行预览 Hierarchy窗口 查看和调整场景对象层级结构 Project窗口 游戏工程资源 Inspector创建 属性查看器&#xff0c;属性设置、脚本组件挂载 Unity提供了几种光源…

IDEA接入阿里云百炼中免费的通义千问[2025版]

安装deepseek 上一篇文章IDEA安装deepseek最新教程2025中说明了怎么用idea安装codeGPT插件&#xff0c;并接入DeepSeek&#xff0c;无奈接入的官方api已经不能使用了&#xff0c;所以我们尝试从其他地方接入 阿里云百炼https://bailian.console.aliyun.com/ 阿里云百炼‌是阿…

3.03-3.09 Web3 游戏周报:Sunflower Land 周留存率 74.2%,谁是本周最稳链游?

回顾上周的区块链游戏概况&#xff0c;查看 Footprint Analytics 与 ABGA 最新发布的数据报告。 【3.03–3.09】Web3 游戏行业动态 Sui 背后开发公司 Mysten Labs 宣布收购游戏开发平台 ParasolYescoin 创始人因合伙人纠纷被警方带走&#xff0c;案件升级为刑事案件Animoca B…

NVIDIA k8s-device-plugin源码分析与安装部署

在《kubernetes Device Plugin原理与源码分析》一文中&#xff0c;我们从源码层面了解了kubelet侧关于device plugin逻辑的实现逻辑&#xff0c;本文以nvidia管理GPU的开源github项目k8s-device-plugin为例&#xff0c;来看看设备插件侧的实现示例。 一、Kubernetes Device Pl…

langChainv0.3学习笔记(初级篇)

LangChain自0.1版本发布以来&#xff0c;已经历了显著的进化&#xff0c;特别是向AI时代的适应性提升。在0.1版本中&#xff0c;LangChain主要聚焦于提供基本的链式操作和工具集成&#xff0c;帮助开发者构建简单的语言模型应用。该版本适用于处理简单任务&#xff0c;但在应对…

聚焦两会:科技与发展并进,赛逸展2025成创新新舞台

在十四届全国人大三次会议和全国政协十四届三次会议期间&#xff0c;代表委员们围绕多个关键议题展开深入讨论&#xff0c;为国家未来发展谋篇布局。其中&#xff0c;技术竞争加剧与经济转型需求成为两会焦点&#xff0c;将在首都北京举办的2025第七届亚洲消费电子技术贸易展&a…