shap-An introduction to explainable AI with Shapley values

news2025/1/15 6:24:12

An introduction to explainable AI with Shapley values

  • 训练模型
  • 检查模型系数
  • 使用部分依赖图的更完整的图片
  • 从部分相关性图中读取SHAP值
  • Shapley值的可加性
  • 解释additive regression模型
  • 解释non-additive boosted tree模型
  • 解释线性逻辑回归模型
  • 解释non-additive boosted tree逻辑回归模型
  • 处理相关特征
  • transformers NLP模型的解释

用到的环境是python3.7(基于上一篇文章的环境),再装了个interpret、transformers、datasets。。。

官方的代码在https://github.com/shap/shap/blob/master/docs/example_notebooks/overviews/An%20introduction%20to%20explainable%20AI%20with%20Shapley%20values.ipynb

训练模型

这是用Shapley值解释机器学习模型的介绍。Shapley值是合作博弈论中一种广泛使用的方法,具有理想的性质。

在使用Shapley值来解释复杂模型之前,了解它们如何适用于简单模型是很有帮助的。最简单的模型类型之一是标准线性回归,因此下面我们在加州住房数据集上训练线性回归模型。该数据集由1990年加利福尼亚州20640栋房屋组成,我们的目标是从8个不同的特征预测房价的中位数的自然对数。

8个特征:

  • MedInc - median income in block group
  • HouseAge - median house age in block group
  • AveRooms - average number of rooms per household
  • AveBedrms - average number of bedrooms per household
  • Population - block group population
  • AveOccup - average number of household members
  • Latitude - block group latitude
  • Longitude - block group longitude

首先获取california数据集,随机抽取1000条数据,X100就是从这1000条数据中抽取100条数据,模型使用sklearn的LinearRegression()。

import pandas as pd
import shap
import sklearn

# a classic housing price dataset
X,y = shap.datasets.california(n_points=1000)

X100 = shap.utils.sample(X, 100) # 100 instances for use as the background distribution

# a simple linear model
model = sklearn.linear_model.LinearRegression()
model.fit(X, y)

又是上次那个警告,,,忽略。。。
在这里插入图片描述

输出:

LinearRegression()

检查模型系数

理解线性模型最常见的方法是检查为每个特征学习的系数。当我们改变每个输入特征时,这些系数告诉我们模型输出的变化程度:

print("Model coefficients:\n")
for i in range(X.shape[1]):
    print(X.columns[i], "=", model.coef_[i].round(5))

model.coef_[i]获取特征的系数,round(5)保留5位小数
model.intercept_ 获取模型截距
y=系数1 * 特征1+系数2 * 特征2+…+截距

输出:

Model coefficients:

MedInc = 0.45769
HouseAge = 0.01153
AveRooms = -0.12529
AveBedrms = 1.04053
Population = 5e-05
AveOccup = -0.29795
Latitude = -0.41204
Longitude = -0.40125

虽然系数可以很好地告诉我们当我们改变输入特征的值时会发生什么,但就其本身而言,它们并不是衡量特征整体重要性的好方法。这是因为每个系数的值取决于输入特征的尺度。例如,如果我们以分钟而不是年来衡量房屋的年龄,那么HouseAge特征的系数将变为0.0115/(3652460)=2.18e-8。很明显,房子建成后的年数并不比分钟数更重要,但它的系数值要大得多。这意味着系数的大小不一定是线性模型中特征重要性的良好衡量标准。

使用部分依赖图的更完整的图片

https://christophm.github.io/interpretable-ml-book/pdp.html

部分依赖图(partial dependence plots,PDP)显示了目标函数(即我们的机器学习模型)和一组特征之间的依赖关系,并边缘化其他特征的值(也就是补充特征)。 它们是通过将模型应用于一组数据、改变感兴趣特征的值同时保持补充特征的值不变可以分析模型输出来计算特征变量对模型预测结果影响的函数关系:例如近似线性关系、单调关系或者更复杂的关系。

为了理解特征在模型中的重要性,有必要了解更改该特征如何影响模型的输出,以及该特征值的分布。为了将其可视化为线性模型,我们可以构建一个经典的部分依赖图,并将特征值的分布显示为x轴上的直方图:

