【可解释性机器学习】详解Python的可解释机器学习库:SHAP

news2024/9/24 11:31:20

详解Python的可解释机器学习库:SHAP

  • SHAP介绍
  • SHAP的用途
  • SHAP的工作原理
  • 解释器Explainer
    • 局部可解释性Local Interper
      • 单个prediction的解释
      • 多个预测的解释
      • 获取单个样本的Top N个特征值及其对应的SHAP值
    • 全局可解释性Global Interper
      • summary_plot
      • Feature Importance
      • Interaction Values
      • dependence_plot
  • 其他类型的explainers
  • 一个使用SHAP计算神经网络影响的示例
  • 参考资料

可解释机器学习在这几年慢慢成为了机器学习的重要研究方向。作为数据科学家需要防止模型存在偏见,且帮助决策者理解如何正确地使用我们的模型。越是严苛的场景,越 需要模型提供证明它们是如何运作且避免错误的证据

关于模型解释性,除了线性模型和决策树这种天生就有很好解释性的模型意外,sklean中有很多模型都有importance这一接口,可以查看特征的重要性。其实这已经含沙射影地体现了模型解释性的理念。只是传统的importance的计算方法其实有很多争议,且并不总是一致。

SHAP介绍

SHAP是Python开发的一个“模型解释”包,可以解释任何机器学习模型的输出。其名称来源于SHapley Additive exPlanation在合作博弈论的启发下SHAP构建一个加性的解释模型,所有的特征都视为“贡献者”。对于每个预测样本,模型都产生一个预测值,SHAP value就是该样本中每个特征所分配到的数值。

假设第i个样本为 X i X_i Xi,第i个样本的第j个特征为 X i j X_{i_j} Xij,模型对该样本的预测值为 y i y_i yi,整个模型的基线(通常是所有样本的目标变量的均值)为 y b a s e y_{base} ybase,那么SHAP value服从以下等式:
y i = y b a s e + f ( X i 1 ) + f ( X i 2 ) + . . . + f ( X i k ) y_i=y_{base}+f(X_{i_1})+f(X_{i_2})+...+f(X_{i_k}) yi=ybase+f(Xi1)+f(Xi2)+...+f(Xik)
其中 f ( X i j ) f(X_{i_j}) f(Xij) X i j X_{i_j} Xij的SHAP值,直观上看, f ( X i 1 ) f(X_{i_1}) f(Xi1)就是第i个样本中第1个特征对最终预测值 y i y_i yi的贡献值,当 f ( X j 1 ) > 0 f(X_{j_1})>0 f(Xj1)>0,说明该特征提升了预测值,也正向作用;反之,说明该特征使得预测值降低,有反向作用。

传统的feature importance只告诉哪个特征重要,但并不清楚该特征是怎样影响预测结果的SHAP value最大的优势是SHAP能对于反映出每一个样本中的特征的影响力,而且还表现出影响的正负性
SHAP示意图通过pip install shap即可安装:

import shap

# 首先训练好一个XGBoost model
X,y = shap.datasets.boston()
model = xgboost.train({"learning_rate": 0.01}, xgboost.DMatrix(X, label=y), 100)

SHAP的用途

SHAP值(SHapley Additive exPlanations的缩写)从预测中把每一个特征的影响分解出来。可以把它应用到类似于下面的场景当中:

  • 模型认为银行不应该给某人放贷,但是法律上需要银行给出每一笔拒绝放贷的原因。
  • 医务人员想要确定对不同的病人而言,分别是哪些因素导致他们有患某种疾病的风险,这样就可以因人而异地采取针对性的卫生干预措施,直接处理这些风险因素。

SHAP的工作原理

