SHAP(五):使用 XGBoost 进行人口普查收入分类

news2025/1/23 11:18:54

SHAP(五):使用 XGBoost 进行人口普查收入分类

本笔记本演示了如何使用 XGBoost 预测个人年收入超过 5 万美元的概率。 它使用标准 UCI 成人收入数据集。 要下载此笔记本的副本,请访问 github。

XGBoost 等梯度增强机方法对于具有多种形式的表格样式输入数据的此类预测问题来说是最先进的。 Tree SHAP(arXiv 论文)允许精确计算树集成方法的 SHAP 值,并已直接集成到 C++ XGBoost 代码库中。 这允许快速精确计算 SHAP 值,无需采样,也无需提供背景数据集(因为背景是从树木的覆盖范围推断出来的)。

在这里,我们演示如何使用 SHAP 值来理解 XGBoost 模型预测。

import matplotlib.pylab as pl
import numpy as np
import xgboost
from sklearn.model_selection import train_test_split

import shap

# print the JS visualization code to the notebook
shap.initjs()

1.加载数据集

X, y = shap.datasets.adult()
X_display, y_display = shap.datasets.adult(display=True)

# create a train/test split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7)
d_train = xgboost.DMatrix(X_train, label=y_train)
d_test = xgboost.DMatrix(X_test, label=y_test)

2.训练模型

params = {
    "eta": 0.01,
    "objective": "binary:logistic",
    "subsample": 0.5,
    "base_score": np.mean(y_train),
    "eval_metric": "logloss",
}
model = xgboost.train(
    params,
    d_train,
    5000,
    evals=[(d_test, "test")],
    verbose_eval=100,
    early_stopping_rounds=20,
)
[0]	test-logloss:0.54663
[100]	test-logloss:0.36373
[200]	test-logloss:0.31793
[300]	test-logloss:0.30061
[400]	test-logloss:0.29207
[500]	test-logloss:0.28678
[600]	test-logloss:0.28381
[700]	test-logloss:0.28181
[800]	test-logloss:0.28064
[900]	test-logloss:0.27992
[1000]	test-logloss:0.27928
[1019]	test-logloss:0.27935

3.经典特征归因

在这里,我们尝试 XGBoost 附带的全局特征重要性计算。 请注意,它们都是相互矛盾的,这激励了 SHAP 值的使用,因为它们具有一致性保证(意味着它们将正确排序特征)。

xgboost.plot_importance(model)
pl.title("xgboost.plot_importance(model)")
pl.show()


在这里插入图片描述

xgboost.plot_importance(model, importance_type="cover")
pl.title('xgboost.plot_importance(model, importance_type="cover")')
pl.show()


在这里插入图片描述

xgboost.plot_importance(model, importance_type="gain")
pl.title('xgboost.plot_importance(model, importance_type="gain")')
pl.show()


在这里插入图片描述

4,解释预测

在这里,我们使用集成到 XGBoost 中的 Tree SHAP 实现来解释整个数据集(32561 个样本)。

# this takes a minute or two since we are explaining over 30 thousand samples in a model with over a thousand trees
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

4.1 可视化单个预测

请注意,我们使用“显示值”数据框,因此我们得到了漂亮的字符串而不是类别代码。

shap.force_plot(explainer.expected_value, shap_values[0, :], X_display.iloc[0, :])

在这里插入图片描述

4.2 将许多预测可视化

为了让浏览器满意,我们只可视化 1,000 个人。

shap.force_plot(
    explainer.expected_value, shap_values[:1000, :], X_display.iloc[:1000, :]
)

在这里插入图片描述

5.平均重要性条形图

这取整个数据集中 SHAP 值大小的平均值,并将其绘制为简单的条形图。

shap.summary_plot(shap_values, X_display, plot_type="bar")


在这里插入图片描述

6.SHAP 概要图

我们没有使用典型的特征重要性条形图,而是使用每个特征的 SHAP 值的密度散点图来确定每个特征对验证数据集中个体的模型输出有多大影响。 特征按所有样本的 SHAP 值大小之和排序。 有趣的是,关系特征比资本收益特征具有更大的总体模型影响,但对于那些资本收益重要的样本,它比年龄具有更大的影响。 换句话说,资本收益对少数预测的影响较大,而年龄对所有预测的影响较小。

