【可解释性机器学习】基于ELI5使用解读LIME算法以及实战案例

news2024/12/25 13:06:07

LIME算法解读与实战案例

  • LIME论文简介
  • LIME算法原理
  • LIME算法要点
  • LIME的注意事项
  • LIME的代码实现
    • 对Pytorch搭建的模型进行解释
      • 使用LIME解释Pytorch构建的模型
  • 参考资料

LIME论文简介

LIME的全称为Local Interpretable Model-agnostic Explanations.
LIME
尽管被广泛采用,机器学习模型仍然主要是黑匣子。 然而,了解预测背后的原因对于评估信任非常重要,如果一个人计划根据预测采取行动,或者在选择是否部署新模型时,信任是基础。 这种理解还提供了对模型的洞察力,可用于将不可信的模型或预测转换为可信的模型。

在这项工作中,作者提出了 LIME,这是一种新颖的解释技术,通过在预测局部学习可解释模型,以可解释和忠实的方式解释任何分类器的预测。 作者还提出了一种通过以非冗余方式呈现具有代表性的个体预测及其解释来解释模型的方法,将任务构建为子模块优化问题。

通过解释文本(例如随机森林)和图像分类(例如神经网络)的不同模型来展示这些方法的灵活性。作者通过新的实验(包括模拟实验和人类受试者)在需要信任的各种场景中展示解释的效用:决定一个人是否应该相信一个预测,在模型之间进行选择,改进一个不可信的分类器,以及确定为什么一个分类器不应该被信任。
Explaining individual predictions

LIME算法原理

LIME的想法很简单,希望使用简单的模型来对复杂的模型进行解释。这里简单的模型可以是线性模型,因为我们可以通过查看线性模型的系数大小来对模型进行解释。 在这里, LIME只会对每一个样本进行解释(explain individual predictions).

LIME会产生一个新的数据集(这个数据集我们是通过对某一个样本数据进行变换得到),接着在这个新的数据集上, 我们训练一个简单模型(容易解释的模型), 我们希望简单模型在新数据集上的预测结果和复杂模型在该数据集上的预测结果是相似的。可以将我们的问题表达为下面的表达式:
explanation ( x ) = arg min ⁡ g ∈ G L ( f , g , π x ) + Ω ( g ) \text{explanation}(x)=\text{arg}\min_{g\in G}L(f,g,\pi_x)+\Omega (g) explanation(x)=arggGminL(f,g,πx)+Ω(g)
其中:

  • f f f表示原始的模型,即需要解释的模型
  • g g g表示简单的模型, G G G是简单模型的一个集合,例如所有可能的线性模型。
  • π x \pi_x πx表示新数据集中的数据 x ′ x' x与原始数据 x x x的距离
  • Ω ( g ) \Omega(g) Ω(g)表示模型 g g g的复杂程度。