shap.partial_dependence_plot(
    "MedInc", model.predict, X100, ice=False,
    model_expected_value=True, feature_expected_value=True
)

输出:
在这里插入图片描述
上图中的灰色水平线表示模型应用于加州住房数据集时的预期值。垂直灰线表示收入中位数特征的平均值。请注意,蓝色的部分依赖性图线(当我们将中值收入特征固定为给定值时,它是模型输出的平均值)总是穿过两条灰色期望值线的交叉点。我们可以将该交点视为关于数据分布的部分依赖图的“中心”。当我们接下来转向Shapley值时,这种中心化的影响将变得显而易见。

从部分相关性图中读取SHAP值

https://christophm.github.io/interpretable-ml-book/shapley.html

机器学习模型基于Shapley值的解释背后的核心思想是使用合作博弈论的公平分配结果来分配模型输出的信用𝑓(𝑥) 在其输入特征中。为了将博弈论与机器学习模型联系起来,既需要将模型的输入特征与游戏中的玩家相匹配,也需要将模型函数与游戏规则相匹配。由于在博弈论中,玩家可以加入或不加入游戏,我们需要一种功能“加入”或“不加入”模型的方法。定义特征“加入”模型意味着什么的最常见方法是,当我们知道该特征的值时,说该特征“加入了模型”,当我们不知道该特征值时,它没有加入模型。评估现有模型𝑓 当只有一个子集𝑆 的特征是模型的一部分,我们使用条件期望值公式来整合其他特征。该构想可以采取两种形式:

E [ f ( X ) ∣ X S = x S ] E[f(X) \mid X_S = x_S] E[f(X)XS=xS]

or

E [ f ( X ) ∣ d o ( X S = x S ) ] E[f(X) \mid do(X_S = x_S)] E[f(X)do(XS=xS)]

在第一种形式中,我们知道S中特征的值,因为我们观察到了它们。在第二种形式中,我们知道S中特征的值,因为我们设置了它们。一般来说,第二种形式通常更可取,因为它告诉我们,如果我们干预并改变其输入,模型将如何表现,也因为它更容易计算。在本教程中,我们将完全关注第二个公式。我们还将使用更具体的术语“SHAP值”来指代应用于机器学习模型的条件期望函数的Shapley值。

SHAP值的计算可能非常复杂(通常是NP-hard),但线性模型非常简单,我们可以从部分依赖图中读取SHAP值。当我们在解释一个预测𝑓(𝑥)时,特定特征 𝑖 的SHAP值只是预期模型输出和在部分依赖图上特征值 x i x_i xi之间的差值:

# compute the SHAP values for the linear model
explainer = shap.Explainer(model.predict, X100)
shap_values = explainer(X)

# make a standard partial dependence plot
sample_ind = 20
shap.partial_dependence_plot(
    "MedInc", model.predict, X100, model_expected_value=True,
    feature_expected_value=True, ice=False,
    shap_values=shap_values[sample_ind:sample_ind+1,:]
)

输出:
在这里插入图片描述
经典部分依赖图和SHAP值之间的紧密对应意味着,如果我们在整个数据集上绘制特定特征的SHAP值,我们将准确地追踪出该特征的部分依赖图的以平均值为中心的版本:

shap.plots.scatter(shap_values[:,"MedInc"])

在这里插入图片描述

Shapley值的可加性

Shapley值的一个基本性质是,它们总是总和为所有玩家都在场时的游戏结果和没有玩家在场时的比赛结果之间的差异。对于机器学习模型,这意味着对于所解释的预测,所有输入特征的SHAP值将总是总和为基线(预期)模型输出和当前模型输出之间的差。最简单的方法是通过瀑布图,从我们对房价的背景预期开始𝐸[𝑓(𝑋)],然后一次添加一个功能,直到达到当前模型输出𝑓(𝑥):

# the waterfall_plot shows how we get from shap_values.base_values to model.predict(X)[sample_ind]
shap.plots.waterfall(shap_values[sample_ind], max_display=14)

在这里插入图片描述

