【可解释性机器学习】解释基于XGBoost对泰坦尼克号数据集的预测过程和结果

news2025/3/12 13:40:38

解释基于XGBoost对泰坦尼克号数据集的预测过程和结果

  • 1. 训练数据
  • 2. 简单的 XGBoost 分类器
  • 3. 解释重量
  • 4. 解释预测
  • 5. 添加文本特性
  • 参考资料

本文介绍如何分析XGBoost分类器的预测( eli5也支持 XGBoost和大多数 scikit-learn树集成的回归)。 我们将使用 Titanic数据集,它很小且没有太多特征,但仍然足够有趣。

使用XGBoost 0.81和从https://www.kaggle.com/c/titanic/data下载的数据(它也存储在eli5源码库中:https://github.com/TeamHG-Memex/eli5/blob/master/notebooks/titanic-train.csv)。

1. 训练数据

首先,加载数据:

import pandas as pd
# 直接从github代码仓库位置加载
url = "https://github.com/TeamHG-Memex/eli5/blob/017c738f8dcf3e31346de49a390835ffafad3f1b/notebooks/titanic-train.csv?raw=true"
data = pd.read_csv(url)
data.head()

数据前五行
变量说明:

  • Age: 年龄
  • Cabin: 船舱
  • Embarked: 出发港 (C = 瑟堡港; Q = 皇后镇; S = 南安普敦)
  • Fare: 乘客票价
  • Name: 姓名
  • Parch: 船上父母/子女人数
  • Pclass: 乘客类别 (1 = 1st; 2 = 2nd; 3 = 3rd)
  • Sex: 性别
  • Sibsp: 船上兄弟姐妹/配偶人数
  • Survived: 幸存(0 = No; 1 = Yes)
  • Ticket: 船票号码

接下来,把数据和我们试图预测的特征(是否生存)分开:

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

data = data.to_dict('records') # 首先将data的每一行都转换为字典,其中键是列明,值是单元格中的数据
_all_xs = [{k: v for k, v in row.items() if k != 'Survived'} for row in data]
_all_ys = np.array([int(row['Survived']) for row in data]) # 标签数据

all_xs, all_ys = shuffle(_all_xs, _all_ys, random_state=0) # 打乱顺序
train_xs, valid_xs, train_ys, valid_ys = train_test_split(all_xs, all_ys, test_size=0.25, random_state=0)
print('{} items total, {:.1%} true'.format(len(all_xs), np.mean(all_ys)))
'''
891 items total, 38.4% true
'''

我们只做最少的预处理:将明显连续的Age和Fare变量转换为 float,将SibSp和Parch转换为整数。删除缺少的年龄值。

for x in all_xs:
    if x['Age']:
        x['Age'] = float(x['Age'])
    else:
        x.pop('Age')
    x['Fare'] = float(x['Fare'])
    x['SibSp'] = int(x['SibSp'])
    x['Parch'] = int(x['Parch'])

2. 简单的 XGBoost 分类器

首先使用 xbgoost.XGBClassifiersklearn.feature_extraction.DictVectorizer 构建一个非常简单的分类器,并使用 10 折交叉验证检查其准确性:

from xgboost import XGBClassifier
from sklearn.feature_extraction import DictVectorizer
from sklearn.pipeline import make_pipeline
from sklearn.model_selection import cross_val_score

clf = XGBClassifier()
vec = DictVectorizer()
pipeline = make_pipeline(vec, clf)

def evaluate(_clf):
  scores = cross_val_score(_clf, all_xs, all_ys, scoring='accuracy', cv=10)
  print('Accuracy: {:.3f} ± {:.3f}'.format(np.mean(scores), 2 * np.std(scores)))
  _clf.fit(train_xs, train_ys)  # so that parts of the original pipeline are fitted

evaluate(pipeline)
'''
Accuracy: 0.828 ± 0.061
'''

上面的代码有一个棘手的地方:可能只是将 dense=True 传递给 DictVectorizer:毕竟,在这种情况下矩阵很小。 但这不是一个很好的解决方案,因为我们将失去区分缺失特征和零值特征的能力

3. 解释重量

为了计算预测,XGBoost 对所有树的预测求和。 树的数量由 n_estimators 参数控制,默认为 100。 每棵树本身并不是一个很好的预测器,但通过对所有树求和,XGBoost 能够在许多情况下提供可靠的估计。 这是其中一棵树:

