Parallelize your massive SHAP computations with MLlib and PySpark

https://medium.com/towards-data-science/parallelize-your-massive-shap-computations-with-mllib-and-pyspark-b00accc8667c

(能翻墙直接看原文)

A stepwise guide for efficiently explaining your models using SHAP.

Photo by Pietro Jeng on Unsplash

Introduction to MLlib

Apache Spark’s Machine Learning Library (MLlib) is designed primarily for scalability and speed by leveraging the Spark runtime for common distributed use cases in supervised learning like classification and regression, unsupervised learning like clustering and collaborative filtering and in other cases like dimensionality reduction. In this article, I cover how we can use SHAP to explain a Gradient Boosted Trees (GBT) model that has fit our data at scale.

What are Gradient Boosted Trees?

Before we understand what Gradient Boosted Trees are, we need to understand boosting. Boosting is an ensemble technique that sequentially combines a number of weak learners to achieve an overall strong learner. In case of Gradient Boosted Trees, each weak learner is a decision tree that sequentially minimizes the errors (MSE in case of regression and log loss in case of classification) generated by the previous decision tree in that sequence. To read about GBTs in more detail, please refer to this blog post.

Understanding our imports

from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf
from pyspark.ml.classification import GBTClassificationModel
import shap
import pyspark.sql.functions as F
from pyspark.sql.types import *

The first two imports are for initializing a Spark session. It will be used for converting our pandas dataframe to a spark one. The third import is used to load our GBT model into memory which will be passed to our SHAP explainer to generate explanations. The SHAP explainer itself will be initialized using the SHAP package using the fourth import. The penultimate and last import is for performing SQL functions and using SQL types. These will be used in our User-Defined Function (UDF) which I shall describe later.

Converting our MLlib GBT feature vector to a Pandas dataframe

The SHAP Explainer takes a dataframe as input. However, training an MLlib GBT model requires data preprocessing. More specifically, the categorical variables in our data needs to be converted into numeric variables using either Category Indexing or One-Hot Encoding. To learn more about how to train a GBT model, refer to this article). The resulting “features” column is a SparseVector (to read more on it, check the “Preprocess Data” section in this example). It looks like something below:

SparseVector features column description — 1. default index value, 2. vector length, 3. list of indexes of the feature columns, 4. list of data values at the corresponding index at 3. [Image by author]

The “features” column shown above is for a single training instance. We need to transform this SparseVector for all our training instances. One way to do it is to iteratively process each row and append to our pandas dataframe that we will feed to our SHAP explainer (ouch!). There is a much faster way, which leverages the fact that we have all of our data loaded in memory (if not, we can load it in batches and perform the preprocessing for each in-memory batch). In Shikhar Dua’s words:

1. Create a list of dictionaries in which each dictionary corresponds to an input data row.

2. Create a data frame from this list.

So, based on the above method, we get something like this:

rows_list = []
for row in spark_df.rdd.collect(): 
    dict1 = {} 
    dict1.update({k:v for k,v in zip(spark_df.cols,row.features)})
    rows_list.append(dict1) 
pandas_df = pd.DataFrame(rows_list)

If rdd.collect() looks scary, it’s actually pretty simple to explain. Resilient Distributed Datasets (RDD) are fundamental Spark data structures that are an immutable distribution of objects. Each dataset in an RDD is further subdivided into logical partitions that can be computed in different worker nodes of our Spark cluster. So, all PySpark RDD collect() does is retrieve data from all the worker nodes to the driver node. As you might guess, this is a memory bottleneck, and if we are handling data larger than our driver node’s memory capacity, we need to increase the number of our RDD partitions and filter them by partition index. Read how to do that here.

Don’t take my word on the execution performance. Check out the stats.

Performance profiling for inserting rows to a pandas dataframe. [Source (Thanks to Mikhail_Sam and Peter Mortensen): here]

Here are the metrics from one of my Databricks notebook scheduled job runs:

Input size: 11.9 GiB (~12.78GB), Total time Across All Tasks: 20 min, Number of records: 165.16K

Summary Metrics for 125 Completed Tasks executed by the stage that run the above cell. [Image by author]

