模型解释与可解释AI实战

news2025/3/28 14:05:23

一、为什么需要模型解释?

模型解释技术帮助:

  1. 理解模型决策依据(特征重要性)
  2. 调试模型错误预测
  3. 满足监管合规要求(金融/医疗)
  4. 提升用户对AI的信任
    本章使用Captum实现CV/NLP模型的可视化解释

二、环境准备与工具安装

!pip install captum torchvision matplotlib
import torch
import numpy as np
from captum.attr import IntegratedGradients, LayerGradCam
import matplotlib.pyplot as plt

三、图像分类解释实战(CIFAR-10)

1. 加载预训练模型

from torchvision.models import resnet18

model = resnet18(pretrained=True)
model.eval()

2. 准备测试图像

from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 加载示例图像(类别:ship)
from PIL import Image
img = Image.open("test_ship.jpg").convert("RGB")
input_tensor = transform(img).unsqueeze(0)

3. 集成梯度解释

def visualize_attr(attr, title):
    attr = attr.squeeze().cpu().detach().numpy()
    plt.imshow(attr, cmap='hot')
    plt.colorbar()
    plt.title(title)
    plt.show()

# 计算特征重要性
integrated_grad = IntegratedGradients(model)
attr_ig = integrated_grad.attribute(input_tensor, target=8)  # ship类别ID为8
visualize_attr(attr_ig.mean(dim=1), "Integrated Gradients")

4. Grad-CAM可视化

# 选择目标卷积层
target_layer = model.layer4.conv2

# 计算Grad-CAM
layer_gradcam = LayerGradCam(model, target_layer)
attr_gc = layer_gradcam.attribute(input_tensor, target=8)

# 可视化叠加效果
heatmap = np.clip(attr_gc.squeeze().cpu().detach().numpy(), 0, None)
heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())

orig_img = input_tensor.squeeze().permute(1,2,0).cpu().detach().numpy()
plt.imshow(orig_img * 0.5 + heatmap * 0.5)
plt.title("Grad-CAM Visualization")
plt.show()

四、文本分类解释实战(IMDB情感分析)

1. 加载情感分析模型

from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb")
model = AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb")

2. 构建解释器

from captum.attr import LayerIntegratedGradients

# 定义输入处理函数
def model_forward(input_ids, attention_mask=None):
    return model(input_ids, attention_mask).logits

# 初始化解释器
lig = LayerIntegratedGradients(
    model_forward,
    model.bert.embeddings
)

3. 计算词元重要性

text = "This movie is a complete disaster, full of terrible acting and pointless scenes."
inputs = tokenizer(text, return_tensors="pt")

# 计算基准值(空输入)
ref_input_ids = torch.tensor([tokenizer.cls_token_id] + [tokenizer.pad_token_id]*(inputs.input_ids.shape-2) + [tokenizer.sep_token_id], 
                            device='cpu').unsqueeze(0)

# 计算归因值
attributions, delta = lig.attribute(
    inputs=inputs.input_ids,
    baselines=ref_input_ids,
    additional_forward_args=(inputs.attention_mask,),
    return_convergence_delta=True,
    target=0  # 负面情感对应的类别
)

# 可视化结果
token_attributions = attributions.sum(dim=2).squeeze(0)
tokens = tokenizer.convert_ids_to_tokens(inputs.input_ids)

plt.figure(figsize=(12, 3))
plt.bar(range(len(tokens)), token_attributions.detach().numpy())
plt.xticks(range(len(tokens)), tokens, rotation=90)
plt.title("Token Importance Scores")
plt.show()

五、高级解释技巧

1. 对比解释(对比不同类别)

# 对比飞机(类别0)与鸟类(类别2)的解释差异
attr_plane = integrated_grad.attribute(input_tensor, target=0)
attr_bird = integrated_grad.attribute(input_tensor, target=2)

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
plt.imshow(attr_plane.mean(1).squeeze().detach().cpu(), cmap='hot')
plt.title('Airplane Attribution')

