【学习打卡05】可解释机器学习笔记之CAM+Captum代码实战

news2025/1/10 12:16:21

可解释机器学习笔记之CAM+Captum代码实战

文章目录

  • 可解释机器学习笔记之CAM+Captum代码实战
    • 代码实战介绍
    • torch-cam工具包
      • 可视化CAM类激活热力图
      • 预训练ImageNet-1000图像分类-单张图像
      • 视频以及摄像头预测
    • pytorch-grad-cam工具包
      • Grad-CAM热力图可解释性分析
      • 基于Guided Grad-CAM的高分辨率细粒度可解释性分析
    • Captum的工具包
      • 遮挡可解释性分析
        • 中等遮挡滑块
        • 大遮挡滑块
        • 小遮挡滑块
      • Integrated Gradients可解释性分析
    • 总结
    • 参考阅读

首先非常感谢同济子豪兄拍摄的可解释机器学习公开课,并且免费分享,这门课程,包含人工智能可解释性、显著性分析领域的导论、算法综述、经典论文精读、代码实战、前沿讲座。由B站知名人工智能科普UP主“同济子豪兄”主讲。 课程主页: https://github.com/TommyZihao/zihao_course/blob/main/XAI 一起打开AI的黑盒子,洞悉AI的脑回路和注意力,解释它、了解它、改进它,进而信赖它。知其然,也知其所以然。这里给出链接,倡导大家一起学习, 别忘了给子豪兄点个关注哦。

学习GitHub 内容链接:
https://github.com/TommyZihao/zihao_course/tree/main/XAI

B站视频合集链接:
https://space.bilibili.com/1900783/channel/collectiondetail?sid=713364

代码实战介绍

在前面经过4个知识的学习之后,已经对可解释机器学习有了一定的了解,但是这些有什么用呢,最重要的当然是代码实战,所以这一部分学习的就是CAM和Captum的一些可视化的代码实战,能将理论和代码结合起来,方便我们理解和学习。

所有的代码都已经分享,都在子豪兄的Github中,这是代码的Github:https://github.com/TommyZihao/Train_Custom_Dataset,可以用pytorch训练自己的图像分类模型,基于torch-cam实现各个类别、单张图像、视频文件、摄像头实时画面的CAM可视化

torch-cam工具包

这里介绍一些主要的一些可视化的用法,具体的操作和方法,在视频和代码中都有体现

可视化CAM类激活热力图

预训练ImageNet-1000图像分类-单张图像

首先可以可视化出我们的类激活图

activation_map = cam_extractor(pred_id, pred_logits)
activation_map = activation_map[0][0].detach().cpu().numpy()
plt.imshow(activation_map)
plt.show()

png

后续根据类激活图,和原有的图片进行叠加,就能得到最后的图片

from torchcam.utils import overlay_mask