Working with the SHAP Library

We are now ready to pass our preprocessed dataset to the SHAP TreeExplainer. Remember that SHAP is a local feature attribution method that explains individual predictions as an algebraic sum of the shapley values of the features of our model.

We use a TreeExplainer for the following reasons:

  1. Suitable: TreeExplainer is a class that computes SHAP values for tree-based models (Random Forest, XGBoost, LightGBM, GBT, etc).
  2. Exact: Instead of simulating missing features by random sampling, it makes use of the tree structure by simply ignoring decision paths that rely on the missing features. The TreeExplainer output is therefore deterministic and does not vary based on the background dataset.
  3. Efficient: Instead of iterating over each possible feature combination (or a subset thereof), all combinations are pushed through the tree simultaneously, using a more complex algorithm to keep track of each combination’s result — reducing complexity from O(TL2ᵐ) for all possible coalitions to the polynomial O(TLD²) (where is the number of features, is number of trees, is maximum number of leaves and is maximum tree depth).

The check_additivity = False flag runs a validation check to verify if the sum of SHAP values equals to the output of the model. However, this flag requires predictions to be run that are not supported by Spark, so it needs to be set to False as it is ignored anyway. Once we get the SHAP values, we convert it into a pandas dataframe from a Numpy array, so that it is easily interpretable.

One thing to note is that the dataset order is preserved when we convert a Spark dataframe to pandas, but the reverse is not true.

The points above lead us to the code snippet below:

gbt = GBTClassificationModel.load('your-model-path') 
explainer = shap.TreeExplainer(gbt)
shap_values = explainer(pandas_df, check_additivity = False)
shap_pandas_df = pd.DataFrame(shap_values.values, cols = pandas_df.columns)

An Introduction to Pyspark UDFs and when to use them

How PySpark UDFs distribute individual tasks to worker (executor) nodes [Source: here]

User-Defined Functions are complex custom functions that operate on a particular row of our dataset. These functions are generally used when the native Spark functions are not deemed sufficient to solve the problem. Spark functions are inherently faster than UDFs because it is natively a JVM structure whose methods are implemented by local calls to Java APIs. However, PySpark UDFs are Python implementations that requires data movement between the Python interpreter and the JVM (refer to Arrow 4 in the picture above). This inevitably introduces some processing delay.

If no processing delays can be tolerated, the best thing to do is create a Python wrapper to call the Scala UDF from PySpark itself. A great example is shown in this blog. However, using a PySpark UDF was sufficient for my use case, since it is easy to understand and code.

The code below explains the Python function to be executed on each worker/executor node. We just pick up the highest SHAP values (absolute values as we want to find the most impactful negative features as well) and append it to the respective pos_features and neg_features list and in turn append both these lists to a features list that is returned to the caller.

def shap_udf(row):
    dict = {} 
    pos_features = [] 
    neg_features = [] 
    for feature in row.columns: 
        dict[feature] = row[feature]     dict_importance = {key: value for key, value in
    sorted(dict.items(), key=lambda item: __builtin__.abs(item[1]),   
    reverse = True)}     for k,v in dict_importance.items(): 
        if __builtin__.abs(v) >= <your-threshold-shap-value>: 
             if v > 0: 
                 pos_features.append((k,v)) 
             else: 
                 neg_features.append((k,v)) 
   features = [] 
   features.append(pos_features[:5]) 
   features.append(neg_features[:5])    return features

We then register our PySpark UDF with our Python function name (in my case, it is shap_udf) and specify the return type (mandatory in Python and Java) of the function in the parameters to F.udf(). There are two lists in the outer ArrayType(), one for positive features and the other for negative ones. Since each individual list comprises of at most 5 (feature-name, shap-value) StructType() pairs, it represents the inner ArrayType(). Below is the code:

udf_obj = F.udf(shap_udf, ArrayType(ArrayType(StructType([ StructField(‘Feature’, StringType()), 
StructField(‘Shap_Value’, FloatType()),
]))))

Now, we just create a new Spark dataframe with a column called ‘Shap_Importance’ that invokes our UDF for each row of the spark_shapdf dataframe. To split the positive and negative features, we create two columns in a new Spark dataframe called final_sparkdf. Our final code-snippet looks like below:

new_sparkdf = spark_df.withColumn(‘Shap_Importance’, udf_obj(F.struct([spark_shapdf[x] for x in spark_shapdf.columns])))final_sparkdf = new_sparkdf.withColumn(‘Positive_Shap’, final_sparkdf.Shap_Importance[0]).withColumn(‘Negative_Shap’, new_sparkdf.Shap_Importance[1])

And finally, we have extracted all the important features of our GBT model per testing instance without the use of any explicit for loops! The consolidated code can be found in the below GitHub gist.

from pyspark.sql import SparkSession
from pyspark import SparkContext, SparkConf
from pyspark.ml.classification import GBTClassificationModel
import shap
import pyspark.sql.functions as  F
from pyspark.sql.types import *

#convert the sparse feature vector that is passed to the MLlib GBT model into a pandas dataframe. 
#This 'pandas_df' will be passed to the Shap TreeExplainer.
rows_list = []
for row in spark_df.rdd.collect(): 
  dict1 = {}
  dict1.update({k:v for k,v in zip(spark_df.cols,row.features)})
  rows_list.append(dict1)
 
pandas_df = pd.DataFrame(rows_list)

#Load the GBT model from the path you have saved it
gbt = GBTClassificationModel.load("<your path where the GBT model is loaded>") 
#make sure the application where your notebook runs has access to the storage path!

explainer = shap.TreeExplainer(gbt)
#check_additivity requires predictions to be run that is not supported by spark [yet], so it needs to be set to False as it is ignored anyway.
shap_values = explainer(pandas_df, check_additivity = False)
shap_pandas_df = pd.DataFrame(shap_values.values, cols = pandas_df.columns)

spark = SparkSession.builder.config(conf=SparkConf().set("spark.master", "local[*]")).getOrCreate()
spark_shapdf = spark.createDataFrame(shap_pandas_df)


def shap_udf(row): #work on a single spark dataframe row, for all rows. This work is distributed among all the worker nodes of your Apache Spark cluster.
  dict = {}
  pos_features = []
  neg_features = []

  for feature in row.columns:
    dict[feature] = row[feature]

  dict_importance = {key: value for key, value in sorted(dict.items(), key=lambda item: __builtin__.abs(item[1]), reverse = True)}

  for k,v in dict_importance.items():
    if __builtin__.abs(v) >= <your-threshold-shap-value>:
      if v > 0:
        pos_features.append((k,v))
      else:
        neg_features.append((k,v))
  features = []
  #taking top 5 features from pos and neg features. We can increase this number.
  features.append(pos_features[:5])
  features.append(neg_features[:5])

  return features


udf_obj = F.udf(shap_udf, ArrayType(ArrayType(StructType([
  StructField('Feature', StringType()),
  StructField('Shap_Value', FloatType()),
]))))

new_sparkdf = spark_df.withColumn('Shap_Importance', udf_obj(F.struct([spark_shapdf[x] for x in spark_shapdf.columns])))
final_sparkdf = new_sparkdf.withColumn('Positive_Shap', final_sparkdf.Shap_Importance[0]).withColumn('Negative_Shap', new_sparkdf.Shap_Importance[1])

Get the most impactful Positive and Negative SHAP values from our fitted GBT Model

P.S. This is my first attempt at writing an article and if there are any factual or statistical inconsistencies, please reach out to me and I shall be more than happy to learn together with you! :)

References

[1] Soner Yıldırım, Gradient Boosted Decision Trees-Explained (2020), Towards Data Science

[2] Susan Li, Machine Learning with PySpark and MLlib — Solving a Binary Classification Problem (2018), Towards Data Science

[3] Stephen Offer, How to Train XGBoost With Spark (2020), Data Science and ML

[4] Use Apache Spark MLlib on Databricks (2021), Databricks

[5] Umberto Griffo, Don’t collect large RDDs (2020), Apache Spark — Best Practices and Tuning

[6] Nikhilesh Nukala, Yuhao Zhu, Guilherme Braccialli, Tom Goldenberg (2019), Spark UDF — Deep Insights in Performance, QuantumBlack

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/713761.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