解释additive regression模型

线性模型的部分依赖图与SHAP值有如此密切的联系的原因是,模型中的每个特征都是独立于其他特征处理的(效果只是相加在一起)。我们可以在放宽直线的线性要求的同时保持这种相加性质。这导致了众所周知的一类广义可加性模型(GAM)。虽然有很多方法可以训练这些类型的模型(比如将XGBoost模型设置为depth-1),但我们将使用专门为此设计的可解释的解释器。

pip install interpret

又是红凸凸的,习惯就好。。。
在这里插入图片描述

# fit a GAM model to the data
import interpret.glassbox
model_ebm = interpret.glassbox.ExplainableBoostingRegressor(interactions=0)
model_ebm.fit(X, y)

# explain the GAM model with SHAP
explainer_ebm = shap.Explainer(model_ebm.predict, X100)
shap_values_ebm = explainer_ebm(X)

# make a standard partial dependence plot with a single SHAP value overlaid
fig,ax = shap.partial_dependence_plot(
    "MedInc", model_ebm.predict, X100, model_expected_value=True,
    feature_expected_value=True, show=False, ice=False,
    shap_values=shap_values_ebm[sample_ind:sample_ind+1,:]
)

输出:
在这里插入图片描述

shap.plots.scatter(shap_values_ebm[:,"MedInc"])

输出:
在这里插入图片描述

# the waterfall_plot shows how we get from explainer.expected_value to model.predict(X)[sample_ind]
shap.plots.waterfall(shap_values_ebm[sample_ind])

输出:
在这里插入图片描述

# the waterfall_plot shows how we get from explainer.expected_value to model.predict(X)[sample_ind]
shap.plots.beeswarm(shap_values_ebm)

输出:
在这里插入图片描述

解释non-additive boosted tree模型

# train XGBoost model
import xgboost
model_xgb = xgboost.XGBRegressor(n_estimators=100, max_depth=2).fit(X, y)

# explain the GAM model with SHAP
explainer_xgb = shap.Explainer(model_xgb, X100)
shap_values_xgb = explainer_xgb(X)

# make a standard partial dependence plot with a single SHAP value overlaid
fig,ax = shap.partial_dependence_plot(
    "MedInc", model_xgb.predict, X100, model_expected_value=True,
    feature_expected_value=True, show=False, ice=False,
    shap_values=shap_values_xgb[sample_ind:sample_ind+1,:]
)

输出:
在这里插入图片描述

shap.plots.scatter(shap_values_xgb[:,"MedInc"])

输出:
在这里插入图片描述

shap.plots.scatter(shap_values_xgb[:,"MedInc"], color=shap_values)

输出:
在这里插入图片描述

解释线性逻辑回归模型

# a classic adult census dataset price dataset
X_adult,y_adult = shap.datasets.adult()

# a simple linear logistic model
model_adult = sklearn.linear_model.LogisticRegression(max_iter=10000)
model_adult.fit(X_adult, y_adult)

def model_adult_proba(x):
    return model_adult.predict_proba(x)[:,1]
def model_adult_log_odds(x):
    p = model_adult.predict_log_proba(x)
    return p[:,1] - p[:,0]

这。。。有点烦。。。可能网太烂了。。。
在这里插入图片描述
老规矩,翻源码。。。C:\Users\gxx\anaconda3\envs\tf-py37\Lib\site-packages\shap。。。
改成自己保存在本地的路径,保存,重启notebook。。。
在这里插入图片描述

注意,解释线性逻辑回归模型的概率在输入中不是线性的。

# make a standard partial dependence plot
sample_ind = 18
fig,ax = shap.partial_dependence_plot(
    "Capital Gain", model_adult_proba, X_adult, model_expected_value=True,
    feature_expected_value=True, show=False, ice=False
)

输出:
在这里插入图片描述
如果我们使用SHAP来解释线性逻辑回归模型的概率,我们会看到强烈的相互作用效应。这是因为线性逻辑回归模型在概率空间中不是可加性的。

# compute the SHAP values for the linear model
background_adult = shap.maskers.Independent(X_adult, max_samples=100)
explainer = shap.Explainer(model_adult_proba, background_adult)
shap_values_adult = explainer(X_adult[:1000])