请注意,当散点不适合在线时,它们会堆积起来以显示密度,每个点的颜色代表该个体的特征值。

shap.summary_plot(shap_values, X)


在这里插入图片描述

7.SHAP 相关图

SHAP 依赖图显示单个特征对整个数据集的影响。 他们绘制了多个样本中某个特征的值与该特征的 SHA 值的关系图。 SHAP 依赖图与部分依赖图类似,但考虑了特征中存在的交互效应,并且仅在数据支持的输入空间区域中定义。 单个特征值处的 SHAP 值的垂直分散是由交互效应驱动的,并且选择另一个特征进行着色以突出可能的交互。

for name in X_train.columns:
    shap.dependence_plot(name, shap_values, X, display_features=X_display)


在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
)

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

8.简单的监督聚类

按 shap_values 对人们进行聚类会导致与手头的预测任务相关的组(在本例中是他们的收入潜力)。

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

shap_pca50 = PCA(n_components=12).fit_transform(shap_values[:1000, :])
shap_embedded = TSNE(n_components=2, perplexity=50).fit_transform(shap_values[:1000, :])
from matplotlib.colors import LinearSegmentedColormap

cdict1 = {
    "red": (
        (0.0, 0.11764705882352941, 0.11764705882352941),
        (1.0, 0.9607843137254902, 0.9607843137254902),
    ),
    "green": (
        (0.0, 0.5333333333333333, 0.5333333333333333),
        (1.0, 0.15294117647058825, 0.15294117647058825),
    ),
    "blue": (
        (0.0, 0.8980392156862745, 0.8980392156862745),
        (1.0, 0.3411764705882353, 0.3411764705882353),
    ),
    "alpha": ((0.0, 1, 1), (0.5, 1, 1), (1.0, 1, 1)),
}  # #1E88E5 -> #ff0052
red_blue_solid = LinearSegmentedColormap("RedBlue", cdict1)
f = pl.figure(figsize=(5, 5))
pl.scatter(
    shap_embedded[:, 0],
    shap_embedded[:, 1],
    c=shap_values[:1000, :].sum(1).astype(np.float64),
    linewidth=0,
    alpha=1.0,
    cmap=red_blue_solid,
)
cb = pl.colorbar(label="Log odds of making > $50K", aspect=40, orientation="horizontal")
cb.set_alpha(1)
cb.outline.set_linewidth(0)
cb.ax.tick_params("x", length=0)
cb.ax.xaxis.set_label_position("top")
pl.gca().axis("off")
pl.show()


在这里插入图片描述

for feature in ["Relationship", "Capital Gain", "Capital Loss"]:
    f = pl.figure(figsize=(5, 5))
    pl.scatter(
        shap_embedded[:, 0],
        shap_embedded[:, 1],
        c=X[feature].values[:1000].astype(np.float64),
        linewidth=0,
        alpha=1.0,
        cmap=red_blue_solid,
    )
    cb = pl.colorbar(label=feature, aspect=40, orientation="horizontal")
    cb.set_alpha(1)
    cb.outline.set_linewidth(0)
    cb.ax.tick_params("x", length=0)
    cb.ax.xaxis.set_label_position("top")
    pl.gca().axis("off")
    pl.show()


在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

训练每棵树只有两个叶子的模型,因此特征之间没有交互项

强制模型没有交互项意味着某个特征对结果的影响不依赖于任何其他特征的值。 这反映在下面的 SHAP 相关图中,因为没有垂直扩展。 垂直分布反映了一个特征的单个值可能对模型输出产生不同的影响,具体取决于个体呈现的其他特征的上下文。 然而,对于没有交互项的模型,无论个体可能具有哪些其他属性,特征总是具有相同的影响。

与传统的部分相关图相比,SHAP 相关图的优点之一是能够区分具有交互项和不具有交互项的模型。 换句话说,SHAP 相关图通过给定特征值处散点图的垂直方差给出了交互项大小的概念。

