第100+23步 ChatGPT学习:概率校准 Sigmoid Calibration

news2025/1/13 13:25:35

基于Python 3.9版本演示

一、写在前面

最近看了一篇在Lancet子刊《eClinicalMedicine》上发表的机器学习分类的文章:《Development of a novel dementia risk prediction model in the general population: A large, longitudinal, population-based machine-learning study》。

学到一种叫做“概率校准”的骚操作,顺手利用GPT系统学习学习。

文章中用的技术是:保序回归(Isotonic regression)。

为了体现举一反三,顺便问了GPT还有哪些方法也可以实现概率校准。它给我列举了很多,那么就一个一个学习吧。

这一期,介绍一个叫做 Sigmoid Calibration 的方法。

二、Sigmoid Calibration

Sigmoid Calibration是一种后处理技术,用于改进机器学习分类器的概率估计。它通常应用于二元分类器的输出,将原始得分转换为校准后的概率。该技术使用逻辑(Sigmoid)函数将分类器的得分映射到概率上,旨在确保预测的概率更准确地反映真实结果的可能性。

(1)Sigmoid Calibration 的基本步骤

1)训练分类器:在训练数据上训练你的二元分类器。

2)获取原始得分:收集分类器在验证数据集上的原始得分或 logits。

3)拟合逻辑回归模型:使用验证数据集拟合一个逻辑回归模型,将原始得分映射为概率。

4)预测校准后的概率:使用拟合的逻辑回归模型,将分类器的原始得分转换为校准后的概率。

(2)Sigmoid Calibration 的使用

对于逻辑回归模型,通常不需要进行Sigmoid校准,因为逻辑回归本身就是基于Sigmoid函数来计算概率的。然而,在一些情况下,即使是逻辑回归模型,校准仍然可能有帮助。以下是一些可能需要校准的情况:

1)类不平衡问题:如果训练数据集中存在严重的类别不平衡问题,即某个类别的数据明显多于其他类别,逻辑回归模型的概率估计可能会偏向于较多的类别。在这种情况下,校准可以帮助调整概率估计,使其更准确地反映实际的类别分布。

2)模型训练不充分:如果逻辑回归模型没有充分训练,可能会导致概率估计不准确。校准可以在一定程度上纠正这种情况。

3)训练和测试数据分布不同:如果训练数据和测试数据的分布存在差异,逻辑回归模型的概率估计可能不适用于测试数据。在这种情况下,可以使用校准技术对模型的输出进行调整。

4)多模型集成:在多模型集成(例如集成学习)中,不同模型的输出需要组合在一起。校准可以确保不同模型的输出概率具有一致性,从而提高集成模型的性能。

三、Sigmoid Calibration代码实现

下面,我编一个1比3的不太平衡的数据进行测试,对照组使用不进行校准的SVM模型,实验组就是加入校准的SVM模型,看看性能能够提高多少?

(1)不进行校准的SVM模型(默认参数)

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve

# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values

# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=666)

# 标准化数据
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)

# 使用SVM分类器
classifier = SVC(kernel='linear', probability=True)
classifier.fit(X_train, y_train)

# 预测结果
y_pred = classifier.predict(X_test)
y_testprba = classifier.decision_function(X_test)

y_trainpred = classifier.predict(X_train)
y_trainprba = classifier.decision_function(X_train)

# 混淆矩阵
cm_test = confusion_matrix(y_test, y_pred)
cm_train = confusion_matrix(y_train, y_trainpred)
print(cm_train)
print(cm_test)

# 绘制测试集混淆矩阵
classes = list(set(y_test))
classes.sort()
plt.imshow(cm_test, cmap=plt.cm.Blues)
indices = range(len(cm_test))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_test)):
    for second_index in range(len(cm_test[first_index])):
        plt.text(first_index, second_index, cm_test[first_index][second_index])

plt.show()

# 绘制训练集混淆矩阵
classes = list(set(y_train))
classes.sort()
plt.imshow(cm_train, cmap=plt.cm.Blues)
indices = range(len(cm_train))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_train)):
    for second_index in range(len(cm_train[first_index])):
        plt.text(first_index, second_index, cm_train[first_index][second_index])