SHAP值通过与某一特征取基线值时的预测做对比,来解释该特征取某一特定值的影响。例如,对一个球队会不会赢得“最佳球员”称号进行了预测。可能会有以下疑问:

  • 预测的结果有多大的程度是由球队进了3个球这一事实影响的?但是,如果我们像下面这样重新表述一下的话,那么给出具体、定量的答案还是比较容易的:
  • 预测的结果由多大的程度是由球队进了3个球这一事实影响的,而不是某些基线进球数

当然,每个球队都由很多特征,所以,如果我们能回答“进球数”的问题,那么我们也能对其它特征重复这一过程。

SHAP值用一种保证良好性质的的方式做这件事。具体而言,用如下等式对预测进行分解:

sum(SHAP values for all features) = pred_for_team - pred_for_baseline_values

也就是说,用所有特征的SHAP值的加和来解释为什么预测结果与基线不同。这就允许我们用像下面这样的一幅图来对预测进行分解:
示例图
该如何解释这张图呢?
可以看到预测的结果是0.7,而基准值是0.4979引起预测增加的特征值是粉色的,它们的长度表示特征影响的程度。引起预测降低的特征值是蓝色的。最大的影响来自Goal Scored等于2的时候。但ball possesion的值则对降低预测的值具有比较有意义的影响。如果把粉色条状图的长度与蓝色条状图的长度相减,差值就等于基准值到预测值之间的距离

解释器Explainer

在SHAP中进行模型解释需要先创建一个explainer,SHAP支持很多类型的explainer(例如deep、gradient、kernel、tree、sampling等),以tree为例,它支持常用的XGB、LGB、CatBoost等树集成算法。

explainer = shap.TreeExplainer(model) # #这里的model在准备工作中已经完成建模,模型名称就是model
shap_values = explainer.shap_values(X) # 传入特征矩阵X,计算SHAP值

上面的shap_values对象是一个包含两个array的list。第一个array是负向结果的SHAP值,而第二个array是正向结果的SHAP值。通常从预测正向结果的角度考虑模型的预测结果,所以会拿出正向结果的SHAP值(拿出shap_values[1])。

局部可解释性Local Interper

Local可解释性提供了预测的细节,侧重于解释单个预测是如何生成的。它可以帮助决策者信任模型,并且解释各个特征是如何影响模型单次的决策

单个prediction的解释

SHAP提供极其强大的数据可视化功能,来展示模型或预测的解释结果。

# 可视化第一个prediction的解释
shap.initjs()
shap.force_plot(explainer.expected_value, shap_values[0,:], X.iloc[0,:])

单个prediction的解释上图的"explanation"展示了每个特征都各自有其贡献,将模型的预测结果从基本值(base value)推动到最终的取值(model output);将预测推高的特征用红色表示,将预测推低的特征用蓝色表示。

基本值(base_value)是我们传入数据集上模型预测值的均值,可以通过自己计算来验证:

y_base = explainer.expected_value
print(y_base) # 14.230186
pred = model.predict(xgboost.DMatrix(X))
print(pred.mean()) # 14.230188

多个预测的解释

如果对多个样本进行解释,将上述形式旋转90度然后水平并排放置,可以看到整个数据集的explanations :

shap.initjs()
shap.force_plot(explainer.expected_value, shap_values, X)

多个预测的解释

获取单个样本的Top N个特征值及其对应的SHAP值

函数代码如下:

# 获取单个样本的Top N特征值和对应的SHAP值
def get_topN_reason(old_list, features, top_num = 3, min_value = 0.0):
  # 输出shap值最高的N个标签
  feature_importance_dict = {}
  for i, f in zip(old_list, features):
    feature_importance_dict[f] = i
  new_dict = dict(sorted(feature_importance_dict.items(), key=lambda e: e[1], reverse=True))
  return_dict = {}
  for k ,v in new_dict.items():
    if top_num > 0:
      if v >= min_value:
        return_dict[k] = v
        top_num -= 1
      else:
        break
    else:
      break
  return return_dict
print(get_topN_reason(old_list=shap_values[505], features=X.columns.values)) # 这里选取第505个样本

