SHAP(四):NHANES I 生存模型

news2025/2/25 10:28:51

SHAP(四):NHANES I 生存模型

这是一个 Cox 比例风险模型,基于来自 NHANES I 的数据以及来自 NHANES I 流行病学随访研究。 它旨在说明 SHAP 值如何能够以传统上仅由线性模型提供的清晰度解释 XGBoost 模型。 我们在数据中看到有趣的非线性模式,这表明了这种方法的潜力。 请记住,我们尚未对数据进行检查以校准当前的实验室测试,因此您不应将结果视为可操作的医学见解,而应将其视为概念证明。

请注意,对 Cox 损失和 SHAP 交互效果的支持最近才合并,因此您需要最新的 XGBoost 主版本才能运行此笔记本。

import matplotlib.pylab as pl
import xgboost
from sklearn.model_selection import train_test_split

import shap

1.创建 XGBoost 数据对象

这使用了 SHAP 数据集模块中可用的 NHANES I 数据的预处理子集。

X, y = shap.datasets.nhanesi()
X_display, y_display = shap.datasets.nhanesi(
    display=True
)  # human readable feature values

xgb_full = xgboost.DMatrix(X, label=y)

# create a train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
xgb_train = xgboost.DMatrix(X_train, label=y_train)
xgb_test = xgboost.DMatrix(X_test, label=y_test)

2.训练 XGBoost 模型

# use validation set to choose # of trees
params = {"eta": 0.002, "max_depth": 3, "objective": "survival:cox", "subsample": 0.5}
model_train = xgboost.train(
    params, xgb_train, 10000, evals=[(xgb_test, "test")], verbose_eval=1000
)
[0]	test-cox-nloglik:7.26952
[1000]	test-cox-nloglik:6.55767
[2000]	test-cox-nloglik:6.48836
[3000]	test-cox-nloglik:6.47129
[4000]	test-cox-nloglik:6.46786
[5000]	test-cox-nloglik:6.46583
[6000]	test-cox-nloglik:6.46623
[7000]	test-cox-nloglik:6.46841
[8000]	test-cox-nloglik:6.46972
[9000]	test-cox-nloglik:6.47175
[9999]	test-cox-nloglik:6.47396
# train final model on the full data set
params = {"eta": 0.002, "max_depth": 3, "objective": "survival:cox", "subsample": 0.5}
model = xgboost.train(
    params, xgb_full, 5000, evals=[(xgb_full, "test")], verbose_eval=1000
)
[0]	test-cox-nloglik:8.88073
[1000]	test-cox-nloglik:8.17142
[2000]	test-cox-nloglik:8.08556
[3000]	test-cox-nloglik:8.04853
[4000]	test-cox-nloglik:8.0248
[4999]	test-cox-nloglik:8.00511

3.检查性能

C 统计量衡量我们如何根据人们的生存时间对他们进行排序(1.0 是完美排序)。

def c_statistic_harrell(pred, labels):
    total = 0
    matches = 0
    for i in range(len(labels)):
        for j in range(len(labels)):
            if labels[j] > 0 and abs(labels[i]) > labels[j]:
                total += 1
                if pred[j] > pred[i]:
                    matches += 1
    return matches / total


# see how well we can order people by survival
c_statistic_harrell(model_train.predict(xgb_test, ntree_limit=5000), y_test)
0.835090082176807

4.解释模型对整个数据集的预测

shap_values = shap.TreeExplainer(model).shap_values(X)

4.1 SHAP 摘要图

XGBoost 的 SHAP 值解释了模型的边际输出,即 Cox 比例风险模型的死亡对数几率的变化。 我们可以从下面看到,根据模型,死亡的主要危险因素是年老。 死亡风险的下一个最有力的指标是男性。

该摘要图取代了特征重要性的典型条形图。 它告诉我们哪些特征是最重要的,以及它们对数据集的影响范围。 颜色使我们能够匹配特征值的变化如何影响风险的变化(例如高白细胞计数导致高死亡风险)。

shap.summary_plot(shap_values, X)

在这里插入图片描述

4.2 SHAP 相关图

SHAP 摘要图给出了每个特征的总体概述,而 SHAP 依赖图显示了模型输出如何随特征值变化。 请注意,每个点都是一个人,单个特征值的垂直分散是由模型中的交互效应产生的。 自动选择用于着色的功能来突出显示可能驱动这些交互的因素。 稍后我们将了解如何使用 SHAP 交互值检查模型中是否确实存在交互。 请注意,SHAP 汇总图的行是将 SHAP 相关图的点投影到 y 轴上,然后由特征本身重新着色得到的。

下面我们给出了每个 NHANES I 特征的 SHAP 依赖图,揭示了有趣但预期的趋势。 请记住,其中一些值的校准可能与现代实验室测试不同,因此得出结论时要小心。

