Parallelize your massive SHAP computations with MLlib and PySpark

news2025/1/11 23:44:40

https://medium.com/towards-data-science/parallelize-your-massive-shap-computations-with-mllib-and-pyspark-b00accc8667c

(能翻墙直接看原文)

A stepwise guide for efficiently explaining your models using SHAP.

Photo by Pietro Jeng on Unsplash

Introduction to MLlib

Apache Spark’s Machine Learning Library (MLlib) is designed primarily for scalability and speed by leveraging the Spark runtime for common distributed use cases in supervised learning like classification and regression, unsupervised learning like clustering and collaborative filtering and in other cases like dimensionality reduction. In this article, I cover how we can use SHAP to explain a Gradient Boosted Trees (GBT) model that has fit our data at scale.

What are Gradient Boosted Trees?

Before we understand what Gradient Boosted Trees are, we need to understand boosting. Boosting is an ensemble technique that sequentially combines a number of weak learners to achieve an overall strong learner. In case of Gradient Boosted Trees, each weak learner is a decision tree that sequentially minimizes the errors (MSE in case of regression and log loss in case of classification) generated by the previous decision tree in that sequence. To read about GBTs in more detail, please refer to this blog post.

Understanding our imports

from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf
from pyspark.ml.classification import GBTClassificationModel
import shap
import pyspark.sql.functions as F
from pyspark.sql.types import *

The first two imports are for initializing a Spark session. It will be used for converting our pandas dataframe to a spark one. The third import is used to load our GBT model into memory which will be passed to our SHAP explainer to generate explanations. The SHAP explainer itself will be initialized using the SHAP package using the fourth import. The penultimate and last import is for performing SQL functions and using SQL types. These will be used in our User-Defined Function (UDF) which I shall describe later.

Converting our MLlib GBT feature vector to a Pandas dataframe

The SHAP Explainer takes a dataframe as input. However, training an MLlib GBT model requires data preprocessing. More specifically, the categorical variables in our data needs to be converted into numeric variables using either Category Indexing or One-Hot Encoding. To learn more about how to train a GBT model, refer to this article). The resulting “features” column is a SparseVector (to read more on it, check the “Preprocess Data” section in this example). It looks like something below:

SparseVector features column description — 1. default index value, 2. vector length, 3. list of indexes of the feature columns, 4. list of data values at the corresponding index at 3. [Image by author]

The “features” column shown above is for a single training instance. We need to transform this SparseVector for all our training instances. One way to do it is to iteratively process each row and append to our pandas dataframe that we will feed to our SHAP explainer (ouch!). There is a much faster way, which leverages the fact that we have all of our data loaded in memory (if not, we can load it in batches and perform the preprocessing for each in-memory batch). In Shikhar Dua’s words:

1. Create a list of dictionaries in which each dictionary corresponds to an input data row.

2. Create a data frame from this list.

So, based on the above method, we get something like this:

rows_list = []
for row in spark_df.rdd.collect(): 
    dict1 = {} 
    dict1.update({k:v for k,v in zip(spark_df.cols,row.features)})
    rows_list.append(dict1) 
pandas_df = pd.DataFrame(rows_list)

If rdd.collect() looks scary, it’s actually pretty simple to explain. Resilient Distributed Datasets (RDD) are fundamental Spark data structures that are an immutable distribution of objects. Each dataset in an RDD is further subdivided into logical partitions that can be computed in different worker nodes of our Spark cluster. So, all PySpark RDD collect() does is retrieve data from all the worker nodes to the driver node. As you might guess, this is a memory bottleneck, and if we are handling data larger than our driver node’s memory capacity, we need to increase the number of our RDD partitions and filter them by partition index. Read how to do that here.

Don’t take my word on the execution performance. Check out the stats.

Performance profiling for inserting rows to a pandas dataframe. [Source (Thanks to Mikhail_Sam and Peter Mortensen): here]

Here are the metrics from one of my Databricks notebook scheduled job runs:

Input size: 11.9 GiB (~12.78GB), Total time Across All Tasks: 20 min, Number of records: 165.16K

Summary Metrics for 125 Completed Tasks executed by the stage that run the above cell. [Image by author]

Working with the SHAP Library

We are now ready to pass our preprocessed dataset to the SHAP TreeExplainer. Remember that SHAP is a local feature attribution method that explains individual predictions as an algebraic sum of the shapley values of the features of our model.

