在 PyTorch 中实现可解释的神经网络模型

news2024/12/25 22:24:20

动动发财的小手,点个赞吧!

alt

目的

深度学习系统缺乏可解释性对建立人类信任构成了重大挑战。这些模型的复杂性使人类几乎不可能理解其决策背后的根本原因。

深度学习系统缺乏可解释性阻碍了人类的信任。

为了解决这个问题,研究人员一直在积极研究新的解决方案,从而产生了重大创新,例如基于概念的模型。这些模型不仅提高了模型的透明度,而且通过在训练过程中结合高级人类可解释的概念(如“颜色”或“形状”),培养了对系统决策的新信任感。因此,这些模型可以根据学习到的概念为其预测提供简单直观的解释,从而使人们能够检查其决策背后的原因。这还不是全部!它们甚至允许人类与学习到的概念进行交互,让我们能够控制最终的决定。

基于概念的模型允许人类检查深度学习预测背后的推理,并让我们重新控制最终决策。

这篇博文[1]中,我们将深入研究这些技术,并为您提供使用简单的 PyTorch 接口实现最先进的基于概念的模型的工具。通过实践经验,您将学习如何利用这些强大的模型来增强可解释性并最终校准人类对您的深度学习系统的信任。

概念瓶颈模型

在这个介绍中,我们将深入探讨概念瓶颈模型。这模型在 2020 年国际机器学习会议上发表的一篇论文中介绍,旨在首先学习和预测一组概念,例如“颜色”或“形状”,然后利用这些概念来解决下游分类任务:

alt

通过遵循这种方法,我们可以将预测追溯到提供解释的概念,例如“输入对象是一个{apple},因为它是{spherical}和{red}。”

概念瓶颈模型首先学习一组概念,例如“颜色”或“形状”,然后利用这些概念来解决下游分类任务。

实现

为了说明概念瓶颈模型,我们将重新审视著名的 XOR 问题,但有所不同。我们的输入将包含两个连续的特征。为了捕捉这些特征的本质,我们将使用概念编码器将它们映射为两个有意义的概念,表示为“A”和“B”。我们任务的目标是预测“A”和“B”的异或 (XOR)。通过这个例子,您将更好地理解概念瓶颈如何在实践中应用,并见证它们在解决具体问题方面的有效性。

我们可以从导入必要的库并加载这个简单的数据集开始:

import torch
import torch_explain as te
from torch_explain import datasets
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

x, c, y = datasets.xor(500)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)

接下来,我们实例化一个概念编码器以将输入特征映射到概念空间,并实例化一个任务预测器以将概念映射到任务预测:

concept_encoder = torch.nn.Sequential(
    torch.nn.Linear(x.shape[1], 10),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(108),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(8, c.shape[1]),
    torch.nn.Sigmoid(),
)
task_predictor = torch.nn.Sequential(
    torch.nn.Linear(c.shape[1], 8),
    torch.nn.LeakyReLU(),
    torch.nn.Linear(81),
)
model = torch.nn.Sequential(concept_encoder, task_predictor)

然后我们通过优化概念和任务的交叉熵损失来训练网络:

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
loss_form_c = torch.nn.BCELoss()
loss_form_y = torch.nn.BCEWithLogitsLoss()
model.train()
for epoch in range(2001):
    optimizer.zero_grad()

    # generate concept and task predictions
    c_pred = concept_encoder(x_train)
    y_pred = task_predictor(c_pred)

    # update loss
    concept_loss = loss_form_c(c_pred, c_train)
    task_loss = loss_form_y(y_pred, y_train)
    loss = concept_loss + 0.2*task_loss

    loss.backward()
    optimizer.step()

训练模型后,我们评估其在测试集上的性能:

c_pred = concept_encoder(x_test)
y_pred = task_predictor(c_pred)

concept_accuracy = accuracy_score(c_test, c_pred > 0.5)
task_accuracy = accuracy_score(y_test, y_pred > 0)

现在,在几个 epoch 之后,我们可以观察到概念和任务在测试集上的准确性都非常好(~98% 的准确性)!

由于这种架构,我们可以通过根据输入概念查看任务预测器的响应来为模型预测提供解释,如下所示:

c_different = torch.FloatTensor([01])
print(f"f({c_different}) = {int(task_predictor(c_different).item() > 0)}")

c_equal = torch.FloatTensor([11])
print(f"f({c_different}) = {int(task_predictor(c_different).item() > 0)}")

这会产生例如 f([0,1])=1 和 f([1,1])=0 ,如预期的那样。这使我们能够更多地了解模型的行为,并检查它对于任何相关概念集的行为是否符合预期,例如,对于互斥的输入概念 [0,1] 或 [1,0],它返回的预测y=1。

概念瓶颈模型通过将预测追溯到概念来提供直观的解释。

淹没在准确性与可解释性的权衡中