前端:鼠标点击实现高亮特效

一、实现思路 获取鼠标点击位置 通过鼠标点击位置设置高亮裁剪动画 二、效果展示 三、按钮组件代码 <template><buttonclass"blueBut"click"clickHandler":style"{backgroundColor: clickBut ? rgb(31, 67, 117) : rgb(128, 128, 128),…

docker login 报错: http: server gave HTTP response to HTTPS client

环境&#xff1a; 自建 Harbor、Docker 1. 问题分析 # 命令&#xff0c;这里用的是 IP&#xff0c;可以为域名 docker login -u test 172.16.51.182:31120 # 输入密码 Password:# 报错如下&#xff1a; Error response from daemon: Get "https://172.16.51.182:31120/…

[DDR4] DDR 简史

依公知及经验整理&#xff0c;原创保护&#xff0c;禁止转载。 专栏 《深入理解DDR4》 存和硬盘&#xff0c;这对电脑的左膀右臂&#xff0c;共同扛起了存储的重任。内存以其超凡的存取速度闻名&#xff0c;但一旦断电&#xff0c;内存中的数据也会消失。它就像我们的工作桌面&…

基于WPF技术的换热站智能监控系统14--搭建西门子PLC通信环境

1、安装博途软件V15 本项目需要用到西门子PLC&#xff0c;系统所需的数据来自现场PLC实时采集的数据&#xff0c;所以需要配置PLC的通信环境&#xff0c;具体请看以下博客文章。 windows10企业版安装西门子博途V15---01准备环境_博途v15.1安装需求-CSDN博客 windows10企业…

5.Sentinel入门与使用

5.Sentinel入门与使用 1.什么是 Sentinel?Sentinel 主要有以下几个功能: 2.为什么需要 Sentinel?3.Sentinel 基本概念3.1 什么是流量控制?3.1.1 常见流量控制算法3.1.2 Sentinel 流量控制流控效果介绍如下: 3.2 什么是熔断?熔断策略 3.3 Sentinel 组成&#xff08;资源和规…

[vue3]组件通信

自定义属性 父组件中给子组件绑定属性, 传递数据给子组件, 子组件通过props选项接收数据 props传递的数据, 在模版中可以直接使用{{ message }}, 在逻辑中使用props.message defineProps defineProps是编译器宏函数, 就是一个编译阶段的标识, 实际编译器解析时, 遇到后会进行…

【Oracle APEX开发小技巧1】转换类型实现显示小数点前的 0 以 及常见类型转换

在 apex 交互式式网格中&#xff0c;有一数值类型为 NUMBER&#xff0c;保留小数点后两位的项&#xff0c;在 展示时小数点前的 0 不显示。 效果如下&#xff1a; 转换前&#xff1a; m.WEIGHT_COEFFICIENT 解决方案&#xff1a; 将 NUMBER&#xff08;20&#xff0c;2&#xf…

模拟电子技术基础(一)--本证半导体与杂质半导体

半导体分为两大类&#xff1a;本征半导体和杂质半导体。这两种类型的半导体在电子结构和电导特性上有显著的区别。 本征半导体&#xff08;Intrinsic Semiconductor&#xff09; 定义和组成&#xff1a;本征半导体是纯净的半导体&#xff0c;没有任何杂质原子。最常见的本征半…

2023年13个最适合销售电子书的WordPress主题

欢迎来到我们用于销售电子书和其他数字/可下载产品&#xff08;软件、应用程序、图标集、主题等&#xff09;的最佳WordPress主题的完整集合。 这些主题有内置的支付网关&#xff0c;可以通过 PayPal、信用卡等处理安全支付。&#xff08;易于配置&#xff01;&#xff09; 最…

Python轮子:Excel读写利器——openpyxl

原文链接&#xff1a;http://www.juzicode.com/python-module-openpyxl 在之前的xlwt和xlrd的文章中我们介绍了Excel访问的2个工具&#xff0c;它们分别只能对Excel文件进行写或者读&#xff0c;今天再介绍一个可以对Excel进行读和写的工具——openpyxl。需要注意的是openpyxl…

