【SHAP解释运用】基于python的树模型特征选择+随机森林回归预测+SHAP解释预测

news2025/1/24 17:58:17

1.导入必要的库

import pandas as pd  
import numpy as np  
import matplotlib.pyplot as plt  
import seaborn as sns  
from sklearn.model_selection import train_test_split  
from sklearn.ensemble import RandomForestRegressor  
from sklearn.tree import export_graphviz  
#from sklearn.inspection import plot_partial_dependence   
from sklearn.metrics import mean_squared_error  
import shap  
import warnings  

2.设置忽略警告与显示字体、负号

warnings.filterwarnings("ignore")  
  
# 设置Matplotlib的字体属性  
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用于中文显示,你可以更改为其他支持中文的字体  
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号 

3.导入数据集

3.1加载数据
# 1. 加载数据  
df = pd.read_excel('train.xlsx')  
X = df.iloc[:, :-1]  # 特征  
y = df.iloc[:, -1]   # 标签  
3.2查看数据分布
1.箱线图
plt.figure(figsize=(30, 6))  
sns.boxplot(data=df)  
plt.title('Box Plots of Dataset Features', fontsize=40, color='black')  
# 如果需要设置坐标轴标签的字体大小和颜色  
plt.xlabel('X-axis Label', fontsize=20, color='red')  # 设置x轴标签的字体大小和颜色  
plt.ylabel('Y-axis Label', fontsize=20, color='green')  # 设置y轴标签的字体大小和颜色  
  
# 还可以调整刻度线的长度、宽度等属性  
plt.tick_params(axis='x', labelsize=20, colors='blue', length=5, width=1)  # 设置x轴刻度线、刻度标签的更多属性  
plt.tick_params(axis='y', labelsize=20, colors='deepskyblue', length=5, width=1)  # 设置y轴刻度线、刻度标签的更多属性    
plt.xticks(rotation=45)  # 如果特征名很长,可以旋转x轴标签  
plt.show()

        结果如图3-1所示:

图3-1

        结果图实在丑陋,这是由数据分布不均衡造成的,这里重点不是数据清洗,就这样凑着用吧。

2.分布图
# 注意:distplot 在 seaborn 0.11.0+ 中已被移除  
# 你可以分别使用 histplot 和 kdeplot  
  
plt.figure(figsize=(50, 10))  
for i, feature in enumerate(df.columns, 1):  
    plt.subplot(1, len(df.columns), i)  
    sns.histplot(df[feature], kde=True, bins=30, label=feature,color='blue') 
    plt.title(f'QQ plot of {feature}', fontsize=40, color='black')  
    # 如果需要设置坐标轴标签的字体大小和颜色  
    plt.xlabel('X-axis Label', fontsize=35, color='red')  # 设置x轴标签的字体大小和颜色  
    plt.ylabel('Y-axis Label', fontsize=40, color='green')  # 设置y轴标签的字体大小和颜色  
  
    # 还可以调整刻度线的长度、宽度等属性  
    plt.tick_params(axis='x', labelsize=40, colors='blue', length=5, width=1)  # 设置x轴刻度线、刻度标签的更多属性  
    plt.tick_params(axis='y', labelsize=40, colors='deepskyblue', length=5, width=1)  # 设置y轴刻度线、刻度标签的更多属性 
plt.tight_layout()  
plt.show()

        结果如图3-2所示:

图3-2

3.QQ图
from scipy import stats
plt.figure(figsize=(50, 10))  
for i, feature in enumerate(df.columns, 1):  
    plt.subplot(1, len(df.columns), i)  
    stats.probplot(df[feature], dist="norm", plot=plt)  
    plt.title(f'QQ plot of {feature}', fontsize=40, color='black')  
    # 如果需要设置坐标轴标签的字体大小和颜色  
    plt.xlabel('X-axis Label', fontsize=35, color='red')  # 设置x轴标签的字体大小和颜色  
    plt.ylabel('Y-axis Label', fontsize=40, color='green')  # 设置y轴标签的字体大小和颜色  
  
    # 还可以调整刻度线的长度、宽度等属性  
    plt.tick_params(axis='x', labelsize=40, colors='blue', length=5, width=1)  # 设置x轴刻度线、刻度标签的更多属性  
    plt.tick_params(axis='y', labelsize=40, colors='deepskyblue', length=5, width=1)  # 设置y轴刻度线、刻度标签的更多属性   
