R与机器学习系列|15.可解释的机器学习算法(Interpretable Machine Learning)(下)

news2024/9/20 17:41:34

今天我们介绍可解释机器学习算法的最后一部分,基于XGBoost算法的SHAP值可视化。关于SHAP值其实我们之前的很多个推文中都介绍到,不论是R版本的还是Python版本的,亦不论是普通的分类问题还是生存数据模型的。在此推文中我们将基于XGBoost模型理解SHAP值的计算过程。此外,我们之前的SHAP可视化是基于别人封装好的函数。在今天的推文中,我们将学习如何使用ggplot2实现更加美观的SHAP值可视化。

生存数据机器学习算法模型的SHAP值可视化

R与机器学习系列|shapviz——机器学习“黑箱模型”SHAP值可视化

机器学习|SHAP value的另一种R可视化方式以及Python实现SHAP value可视化

机器学习|分享一篇25分临床预测模型文章,再次体现SHAP 值在机器学习中的重要性!

R学习|R复现机器学习算法XGBoost特征重要性解释——SHAP value

SHAP值在机器学习算法中的重要性主要体现在以下几个方面:

解释模型预测结果:SHAP值能够解释单个样本预测结果的贡献。它告诉我们每个特征对于某个特定预测结果的影响程度,从而帮助我们理解模型是如何基于输入特征做出预测的。

特征重要性评估:SHAP值可以用来评估特征的重要性。通过分析多个样本的SHAP值,我们可以得出哪些特征对于整体模型的性能影响最大,从而在特征选择、降维等任务中提供指导。

模型调试与验证:通过检查每个样本的SHAP值,可以帮助我们识别模型在某些特定预测上可能出现的问题。如果某个样本的预测与真实值相差较大,SHAP值可以揭示哪些特征导致了这种预测差异。

透明性和可信度:SHAP值的计算基于合理的博弈论原理,它们为模型的预测结果提供了一种可解释的解释。这可以增加模型的可信度,特别是在需要对模型决策做出解释的场景中。

特征交互分析:SHAP值不仅仅告诉我们单个特征的影响,还可以揭示不同特征之间的交互作用对预测结果的影响。这对于理解特征之间的复杂关系以及模型如何从这些关系中学习非常有帮助。

我们也可以看到SHAP值对模型的解释在高分的机器学习文献中出现的还是很频繁,如下面的两篇分别发表在EClinicalMedicine和 JAMA surgery上的文章。


Tsai, Shang-Feng et al. “Development and validation of an insulin resistance model for a population without diabetes mellitus and its clinical implication: a prospective cohort study.” EClinicalMedicine vol. 58 101934. 4 Apr. 2023, doi:10.1016/j.eclinm.2023.101934
Bertsimas, Dimitris et al. “Using Artificial Intelligence to Find the Optimal Margin Width in Hepatectomy for Colorectal Cancer Liver Metastases.” JAMA surgery vol. 157,8 (2022): e221819. doi:10.1001/jamasurg.2022.1819

1.1介绍

真实的 Shapley 值在理论上被认为是最优的;然而,真实SHAP值的计算会花费大量的时间。因此,iml包提供了近似 的Shapley 值计算方法。此外,Lundberg 和 Lee也开发了其他SHAP值的近似计算方法,虽然不是纯粹的模型无关方法,但也适用于基于树的模型,并且在大多数 XGBoost 算法实现中(包括 xgboost 包)完全可行。与 iml 的近似方法类似,这种基于树的 Shapley 值估计方法也是一种近似估计的方法,但其运行的时间远远要比iml包的计算时间短。为了演示,我们将使用第 12.5.2 节中使用的特征和最终创建的 XGBoost 模型。

1.2 SHAP计算

为了说明我们上面提到的问题,我们利用之前的数据再xgboost中拟合一个模型。xgboost算法的执行、参数调整及特征重要性解释在之前的章节中也有介绍。这里不过多介绍。首先,我们加载相关依赖包。

# Helper packages
library(tidyverse)    # for general data wrangling needs
# Modeling packages
library(gbm)      # for original implementation of regular and stochastic GBMs
library(h2o)      # for a java-based implementation of GBM variants
library(xgboost)  # for fitting extreme gradient boosting
library(rsample)# for data split
library(caret)# dummy funtion for categorical variables

然后我们加载需要用到的数据。需要注意的是,我们将一个变量处理为多分类变量,已说明独热编码在xgboost模型数据预处理中的应用。此外,如果这里直接将多分类变量处理为数值型变量,那么最后的SHAP图里面也不会看到该变量其他哑变量的信息。
此外,因为xgboost的输入特征文件格式为矩阵,如果这个时候不对多分类变量进行虚拟编码,那么直接转换为矩阵后数据维度便会出错。

