SHAP 依赖图

news2024/10/21 6:59:52

SHAP 依赖图

SHAP 依赖图用于可视化单个特征对机器学习模型预测结果的影响,具体来说,x 轴是特征值,y 轴是 SHAP 值(度量特征对预测结果的重要性),这些图可以直观地显示出某个特征是对模型预测起正向还是负向作用

代码实现

数据集加载

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
plt.rcParams['font.family'] = 'Times New Roman'
plt.rcParams['axes.unicode_minus'] = False
import warnings
warnings.filterwarnings("ignore")
df = pd.read_csv('UCI Heart Disease Dataset.csv')
# 划分特征和目标变量
X = df.drop(['target'], axis=1)
y = df['target']
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, 
                                                    random_state=42, stratify=df['target'])
df.head()

        首先,需要加载数据集并将其划分为特征 X 和目标变量 y,然后进行训练集和测试集的划分。目标变量是我们要预测的值,X 是输入的特征,这是一个分类任务,目标是预测患者是否患有心脏病。虽然是分类任务,但无论是分类问题还是回归问题,SHAP 依赖图的使用方式和原理是相同的,都可以用来解释模型中各个特征对预测结果的贡献

训练机器学习模型

from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import GridSearchCV

# GBT模型参数
params_gbt = {
    'learning_rate': 0.02,            # 学习率,控制每一步的步长,用于防止过拟合。典型值范围:0.01 - 0.1
    'max_depth': 3,                   # 树的深度,控制模型复杂度
    'random_state': 42,               # 随机种子,用于重现模型的结果
    'subsample': 0.7,                 # 每次迭代时随机选择的样本比例,用于增加模型的泛化能力
}

# 初始化Gradient Boosting分类模型
model_gbt = GradientBoostingClassifier(**params_gbt)

# 定义参数网格,用于网格搜索
param_grid = {
    'n_estimators': [100, 200, 300],  # 树的数量
    'max_depth': [3, 4, 5],               # 树的深度
    'learning_rate': [0.01, 0.1],   # 学习率
}

# 使用GridSearchCV进行网格搜索和k折交叉验证
grid_search = GridSearchCV(
    estimator=model_gbt,
    param_grid=param_grid,
    scoring='neg_log_loss',  # 评价指标为负对数损失
    cv=5,                    # 5折交叉验证
    n_jobs=-1,               # 并行计算
    verbose=1                # 输出详细进度信息
)

# 训练模型
grid_search.fit(X_train, y_train)

# 使用最优参数训练模型
best_model = grid_search.best_estimator_

        这里使用了梯度提升树(GBT),这是一个强大且常用的机器学习算法,通过网格搜索进行参数优化

计算 SHAP 值

import shap
explainer = shap.TreeExplainer(best_model)
# 计算shap值为numpy.array数组
shap_values_numpy = explainer.shap_values(X)
# 计算shap值为Explanation格式
shap_values_Explanation = explainer(X)

       

 模型训练完毕后,可以使用 shap 包来计算 SHAP 值,SHAP 值用于衡量特定特征对模型输出的影响,这里分别通过 explainer.shap_values(X) 计算 SHAP 值为数组格式以便自定义绘制,和通过 explainer(X) 计算为 Explanation 格式,直接使用 SHAP 自带的绘图函数进行可视化

SHAP自带绘图函数实现依赖图

默认参数下绘制

# 绘制 'age' 特征的SHAP依赖图
shap.dependence_plot('age', shap_values_Explanation.values, X_test, show=False)
# 添加 SHAP=0 的横线
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
plt.savefig("SHAP Dependence Plot_1.pdf", format='pdf',bbox_inches='tight',dpi=1200)
plt.show()

        上图展示了 age(年龄) 特征对模型预测结果的 SHAP 值的依赖关系,说明不同年龄段如何影响模型的预测

  • X 轴(age):表示年龄的取值范围,从 30 到 75 岁

  • Y 轴(SHAP value for age):表示年龄对模型预测的影响。SHAP 值为正时,表示该年龄段增加了模型预测的概率;SHAP 值为负时,表示该年龄段降低了预测的概率

