使用grad-cam对ViT的输出进行可视化

news2025/1/25 4:43:07

使用grad-cam对ViT的输出进行可视化

文章目录

  • 使用grad-cam对ViT的输出进行可视化
    • 前言
    • 原理
    • 使用代码
    • Pytorch-grad-cam库的更多方法
    • 在MMpretrain中使用
    • 示例
    • 总结

前言

Vision Transformer (ViT) 作为现在CV中的主流backbone,它可以在图像分类任务上达到与卷积神经网络 (CNN) 相媲美甚至超越的性能。ViT 的核心思想是将输入图像划分为多个小块,然后将每个小块作为一个 token 输入到 Transformer 的编码器中,最终得到一个全局的类别 token 作为分类结果。

ViT 的优势在于它可以更好地捕捉图像中的长距离依赖关系,而不需要使用复杂的卷积操作。然而,这也带来了一个挑战,那就是如何解释 ViT 的决策过程,以及它是如何关注图像中的不同区域的。为了解决这个问题,我们可以使用一种叫做 grad-cam 的技术,它可以根据 ViT 的输出和梯度,生成一张热力图,显示 ViT 在做出分类时最关注的图像区域。

原理

grad-cam对ViT的输出进行可视化的原理是利用 ViT 的最后一个注意力块的输出和梯度,计算出每个 token 对分类结果的贡献度,然后将这些贡献度映射回原始图像的空间位置,形成一张热力图。具体来说,grad-cam+ViT 的步骤如下:

  1. 给定一个输入图像和一个目标类别,将图像划分为 14x14 个小块,并将每个小块转换为一个 768 维的向量。在这些向量之前,还要加上一个特殊的类别 token ,用于表示全局的分类信息。这样就得到了一个 197x768 的矩阵,作为 ViT 的输入。

  2. 将 ViT 的输入通过 Transformer 的编码器,得到一个 197x768 的输出矩阵。其中第一个向量就是类别 token ,它包含了 ViT 对整个图像的理解。我们将这个向量通过一个线性层和一个 softmax 层,得到最终的分类概率。

  3. 计算类别 token 对目标类别的梯度,即 ∂ y c ∂ A \frac{\partial y_c}{\partial A} Ayc ,其中 y c y_c yc 是目标类别的概率, A A A 是 ViT 的输出矩阵。这个梯度表示了每个 token 对分类结果的重要性。

  4. 对每个 token 的梯度求平均值,得到一个 197 维的向量 w w w ,其中 w i = 1 Z ∑ k ∂ y c ∂ A i k w_i = \frac{1}{Z}\sum_k \frac{\partial y_c}{\partial A_{ik}} wi=Z1kAikyc Z Z Z 是梯度的维度,即 768 。这个向量 w w w 可以看作是每个 token 的权重。

  5. 将 ViT 的输出矩阵和权重向量相乘,得到一个 197 维的向量 s s s ,其中 s i = ∑ k w k A i k s_i = \sum_k w_k A_{ik} si=kwkAik 。这个向量 s s s 可以看作是每个 token 对分类结果的贡献度。

  6. 将贡献度向量 s s s 除去第一个元素(类别 token ),并重塑为一个 14x14 的矩阵 M M M​ ,其中 M i j = s ( i − 1 ) × 14 + j + 1 M_{ij} = s_{(i-1) \times 14 + j + 1} Mij=s(i1)×14+j+1 。这个矩阵 M M M 可以看作是每个小块对分类结果的贡献度。

  7. 将贡献度矩阵 M M M 进行归一化和上采样,得到一个与原始图像大小相同的矩阵 H H H ,其中 H i j = M i j − min ⁡ ( M ) max ⁡ ( M ) − min ⁡ ( M ) H_{ij} = \frac{M_{ij} - \min(M)}{\max(M) - \min(M)} Hij=max(M)min(M)Mijmin(M) 。这个矩阵 H H H 就是我们要求的热力图,它显示了 ViT 在做出分类时最关注的图像区域。

  8. 将热力图 H H H 和原始图像进行叠加,得到一张可视化的图像,可以直观地看到 ViT 的注意力分布。