data<-read.csv("diabetes.csv",header = T)
data%>%
  mutate(Pregnancies=case_when(
    Pregnancies<3~"A",
    Pregnancies>=3 &Pregnancies<=6~"B",
    Pregnancies>6~"C"
  ))->data
data$Pregnancies<-as.factor(data$Pregnancies)
# Stratified sampling with the rsample package
set.seed(123)
split <- initial_split(data, prop = 0.7, 
                       strata = "Outcome")
data_train  <- training(split)
data_test   <- testing(split)

data_train2=select(data_train, -Outcome)

独热编码

dmytr = dummyVars(" ~ .", data =data_train2, fullRank=T)
data_train3 = predict(dmytr, newdata =data_train2)

X <-data_train3
Y<- data_train[,ncol(data_train)]

此时的X为经过独热编码之后的特征矩阵。下面我们利用之前的超参数直接建立xgboost模型

# optimal parameter list
params <- list(
  eta = 0.01,
  max_depth = 3,
  min_child_weight = 3,
  subsample = 0.5,
  colsample_bytree = 0.5
)

# train final model
xgb.fit.final <- xgboost(
  params = params,
  data = X,
  label = Y,
  nrounds = 602,
  objective = "binary:logistic",
  verbose = 0
)

然后我们将特征重新由低到高进行标准化

feature_values <- X %>%
  as.data.frame() %>%
  mutate_all(scale) %>%
  gather(feature, feature_value) %>% 
  pull(feature_value)

然后我们计算特征的SHAP值以及SHAP重要性等参数

shap_df <- xgb.fit.final %>%
  predict(newdata = X, predcontrib = TRUE) %>%
  as.data.frame() %>%
  select(-BIAS) %>%
  gather(feature, shap_value) %>%
  mutate(feature_value = feature_values) %>%
  group_by(feature) %>%
  mutate(shap_importance = mean(abs(shap_value)))

1.3 SHAP可视化

现在,我们已经计算得到了这些特征的SHAP值,下面我们进行可视化。首先我们使用ggplot2进行可视化,严格的来说是基于ggplot2的蜂群图可视化。看过SHAP图后可以看到其实就是一个散点图,横坐标是SHAP值,纵坐标是每个特征,每个点代表一个观测值。此外,纵坐标按照SHAP值的重要性进行排序。

library(ggbeeswarm)
p1 <- ggplot(shap_df, aes(x = shap_value, y = reorder(feature, shap_importance))) +
  geom_quasirandom(groupOnX = FALSE, varwidth = TRUE, size =1, alpha = 0.8, aes(color = shap_value)) +
  scale_color_gradient(low = "#ffcd30", high = "#6600cd") +
  labs(x="SHAP value",y="")+
  theme_bw()+
  theme(axis.text = element_text(color = "black"),
        panel.border = element_rect(linewidth = 1))+
  geom_vline(xintercept = 0,linetype="dashed",color="grey",linewidth=1)

p1 
基于ggplot2的SHAP值可视化

从上图中我们可以看出患者血糖对结局影响最大,其次是年龄、BMI。

下面我们再根据SHAP重要性值做一个SHAP重要性图

p2 <- shap_df %>% 
  select(feature, shap_importance) %>%
  filter(row_number() == 1) %>%
  ggplot(aes(x = reorder(feature, shap_importance), y = shap_importance,fill=feature)) +
  geom_col(alpha=0.6) +
  coord_flip() +
  xlab(NULL) +
  ylab("mean(|SHAP value|)")+
  scale_fill_brewer(palette = "Set1")+
  theme_bw()+
  theme(legend.position = "",
        axis.text = element_text(color = "black"),
        panel.border = element_rect(linewidth = 1))
p2
SHAP重要性图

我们也可以把两个拼图展示

library(patchwork)
plot<-p1+p2&
  plot_layout(widths = c(2,1))
plot
SHAP值可视化及SHAP重要性排序

下面我们用之前封装好的SHAP.R函数看看效果

source("shap.R")
shap_result = shap.score.rank(xgb_model =xgb.fit.final, 
                              X_train =data_train3,
                              shap_approx = F)

#计算前10个特征的SHAP值
shap_long_hd = shap.prep(X_train =data_train3 , top_n =9)
#SHAP值可视化
shapR<-plot.shap.summary(data_long =shap_long_hd)
shapR