在这里插入图片描述

shap.plots.scatter(shap_values_adult[:,"Age"])

输出:
在这里插入图片描述
如果我们解释模型的对数几率输出,我们会看到模型输入和模型输出之间存在完美的线性关系。重要的是要记住你所解释的模型的单位是什么,并且解释不同的模型输出可能会导致对模型行为的不同看法。

# compute the SHAP values for the linear model
explainer_log_odds = shap.Explainer(model_adult_log_odds, background_adult)
shap_values_adult_log_odds = explainer_log_odds(X_adult[:1000])

在这里插入图片描述

shap.plots.scatter(shap_values_adult_log_odds[:,"Age"])

输出:
在这里插入图片描述

# make a standard partial dependence plot
sample_ind = 18
fig,ax = shap.partial_dependence_plot(
    "Age", model_adult_log_odds, X_adult, model_expected_value=True,
    feature_expected_value=True, show=False, ice=False
)

输出:
在这里插入图片描述

解释non-additive boosted tree逻辑回归模型

# train XGBoost model
model = xgboost.XGBClassifier(n_estimators=100, max_depth=2).fit(X_adult, y_adult*1, eval_metric="logloss")

# compute SHAP values
explainer = shap.Explainer(model, background_adult)
shap_values = explainer(X_adult)

# set a display version of the data to use for plotting (has string values)
shap_values.display_data = shap.datasets.adult(display=True)[0].values

在这里插入图片描述
默认情况下,SHAP条形图将取数据集所有实例(行)上每个特征的平均绝对值。

shap.plots.bar(shap_values)

输出:
在这里插入图片描述
但是,平均绝对值并不是创建特征重要性全局度量的唯一方法,我们可以使用任何数量的变换。在这里,我们展示了使用最大绝对值如何突出资本收益和资本损失特征,因为它们具有罕见但高幅度的影响。

shap.plots.bar(shap_values.abs.max(0))

输出:
在这里插入图片描述
如果我们愿意处理更多的复杂性,我们可以使用蜂群图来总结每个特征的SHAP值的整个分布。

shap.plots.beeswarm(shap_values)

输出:
在这里插入图片描述
通过取绝对值并使用纯色,我们在条形图和全蜂群图的复杂性之间取得了折衷。请注意,上面的条形图只是以下蜂群图中所示值的汇总统计数据。

shap.plots.beeswarm(shap_values.abs, color="shap_red")

输出:
在这里插入图片描述

shap.plots.heatmap(shap_values[:1000])

输出:
在这里插入图片描述

shap.plots.scatter(shap_values[:,"Age"])

输出:
在这里插入图片描述

shap.plots.scatter(shap_values[:,"Age"], color=shap_values)

输出:
在这里插入图片描述

shap.plots.scatter(shap_values[:,"Age"], color=shap_values[:,"Capital Gain"])

输出:
在这里插入图片描述

shap.plots.scatter(shap_values[:,"Relationship"], color=shap_values)

输出:
在这里插入图片描述

处理相关特征

clustering = shap.utils.hclust(X_adult, y_adult)
shap.plots.bar(shap_values, clustering=clustering)

输出:
在这里插入图片描述

shap.plots.bar(shap_values, clustering=clustering, clustering_cutoff=0.8)

输出:
在这里插入图片描述

shap.plots.bar(shap_values, clustering=clustering, clustering_cutoff=1.8)

输出:
在这里插入图片描述

transformers NLP模型的解释

这展示了SHAP如何应用于具有高度结构化输入的复杂模型类型。

conda install transformers
conda install datasets
import transformers
import datasets
import torch
import numpy as np
import scipy as sp

# load a BERT sentiment analysis model
tokenizer = transformers.DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
model = transformers.DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased-finetuned-sst-2-english"
).cuda()

# define a prediction function
def f(x):
    tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=500, truncation=True) for v in x]).cuda()
    outputs = model(tv)[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores[:,1]) # use one vs rest logit units
    return val

# build an explainer using a token masker
explainer = shap.Explainer(f, tokenizer)