希望原始模型 f f f与新模型 g g g预测值之间的误差是小的。简单来说,可以通过下面的式子来衡量两个式子预测值之间的差:
L ( f , g , w y ) = ∑ i = 1 N w y ( z i ) ( f ( z i ) − g ( z i ′ ) ) 2 \mathcal{L}(f, g, w^y)=\sum_{i=1}^N w^y(z_i)(f(z_i)-g(z'_i))^2 L(f,g,wy)=i=1Nwy(zi)(f(zi)g(zi))2
于是整个LIME的步骤如下(即训练模型 g g g的步骤):
①选择想要解释的变量x;
②对数据集中的数据进行扰动得到新的数据,同时计算出黑盒模型对这些新的数据的预测值;
③对这些新的sample求出权重,这个权重是这些数据点与我们要解释的数据之间的距离;
④根据上面新的数据集,预测值和权重,训练出模型 g g g
⑤通过对模型 g g g来对模型 f f f在x点附近进行解释。

那么我们如何对数据集进行扰动来得到新的数据, 对于表格数据, 我们可以分别扰动每一个特征, 从一个正态分布(均值和方差为这个特征的均值和方差)中进行随机抽样. 这样做会有一个问题, 即不是从我们要解释的数据为中心进行采样, 而是从整个数据集的中心进行采样. (LIME samples are not taken around the instance of interest, but from the training data’s mass center, which is problematic.)

通过一张图片来对上面的过程进行解释(这张图片是上面第一个参考链接中的).
对数据集进行扰动
图A表示: 随机森林的分类结果, 颜色深的为一类, 颜色浅的为一类; 图B表示: 通过对特征进行扰动得到的新的数据集;图C表示: 对每一个数据进行加权;图D表示: 对简单的模型进行求解。

存在的问题: 定义我们要解释的点的周围是困难的.

LIME算法要点

LIME (Ribeiro et al。2016)是一种解释黑盒估计量预测的算法:

  1. 根据我们将要解释的例子生成一个假的数据集。
  2. 使用黑盒估计器为生成的数据集中的每个示例获取目标值(例如,类概率)。
  3. 训练一个新的白盒估计器,使用生成的数据集和生成的标签作为训练数据。这意味着我们正在尝试创建一个估计器,它的工作原理与黑盒估计器相同,但是更容易检查。它不必在全局范围内工作得很好,但是它必须在接近原始示例的区域内很好地近似黑盒模型。
    要表示“接近原始示例的区域”,用户必须为生成的数据集中的示例提供距离/相似度度量。然后根据训练数据与原始样本之间的距离进行加权——样本越远,训练数据对白盒估计器权值的影响越小
  4. 通过这个白盒估计器的权重来解释原来的例子。
  5. 白盒分类器的预测质量显示了它对黑盒分类器的近似程度。如果质量低,那么解释就不可信。

LIME的注意事项

  1. 如果白盒估计器在生成的数据集上获得高分,并不一定意味着它可以被信任——这也可能意味着生成的数据集过于简单和统一,或者用户提供的相似性度量为大多数示例分配了非常低的值,因此“接近原始示例的区域”太小而不有趣。
  2. 假数据集生成是主要问题; 它在很大程度上是特定于任务的。 所以 LIME 可以与任何黑盒分类器一起工作,但用户可能需要为每个数据集编写特定的代码。 检查模型权重有一个相反的权衡:它适用于任何任务,但必须为每种估计器类型编写检查代码
    eli5.lime为文本数据(删除随机词)和任意数据(使用核密度估计采样)提供数据集生成实用程序。
    对于文本数据,eli5 还提供了 eli5.lime.TextExplainer,它汇集了所有 LIME 步骤并允许解释文本分类器; 它仍然需要对分类器做出假设才能生成有效的假数据集
  3. 相似性度量对结果有巨大影响。 通过选择不同大小的邻域,可以得到相反的解释。

LIME的代码实现

有一个由 LIME 作者实现的 LIME:https://github.com/marcotcr/lime,所以 eli5.lime 应该被视为替代品。 在撰写本文时,eli5.lime 与规范的 LIME 实现有一些差异:

  1. eli5 支持来自多个库的许多白盒分类器,可以将它们中的任何一个与 LIME 一起使用;
  2. eli5 支持使用核密度估计生成数据集,以确保生成的数据集看起来与原始数据集相似;
  3. 为了解释概率分类器的预测,eli5 默认使用另一个分类器,使用交叉熵损失进行训练,而标准库在概率输出上拟合回归模型。

对Pytorch搭建的模型进行解释

首先是描述Pytorch的完整训练的过程。这里使用Iris dataset作为数据集来搭建一个多分类的网络.

import pandas as pd
import numpy as np
from sklearn import datasets

from sklearn.preprocessing import StandardScaler
%matplotlib inline

导入Pytorch相关库:

import torch
import torch.nn as nn
import torch.utils.data as Data

数据导入:

iris_datas = datasets.load_iris()
df = pd.concat([pd.DataFrame(iris_datas.data), pd.DataFrame(iris_datas.target)], axis=1)
df.columns = iris_datas.feature_names+['target']
df = df.sample(frac=1) # 打乱顺序
df.head()

数据导入
数据处理,包括数据集划分、数据标准化,并将数据转换为tensor类型:
数据集划分
构建dataset和dataloader:

batch_size = 10
train_dataset = Data.TensorDataset(X_train, Y_train)
test_dataset = Data.TensorDataset(X_test, Y_test)

# Data Loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

构建网络模型,这里为简单起见,使用全连接网络:

# 网络的定义
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, n_layers):
        super(NeuralNet, self).__init__()
        layers = []
        for i in range(n_layers):
            layers.append(nn.Linear(hidden_size, hidden_size))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.Dropout(0.3))
        self.inLayer = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.hiddenLayer = nn.Sequential(*layers)
        self.outLayer = nn.Linear(hidden_size, num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        out = self.inLayer(x)
        out = self.relu(out)
        out = self.hiddenLayer(out)
        out = self.outLayer(out)
        out = self.softmax(out)
        return out

接着对上面的网络进行初始化:
模型初始化
网络训练,包括定义损失函数和优化函数:

# 网络的训练
num_epochs = 30
learning_rate = 0.001

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model.train()
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, (data, labels) in enumerate(train_loader):
        outputs = model(data)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i + 1) %5 ==0:
            correct = 0
            total = 0
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += sum(predicted == labels).item()
            acc = 100 * correct/total
            print ('Epoch [{}/{}], Step [{}/{}], Accuracy: {}, Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, total_step, acc, loss.item()))
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data, labels in test_loader:
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += sum(predicted==labels).item()
        print('Accuracy of the network test dataset: {} %'.format(100 * correct / total))
        print('-'*10)
'''
----------
Epoch [19/30], Step [5/12], Accuracy: 100.0, Loss: 0.6313
Epoch [19/30], Step [10/12], Accuracy: 100.0, Loss: 0.6491
Accuracy of the network test dataset: 93.33333333333333 %
----------
Epoch [20/30], Step [5/12], Accuracy: 100.0, Loss: 0.5853
Epoch [20/30], Step [10/12], Accuracy: 90.0, Loss: 0.6876
Accuracy of the network test dataset: 90.0 %
'''

使用LIME解释Pytorch构建的模型

使用LIME来解释Pytorch的模型主要有下面的几个步骤:

  • 定义预测函数
  • 创建解释器
  • 对某一个样本给出解释

首先定义预测函数:

# 定义预测函数
def batch_predict(data, model=model):
    """
    :param data: 需要预测的数据
    :param model: Pytorch训练的模型,**这里需要有默认的模型**
    :return:
    """
    X_tensor = torch.from_numpy(data).float()
    model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    X_tensor = X_tensor.to(device)
    logits = model(X_tensor)
    probs = torch.nn.functional.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

简单测试一下,输入数据, 出来的是每一类的概率.:
测试预测函数
创建解释器,用来对后面的样本进行解释:

from lime.lime_tabular import LimeTabularExplainer
# 创建解释器
targets = iris_datas.target_names
features_names = iris_datas.feature_names
explainer = LimeTabularExplainer(X, feature_names=features_names, class_names=targets, discretize_continuous=True)

解释某一个样本, 这里是对第5个样本进行解释.:

# 解释某个样本
exp = explainer.explain_instance(X[5], batch_predict, num_features=5, top_labels=5)
# 结果可视化
# exp.show_in_notebook(show_table=True, show_all=False) # 代码无效
exp.save_to_file('../Results/exp.html') # 保存为HTML文件,用浏览器打开即可

结果可视化。可视化的内容会包括是某一类的原因(或是不是某一类的原因), 比如对于Iris dataset来说, 会分别给出三张图. 如下所示.
结果可视化
如上图所示, 对于virginica来说, 模型给出了是这个分类的原因, 例如因为petal width>-1.23, 这就是给出了一个模型判断的原因。除了上面的画图方式, 我们还可以使用下面的画图方式, 只画出某一个类别的判断的可能性:
某一个类别的重要性

参考资料

[1] “Why Should I Trust You?”: Explaining the Predictions of Any Classifier
[2] 模型解释-LIME的原理和实现
[3] Pytorch例子演示及LIME使用例子

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/161042.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

模板进阶篇

一、非类型模板参数 模板参数分类类型形参与非类型形参。 类型形参:出现在模板参数列表中,跟在class或者typename之类的参数类型名称。如图: 非类型形参:就是用一个常量作为类(函数)模板的一个参数,在类(函数)模板中可…

Mybatis 原理之启动阶段

文章目录1.MyBatis 核心流程2.启动准备阶段流程3.创建 SQlSessionFactory4.创建XMLConfigBuilder5.创建 XPathParser6.解析并设置 configuration 中的属性7.解析Mappers标签1.MyBatis 核心流程 Mybatis的核心流程氛围两个阶段,启动准备阶段和执行SQL阶段。 加载配…

Day858.高性能网络应用框架Netty -Java 并发编程实战

高性能网络应用框架Netty Hi,我是阿昌,今天学习记录的是关于高性能网络应用框架Netty的内容。 Netty 是一个高性能网络应用框架,应用非常普遍,目前在 Java 领域里,Netty 基本上成为网络程序的标配了。 Netty 框架功…

win10录屏软件哪款比较好用?一款不限时长的录屏软件

现在大部分人的电脑都是win10系统的电脑,也有许多小伙伴会经常会问:“win10电脑怎么录屏?”录制电脑屏幕,需要使用到录屏软件,那win10录屏软件哪款比较好用?小编今天给大家分享一款试用版即可不限录制时长的…

【区间合并】洛谷 P1496 火烧赤壁

P1496 火烧赤壁 文章目录题目背景题目描述输入格式:输出格式:数据范围输入样例输出样例方法:区间合并解题思路代码复杂度分析:题目背景 曹操平定北方以后,公元 208 年,率领大军南下,进攻刘表。…

部分时变离散系统中的稳定性判据

部分时变离散系统中的稳定性判据 1.Lyapunov稳定性理论 下面先给出Lyapunov稳定性的一些基本理论(网上资源较多这里不再过多赘述): 2.一类时变离散系统的稳定性 定理 ​ 对于离散时变系统x(k1)A(k)x(k)x(k1)A(k)x(k)x(k1)A(k)x(k)&#x…

Java EE|多线程代码实例之单例模式与阻塞队列

文章目录前言设计模式介绍🔴单例模式什么是单例模式单例模式实现方式饿汉模式懒汉模式基于上述单例模式实现线程安全问题讨论重点回顾🔴阻塞队列阻塞队列是什么标准库中的阻塞队列典型应用场景:生产者消费者模型利用系统提供的BlockingQueue实…

osg fbo(三),将颜色缓冲区图片通过shader变绿

这个其实很简单, 一,写顶点着色器和片元着色器 static const char * vertexShader { “void main(void)\n” “{\n” " gl_Position ftransform();\n" “}\n” }; static const char *psShader { “uniform float alpha;” “void main(vo…

12、ThingsBoard-如何配置发送邮件

1、概述 ThingsBoard提供了系统层设置邮件配置和租户层通过设置邮件规则节点,对规则引擎产生的告警进行分发这两种邮件配置,其中系统层设置邮件配置主要是针对用于向用户分发激活和密码重置电子邮件;租户层通过设置邮件规则节点是针对告警通知的;一定要区别开这两个邮件配…

SpringBoot整合SpringSecurity实现进行认证和授权。

目录 2.在子工程通过easyCode创建项目相关包和文件 3.子项目新建Controllter层,并建立BlogLoginController.java 4.在servic 层定义login 方法,并new UsernamePasswordAuthenticationToken对象,传入对应用户名,密码 5.自定义实…

Java集合(进阶)

Java集合Collection集合体系结构CollectionCollection系列集合三种遍历方式List泛型泛型类泛型方法泛型接口泛型的继承和通配符SetHashSetTreeSet总结:Map(双列集合)HashMapLinkedHashMapTreeMap可变参数集合工具类Collections集合嵌套案例不…

打破应用孤岛,iPaaS连接全域新协作

“据全球知名的咨询平台Garner分析,集成平台将在企业数字化转型过程中扮演重要的角色,企业内外应用的打通成为推动企业快速实现数字化转型的重要因素之一。SaaS 的井喷式发展也带来了新的机遇与挑战,企业亟需新的集成方法和手段帮助解决自身问…

吴恩达【神经网络和深度学习】Week4——深层神经网络

文章目录Deep Neural Network1、Deep L-layer Neural Network2、Forward Propagation in a Deep Network3、Getting your matrix dimensions right4、Why deep representations?5、 Building blocks of deep neural networks6、 Forward and Backward Propagation7、Parameter…

【Ctfer训练计划】——(十一)

作者名:Demo不是emo主页面链接: 主页传送门创作初心: 舞台再大,你不上台,永远是观众,没人会关心你努不努力,摔的痛不痛,他们只会看你最后站在什么位置,然后羡慕或鄙夷座右…

最新版wifi营销分销流量主前后端+小程序源码+搭建教程

前端后端数据库搭建教程,无任何密码,亲测能用,避免踩坑,v:JZ716888 教程如下: 安装源码到根目录 1、网站运行目录public 2、PHP7.2,开通SSL 3、导入数据库文件 4、修改数据库文件里applic…

【十一】Netty UDP协议栈开发

Netty UDP协议栈开发介绍协议简介伪首部UDP协议的特点开发jar依赖UDP 服务端启动类服务端业务处理类客户端启动类客户端业务处理类代码说明测试服务端打印截图:客户端打印截图:测试结果总结介绍 UDP 是用户数据报协议(User Datagram Protocol) 的简称,其…

【Azure 架构师学习笔记】-Azure Logic Apps(4)-演示2

本文属于【Azure 架构师学习笔记】系列。 本文属于【Azure Logic Apps】系列。 接上文[【Azure 架构师学习笔记】-Azure Logic Apps(3)-演示1] (https://blog.csdn.net/DBA_Huangzj/article/details/128542539) 前言 上文做了简单的演示,这一…

【Flutter】关于Button 的那些知识ElevatedButton等,以及Buttonstyle

文章目录前言一、Button是什么?二、开始使用button1.ElevatedButton1.无style 的ElevatedButton2.基础功能的处理之后的button3.利用buttonstyle 来美化下button2.IconButton,TextButton基础功能都是一样的三、做几个好看点的按键总结前言 一、Button是什…

【设计模式】七大设计原则

设计模式学习之旅(二) 查看更多可关注后查看主页设计模式DayToDay专栏 在软件开发中,为了提高软件系统的可维护性和可复用性,增加软件的可扩展性和灵活性,程序员要尽量根据7条原则来开发程序,从而提高软件开发效率、节约软件开发成…

SAP 详细解析在建工程转固定资产

由固定资产归口采购部门或业务部门提交购置固定资产/在建工程的申请,经审批后,若是需要安装调试,则由财务部固定资产会计建立内部订单收集成本,月末结转在建工程。项目完工后,相关部门(公司装备部、分公司装…