人工智能|机器学习——强大的 Scikit-learn 可视化让模型说话

news2025/1/11 19:48:46

一、显示 API 简介

使用 utils.discovery.all_displays 查找可用的 API。

Sklearn 的utils.discovery.all_displays可以让你看到哪些类可以使用。

from sklearn.utils.discovery import all_displays
displays = all_displays()
displays

Scikit-learn (sklearn) 总是会在新版本中添加 "Display "API,因此这里可以了解你的版本中有哪些可用的 API 。例如,在我的 Scikit-learn 1.4.0 中,就有这些类:

[('CalibrationDisplay', sklearn.calibration.CalibrationDisplay),
 ('ConfusionMatrixDisplay',
  sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay),
 ('DecisionBoundaryDisplay',
  sklearn.inspection._plot.decision_boundary.DecisionBoundaryDisplay),
 ('DetCurveDisplay', sklearn.metrics._plot.det_curve.DetCurveDisplay),
 ('LearningCurveDisplay', sklearn.model_selection._plot.LearningCurveDisplay),
 ('PartialDependenceDisplay',
  sklearn.inspection._plot.partial_dependence.PartialDependenceDisplay),
 ('PrecisionRecallDisplay',
  sklearn.metrics._plot.precision_recall_curve.PrecisionRecallDisplay),
 ('PredictionErrorDisplay',
  sklearn.metrics._plot.regression.PredictionErrorDisplay),
 ('RocCurveDisplay', sklearn.metrics._plot.roc_curve.RocCurveDisplay),
 ('ValidationCurveDisplay',
  sklearn.model_selection._plot.ValidationCurveDisplay)]

二、显示决策边界

使用 inspection.DecisionBoundaryDisplay 显示决策边界

如果使用 Matplotlib 来绘制,会很麻烦:

  • 使用 np.linspace 设置坐标范围;

  • 使用 plt.meshgrid 计算网格;

  • 使用 plt.contourf 绘制决策边界填充;

  • 然后使用 plt.scatter 绘制数据点。

现在,使用 inspection.DecisionBoundaryDisplay 可以简化这一过程:

from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

iris = load_iris(as_frame=True)
X = iris.data[['petal length (cm)', 'petal width (cm)']]
y = iris.target


svc_clf = make_pipeline(StandardScaler(), 
                        SVC(kernel='linear', C=1))
svc_clf.fit(X, y)

display = DecisionBoundaryDisplay.from_estimator(svc_clf, X, 
                                                 grid_resolution=1000,
                                                 xlabel="Petal length (cm)",
                                                 ylabel="Petal width (cm)")
plt.scatter(X.iloc[:, 0], X.iloc[:, 1], c=y, edgecolors='w')
plt.title("Decision Boundary")
plt.show()

使用 DecisionBoundaryDisplay 绘制三重分类模型。

请记住,Display 只能绘制二维数据,因此请确保数据只有两个特征或更小的维度。

三、概率校准

要比较分类模型,使用 calibration.CalibrationDisplay 进行概率校准,概率校准曲线可以显示模型预测的可信度。

CalibrationDisplay使用的是模型的 predict_proba。如果使用支持向量机,需要将 probability 设为 True:

from sklearn.calibration import CalibrationDisplay
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.ensemble import HistGradientBoostingClassifier

X, y = make_classification(n_samples=1000,
                           n_classes=2, n_features=5,
                           random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                            test_size=0.3, random_state=42)
proba_clf = make_pipeline(StandardScaler(), 
                          SVC(kernel="rbf", gamma="auto", 
                              C=10, probability=True))
proba_clf.fit(X_train, y_train)

CalibrationDisplay.from_estimator(proba_clf, 
                                            X_test, y_test)

hist_clf = HistGradientBoostingClassifier()
hist_clf.fit(X_train, y_train)

ax = plt.gca()
CalibrationDisplay.from_estimator(hist_clf,
                                  X_test, y_test,
                                  ax=ax)
plt.show()

CalibrationDisplay.