We use a TreeExplainer for the following reasons:

  1. Suitable: TreeExplainer is a class that computes SHAP values for tree-based models (Random Forest, XGBoost, LightGBM, GBT, etc).
  2. Exact: Instead of simulating missing features by random sampling, it makes use of the tree structure by simply ignoring decision paths that rely on the missing features. The TreeExplainer output is therefore deterministic and does not vary based on the background dataset.
  3. Efficient: Instead of iterating over each possible feature combination (or a subset thereof), all combinations are pushed through the tree simultaneously, using a more complex algorithm to keep track of each combination’s result — reducing complexity from O(TL2ᵐ) for all possible coalitions to the polynomial O(TLD²) (where is the number of features, is number of trees, is maximum number of leaves and is maximum tree depth).

The check_additivity = False flag runs a validation check to verify if the sum of SHAP values equals to the output of the model. However, this flag requires predictions to be run that are not supported by Spark, so it needs to be set to False as it is ignored anyway. Once we get the SHAP values, we convert it into a pandas dataframe from a Numpy array, so that it is easily interpretable.

One thing to note is that the dataset order is preserved when we convert a Spark dataframe to pandas, but the reverse is not true.

The points above lead us to the code snippet below:

gbt = GBTClassificationModel.load('your-model-path') 
explainer = shap.TreeExplainer(gbt)
shap_values = explainer(pandas_df, check_additivity = False)
shap_pandas_df = pd.DataFrame(shap_values.values, cols = pandas_df.columns)

An Introduction to Pyspark UDFs and when to use them

How PySpark UDFs distribute individual tasks to worker (executor) nodes [Source: here]

User-Defined Functions are complex custom functions that operate on a particular row of our dataset. These functions are generally used when the native Spark functions are not deemed sufficient to solve the problem. Spark functions are inherently faster than UDFs because it is natively a JVM structure whose methods are implemented by local calls to Java APIs. However, PySpark UDFs are Python implementations that requires data movement between the Python interpreter and the JVM (refer to Arrow 4 in the picture above). This inevitably introduces some processing delay.

If no processing delays can be tolerated, the best thing to do is create a Python wrapper to call the Scala UDF from PySpark itself. A great example is shown in this blog. However, using a PySpark UDF was sufficient for my use case, since it is easy to understand and code.

The code below explains the Python function to be executed on each worker/executor node. We just pick up the highest SHAP values (absolute values as we want to find the most impactful negative features as well) and append it to the respective pos_features and neg_features list and in turn append both these lists to a features list that is returned to the caller.

def shap_udf(row):
    dict = {} 
    pos_features = [] 
    neg_features = [] 
    for feature in row.columns: 
        dict[feature] = row[feature]     dict_importance = {key: value for key, value in
    sorted(dict.items(), key=lambda item: __builtin__.abs(item[1]),   
    reverse = True)}     for k,v in dict_importance.items(): 
        if __builtin__.abs(v) >= <your-threshold-shap-value>: 
             if v > 0: 
                 pos_features.append((k,v)) 
             else: 
                 neg_features.append((k,v)) 
   features = [] 
   features.append(pos_features[:5]) 
   features.append(neg_features[:5])    return features

We then register our PySpark UDF with our Python function name (in my case, it is shap_udf) and specify the return type (mandatory in Python and Java) of the function in the parameters to F.udf(). There are two lists in the outer ArrayType(), one for positive features and the other for negative ones. Since each individual list comprises of at most 5 (feature-name, shap-value) StructType() pairs, it represents the inner ArrayType(). Below is the code:

udf_obj = F.udf(shap_udf, ArrayType(ArrayType(StructType([ StructField(‘Feature’, StringType()), 
StructField(‘Shap_Value’, FloatType()),
]))))

Now, we just create a new Spark dataframe with a column called ‘Shap_Importance’ that invokes our UDF for each row of the spark_shapdf dataframe. To split the positive and negative features, we create two columns in a new Spark dataframe called final_sparkdf. Our final code-snippet looks like below:

new_sparkdf = spark_df.withColumn(‘Shap_Importance’, udf_obj(F.struct([spark_shapdf[x] for x in spark_shapdf.columns])))final_sparkdf = new_sparkdf.withColumn(‘Positive_Shap’, final_sparkdf.Shap_Importance[0]).withColumn(‘Negative_Shap’, new_sparkdf.Shap_Importance[1])