从图中可以看到:

  • 年龄在 50 到 60 岁之间 对模型预测结果有显著的正面影响,SHAP 值较高,说明模型在这个年龄段倾向于预测目标事件的发生

  • 70 岁左右,SHAP 值开始变为负数,意味着在这个年龄段,模型预测发生的概率降低

  • 颜色代表了 thal(地中海贫血类型) 这一交互特征的影响,红色表示更高的值,蓝色表示较低的值,可以看到,thal 的不同取值对 SHAP 值的分布有一定影响,尤其是在 SHAP 值较大的区域,红色点较为集中

绘制无颜色条的年龄 SHAP 依赖图

# 绘制 'age' 特征的 SHAP 依赖图,不显示颜色条
shap.dependence_plot('age', shap_values_Explanation.values, X_test, show=False, interaction_index=None)
# 添加 SHAP=0 的横线
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
plt.savefig(r"SHAP Dependence Plot_2.pdf", format='pdf',bbox_inches='tight',dpi=1200)
plt.show()

        在这里,通过设置 interaction_index=None 可以关闭颜色条,不显示交互特征的影响。不过,该函数目前没有内置参数可以直接在 SHAP 值为 0 的位置添加一条横线。为了实现这一功能,可以利用 matplotlib 的 plt.axhline() 方法,在绘制依赖图后手动添加横线

接下来,还可以通过 explainer.shap_values(X) 格式绘制这个shap依赖图,以便实现自定义绘图

自定义绘图

将 SHAP 值转换为 DataFrame 格式以便于自定义绘图

# 将 SHAP 值转换为 DataFrame
shap_values_df = pd.DataFrame(shap_values, columns=X_test.columns)
shap_values_df.head()

单个shap依赖图绘制

# 绘制散点图,x轴是'age'特征,y轴是SHAP值
plt.scatter(X_test['age'], shap_values_df['age'], s=10)
# 添加shap=0的横线
plt.axhline(y=0, color='black', linestyle='-.', linewidth=1)
# 设置x和y轴标签  
plt.xlabel('age', fontsize=12)
plt.ylabel('SHAP value for age', fontsize=12)
# 隐藏顶部和右侧的脊柱
plt.gca().spines['top'].set_visible(False)
plt.gca().spines['right'].set_visible(False)
plt.savefig(r"SHAP Dependence Plot_3.pdf", format='pdf',bbox_inches='tight',dpi=1200)
plt.show()

        代码生成一个 SHAP 值依赖图,其中展示了特征 age 对模型输出的贡献,同时对图表进行了一些格式上的优化,比如隐藏不必要的边框线条、在 SHAP=0 处添加一条基准线,并最终将图像保存为高分辨率的 PDF 文件。相比于直接使用 shap.dependence_plot() 的默认作图方式,这种方法提供了更高的灵活性,特别是在定制化绘图方面,可以根据不同场景、需求对图表进行高度定制,从而提高可视化的效果和表达的准确性。

多个shap依赖图绘制

# 定义绘制 SHAP 依赖图的函数
def plot_shap_dependence(feature_list, df, shap_values_df, file_name="SHAP_Dependence_Plots.pdf"):
    fig, axs = plt.subplots(2, 3, figsize=(12, 8), dpi=1200)
    plt.subplots_adjust(hspace=0.4, wspace=0.4)
    
    # 循环绘制每个特征的 SHAP 依赖图
    for i, feature in enumerate(feature_list):
        row = i // 3  # 行号
        col = i % 3   # 列号
        ax = axs[row, col]
        
        # 绘制散点图,x轴是特征值,y轴是SHAP值
        ax.scatter(df[feature], shap_values_df[feature], s=10)
        
        # 添加shap=0的横线
        ax.axhline(y=0, color='black', linestyle='-.', linewidth=1)
        
        # 设置x和y轴标签
        ax.set_xlabel(feature, fontsize=12)
        ax.set_ylabel(f'SHAP value for\n{feature}', fontsize=12)
        
        # 隐藏顶部和右侧的脊柱
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
    # 隐藏最后一个空图表的坐标轴 (若画布未关闭)
    axs[1, 2].axis('off')
    plt.savefig(file_name, format='pdf', bbox_inches='tight')
    plt.show()