其中:

  • old_list:shap_value中某个array的单个元素(类型是list),这里我选择的是array中的505号样本
  • features: 与old_list的列数相同,主要用于输出的特征能让人看得懂
  • top_num:展示前N个最重要的特征
  • min_value: 限制了shap值的最小值

输出结果:

{'LSTAT': 1.11883, 'NOX': 0.10774355, 'TAX': 0.061408427}

全局可解释性Global Interper

Global可解释性:寻求理解模型的overall structure(总体结构)。这往往比解释单个预测困难得多,因为它涉及到对模型的一般工作原理作出说明,而不仅仅是一个预测

summary_plot

summary plot 为每个样本绘制其每个特征的SHAP值,这可以更好地理解整体模式,并允许发现预测异常值每一行代表一个特征,横坐标为SHAP值一个点代表一个样本,颜色表示特征值(红色高,蓝色低)。比如,这张图表明LSTAT特征较高的取值会降低预测的房价

# summarize the effects of all the features
shap.summary_plot(shap_values, X)

summary_plot

Feature Importance

传统的importance的计算方法效果不好,SHAP提供了另一种计算特征重要性的思路取每个特征的SHAP值的绝对值的平均值作为该特征的重要性,得到一个标准的条形图(multi-class则生成堆叠的条形图):

shap.summary_plot(shap_values, X, plot_type="bar")

Feature Importance

Interaction Values

interaction value是将SHAP值推广到更高阶交互的一种方法。树模型实现了快速、精确的两两交互计算,这将为每个预测返回一个矩阵,其中主要影响在对角线上,交互影响在对角线外。这些数值往往揭示了有趣的隐藏关系(交互作用):

shap_interaction_values = explainer.shap_interaction_values(X)
shap.summary_plot(shap_interaction_values, X)

Interaction Values

dependence_plot

为了理解单个feature如何影响模型的输出,可以将该feature的SHAP值与数据集中所有样本的feature值进行比较。由于SHAP值表示一个feature对模型输出中的变动量的贡献,下面的图表示随着特征RM变化的预测房价(output)的变化。单一RM(特征)值垂直方向上的色散表示与其他特征的相互作用,为了帮助揭示这些交互作用,“dependence_plot函数”自动选择另一个用于着色的feature。在这个案例中,RAD特征着色强调了RM(每栋房屋的平均房间数)对RAD值较高地区的房价影响较小

# create a SHAP dependence plot to show the effect of a single feature across the whole dataset
shap.dependence_plot("RM", shap_values, X)

dependence_plot

其他类型的explainers

SHAP库可用的explainers有:

  • deep:用于计算深度学习模型,基于DeepLIFT算法
  • gradient:用于深度学习模型,综合了SHAP、集成梯度、和SmoothGrad等思想,形成单一期望值方程
  • kernel:模型无关,适用于任何模型
  • linear:适用于特征独立不相关的线性模型
  • tree:适用于树模型和基于树模型的集成算法
  • sampling :基于特征独立性假设,当你想使用的后台数据集很大时,kenel的一个很好的替代方案

重点介绍Kernel Explainer,适用于任何模型,但性能不一定是最优的,可能很慢;例如KNN算法只能使用kernel explainer。不过可以用K-mean聚类算法对数据集进行summarizing,这样可以有效提高Kenel的速度(当然,会损失一些准确性),例如:

import time
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsRegressor

X_train, X_test, y_train, y_val = train_test_split(X, y, random_state=1)

knn = KNeighborsRegressor().fit(X_train, y_train)

X_train_summary = shap.kmeans(X_train, 10)
t0 = time.time()
explainerKNN = shap.KernelExplainer(knn.predict, X_train_summary)
shap_values_KNN_train = explainerKNN.shap_values(X_train)
shap_values_KNN_test = explainerKNN.shap_values(X_test)
timeit=time.time()-t0
timeit
'''
103.51293921470642
'''

