【MATLAB第116期】基于MATLAB的NBRO-XGBoost的SHAP可解释回归模型(敏感性分析方法)

news2025/4/23 13:44:33

【MATLAB第116期】基于MATLAB的NBRO-XGBoost的SHAP可解释回归模型(敏感性分析方法)

引言

该文章实现了一个可解释的回归模型,使用NBRO-XGBoost(方法可以替换,但是需要有一定的编程基础)来预测特征输出。该模型利用七个变量参数作为输入特征进行训练。为了提高可解释性,应用了SHapley Additive exPlanations(SHAP),去深入了解每个参数对模型预测的贡献。

第112期用了BP神经网络作为代理模型, xgboost难度在于其环境配置,其次代码的更改,以及最重要的是计算速度,目前已解决上述问题。

一、NBRO-XGBoost模型训练

%  清空环境变量
warning off             % 关闭报警信息
close all               % 关闭开启的图窗
clear                   % 清空变量
clc                     % 清空命令行

%%  导入数据
res = xlsread('数据集.xlsx');
%%  数据分析
num_size = 0.8;                              % 训练集占数据集比例
outdim = 1;                                  % 最后1列为输出
num_samples = size(res, 1);                  % 样本个数
res = res(randperm(num_samples), :);         % 打乱数据集(不希望打乱时,注释该行)
num_train_s = round(num_size * num_samples); % 训练集样本个数
f_ = size(res, 2) - outdim;                  % 输入特征维度
id=1;                                        % 输出第几个因变量()
%%  划分训练集和测试集
P_train = res(1: num_train_s, 1: f_)';
T_train = res(1: num_train_s, f_ +id)';
%T_train=mean(T_train0);
M = size(P_train, 2);

P_test = res(num_train_s + 1: end, 1: f_)';
T_test = res(num_train_s + 1: end, f_ + id)';
%T_test=mean(T_test0);
N = size(P_test, 2);

%%  数据归一化
[p_train, ps_input] = mapminmax(P_train, 0, 1);%将训练集和测试集的数据调整到0到1之间
p_test = mapminmax('apply', P_test, ps_input);

[t_train, ps_output] = mapminmax(T_train, 0, 1);% 对测试集数据做归一化
t_test = mapminmax('apply', T_test, ps_output);

%%  数据转置 为适应模型的建立
p_train = p_train'; p_test = p_test';
t_train = t_train'; t_test = t_test';

%%  参数设置
fun = @getObjValue;                 % 目标函数
dim = 3;                            % 优化参数个数
lb  = [10, 10, 0.01];             % 优化参数目标下限(最大迭代次数,深度,学习率)
ub  = [200,  20,  1];             % 优化参数目标上限(最大迭代次数,深度,学习率)
SearchAgents_no = 10;                            % 种群数量6
Max_iteration = 10;                 % 最大迭代次数20
params.objective = 'reg:linear';    % 回归函数

%%  优化算法
[Best_score ,Best_pos, curve] = NRBO(SearchAgents_no, Max_iteration, lb, ub, dim, fun);

%%  获取最优参数
num_trees = round(Best_pos(1, 1));         % 迭代次数
params.max_depth = round(Best_pos(1, 2));  % 树的深度
params.eta = Best_pos(1, 3);        % 学习率

%%  建立模型
for j = 1 : size(t_train,2)  
%% 建立模型
 model(j) = xgboost_train(p_train, t_train(:,j), params, num_trees);  
%%  模型预测
 t_sim1(:,j) = xgboost_test(p_train, model(j));
 t_sim2(:,j) = xgboost_test(p_test , model(j));
end