result = overlay_mask(img_pil, Image.fromarray(activation_map), alpha=0.7)
result

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0OIPlknr-1671639191799)(https://relph1119.github.io/my-team-learning/Interpretable_machine_learning44/task05/output_20_0.png)]

除此之外,我们还能固定可视化的类别,这样就可以展示出来我们想要的类别了。

img_path = './test_img/cat_dog.jpg'
# 可视化热力图的类别ID,如果为 None,则为置信度最高的预测类别ID

# 边牧犬
show_class_id = 231

# 是否显示中文类别
Chinese = True
def get_cam(img_pil, test_transform, model, cam_extractor, 
            show_class_id, pred_id, device):
    # 前向预测
    input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
    pred_logits = model(input_tensor)
    pred_top1 = torch.topk(pred_logits, 1)
    pred_id = pred_top1[1].detach().cpu().numpy().squeeze().item()
    
    # 可视化热力图的类别ID,如果不指定,则为置信度最高的预测类别ID
    if show_class_id:
        show_id = show_class_id
    else:
        show_id = pred_id
        show_class_id = pred_id
    
    # 生成可解释性分析热力图
    activation_map = cam_extractor(show_id, pred_logits)
    activation_map = activation_map[0][0].detach().cpu().numpy()
    result = overlay_mask(img_pil, Image.fromarray(activation_map), alpha=0.7)
    return result, pred_id, show_class_idCopy to clipboardErrorCopied
img_pil = Image.open(img_path)
result, pred_id, show_class_id = get_cam(img_pil, test_transform, model, cam_extractor, 
                                show_class_id, pred_id, device)
def print_image_label(result, pred_id, show_class_id, 
                      idx_to_labels, idx_to_labels_cn=None, Chinese=False):
    # 在图像上写字
    draw = ImageDraw.Draw(result)

    if Chinese:
        # 在图像上写中文
        text_pred = 'Pred Class: {}'.format(idx_to_labels_cn[pred_id])
        text_show = 'Show Class: {}'.format(idx_to_labels_cn[show_class_id])
    else:
        # 在图像上写英文
        text_pred = 'Pred Class: {}'.format(idx_to_labels[pred_id])
        text_show = 'Show Class: {}'.format(idx_to_labels[show_class_id])
    # 文字坐标,中文字符串,字体,rgba颜色
    draw.text((50, 100), text_pred, font=font, fill=(255, 0, 0, 1))
    draw.text((50, 200), text_show, font=font, fill=(255, 0, 0, 1))
    
    return result
result = print_image_label(result, pred_id, show_class_id,
                           idx_to_labels, idx_to_labels_cn, Chinese)
result

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-8ZGXtE9O-1671639191800)(https://relph1119.github.io/my-team-learning/Interpretable_machine_learning44/task05/output_27_0.png)]

视频以及摄像头预测

除此之外,我们还可以检测视频或者是摄像头,实际上就是一帧一帧的图片而已,原理是一样的,具体可以去看代码,这里就不多介绍了

pytorch-grad-cam工具包

Grad-CAM热力图可解释性分析

from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
img_path = './test_img/cat_dog.jpg'
from torchvision import transforms

# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(512),
                                     # transforms.CenterCrop(512),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])
img_pil = Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device)
# Grad-CAM
from pytorch_grad_cam import GradCAM
# 指定要分析的层
target_layers = [model.layer4[-1]]
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
# 如果 targets 为 None,则默认为最高置信度类别
targets = [ClassifierOutputTarget(232)]
cam_map = cam(input_tensor=input_tensor, targets=targets)[0] # 不加平滑
plt.imshow(cam_map)
plt.title('Grad-CAM')
plt.show()

png

import torchcam
from torchcam.utils import overlay_mask

result = overlay_mask(img_pil, Image.fromarray(cam_map), alpha=0.7)
result

在这里插入图片描述

基于Guided Grad-CAM的高分辨率细粒度可解释性分析

Guided Backpropagation算法

from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, deprocess_image, preprocess_image
# 初始化算法
gb_model = GuidedBackpropReLUModel(model=model, use_cuda=True)
# 生成 Guided Backpropagation热力图
gb_origin = gb_model(input_tensor, target_category=None)
gb_show = deprocess_image(gb_origin)
plt.imshow(gb_show)
plt.title('Guided Backpropagation')
plt.show()

png

将Grad-CAM热力图与Gudied Backpropagation热力图逐元素相乘

# Grad-CAM三通道热力图
cam_mask = cv2.merge([cam_map, cam_map, cam_map])
# 逐元素相乘
guided_gradcam = deprocess_image(cam_mask * gb_origin)
plt.imshow(guided_gradcam)
plt.title('Guided Grad-CAM')
plt.show()

png

Captum的工具包

遮挡可解释性分析

这里介绍一部分Captum的方法,也就是遮挡可解释性分析-ImageNet图像分类

在输入图像上,用遮挡滑块,滑动遮挡不同区域,探索哪些区域被遮挡后会显著影响模型的分类决策。

提示:因为每次遮挡都需要分别单独预测,因此代码运行可能需要较长时间。

model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
model = model.eval().to(device)
occlusion = Occlusion(model)

中等遮挡滑块

# 获得输入图像每个像素的 occ 值
attributions_occ = occlusion.attribute(input_tensor,
                                       strides = (3, 8, 8), # 遮挡滑动移动步长
                                       target=pred_id, # 目标类别
                                       sliding_window_shapes=(3, 15, 15), # 遮挡滑块尺寸
                                       baselines=0) # 被遮挡滑块覆盖的像素值