MFC工控项目实例之三theApp变量传递对话框参数

承接专栏《MFC工控项目实例之二主菜单制作》 用theApp变量传递对话框参数实时改变iPlotX坐标轴最小值、最大值。 1、新建IDD_SYS_DATA对话框&#xff0c;类名SYS_DATA。 三个编辑框IDC_EDIT1、IDC_EDIT2、IDC_EDIT3变量如图 2、SEAL_PRESSURE.h中添加代码 #include "re…

CleanMyMac X软件下载附加详细安装教程

​首先要介绍的是CleanMyMac X&#xff0c;这是一款极受欢迎的苹果电脑清理软件&#xff0c;它能够全面扫描你的电脑系统&#xff0c;清理无用的文件和垃圾&#xff0c;以释放硬盘空间&#xff0c;除了清理功能之外&#xff0c;CleanMyMac X 还可协助管理应用程序、优化性能、修…

python基础 002 - 2 常用数据类型

python的常用数据类型 int , 整型 1,2,3float ,小数&#xff0c;浮点类型1.2bool , boolean 布尔&#xff0c;真假。判断命题。True Flasestr &#xff0c;字符串 list , 列表 a []tuple, 元组 a ()dict , dictionary, 字典 a {}set , 集合 a {} 1 查看数据类型 typ…

某集团数字化转型蓝图规划项目案例(94页PPT)

案例介绍&#xff1a; 本集团数字化转型蓝图规划项目通过确定目标&#xff0c;如制定集团数字化转型的整体战略和规划&#xff0c;明确转型方向和目标。构建数字化业务体系&#xff0c;实现业务流程数字化、智能化。搭建数字化管理平台&#xff0c;提升集团内部的管理效率和决…

条件语句与循环结构

引言 条件语句和循环结构是C语言中构建程序逻辑的基本工具。它们允许程序根据条件执行不同的代码块和重复执行某些操作。本篇文章将详细介绍C语言中的条件语句和循环结构&#xff0c;包括if、else、switch语句&#xff0c;以及for、while、do-while循环的使用&#xff0c;帮助读…

IDEA快速入门03-代码头统一配置

三、代码规范配置 3.1 文件头和作者信息 配置入口&#xff1a;依次打开 File -> Settings -> Editor -> File and Code Templates。 Class /*** Copyright (C) 2020-${YEAR}, Glodon Digital Supplier & Purchaser BU.* * All Rights Reserved.*/ #if (${PACKA…

专业是软件工程,现在好迷茫,感觉什么都没有学到,该怎么办?

学习软件工程可能会遇到迷茫和困惑的时期&#xff0c;这很正常&#xff0c;尤其是在学习初期。这里有一些建议&#xff0c;或许可以帮助你找到方向&#xff1a; 明确目标&#xff1a;思考你学习软件工程的目的是什么&#xff0c;是为了将来从事软件开发工作&#xff0c;还是对编…

MyBatis 的多级缓存机制是怎么样运作的?

引言&#xff1a;上周三&#xff0c;小 X 去面试一家中厂&#xff0c;其中面试官问到 MyBatis 的多级缓存机制是怎么样运行的&#xff1f;这个问题可以好好准备一下&#xff0c;很多人可能只会用 MyBatisPlus&#xff0c;简单的多表联查 SQL 语句可能都写不出来&#xff0c;更别…

神经网络 torch.nn---nn.LSTM()

torch.nn - PyTorch中文文档 (pytorch-cn.readthedocs.io) LSTM — PyTorch 2.3 documentation LSTM层的作用 LSTM层:长短时记忆网络层&#xff0c;它的主要作用是对输入序列进行处理&#xff0c;对序列中的每个元素进行编码并保存它们的状态&#xff0c;以便后续的处理。 …

python-求分数序列和

[题目描述]&#xff1a; 输入&#xff1a; 输入一行一个正整数n(n≤30)。输出&#xff1a; 输出一行一个浮点数&#xff0c;表示分数序列前n 项的和&#xff0c;精确到小数点后4位。样例输入1 2 样例输出1 3.5000 来源/分类&#xff08;难度系数&#xff1a;一星&#xff09;…
最新文章