# we pass "Age" instead of an index because dependence_plot() will find it in X's column names for us
# Systolic BP was automatically chosen for coloring based on a potential interaction to check that
# the interaction is really in the model see SHAP interaction values below
shap.dependence_plot("Age", shap_values, X)

在这里插入图片描述

# we pass display_features so we get text display values for sex
shap.dependence_plot("Sex", shap_values, X, display_features=X_display)

在这里插入图片描述

# setting show=False allows us to continue customizing the matplotlib plot before displaying it
shap.dependence_plot("Systolic BP", shap_values, X, show=False)
pl.xlim(80, 225)
pl.show()

在这里插入图片描述

shap.dependence_plot("Poverty index", shap_values, X)

在这里插入图片描述

shap.dependence_plot(
    "White blood cells", shap_values, X, display_features=X_display, show=False
)
pl.xlim(2, 15)
pl.show()

在这里插入图片描述

shap.dependence_plot("BMI", shap_values, X, display_features=X_display, show=False)
pl.xlim(15, 50)
pl.show()

在这里插入图片描述

shap.dependence_plot("Serum magnesium", shap_values, X, show=False)
pl.xlim(1.2, 2.2)
pl.show()

在这里插入图片描述

shap.dependence_plot("Sedimentation rate", shap_values, X)

在这里插入图片描述

shap.dependence_plot("Serum protein", shap_values, X)

在这里插入图片描述

shap.dependence_plot("Serum cholesterol", shap_values, X, show=False)
pl.xlim(100, 400)
pl.show()

请添加图片描述

shap.dependence_plot("Pulse pressure", shap_values, X)

在这里插入图片描述

shap.dependence_plot("Serum iron", shap_values, X, display_features=X_display)

在这里插入图片描述

shap.dependence_plot("TS", shap_values, X)

在这里插入图片描述

shap.dependence_plot("Red blood cells", shap_values, X)

在这里插入图片描述

5.计算 SHAP 交互值

有关更多详细信息,请参阅 Tree SHAP 论文,但简单地说,SHAP 交互值是 SHAP 值对更高阶交互的推广。 最新版本的 XGBoost 中使用 pred_interactions 标志实现了成对交互的快速精确计算。 使用此标志,XGBoost 为每个预测返回一个矩阵,其中主效应位于对角线上,交互效应位于非对角线上。 主效应类似于线性模型获得的 SHAP 值,交互效应捕获所有高阶交互,并将它们划分为成对交互项。 请注意,整个交互矩阵的总和是模型当前输出与预期输出之间的差,因此非对角线上的交互效应被分成两半(因为每个都有两个)。 绘制交互效果时,SHAP 包会自动将非对角线值乘以 2,以获得完整的交互效果。

# takes a couple minutes since SHAP interaction values take a factor of 2 * # features
# more time than SHAP values to compute, since this is just an example we only explain
# the first 2,000 people in order to run quicker
shap_interaction_values = shap.TreeExplainer(model).shap_interaction_values(
    X.iloc[:2000, :]
)

5.1 SHAP 交互值汇总图

SHAP 交互值矩阵的汇总图绘制了汇总图矩阵,其中对角线上有主效应,对角线外有交互效应。

shap.summary_plot(shap_interaction_values, X.iloc[:2000, :])

在这里插入图片描述

5.2 SHAP 交互值依赖图

对 SHAP 交互值 a 运行依赖图可以让我们分别观察主效应和交互效应。

下面我们绘制了年龄的主要影响以及年龄的一些交互影响。 将年龄的主效应图与早期的年龄 SHAP 值图进行比较可以提供丰富的信息。 主效应图没有垂直分散,因为相互作用效应全部以非对角线项捕获。

shap.dependence_plot(
    ("Age", "Age"),
    shap_interaction_values,
    X.iloc[:2000, :],
    display_features=X_display.iloc[:2000, :],
)

在这里插入图片描述

现在我们绘制涉及年龄的交互效应。 这些效应捕获了原始 SHAP 图中存在但上面的主效应图中缺失的所有垂直色散。 下图涉及年龄和性别,显示基于性别的死亡风险差距因年龄而异,并在 60 岁时达到峰值。

shap.dependence_plot(
    ("Age", "Sex"),
    shap_interaction_values,
    X.iloc[:2000, :],
    display_features=X_display.iloc[:2000, :],
)

在这里插入图片描述

shap.dependence_plot(
    ("Age", "Systolic BP"),
    shap_interaction_values,
    X.iloc[:2000, :],
    display_features=X_display.iloc[:2000, :],
)

在这里插入图片描述