# 使用函数绘制age、trestbps、chol、thalach、oldpeak的shap依赖图
feature_list = ['age', 'trestbps', 'chol', 'thalach', 'oldpeak']
plot_shap_dependence(feature_list, X_test, shap_values_df, file_name=r"SHAP Dependence Plot_4.pdf")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2213760.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

web前端-----html5----用户注册

以改图为例 <!DOCTYPE html> <html lang"en"> <head> <meta charset"UTF-8"> <meta name"viewport" content"widthdevice-width, initial-scale1.0"> <title>用户注册</title> </hea…

计算机网络:数据链路层 —— 扩展共享式以太网

文章目录 共享式以太网共享式以太网存在的问题在物理层扩展以太网扩展站点与集线器之间的距离扩展共享式以太网的覆盖范围和站点数量 在链路层扩展以太网网桥的主要结构网桥的基本工作原理透明网桥自学习和转发帧生成树协议STP 共享式以太网 共享式以太网是当今局域网中广泛采…

uni-app基础语法(一)

我们今天的学习目标 基础语法1. 创建新页面2.pages配置页面3.tabbar配置4.condition 启动模式配置 基础语法 1. 创建新页面 2.pages配置页面 属性类型默认值描述pathString配置页面路径styleObject配置页面窗口表现&#xff0c;配置项参考pageStyle 我们来通过style修改页面的…

CASA(Carnegie-Ames-Stanford Approach) 模型原理及实践技术

植被作为陆地生态系统的重要组成部分对于生态环境功能的维持具有关键作用。植被净初级生产力&#xff08;Net Primary Productivity, NPP&#xff09;是指单位面积上绿色植被在单位时间内由光合作用生产的有机质总量扣除自养呼吸的剩余部分。植被NPP是表征陆地生态系统功能及可…

C语言:在Visual Studio中使用C语言scanf输入%s出现的栈溢出问题

学了C之后就很少使用C语言了&#xff0c;今天帮同学解答C语言问题&#xff0c;遇到了一个我以前没有遇到过的问题。 一、问题描述 先看以下代码&#xff1a; #include<stdio.h> int main() {char str[100] { 0 };scanf_s("%s", str);printf("%s",…

2024 年 04 月编程语言排行榜,PHP 排名创新低?

编程语言的流行度总是变化莫测&#xff0c;每个月的排行榜都揭示着新的趋势。2024年4月的编程语言排行榜揭示了一个引人关注的现象&#xff1a;PHP的排名再次下滑&#xff0c;创下了历史新低。这种变化对于PHP开发者和整个技术社区来说&#xff0c;意味着什么呢&#xff1f; P…

Java Maven day1014

ok了家人们&#xff0c;今天学习了如何安装和配置Maven项目&#xff0c;我们一起去看看吧 一.Maven概述 1.1 Maven作用 Maven 是专门用于管理和构建 Java 项目的工具&#xff0c;它的主要功能有&#xff1a; 提供了一套标准化的项目结构 提供了一套标准化的构建流程&#x…

力扣41~45题

题41&#xff08;困难&#xff09;&#xff1a; 分析&#xff1a; 这题我开始没什么思路,记录第一个逼我看评论的&#xff0c;后面看评论的方法&#xff0c;真解&#xff0c;借助一个数组&#xff0c;将nums对应数字放对应位置&#xff0c;然后如果下标和数字不同就返回 pyth…

支撑每秒数百万订单无压力,SpringBoot + Disruptor 太猛了!

文章目录 一、支撑每秒数百万订单无压力&#xff0c;SpringBoot Disruptor 太猛了&#xff01;二、项目环境配置1.Maven 配置 (pom.xml)2.Yaml 配置 (application.yml)3.Disruptor 的核心实现4.定义事件工厂&#xff08;OrderEventFactory&#xff09;5.定义事件处理器&#x…

概率 随机变量以及分布