可以看到结果是一致的。
我们还可以利用这些信息来创建与PDPs(部分依赖图)相对应的另一种方法。基于Shapley值的依赖图将一个特征的Shapley值显示在y轴上,将该特征的值显示在x轴上。通过为数据集中的所有观察值绘制这些值,我们可以看到随着特征的值变化,其归因重要性如何变化。

shap_df %>% 
  filter(feature %in% c("BMI", "Glucose")) %>%
  ggplot(aes(x = feature_value, y = shap_value)) +
  geom_point(aes(color = shap_value)) +
  scale_colour_viridis_c(name = "Feature value\n(standardized)", option = "C") +
  facet_wrap(~ feature, scales = "free") +
  scale_y_continuous('Shapley value', labels = scales::comma) +
  xlab('Normalized feature value')+
  theme_bw()

我们可以看到BMI和血糖与SHAP值明显正相关,随着这两个特征值增大,SHAP值也逐渐增大,说明对结局的影响也增加。
终于,这个系列(有监督机器学习)更新到今天结束了。希望大家都有收获,下个系列我们再见!


图源于网络

参考来源:Bradley Boehmke & Brandon Greenwell R与机器学习



喜欢的朋友记得点赞、收藏、关注哦!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2130787.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

付费进群付费入群流量掘金入群系统九牧云版源码系统搭建

适用于各类资源类付费进群领取&#xff0c;私域类项目经营等 简洁大气直观。流量掘金类必备。 前端展示视频&#xff1a; https://pan.baidu.com/s/1lqyGCOrfmE4LDXb1cm-eDQ?pwdvnk6 https://yun.ktbf.xyz/s/by6jIzghpb 大致功能&#xff1a; 支持域名防红模式 支持对接…

QT+OSG+OSG-earth如何在窗口显示一个地球

1、环境配置 系统&#xff1a;windows10系统 QT:版本5.15.2 编译器&#xff1a;MSVC2019_64bit 编辑器&#xff1a;QT Creator OSG版本&#xff1a;3.7.0 64位 为MSVC环境下编译 osgQt:为第三方编译的库&#xff0c;OSG因为版本不同已经不提供osgQt的…

【一文就懂】计算机视觉期刊和会议缩写

下面IEEE相关的期刊及其缩写&#xff0c;并重新整理为期刊和会议两个部分。 期刊缩写 期刊全称缩写IEEE Transactions on Pattern Analysis and Machine IntelligenceIEEE Trans. Pattern Anal. Mach. Intell.IEEE Transactions on Image ProcessingIEEE Trans. Image Proce…

用于大数据分析的数据存储格式:Parquet、Avro 和 ORC 的性能和成本影响

高效的数据处理对于依赖大数据分析做出明智决策的企业和组织至关重要。显著影响数据处理性能的一个关键因素是数据的存储格式。本文探讨了不同存储格式&#xff08;特别是 Parquet、Avro 和 ORC&#xff09;对 Google Cloud Platform &#xff08;GCP&#xff09; 上大数据环境…

机器学习--支持向量机(SVM)

支持向量机(线性) S V M SVM SVM 引入 S V M SVM SVM 用于解决的问题也是 c l a s s i f i c a t i o n classification classification&#xff0c;这里 y ∈ { − 1 , 1 } y \in \{-1, 1\} y∈{−1,1} 比如说这样一个需要分类的训练数据&#xff1a; 我们可以有很多直线来…

Vue3 的 shallowRef 和 shallowReactive:优化性能

大家对 Vue3 的 ref 和 reactive 都很熟悉&#xff0c;那么对 shallowRef 和 shallowReactive 是否了解呢&#xff1f; 在编程和数据结构中&#xff0c;“shallow”&#xff08;浅层&#xff09;通常指对数据结构的最外层进行操作&#xff0c;而不递归地处理其内部或嵌套的数据…

Brave编译指南2024 Windows篇:安装Git(四)

1.引言 在编译Brave浏览器的过程中&#xff0c;Git是必不可少的工具之一。作为最流行的分布式版本控制系统&#xff0c;Git允许开发者高效地管理和协作开发源码。通过Git&#xff0c;您可以轻松获取、更新和提交Brave的源码版本&#xff0c;并跟踪所有更改记录。无论是独立开发…

大模型入门 ch 03:注意力机制

本文是github上的大模型教程LLMs-from-scratch的学习笔记&#xff0c;教程地址&#xff1a;教程链接 Chapter 3&#xff1a; Attention Mechanism 本文首先从固定参数的注意力机制说起&#xff0c;然后拓展到可以训练的注意力机制&#xff0c;然后加入掩码mask&#xff0c;最后…