plt.subplot(1,2,2)
plt.imshow(attr_bird.mean(1).squeeze().detach().cpu(), cmap='hot')
plt.title('Bird Attribution')
plt.show()

2. 层次相关性传播(LRP)

from captum.attr import LRP

lrp = LRP(model)
attr_lrp = lrp.attribute(input_tensor, target=8)

plt.imshow(attr_lrp.mean(1).squeeze().detach().cpu(), cmap='hot')
plt.title('Layer-wise Relevance Propagation')
plt.show()

六、常见问题解答

Q1:如何安装最新版Captum?

pip install git+https://github.com/pytorch/captum.git

Q2:归因结果全为0怎么办?

  • 检查输入是否经过正确的归一化
  • 尝试不同的基线值(Baseline)
  • 验证模型是否真的使用该特征
import matplotlib
matplotlib.use('Agg')  # 无GUI模式

plt.ioff()
fig = plt.figure()
# ...生成图像...
fig.savefig('explanation.png', bbox_inches='tight')
plt.close(fig)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2321304.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

制作PaddleOCR/PaddleHub的Docker镜像

背景 在落地RAG知识库过程中,遇到了图文识别、图片表格内容识别的需求。但那时(2024年4月)各开源RAG项目还没有集成成熟的解决方案,经调研我选择了百度开源的PaddleOCR。支持国产! 概念梳理 PaddleOCR 百度飞桨的OCR…

Ubuntu部署Docker搭建靶场

前言 我们需要部署Docker来搭建靶场题目,他可以提供一个隔离的环境,方便在不同的机器上部署,接下来,我会记录我的操作过程,简单的部署一道题目 Docker安装 不推荐在物理机上部署,可能会遇到一些问题&…

【DFS】羌笛何须怨杨柳,春风不度玉门关 - 4. 二叉树中的深搜

本篇博客给大家带来的是二叉树深度优先搜索的解法技巧,在后面的文章中题目会涉及到回溯和剪枝,遇到了一并讲清楚. 🐎文章专栏: DFS 🚀若有问题 评论区见 ❤ 欢迎大家点赞 评论 收藏 分享 如果你不知道分享给谁,那就分享给薯条. 你们的支持是我不断创作的…

搭建Redis主从集群

主从集群说明 单节点Redis的并发能力是有上限的,要进一步提高Redis的并发能力,就需要搭建主从集群,实现读写分离。 主从结构 这是一个简单的Redis主从集群结构 集群中有一个master节点、两个slave节点(现在叫replica)…

WSL2增加memory问题

我装的是Ubuntu24-04版本,所有的WSL2子系统默认memory为主存的一半(我的电脑是16GB,wsl是8GB),可以通过命令查看: free -h #查看ubuntu的memory和swap (改过的11GB) 前几天由于配置E…

git 合并多次提交 commit

在工作中,有时候在反复修改代码中(比如处理MR的检视意见,或者为了推送到测试环境,先 commit到自己的远程分支上)不免会有多次 commit,这样发起 MR 的时候,就会有一堆 commit 信息,看…

如何分析和解决服务器的僵尸进程问题

### 如何分析和解决服务器的僵尸进程问题 #### **一、僵尸进程的定义与影响** **僵尸进程(Zombie Process)** 是已终止但未被父进程回收资源的进程。其特点: - **状态标识**:在进程列表(如 ps 或 top)中标…

XXL-Job 二次分片是怎么做的?有什么问题?怎么去优化的?

XXL-JOB二次分片机制及优化策略 二次分片实现原理 XXL-JOB的二次分片是在分片广播策略的基础上,由开发者自行实现的更细粒度数据拆分。核心流程如下: 初次分片:调度中心根据执行器实例数量(总分片数n)分配分片索引i&…

java版嘎嘎快充玉阳软件互联互通中电联云快充协议充电桩铁塔协议汽车单车一体充电系统源码uniapp

演示: 微信小程序:嘎嘎快充 http://server.s34.cn:1888/ 系统管理员 admin/123456 运营管理员 yyadmin/Yyadmin2024 运营商 operator/operator2024 系统特色: 多商户、汽车单车一体、互联互通、移动管理端(开发中) 另…