plt.show()

# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):
    a = cm[0, 0]
    b = cm[0, 1]
    c = cm[1, 0]
    d = cm[1, 1]
    acc = (a + d) / (a + b + c + d)
    error_rate = 1 - acc
    sen = d / (d + c)
    sep = a / (a + b)
    precision = d / (b + d)
    F1 = (2 * precision * sen) / (precision + sen)
    MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))
    auc_score = roc_auc_score(y_true, y_pred_prob)
    
    metrics = {
        "Accuracy": acc,
        "Error Rate": error_rate,
        "Sensitivity": sen,
        "Specificity": sep,
        "Precision": precision,
        "F1 Score": F1,
        "MCC": MCC,
        "AUC": auc_score
    }
    return metrics

metrics_test = calculate_metrics(cm_test, y_test, y_testprba)
metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)

print("Performance Metrics (Test):")
for key, value in metrics_test.items():
    print(f"{key}: {value:.4f}")

print("\nPerformance Metrics (Train):")
for key, value in metrics_train.items():
print(f"{key}: {value:.4f}")

结果输出:

记住这些个数字。

这个参数的SVM还没有LR好。

(2)进行校准的SVM模型(默认参数)

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import confusion_matrix, roc_auc_score, roc_curve
from sklearn.calibration import CalibratedClassifierCV

# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values

# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.30, random_state=666)

# 标准化数据
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)

# 使用SVM分类器
classifier = SVC(kernel='rbf', C= 0.1, probability=True)
classifier.fit(X_train, y_train)

# 进行Sigmoid校准
calibrated_classifier = CalibratedClassifierCV(base_estimator=classifier, method='sigmoid', cv='prefit')
calibrated_classifier.fit(X_test, y_test)

# 预测结果
y_pred = calibrated_classifier.predict(X_test)
y_testprba = calibrated_classifier.predict_proba(X_test)[:, 1]

y_trainpred = calibrated_classifier.predict(X_train)
y_trainprba = calibrated_classifier.predict_proba(X_train)[:, 1]

# 混淆矩阵
cm_test = confusion_matrix(y_test, y_pred)
cm_train = confusion_matrix(y_train, y_trainpred)
print(cm_train)
print(cm_test)

# 绘制测试集混淆矩阵
classes = list(set(y_test))
classes.sort()
plt.imshow(cm_test, cmap=plt.cm.Blues)
indices = range(len(cm_test))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_test)):
    for second_index in range(len(cm_test[first_index])):
        plt.text(first_index, second_index, cm_test[first_index][second_index])

plt.show()

# 绘制训练集混淆矩阵
classes = list(set(y_train))
classes.sort()
plt.imshow(cm_train, cmap=plt.cm.Blues)
indices = range(len(cm_train))
plt.xticks(indices, classes)
plt.yticks(indices, classes)
plt.colorbar()
plt.xlabel('Predicted')
plt.ylabel('Actual')
for first_index in range(len(cm_train)):
    for second_index in range(len(cm_train[first_index])):
        plt.text(first_index, second_index, cm_train[first_index][second_index])

plt.show()

# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):
    a = cm[0, 0]
    b = cm[0, 1]
    c = cm[1, 0]
    d = cm[1, 1]
    acc = (a + d) / (a + b + c + d)
    error_rate = 1 - acc
    sen = d / (d + c)
    sep = a / (a + b)
    precision = d / (b + d)
    F1 = (2 * precision * sen) / (precision + sen)
    MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))
    auc_score = roc_auc_score(y_true, y_pred_prob)
    
    metrics = {
        "Accuracy": acc,
        "Error Rate": error_rate,
        "Sensitivity": sen,
        "Specificity": sep,
        "Precision": precision,
        "F1 Score": F1,
        "MCC": MCC,
        "AUC": auc_score
    }
    return metrics

metrics_test = calculate_metrics(cm_test, y_test, y_testprba)
metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)

print("Performance Metrics (Test):")
for key, value in metrics_test.items():
    print(f"{key}: {value:.4f}")

print("\nPerformance Metrics (Train):")
for key, value in metrics_train.items():
    print(f"{key}: {value:.4f}")

