医学图像分割任务的测试代码

news2025/2/3 12:10:55

测试集进行测试

import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from sklearn.metrics import (
    precision_score,
    recall_score,
    f1_score,
    roc_curve,
    auc,
    confusion_matrix,
)
import matplotlib.pyplot as plt
from utils import NiiDataset
from model.UNet import UNet

# 加载最佳模型
best_unet_model = r"D:\PytnonProject\Segment\best_unet_model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load(best_unet_model))
model.eval()  # 设置为评估模式

# 定义测试数据集
test_image_paths = [
    r"D:\Data\DegmentData\OriginalNii\DCE\CAO_ZHAN_GUO.nii",  # 替换为实际的测试图像路径
    r"D:\Data\DegmentData\OriginalNii\DCE\CHAI_GUI_LAN.nii",
    # 添加其他测试数据路径...
]

test_mask_paths = [
    r"D:\Data\DegmentData\ROI\CAO_ZHAN_GUO-label.nii",  # 替换为实际的测试掩码路径
    r"D:\Data\DegmentData\ROI\CHAI_GUI_LAN-label.nii",
    # 添加其他测试掩码路径...
]

# 创建测试数据集和数据加载器
test_dataset = NiiDataset(test_image_paths, test_mask_paths)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)  # 批量大小为 1

# 定义评估指标
def dice_coefficient(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    denominator = np.sum(y_true) + np.sum(y_pred)
    if denominator == 0:
        return 1.0  # 如果分母为零,返回 1(表示完全匹配)
    return (2.0 * intersection) / denominator

def iou(y_true, y_pred):
    intersection = np.sum(y_true * y_pred)
    union = np.sum(y_true) + np.sum(y_pred) - intersection
    if union == 0:
        return 1.0  # 如果分母为零,返回 1(表示完全匹配)
    return intersection / union

# 初始化指标
dice_scores = []
iou_scores = []
precisions = []
recalls = []
f1_scores = []
sensitivities = []
specificities = []
auc_scores = []

# 用于 ROC 曲线的数据
all_masks = []
all_predictions = []
all_probabilities = []

# 测试过程
with torch.no_grad():  # 禁用梯度计算
    for images, masks in test_dataloader:
        images = images.to(device)  # 形状: [1, 1, 480, 480]
        masks = masks.to(device)    # 形状: [1, 1, 480, 480]

        # 前向传播
        outputs = model(images)  # 形状: [1, 1, 480, 480]

        # 将输出转换为概率和二进制掩码
        probabilities = outputs.cpu().numpy().flatten()  # 形状: [480 * 480]
        predictions = (outputs > 0.5).float().cpu().numpy().flatten()  # 形状: [480 * 480]
        masks = masks.cpu().numpy().flatten()  # 形状: [480 * 480]

        # 保存用于 ROC 曲线的数据
        all_masks.extend(masks)
        all_predictions.extend(predictions)
        all_probabilities.extend(probabilities)

        # 检查 masks 和 predictions 是否只包含一个类别
        if np.all(masks == 0) and np.all(predictions == 0):
            # 如果 masks 和 predictions 都为全 0,则跳过该样本
            continue

        # 计算指标
        dice = dice_coefficient(masks, predictions)
        iou_score = iou(masks, predictions)
        precision = precision_score(masks, predictions, zero_division=0)
        recall = recall_score(masks, predictions, zero_division=0)
        f1 = f1_score(masks, predictions, zero_division=0)

        # 计算混淆矩阵
        cm = confusion_matrix(masks, predictions)
        if cm.size == 1:
            # 如果混淆矩阵只有一个值(全 0 或全 1)
            if np.all(masks == 0):
                tn, fp, fn, tp = cm[0, 0], 0, 0, 0
            else:
                tn, fp, fn, tp = 0, 0, 0, cm[0, 0]
        else:
            tn, fp, fn, tp = cm.ravel()

        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0  # 灵敏度
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0  # 特异度

        # 保存指标
        dice_scores.append(dice)
        iou_scores.append(iou_score)
        precisions.append(precision)
        recalls.append(recall)
        f1_scores.append(f1)
        sensitivities.append(sensitivity)
        specificities.append(specificity)

# 计算 AUC
fpr, tpr, thresholds = roc_curve(all_masks, all_probabilities)
roc_auc = auc(fpr, tpr)
auc_scores.append(roc_auc)

# 打印平均指标
print(f"Average Dice Coefficient: {np.mean(dice_scores):.4f}")
print(f"Average IoU: {np.mean(iou_scores):.4f}")
print(f"Average Precision: {np.mean(precisions):.4f}")
print(f"Average Recall: {np.mean(recalls):.4f}")
print(f"Average F1 Score: {np.mean(f1_scores):.4f}")
print(f"Average Sensitivity: {np.mean(sensitivities):.4f}")
print(f"Average Specificity: {np.mean(specificities):.4f}")
print(f"Average AUC: {np.mean(auc_scores):.4f}")

# 创建 checkpoint 文件夹(如果不存在)
checkpoint_dir = "checkpoint"
os.makedirs(checkpoint_dir, exist_ok=True)

# 保存 ROC 曲线的数据到 checkpoint 文件夹下的 .npz 文件
roc_data_path = os.path.join(checkpoint_dir, "roc_data.npz")
np.savez(roc_data_path, fpr=fpr, tpr=tpr, thresholds=thresholds, roc_auc=roc_auc)
print(f"ROC 曲线的数据已保存到文件: {roc_data_path}")

# 绘制 ROC 曲线
plt.figure()
plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC) Curve")
plt.legend(loc="lower right")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2291258.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