# 转为 224 x 224 x 3的数据维度
attributions_occ_norm = np.transpose(attributions_occ.detach().cpu().squeeze().numpy(), (1,2,0))
viz.visualize_image_attr_multiple(attributions_occ_norm, # 224 224 3
                                  rc_img_norm,           # 224 224 3
                                  ["original_image", "heat_map"],
                                  ["all", "positive"],
                                  show_colorbar=True,
                                  outlier_perc=2)
plt.show()


png

大遮挡滑块

# 更改遮挡滑块的尺寸
attributions_occ = occlusion.attribute(input_tensor,
                                       strides = (3, 50, 50), # 遮挡滑动移动步长
                                       target=pred_id, # 目标类别
                                       sliding_window_shapes=(3, 60, 60), # 遮挡滑块尺寸
                                       baselines=0)

# 转为 224 x 224 x 3的数据维度
attributions_occ_norm = np.transpose(attributions_occ.detach().cpu().squeeze().numpy(), (1,2,0))

viz.visualize_image_attr_multiple(attributions_occ_norm, # 224 224 3
                                  rc_img_norm,           # 224 224 3
                                  ["original_image", "heat_map"],
                                  ["all", "positive"],
                                  show_colorbar=True,
                                  outlier_perc=2)
plt.show()

png

小遮挡滑块

# 更改遮挡滑块的尺寸
attributions_occ = occlusion.attribute(input_tensor,
                                       strides = (3, 2, 2), # 遮挡滑动移动步长
                                       target=pred_id, # 目标类别
                                       sliding_window_shapes=(3, 4, 4), # 遮挡滑块尺寸
                                       baselines=0)

# 转为 224 x 224 x 3的数据维度
attributions_occ_norm = np.transpose(attributions_occ.detach().cpu().squeeze().numpy(), (1,2,0))

viz.visualize_image_attr_multiple(attributions_occ_norm, # 224 224 3
                                  rc_img_norm,           # 224 224 3
                                  ["original_image", "heat_map"],
                                  ["all", "positive"],
                                  show_colorbar=True,
                                  outlier_perc=2)
plt.show()

png

Integrated Gradients可解释性分析

Integrated Gradients原理:输入图像像素由空白变为输入图像像素的过程中,模型预测为某一特定类别的概率相对于输入图像像素的梯度积分。

from captum.attr import IntegratedGradients
from captum.attr import NoiseTunnel
# 初始化可解释性分析方法
integrated_gradients = IntegratedGradients(model)
# 获得输入图像每个像素的 IG 值
attributions_ig = integrated_gradients.attribute(input_tensor, target=pred_id, n_steps=50)
# 转为 224 x 224 x 3的数据维度
attributions_ig_norm = np.transpose(attributions_ig.detach().cpu().squeeze().numpy(), (1,2,0))
from matplotlib.colors import LinearSegmentedColormap

# 设置配色方案
default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)

# 可视化 IG 值
viz.visualize_image_attr(attributions_ig_norm, # 224,224,3
                         rc_img_norm,          # 224,224,3
                         method='heat_map',
                         cmap=default_cmap,
                         show_colorbar=True,
                         sign='positive',
                         outlier_perc=1)
plt.show()

png

加入高斯噪声的多张图像,平滑输出

在输入图像中加入高斯噪声,构造nt_samples个噪声样本,分别计算IG值,再使用smoothgrad_sq(先平均再平方)平滑。

noise_tunnel = NoiseTunnel(integrated_gradients)

# 获得输入图像每个像素的 IG 值
attributions_ig_nt = noise_tunnel.attribute(input_tensor, nt_samples=3, nt_type='smoothgrad_sq', target=pred_id)

# 转为 224 x 224 x 3的数据维度
attributions_ig_nt_norm = np.transpose(attributions_ig_nt.squeeze().cpu().detach().numpy(), (1,2,0))
# 设置配色方案
default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)

viz.visualize_image_attr_multiple(attributions_ig_nt_norm, # 224 224 3
                                  rc_img_norm, # 224 224 3
                                  ["original_image", "heat_map"],
                                  ["all", "positive"],
                                  cmap=default_cmap,
                                  show_colorbar=True)
plt.show()

png

总结

在这次任务中,主要学习到了CAM和Captum工具包的使用,在图像分类的基础上去解释他,知其然还要知其所以然。使用CAM和Captum工具包,可以减少我们很多很多的代码量,并且能快速使用,快速应用在自己的任务中、

