SHAP(一):使用 XGBoost 预测英雄联盟获胜

news2025/1/11 22:43:34

SHAP(一):使用 XGBoost 预测英雄联盟获胜

本笔记本使用 Kaggle 数据集 英雄联盟排名比赛,其中包含从 2014 年开始的 180,000 场英雄联盟排名比赛。 根据这些数据,我们构建了一个 XGBoost 模型,根据有关该球员比赛表现的统计数据来预测该球员的球队是否会获胜。

这里使用的方法适用于任何数据集。 我们使用此数据集来说明 SHAP 值如何帮助使梯度增强树(例如 XGBoost)可解释。 由于数据集的大小、交互作用、包含分类和连续特征及其可解释性(特别是对于游戏玩家),该数据集适合作为各个方面的一个很好的例子。 有关 SHAP 值的更多信息,请参阅:https://github.com/shap/shap

from pathlib import Path

import matplotlib.pyplot as pl
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import train_test_split

import shap
d:\work\miniconda3\lib\site-packages\scipy\__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.3
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"

加载数据集

要自己运行此程序,您需要从 Kaggle 下载数据集并确保下面的“前缀”变量正确。 为此,请点击上面给出的链接并下载并提取数据。 如果需要,更改“前缀”变量。

# read in the data
folder_path = Path("./data/league-of-legends-ranked-matches/")
matches = pd.read_csv(folder_path / "matches.csv")
participants = pd.read_csv(folder_path / "participants.csv")
stats1 = pd.read_csv(folder_path / "stats1.csv", low_memory=False)
stats2 = pd.read_csv(folder_path / "stats2.csv", low_memory=False)
stats = pd.concat([stats1, stats2])

# merge into a single DataFrame
a = pd.merge(
    participants, matches, left_on="matchid", right_on="id", suffixes=("", "_matches")
)
allstats_orig = pd.merge(
    a, stats, left_on="matchid", right_on="id", suffixes=("", "_stats")
)
allstats = allstats_orig.copy()

# drop games that lasted less than 10 minutes
allstats = allstats.loc[allstats["duration"] >= 10 * 60, :]

# Convert string-based categories to numeric values
cat_cols = ["role", "position", "version", "platformid"]
for c in cat_cols:
    allstats[c] = allstats[c].astype("category")
    allstats[c] = allstats[c].cat.codes
allstats["wardsbought"] = allstats["wardsbought"].astype(np.int32)

X = allstats.drop(columns=["win"])
y = allstats["win"]

# convert all features we want to consider as rates
rate_features = [
    "kills",
    "deaths",
    "assists",
    "killingsprees",
    "doublekills",
    "triplekills",
    "quadrakills",
    "pentakills",
    "legendarykills",
    "totdmgdealt",
    "magicdmgdealt",
    "physicaldmgdealt",
    "truedmgdealt",
    "totdmgtochamp",
    "magicdmgtochamp",
    "physdmgtochamp",
    "truedmgtochamp",
    "totheal",
    "totunitshealed",
    "dmgtoobj",
    "timecc",
    "totdmgtaken",
    "magicdmgtaken",
    "physdmgtaken",
    "truedmgtaken",
    "goldearned",
    "goldspent",
    "totminionskilled",
    "neutralminionskilled",
    "ownjunglekills",
    "enemyjunglekills",
    "totcctimedealt",
    "pinksbought",
    "wardsbought",
    "wardsplaced",
    "wardskilled",
]
for feature_name in rate_features:
    X[feature_name] /= X["duration"] / 60  # per minute rate

# convert to fraction of game
X["longesttimespentliving"] /= X["duration"]

# define friendly names for the features
full_names = {
    "kills": "Kills per min.",
    "deaths": "Deaths per min.",
    "assists": "Assists per min.",
    "killingsprees": "Killing sprees per min.",
    "longesttimespentliving": "Longest time living as % of game",
    "doublekills": "Double kills per min.",
    "triplekills": "Triple kills per min.",
    "quadrakills": "Quadra kills per min.",
    "pentakills": "Penta kills per min.",
    "legendarykills": "Legendary kills per min.",
    "totdmgdealt": "Total damage dealt per min.",
    "magicdmgdealt": "Magic damage dealt per min.",
    "physicaldmgdealt": "Physical damage dealt per min.",
    "truedmgdealt": "True damage dealt per min.",
    "totdmgtochamp": "Total damage to champions per min.",
    "magicdmgtochamp": "Magic damage to champions per min.",
    "physdmgtochamp": "Physical damage to champions per min.",
    "truedmgtochamp": "True damage to champions per min.",
    "totheal": "Total healing per min.",
    "totunitshealed": "Total units healed per min.",
    "dmgtoobj": "Damage to objects per min.",
    "timecc": "Time spent with crown control per min.",
    "totdmgtaken": "Total damage taken per min.",
    "magicdmgtaken": "Magic damage taken per min.",
    "physdmgtaken": "Physical damage taken per min.",
    "truedmgtaken": "True damage taken per min.",
    "goldearned": "Gold earned per min.",
    "goldspent": "Gold spent per min.",
    "totminionskilled": "Total minions killed per min.",
    "neutralminionskilled": "Neutral minions killed per min.",
    "ownjunglekills": "Own jungle kills per min.",
    "enemyjunglekills": "Enemy jungle kills per min.",
    "totcctimedealt": "Total crown control time dealt per min.",
    "pinksbought": "Pink wards bought per min.",
    "wardsbought": "Wards bought per min.",
    "wardsplaced": "Wards placed per min.",
    "turretkills": "# of turret kills",
    "inhibkills": "# of inhibitor kills",
    "dmgtoturrets": "Damage to turrets",
}
feature_names = [full_names.get(n, n) for n in X.columns]
X.columns = feature_names

