python基于自己模型数据集和权重混淆矩阵生成

news2024/11/24 19:43:44

混淆矩阵(Confusion Matrix)是用于评估分类模型性能的一种表格形式。它显示了在分类问题中模型的预测结果与实际标签之间的各种组合情况。

混淆矩阵通常用于二分类问题,但也可以扩展到多分类问题。对于二分类问题,它由四个重要的指标组成:

真正例(True Positive, TP):模型预测为正例,并且实际上是正例的数量。
真反例(True Negative, TN):模型预测为反例,并且实际上是反例的数量。
假正例(False Positive, FP):模型预测为正例,但实际上是反例的数量。也称为"误报"。
假反例(False Negative, FN):模型预测为反例,但实际上是正例的数量。也称为"漏报"。

混淆矩阵的一般形式如下:
在这里插入图片描述

使用混淆矩阵可以计算多个衡量分类器性能的指标,如准确率(Accuracy)、精确率(Precision)、召回率(Recall,也称为敏感度或真正例率)和 F1 值等。这些指标可以通过混淆矩阵中的各个元素计算得出:

准确率(Accuracy):分类器预测正确的样本占总样本数的比例,计算公式为 (TP + TN) / (TP + TN + FP + FN) 。
精确率(Precision):正例预测正确的比例,计算公式为 TP / (TP + FP) 。
召回率(Recall):正例被正确预测为正例的比例,计算公式为 TP / (TP + FN) 。
F1 值:综合考虑了精确率和召回率的指标,计算公式为 2 (Precision Recall) / (Precision + Recall) 。

混淆矩阵提供了更详细和全面地评估分类模型性能的能力,帮助我们了解预测中的误报和漏报情况。通过分析混淆矩阵,我们可以获得对分类器在每个类别上的表现有关的宝贵见解,并对分类结果进行优化。

废话不多数,上代码:

def draw_confusion_matrix(label_true, label_pred, label_name, normlize, title="Confusion Matrix", pdf_save_path=None, dpi=100):
    """

    @param label_true: 真实标签,比如[0,1,2,7,4,5,...]
    @param label_pred: 预测标签,比如[0,5,4,2,1,4,...]
    @param label_name: 标签名字,比如['cat','dog','flower',...]
    @param normlize: 是否设元素为百分比形式
    @param title: 图标题
    @param pdf_save_path: 是否保存,是则为保存路径pdf_save_path=xxx.png | xxx.pdf | ...等其他plt.savefig支持的保存格式
    @param dpi: 保存到文件的分辨率,论文一般要求至少300dpi
    @return:

    example:
            draw_confusion_matrix(label_true=y_gt,
                          label_pred=y_pred,
                          label_name=["Angry", "Disgust", "Fear", "Happy", "Sad", "Surprise", "Neutral"],
                          normlize=True,
                          title="Confusion Matrix on Fer2013",
                          pdf_save_path="Confusion_Matrix_on_Fer2013.png",
                          dpi=300)

    """
    cm1=confusion_matrix(label_true, label_pred)
    cm = confusion_matrix(label_true, label_pred)
    if normlize:
        row_sums = np.sum(cm, axis=1)
        cm = cm / row_sums[:, np.newaxis]
    cm=cm.T
    cm1=cm1.T
    plt.imshow(cm, cmap='Blues')
    plt.title(title)
    plt.xlabel("Predict label")
    plt.ylabel("Truth label")
    plt.yticks(range(label_name.__len__()), label_name)
    plt.xticks(range(label_name.__len__()), label_name, rotation=45)

    plt.tight_layout()

    plt.colorbar()

    for i in range(label_name.__len__()):
        for j in range(label_name.__len__()):
            color = (1, 1, 1) if i == j else (0, 0, 0)	# 对角线字体白色,其他黑色
            value = float(format('%.1f' % (cm[i, j]*100)))
            value1=str(value)+'%\n'+str(cm1[i, j])
            plt.text(i, j, value1, verticalalignment='center', horizontalalignment='center', color=color)

    # plt.show()
    if not pdf_save_path is None:
        plt.savefig(pdf_save_path, bbox_inches='tight',dpi=dpi)



labels_name = ['bananaquit', 'Black Skimmer', 'Black Throated Bushtiti', 'Cockatoo']

y_gt=[]
y_pred=[]