booster = clf.get_booster()
original_feature_names = booster.feature_names
booster.feature_names = vec.get_feature_names()
print(booster.get_dump()[0])
# recover original feature names
booster.feature_names = original_feature_names
'''
0:[Sex=female<-9.53674316e-07] yes=1,no=2,missing=1
	1:[Age<13] yes=3,no=4,missing=4
		3:[SibSp<2] yes=7,no=8,missing=7
			7:leaf=0.145454556
			8:leaf=-0.125
		4:[Fare<26.2687492] yes=9,no=10,missing=9
			9:leaf=-0.151515156
			10:leaf=-0.0727272779
	2:[Pclass<2.5] yes=5,no=6,missing=5
		5:[Fare<12.1750002] yes=11,no=12,missing=11
			11:leaf=0.0500000007
			12:leaf=0.175193802
		6:[Fare<24.8083496] yes=13,no=14,missing=13
			13:leaf=0.0365591422
			14:leaf=-0.151999995

'''

可以看到这棵树检查了 Sex、Age、Pclass、Fare 和 SibSp 特征。 leaf 给出了单个树的决定,并且它们对ensemble中的所有树求和

eli5.show_weights()检查特征重要性:

from eli5 import show_weights
show_weights(clf, vec=vec)

输出结果
有几种不同的方法可以计算特征重要性。 默认情况下,使用“gain”,即特征在树中使用时的平均增益。 其他类型是“weight”——一个特征被用来分割数据的次数,以及“cover”——特征的平均覆盖率。 您可以使用 importance_type 参数传递它。

现在知道两个最重要的特征是 Sex=female 和 Pclass=3,但仍然不知道 XGBoost 如何根据它们的值来决定做出什么样的预测。

4. 解释预测

为了更好地了解我们的分类器是如何工作的,使用 eli5.show_prediction() 检查单个预测:

from eli5 import show_prediction
show_prediction(clf, valid_xs[1], vec=vec, show_feature_values=True)

输出结果
Weight表示每个特征对所有树的最终预测的贡献程度。 权重计算的思路在http://blog.datadive.net/interpreting-random-forests/中有描述;eli5XGBoost和大多数scikit-learn树集成提供了该算法的独立实现。

在这里,可以看到分类器认为成为女性是件好事,但乘坐三等车厢是不好的。 一些特征的值是“Missing”(我们通过 show_feature_values=True 来查看值):这意味着该特征缺失,所以在这种情况下最好不要在南安普顿上船。 这是我们决定使用稀疏矩阵的地方——我们仍然看到 Parch 为零,而不是缺失。

可以使用 feature_filter 参数仅显示存在的特征:它是一个接受特征名称和值的函数,并为应该显示的特征返回 True 值:

no_missing = lambda feature_name, feature_value: not np.isnan(feature_value)
show_prediction(clf, valid_xs[1], vec=vec, show_feature_values=True, feature_filter=no_missing)

输出结果

5. 添加文本特性

现在将 Name 字段视为分类的,就像其他文本特征一样。 但是在这个数据集中,每个名字都是唯一的,所以 XGBoost 根本不使用这个特性,因为它是一个很差的鉴别器:它在第 3 部分的权重表中不存在。

但是 Name 仍然可能包含一些有用的信息。 我们不猜测如何最好地对其进行预处理以及提取哪些特征,所以使用最通用的char ngram 向量化器:

from sklearn.pipeline import FeatureUnion
from sklearn.feature_extraction.text import CountVectorizer

vec2 = FeatureUnion([
    ('Name', CountVectorizer(
        analyzer='char_wb',
        ngram_range=(3, 4),
        preprocessor=lambda x: x['Name'],
        max_features=100,
    )),
    ('All', DictVectorizer()),
])
clf2 = XGBClassifier()
pipeline2 = make_pipeline(vec2, clf2)
evaluate(pipeline2)
'''
Accuracy: 0.824 ± 0.076
'''

在这种情况下,管道更加复杂,稍微改进了结果,但改进并不显着。 继续查看特征重要性:

show_weights(clf2, vec=vec2)

输出结果
看到现在有很多特征来自 Name 字段(事实上,仅基于 Name 的分类器给出了大约 0.79 的准确度)。 以这种方式列出的名称特征不是很有用,当我们检查预测时它们更有意义。 这里隐藏了缺失的特征,因为文本中有很多缺失的特征,但它们并不是很有趣:

from IPython.display import display

for idx in [4, 5, 7]:
    display(show_prediction(clf2, valid_xs[idx], vec=vec2, show_feature_values=True, feature_filter=no_missing))

输出结果
Name 字段中的文本特征直接在文本中突出显示,权重之和在权重表中显示为“Name: Highlighted in text (sum)”。