# explain the model's predictions on IMDB reviews
imdb_train = datasets.load_dataset("imdb")["train"]
shap_values = explainer(imdb_train[:10], fixed_context=1, batch_size=2)

bug又来了,,,估计是版本不兼容无法导入pyarrow。。。
在这里插入图片描述
当前pyarrow版本8.0.0
在这里插入图片描述
尝试安装pyarrow==4.0.1

pip install pyarrow==4.0.1

在这里插入图片描述
尝试安装pyarrow==6.0.0

pip install --user pyarrow==6.0.0

又又又有新bug。。。
在这里插入图片描述
莫非这需要魔法???
。。。
好的,这部分跳过。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1036780.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Jenkins “Trigger/call builds on other project“用法及携带参数

1.功能 “Trigger/call builds on other project” 功能是 Jenkins 中的一个特性,允许您在某个项目的构建过程中触发或调用另一个项目的构建。 当您在 Jenkins 中启用了 “Trigger/call builds on other project” 功能并配置了相应的触发条件后,当主项…

python实验2

1、实验题目:个人用户信息注册 模拟用户个人信息注册,需要输入用户个人信息 姓名、性别、年龄、血型、身高、电话 信息,并输出显示。 源代码: print(用户个人信息注册) name input("请输入您的姓名:") sex…

Northstar 量化平台

基于 B/S 架构、可替代付费商业软件的一站式量化交易平台。具备历史回放、策略研发、模拟交易、实盘交易等功能。兼顾全自动与半自动的使用场景。 已对接国内期货股票、外盘美股港股。 面向程序员的量化交易软件,用于期货、股票、外汇、炒币等多种交易场景&#xff…

1.2 kV SiC SWITCH-MOS 在短路应力后的分析

标题:Analysis of 1.2 kV SiC SWITCH-MOS after Short-circuit Stress 摘要 本研究调查了在短路应力后1.2 kV SWITCH-MOS的残余损伤。在应力施加后,相当于SWITCH-MOS耐受时间的约80%,正向阻断状态下的漏电流急剧增加。发现SWITCH-MOS中的SB…

一起学数据结构(8)——二叉树中堆的代码实现

在上篇文章中提到,提到了二叉树中一种特殊的结构——完全二叉树。对于完全二叉树,在存储时,适合使用顺序存储。对于非完全二叉树,适合用链式存储。本文将给出完全二叉树的顺序结构以及相关的代码实现: 1. 二叉树的结构…

Categraf v0.3.22部署

wget https://github.com/flashcatcloud/categraf/releases/download/v0.3.22/categraf-v0.3.22-linux-amd64.tar.gz下载安装包。 sudo mkdir /opt/categraf创建一个目录。 tar zxf categraf-v0.3.22-linux-amd64.tar.gz -C /opt/categraf进行解压。 /opt/categraf/categ…

ORA-27102: out of memory

正在外面办事呢,项目经理打电话并截图说明,物理服务器增加内存后,他调整sgapga后,重启无法启动了,报错ORA-27102: out of memory。 SYSorcl> startup; ORA-27102: out of memory Linux-x86_64 Error: 28: No space…

9领域事件

本系列包含以下文章: DDD入门DDD概念大白话战略设计代码工程结构请求处理流程聚合根与资源库实体与值对象应用服务与领域服务领域事件(本文)CQRS 案例项目介绍 # 既然DDD是“领域”驱动,那么我们便不能抛开业务而只讲技术&…

深度学习综述:Computation-efficient Deep Learning for Computer Vision: A Survey

论文作者:Yulin Wang,Yizeng Han,Chaofei Wang,Shiji Song,Qi Tian,Gao Huang 作者单位:Tsinghua University; Huawei Inc. 论文链接:http://arxiv.org/abs/2308.13998v1 内容简介: 在过去的十年中,深度学习模型取…

原生HTML实现marquee向上滚动效果

实现原理&#xff1a;借助CSS3中animation动画以及原生JS克隆API <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /…

【MySQL集群二】使用MyCat和ProxySql代理MySQL集群