model_weight_path = "./best_CBAM_model.pth"
models = Xception(num_classes = 4)
models.load_state_dict(torch.load(model_weight_path))




models.eval()
for index, (imgs, labels) in enumerate(test_dl):
    labels_pd = models(imgs)
    predict_np = np.argmax(labels_pd.cpu().detach().numpy(), axis=-1).tolist()
    labels_np = labels.numpy().tolist()

    y_pred.extend(predict_np)
    y_gt.extend(labels_np)
print("预测标签为:", y_pred)
print("真实标签为", y_gt)



draw_confusion_matrix(label_true=y_gt,
                      label_pred=y_pred,
                      label_name=labels_name,
                      normlize=True,
                      title="Confusion Matrix",
                      pdf_save_path="Confusion_Matrix.jpg",
                      dpi=300)

结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/846583.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

dotNet 之网络TCP

**硬件支持型号 点击 查看 硬件支持 详情** DTU701 产品详情 DTU702 产品详情 DTU801 产品详情 DTU802 产品详情 DTU902 产品详情 G5501 产品详情 ARM dotnet 编程 dotNet使用TCP,可以使用Socket和TcpClient 、TcpListener类 2种,对于高级用户&…

win11下docker安装testsigma自动化测试平台教程

Testsigma是一个基于云端的、支持测试左移的、以AI驱动测试的自动化平台,适用于Web、移动应用以及RESTful服务等各种应用的测试服务。 一、如何开始使用? 有三种方式:1、直接使用 Testsigma Cloud(目前已经不开放个人邮箱注册&am…

LeetCode[207]课程表

难度:Medium 题目: 你这个学期必须选修 numCourses 门课程,记为 0 到 numCourses - 1 。 在选修某些课程之前需要一些先修课程。 先修课程按数组 prerequisites 给出,其中 prerequisites[i] [ai, bi] ,表示如果要学习…

CRITICAL_SECTION 用法

#include <stdio.h> #include <windows.h> typedef RTL_CRITICAL_SECTION CRITICAL_SECTION; CRITICAL_SECTION g_cs; //声明关键段 // 共享资源 char g_cArray[10]; unsigned int g_Count 0; DWORD WINAPI ThreadProc10(LPVOID pParam) { // 进入临界区 …

返回一组数据中出现频率最多的元素(众数),可能是一个或多个statistics.multimode()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 返回一组数据中出现频率最多的 元素(众数)&#xff0c;可能是一个或多个 statistics.multimode() 选择题 下列说法错误的是? import statistics data [0, 1, 1, 2, 2, 3] print(【显示】d…

无涯教程-Perl - endservent函数

描述 此功能告诉系统您不再期望使用getservent从服务文件中读取条目。 语法 以下是此函数的简单语法- endservent返回值 此函数不返回任何值。 例 以下是显示其基本用法的示例代码- #!/usr/bin/perlwhile(($name, $aliases, $port_number,$protocol_name)getservent())…

【宝藏系列】嵌入式软件设计的 7 种架构模式

【宝藏系列】嵌入式软件设计的 7 种架构模式 文章目录 【宝藏系列】嵌入式软件设计的 7 种架构模式前言1️⃣分层架构2️⃣多层架构3️⃣管道/过滤器架构4️⃣客户端、过滤器架构5️⃣模型、视图、控制器架构&#xff08;MVC&#xff09;6️⃣事件驱动架构7️⃣微服务架构 前言…

【java】访问权限

java访问权限 publicprotecteddefaultprivate内部类 java中访问权限修饰符有以下4个&#xff1a;public、protected、default、private public public代表着公共的&#xff0c;在java源码中。公共类只能有一个&#xff0c;而且必须和源码文件名相同。 我们发现一直写的main方法…

基于Java+SpringBoot+Vue的企业客户信息反馈平台设计与实现(源码+LW+部署文档等)

博主介绍&#xff1a; 大家好&#xff0c;我是一名在Java圈混迹十余年的程序员&#xff0c;精通Java编程语言&#xff0c;同时也熟练掌握微信小程序、Python和Android等技术&#xff0c;能够为大家提供全方位的技术支持和交流。 我擅长在JavaWeb、SSH、SSM、SpringBoot等框架…

高薪Offer收割机之聚集索引和非聚集索引

什么是聚集索引&#xff0c;非聚集索引&#xff0c;回表查询&#xff0c;覆盖索引 聚集索引就是将数据存储与索引放到了一起&#xff0c;索引结构的叶子节点保存了行数据&#xff0c;一张表必须有且只有一个聚集索引。 如果存在主键&#xff0c;主键就是聚集索引&#xff0c;…

180天,小卡拉米 - 编程路线,学习计划!

作者&#xff1a;小傅哥 博客&#xff1a;https://bugstack.cn 沉淀、分享、成长&#xff0c;让自己和他人都能有所收获&#xff01;&#x1f604; 职业生涯这条路&#xff0c;我在前面10年开的路&#xff0c;将让你少走很多弯路&#x1f463;&#xff01; 工作了这么多年&…

Nginx开启gzip网页传输压缩配置

场景 Nginx 服务器为网页压缩专门提供了 gz 模块&#xff0c;并且模块中的相关指令均可以设置在http、server或location块中&#xff0c; 实现服务器端按照指定的设置进行压缩。 CentOS7中解压tar包的方式安装Nginx&#xff1a; CentOS7中解压tar包的方式安装Nginx_centos7…

ospf减少LSA更新

实验及实验要求 一、思路 1.根据区域划分IP地址 2.使公网可通---写缺省 3.使R3成为MGRE中心站点&#xff0c;R5、R6、R7为分支站点 4.一个个去配置ospf区域和RIP区域&#xff0c;确保每个区域配置无误 5.区域0要更改OSPF在接口的工作类型为broadcast &#xff0c;并使R3为…

WDM设备栈

图 1一块USB主控制器卡&#xff08;主控芯片为VIA VL805&#xff09; 图中的板卡包含USB主控制器&#xff08;USB Host Controller&#xff09;、USB集线器&#xff08;USB Hub&#xff09;&#xff0c;这里USB集线器扩展了4个USB端口&#xff08;USB Port&#xff09;。 PCI…

智慧工地源码:数字孪生智慧工地可视化解决方案

一、智慧工地建设背景 我国经济发展正从传统粗放式的高速增长阶段&#xff0c;进入高效率、低成本、可持续的中高速增长阶段。随着现代建筑的复杂度和体量等不断增加&#xff0c;施工现场管理的内容越来越多&#xff0c;管理的技术难度和要求在不断提高。传统的施工现场管理模…

C++ 测试框架 GoogleTest 初学者入门篇 丙

断言 什么是断言&#xff1f;断言是用来对表达式执行比较的代码块&#xff0c;调用时类似函数。当表达式一致时&#xff0c;断言返回成功&#xff0c;否则失败。 googletest 的断言是一组宏定义。分为 ASSERT_* 和 EXPECT_* 两种。 比如 ASSERT_EQ(1, 2);EXPECT_EQ(1, 2);上…

Element-UI简介

目录 安装 常用组件 Container 布局容器 Button 按钮 MessageBox 弹框 Form 表单验证 element-ui是一个前端的ui框架&#xff0c;封装了很多已经写好的ui组件&#xff0c;例如表单组件&#xff0c;布局组件&#xff0c;表格组件.......是一套桌面端组件。 Element - 网站…

JAVA SpringBoot 项目 多线程、线程池的使用。

1.1 线程&#xff1a; 线程就是进程中的单个顺序控制流&#xff0c;也可以理解成是一条执行路径 单线程&#xff1a;一个进程中包含一个顺序控制流&#xff08;一条执行路径&#xff09; 多线程&#xff1a;一个进程中包含多个顺序控制流&#xff08;多条执行路径&#xff0…

斗轮机无线控制系统技改方案

一、应用背景 马钢的前身是成立于1953年的马鞍山铁厂&#xff0c;2019年马钢集团正式成为中国宝武控股子公司。马钢产品以建筑用型线材为主&#xff0c;满足重型工业厂房、轻钢结构、高层建筑、桥梁结构、工业管道等构件的加工需要。目前马钢在岗员工4.8万人&#xff0c;具备了…

采用模块化方式编译

一、前言 比如&#xff1a;uImage下有很多驱动文件&#xff0c;但是驱动开发时&#xff0c;要频繁更改驱动文件&#xff0c;如果每次编译整个uImage编译会浪费时间&#xff0c;所以引入模块化方式编译&#xff0c;把驱动设置为模块化编译&#xff0c;这样每次更改或重新编译时…