Spatial Multiplexing Power Save

802.11n中添加的PSMP,SMPS机制。 SM 节能功能可让 STA 在大部分时间内仅通过一条活动接收链运行,从而达到节能目的。 空间复用省电(Spatial Multiplexing Power Save)模式下,节点会关闭多余的天线,仅仅使用一根天线进…

2025年渗透测试面试题总结-某360-企业蓝军面试复盘 (题目+回答)

网络安全领域各种资源,学习文档,以及工具分享、前沿信息分享、POC、EXP分享。不定期分享各种好玩的项目及好用的工具,欢迎关注。 目录 360-企业蓝军 一、Shiro绕WAF实战方案 二、WebLogic遭遇WAF拦截后的渗透路径 三、JBoss/WebLogic反序…

C语言基础—函数指针与指针函数

函数指针 定义 函数指针本质上是指针,它是函数的指针(定义了一个指针变量,变量中存储了函数的地址)。函数都有一个入口地址,所谓指向函数的指针,就是指向函数的入口地址。这里函数名就代表入口地址。 函…

用DrissionPage升级网易云音乐爬虫:更稳定高效地获取歌单音乐(附原码)

一、传统爬虫的痛点分析 原代码使用requests re的方案存在以下局限性: 动态内容缺失:无法获取JavaScript渲染后的页面内容 维护成本高:网页结构变化需频繁调整正则表达式 反爬易触发:简单请求头伪造容易被识别 资源消耗大&am…

OpenCV图像拼接(5)构建图像的拉普拉斯金字塔 (Laplacian Pyramid)

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 算法描述 cv::detail::createLaplacePyr 是 OpenCV 中的一个函数,用于构建图像的拉普拉斯金字塔 (Laplacian Pyramid)。拉普拉斯金字塔是一种多…

通俗一点介绍什么是场外期权交易 ?

场外期权是交易所以外的市场进行交易的期权,主要由期货公司、证券公司等金融机构根据客户具体要求进行设计,最终由期货公司等机构与客户签订协议的形式进行,通俗一点理解场外期权就是股票做多的玩法交易,下文为大家科普通俗一点介…

蓝桥杯备考:图的遍历

这道题乍一看好像没什么不对的&#xff0c;但是&#xff01;但是&#xff01;结点最大可以到10的5次方&#xff01;&#xff01;&#xff01;我们递归的时间复杂度是很高的&#xff0c;我们正常遍历是肯定通过不了的&#xff0c;不信的话我们试一下 #include <iostream>…

IIS漏洞攻略

一&#xff0c;PUT漏洞 1&#xff0c;在windows server 2003 中开启 WebDAV 和写权限&#xff0c;然后访问并使用BP抓包 2&#xff0c;使用PUT上传一个木马文件&#xff0c;后缀要改成其他格式 3&#xff0c;将上传的木马文件的内容写入到asp文件中&#xff0c;然后进行连接即…

C++《红黑树》

在之前的篇章当中我们已经了解了基于二叉搜索树的AVL树&#xff0c;那么接下来在本篇当中将继续来学习另一种基于二叉搜索树的树状结构——红黑树&#xff0c;在此和之前学习AVL树类似还是通过先了解红黑树是什么以及红黑树的结构特点&#xff0c;接下来在试着实现红黑树的结构…

struts2框架漏洞攻略

S2-057远程执⾏代码漏洞 环境 vulhub靶场 /struts2/s2-057 漏洞简介 漏洞产⽣于⽹站配置XML时如果没有设置namespace的值&#xff0c;并且上层动作配置中并没有设置 或使⽤通配符namespace时&#xff0c;可能会导致远程代码执⾏漏洞的发⽣。同样也可能因为url标签没有设置…

8662 234的和

8662 234的和 ⭐️难度&#xff1a;中等 &#x1f31f;考点&#xff1a;模拟、二维前缀和 &#x1f4d6; &#x1f4da; import java.util.Arrays; import java.util.LinkedList; import java.util.Queue; import java.util.Scanner;public class Main {static int[] a ne…