使用代码

import argparse
import cv2
import numpy as np
import torch

from pytorch_grad_cam import GradCAM, \
                            ScoreCAM, \
                            GradCAMPlusPlus, \
                            AblationCAM, \
                            XGradCAM, \
                            EigenCAM, \
                            EigenGradCAM, \
                            LayerCAM, \
                            FullGrad

from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, \
preprocess_image
from pytorch_grad_cam.ablation_layer import AblationLayerVit

# 加载预训练的 ViT 模型
model = torch.hub.load('facebookresearch/deit:main',
'deit_tiny_patch16_224', pretrained=True)
model.eval()

# 判断是否使用 GPU 加速
use_cuda = torch.cuda.is_available()
if use_cuda:
model = model.cuda()

接下来,我们需要定义一个函数来将 ViT 的输出层从三维张量转换为二维张量,以便 grad-cam 能够处理:

def reshape_transform(tensor, height=14, width=14):
    # 去掉类别标记
    result = tensor[:, 1:, :].reshape(tensor.size(0),
    height, width, tensor.size(2))

    # 将通道维度放到第一个位置
    result = result.transpose(2, 3).transpose(1, 2)
    return result

然后,我们需要选择一个目标层来计算 grad-cam。由于 ViT 的最后一层只有类别标记对预测类别有影响,所以我们不能选择最后一层。我们可以选择倒数第二层中的任意一个 Transformer 编码器作为目标层。在这里,我们选择第 11 层作为示例:


# 创建 GradCAM 对象
cam = GradCAM(model=model,
target_layer=model.blocks[5],
use_cuda=use_cuda,
reshape_transform=reshape_transform)

接下来,我们需要准备一张输入图像,并将其转换为适合 ViT 的格式:

# 读取输入图像
image_path = "cat.jpg"
rgb_img = cv2.imread(image_path, 1)[:, :, ::-1]
rgb_img = cv2.resize(rgb_img, (224, 224))

# 预处理图像
input_tensor = preprocess_image(rgb_img,
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

# 将图像转换为批量形式
input_tensor = input_tensor.unsqueeze(0)
if use_cuda:
input_tensor = input_tensor.cuda()

最后,我们可以调用 cam 对象的 forward 方法,传入输入张量和预测类别(如果不指定,则默认为最高概率的类别),得到 grad-cam 的输出:

# 计算 grad-cam
target_category = None # 可以指定一个类别,或者使用 None 表示最高概率的类别
grayscale_cam = cam(input_tensor=input_tensor,
target_category=target_category)

# 将 grad-cam 的输出叠加到原始图像上
visualization = show_cam_on_image(rgb_img, grayscale_cam)

# 保存可视化结果
cv2.imwrite('cam.jpg', visualization)

这样,我们就完成了使用 grad-cam 对 ViT 的输出进行可视化的过程。我们可以看到,ViT 主要关注了图像中的猫的头部和身体区域,这与我们的直觉相符。通过使用 grad-cam,我们可以更好地理解 ViT 的工作原理,以及它对不同图像区域的重要性。

Pytorch-grad-cam库的更多方法

除了经典的grad-cam,库里目前支持的方法还有:

MethodWhat it does
GradCAM使用平均梯度对 2D 激活进行加权
GradCAM++类似 GradCAM,但使用了二阶梯度
XGradCAM类似 GradCAM,但通过归一化的激活对梯度进行了加权
EigenCAM使用 2D 激活的第一主成分(无法区分类别,但效果似乎不错)
EigenGradCAM类似 EigenCAM,但支持类别区分,使用了激活 * 梯度的第一主成分,看起来和 GradCAM 差不多,但是更干净
LayerCAM使用正梯度对激活进行空间加权,对于浅层有更好的效果

这里给出MMpretrain提供的对比示例:

image-20230502234318715

在MMpretrain中使用

image-20230502234419413

如果你刚好在用MMpretrain,那么有着方便的脚本文件来帮助你更加方便的进行上面的工作,具体可见:类别激活图(CAM)可视化 — MMPretrain 1.0.0rc7 文档

image-20230502234349894

示例

这里也放一些我自己试过的例子:

bc264783b1bfe850f5b8236c619f8df

下载

总结

通过使用 grad-cam,我们可以更好地理解 ViT 的工作原理,以及它是如何从图像中提取有用的特征的。grad-cam 也可以用于其他基于 Transformer 的模型,例如DeiT、Swin Transformer 等,只需要根据不同的模型结构和输出,调整相应的计算步骤即可。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1210776.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mysql数据库超过最大连接数

mysql 超过数据库最大连接数解决办法 1、报错信息 首先无论是navicat 执行sql还是 用idea启动多的服务都会有如下报错信息: 2、解决办法 2.1命令方式修改 这种方法是由其他资料提供的。这种修改方式是临时的,如果mysql服务重启设置就会还原&#xff…

弱类型和强类型自定义UDAF函数

目录 使用自带的avg函数弱类型自定义UDAF函数(AVG)强类型自定义UDAF函数(AVG) 弱类型:3.x过期 2.x有 强类型:3.x 2.x没有 使用自带的avg函数 import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SparkSession}object UserDefine…