概念瓶颈模型的主要优势之一是它们能够通过揭示概念预测模式来为预测提供解释,从而使人们能够评估模型的推理是否符合他们的期望。

然而,标准概念瓶颈模型的主要问题是它们难以解决复杂问题!更一般地说,他们遇到了可解释人工智能中众所周知的一个众所周知的问题,称为准确性-可解释性权衡。实际上,我们希望模型不仅能实现高任务性能,还能提供高质量的解释。不幸的是,在许多情况下,当我们追求更高的准确性时,模型提供的解释往往会在质量和忠实度上下降,反之亦然。

在视觉上,这种权衡可以表示如下:

alt

可解释模型擅长提供高质量的解释,但难以解决具有挑战性的任务,而黑盒模型以提供脆弱和糟糕的解释为代价来实现高任务准确性。

为了在具体设置中说明这种权衡,让我们考虑一个概念瓶颈模型,该模型应用于要求稍高的基准,即“三角学”数据集:

x, c, y = datasets.trigonometry(500)
x_train, x_test, c_train, c_test, y_train, y_test = train_test_split(x, c, y, test_size=0.33, random_state=42)

在该数据集上训练相同的网络架构后,我们观察到任务准确性显着降低,仅达到 80% 左右。

概念瓶颈模型未能在任务准确性和解释质量之间取得平衡。

这就引出了一个问题:我们是永远被迫在准确性和解释质量之间做出选择,还是有办法取得更好的平衡?

Reference

[1]

Source: https://towardsdatascience.com/implement-interpretable-neural-models-in-pytorch-6a5932bdb078

本文由 mdnice 多平台发布

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/668767.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

chatgpt赋能python:Python中算法的几种描述方法

Python中算法的几种描述方法 在Python中,我们可以采用不同的方法来描述和实现不同的算法。本文将介绍三种常见的描述算法的方法,希望能够帮助读者更好地理解算法和Python编程。 方法一:自然语言描述 自然语言是我们最熟悉的方式来描述算法…

三层交换器与可配置的二层交换机通信配置(华为交换机)

#三层交换器与可配置的二层交换机通信配置 三层交换机配置 #进入系统视图 <Huawei>system-view #关闭系统提示信息 [Huawei]undo info-center enable #启动DHCP功能 [Huawei]dhcp enable #创建vlan 10 并配置 vlanif 地址 作为二层交换机默认网关 [Huawei]vlan 10 …

nodejs高版本降为低版本的详细解决方案

部分老旧项目需要使用低版本的node,网上很多是无效的,高版本无法直接安装低版本node,但是低版本nodejs可以安装部分高版本node,从而达到升级效果,下面这篇文章主要给大家介绍了关于nodejs高版本降为低版本的详细解决方案,需要的朋友可以参考下 1.首先通过控制面板应用卸载当前环…

如何使用 Swagger2 自动生成 RESTful API 文档

如何使用 Swagger2 自动生成 RESTful API 文档 在开发 RESTful API 的过程中&#xff0c;文档是非常重要的一部分。它可以帮助开发者了解 API 的功能和使用方法&#xff0c;同时也是接口设计和测试的重要依据。而手动编写 API 文档往往比较耗时且容易出错&#xff0c;这时候 S…

【kubernetes】部署controller-manager与kube-scheduler

前言:二进制部署kubernetes集群在企业应用中扮演着非常重要的角色。无论是集群升级,还是证书设置有效期都非常方便,也是从事云原生相关工作从入门到精通不得不迈过的坎。通过本系列文章,你将从虚拟机准备开始,到使用二进制方式从零到一搭建起安全稳定的高可用kubernetes集…

NodeJS 配置HTTPS协议证书⑩⑤

文章目录 ✨文章有误请指正&#xff0c;如果觉得对你有用&#xff0c;请点三连一波&#xff0c;蟹蟹支持&#x1f618;前言HTTPS ❓配置证书工具 CertbotCertbot 使用步骤总结 ✨文章有误请指正&#xff0c;如果觉得对你有用&#xff0c;请点三连一波&#xff0c;蟹蟹支持&…

chatgpt赋能python:Python描述符详解

Python 描述符详解 Python 中的描述符是一种机制&#xff0c;用于控制属性&#xff08;attribute&#xff09;的访问。通过描述符&#xff0c;我们可以在属性被访问或者修改时加入特定的逻辑。 描述符是一个类&#xff0c;其中至少定义了以下三个方法&#xff1a; __get__(s…

docker入门手册

目录 1. docker基础 1. 1 docker介绍 1.2 docker架构与核心组件 1.3 docker安装和卸载 安装 卸载 docker加速器设置 -> 可选 1.4 权限问题 1.5 docker服务相关操作命令 2. docker镜像管理 2.1 镜像的搜索/获取/查看 镜像搜索 2.2 镜像别名/删除 2.3 镜像的导入…

Apache Atlas产品调研