一、基础定义及分类 1、随机变量 随机变量是一个从样本空间&#xff08;所有可能结果的集合&#xff09;到实数集的函数。&#xff08;随机变量的值可以是离散的&#xff0c;也可以是连续的。 &#xff09; 事件可以定义为随机变量取特定值的集合。 2、离散型随机变量 随机变…

怎么才能算AI智能体?

科技界对 AI 智能体的痴迷愈演愈烈。销售从智能体到自动化系统&#xff0c;比如像 Salesforce 和 Hubspot 这样的公司声称可以提供具有颠覆性的 AI 智能体。但是&#xff0c;我还没有看到一个真正令人信服、完全自主的基于 LLM 的智能体。市场上充斥着各种 “废物机器人”&…

OIDS与ERP:物料管理的高效协同

添加HanTop-MKT&#xff0c;咨询物料管理协同解决方案 客户案例 背景&#xff1a; 在当前快速发展的3C自动化行业&#xff0c;企业面临着前所未有的挑战。产品生命周期的缩短、个性化需求的增长以及市场变化的加速&#xff0c;都要求企业必须具备快速响应的能力。在这样的环…

一个月学会Java 第15天 枚举与Debug

Day15 枚举与Debug 这节课我们来看看枚举&#xff0c;和Debug&#xff0c;当我们学完并会用debug之后呢&#xff0c;编码会非常的舒服&#xff0c;而且debug就是调试嘛&#xff0c;所以我们会了debug之后&#xff0c;在程序哪里出问题也可以进行锁定。 第一章 枚举 枚举并不是非…

Spring Boot知识管理:提升团队协作效率

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常适…

探索巅峰性能 |迅为 RK3588开发板深度剖析

RK3588作为瑞芯微公司旗下一款高端处理器的杰出代表&#xff0c;凭借卓越的性能与多样化的外设接口成为了众多开发和爱好者的首选。随着RK3588在市场上的广泛应用&#xff0c;大家不禁要提出疑问&#xff1a;RK3588究竟强在何处&#xff1f;在2022年&#xff0c;北京迅为电子推…

【Linux网络编程】--- Linux基本指令(上)

Welcome to 9ilks Code World (๑•́ ₃ •̀๑) 个人主页: 9ilk (๑•́ ₃ •̀๑) 文章专栏&#xff1a; Linux网络编程 &#x1f3e0; ls命令 语法 : ls -[选项] [目录或文件] 功能 : 对于目录,该命令列出该目录下的所有子目录与文件;对于文件,将列出文件名…

STL.string(上)

string string类string类构造string类对象的容量操作size和lengthmax_sizeappend小总结下size、capacity、append、operatorresizereserve 初识迭代器附录1. vs下string结构的说明&#xff08;解释前文为什么capacity是16而不是别的&#xff09; 由于string创始初期没有参照导致…

1.centos 镜像

centos 它有官网的下载地址&#xff1a;https://vault.centos.org/ 选择想要的版本&#xff0c;我选择 centos7.8 进入到镜像目录 isos 选择 x86_64 选择想要的版本&#xff0c;我选择 CentOS-7-x86_64-DVD-2003.iso 安装就正常安装就行。我选择虚拟机安装。这个参考&…

一区鱼鹰优化算法+深度学习+注意力机制!OOA-TCN-LSTM-Attention多变量时间序列预测

一区鱼鹰优化算法深度学习注意力机制&#xff01;OOA-TCN-LSTM-Attention多变量时间序列预测 目录 一区鱼鹰优化算法深度学习注意力机制&#xff01;OOA-TCN-LSTM-Attention多变量时间序列预测预测效果基本介绍程序设计参考资料 预测效果 基本介绍 1.基于OOA-TCN-LSTM-Attenti…

Java 入门基础篇11 - java基础语法

一 流程控制 1.1 流程控制语句介绍 一个java程序有很多条语句组成&#xff0c;流程控制语句是用来控制程序中的各语句执行的顺序&#xff0c;通过流程语句控制让程序执行顺序达到我们想要实现的功能。 其流程控制方式采用结构化程序设计中规定的三种基本流程结构&#xff1a;…