shap.dependence_plot(
    ("Age", "White blood cells"),
    shap_interaction_values,
    X.iloc[:2000, :],
    display_features=X_display.iloc[:2000, :],
)

在这里插入图片描述

shap.dependence_plot(
    ("Age", "Poverty index"),
    shap_interaction_values,
    X.iloc[:2000, :],
    display_features=X_display.iloc[:2000, :],
)

在这里插入图片描述

shap.dependence_plot(
    ("Age", "BMI"),
    shap_interaction_values,
    X.iloc[:2000, :],
    display_features=X_display.iloc[:2000, :],
)

在这里插入图片描述

shap.dependence_plot(
    ("Age", "Serum magnesium"),
    shap_interaction_values,
    X.iloc[:2000, :],
    display_features=X_display.iloc[:2000, :],
)

在这里插入图片描述

Now we show a couple examples with systolic blood pressure.

shap.dependence_plot(
    ("Systolic BP", "Systolic BP"),
    shap_interaction_values,
    X.iloc[:2000, :],
    display_features=X_display.iloc[:2000, :],
)

在这里插入图片描述

shap.dependence_plot(
    ("Systolic BP", "Age"),
    shap_interaction_values,
    X.iloc[:2000, :],
    display_features=X_display.iloc[:2000, :],
)

在这里插入图片描述

shap.dependence_plot(
    ("Systolic BP", "Age"),
    shap_interaction_values,
    X.iloc[:2000, :],
    display_features=X_display.iloc[:2000, :],
)

在这里插入图片描述

import matplotlib.pylab as pl
import numpy as np
tmp = np.abs(shap_interaction_values).sum(0)
for i in range(tmp.shape[0]):
    tmp[i, i] = 0
inds = np.argsort(-tmp.sum(0))[:50]
tmp2 = tmp[inds, :][:, inds]
pl.figure(figsize=(12, 12))
pl.imshow(tmp2)
pl.yticks(
    range(tmp2.shape[0]), X.columns[inds], rotation=50.4, horizontalalignment="right"
)
pl.xticks(
    range(tmp2.shape[0]), X.columns[inds], rotation=50.4, horizontalalignment="left"
)
pl.gca().xaxis.tick_top()
pl.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1278880.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JDK17的安装与配置

JDK17的安装与配置 下载地址安装步骤配置环境变量验证安装是否成功 下载地址 此jdk17安装的系统是win10系统 https://www.oracle.com/java/technologies/downloads/ 这里选择JDK17进行下载 下载完成之后,显示如下图: 安装步骤 自定义的安装路径&…

【从删库到跑路 | MySQL总结篇】事务详细介绍

个人主页:兜里有颗棉花糖 欢迎 点赞👍 收藏✨ 留言✉ 加关注💓本文由 兜里有颗棉花糖 原创 收录于专栏【MySQL学习专栏】🎈 本专栏旨在分享学习MySQL的一点学习心得,欢迎大家在评论区讨论💌 目录 一、事务…

多项式拟合求解

目录 简介 基本原理 例1 例2 例3 参考资料 简介 多项式拟合可以用最小二乘求解,不管是一元高阶函数,还是多元多项式函数,还是二者的混合,都可以通过统一的方法求解。当然除了最小二乘法,还是其他方法可以求解&…

分享一个国内可用的免费GPT4-AI提问AI绘画网站工具

一、前言 ChatGPT GPT4.0,Midjourney绘画,相信对大家应该不感到陌生吧?简单来说,GPT-4技术比之前的GPT-3.5相对来说更加智能,会根据用户的要求生成多种内容甚至也可以和用户进行创作交流。 然而,GPT-4对普…

【尾递归】

尾递归 如果函数在返回前才进行递归调用,则该函数可以被编译器或解释器优化,使其在空间效率上与迭代相当。这种情况被称为「尾递归 tail recursion」。 普通递归:当函数返回到上一层级的函数后,需要继续执行代码,因此…

CookieSession Redis 到JWT会话管理历史

单应用时期,通常使用 Cookies 和 Session 进行会话管理。 用户登录后,服务器创建一个唯一的会话标识符(Session ID),将其存储在浏览器的 Cookies 中,并在服务端维护一个关联该标识符的会话对象。 这种方…

【C++】类和对象——初始化列表和static修饰成员

首先我们来谈一下初始化列表,它其实是对于我们前边构造函数体内初始化的一种补充,换一种说法,它以后才是我们构造函数的主体部分。 我们先考虑一个问题,就是一个类里面有用引用或const初始化的成员变量,比如说&#xf…

已解决AttributeError: module ‘gradio‘ has no attribute ‘outputs‘

问题描述 Traceback (most recent call last): File "/media/visionx/monica/project/ResShift/app.py", line 118, in <module> gr.outputs.File(label"Download the output")AttributeError: module gradio has no attribute outputs 解决办…