看起来姓名分类器试图从标题“先生”中推断出性别和地位。 “Mr.”不好是因为女人先得救,做“Mrs.(结婚)”比“Miss.”好。姓名分类器也在尝试挑选姓名的某些部分,尤其是结尾,或许作为社会地位的代表。 如果来自三等舱,那么成为“Mary”尤其糟糕。

参考资料

[1] Explaining XGBoost predictions on the Titanic dataset

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/187519.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据结构】8.5 归并排序

文章目录相邻两个有序子序列的归并归并排序算法归并排序算法分析基本思想 将两个或两个以上的有序子序列归并为一个有序序列。在内部排序中&#xff0c;通常采用的是2-路归并排序。 即&#xff1a;将两个位置相邻的有序子序列 R[l…m] 和 R[m1…n] 归并为一个有序序列 R[l…n]…

1个寒假能学会多少网络安全技能?

现在可以看到很多标题都声称三个月内就可以转行网络安全领域&#xff0c;并且成为月入15K的网络工程师。那么&#xff0c;这个寒假的时间能学多少网络安全知识&#xff1f;是否能入门网络安全工程师呢&#xff1f; 答案是肯定的。 虽然网络完全知识是一门广泛的学科&#xff…

在线支付系列【9】微信支付之申请微信公众号

有道无术&#xff0c;术尚可求&#xff0c;有术无道&#xff0c;止于术。 文章目录前言申请微信公众号前言 由于微信支付的产品体系全部搭载于微信的社交体系之上&#xff0c;所以直连商户或服务商接入微信支付之前&#xff0c;都需要有一个微信社交载体&#xff0c;该载体对应…

天啦撸~ChatGPT通过国际软件测试工程师(ISTQB)认证了~

天啦撸&#xff01;目前最火的AI应用ChatGPT通过ISTQB认证了~ 近期&#xff0c;国外的一位工程师&#xff0c;放出了他用ChatGPT通过认证的相关信息。 ChatGPT相信大家都知道是什么了&#xff0c;ISTQB相信很多测试小伙伴也不陌生&#xff0c;而且很多考证的小伙伴也对此梦寐以…

Linux之网络性能测试工具netperf实践

一、netperf简介 Netperf是一种网络性能的测量工具&#xff0c;主要针对基于TCP或UDP的传输。Netperf根据应用的不同&#xff0c;可以进行不同模式的网络性能测试&#xff0c;即批量数据传输&#xff08;bulk data transfer&#xff09;模式和请求/应答&#xff08;request/rep…

公司通知要大裁员,hr太强势,和所有人吵起来,老板见势不妙,不得不答应大家要求,把HR一起裁掉了!...

在裁员中&#xff0c;hr一般都会代表老板的利益和员工对抗&#xff0c;但如果hr和员工闹翻了&#xff0c;老板会维护hr吗&#xff1f;一位网友说&#xff1a;一上班就收到消息要裁员&#xff0c;立马让报上名单面谈&#xff0c;锁电脑关权限。后面那些人面谈的时候吵起来了&…

OpenAI Chatgpt注册及使用教程

零、什么是chatgpt?​ 1、简介 ChatGPT&#xff08;Chat Generative Pre-trained Transformer&#xff09;是OpenAI于 2022 年 11 月推出的聊天机器人。它建立在 OpenAI 的GPT-3大型语言模型家族之上&#xff0c;并经过微调&#xff08;一种迁移学习的方法&#xff09;…

双点双向的ISIS与OSPF、OSPF与OSPF、ISIS与ISIS环境以路由策略解决(1tag、2tag、4tag介绍与配置)

3.1.1 双点双向的ISIS与OSPF、OSPF与OSPF、ISIS与ISIS环境以路由策略解决&#xff08;1tag、2tag、4tag介绍与配置&#xff09; OSPF与ISIS双点双向 次优的产生与解决&#xff1a; 由于OSPF引入外部路由之后其优先级为150&#xff0c;再由ASBR进行双向引入之后。 原先OSPF外部…

闲鱼自动化软件——筛选/发送系统 V20已经测试完毕

做程序&#xff0c;就是不断地改&#xff0c;不断地优化。当改动达到一定程序&#xff0c;已经和前面形成断代&#xff0c;程序的升级时刻便到了。V20做了哪些更改或优化。1。优化抓取&#xff1a;在抓取环境优化参数&#xff0c;使抓取更顺滑&#xff0c;抓取数据效果上更准确…

智能家居创意DIY-Homekit智能灯