# create train/validation split
Xt, Xv, yt, yv = train_test_split(X, y, test_size=0.2, random_state=10)
dt = xgb.DMatrix(Xt, label=yt.values)
dv = xgb.DMatrix(Xv, label=yv.values)

训练 XGBoost 模型

params = {
    "objective": "binary:logistic",
    "base_score": np.mean(yt),
    "eval_metric": "logloss",
}
model = xgb.train(
    params,
    dt,
    num_boost_round=10,
    evals=[(dt, "train"), (dv, "valid")],
    early_stopping_rounds=5,
    verbose_eval=25,
)
[0]	train-logloss:0.57255	valid-logloss:0.57258
[9]	train-logloss:0.34293	valid-logloss:0.34323

解释XGBoost模型

由于 Tree SHAP 算法是在 XGBoost 中实现的,因此我们可以快速计算数千个样本的精确 SHAP 值。 单个预测的 SHAP 值(包括最后一列中的预期输出)总和为该预测的模型输出。

# compute the SHAP values for every prediction in the validation dataset
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(Xv)

解释单个玩家赢得特定比赛的机会

SHAP 值的总和等于模型的预期输出与当前玩家的当前输出之间的差值。 请注意,对于 Tree SHAP 实现,解释的是模型的边际输出,而不是转换后的输出(例如逻辑回归的概率)。 这意味着该模型的 SHAP 值的单位是对数优势比。 较大的正值意味着玩家可能会获胜,而较大的负值意味着他们可能会失败。

shap.force_plot(explainer.expected_value, shap_values[0, :], Xv.iloc[0, :])
xs = np.linspace(-4, 4, 100)
pl.xlabel("Log odds of winning")
pl.ylabel("Probability of winning")
pl.title("How changes in log odds convert to probability of winning")
pl.plot(xs, 1 / (1 + np.exp(-xs)))
pl.show()

在这里插入图片描述

总结所有特征对整个数据集的影响

特定预测的特征的 SHAP 值表示当我们观察该特征时模型预测的变化程度。 在下面的汇总图中,我们将单个特征(例如“goldearned”)的所有 SHAP 值绘制成一行,其中 x 轴是 SHAP 值(对于该模型,以获胜的对数赔率为单位)。 通过对所有特征执行此操作,我们可以看到哪些特征对模型的预测有很大影响(例如“goldearned”),哪些特征对预测影响很小(例如“kills”)。 请注意,当点在线上不一致时,它们会垂直堆积以显示密度。 每个点也根据该特征的值从高到低进行着色。

shap.summary_plot(shap_values, Xv)

在这里插入图片描述

检查特征的变化如何改变模型的预测

我们上面训练的 XGBoost 模型非常复杂,但是通过绘制某个特征的 SHAP 值与所有玩家的该特征的实际值,我们可以看到特征值的变化如何影响模型的输出。 请注意,这些图与标准部分依赖图非常相似,但它们提供了额外的优势,即显示上下文对于特征的重要性(或者换句话说,交互项的重要性)。 有多少交互项影响特征的重要性是通过数据点的垂直分散来捕获的。 例如,在游戏中每分钟仅赚取 100 金币可能会使某些玩家的获胜几率降低 10 倍,而另一些玩家则只会降低 3 倍。 为什么是这样? 因为这些玩家的其他特征会影响赚取金币对于赢得游戏的重要性。 请注意,一旦您每分钟赚取至少 500 金币,垂直价差就会缩小,这意味着其他功能的背景对于高金币收入者而言不如低金币收入者那么重要。 我们用另一个最能解释交互效应方差的特征对数据点进行着色。 例如,如果你死得不多,赚到的金币少还不算太糟糕,但如果你也死了很多次,那就真的很糟糕了。

