对于二分类任务:
- 真阳性(True Positives, TP)为真实值为1,预测值为1,即正确预测出的正样本个数
- 真阴性(True Negatives, TN)为真实值为0,预测值为0,即正确预测出的负样本个数
- 假阳性(False Posiives, FP)为真实值为0,预测值为1,即错误预测出的正样本个数(即数理统计的第一类错误)
- 假阴性(False Negatives, FN)为真实值为1,预测值为0,即错误预测出的负样本个数(即数理统计的第二类错误
1.混淆矩阵的输出
#首先要安装scikit-learn, scikit-image这两个包
#导入相应的库函数
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.metrics import precision_recall_curve #计算pr曲线
from sklearn.metrics import confusion_matrix #计算混淆矩阵
#混淆矩阵的计算
a = confusion_matrix(y_true,y_pred)
#输出的a为混淆矩阵,数组格式
范例:
>>> y_true = [2, 0, 2, 2, 0, 1]
>>> y_pred = [0, 0, 2, 2, 0, 2]
>>> confusion_matrix(y_true, y_pred)
array([[2, 0, 0],
[0, 0, 1],
[1, 0, 2]])
>>> y_true = ["cat", "ant", "cat", "cat", "ant", "bird"]
>>> y_pred = ["ant", "ant", "cat", "cat", "ant", "cat"]
>>> confusion_matrix(y_true, y_pred, labels=["ant", "bird", "cat"])
array([[2, 0, 0],
[0, 0, 1],
[1, 0, 2]])
#其中行是真实标签,列为预测值
对于二分类:
>>> tn, fp, fn, tp = confusion_matrix([0, 1, 0, 1], [1, 1, 1, 0]).ravel()
#注意y_true,y_pred中都为整数,其格式为一维列表(list),或一维数组(array)
import numpy as np
y_pred = np.asarray(y_pred) #可以使用此代码将其变为数组
#以异常检测为例
#如果y_pred输出的是分数,比如异常分数,必须将其转换为整数
#先找出最优阈值,大于阈值的为异常,设为1,小于阈值的为正常,设为0
#寻找最优阈值
precision, recall, thresholds = precision_recall_curve(y_true, scores)
a = 2 * precision * recall
b = precision + recall
f1 = np.divide(a, b, out=np.zeros_like(a), where=b != 0)
threshold = thresholds[np.argmax(f1)]
#其中scores是网络输出的异常分数,precision_recall_curve的输出precision, recall, thresholds 为一维数组,分别为不同阈值下的准确率与召回率,然后计算F1分数,根据F1分数的最大值选最优阈值
#得到最优阈值后,根据最优阈值将scores变为y_pred(均为整数)
#方法1:
scoress = np.where(scores > threshold, 1, scores) #先将大于阈值的设为1,小于阈值的不变
y_pred = np.where(scoress <= threshold, 0, scoress)
#在上一步变化的基础上,将小于阈值的设为0
#方法2:
mask[mask > threshold] = 1
mask[mask <= threshold] = 0 #大于阈值的设为1,小于阈值的设为0
混淆矩阵,结果可视化参考论文代码如下
论文1:PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization
https://arxiv.org/pdf/2011.08785v1.pdf
代码:https://github.com/xiahaifeng1995/PaDiM-Anomaly-Detection-Localization-master
论文2:SimpleNet: A Simple Network for Image Anomaly Detection and Localization
https://arxiv.org/pdf/2303.15140v2.pdf
代码: https://github.com/donaldrr/simplenet