# train final model on the full data set
params = {
    "eta": 0.05,
    "max_depth": 1,
    "objective": "binary:logistic",
    "subsample": 0.5,
    "base_score": np.mean(y_train),
    "eval_metric": "logloss",
}
model_ind = xgboost.train(
    params,
    d_train,
    5000,
    evals=[(d_test, "test")],
    verbose_eval=100,
    early_stopping_rounds=20,
)
[0]	test-logloss:0.54113
[100]	test-logloss:0.35499
[200]	test-logloss:0.32848
[300]	test-logloss:0.31901
[400]	test-logloss:0.31331
[500]	test-logloss:0.30930
[600]	test-logloss:0.30619
[700]	test-logloss:0.30371
[800]	test-logloss:0.30184
[900]	test-logloss:0.30035
[1000]	test-logloss:0.29913
[1100]	test-logloss:0.29796
[1200]	test-logloss:0.29695
[1300]	test-logloss:0.29606
[1400]	test-logloss:0.29525
[1500]	test-logloss:0.29471
[1565]	test-logloss:0.29439
shap_values_ind = shap.TreeExplainer(model_ind).shap_values(X)

请注意,下面的交互颜色条对于该模型来说没有意义,因为它没有交互。

for name in X_train.columns:
    shap.dependence_plot(name, shap_values_ind, X, display_features=X_display)
invalid value encountered in divide
invalid value encountered in divide

在这里插入图片描述

invalid value encountered in divide
invalid value encountered in divide

在这里插入图片描述

invalid value encountered in divide
invalid value encountered in divide

在这里插入图片描述

invalid value encountered in divide
invalid value encountered in divide

在这里插入图片描述

invalid value encountered in divide
invalid value encountered in divide

在这里插入图片描述

invalid value encountered in divide
invalid value encountered in divide

在这里插入图片描述

invalid value encountered in divide
invalid value encountered in divide

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1279558.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++ day48 打家劫舍

题目1:198 打家劫舍 题目链接:打家劫舍 对题目的理解 专业小偷偷盗房屋的钱财,每个房屋存放的金额用非负整数数组表示; 如果两间相邻的房屋在同一晚上被小偷闯入,系统会自动报警; 不触动警报装置的情况…

吸积效应:为什么接口会越来越臃肿?我们从一个接口说起

欢迎大家关注公众号「JAVA前线」查看更多精彩分享文章,主要包括源码分析、实际应用、架构思维、职场分享、产品思考等等,同时欢迎大家加我微信「java_front」一起交流学习 1 从一个接口说起 1.1 初始接口 假设现在有一个创建订单接口: pub…

C语言每日一题(44)删除排序链表中的重复元素 II

力扣 82 删除排序链表中的重复元素 II 题目描述 给定一个已排序的链表的头 head , 删除原始链表中所有重复数字的节点,只留下不同的数字 。返回 已排序的链表 。 示例 1: 输入:head [1,2,3,3,4,4,5] 输出:[1,2,5]示…

mac安装解压缩rar后缀文件踩坑

mac默认能够解压缩zip后缀的文件,如果是rar后缀的自己需要下载相关的工具解压 下载地址: https://www.rarlab.com/download.htm mac我是因特尔芯片所以下载 x64 然后解压缩文件进入目录 rar中 将可执行文件 rar、unrar 移动到 /usr/local/bin目录下即…

【高效开发工具系列】jackson入门使用

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

ASP.NET Core MVC过滤器

1、过滤器分为授权过滤、资源访问过滤、操作方法(Action)过滤、结果过滤、异常过滤、终结点过滤。上一次咱们没有说异常过滤和终结点过滤,不过老周后面会说的。对这些过滤器,你有印象就行了。 2、所有过滤器接口都有同步版本和异…

可视化开源编辑器Swagger Editor本地部署并实现远程访问管理编辑文档

最近,我发现了一个超级强大的人工智能学习网站。它以通俗易懂的方式呈现复杂的概念,而且内容风趣幽默。我觉得它对大家可能会有所帮助,所以我在此分享。点击这里跳转到网站。 文章目录 Swagger Editor本地接口文档公网远程访问1. 部署Swagge…

【Python百练——第3练】矩形类及操作

💐作者:insist-- 💐个人主页:insist-- 的个人主页 理想主义的花,最终会盛开在浪漫主义的土壤里,我们的热情永远不会熄灭,在现实平凡中,我们终将上岸,阳光万里 ❤️欢迎点…

一缕青丝寄相思