plt.tight_layout()  
plt.show()

        结果如图3-3所示:

图3-3

4.树模型特征选择

# 4. 特征选择(使用随机森林的特征重要性)  
rf = RandomForestRegressor(n_estimators=100, random_state=42)  
rf.fit(X_scaled, y)  
importances = rf.feature_importances_  
indices = np.argsort(importances)[::-1]  
  
# 可视化特征重要性  
plt.figure(figsize=(10,7))  
plt.title("Feature importances")  
plt.bar(range(X.shape[1]), importances[indices],align="center", color='cyan')
plt.xticks(range(X.shape[1]), [X.columns[i] for i in indices], rotation='vertical')  
plt.xlim([-1, X.shape[1]])  
plt.show()

        特征重要性比较如图4-1所示:

图4-1

5.随机森林回归预测

# 划分训练集和测试集  
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)  
  
# 随机森林回归预测  
rf_final = RandomForestRegressor(n_estimators=100, random_state=42)  
rf_final.fit(X_train, y_train)  
y_pred = rf_final.predict(X_test)  
mse = mean_squared_error(y_test, y_pred)  
print(f"Mean Squared Error: {mse}")  
 
# 预测结果输出与比对
plt.figure()
plt.plot(np.arange(21), y_test[:100], "go-", label="True value")
plt.plot(np.arange(21), y_pred[:100], "ro-", label="Predict value")
plt.title("True value And Predict value")
plt.legend()
plt.show()
  

        预测结果如图5-1所示:

图5-1

        由图5-1结合这里的误差Mean Squared Error: 16.092619015714185,说明预测效果很一般,不过本身数据集没有经过清洗,数据分布不合理,有这样的结果也能接受。我一般使用matlab进行数据清晰和标准化,matlab暂时打不开,先搁置,后面我会出数据标准化的文章。

5.SHAP库解释预测

5.1shap库下载安装

        这里的shap库我已经下载安装过了,没有下载安装的在pycharm终端、Anaconda Promt终端等等执行命令进行下载安装,最好带上清华镜像源,在网络信号不好时也能顺利安装且节省时间。

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple shap
5.2waterfall
shap.plots.waterfall(shap_values[0]) # For the first observation

        结果如图5-1所示:

图5-1

5.3forceplot
#相互作用图
force_plot1 = shap.force_plot(explainer.expected_value, np.mean(shap_values, axis=0), np.mean(X_test, axis=0),feature_label,matplotlib=True, show=False)
shap_interaction_values = explainer.shap_interaction_values(X_test)
shap.summary_plot(shap_interaction_values,X_test)

        结果如图5-2所示:

图5-2

5.4特征影响图
shap.plots.force(explainer.expected_value,shap_values.values,shap_values.data)

        结果如图5-3所示:

图5-3

5.5特征密度散点图:summary_plot/beeswarm
5.5.1summary_plot
# 创建SHAP解释器
explainer = shap.TreeExplainer(rf)

# 计算SHAP值
shap_values = explainer.shap_values(X_test)

#特征标签
feature_label=['feature1','feature2','feature3','feature4','feature5','feature6','feature7']

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'
plt.rcParams['font.size'] = 13  # 设置字体大小为14
# 现在创建 SHAP 可视化 

#配色   viridis  Spectral   coolwarm  RdYlGn  RdYlBu  RdBu  RdGy  PuOr  BrBG PRGn  PiYG 
shap.summary_plot(shap_values, X_test,feature_names=feature_label)