通过SHAP,用knn模型在整个"波士顿房价"数据集上跑完需要1个小时。如果我们牺牲一些精度,通过k-means聚类对数据进行summarizing,可以将时间缩短到2分钟。

一个使用SHAP计算神经网络影响的示例

在此示例中,使用SHAP计算使用 Python 和 scikit-learn 的神经网络的特征影响 。对于这个例子,使用 scikit-learn 的糖尿病数据集,它是一个回归数据集。

首先安装shap库。

!pip install shap

然后,导入相关库

import shap
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPRegressor
from sklearn.pipeline import make_pipeline

from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split

导入数据,并划分训练集和测试集:

# 加载数据集和特征名称
X,y = load_diabetes(return_X_y=True)
features = load_diabetes()['feature_names']
# 拆分训练和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

这里使用的模型是MLP,即多层感知机。首先对特征进行归一化。该模型本身是一个前馈神经网络,在隐藏层有 5 个神经元,10000 个 epoch 和一个具有自适应学习率的逻辑激活函数。在现实生活中,可以在设置这些值之前适当地优化这些超参数。

model = make_pipeline(
    StandardScaler(),
    MLPRegressor(hidden_layer_sizes=(5,),activation='logistic', max_iter=10000,learning_rate='invscaling',random_state=0)
)
model.fit(X_train,y_train)

输出结果
现在是 SHAP 部分。首先,需要创建一个名为explainer的对象。它是在输入中接受模型的预测方法和训练数据集的对象。为了使 SHAP 模型与模型无关,它围绕训练数据集的点执行扰动,并计算这种扰动对模型的影响。这是一种重采样技术,其样本数量稍后设置。这种方法与另一种称为 LIME 的著名方法有关,该方法已被证明是原始 SHAP 方法的一个特例。结果是对 SHAP 值的统计估计。
所以,首先定义解释器对象

explainer = shap.KernelExplainer(model.predict,X_train)
'''
WARNING:shap:Using 296 background data samples could cause slower run times. Consider using shap.sample(data, K) or shap.kmeans(data, K) to summarize the background as K samples.
'''

现在可以计算SHAP值。请记住,它们是通过对训练数据集重新采样并计算对这些扰动的影响来计算的,因此必须定义适当数量的样本。对于此示例,使用 100 个样本。然后,在测试数据集上计算影响。
计算SHAP值
出现一个漂亮的进度条并显示计算的进度,这可能很慢。

最后,得到一个 (n_samples, n_features) 的numpy 数组。每个元素都是该记录的该特征的 shap 值。请记住,SHAP值是针对每个特征和每个记录计算的
SHAP值
现在可以绘制“summary_plot”。

shap.summary_plot(shap_values, X_test, feature_names=features)

summary_plot
每行的每个点都是测试数据集的记录。这些特征从最重要的一个到不太重要的排序。可以看到s5是最重要的特征。该特征的值越高,对目标的影响越积极。该值越低,贡献越负。

更深入地了解特定记录,还可以绘制的一个非常有用的图称为force_plot

shap.initjs()
shap.force_plot(explainer.expected_value, shap_values[0,:] ,X_test[0,:],feature_names=features)

force_plot
从图中可以看出113.90 是预测值。基准值(base value)是目标变量在所有记录中的平均值。每个条带都显示了其特征在将目标变量的值推得更远或更接近基准值方面的影响。红色条纹表明它们的特征将价值推向更高的价值。蓝色条纹表明它们的特征将值推向较低的值。条纹越宽,贡献越高(绝对值)。这些贡献的总和将目标变量的值从花瓶值推到最终的预测值。

对于这个特定的记录,bmi、bp、s2、sex和s5值对预测值有正贡献s5仍然是这条记录中最重要的变量,因为它的贡献是最宽的(它具有最大的条带)。显示负贡献的变量是s1和s4,但它不足以使预测值低于基值。因此,由于总的正贡献(红色条纹)大于负贡献(蓝色条纹),因此最终值大于基准值