双极膜电渗析设备

#双极膜电渗析设备 双极膜(bipolar membrance,简称BPM)是一种新型的离子交换复合膜,它通常由阳离子交换层(N型膜)、界面亲水层(催化层)和阴离子交换膜(P型膜)…

【计算思维】少儿编程蓝桥杯青少组计算思维题考试真题及解析C

【科技素养】少儿编程蓝桥杯青少组计算思维题考试真题及解析 1.天平的左右两端分别放有一些砝码,如下图所示,右边的砝码不变,从左边最多拿走几个砝码,可以使天平左右两边平衡: A、1 B、2 C、3 D、4 2.把下面的图形…

sqli-labs(Less-3)

1. 通过构造id1’ 和id1’) 和id1’)–确定存在注入 可知原始url为 id(‘1’) 2.使用order by 语句猜字段数 http://127.0.0.1/sqlilabs/Less-3/?id1) order by 4 -- http://127.0.0.1/sqlilabs/Less-3/?id1) order by 3 --3. 使用联合查询union select http://127.0.0.1…

Window MongoDB安装

三种NOSQL的一种,Redis MongoDB ES 应用场景: 1.社交场景:使用Mongodb存储用户信息,以及用户发表的朋友圈信息,通过地理位置索引实现附近的人,地点等功能 2.游戏场景:使用Mongodb存储游戏用户信息,用户的装备,积分等直接以内嵌文档的形式存储,方便查询,高效率存储和访问…

IDEA创建JavaFX项目

1、New -> Project 2、选择JavaFX 配置项目名,包名,lib包管理工具,JDK版本(注,JDK版本最低需要11) 3、选择lib包 根据自己需求选择 lib包介绍 BootstrapFX:BootstrapFX 是一个为 JavaFX 提…

怎样正确选择等保测评机构开展等保测评工作?

随着大家对网络安全的重视,越来越多的企业需要做等保测评了。很多小伙伴想知道怎样正确选择等保测评机构开展等保测评工作?这里就给大家简单说说。 怎样正确选择等保测评机构开展等保测评工作? 【回答】:正确选择等保测评机构开展…

Java魔法解密:HashMap底层机制大揭秘