%%  数据反归一化
T_sim1 = mapminmax('reverse', t_sim1', ps_output);
T_sim2 = mapminmax('reverse', t_sim2', ps_output);

figure
plot(1 : length(curve), curve, 'LineWidth', 1.5);
title('NRBO-XGboost', 'FontSize', 10);
xlabel('迭代次数', 'FontSize', 10);
ylabel('适应度值', 'FontSize', 10);
grid on
%% 评价指标
%%  绘图
for i = 1 : 1

figure
%%  均方根误差
error1 = sqrt(sum((T_sim1(i, :) - T_train(i, :)).^2) ./ M);
error2 = sqrt(sum((T_sim2(i, :) - T_test (i, :)).^2) ./ N);

subplot(2, 1, 1)
plot(1: M, T_train(i, :), 'r-', 1: M, T_sim1(i, :), 'b-', 'LineWidth', 1)
legend('真实值', 'NRBO-XGboost预测值')
xlabel('预测样本')
ylabel('预测结果')
string = {'训练集预测结果对比'; ['RMSE=' num2str(error1)]};
title(string)
xlim([1, M])
grid

subplot(2, 1, 2)
plot(1: N, T_test(i, :), 'r-', 1: N, T_sim2(i, :), 'b-o', 'LineWidth', 1)
legend('真实值', 'NRBO-XGboost预测值')
xlabel('预测样本')
ylabel('预测结果')
string = {'测试集预测结果对比'; ['RMSE=' num2str(error2)]};
title(string)
xlim([1, N])
grid

%%  相关指标计算
%  R2
R1(i) = 1 - norm(T_train(i, :) - T_sim1(i, :))^2 / norm(T_train(i, :) - mean(T_train(i, :)))^2;
R2(i) = 1 - norm(T_test (i, :) - T_sim2(i, :))^2 / norm(T_test (i, :) - mean(T_test (i, :)))^2;

disp(['输出:', num2str(i), '  训练集数据的R2为:', num2str(R1(i))])
disp(['输出:', num2str(i), '  测试集数据的R2为:', num2str(R2(i))])

%  MAE
mae1(i) = sum(abs(T_sim1(i, :) - T_train(i, :))) ./ M ;
mae2(i) = sum(abs(T_sim2(i, :) - T_test (i, :))) ./ N ;

disp(['输出:', num2str(i), '  训练集数据的MAE为:', num2str(mae1(i))])
disp(['输出:', num2str(i), '  测试集数据的MAE为:', num2str(mae2(i))])

%  MBE
mbe1(i) = sum(T_sim1(i, :) - T_train(i, :)) ./ M ;
mbe2(i) = sum(T_sim2(i, :) - T_test (i, :)) ./ N ;

disp(['输出:', num2str(i), '  训练集数据的MBE为:', num2str(mbe1(i))])
disp(['输出:', num2str(i), '  测试集数据的MBE为:', num2str(mbe2(i))])

end

二、模型训练结果:

输出:1 训练集数据的R2为:1
输出:1 测试集数据的R2为:0.90554
输出:1 训练集数据的MAE为:0.011996
输出:1 测试集数据的MAE为:1.6451
输出:1 训练集数据的MBE为:-1.1766e-06
输出:1 测试集数据的MBE为:0.27044

在这里插入图片描述
在这里插入图片描述

三、SHAP分析

1、生成随机数据
在本部分,生成一组合成输入数据用于SHAP分析。这种合成数据允许在受控和一致的方式下评估模型的特征贡献。步骤包括:

样本数量:脚本设置生成的合成样本数量为200(numSamples = 200).
特征范围: 定义操作参数在特定范围内,选择训练数据中各个输入变量的最大值和最小值
不需要手动再输入

VarMin=min(P_train');%各个参数下限
VarMax=max(P_train');%各个参数上限

随机数据生成: 使用rand函数在定义的范围内为每个特征生成随机值,创建200个样本.
并行计算提高速度

parfor i=1:size(x,2)
x_shap(:,i)=VarMin(i)+ (VarMax(i) - VarMin(i)) * rand(numSamples, 1);
end

此生成数据用于评估SHAP值并分析每个特征如何影响模型的预测。生成随机输入数据确保了SHAP分析中特征值的广泛范围,便于更全面地评估特征重要性.

2、计算SHAP值
该代码计算神经网络模型的SHapley Additive exPlanations(SHAP)值。SHAP值量化了每个特征对模型预测的贡献。该过程包括:

  1. 预分配SHAP值矩阵:初始化一个矩阵以存储所有输入样本和特征的SHAP值.
    2.计算参考值:将参考值计算为所有输入特征的平均值,用于在排除或包含特征时进行比较.
    3.计算SHAP值:对于每个输入样本,使用自定义的shapley_ann函数计算SHAP值,该函数迭代所有可能的特征组合以确定每个特征对预测的贡献.
    4.自定义的shapley函数接受一个训练好的神经网络(net)、当前输入样本和参考值来计算每个特征的SHAP值。该方法提供了对单个特征如何影响模型输出的洞察.
% ------------------------------------
function shapValues = shapley_nrbo_xgb_fast(x_shap, refValue) % 
    

    % 使用Shapley公式计算SHAP值
    如果有7个特征,则依次分析每个特征的累计贡献值
    当分析第1个特征时,排除当前特征,即 1  0  0  0  0  0  0
    迭代所有可能的特征组合 
   for i=12^(D-1)
    xt1: 每个样本的特征变量输入值(处理后)   1*7
    xt2: 计算的每个样本平均值(处理后)       1*7
    xt3: 当分析不同特征时,将该特征值替换为平均值。  1*7
    shapValues=shapValues+xgboost_test(xt3)-xgboost_test(xt2)   
  end

3、可视化
------蜂群图:为每个特征创建散点图(蜂群图),显示所有样本的SHAP值。特征值被标准化并颜色编码以提高可解释性.
包括轴标签、网格、框以提高清晰度以及带有操作参数标签的颜色条. 此SHAP摘要图有助于理解哪些特征对模型的预测影响最大以及特征在样本中的变化情况.显示每个特征对模型预测的贡献。
在这里插入图片描述

-----条形图
计算平均绝对SHAP值:计算每个特征的绝对SHAP值的平均值,以量化每个特征的整体重要性.
条形图可视化:创建一个水平条形图,特征按其平均绝对SHAP值排序。这提供了模型中特征重要性的清晰、排序表示. 结果的SHAP摘要条形图有助于识别哪些特征对模型的预测影响最大.

在这里插入图片描述

四、代码获取

1.阅读首页置顶文章
2.关注CSDN
3.根据自动回复消息,私信回复“116期”以及相应指令,即可获取对应下载方式。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2340818.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

微信公众号消息模板推送没有“详情“按钮?无法点击跳转

踩坑!!!!踩坑!!!!踩坑!!!! 如下 简单说下我的情况,按官方文档传参url了 、但就是看不到查看详情按钮 。如下 真凶&#x…

电动单座V型调节阀的“隐形守护者”——阀杆节流套如何解决高流速冲刷难题

电动单座V型调节阀的“隐形守护者”——阀杆节流套如何解决高流速冲刷难题? 在工业自动化控制中,电动单座V型调节阀因其精准的流量调节能力,成为石油、化工等领域的核心设备。然而,长期高流速工况下,阀芯与阀座的冲刷腐…

自动驾驶与机器人算法学习

自动驾驶与机器人算法学习 直播与学习途径 欢迎你的点赞关注~

【网络编程】TCP数据流套接字编程

目录 一. TCP API 二. TCP回显服务器-客户端 1. 服务器 2. 客户端 3. 服务端-客户端工作流程 4. 服务器优化 TCP数据流套接字编程是一种基于有连接协议的网络通信方式 一. TCP API 在TCP编程中,主要使用两个核心类ServerSocket 和 Socket ServerSocket Ser…

从零开始配置 Zabbix 数据库监控:MySQL 实战指南

Zabbix作为一款开源的分布式监控工具,在监控MySQL数据库方面具有显著优势,能够为数据库的稳定运行、性能优化和故障排查提供全面支持。以下是使用Zabbix监控MySQL数据库的配置。 一、安装 Zabbix Agent 和 MySQL 1. 安装 Zabbix Agent services:zabbix…

Java学习手册:RESTful API 设计原则

一、RESTful API 概述 REST(Representational State Transfer)即表述性状态转移,是一种软件架构风格,用于设计网络应用程序。RESTful API 是符合 REST 原则的 Web API,通过使用 HTTP 协议和标准方法(GET、…

读一篇AI论文并理解——通过幻觉诱导优化缓解大型视觉语言模型中的幻觉

目录 论文介绍 标题 作者 Publish Date Time PDF文章下载地址 文章理解分析 📄 中文摘要:《通过幻觉诱导优化缓解大型视觉语言模型中的幻觉》 🧠 论文核心动机 🚀 创新方法:HIO(Hallucination-In…

IOT项目——物联网 GPS

GeoLinker - 物联网 GPS 可视化工具 项目来源制作引导 项目来源 [视频链接] https://youtu.be/vi_cIuxDpcA?sigMaOKv681bAirQF8 想要在任何地方追踪任何东西吗?在本视频中,我们将向您展示如何使用 ESP32 和 Neo-6M GPS 模块构建 GPS 跟踪器——这是一…

Java学习手册:HTTP 协议基础知识

一、HTTP 协议概述 HTTP(HyperText Transfer Protocol)即超文本传输协议,是用于从万维网(WWW:World Wide Web )服务器传输超文本到本地浏览器的传输协议。它是一个应用层协议,基于请求-响应模型…

【含文档+PPT+源码】基于微信小程序的健康饮食食谱推荐平台的设计与实现

课程目标: 教你从零开始部署运行项目,学习环境搭建、项目导入及部署,含项目源码、文档、数据库、软件等资料 课程简介: 本课程演示的是一款基于微信小程序的健康饮食食谱推荐平台的设计与实现,主要针对计算机相关专…

Redis 慢查询分析与优化

Redis 慢查询分析与优化 参考书籍 : https://weread.qq.com/web/reader/d5432be0813ab98b6g0133f5kd8232f00235d82c8d161fb2 以下从配置参数、耗时细分、分析工具、优化策略四个维度深入解析 Redis 慢查询问题,结合实战调优建议,帮助开发者…

使用达梦官方管理工具SQLark快速生成数据库ER图并导出

在数据库设计与开发中,实体-关系图(ER 图)作为数据建模的核心工具,能够直观呈现表结构、字段属性及表间关系,是团队沟通和文档维护的重要工具。然而,许多开发者在实际工作中常面临一个痛点:手动…

模型 替罪羊效应

系列文章分享模型,了解更多👉 模型_思维模型目录。转嫁罪责于无辜,维系群体控制与稳定 1 替罪羊效应的应用 1.1 多品牌危机中的行业“背锅侠” 行业背景:食品行业爆发大规模安全危机,多家企业卷入某类食品重金属超标…

TapData × 梦加速计划 | 与 AI 共舞,TapData 携 AI Ready 实时数据平台亮相加速营,企业数据基础设施现代化

在实时跃动的数据节拍中,TapData 与 AI 共舞,踏出智能未来的新一步。 4月10日,由前海产业发展集团、深圳市前海梦工场、斑马星球科创加速平台等联合发起的「梦加速计划下一位独角兽营」正式启航。 本次加速营以“打造下一位独角兽企业”为目…

15.电感特性在EMC设计中的运用

电感特性在EMC设计中的运用 1. 共模电感与差模电感的差异2. 电感的高频等效特性![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/b4dc000672af4dd69a528450eb42cf10.png)3. 电感在EMC设计中的使用注意事项3.1 LC滤波计算3.2 并联型多级浪涌防护的电感退耦 1. 共模电感…

uniapp Vue2升级到Vue3,并发布到微信小程序的快捷方法

目录 前言:升级项目的两种方式步骤一、新建项目 【选择-默认模版】二、修改-pages.json三、补充-缺少的文件四、修改-Main.js按照 [官方文档-vue2升级vue3迁移指南](https://uniapp.dcloud.net.cn/tutorial/migration-to-vue3.html) 修改 五、升级-uni-ui扩展组件的…

数据重构如何兼顾效率与性能稳定?zStorage 全闪存分布式存储的技术实践与实测数据

点击蓝字 关注我们 zStorage 作为数据库场景下的全闪存分布式存储,除了性能要好,更重要的是要在各种情况下都能保持“稳定”的好。一个高并发的交易型业务数据库,如果出现轻微的IO抖动,就可能造成数据库并发事务提交的排队&#x…

A2A + MCP:构建实用人工智能系统的超强组合

构建真正有效的连接型人工智能系统的挑战 如果你正在构建人工智能应用,这种情况可能听起来很熟悉: 你需要特定的人工智能能力来解决业务问题。你找到了完成每个单独任务的出色工具。但把所有东西连接在一起却占据了大部分开发时间,还创建了…

力扣每日打卡17 49. 字母异位词分组 (中等)

力扣 49. 字母异位词分组 中等 前言一、题目内容二、解题方法1. 哈希函数2.官方题解2.1 前言2.2 方法一:排序2.2 方法二:计数 前言 这是刷算法题的第十七天,用到的语言是JS 题目:力扣 49. 字母异位词分组 (中等) 一、题目内容 给…

Word处理控件Spire.Doc系列教程:C# 为 Word 文档设置背景颜色或背景图片

在 Word 文档中,白色是默认的背景设置。一般情况下,简洁的白色背景足以满足绝大多数场景的使用需求。但是,如果您需要创建简历、宣传册或其他创意文档,设置独特的背景颜色或图片能够极大地增强文档的视觉冲击力。本文将演示如何使…