四、显示混淆矩阵

在评估分类模型和处理不平衡数据时,需要查看精确度和召回率。使用 metrics.ConfusionMatrixDisplay绘制混淆矩阵(TP、FP、TN 和 FN)。

from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import ConfusionMatrixDisplay

digits = fetch_openml('mnist_784', version=1)
X, y = digits.data, digits.target
rf_clf = RandomForestClassifier(max_depth=5, random_state=42)
rf_clf.fit(X, y)

ConfusionMatrixDisplay.from_estimator(rf_clf, X, y)
plt.show()

五、Roc 和 Det 曲线

因为经常并列评估Roc 和 Det 曲线,因此把metrics.RocCurveDisplay 和 metrics.DetCurveDisplay两个图表放在一起。

  • RocCurveDisplay比较模型的 TPR 和 FPR。对于二分类,希望 FPR 低而 TPR 高,因此左上角是最佳位置。Roc 曲线向这个角弯曲。

由于 Roc 曲线停留在左上角附近,右下角是空的,因此很难看到模型差异。

  • 使用 DetCurveDisplay 绘制一条带有 FNR 和 FPR 的 Det 曲线。它使用了更多空间,比 Roc 曲线更清晰。Det 曲线的最佳点是左下角。

from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import DetCurveDisplay

X, y = make_classification(n_samples=10_000, n_features=5,
                           n_classes=2, n_informative=2)
X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                             test_size=0.3, random_state=42,
                                                   stratify=y)


classifiers = {
    "SVC": make_pipeline(StandardScaler(), 
                        SVC(kernel="linear", C=0.1, random_state=42)),
    "Random Forest": RandomForestClassifier(max_depth=5, random_state=42)
}

fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(10, 4))
for name, clf in classifiers.items():
    clf.fit(X_train, y_train)
    
    RocCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_roc, name=name)
    DetCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_det, name=name)

六、调整阈值

在数据不平衡的情况下,希望调整召回率和精确度。可以使用使用 metrics.PrecisionRecallDisplay 调整阈值

  • 对于电子邮件欺诈,需要高精确度。

  • 而对于疾病筛查,则需要高召回率来捕获更多病例。

那么可以调整阈值,但调整多少才合适呢?因此可以使用metrics.PrecisionRecallDisplay 来绘制相关图表。

from xgboost import XGBClassifier
from sklearn.datasets import load_wine
from sklearn.metrics import PrecisionRecallDisplay

wine = load_wine()
X, y = wine.data[wine.target<=1], wine.target[wine.target<=1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,
                                               stratify=y, random_state=42)

xgb_clf = XGBClassifier()
xgb_clf.fit(X_train, y_train)

PrecisionRecallDisplay.from_estimator(xgb_clf, X_test, y_test)
plt.show()

这表明可以按照 Scikit-learn 的设计绘制模型,就像这里的 xgboost

七、回归模型评估

Scikit-learn 的 metrics.PredictionErrorDisplay 绘制残差图可以帮助评估回归模型。

from sklearn.svm import SVR
from sklearn.metrics import PredictionErrorDisplay

rng = np.random.default_rng(42)
X = rng.random(size=(200, 2)) * 10
y = X[:, 0]**2 + 5 * X[:, 1] + 10 + rng.normal(loc=0.0, scale=0.1, size=(200,))

reg = make_pipeline(StandardScaler(), SVR(kernel='linear', C=10))
reg.fit(X, y)

fig, axes = plt.subplots(1, 2, figsize=(8, 4))
PredictionErrorDisplay.from_estimator(reg, X, y, ax=axes[0], kind="actual_vs_predicted")
PredictionErrorDisplay.from_estimator(reg, X, y, ax=axes[1], kind="residual_vs_predicted")
plt.show()

图表展示预测值与实际值的比较,左图适合线性回归。然而,并非所有数据都是完全线性的,因此,请参考右图。右图展示了实际值与预测值的差异,即残差图。残差图的香蕉形状暗示我们的数据可能不适合线性回归。考虑将核函数从"线性" 转换为 "rbf" ,残差图会更好。