看看结果:

总体来看,仅仅训练集起作用了,验证集差强人意。

四、换个策略

参考那篇文章的策略:采用五折交叉验证来建立和评估模型,其中四折用于训练,一折用于评估,在训练集中,其中三折用于建立SVM模型,另一折采用Sigmoid Calibration概率校正,在训练集内部采用交叉验证对超参数进行调参。

代码:

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split, KFold, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import confusion_matrix, roc_auc_score, make_scorer

# 加载数据
dataset = pd.read_csv('8PSMjianmo.csv')
X = dataset.iloc[:, 1:20].values
Y = dataset.iloc[:, 0].values

# 标准化数据
sc = StandardScaler()
X = sc.fit_transform(X)

# 五折交叉验证
kf = KFold(n_splits=5, shuffle=True, random_state=666)

# 超参数调优参数网格
param_grid = {
    'C': [0.1, 1, 10, 100],
    'kernel': ['linear', 'rbf']
}

# 计算并打印性能参数
def calculate_metrics(cm, y_true, y_pred_prob):
    a = cm[0, 0]
    b = cm[0, 1]
    c = cm[1, 0]
    d = cm[1, 1]
    acc = (a + d) / (a + b + c + d)
    error_rate = 1 - acc
    sen = d / (d + c)
    sep = a / (a + b)
    precision = d / (b + d)
    F1 = (2 * precision * sen) / (precision + sen)
    MCC = (d * a - b * c) / (np.sqrt((d + b) * (d + c) * (a + b) * (a + c)))
    auc_score = roc_auc_score(y_true, y_pred_prob)
    
    metrics = {
        "Accuracy": acc,
        "Error Rate": error_rate,
        "Sensitivity": sen,
        "Specificity": sep,
        "Precision": precision,
        "F1 Score": F1,
        "MCC": MCC,
        "AUC": auc_score
    }
    return metrics

# 初始化结果列表
results_train = []
results_test = []

# 初始化变量以跟踪最优模型
best_auc = 0
best_model = None
best_X_train = None
best_X_test = None
best_y_train = None
best_y_test = None

# 交叉验证过程
for train_index, test_index in kf.split(X):
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = Y[train_index], Y[test_index]

    # 内部交叉验证进行超参数调优和模型训练
    inner_kf = KFold(n_splits=4, shuffle=True, random_state=666)
    grid_search = GridSearchCV(SVC(probability=True), param_grid, cv=inner_kf, scoring='roc_auc')
    grid_search.fit(X_train, y_train)
    model = grid_search.best_estimator_

    # Sigmoid Calibration 概率校准
    calibrated_svm = CalibratedClassifierCV(model, method='sigmoid', cv='prefit')
    calibrated_svm.fit(X_train, y_train)

    # 评估模型
    y_trainpred = calibrated_svm.predict(X_train)
    y_trainprba = calibrated_svm.predict_proba(X_train)[:, 1]
    cm_train = confusion_matrix(y_train, y_trainpred)
    metrics_train = calculate_metrics(cm_train, y_train, y_trainprba)
    results_train.append(metrics_train)
    
    y_pred = calibrated_svm.predict(X_test)
    y_testprba = calibrated_svm.predict_proba(X_test)[:, 1]
    cm_test = confusion_matrix(y_test, y_pred)
    metrics_test = calculate_metrics(cm_test, y_test, y_testprba)
    results_test.append(metrics_test)
    
    # 更新最优模型
    if metrics_test['AUC'] > best_auc:
        best_auc = metrics_test['AUC']
        best_model = calibrated_svm
        best_X_train = X_train
        best_X_test = X_test
        best_y_train = y_train
        best_y_test = y_test
        best_params = grid_search.best_params_

    print("Performance Metrics (Train):")
    for key, value in metrics_train.items():
        print(f"{key}: {value:.4f}")
    
    print("\nPerformance Metrics (Test):")
    for key, value in metrics_test.items():
        print(f"{key}: {value:.4f}")
    print("\n" + "="*40 + "\n")