参考资料

[1] SHAP:Python的可解释机器学习库
[2] Python:使用SHAP库将前N个重要特征提取出来
[3] SHAP VALUES —— 什么影响了你的决定?
[4] SHAP 机器学习模型解释可视化工具
[5] 机器学习模型可解释性进行到底 —— SHAP值理论(一)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/187791.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SkyWalking 极简入门

SkyWalking 极简入门1.概述1.1 概念1.2 功能列表1.3 整体架构1.4 官方文档2. 搭建 SkyWalking 单机环境2.1 Elasticsearch 搭建2.2 下载 SkyWalking 软件包2.3 SkyWalking OAP 搭建2.4 SkyWalking UI 搭建2.5 SkyWalking Agent2.5.1 Shell2.5.2 IDEA3. 搭建 SkyWalking 集群环境…

【4】Linux实用操作

学习笔记目录 初识Linux--入门Linux基础命令--会用Linux权限管控--懂权限Linux实用操作--熟练实战软件部署--深入掌握脚本&自动化--用的更强项目实战--学到经验云平台技术--紧跟潮流 各类小技巧(快捷键) ctrl c强制停止 Linux某些程序的运行&am…

AI作画:文心一格赋能艺术与设计创作

针对视觉内容创作门槛高、耗时长等行业痛点问题,百度推出了基于文心大模型的AI艺术创作产品文心一格。通过文心一格核心系统的技术创新,让AI作画普惠大众,提升创作效率。目前,文心一格产品已经对外发布使用,大众用户均…

Docker容器基本操作

docker中的容器就是一个轻量级的虚拟机,是镜像运行起来的一个状态,本文就先来看看容器的基本操作。 查看容器 查看容器 启动docker后,使用docker ps命令可以查看当前正在运行的容器: 查看所有容器 上面这条命令是查看当前正在…

[强网杯 2019]高明的黑客

目录 信息收集 正则测试 python脚本 getshell 信息收集 $_GET[ganVMUq3d] ; eval($_GET[ganVMUq3d] ?? ); $_GET[jVMcNhK_F] ; system($_GET[jVMcNhK_F] ?? ); $_GET[cXjHClMPs] ; echo {$_GET[cXjHClMPs]}; 下载gz解压后得到几千个php文件,简单看…

【项目实战】count(1) 、count(col)、count(*) 如何选择?

一、背景 有时候会看业务执行的情况,如查看多少用户已经领取了礼品等,需要看数据库的计数或统计用户使用情况时,往往会使用聚合函数COUNT(),聚合函数有很多种,列出如官网的截图 而其中常用的聚合函数主要是包括以下&…

Linux常用命令——readelf命令