在经过一个多星期的学习,也是需要这种代码实战告诉我们,这些应用是全面且方方面面的,这样就不会空读理论,这样可以让我们有机会将理论和实践结合起来,希望后续能够将XAI和CAM运用到我的领域中,学习到更多的知识。

参考阅读

  • 可以根据按照代码教程:https://github.com/TommyZihao/Train_Custom_Dataset,用pytorch训练自己的图像分类模型,基于torch-cam实现各个类别、单张图像、视频文件、摄像头实时画面的CAM可视化

  • Grad-CAM官方代码:https://github.com/ramprs/grad-cam

  • torch-cam代码库:https://github.com/frgfm/torch-cam

  • pytorch-grad-cam代码库:https://github.com/jacobgil/pytorch-grad-cam

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/108106.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySql 根据中文拼音首字母排序、 分组排序

如地域表信息: 如果我们想根据NAME 字段 的值, 按照中文拼音首字母排序 : sql SELECT CODE, NAME FROM district_info ORDER BY CONVERT(name USING gbk) COLLATE gbk_chinese_ci ASC 效果很OK: 那么如果我要整成想电话簿那样&am…

【正版软件】Navicat for Oracle 数据库数管理和开发工具

前言 Navicat for Oracle 透过精简的工作环境,提高 Oracle 开发人员和管理员的效率和效率。 Navicat for Oracle 透过精简的工作环境,提高 Oracle 开发人员和管理员的效率和效率。专业化 Oracle 的开发-快速安全地创建、组织、访问和共享信息…

微信小程序自定义顶部状态栏

因为工作需要,要在微信小程序中自定义顶部导航栏,通过这篇文章来记录一下自己所得~ 第一步: 需要在json文件中配置"navigation" : "custom",完成自定义导航栏,只保留胶囊按钮,效果如下图&#x…

养殖废水生化后氨氮400mg/L做到15mg/L,有什么降氨氮的工艺?

水产养殖过程中,鱼的排泄物和没有被消耗的饲料降解均会使水中的氨氮剧增,当氨氮浓度大于0.2mg/L时,鱼类摄食就会受到严重影响,造成生长不良或停止生长;达到 2mg/L时,则会造成生物的死亡,严重影响水产的养殖…

【Maven实战技巧】「插件使用专题」Maven-Archetype插件创建自定义maven项目骨架

技术推荐 自定义Archetype Maven骨架/以当前项目为模板创建maven骨架,可以参考http://maven.apache.org/archetype/maven-archetype-plugin/advanced-usage.html,详细介绍了如何快速创建和使用Archetype。 技术背景 在工作过程中必然会遇到创建项目的蛋…

最新版Crack:Xceed Ultimate Suite

Xceed Ultimate Suite 包括 160 多个适用于所有 Windows 平台的自适应、可靠和高性能控件和库的重要集合。程序前端的 UI 控件和后端的数据处理库。经常更新,并得到反应支持和开发人员的认可。 适用于所有 Windows 平台的 160 个自适应、可靠和高性能控件和库的重要…

vue-elementUI后台管理系统,已实现用户管理、菜单管理、角色管理、公司管理、权限管理、支付管理等

vue搭建后台管理界面模版(PC端) 完整代码下载地址:vue-elementUI后台管理系统 技术栈 vue2 vuex vue-router webpack ES6/7 axios elementUI 阿里图标iconfont 项目预览 http://nmgwap.gitee.io/vueproject/#/login 说明 本项目主…

Python使用pandas导入csv文件内容

使用pandas导入csv文件内容使用pandas导入csv文件内容1. 默认导入2. 指定分隔符3. 指定读取行数4. 指定编码格式5. 列标题与数据对齐使用pandas导入csv文件内容 1. 默认导入 在Python中导入.csv文件用的方法是read_csv()。 使用read_csv()进行导入时,指定文件名即…

jQuery 过滤方法

文章目录jQuery 过滤方法hasClass() 类名过滤eq() 下标过滤is() 判断过滤not() 反向过滤filter() 表达式过滤has() 表达式过滤后代元素jQuery 过滤方法 过滤方法说明hasClass()类名过滤eq()下标过滤is()判断过滤not()反向过滤filter()表达式过滤has()表达式过滤后代元素 hasCl…

微软发现macOS漏洞可让恶意软件绕过安全检查