reg = make_pipeline(StandardScaler(), 
                    SVR(kernel='rbf', C=10))

八、绘制学习曲线

学习曲线主要研究模型的泛化效果和训练测试数据之间的差异或偏差。接下来,使用 model_selection.LearningCurveDisplay 绘制学习曲线,并比较了决策树分类器和梯度提升分类器在不同训练数据下的表现。

from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import LearningCurveDisplay

X, y = make_classification(n_samples=1000, n_classes=2, n_features=10,
                           n_informative=2, n_redundant=0, n_repeated=0)

tree_clf = DecisionTreeClassifier(max_depth=3, random_state=42)
gb_clf = GradientBoostingClassifier(n_estimators=50, max_depth=3, tol=1e-3)

train_sizes = np.linspace(0.4, 1.0, 10)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
LearningCurveDisplay.from_estimator(tree_clf, X, y,
                                    train_sizes=train_sizes,
                                    ax=axes[0],
                                    scoring='accuracy')
axes[0].set_title('DecisionTreeClassifier')
LearningCurveDisplay.from_estimator(gb_clf, X, y,
                                    train_sizes=train_sizes,
                                    ax=axes[1],
                                    scoring='accuracy')
axes[1].set_title('GradientBoostingClassifier')
plt.show()

从图中可以看出,虽然基于树的 GradientBoostingClassifier 在训练数据上保持了良好的准确性,但其在测试数据上的泛化能力与 DecisionTreeClassifier 相比并无明显优势。

九、可视化参数调整

为了改善泛化效果差的模型,可以尝试通过调整正则化参数来提高性能。传统的方法是使用 "GridSearchCV" 或 "Optuna" 等工具来实现模型调整,然而这些方法只能找出整体表现最佳的模型,且调整过程并不直观。如果需要调整特定参数以测试其对模型的影响,建议使用 model_selection.ValidationCurveDisplay 来直观地观察模型在参数变化时的表现。

from sklearn.model_selection import ValidationCurveDisplay
from sklearn.linear_model import LogisticRegression

param_name, param_range = "C", np.logspace(-8, 3, 10)
lr_clf = LogisticRegression()

ValidationCurveDisplay.from_estimator(lr_clf, X, y,
                                      param_name=param_name,
                                      param_range=param_range,
                                      scoring='f1_weighted',
                                      cv=5, n_jobs=-1)
plt.show()

十、讨论

尝试过所有这些显示后,我必须承认一些遗憾:

  • 最大的遗憾是这些 API 大多数缺乏详细的教程,这可能也是与 Scikit-learn 的详尽文档相比不为人知的原因。

  • 这些应用程序接口散布在不同的软件包中,因此很难从一个地方引用它们。

  • 代码仍然非常基础。通常需要将其与 Matplotlib 的 API 搭配使用才能完成工作。一个典型的例子是 "DecisionBoundaryDisplay",在绘制决策边界后,还需要使用 Matplotlib 来绘制数据分布。

  • 它们很难扩展。除了一些验证参数的方法外,很难用工具或方法来简化模型的可视化过程;最终需要重写了很多东西。

这些 API 希望得到更多关注,并且随着版本升级,可视化 API 也能更易用。

在机器学习中,用可视化方式解释模型与训练模型同样重要。

本文介绍了当前版本 scikit-learn 中的各种绘图 API,利用这些 API,可以简化一些 Matplotlib 代码,缓解学习曲线,并简化模型评估过程。由于篇幅有限,未对每个 API 进行详细介绍。如果有兴趣,可以查看 [官方文档:https://scikit-learn.org/stable/visualizations.html?ref=dataleadsfuture.com] 了解更多详情。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1651053.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

(三十六)第 6 章 树和二叉树(二叉树的顺序存储表示实现)

1. 背景说明 2. 示例代码 1) errorRecord.h // 记录错误宏定义头文件#ifndef ERROR_RECORD_H #define ERROR_RECORD_H#include <stdio.h> #include <string.h> #include <stdint.h>// 从文件路径中提取文件名 #define FILE_NAME(X) strrchr(X, \\) ?…