And finally, we have extracted all the important features of our GBT model per testing instance without the use of any explicit for loops! The consolidated code can be found in the below GitHub gist.

Get the most impactful Positive and Negative SHAP values from our fitted GBT Model

P.S. This is my first attempt at writing an article and if there are any factual or statistical inconsistencies, please reach out to me and I shall be more than happy to learn together with you! :)

References

[1] Soner Yıldırım, Gradient Boosted Decision Trees-Explained (2020), Towards Data Science

[2] Susan Li, Machine Learning with PySpark and MLlib — Solving a Binary Classification Problem (2018), Towards Data Science

[3] Stephen Offer, How to Train XGBoost With Spark (2020), Data Science and ML

[4] Use Apache Spark MLlib on Databricks (2021), Databricks

[5] Umberto Griffo, Don’t collect large RDDs (2020), Apache Spark — Best Practices and Tuning

[6] Nikhilesh Nukala, Yuhao Zhu, Guilherme Braccialli, Tom Goldenberg (2019), Spark UDF — Deep Insights in Performance, QuantumBlack

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1828106.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ThinkPHP邮件发送配置教程?怎么配置群发?

ThinkPHP邮件发送安全性如何保障&#xff1f;ThinkPHP如何实现&#xff1f; 无论是用户注册后的验证邮件&#xff0c;还是订单处理的通知邮件&#xff0c;都需要一个可靠的邮件发送机制。AokSend将详细介绍如何在ThinkPHP框架中配置邮件发送功能&#xff0c;并带您逐步了解其中…

第十九篇——信噪比:历史有真相嘛?

目录 一、背景介绍二、思路&方案三、过程1.思维导图2.文章中经典的句子理解3.学习之后对于投资市场的理解4.通过这篇文章结合我知道的东西我能想到什么&#xff1f; 四、总结五、升华 一、背景介绍 对于信噪比的理解&#xff0c;通过历史是否有真相这个故事来表达信号和噪…

Docker Jenkins(改错版本)

Devops:它强调开发(Development)和运维(Operations)团队之间的协作.实现更快,更可靠的软件交付部署. JenKins是一个开源的自动化服务器,广泛用于构建,测试和部署软件项目.它是持续集成(CI)和持续交付/部署(CD)的工具.JenKins是实现DevOps实践的重要工具. 前端项目部署一般流程:…

CTFshow-web sql注入

Web171 1 在题目中可以看到查询语句为 "select username,password from user where username !flag and id ".$_GET[id]." limit 1;"; 直接使用万能密码 查到了所有用户 获得flag Web172 0 可以看到返回逻辑显示 如果返回的查询数据中username不等于fl…

【JavaEE精炼宝库】多线程(6)线程池

目录 一、线程池的概念及优势 1.1 线程池的概念&#xff1a; 1.2 线程池的优势&#xff1a; 二、工厂模式 三、标准库中的线程池 3.1 标准库线程池参数解释&#xff1a; 3.1.1 corePoolSize | maximumPoolSize&#xff1a; 3.1.2 keepAliveTime | unit&#xff1a; 3.1…

String常用方法详解

auth&#xff1a;别晃我的可乐 date&#xff1a;2024年06月16日 比较大小 equals(Object obj): 用于比较字符串内容是否相等。compareTo(String anotherString): 按字典顺序比较两个字符串。 String str1 "hello"; String str2 "world";boolean isEqual …

Python-random模块

一、random模块的用法 import randomprint(random.random()) # 不需要传参&#xff0c;random 返回0-1随机小数print(random.uniform(1, 10)) # 需要传参&#xff0c;返回参数区间的随机小数print(random.randint(-100, 100)) # 需要传参&#xff0c;返回参数区间的随机整数…

APP抓包渗透测试首尝试

前言 文章分为两大部分&#xff0c;一是介绍抓取app数据包的常用方法&#xff0c;二是结合笔者所接触的授权项目&#xff0c;对抓取的app数据包转传统Web渗透测试的小分享。通过阅读文章分享内容&#xff0c;读者能快速了解app渗透测试方式&#xff0c;初步入门APP抓包渗透测试…

禁止methtype联网

mathtype断网_如何禁止mathtype联网-CSDN博客https://blog.csdn.net/qq_41060221/article/details/128144783

03-QTWebEngine中使用qtvirtualkeyboard