©网络研究院 苹果修复了一个漏洞,攻击者可以利用该漏洞通过能够绕过 Gatekeeper 应用程序执行限制的不受信任的应用程序;在易受攻击的 macOS 设备上部署恶意软件。 由微软首席安全研究员发现并报告的安全漏洞(称为Achilles&#xff09…

前端基础_矩阵变换

矩阵变换 在介绍矩阵变换之前,首先要介绍一下变换矩阵,这个矩阵是专门用来实现图形变形的,它与坐标一起配合使用,以达到变形的目的。当图形上下文被创建完毕时,事实上也创建了一个默认的变换矩阵,如果不对…

腾讯T4熬夜硬肝的全套微服务学习笔记,Github万星只是开始

写在前面 微服务架构被认为是 IT 软件架构的未来方向。热度虽高,但对于很多中小公司来说微服务却是遥不可及,因为团队规模和能力又反过来制约了他们采用新技术的步伐。很多人对于微服务技术也都有着一些疑虑,比如:微服务这技术虽然…

编译器原理简介(以Cortex-M3为例)

在"keil根目录\ARM\ARMCC\bin"下可以找到如下文件: 他们就是编译器内核,将工程代码转换成二进制文件,烧写进MCU中执行。 目录 C与汇编 典型的开发流程 编译工具报错举例 C与汇编 在CM3上编程,开发人员既可以使用C也…

CANoe-新型通信模式(SOA面向服务架构)

传统的以ECU为单元的整车通信架构,是面向信号的以CAN/LIN等总线为代表的经典通信模式。而以车载以太网为总线,SOME/IP或DDS等为中间件的SOA面向服务的新型通信模式,在以域控为单元的整车通信架构中被越来越多的使用 CANoe作为仿真和测试环境提供了统一的跨网络通信概念。这…

字符设备驱动_3:register_chrdev_region() 简单字符设备驱动的实现

概述&#xff1a;利用regist_chrdev_region() 函数接口注册同一类字符设备的多个子设备。 上一节一起整理了一遍注册一个简单字符设备的流程&#xff0c;接下来就来实现一个同一类字符设备的多个子设备驱动程序。 1. Demo 程序 #include <linux/module.h> #include <…

Linux篇 三、香橙派Zero2搭建Qt环境

香橙派Zero2系列文章目录 一、香橙派Zero2设置开机连接wifi 二、香橙派Zero2获取Linux SDK源码 三、香橙派Zero2搭建Qt环境 文章目录香橙派Zero2系列文章目录前言一、下载交叉编译工具二、编译QT库1.先去网站下载Qt的资源包2.解压3.开始移植&#xff1a;4.编译&#xff1a;5.安…

jQuery 查找方法

文章目录jQuery 查找方法查找祖先元素parent()parents()parentsUntil()查找后代元素children()find()contents()向前查找兄弟元素prev()prevAll()prevUnitl()向后查找兄弟元素next()nextAll()nextUntil()查找所有兄弟元素siblings()jQuery 查找方法 查找祖先元素查找后代元素向…

年度创新力十强,热点领域重要力量,典型案例报告入选!美创再获ISC安全百强多项殊荣

12月21日&#xff0c;数字安全界“奥斯卡”—ISC 2022数字安全创新能力百强&#xff08;简称“创新百强”&#xff09;重磅揭晓&#xff0c;本届评选由ISC平台发起&#xff0c;联合赛迪顾问、数世咨询、数说安全、看雪、安在等网络安全行业权威机构、媒体共同开启评选&#xff…

web开发前基础知识补充

什么是URL&#xff1f; URL是统一资源定位符&#xff0c;对可以从互联网上得到的资源的位置和访问方法的一种简洁的表示&#xff0c;是互联网上标准资源的地址&#xff1b; 互联网上的每个文件都有一个唯一的URL&#xff1b; 基本URL包含模式&#xff08;或称协议&#xff0…

Kafka使用MirrorMaker同步数据的两种方式

1.前言 MirrorMaker 是 Kafka官方提供的跨数据中心的流数据同步方案。原理是通过从 原始kafka集群消费消息&#xff0c;然后把消息发送到 目标kafka集群。操作简单&#xff0c;只要通过简单的 consumer配置和 producer配置&#xff0c;然后启动 Mirror&#xff0c;就可以实现准…