爬虫学习:XPath匹配网页数据

目录 一、安装XPath 二、XPath的基础语法 1.选取节点 三、使用XPath匹配数据 1.浏览器审查元素 2.具体实例 四、总结 一、安装XPath 控制台输入指令&#xff1a;pip install lxml 二、XPath的基础语法 XPath是一种在XML文档中查找信息的语言&#xff0c;可以使用它在HTM…

vue导出大量数据的表格方法

我目前的项目导出4万7数据没问题 先安装 npm install -S file-saver npm install xlsx0.16.0 -S npm install -D script-loader 我使用的版本是"file-saver": “^2.0.5”, “xlsx”: “^0.16.0” 新建Export2Excel.js //Export2Excel.js /* eslint-disable */ requ…

代码训练LeetCode(17)存在重复元素

代码训练(17)LeetCode之存在重复元素 Author: Once Day Date: 2024年5月7日 漫漫长路&#xff0c;才刚刚开始… 全系列文章可参考专栏: 十年代码训练_Once-Day的博客-CSDN博客 参考文章: 219. 存在重复元素 II - 力扣&#xff08;LeetCode&#xff09;力扣 (LeetCode) 全球…

基于微信小程序的图书馆预约系统的设计与实现

个人介绍 hello hello~ &#xff0c;这里是 code袁~&#x1f496;&#x1f496; &#xff0c;欢迎大家点赞&#x1f973;&#x1f973;关注&#x1f4a5;&#x1f4a5;收藏&#x1f339;&#x1f339;&#x1f339; &#x1f981;作者简介&#xff1a;一名喜欢分享和记录学习的…

200554-19-4,AF350琥珀酰亚胺酯具有较高的荧光量子产率

产品概述 AF350 NHS Ester&#xff0c;即AF350琥珀酰亚胺酯&#xff0c;是一种重要的荧光标记染料&#xff0c;具有广泛的应用领域和显著的性能特点。 中文名称&#xff1a;AF350琥珀酰亚胺酯 英文名称&#xff1a;AF350 NHS Ester&#xff0c;AlexaFluor350 SE CAS号&…

可视化数据报道:Kompas.ai如何用图表和动态效果讲述故事

在数字化时代&#xff0c;数据无处不在&#xff0c;而如何将这些数据转化为易于理解且吸引人的故事&#xff0c;成为信息传递的关键。数据可视化作为一种强有力的工具&#xff0c;能够帮助观众快速把握复杂信息的要点&#xff0c;增强记忆&#xff0c;并激发情感共鸣。本文将深…

Dask简介

目录 一、概述 二、编程模型 2.1 High-Level Collection 2.2 Low level Interface 三、调度框架 3.1 任务图 3.2 调度 3.3 优化 3.4 动态任务图 一、概述 Dask是一个灵活的Python并行计算库。 Dask由两部分组成&#xff1a; 为计算优化的动态任务调度&#xff1a;和A…

所向披靡のmakefile

在VS里敲代码&#xff0c;只需要FnF5就可以直接运行勒&#xff0c;在Linux下敲代码却要即敲命令还要用编辑器还要用编译器&#xff0c;那在Linux下有没有能帮我们进行自动化组建的工具呢&#xff1f; 当然有&#xff0c;超级巨星&#xff1a;makefile&#xff01;&#xff01;…

obs64无法定位程序输入点IsWow64Process2

obs安装后&#xff0c;打开提示&#xff1a;obs64无法定位程序输入点IsWow64Process2。 解决办法&#xff0c;找到obs.dll文件&#xff0c;并找软件打开。 &#xff08;我用的是 notepad打开的&#xff09; 用CTRLF 搜索 “IsWow64Process2” 对应的"32"改为"…

【容器】Pod 生命周期