qt提供了 virtualKeyboard 虚拟键盘模块&#xff0c;只需要在在main函数中最开始加入这样一句就可以了 qputenv("QT_IM_MODULE", QByteArray("qtvirtualkeyboard")); 但是在使用的时候遇到了一些问题&#xff1a; 1、中文输入的时候没有输入提示 Qvirt…

openh264 SVC 时域分层原理介绍

openh264 OpenH264是一个开源的H.264编码器&#xff0c;由Cisco公司开发并贡献给开源社区。它支持包括SVC&#xff08;Scalable Video Coding&#xff09;在内的多种编码特性&#xff0c;适用于实时应用场景&#xff0c;比如WebRTC。OpenH264项目在GitHub上是公开的&#xff0…

hugo-magic主题使用教程(一)

前提条件 以下教程以windows10为例操作终端使用git bash魔法上网的前提下 下载hugo https://github.com/gohugoio/hugo/releases/download/v0.127.0/hugo_extended_0.127.0_windows-amd64.zip解压到任意目录,然后将目录添加到系统环境变量 如图 (windows)打开cmd 输入 hugo …

windows系统,家庭自用NAS。本地局域网 Docker安装nextcloud

windows系统&#xff0c;家庭自用NAS。本地局域网 Docker安装nextcloud 1、docker安装 太简单了&#xff0c;直接去搜一搜。 docker-compose 相关命令 docker-compose down docker compose up -d2、还是使用老的 在你需要挂载的目录下&#xff0c;新建一个文件&#xff0c;…

2023年13个最适合销售电子书的WordPress主题

欢迎来到我们用于销售电子书和其他数字/可下载产品&#xff08;软件、应用程序、图标集、主题等&#xff09;的最佳WordPress主题的完整集合。 这些主题有内置的支付网关&#xff0c;可以通过 PayPal、信用卡等处理安全支付。&#xff08;易于配置&#xff01;&#xff09; 最…

我主编的电子技术实验手册(07)——串联电路

本专栏是笔者主编教材&#xff08;图0所示&#xff09;的电子版&#xff0c;依托简易的元器件和仪表安排了30多个实验&#xff0c;主要面向经费不太充足的中高职院校。每个实验都安排了必不可少的【预习知识】&#xff0c;精心设计的【实验步骤】&#xff0c;全面丰富的【思考习…

海外仓系统如何让海外仓受益,WMS海外仓系统使用指南

随着跨境电商业务的快速发展&#xff0c;海外仓面临着需要更加高速运转的巨大挑战。 当海外仓出现因为手动作业导致效率低下&#xff0c;成本不断飙升或者出现库存管理问题的时候&#xff0c;意味着是时候引入一套合适的海外仓管理系统了。 对于寻求海外仓业务流程优化的企业…

2024年大数据领域的主流分布式计算框架有哪些

Apache Spark 适用场景 以批处理闻名&#xff0c;有专门用于机器学习的相关类库进行复杂的计算&#xff0c;有SparkSQL可以进行简单的交互式查询&#xff0c;也可以使用DataSet&#xff0c;RDD&#xff0c;DataFrame进行复杂的ETL操作。 关键词 处理数据量大批计算微批计算…

MATLAB中与直方图有关函数的关系

histogram Histogram plot画直方图 histcounts 直方图 bin 计数 histcounts是histogram的主要计算函数。 discretize 将数据划分为 bin 或类别 histogram2 画二元直方图 histcounts2 二元直方图 bin 计数 hist和histc过时了。替换不建议使用的 hist 和 histc 实例 hist → \r…

ByteTrack

1. 论文中伪代码表示的流程图 2. 简要版 此图源自&#xff1a; ByteTrack多目标跟踪原理&#xff0c;白老师人工智能学堂 3. 详细版 根据ByteTrack-CPP-ncnn代码的数据流画的较为详细的流程图&#xff1a; 4. ByteTrack-CPP-ncnn的UML类图 Reference ByteTrack多目标跟踪原…

[FFmpeg学习]windows环境sdl播放音频试验

参考资料&#xff1a; FFmpeg和SDL2播放mp4_sdl 播放mp4 声音-CSDN博客 SimplePlayer/SimplePlayer.c at master David1840/SimplePlayer GitHub 在前面的学习中&#xff0c;通过获得的AVFrame进行了播放画面&#xff0c; [FFmpeg学习]初级的SDL播放mp4测试-CSDN博客 播放…