# 使用最优模型评估性能
y_trainpred = best_model.predict(best_X_train)
y_trainprba = best_model.predict_proba(best_X_train)[:, 1]
cm_train = confusion_matrix(best_y_train, y_trainpred)
metrics_train = calculate_metrics(cm_train, best_y_train, y_trainprba)

y_pred = best_model.predict(best_X_test)
y_testprba = best_model.predict_proba(best_X_test)[:, 1]
cm_test = confusion_matrix(best_y_test, y_pred)
metrics_test = calculate_metrics(cm_test, best_y_test, y_testprba)

print("Performance Metrics of the Best Model (Train):")
for key, value in metrics_train.items():
    print(f"{key}: {value:.4f}")

print("\nPerformance Metrics of the Best Model (Test):")
for key, value in metrics_test.items():
    print(f"{key}: {value:.4f}")

# 打印最优模型的参数
print("\nBest Model Parameters:")
for key, value in best_params.items():
    print(f"{key}: {value}")

输出:

还是有提升的,不过并没有Platt Scaling的结果好。

五、最后

各位可以去试一试在其他数据或者在其他机器学习分类模型中使用的效果。

数据不分享啦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2084440.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java中超级重要的SPI机制

SPI机制是理解各个框架集成的重要思想,只用理解了SPI才能理解框架的集成和扩展。直观的就是SPI机制可以让你更好的理解springboot集成各个扩展。 废话不多说!! 1.什么是spi机制? Spi机制是java提供的一种服务发现机制&#xff0…

企业常用的10款主流图纸加密软件排行榜|企业图纸防泄密

图纸是企业的重要资产,其安全性直接关系到企业的核心竞争力。下面为大家介绍10款主流的图纸加密软件,每款软件都有其独特的功能和优势,帮助企业有效防止图纸泄密。 1. 安秉图纸加密软件 安秉图纸加密软件是一款专为企业用户设计的图纸加密解…

全球石油巨头哈里伯顿因网络攻击被迫关闭系统

美国能源服务巨头哈里伯顿公司在周五向联邦监管机构提交的文件中表示,在本周遭受网络攻击后,该公司主动关闭了某些系统以“帮助保护它们”。 该公司在提交给美国证券交易委员会的文件中表示,周三该公司获悉“未经授权的第三方获得了其系统某些部分的访问权限”,并正在与外…

如何开启让设备获取到IPv6?

前言 现在许多小伙伴拉的宽带基本上都是光猫进行拨号的。这个就导致很多小伙伴不知道如何让设备获取IPv6。 但好像还有小伙伴分不清光猫拨号和光猫桥接的区别,其实它们的区别就在于让设备直连光猫的网口或者光猫的Wi-Fi,就会出现两种情况: …

实战派六西格玛:培训只是热身,应用才是关键!

在当今竞争激烈的市场环境中,六西格玛作为一套卓越的质量管理策略与工具,其影响力已远远超越了单纯的制造领域,渗透至各行各业的运营管理之中。然而,许多企业在追逐六西格玛光环的过程中,却常常偏离了其核心轨道&#…

中国全球投资追踪相关数据(2005-2023年)

中国全球投资追踪的相关数据可以为了解中国在全球范围内的投资活动提供重要视角。根据美国企业研究所(American Enterprise Institute,AEI)编制的《中国全球投资追踪》数据库,该数据库详细追踪了2005年至2023年间中国的海外直接投…

wx.updateAppMessageShareData 自定义分享内容安卓无效

记录wx.updateAppMessageShareData 自定义分享内容安卓无效 bug,主要是因为微信公众平台要配置分享的链接域名。 情况微信公众号使用wxjsdk后分享api全都注入成功,自定义分享内容时ios正常安卓分享出去的是当前页面的url。 主要原因在于你自定义这个分享…

在linux 中如何将.c 文件转换为可执行文件

目录 一、引言 二、准备工作 三、编译单个.c 文件 1.预处理 2.编译 3.汇编 4.链接 四、编译多个.c 文件 五、调试和优化 六、总结 一、引言 在 Linux 环境下进行 C 语言编程时,将 .c 文件转换为可执行文件是一个关键的步骤。这个过程涉及到使用编译器和一…

携程:从MySQL迁移OceanBase的数据库发布系统实践