#粉红色点:表示该特征值在这个观察中对模型预测产生了正面影响(增加预测值)
#蓝色点:表示该特征值在这个观察中对模型预测产生了负面影响(降低预测值)
#水平轴(SHAP 值)显示了影响的大小。点越远离中心线(零点),该特征对模型输出的影响越大
#图中垂直排列的特征按影响力从上到下排序。上方的特征对模型输出的总体影响更大,而下方的特征影响较小。
# 最上方的特征显示了大量的正面和负面影响,表明它在不同的观察值中对模型预测的结果有很大的不同影响。
# 中部的特征也显示出两种颜色的点,但点的分布更集中,影响相对较小。
# 底部的特征对模型的影响最小,且大部分影响较为接近零,表示这些特征对模型预测的贡献较小

        结果如图5-4所示:

图5-4


# 创建SHAP解释器
explainer = shap.TreeExplainer(rf)
# 计算SHAP值
shap_values = explainer.shap_values(X_test)
#特征标签
feature_label=['feature1','feature2','feature3','feature4','feature5','feature6','feature7']
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = 'Times New Roman'
plt.rcParams['font.size'] = 13  # 设置字体大小为14
# 现在创建 SHAP 可视化 
#配色   viridis  Spectral   coolwarm  RdYlGn  RdYlBu  RdBu  RdGy  PuOr  BrBG PRGn  PiYG 
shap.summary_plot(shap_values,X_test,feature_names=feature_label,cmap='Spectral')

使颜色丰富些如图5-5所示:

图5-5

5.5.2beeswarm
# summarize the effects of all the features
# 样本决策图
shap.initjs()
shap_values = explainer(X_test)
expected_value = explainer.expected_value
shap.plots.beeswarm(shap_values)

结果如图5-6所示:

图5-6

5.6特征重要性SHAP值
shap.summary_plot(shap_values,X_test,feature_names=feature_label,plot_type='bar')
#主要表示绝对重要值的大小,把SHAP value 的样本取了绝对平均值

        或者:

shap.plots.bar(shap_values)

        结果如图5-7、图5-8所示,本质都是一样的:

图5-7

图5-8

5.7聚类热力图:heatmap plot

#热图
shap.initjs()
shap_values = explainer(X_test)
shap.plots.heatmap(shap_values)

        结果如图5-9所示:

图5-9

5.7层次聚类shap值
# 层次聚类 + SHAP值
clust = shap.utils.hclust(X, y, linkage="single")
shap.plots.bar(shap_values, clustering=clust, clustering_cutoff=1)

        结果如图5-10所示:

图5-10

5.8决策图
# 样本决策图
shap.initjs()
shap_values = explainer.shap_values(X_test)
expected_value = explainer.expected_value
shap.decision_plot(expected_value, shap_values,feature_label)

        结果如图5-11所示:

图5-11

变形1:由数值 -> 概率

# 样本决策图
shap.initjs()
shap_values = explainer.shap_values(X_test)
expected_value = explainer.expected_value
feature_label=['feature1','feature2','feature3','feature4','feature5','feature6','feature7']
shap.decision_plot(expected_value, shap_values, feature_label, link='logit')

        结果如图5-12所示:

图5-12

变形2:高亮某个样本线highlight

shap.decision_plot(expected_value, shap_values, feature_label, highlight=12)

        结果如图5-13所示:

图5-13

5.9特征依赖图:dependence_plot
5.9.1单个特征依赖
shap.dependence_plot('feature1', shap_values,X_test, interaction_index=None)

        结果如图5.14所示:

图5-14

5.9.2相互依赖图
shap.dependence_plot('feature3', shap_values,X_test, interaction_index='feature4')

        结果如图5-15所示:

图5-15

5.10相互作用图:summary_plot
shap.summary_plot(shap_interaction_values,X_test)

        结果如图5-16所示:

图5-16