下图中的 y 轴代表该特征的 SHAP 值,因此 -4 表示观察该特征会将您的获胜对数几率降低 4,而值 +2 表示观察该特征会将您的获胜对数几率提高 2 。

请注意,这些图只是解释了 XGBoost 模型的工作原理,而不一定说明现实是如何工作的。 由于 XGBoost 模型是根据观察数据进行训练的,因此它不一定是因果模型,因此仅仅因为更改一个因素会使模型的获胜预测上升,并不总是意味着它会提高您的实际机会。

shap.dependence_plot(
    "Gold earned per min.", shap_values, Xv, interaction_index="Deaths per min."
)

在这里插入图片描述

# sort the features indexes by their importance in the model
# (sum of SHAP value magnitudes over the validation dataset)
top_inds = np.argsort(-np.sum(np.abs(shap_values), 0))

# make SHAP plots of the three most important features
for i in range(20):
    shap.dependence_plot(top_inds[i], shap_values, Xv)

在这里插入图片描述

在这里插入图片描述

请添加图片描述

请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1175436.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

48基于matlab的经验傅里叶分解,适用于非线性及非平稳时间序列分析,将信号进行精确分解。程序已调通,可直接运行。

基于matlab的经验傅里叶分解&#xff0c;适用于非线性及非平稳时间序列分析&#xff0c;将信号进行精确分解。程序已调通&#xff0c;可直接运行。

3D高斯泼溅(Splatting)简明教程

在线工具推荐&#xff1a; Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 3D场景编辑器 3D 高斯泼溅&#xff08;Splatting&#xff09;是用于实时辐射场渲染的 3D 高斯分布描述的一种光栅化技术&#xff0c;它允许实时渲染从小图像样…

优先级队列:PriorityQueue常用接口+构造+方法+源码分析+OJ练习

文章目录 PriorityQueue常用接口一.PriorityQueue 的特性二.PriorityQueue常用接口介绍1.优先级队列的构造2.插入/删除/获取优先级最高的元素3.PriorityQueue的扩容方式&#xff1a; PriorityQueue常用接口 一.PriorityQueue 的特性 1.Java集合框架中提供了 **PriorityQueue *…

ActiveMq学习⑦__ActiveMq协议

问题一、默认的61616端口如何更改&#xff1f; 问题二、你生产上的链接协议如何配置的&#xff1f;使用tcp吗&#xff1f; ActiveMQ 支持的client-broker 通讯协议有&#xff1a;TVP、NIO、UDP、SSL、Http(s)、VM。 其中配置TransportConnector 的文件在ActiveMQ 安装目录的co…

产品经理墨刀学习----注册页面

我们做的产品是一个校园论坛学习开发系统&#xff0c;目前才开始学习。 &#xff08;一&#xff09;流程图 &#xff08;二&#xff09;简单墨刀设计--注册页面 &#xff08;1&#xff09;有账号 &#xff08;a&#xff09;直接登录&#xff1a; &#xff08;b&#xff09;忘…

Git使用规范指南

文章目录 Git使用规范指南前言分支命名规范分支合并流程规范提交信息规范Angular提交规范注意事项 通用Git忽略文件配置 Git使用规范指南 前言 由于最近写完代码之后&#xff0c;Git使用不规范被领导说了&#xff0c;所以最近通过阅读大量的相关博客快速学习Git使用规范&#…

apachesolr启动带调试

这里solr.cmd报错&#xff0c;报错原因是java版本问题&#xff0c;后面发现这是因为多个java版本导致读取java_home失败&#xff0c; 那么我们修改solr.cmd中的JAVA_HOME为SOLR_JAVA_HOME IF DEFINED SOLR_JAVA_HOME set "JAVA_HOME%SOLR_JAVA_HOME%"环境变量将SOLR…

Qt全局定义

一、QtGlobal头文件 头文件中包含了Qt类库的一些全局定义&#xff0c;包括&#xff1a; 基本数据类型全局函数宏定义 二、基本数据类型 三、全局函数 四、宏定义 1.Qt版本相关的宏 1.1 QT_VERSION 这个宏展开为数值形式 0xMMNNPP (MM major, NN minor, PP patch) 表示…

P9831 [ICPC2020 Shanghai R] Gitignore

P9831 [ICPC2020 Shanghai R] Gitignore - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 只看题意翻译这道题是做不出来的&#xff0c;还要去看英文里面的规定&#xff08;这里就不放英文了&#xff09;&#xff0c;主要问题是不要公用子文件夹。 例如: 1 / a / 2 2 / a / 3…

【C语言】函数的系统化精讲(一)