作者简介:杨晓军 现就职于携程的数据库团队,主要负责携程数据库的研发与管理,专注于提升数据库的稳定性。 自分布式关系型数据库OceanBase开源以来,携程已经在线上环境中进行了广泛的应用,取代了原先以MySQL为主力的业…

虚幻5|技能栏优化(1)---优化技能UI,并添加多个技能

一.添加多一个技能格子并进行初始化清楚 1.打开技能UI把原先的事件构造后面的蓝图,全部选中,右键创建一个函数,命名为初始化 2.添加以下两个蓝图,用于清楚技能格子内容 2.在之前,事件构造后面的蓝图,不需…

人工智能是如何预测足球比赛的?看完这篇文章,你就全懂了!AutoPrediction

2024年欧洲杯开赛至今,德叔已经用人工智能预测了小组赛阶段的36场比赛,以及淘汰赛阶段的8场比赛,并且通过在网络上的发文,记录了所有这些比赛的预测结果。这些文章引来了不少朋友的围观,也让很多人对人工智能预测球赛这…

收藏夹里的“小网站”被误报违规不让上怎么办?如何将Chrome和Edge安装到 D 盘(含用户数据),重装系统也不会丢失收藏夹和密码?

当你用国产浏览器访问网站的时候,有时候会显示这个: 如果确实是违规网站,不让访问也没什么,但是很多都是误报啊,你这样直接来个大红横幅,还让人活不? 那遇到这种误报应当怎么办呢?有…

echarts 中 鼠标悬浮上加单位

tooltip: {trigger: "axis",valueFormatter: function (value: any) {return value "℃";},}, 效果:

开放式耳机的优缺点?开放式耳机王者带你一探究竟

盛夏时节,天气越来越热,小伙伴们都在抱怨,实在没法戴口罩了。实际上,大家只关注了呼吸,却忽视了一个问题,其实,我们的耳朵也是要“呼吸”的,闷热的天气里,长时间佩戴入耳…

S2P销讯通-主数据对于客户关系管理系统的重要性

由于业务发展,各大企业的业务系统经历了从无到有,从简单到复杂,从而形成了一个又一个的业务系统,比如OA、HR、CRM、ERP等等。 主数据在客户关系管理系统(CRM)中扮演着至关重要的角色。主数据是指那些描述企…

SD操作手册

1、创建条件记录 2、创建标准订单/寄售补货订单/寄售发货订单/借贷项凭证/退货订单 3、创建外向交货单 4、开票 创建条件记录-VK11 PR00代表净价 输入销售组织、分销渠道、客户、物料 创建销售订单-VA01 2.1创建标准SO 维护完表头的信息后,双击行项目&#xff…

Python发送多人邮件如何实现高效群发功能?

Python发送多人邮件的教程?怎么使用Python群发邮件? Python,作为一种强大的编程语言,提供了多种库和工具,使得群发邮件变得既简单又高效。AokSend将详细介绍如何使用Python发送多人邮件,并探讨一些提高群发…

wpf prism 《3》 弹窗 IOC

传统的弹窗 这种耦合度高 new 窗体() . Show(); new 窗体() . ShowDialog(); 利用Prism 自动的 IOC 弹窗的 必须 必须 必须 页面控件 弹窗的 必须 必须 必须 页面控件 弹窗的 必须 必须 必须 页面控件 弹窗的 必须 必须 必须 页面控件 弹窗的 必须 必须 必须 页面控件 》》否…

sql-labs41-45关通关攻略

第41关 一.查询数据库 http://127.0.0.1/Less-41/?id-1%20union%20select%201,2,database()--http://127.0.0.1/Less-41/?id-1%20union%20select%201,2,database()-- 二.查表 http://127.0.0.1/Less-41/?id-1%20union%20select%201,2,(select%20group_concat(table_name)…

【Java】—— Java面向对象基础:Java中类的构造器与属性初始化,Student类的实例

目录 定义Student类 在main方法中创建Student对象 结论 在Java中,类的构造器(Constructor)是一个特殊的方法,用于在创建对象时初始化对象的属性。今天,我们将通过一个简单的Student类实例,来探讨如何在J…