文章目录 一、 源码深度解析1.1 窥探Java集合框架中的设计思想1.2 逐行解读HashMap的源代码1.2.1 类信息1.2.2 常量属性1.2.3 变量属性1.2.4 节点信息1.2.5 构造方法1.2.6 put方法1.2.6.1 putVal方法1.2.6.2 putTreeVal方法1.2.6.3 tieBreakOrder方法1.2.6.4 treeifyBin方法1.2…

【联邦学习+区块链】TORR: A Lightweight Blockchain for Decentralized Federated Learning

文章目录 I.CONTRIBUTIONII. ASSUMPTIONS AND THREAT MODELA. AssumptionsB. Threat Model III. SYSTEM DESIGNA. Design OverviewB. Block DesignC. InitializationD. Role SelectionE. Storage ProtocolF. Aggregation ProtocolG. Proof of ReliabilityH. Blockchain Consens…

Hive的安装部署

目录 1.修改hadoop相关参数2.Hive解压安装3.Hive元数据的三种部署方式3.1 元数据库之Derby3.2 元数据库之Mysql3.3 元数据之MetaStore Server 4.hive的两种访问方式4.1 命令行的方式4.2 HiveServer2模式 1.修改hadoop相关参数 1)修改core-site.xml [roothadoop102…

Android 12 intent-filter添加android:exported后任然报错解决方法

Android 12 或更高版本为目标平台,且包含使用intent-filter 过滤器的 activity、service或receiver,您必须为这些应用组件显式声明 android:exported 属性。 常规操作 查看AndroidManifest.xml文件,搜索intent-filter,然后添加好…

灰度图处理方法

做深度学习项目图像处理的时候常常涉及到灰度图处理,这里对自己处理灰度图的方式做一个记录,后续有更新的话会在此更新 一,多维数组可视化 将多维数组可视化为灰度图 img_gray Image.fromarray(img, modeL) # 实现array到image的转换,m…

原来有 K8s 认证,找工作这么吃香!

随着容器的快速发展,容器管理工具Kubernetes(下文简称K8s)也应运而生,目前不仅百度、京东、阿里、Google等大公司在使用K8s,一些中小企业也开始把业务迁移到K8s中。 K8s在人工智能、大数据、5G、区块链、智能家居、航…

【postgresql】CentOS7 安装Pgweb

Pgweb Pgweb是PostgreSQL的一个基于web的数据库浏览器,用Go编写,可在Mac、Linux和Windows机器上运行。以零依赖性的简单二进制形式分布。非常易于使用,并具有适当数量的功能。简单的基于web和跨平台的PostgreSQL数据库浏览器。 特点 跨平台…

【用户实践】openGauss5.0在某省医保局实时数仓应用

一、项目背景 采用数据同步软件将各系统的数据库下的数据实时同步到openGauss数据库中;建立实时数仓;可以在实时数仓自行查询、分析、统计数据及报表;同时横向集成公共服务区和核心业务区生产库数据、集成其他委办局数据。纵向集成市级的生产…

使用Zoho Projects软件进行高效项目管理的方法

本文将通过Zoho Projects项目管理软件的功能,结合一般情况下使用Zoho Projects进行项目管理的过程,为大家展示如何充分利用Zoho Projects项目管理软件进行项目管理。我们将详细介绍任务看板、文档管理、甘特图与报表、里程碑、issue等功能,基…

PieCloudDB Database 自研内存管理器 ASanAlloc:为产品质量保驾护航

内存管理是计算机科学中至关重要的一部分,它涉及到操作系统、硬件和软件应用之间的动态交互。有效的内存管理可以确保系统的稳定性和安全性,提高系统运行效率,帮助我们最大限度地利用有效的内存资源,合理分配和回收内存&#xff0…

TikTok:传承文化多样性,扬播全球声音

在数字时代,社交媒体平台已经成为了传播文化多样性和全球声音的重要渠道。其中,TikTok无疑是最引人注目的之一。 这个短视频应用在短短几年内迅速崭露头角,吸引了全球数亿用户,成为一个独特的文化传媒工具,通过短视频…