为AI聊天工具添加一个知识系统 之83 详细设计之24 度量空间之1 因果关系和过程:认知金字塔

本文要点 度量空间 在本项目(为AI聊天工具添加一个知识系统 )中 是出于对“用”的考量 来考虑的。这包括: 相对-位置 力用(“相”)。正如 法力,相关-速度 体用 (“体”)。例如 重…

如何配置Java JDK

步骤1:点击资源,点击Java下载 https://www.oracle.com/ 步骤2:点击java下载、JDK23下载,下载第一行第一个 步骤3:解压到一个空文件夹下,复制lib地址 步骤4:在设置里面搜索“高级系统设置”;点击…

CodeGPT使用本地部署DeepSeek Coder

目前NV和github都托管了DeepSeek,生成Key后可以很方便的用CodeGPT接入。CodeGPT有三种方式使用AI,分别时Agents,Local LLMs(本地部署AI大模型),LLMs Cloud Model(云端大模型,从你自己…

JAVA安全—反射机制攻击链类对象成员变量方法构造方法

前言 还是JAVA安全,哎,真的讲不完,太多啦。 今天主要是讲一下JAVA中的反射机制,因为反序列化的利用基本都是要用到这个反射机制,还有一些攻击链条的构造,也会用到,所以就讲一下。 什么是反射…

【深度学习】softmax回归的简洁实现

softmax回归的简洁实现 我们发现(通过深度学习框架的高级API能够使实现)(softmax)线性(回归变得更加容易)。 同样,通过深度学习框架的高级API也能更方便地实现softmax回归模型。 本节继续使用Fashion-MNIST数据集,并保持批量大小为256。 import torch …

基础篇03-图像的基本运算

本节将简要介绍Halcon中有关图像的两类基本运算,分别是代数运算和逻辑运算。除此之外,还介绍几种特殊的代数运算。 目录 1.引言 2. 基本运算 2.1 加法运算 2.2 减法运算 2.3 乘法运算 2.4 除法运算 2.5 综合实例 3. 逻辑运算 3.1 逻辑与运算 …

工具的应用——安装copilot

一、介绍Copilot copilot是一个AI辅助编程的助手,作为需要拥抱AI的程序员可以从此尝试进入,至于好与不好,应当是小马过河,各有各的心得。这里不做评述。重点在安装copilot的过程中遇到了一些问题,然后把它总结下&…

Alibaba开发规范_编程规约之命名风格

文章目录 命名风格的基本原则1. 命名不能以下划线或美元符号开始或结束2. 严禁使用拼音与英文混合或直接使用中文3. 类名使用 UpperCamelCase 风格,但以下情形例外:DO / BO / DTO / VO / AO / PO / UID 等4. 方法名、参数名、成员变量、局部变量使用 low…

MATLAB中的IIR滤波器设计

在数字信号处理中,滤波器是消除噪声、提取特征或调整信号频率的核心工具。其中,无限脉冲响应(IIR)滤波器因其低阶数实现陡峭滚降的特性,被广泛应用于音频处理、通信系统和生物医学工程等领域。借助MATLAB强大的工具箱&…

vector容器(详解)

本文最后是模拟实现全部讲解,文章穿插有彩色字体,是我总结的技巧和关键 1.vector的介绍及使用 1.1 vector的介绍 https://cplusplus.com/reference/vector/vector/(vector的介绍) 了解 1. vector是表示可变大小数组的序列容器。…

python学opencv|读取图像(五十二)使用cv.matchTemplate()函数实现最佳图像匹配

【1】引言 前序学习了图像的常规读取和基本按位操作技巧,相关文章包括且不限于: python学opencv|读取图像-CSDN博客 python学opencv|读取图像(四十九)原理探究:使用cv2.bitwise()系列函数实现图像按位运算-CSDN博客…

【VUE案例练习】前端vue2+element-ui,后端nodo+express实现‘‘文件上传/删除‘‘功能

近期在做跟毕业设计相关的数据后台管理系统,其中的列表项展示有图片展示,添加/编辑功能有文件上传。 “文件上传/删除”也是我们平时开发会遇到的一个功能,这里分享个人的实现过程,与大家交流谈论~ 一、准备工作 本次案例使用的…

使用真实 Elasticsearch 进行高级集成测试

作者:来自 Elastic Piotr Przybyl 掌握高级 Elasticsearch 集成测试:更快、更智能、更优化。 在上一篇关于集成测试的文章中,我们介绍了如何通过改变数据初始化策略来缩短依赖于真实 Elasticsearch 的集成测试的执行时间。在本期中&#xff0…

【R语言】函数

一、函数格式 如下所示: hello:函数名;function:定义的R对象是函数而不是其它变量;():函数的输入参数,可以为空,也可以包含参数;{}:函数体,如果…

VSCode插件Live Server

简介:插件Live Server能够实现当我们在VSCode编辑器里修改 HTML、CSS 或者 JavaScript 文件时,它都能自动实时地刷新浏览器页面,让我们实时看到代码变化的效果。再也不用手动刷新浏览器了,节省了大量的开发过程耗时! 1…

50. 正点原子官方系统镜像烧写实验

一、Windows下使用OTG烧写系统 1、在Windos使用NXP提供的mfgtool来向开发烧写系统。需要用先将开发板的USB_OTG接口连接到电脑上。 Mfgtool工具是向板子先下载一个Linux系统,然后通过这个系统来完成烧写工作。 切记!使用OTG烧写的时候要先把SD卡拔出来&…

扩散模型(三)

相关阅读: 扩散模型(一) 扩散模型(二) Latent Variable Space 潜在扩散模型(LDM;龙巴赫、布拉特曼等人,2022 年)在潜在空间而非像素空间中运行扩散过程,这…

it基础使用--5---git远程仓库

it基础使用–5—git远程仓库 1. 按顺序看 -git基础使用–1–版本控制的基本概念 -git基础使用–2–gti的基本概念 -git基础使用–3—安装和基本使用 -git基础使用–4—git分支和使用 2. 什么是远程仓库 在第一篇文章中,我们已经讲过了远程仓库,每个本…

Baklib如何改变内容管理平台的未来推动创新与效率提升

内容概要 在信息爆炸的时代,内容管理平台成为了企业和个人不可或缺的工具。它通过高效组织、存储和发布内容,帮助用户有效地管理信息流。随着技术的发展,传统的内容管理平台逐渐暴露出灵活性不足、易用性差等局限性,这促使市场需…

16.[前端开发]Day16-HTML+CSS阶段练习(网易云音乐五)

完整代码 网易云-main-left-rank&#xff08;排行榜&#xff09; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name&q…