10年8月16日七夕节男孩向女孩表白,女孩不知道那天是七夕,也没有读懂男孩的爱,女孩在9月22日中秋,向男孩打开了心门,男孩却没有懂女孩的心思.13年后的一封问候邮件,一束女孩的长发和回不去的青春 洒满阳光的午后 转眼间看到你的笑脸 微笑着你对我说 遇上你认识我真好 你说得好莫…

论文解读--Robust lane detection and tracking with Ransac and Kalman filter

使用随机采样一致性和卡尔曼滤波的鲁棒的车道线跟踪 摘要 在之前的一篇论文中,我们描述了一种使用霍夫变换和迭代匹配滤波器的简单的车道检测方法[1]。本文扩展了这项工作,通过结合逆透视映射来创建道路的鸟瞰视图,应用随机样本共识来帮助消…

力扣日记12.3-【二叉树篇】二叉树的所有路径

力扣日记:【二叉树篇】二叉树的所有路径 日期:2023.12.3 参考:代码随想录、力扣 257. 二叉树的所有路径 题目描述 难度:简单 给你一个二叉树的根节点 root ,按 任意顺序 ,返回所有从根节点到叶子节点的路径…

JAVA代码优化:CommandLineRunner(项目启动之前,预先加载数据)

CommandLineRunner接口是Spring Boot框架中的一个接口,用于在应用程序启动后执行一些特定的代码逻辑。它是一个函数式接口,只包含一个run方法,该方法在应用程序启动后被自动调用。可以帮助我们在应用程序启动后自动执行一些代码逻辑&#xff…

sharding-jdbc实现分库分表

shigen日更文章的博客写手,擅长Java、python、vue、shell等编程语言和各种应用程序、脚本的开发。记录成长,分享认知,留住感动。 😅😅最近几天的状态有点不对,所以有几天没有更新了。 当我们的数据量比较大…

css实现简单的抽奖动画效果和旋转效果,还有春联效果

使用css的animation和transform和transition可以实现简单的图片放大缩小,旋转,位移的效果,由此可以延伸的动画效果还是挺多的,比如图片慢慢放大,图片慢慢旋转并放大,图片慢慢变化位置等等, 抽奖…

使用pytorch从零开始实现迷你GPT

生成式建模知识回顾: [1] 生成式建模概述 [2] Transformer I,Transformer II [3] 变分自编码器 [4] 生成对抗网络,高级生成对抗网络 I,高级生成对抗网络 II [5] 自回归模型 [6] 归一化流模型 [7] 基于能量的模型 [8] 扩散模型 I, 扩散模型 II…

MathType 7.5.2中文版软件使用期到了怎么办?

MathType 7.5.2中文版作为一款专业的公式编辑器,MathType受到很多人的青睐,它可以将编辑好的公式保存成多种图片格式或透明图片模式,可以很方便的添加或移除符号、表达式等模板(只需要简单地用鼠标拖进拖出即可),也可以…

Verilog 入门(八)(验证)

文章目录 编写测试验证程序波形产生值序列重复模式 测试验证程序实例从文本文件中读取向量实例:时序检测器 测试验证程序用于测试和验证设计方法的正确性。Verilog 提供强有力的结构来说明测试验证程序。 编写测试验证程序 测试验证程序有三个主要目的:…

Java 表达式引擎

企业的需求往往是多样化且复杂的,对接不同企业时会有不同的定制化的业务模型和流程。我们在业务系统中使用表达式引擎,集中配置管理业务规则,并实现实时决策和计算,可以提高系统的灵活性和响应能力。 引入规则引擎似乎就能解决这个…

UE学习C++(1)创建actor

创建新C类 在 虚幻编辑器 中,点击 文件(File) 下拉菜单,然后选择 新建C类...(New C Class...) 命令: 此时将显示 选择父类(Choose Parent Class) 菜单。可以选择要扩展的…

你知道MySQL 中的 order by 是怎么工作的吗?

欢迎大家到我的博客浏览。order by是怎么工作的? | YinKais Blog今天我们来看一下 MySQL 中 “ order by ” 是怎么工作的。 我们以一个实际的例子,来探讨这个问题: 假设我们的表是这样定义的: CREATE TABLE t (id int(11) NOT…