具体的每种解释图的含义可以搜寻以下参考文章:

代码借鉴:http://t.csdnimg.cn/6JWrj

理论借鉴   

http://t.csdnimg.cn/6JWrj

http://t.csdnimg.cn/H9X0B

http://t.csdnimg.cn/zvtA8

http://t.csdnimg.cn/nygl6

http://t.csdnimg.cn/zyHy0

http://t.csdnimg.cn/rTPw2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1865054.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

对input输入框脱敏的实现(input输入时可回删、可粘贴)

目录 1.要解决的问题2.第一回合:substring replace3.第二回合:移魂大法4.第三回合:移花接木5.第四回合:万佛归宗 写在前面: 如有转载,务必注明出处,否则后果自负。 1.要解决的问题 继续与客户…

从云原生视角看 AI 原生应用架构的实践

本文核心观点: 基于大模型的 AI 原生应用将越来越多,容器和微服务为代表的云原生技术将加速渗透传统业务。API 是 AI 原生应用的一等公民,并引入了更多流量,催生企业新的生命力和想象空间。AI 原生应用对网关的需求超越了传统的路…

云计算运维工程师的突发状况处理

云计算运维工程师在应对突发的故障和紧急情况时,需要采取一系列迅速而有效的措施来最小化服务中断的时间并恢复系统的稳定性。 以下是一些关键步骤和策略: 快速响应: 立即识别并确认故障的性质和范围。通知团队成员和相关的利益相关者,确保所有人了解当前情况。故障诊断:…

Web Worker 学习及使用

了解什么是 Web Worker 提供了可以在后台线程中运行 js 的方法。可以不占用主线程,不干扰用户界面,可以用来执行复杂、耗时的任务。 在worker中运行的是另一个全局上下文,不能直接获取 Window 全局对象。不同的 worker 可以分为专用和共享&…

使用 Vue Router 的 meta 属性实现多种功能

在 Vue.js 中,Vue Router 提供了强大的路由管理功能。通过 meta 属性,我们可以在路由定义中添加自定义元数据,以实现访问控制、页面标题设置、角色权限管理、页面过渡效果等多种功能。本文将总结如何使用 meta 属性来实现这些常见的功能。 1.…

gici-open学习日记(7):GNSS图优化——RTK

gici-open学习日记——GNSS RTK图优化 前言初始化RTK的调用rearrangePhasesAndCodes双差构造formPhaserangeDDPair周跳探测cycleSlipDetectionSD添加参数块模糊度参数部分addSdAmbiguityParameterBlocks添加双差伪距残差addDdPseudorangeResidualBlocks添加双差相位残差addDdPh…

springcloud-gateway 路由加载流程

问题 Spring Cloud Gateway版本是2.2.9.RELEASE,原本项目中依赖服务自动发现来自动配置路由到微服务的,但是发现将spring.cloud.gateway.discovery.locator.enabledfalse 启动之后Gateway依然会将所有微服务自动注册到路由中,百思不得其解&a…

手把手从零开始搭建远程访问服务

远程访问服务工具——FRP frp 是一个能够实现内网穿透的高性能的反向代理应用,支持 TCP、UDP、HTTP、HTTPS 等多种协议。可以将内网服务以安全、便捷的方式通过具有公网的服务器来转发。 资源链接 根据自己服务型号和操作系统来选取对应的文件,不知道的…

汽车EDI: BMW EDI项目案例

宝马集团是全世界成功的汽车和摩托车制造商之一,旗下拥有BMW、MINI和Rolls-Royce三大品牌;同时提供汽车金融和高档出行服务。作为一家全球性公司,宝马集团在14个国家拥有31家生产和组装厂,销售网络遍及140多个国家和地区。 本文主…

mitt通信

一、mitt介绍 mitt是一款轻量级的组件通信插件(大小仅为200字节左右) 二、mitt安装 npm install --save mitt三、使用 1.在组件中使用 import mitt from mitt //创建mitt实例 const emitter mitt()// 监听事件 emitter.on(foo, e > console.log(foo, e) )// 通过通配符监…