一、什么是智能灯 传统的灯泡是通过手动打开和关闭开关来工作。有时&#xff0c;它们可以通过声控、触控、红外等方式进行控制&#xff0c;或者带有调光开关&#xff0c;让用户调暗或调亮灯光。 智能灯泡内置有芯片和通信模块&#xff0c;可与手机、家庭智能助手、或其他智能…

若依-excel预览功能实现

实现效果及源码 实现效果如下图所示&#xff1a; 实现思路&#xff1a; 1.动态表格&#xff1a;定义表头数组&#xff0c;表格遍历表头生成表格列 2.读取excel文件内容&#xff0c;封装表头&#xff0c;绑定表格数据 代码修改 首先参考若依官网&#xff0c;先实现excel导入功…

C++基础——C++ 字符串

C基础——C 字符串C 字符串C 风格字符串C 中的 String 类C 字符串 C 提供了以下两种类型的字符串表示形式&#xff1a; C 风格字符串C 引入的 string 类类型 C 风格字符串 C 风格的字符串起源于 C 语言&#xff0c;并在 C 中继续得到支持。字符串实际上是使用 null 字符 ‘…

126、【回溯算法】leetcode ——332. 重新安排行程:回溯算法(C++版本)

题目描述 原题链接&#xff1a;332. 重新安排行程 解题思路 本题要解决的问题&#xff1a; 需要构建起始与目的机场的映射关系&#xff1b;每次选择目的机场时&#xff0c;需要选择当前最小字母顺序的机场&#xff1b;从“JFK”之后依次飞往&#xff0c;并且可能会有多条路径…

58同城AI Lab在WeNet中开源Efficient Conformer模型

2022年8月&#xff0c;58同城TEG-AI Lab语音技术团队完成了WeNet端到端语音识别的大规模落地&#xff0c;替换了此前基于Kaldi的系统&#xff0c;并针对业务需求对识别效果和推理速度展开优化&#xff0c;取得了优异的效果&#xff0c;当前录音文件识别引擎处理语音时长达1000万…

非标设备ERP管理系统可以帮助企业解决哪些管理难题?

多品种、小批量、交货周期短、非标准化生产是大多数非标设备制造企业共同的特性&#xff0c;这就要求非标设备制造企业应具备足够的经营、技术、生产和管理力量&#xff0c;否则就会顾此失彼&#xff0c;产品质量难以得到保证。非标设备制造企业常见的管理难题&#xff08;1&am…

DynaSLAM-2 DynaSLAM中Mask R-CNN部分源码解析(Ⅰ)

目录 1.Mask R-CNN源码地址 2.Mask R-CNN效果 3.项目配置 4.源码使用 1.Mask R-CNN源码地址 Mask R-CNN源码地址https://github.com/matterport/Mask_RCNN/releases 这里我们拿Mask R-CNN2.1版本进行讲解。 2.Mask R-CNN效果 最传统最核心的功能就是物体检测了…

4款让人心疼的电脑软件,由于免费又实用,常被同行挤压

许多小众软件&#xff0c;免费、实用、体验好、无广告&#xff0c;出淤泥而不染&#xff0c;却因过于良心备受排挤&#xff0c;让人唏嘘。 1、oCam 市面上的视频录屏工具&#xff0c;要么限制时长&#xff0c;要么附上水印&#xff0c;需要使用完整功能必须付费&#xff0c;oca…

Java项目调用C++端的订阅功能,获得推送数据(从设计到代码全栈完整过程)

前言 有关java和C的交互的基本概念和知识&#xff0c;本文不再详述。有需要的可以参考我的这篇文章。 JNI、DLL、SO等相关概念 开发背景 C项目端开发了一套股票市场资讯推送的功能&#xff0c;多个小组都会用到该功能&#xff0c;为了避免重复开发&#xff0c;中台小组要负担…

SpringBoot项目集成logback日志分等级配置

背景&#xff1a; 日志的作用&#xff1a; boot项目集成logback&#xff1a; 一、单模块项目配置&#xff1a; 1、添加依赖 2、添加logback-spring.xml配置文件到resources目录下 3、接下来启动一下项目&#xff0c;就可以看到我们的日志已经区分等级打印了 二、多微服务…

DVWA之SQL注入

Low(数字型注入)1、先确定正常和不正常的回显回显&#xff0c;就是显示正在执行的批处理命令及执行的结果等。输入1时&#xff0c;有回显&#xff0c;是正常的 数据库语句&#xff1a; select * from table where id 1输入5时&#xff0c;有回显&#xff0c;是正常的 数据库语句…