中间件代理MySQL MyCat安装MyCat介绍&#xff1a;步骤1&#xff1a;安装Java环境步骤2&#xff1a;下载并解压Mycat步骤3&#xff1a;配置Mycat步骤4&#xff1a;启动Mycat ProxySql安装ProxySql介绍&#xff1a;步骤1&#xff1a;更新系统步骤2&#xff1a;安装ProxySQL步骤3&…

数学笔记:傅里叶变化

1 介绍 简而言之&#xff0c;傅里叶变换把一个输入信号分解成一堆正弦波的叠加 比如&#xff0c;以下是一个波&#xff1a; 这个波可以分解为两个正弦波的叠加。 也就是说&#xff0c;当我们将两个正弦波相加时&#xff0c;就会得到原来的波 哪怕是一个方波 也可以分解成一组…

【块状链表C++】文本编辑器(指针中 引用 的使用)

》》》算法竞赛 /*** file * author jUicE_g2R(qq:3406291309)————彬(bin-必应)* 一个某双流一大学通信与信息专业大二在读 * * brief 一直在竞赛算法学习的路上* * copyright 2023.9* COPYRIGHT 原创技术笔记&#xff1a;转载…

稀疏奖励问题解决方案总览

方案简介 HER (Hindsight Experience Replay) - 2017年 思想 HER&#xff08;Hindsight Experience Replay&#xff09;是一种特别设计用于解决稀疏奖励问题的强化学习算法。它主要用于那些具有高度稀疏奖励和延迟奖励的任务&#xff0c;特别是在连续动作空间中&#xff0c;如机…

IDEA设置注释快捷键进行 注释对齐

给大家推荐一个嘎嘎好用的功能~ 相信大家在使用IDE写代码的时候&#xff0c;经常用到 Ctrl / 来注释代码吧&#xff0c;但是默认的是将注释在行首对齐&#xff0c;看着很让人不舒服。但是下面的操作会将注释会和当前代码对齐&#xff0c;还会自动保留一个空格&#xff0c;真的…

【用unity实现100个游戏之13】复刻类泰瑞利亚生存建造游戏——包括建造系统和库存系统

文章目录 前言素材人物瓦片其他 一、建造系统1. 定义物品类2. 绘制地图3. 实现瓦片选中效果4. 限制瓦片选择5. 放置物品功能6. 清除物品7. 生成和拾取物品功能 二、库存系统1. 简单绘制UI2. 零代码控制背包的开启关闭3. 实现物品的拖拽拖拽功能拖拽恢复问题 4. 拖拽放置物品5. …

【C语言精髓 之 指针】指针*、取地址、解引用*、引用

/*** file * author jUicE_g2R(qq:3406291309)————彬(bin-必应)* 一个某双流一大学通信与信息专业大二在读 * copyright 2023.9* COPYRIGHT 原创技术笔记&#xff1a;转载需获得博主本人同意&#xff0c;且需标明转载源* language …

人工智能驱动的自然语言处理:解锁文本数据的价值

文章目录 什么是自然语言处理&#xff1f;NLP的应用领域1. 情感分析2. 机器翻译3. 智能助手4. 医疗保健5. 舆情分析 使用Python进行NLP避免NLP中的陷阱结论 &#x1f389;欢迎来到AIGC人工智能专栏~人工智能驱动的自然语言处理&#xff1a;解锁文本数据的价值 ☆* o(≧▽≦)o *…

flutter web 优化和flutter_admin_template

文章目录 Flutter Admin TemplateLive demo: https://githubityu.github.io/live_flutter_adminWeb 优化 Setup登录注册英文 亮色主题 中文 暗黑主题管理员登录权限 根据权限动态添加路由 第三方依赖License最后参考学习 Flutter Admin Template Responsive web with light/da…

C++ 学习系列 -- std::vector (未完待续)

一 std::vector 是什么&#xff1f; vector 是c 中一种序列式容器&#xff0c;与前面说的 array 类似&#xff0c;其内存分配是连续的&#xff0c;但是与 array 不同的地方在于&#xff0c;vector 在运行时是可以动态扩容的&#xff0c;此外 vector 提供了许多方便的操作&…