Java高级技术-单元测试

单元测试 Junit单元测试框架 Junit单元测试-快速入门 方法类 测试类 Junit框架的基本注解

同旺科技 USB TO SPI / I2C --- 调试W5500_Ping测试

所需设备&#xff1a; 内附链接 1、USB转SPI_I2C适配器(专业版); 首先&#xff0c;连接W5500模块与同旺科技USB TO SPI / I2C适配器&#xff0c;如下图&#xff1a; 设置寄存器&#xff1a; SHAR&#xff08;源MAC地址寄存器&#xff09;&#xff0c;该寄存器用来设置源MAC…

中国人工智能

随着科技的飞速发展&#xff0c;人工智能&#xff08;AI&#xff09;作为一项前沿技术在各个领域展现出了强大的潜力。本文将探讨中国人工智能的历史、现状&#xff0c;并展望其未来发展。 人工智能的起源与历史 人工智能的概念最早诞生于1956年的美国达特茅斯学院的夏季研讨会…

231202 刷题日报

周四周五&#xff0c;边值班边扯皮&#xff0c;没有刷题。。 今天主要是做了: 1. 稀疏矩阵压缩&#xff0c;十字链表法 2. 快速排序 3.349. 两个数组的交集​​​​​ 4. 174. 地下城游戏 要注意溢出问题&#xff01;

KNN实战-图像识别

数据说明 是在循环0-9的数字一直循环500次所得到的数据&#xff0c;然后以手写照片的形式存在 识别的步骤 加载数据构建目标值构建模型参数调优可视化展示 加载数据 import numpy as np import matplotlib.pyplot as plt # 记载数据 data np.load(./digit.npy) data构建目…

【HDFS】调试慢节点pipiline ack信息

Client - DN1 - DN2 - DN3 DN3 send ack:[0][d3]。 DN2 send ack: [从dn2入队到收到dn3的ack耗时,0] [d2,d3]。 DN1 send ack: [pkt从dn1入队到收到dn2的ack耗时,pkt从dn2入队到收到dn3的ack耗时,0] [d1,d2,d3]。 Client receive: 就是DN1发送过来数据。 客户端收到的第一个…

000FreeCAD源码学习--MainGui.cpp

目录 1 MainGui.cpp源代码 2 int main()函数分析 3 编译运行截图 FreeCADMain项目下的MainGui.cpp 1 MainGui.cpp源代码 int main( int argc, char ** argv ) { #if defined (FC_OS_LINUX) || defined(FC_OS_BSD)setlocale(LC_ALL, ""); // use native environm…

【C++干货铺】继承 | 多继承 | 虚继承

个人主页点击直达&#xff1a;小白不是程序媛 C系列专栏&#xff1a;C干货铺 代码仓库&#xff1a;Gitee 目录 继承的概念及定义 继承的概念 继承的定义 继承基类成员访问方式的变化 基类和派生类的赋值转化 继承中的作用域 派生类的默认成员函数 构造函数 拷贝构造…

判断是否有环形链表

问题描述&#xff1a; 给定一个链表&#xff0c;判断链表中是否有环。 如果链表中有某个节点&#xff0c;可以通过连续跟踪next指针再次到达&#xff0c;则链表中存在环。为了表示给定链表中的环&#xff0c;我们使用整数pos来表示链表尾连接到链表中的位置&#xff08;索引从0…

抖音视频如何无水印保存?抖音视频无水印保存教程

抖音视频如何无水印保存&#xff1f;当下短视频盛行时代&#xff0c;抖音作为当下主流短视频平台之一&#xff0c;每天都有数以亿计的用户在抖音上分享自己的创作&#xff0c;然后当我们遇到感兴趣的视频&#xff0c;下载保存后会发现带有水印&#xff0c;那么抖音视频如何无水…

scrapy介绍,并创建第一个项目

一、scrapy简介 scrapy的概念 Scrapy是一个Python编写的开源网络爬虫框架。它是一个被设计用于爬取网络数据、提取结构性数据的框架。 Scrapy 使用了Twisted异步网络框架&#xff0c;可以加快我们的下载速度。 Scrapy文档地址&#xff1a;http://scrapy-chs.readthedocs.io/z…

LangChain 17 LangSmith调试、测试、评估和监视基于任何LLM框架构建的链和智能代理

LangChain系列文章 LangChain 实现给动物取名字&#xff0c;LangChain 2模块化prompt template并用streamlit生成网站 实现给动物取名字LangChain 3使用Agent访问Wikipedia和llm-math计算狗的平均年龄LangChain 4用向量数据库Faiss存储&#xff0c;读取YouTube的视频文本搜索I…