测试集进行测试
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
from sklearn.metrics import (
precision_score,
recall_score,
f1_score,
roc_curve,
auc,
confusion_matrix,
)
import matplotlib.pyplot as plt
from utils import NiiDataset
from model.UNet import UNet
# 加载最佳模型
best_unet_model = r"D:\PytnonProject\Segment\best_unet_model.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=1, out_channels=1).to(device)
model.load_state_dict(torch.load(best_unet_model))
model.eval() # 设置为评估模式
# 定义测试数据集
test_image_paths = [
r"D:\Data\DegmentData\OriginalNii\DCE\CAO_ZHAN_GUO.nii", # 替换为实际的测试图像路径
r"D:\Data\DegmentData\OriginalNii\DCE\CHAI_GUI_LAN.nii",
# 添加其他测试数据路径...
]
test_mask_paths = [
r"D:\Data\DegmentData\ROI\CAO_ZHAN_GUO-label.nii", # 替换为实际的测试掩码路径
r"D:\Data\DegmentData\ROI\CHAI_GUI_LAN-label.nii",
# 添加其他测试掩码路径...
]
# 创建测试数据集和数据加载器
test_dataset = NiiDataset(test_image_paths, test_mask_paths)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False) # 批量大小为 1
# 定义评估指标
def dice_coefficient(y_true, y_pred):
intersection = np.sum(y_true * y_pred)
denominator = np.sum(y_true) + np.sum(y_pred)
if denominator == 0:
return 1.0 # 如果分母为零,返回 1(表示完全匹配)
return (2.0 * intersection) / denominator
def iou(y_true, y_pred):
intersection = np.sum(y_true * y_pred)
union = np.sum(y_true) + np.sum(y_pred) - intersection
if union == 0:
return 1.0 # 如果分母为零,返回 1(表示完全匹配)
return intersection / union
# 初始化指标
dice_scores = []
iou_scores = []
precisions = []
recalls = []
f1_scores = []
sensitivities = []
specificities = []
auc_scores = []
# 用于 ROC 曲线的数据
all_masks = []
all_predictions = []
all_probabilities = []
# 测试过程
with torch.no_grad(): # 禁用梯度计算
for images, masks in test_dataloader:
images = images.to(device) # 形状: [1, 1, 480, 480]
masks = masks.to(device) # 形状: [1, 1, 480, 480]
# 前向传播
outputs = model(images) # 形状: [1, 1, 480, 480]
# 将输出转换为概率和二进制掩码
probabilities = outputs.cpu().numpy().flatten() # 形状: [480 * 480]
predictions = (outputs > 0.5).float().cpu().numpy().flatten() # 形状: [480 * 480]
masks = masks.cpu().numpy().flatten() # 形状: [480 * 480]
# 保存用于 ROC 曲线的数据
all_masks.extend(masks)
all_predictions.extend(predictions)
all_probabilities.extend(probabilities)
# 检查 masks 和 predictions 是否只包含一个类别
if np.all(masks == 0) and np.all(predictions == 0):
# 如果 masks 和 predictions 都为全 0,则跳过该样本
continue
# 计算指标
dice = dice_coefficient(masks, predictions)
iou_score = iou(masks, predictions)
precision = precision_score(masks, predictions, zero_division=0)
recall = recall_score(masks, predictions, zero_division=0)
f1 = f1_score(masks, predictions, zero_division=0)
# 计算混淆矩阵
cm = confusion_matrix(masks, predictions)
if cm.size == 1:
# 如果混淆矩阵只有一个值(全 0 或全 1)
if np.all(masks == 0):
tn, fp, fn, tp = cm[0, 0], 0, 0, 0
else:
tn, fp, fn, tp = 0, 0, 0, cm[0, 0]
else:
tn, fp, fn, tp = cm.ravel()
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0 # 灵敏度
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0 # 特异度
# 保存指标
dice_scores.append(dice)
iou_scores.append(iou_score)
precisions.append(precision)
recalls.append(recall)
f1_scores.append(f1)
sensitivities.append(sensitivity)
specificities.append(specificity)
# 计算 AUC
fpr, tpr, thresholds = roc_curve(all_masks, all_probabilities)
roc_auc = auc(fpr, tpr)
auc_scores.append(roc_auc)
# 打印平均指标
print(f"Average Dice Coefficient: {np.mean(dice_scores):.4f}")
print(f"Average IoU: {np.mean(iou_scores):.4f}")
print(f"Average Precision: {np.mean(precisions):.4f}")
print(f"Average Recall: {np.mean(recalls):.4f}")
print(f"Average F1 Score: {np.mean(f1_scores):.4f}")
print(f"Average Sensitivity: {np.mean(sensitivities):.4f}")
print(f"Average Specificity: {np.mean(specificities):.4f}")
print(f"Average AUC: {np.mean(auc_scores):.4f}")
# 创建 checkpoint 文件夹(如果不存在)
checkpoint_dir = "checkpoint"
os.makedirs(checkpoint_dir, exist_ok=True)
# 保存 ROC 曲线的数据到 checkpoint 文件夹下的 .npz 文件
roc_data_path = os.path.join(checkpoint_dir, "roc_data.npz")
np.savez(roc_data_path, fpr=fpr, tpr=tpr, thresholds=thresholds, roc_auc=roc_auc)
print(f"ROC 曲线的数据已保存到文件: {roc_data_path}")
# 绘制 ROC 曲线
plt.figure()
plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC) Curve")
plt.legend(loc="lower right")
plt.show()