元数据产品调研 &#x1f4a1; 思考可以构成一座桥&#xff0c;让我们通向新知识。—— 普朗克 一、什么是元数据 元数据是关于数据的数据&#xff0c;是为了描述数据的相关信息而存在的数据。 元数据是用数据管理数据&#xff0c;是快速查找数据、精确定位数据、准确理解数据…

CSS3-三大特性-继承性、层叠性、优先级

CSS三大特性 1 继承性 2 层叠性 3 优先级 1 继承性 特性&#xff1a;子元素有默认继承父元素样式的特点&#xff08;子承父业&#xff09; 可以继承的常见属性&#xff1a; 1 color 2 font-style、font-weight、font-size、font-family 3 text-indent、text-align 4 line-heigh…

io.netty学习(七)字节缓冲区 ByteBuf(下)

目录 前言 实现原理 ByteBuf 的使用案例 ByteBuf 的3种使用模式 堆缓冲模式 直接缓冲区模式 复合缓冲区模式 总结 前言 在了解了 ByteBuffer 的原理之后&#xff0c;再来理解Netty 的 ByteBuf 就比较简单了。 ByteBuf 是 Netty 框架封装的数据缓冲区&#xff0c;区别…

第5章 函数式编程

第5章 函数式编程 5.1 函数式编程思想 在之前的学习中&#xff0c;我们一直学习的就是面向对象编程&#xff0c;所以解决问题都是按照面向对象的方式来处理的。比如用户登陆&#xff0c;但是接下来&#xff0c;我们会学习函数式编程&#xff0c;采用函数式编程的思路来解决问…

【Python 随练】阶乘累加计算

题目&#xff1a; 求 12!3!…20!的和。 简介&#xff1a; 在本篇博客中&#xff0c;我们将解决一个数学问题&#xff1a;计算阶乘序列的和。我们将介绍阶乘的概念&#xff0c;并提供一个完整的代码示例来计算给定范围内阶乘数的和。 问题分析&#xff1a; 我们需要计算从1…

基于“遥感+”融合技术在碳储量、碳收支、碳循环等多领域监测与模拟

卫星遥感具有客观、连续、稳定、大范围、重复观测的优点&#xff0c;已成为监测全球碳盘查不可或缺的技术手段&#xff0c;卫星遥感也正在成为新一代 、国际认可的全球碳核查方法。本内容目的就是梳理碳中和与碳达峰对卫星遥感的现实需求&#xff0c;系统总结遥感技术在生态系统…

开源开放 | 开源知识图谱抽取工具发布大模型版DeepKE-LLM

DeepKE-LLM链接&#xff1a; https://github.com/zjunlp/DeepKE/tree/main/example/llm OpenKG地址&#xff1a; http://openkg.cn/tool/deepke Gitee地址&#xff1a; https://gitee.com/openkg/deepke/tree/main/example/llm 开放许可协议&#xff1a;Apache-2.0 license 贡献…

计算机网络核心

1、OSI开放式互联参考模型 1、物理层&#xff1a;机械、电子、定时接口通信信道上的原始比特流传输。2、数据链路层&#xff1a;物理寻址&#xff0c;同时将原始比特流转变为逻辑传输线路。3、网络层&#xff1a;控制子网的运行&#xff0c;如逻辑编址、分组传输、路由选择(IP…

chatgpt赋能python:Python捕获Ctrl+C信号

Python 捕获 CtrlC 信号 在 Python 中&#xff0c;我们通过按下键盘上的 CtrlC 快捷键可以中断程序的运行&#xff0c;但是在某些情况下&#xff0c;我们希望程序在收到 CtrlC 信号后进行一些特殊的处理&#xff0c;而非直接退出或崩溃。这就需要捕获 CtrlC 信号&#xff0c;并…

前端Vue自定义简单实用中国省市区三级联动选择器

前端Vue自定义简单实用中国省市区三级联动选择器&#xff0c; 请访问uni-app插件市场地址&#xff1a;https://ext.dcloud.net.cn/plugin?id13118 效果图如下&#xff1a; #### 使用方法 使用方法 <!-- themeColor:主题颜色 ref:设置唯一ref pickerValueDefault:默认选择…

C++技能系列 ( 6 ) - 可调用对象、std::function与std::bind【详解】

系列文章目录 C技能系列 Linux通信架构系列 C高性能优化编程系列 深入理解软件架构设计系列 高级C并发线程编程 期待你的关注哦&#xff01;&#xff01;&#xff01; 现在的一切都是为将来的梦想编织翅膀&#xff0c;让梦想在现实中展翅高飞。 Now everything is for the…

单片机如何生成周期正弦波

一&#xff0c;简介 在某些场景需要使用单片机的IIS等外设播放正弦波音频数据。本文介绍一种“笨方法”来生成固定频率和固定幅度的正弦波定点型数据&#xff0c;记录总结学习使用。 二&#xff0c;步骤简介 总体步骤概述&#xff1a; 1&#xff0c;使用Audition生成制定波形…