09. Java ThreadLocal 的使用

1. 前言 本节内容主要是对 ThreadLocal 进行深入的讲解,具体内容点如下: 了解 ThreadLocal 的诞生,以及总体概括,是学习本节知识的基础;了解 ThreadLocal 的作用,从整体层面理解 ThreadLocal 的程序作用&…

VC++开发积累——vc++6.0中删除函数的方法,右键,Delete

目录 引出插曲:删除函数的方法多行注释的实现代码输入的自动提示搜索出来,标记和取消标记跳转到上一步的位置 ctrl TAB 总结其他规范和帮助文档创建第一个Qt程序对象树概念信号signal槽slot自定义信号和槽1.自定义信号2.自定义槽3.建立连接4.进行触发 自…

千呼新零售2.0-OCR拍照识别采购单

千呼新零售2.0系统是零售行业连锁店一体化收银系统,包括线下收银线上商城连锁店管理ERP管理商品管理供应商管理会员营销等功能为一体,线上线下数据全部打通。 适用于商超、便利店、水果、生鲜、母婴、服装、零食、百货、宠物、中医养生、大健康等连锁店…

Python 实现Excel转TXT,或TXT文本导入Excel

Excel是一种具有强大的数据处理和图表制作功能的电子表格文件,而TXT则是一种简单通用、易于编辑的纯文本文件。将Excel转换为TXT可以帮助我们将复杂的数据表格以文本的形式保存,方便其他程序读取和处理。而将TXT转换为Excel则可以将文本文件中的数据导入…

AI引领创意潮流:高效生成图片,参考图助力,一键保存到指定文件夹

在这个数字与创意交融的时代,我们迎来了AI绘画的新纪元。借助先进的AI技术,我们不仅能够高效生成图片,还能在参考图的启发下,激发无限创意,让您的想象力在数字世界中自由翱翔。 首助编辑高手软件中的魔法智能绘图板块&…

PMP证书在国内已经泛滥了,大家怎么看?

目前,越来越多的人获得了PMP证书。自1999年PMP引入中国以来,全国累计PMP考试人数接近60万人次,通过PMP认证的人数约为42万人。虽然这个数据看起来很大,但绝对不能说是过多。 首先,PMP在中国并不普遍。根据美国项目管理…

解决go语言对接s3的SDK上传文件遇到的问题

先看正确的配置 问题1 配置文件中的OssEndpoint 不管是minio还是oss需要带上http://或者https:// 否则会出现这个问题 operation error S3: PutObject, exceeded maximum number of attempts, 3, https response error StatusCode: 0, RequestID: , HostID: , request send …

qt报错:“QtRunWork”任务返回了 false,但未记录错误。

qt报错:“QtRunWork”任务返回了 false,但未记录错误。 说明情况一 说明 这个报错可能的原因有很多,这里只写一种,以后遇到再进行补充。 情况一 如果 Q_OBJECT 宏未正确处理,通常会出现类似的错误。 要使用信号与槽…

视频汇聚平台LntonCVS视频集中存储平台技术解决方案

安防视频监控技术是一种利用各种监控设备捕捉实时画面,并将其传输至监控中心或数据存储设备的技术。随着科技的不断进步,监控视频技术也在不断改进,应用领域也在不断扩展。 然而,尽管技术进步,当前视频监控技术仍然面临…

PointCloudLib-特征(Features)-全局对齐空间分布 (GASD) 描述符

全局对齐空间分布 (GASD) 描述符 本文档介绍用于高效对象识别和姿势估计的全局对齐空间分布 ([GASD]) 全局描述符。 GASD 基于对表示对象实例的整个点云的参考系的估计,该参考系用于将其与规范坐标系对齐。之后,根据对齐点云的 3D 点的空间分布方式计算对齐点云的描述符…