&#x1f308;write in front :&#x1f50d;个人主页 &#xff1a; 啊森要自信的主页 &#x1f308;作者寄语 &#x1f308;&#xff1a; 小菜鸟的力量不在于它的体型&#xff0c;而在于它内心的勇气和无限的潜能&#xff0c;只要你有决心&#xff0c;就没有什么事情是不可能的…

线程条件控制实现线程的同步

前面讲了互斥锁&#xff0c;但是总感觉有些功能互斥锁有些不够用。 条件变量是线程另一可用的同步机制。条件变量给多个线程提供了一个会合的场所。条件变量与互斥量一起使用时&#xff0c;允许线程以无竞争的方式等待特定的条件发生。 条件本身是由互斥量保护的。线程在改变条…

Spring,SpringBoot和SpringMVC的关系以及区别 —— 超准确,可当面试题!!!也可供零基础学习

&#x1f9f8;欢迎来到dream_ready的博客&#xff0c;&#x1f4dc;相信您对这几篇博客也感兴趣o (ˉ▽ˉ&#xff1b;) &#x1f4dc;什么是SpringMVC&#xff1f;简单好理解&#xff01;什么是应用分层&#xff1f;SpringMVC与应用分层的关系&#xff1f; 什么是三层架构&…

Danswer 接入 Llama 2 模型 | 免费在 Google Colab 上托管 Llama 2 API

一、前言 前面在介绍本地部署免费开源的知识库方案时&#xff0c;已经简单介绍过 Danswer《Danswer 快速指南&#xff1a;不到15分钟打造您的企业级开源知识问答系统》&#xff0c;它支持即插即用不同的 LLM 模型&#xff0c;可以很方便的将本地知识文档通过不同的连接器接入到…

Linux中的高级IO

文章目录 1.IO1.1基本介绍1.2基础io的低效性1.3如何提高IO效率1.4五种IO模型1.5非阻塞模式的设置 2.IO多路转接之Select2.1函数的基本了解2.2fd_set理解2.3完整例子代码&#xff08;会在代码中进行讲解&#xff09;2.4优缺点 3.多路转接之poll3.1poll函数的介绍3.2poll服务器3.…

初阶JavaEE(15)(Cookie 和 Session、理解会话机制 (Session)、实现用户登录网页、上传文件网页、常用的代码片段)

接上次博客&#xff1a;初阶JavaEE&#xff08;14&#xff09;表白墙程序-CSDN博客 Cookie 和 Session 你还记得我们之前提到的Cookie吗&#xff1f; Cookie是HTTP请求header中的一个属性&#xff0c;是一种用于在浏览器和服务器之间持久存储数据的机制&#xff0c;允许网站…

【51单片机】蜂鸣器(学习笔记)

一、蜂鸣器 1、蜂鸣器介绍 鸣器是一种将电信号转换为声音信号的器件&#xff0c;常用来产生设备的按键音、报警音等提示信号 有源蜂鸣器&#xff1a;内部自带振荡源&#xff0c;将正负极接上直流电压即可持续发声&#xff0c;频率固定无源蜂鸣器&#xff1a;内部不带振荡源&…

Educational Codeforces Round 157 (Rated for Div. 2) F. Fancy Arrays(容斥+组合数学)

题目 称一个长为n的数列a是fancy的&#xff0c;当且仅当&#xff1a; 1. 数组内至少有一个元素在[x,xk-1]之间 2. 相邻项的差的绝对值不超过k&#xff0c;即 t(t<50)组样例&#xff0c;每次给定n(1<n<1e9),x(1<x<40), 求fancy的数组的数量&#xff0c;答案…

【错误解决方案】ModuleNotFoundError: No module named ‘my_fake_useragent‘

1. 错误提示 ModuleNotFoundError: No module named my_fake_useragent&#xff0c;这意味着你试图导入一个名为 my_fake_useragent 的模块&#xff0c;但Python找不到这个模块。 2. 解决方案 检查模块名是否正确: 确保你试图导入的模块名是正确的。也许你拼写错误或者大小写不…

NeurIPS 2023 | 基于多模态统一表达的跨模态泛化

©PaperWeekly 原创 作者 | 夏炎 学校 | 浙江大学 研究方向 | 多模态 论文标题&#xff1a; Achieving Cross Modal Generalization with Multimodal Unified Representation 模型&代码地址&#xff1a; https://github.com/haihuangcode/CMG 在本文中&#xff0c;我们…

LeetCode 面试题 16.17. 连续数列

文章目录 一、题目二、C# 题解 一、题目 给定一个整数数组&#xff0c;找出总和最大的连续数列&#xff0c;并返回总和。 示例&#xff1a; 输入&#xff1a; [-2,1,-3,4,-1,2,1,-5,4] 输出&#xff1a; 6 解释&#xff1a; 连续子数组 [4,-1,2,1] 的和最大&#xff0c;为 6。…