概述 Pod的生命周期包含从Pod创建事件的触发到Pod被停止的整个流程。了解Pod的生命周期方便日常排障&#xff0c;并能帮助较深入了解K8s。 在Pod生命周期中有两个重要的标识&#xff1a;Pod Condition 和 Pod Phase。Pod Phase提供了一个Pod当前状况的概览&#xff0c;可以帮…

APP 在华为应用市场上架 保姆级别详细流程

1、作为一名干开发的程序员&#xff0c;第一次能把自己的APP 上架&#xff0c;对自己来说是多么有意义的一项成就 2、创建一个 华为的开发者账号 根据提示填写完注册的信息https://developer.huawei.com/consumer/cn/product/华为开发者产品 | 开发者平台 | 流量变现 | 华为开…

Three.js的几何形状

在创建物体的时候&#xff0c;需要传入两个参数&#xff0c;一个是几何形状【Geometry】&#xff0c;一个是材质【Material】 几何形状主要是存储一个物体的顶点信息&#xff0c;在Three中可以通过指定一些特征来创建几何形状&#xff0c;比如使用半径来创建一个球体。 立方体…

Android Studio查看xml文件的修改时间和记录

Android Studio查看xml文件的修改时间和记录 Android Studio里面如果是Java/Kotlin编写界面&#xff0c;可以点击函数开头上面的提交在直接&#xff0c;然后在编辑界面的左侧查看历史时间上的修改记录&#xff0c;但是xml文件里面没有直观的这样操作方式。 但xml里面可以通过快…

FileLink跨网文件交换,推动企业高效协作|半导体行业解决方案

随着信息技术的迅猛发展&#xff0c;全球信息产业已经迎来了前所未有的繁荣与变革。在这场科技革命中&#xff0c;半导体作为信息产业的基础与核心&#xff0c;其重要性日益凸显&#xff0c;半导体的应用场景和市场需求将进一步扩大。 然而&#xff0c;在这一繁荣的背后&#x…

微信公众号营销攻略,2024年微信引流商业最佳实践

确实&#xff0c;微信是中国市场上不可或缺的营销工具。下面是一些关于如何在微信上进行有效营销的最佳实践&#xff0c;以及如何通过微信公众号进行广告宣传&#xff0c;以提升品牌知名度并推动业务增长。 拥有一个微信公众号是进行微信营销的关键第一步。 通过公众号&#x…

UE5自动生成地形一:地形制作

UE5自动生成地形一&#xff1a;地形制作 常规地形制作地形编辑器地形管理添加植被手动修改部分地形的植被 置换贴图全局一致纹理制作地貌裸露岩石地形实例 常规地形制作 地形制作入门 地形导入部分 选择模式&#xff1a;地形模式。选择地形子菜单&#xff1a;管理->导入 …

吴恩达深度学习笔记:深度学习的 实践层面 (Practical aspects of Deep Learning)1.13-1.14

目录 第二门课: 改善深层神经网络&#xff1a;超参数调试、正 则 化 以 及 优 化 (Improving Deep Neural Networks:Hyperparameter tuning, Regularization and Optimization)第一周&#xff1a;深度学习的 实践层面 (Practical aspects of Deep Learning)1.13 梯度检验&#…

蓝桥杯单片机之模块代码《AT24C02》

过往历程 历程1&#xff1a;秒表 历程2&#xff1a;按键显示时钟 历程3&#xff1a;列矩阵按键显示时钟 历程4&#xff1a;行矩阵按键显示时钟 历程5&#xff1a;新DS1302 历程6&#xff1a;小数点精确后两位ds18b20 历程7&#xff1a;35定时器测量频率 文章目录 过往历…

微信小程序(Taro)获取经纬度并转化为具体城市

1、获取经纬度 申请权限&#xff0c;想要使用微信小程序获取经纬度的方法是要申请该方面的权限。 获取经纬度的方法有很多选择其中一个使用就好。 我使用的是Taro.getFuzzyLocation(&#xff09; 在app.config.js中需要添加设置 requiredPrivateInfos: ["getFuzzyLocat…