基于 onsemi NCV78343 NCV78964的汽车矩阵式大灯方案

一、方案描述 大联大世平集团针对汽车矩阵大灯&#xff0c;推出 基于 onsemi NCV78343 & NCV78964的汽车矩阵式大灯方案。 开发板搭载的主要器件有 onsemi 的 Matrix Controller NCV78343、LED Driver NCV78964、Motor Driver NCV70517、以及 NXP 的 MCU S32K344。 二、开…

抖音微信超火国庆节国旗头像生成源码

源码介绍&#xff1a; 抖音微信超火国庆节国旗头像生成源码&#xff0c;静态页前端生成速度超快&#xff01;源码直接上传到服务器即可使用。 1、打开地址后点击上传->选一张你喜欢的头像->然后点右边箭头符合选款式->最后点保存头像->按照提示 2、保存到手机即…

开源多场景问答社区论坛Apache Answer本地部署并发布至公网使用

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

CCDO|数据跃动未来:首席数据官如何引领构建活数据引擎

在数字化浪潮汹涌澎湃的今天&#xff0c;数据已成为企业最宝贵的资产之一&#xff0c;它不仅记录着过去&#xff0c;更预示着未来的方向。随着大数据、人工智能、云计算等技术的飞速发展&#xff0c;数据的潜力被前所未有地激发&#xff0c;而首席数据官&#xff08;CDO&#x…

T4周:猴痘病识别

>- **&#x1f368; 本文为[&#x1f517;365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客** >- **&#x1f356; 原作者&#xff1a;[K同学啊](https://mtyjkh.blog.csdn.net/)** 1. 设置GPU 如果使用的是CPU可以忽略这步 …

Eclipse折叠if、else、try catch的{}

下载插件com.cb.eclipse.folding_1.0.6.jar。将插件放到eclipse的dropins文件夹中。修改配置&#xff0c;然后保存&#xff0c;重启Eclipse即可。

Flink快速上手

Flink快速上手 批处理Maven配置pom文件java编写wordcount代码 有界流处理无界流处理 批处理 Maven配置pom文件 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://ww…

《深度学习》深度学习 框架、流程解析、动态展示及推导

目录 一、深度学习 1、什么是深度学习 2、特点 3、神经网络构造 1&#xff09;单层神经元 • 推导 • 示例 2&#xff09;多层神经网络 3&#xff09;小结 4、感知器 神经网络的本质 5、多层感知器 6、动态图像示例 1&#xff09;一个神经元 相当于下列状态&…

通信原理:绪论

1、消息、信号与信息 消息&#xff1a; 通信系统要传输的对象&#xff0c;是具体的、物理上存在的东西。也是信息的载体。形式多种&#xff1a; 连续消息&#xff1a;语音、温度、活动图片.离散消息&#xff1a;数据、符号、文字. 信息&#xff1a; 消息中所蕴含的内容&…

proteus+51单片机+实验(LCD1620、定时器)

目录 1.LCD1602液晶显示屏 1.1基本概念 1.1.1LCD的简介 1.1.2LCD的显示原理 ​​​1.1.3LCD的硬件电路 1.1.4LCD的常见指令 1.1.5LCD的时序 ​​​​​​​1.2代码 1.2.1写命令和写数据操作 1.2.2初始化和测试代码 1. 3.3功能函数 1.3proteus代码 1.3.1器件代码 1.…

几种手段mfc140u.dll丢失的解决方法,了解mfc140u.dll

在使用Windows操作系统时&#xff0c;许多用户可能会遇到“找不到mfc140u.dll”或“mfc140u.dll未找到”的错误提示。这个错误通常是由于该文件丢失或损坏所致。本文将详细介绍mfc140u.dll文件的作用、丢失的原因及其解决方法&#xff0c;帮助您快速恢复系统的正常运行。 一、m…

无人机视角的道路损害数据集,2400张图像,包括纵向裂缝(LC)、横向裂缝(TC)、鳄鱼裂缝(AC)、斜裂(OC)、修补(RP)和坑洞(PH),共2.3GB

数据集名称 无人机视角的道路损害数据集 数据集描述 这是一个专注于道路损害检测的数据集&#xff0c;包含了从无人机视角拍摄的2400张高清图像&#xff0c;涵盖了六种典型的道路损害类型&#xff1a;纵向裂缝&#xff08;LC&#xff09;、横向裂缝&#xff08;TC&#xff0…