在线Linux命令查询工具(http://www.lzltool.com/LinuxCommand) readelf 用于显示elf格式文件的信息 补充说明 readelf命令用来显示一个或者多个elf格式的目标文件的信息,可以通过它的选项来控制显示哪些信息。这里的elf-file(s)就表示那些被检查的文件。可以支持…

java基于ssm滑雪场门票在线售票系统的设计与实现

基于jsp技术设计并实现了滑雪售票系统。该系统基于B/S即所谓浏览器/服务器模式,应用SSM框架,选择MySQL作为后台数据库。系统主要包括个人中心、用户管理、票务信息管理、购票信息管理、技巧交流、系统管理等功能模块。 性能测试主要通过模拟系统运行环境…

蓝桥杯刷题017——轨道炮(贪心)

2019国赛轨道炮 题目描述 小明在玩一款战争游戏。地图上一共有 N 个敌方单位,可以看作 2D 平面上的点。其中第 i 个单位在 0 时刻的位置是 (Xi​,Yi​),方向是 Di​ (上下左右之一, 用U/D/L/R 表示),速度是 Vi​。 小明的武器是轨道炮&#x…

我来告诉你,ChatGPT 该怎么对接到自己的项目中!

作者:小傅哥 博客:https://bugstack.cn 沉淀、分享、成长,让自己和他人都能有所收获!😄 1. 项目介绍 《ChatGPT AI 问答助手》 开源免费项目,涵盖爬虫接口、ChatGPT API对接、DDD架构设计、镜像打包、Dock…

六十年间中国经济总量增长245倍

中国GDP(现价美元)走势图回顾2022年,中国经济顶住了来自防控疫情及国际变化的巨大压力,全年GDP实现3%的同比增长,其中年末冬季GDP同比增长2.9%。从主要指标显示,12月第三产业、投资、地产都有不同程度的改善…

HTB_Base_php弱类型与find命令提权

至此,startpoint部分就完成了,也初步建立的渗透测试的思路和体系 温故而知新,可以为师矣 文章目录信息收集php弱类型 & strcmp绕过GTFOBins 二进制常用命令find提权信息收集 nmap -sC -A -Pn ip开放端口22、80 dirsearch 扫描目录&…

一文5000字手把手教你使用jenkins搭建一个中小企业前端项目部署环境

本文你能学到什么? 传统发布和现在发布的对比和区别 项目案例-手动上传服务器,使用jenkins上传服务器 配置不同的发布环境 配置域名 配置https 配置钉钉机器人通知【文末有配套资源领取】 服务器购买:抢占式实例 Jenkins 安装 示例服务器为 阿里…

系统移植 tf-a

1、从ST官网下载TF-A源码,将TF-A源码拷贝到ubuntu中并进行解压。 2、进入tf-a源码,阅读README.HOW_TO.txt帮助文档。 3、分析帮助文档。 分析文档可得,移植tf-a到开发板中步骤为:对tf-a源码进行解压;打补丁&#xff…

ElasticSearch从入门到出门【上】

文章目录初识elasticsearch了解ESelasticsearch的作用ELK技术栈elasticsearch和lucene为什么不是其他搜索技术?倒排索引正向索引倒排索引正向和倒排ES的一些概念文档和字段索引和映射mysql与elasticsearch安装elasticsearch部署单点es部署kibana安装IK分词器在线安装…

SpringBoot 使用自定义的方式整合Druid数据源(powernode document)(内含源代码)

SpringBoot 使用自定义的方式整合Druid数据源(powernode document)(内含源代码) 源代码下载链接地址:https://download.csdn.net/download/weixin_46411355/87404561 目录SpringBoot 使用自定义的方式整合Druid数据源…

SQLSERVER 的 nolock 到底是怎样的无锁?

一:背景 1. 讲故事 相信绝大部分用 SQLSERVER 作为底层存储的程序员都知道 nolock 关键词,即使当时不知道也会在踩过若干阻塞坑之后果断的加上 nolock,但这玩意有什么注意事项呢?这就需要了解它的底层原理了。 二:n…

python之np.sum()用法详解

python库numpy提供的求和方法np.sum(),可以对数组和矩阵进行求和。sum方法可以接收多个参数,主要是数组a,坐标轴axis,数据类型dtype,初始值initial。其中,axis对于我们来说比较容易迷糊,这个值对求和有什么影响?一般来…

Linux常用命令——rcp命令

在线Linux命令查询工具(http://www.lzltool.com/LinuxCommand) rcp 使在两台Linux主机之间的文件复制操作更简单 补充说明 rcp命令使在两台Linux主机之间的文件复制操作更简单。通过适当的配置,在两台Linux主机之间复制文件而无需输入密码,就像本地文…

React基础入门(一)

1、React简介 官网 英文官网: https://reactjs.org/ 中文官网: https://react.docschina.org/ 描述介绍 用于动态构建用户界面的 JavaScript 库(只关注于视图) 由Facebook开源 React特点 1、声明式编码 2、组件化编码 3、React Native 编写原生应用 4、高效(优秀…