Grad-CAM,即梯度加权类激活映射 (Gradient-weighted Class Activation Mapping)

news2025/1/17 18:00:55

Grad-CAM,即梯度加权类激活映射 (Gradient-weighted Class Activation Mapping),是一种用于解释卷积神经网络决策的方法。它通过可视化模型对于给定输入的关注区域来提供洞察。

原理:

Grad-CAM的关键思想是将输出类别的梯度(相对于特定卷积层的输出)与该层的输出相乘,然后取平均,得到一个“粗糙”的热力图。这个热力图可以被放大并叠加到原始图像上,以显示模型在分类时最关注的区域。

具体步骤如下:

  1. 选择一个卷积层作为解释的来源。通常,我们会选择网络的最后一个卷积层,因为它既包含了高级特征,也保留了空间信息。
  2. 前向传播图像到网络,得到你想解释的类别的得分。
  3. 计算此得分 相对于我们选择的卷积层 输出的梯度。
  4. 对于该卷积层的每个通道,使用上述梯度的全局平均值对该通道进行加权
  5. 结果是一个与卷积层的空间维度相同的加权热力图

优势

Grad-CAM的优点是它可以用于任何卷积神经网络,无需进行结构修改或重新训练。它为我们提供了一个简单但直观的方式来理解模型对于特定输入的决策。

Code

import torch
import cv2
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.feature_maps = None
        self.gradients = None
        
        # Hook layers
        target_layer.register_forward_hook(self.save_feature_maps)
        target_layer.register_backward_hook(self.save_gradients)

    def save_feature_maps(self, module, input, output):
        self.feature_maps = output.detach()

    def save_gradients(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def generate_cam(self, image, class_idx=None):
        # Set model to evaluation mode
        self.model.eval()
        
        # Forward pass
        output = self.model(image)
        if class_idx is None:
            class_idx = torch.argmax(output).item()

        # Zero out gradients
        self.model.zero_grad()

        # Backward pass for target class
        one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32)
        one_hot[0][class_idx] = 1
        output.backward(gradient=one_hot.cuda(), retain_graph=True)

        # Get pooled gradients and feature maps
        pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
        activation = self.feature_maps.squeeze(0)
        for i in range(activation.size(0)):
            activation[i, :, :] *= pooled_gradients[i]
        
        # Create heatmap
        heatmap = torch.mean(activation, dim=0).squeeze().cpu().numpy()
        heatmap = np.maximum(heatmap, 0)
        heatmap /= torch.max(heatmap)
        heatmap = cv2.resize(heatmap, (image.size(3), image.size(2)))
        heatmap = np.uint8(255 * heatmap)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        
        # Superimpose heatmap on original image
        original_image = self.unprocess_image(image.squeeze().cpu().numpy())
        superimposed_img = heatmap * 0.4 + original_image
        superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
        
        return heatmap, superimposed_img

    def unprocess_image(self, image):
        # Reverse the preprocessing step
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = (((image.transpose(1, 2, 0) * std) + mean) * 255).astype(np.uint8)
        return image

def visualize_gradcam(model, input_image_path, target_layer):
    # Load image
    img = Image.open(input_image_path)
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    input_tensor = preprocess(img).unsqueeze(0).cuda()

    # Create GradCAM
    gradcam = GradCAM(model, target_layer)
    heatmap, result = gradcam.generate_cam(input_tensor)

    plt.figure(figsize=(10,10))
    plt.subplot(1,2,1)
    plt.imshow(heatmap)
    plt.title('Heatmap')
    plt.axis('off')
    plt.subplot(1,2,2)
    plt.imshow(result)
    plt.title('Superimposed Image')
    plt.axis('off')
    plt.show()

# Load your model (e.g., resnet20 in this case)
# model = resnet20()
# model.load_state_dict(torch.load("path_to_your_weights.pth"))
# model.to('cuda')

# Visualize GradCAM
# visualize_gradcam(model, "path_to_your_input_image.jpg", model.layer3[-1])

中文注释详细版

import torch
import cv2
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model  # 要进行Grad-CAM处理的模型
        self.target_layer = target_layer  # 要进行特征可视化的目标层
        self.feature_maps = None  # 存储特征图
        self.gradients = None  # 存储梯度
        
        # 为目标层添加钩子,以保存输出和梯度
        target_layer.register_forward_hook(self.save_feature_maps)
        target_layer.register_backward_hook(self.save_gradients)

    def save_feature_maps(self, module, input, output):
        """保存特征图"""
        self.feature_maps = output.detach()

    def save_gradients(self, module, grad_input, grad_output):
        """保存梯度"""
        self.gradients = grad_output[0].detach()

    def generate_cam(self, image, class_idx=None):
        """生成CAM热力图"""
        # 将模型设置为评估模式
        self.model.eval()
        
        # 正向传播
        output = self.model(image)
        if class_idx is None:
            class_idx = torch.argmax(output).item()

        # 清空所有梯度
        self.model.zero_grad()

        # 对目标类进行反向传播
        one_hot = torch.zeros((1, output.size()[-1]), dtype=torch.float32)
        one_hot[0][class_idx] = 1
        output.backward(gradient=one_hot.cuda(), retain_graph=True)

        # 获取平均梯度和特征图
        pooled_gradients = torch.mean(self.gradients, dim=[0, 2, 3])
        activation = self.feature_maps.squeeze(0)
        for i in range(activation.size(0)):
            activation[i, :, :] *= pooled_gradients[i]
        
        # 创建热力图
        heatmap = torch.mean(activation, dim=0).squeeze().cpu().numpy()
        heatmap = np.maximum(heatmap, 0)
        heatmap /= torch.max(heatmap)
        heatmap = cv2.resize(heatmap, (image.size(3), image.size(2)))
        heatmap = np.uint8(255 * heatmap)
        heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
        
        # 将热力图叠加到原始图像上
        original_image = self.unprocess_image(image.squeeze().cpu().numpy())
        superimposed_img = heatmap * 0.4 + original_image
        superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)
        
        return heatmap, superimposed_img

    def unprocess_image(self, image):
        """反预处理图像,将其转回原始图像"""
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = (((image.transpose(1, 2, 0) * std) + mean) * 255).astype(np.uint8)
        return image

def visualize_gradcam(model, input_image_path, target_layer):
    """可视化Grad-CAM热力图"""
    # 加载图像
    img = Image.open(input_image_path)
    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    input_tensor = preprocess(img).unsqueeze(0).cuda()

    # 创建GradCAM
    gradcam = GradCAM(model, target_layer)
    heatmap, result = gradcam.generate_cam(input_tensor)

    # 显示图像和热力图
    plt.figure(figsize=(10,10))
    plt.subplot(1,2,1)
    plt.imshow(heatmap)
    plt.title('热力图')
    plt.axis('off')
    plt.subplot(1,2,2)
    plt.imshow(result)
    plt.title('叠加后的图像')
    plt.axis('off')
    plt.show()

# 以下是示例代码,显示如何使用上述代码。
# 首先,你需要加载你的模型和权重。
# model = resnet20()
# model.load_state_dict(torch.load("path_to_your_weights.pth"))
# model.to('cuda')

# 然后,调用`visualize_gradcam`函数来查看结果。
# visualize_gradcam(model, "path_to_your_input_image.jpg", model.layer3[-1])

论文链接:https://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/983203.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

初阶三子棋(超详解)

✨博客主页:小钱编程成长记 🎈博客专栏:C语言小游戏 初阶三子棋 1.游戏介绍2.基本思路3.实现前的准备4.实现步骤4.1 打印菜单4.2 初始化棋盘4.3 打印棋盘4.4 玩家下棋4.5 电脑下棋4.6 判断本局游戏继续还是结束4.7 优化棋盘的显示 5.游戏代码…

汽车技术发展趋势及我国节能与新能源汽车技术

一、世界汽车技术发展趋势 汽车技术正向着低碳化、信息化、智能化方向发展;“三化”趋势成为世界主要汽车强国、主要车企共同的战略选择。 主要汽车战略及方向 在“三化”趋势下,各汽车强国在汽车节能技术、新能源汽车技术、智能网联汽车技术等方面持续…

算法训练营day42|动态规划 part04:0-1背包 (01背包问题基础(两种解决方案)、LeetCode 416.分割等和子集)

文章目录 01背包----二维dp数组01背包----滚动数组416.分割等和子集思路分析背包解法思考总结 有n件物品和一个最多能背重量为w 的背包。第i件物品的重量是weight[i],得到的价值是value[i] 。每件物品只能用一次,求解将哪些物品装入背包里物品价值总和最…

深入解析Spring Boot中最常用注解的使用方式(下篇)

摘要:本文是《深入解析Spring Boot中最常用注解的使用方式》的下篇内容,将继续介绍Spring Boot中其他常用的注解的使用方式,并通过代码示例进行说明,帮助读者更好地理解和运用Spring Boot框架。 目录 第二部分:常见的容…

浏览器开发者模式下只显示 XHR 请求应该怎么办

浏览器开发者模式下只显示 XHR 请求应该怎么办 问题分析 问题 F12打开浏览器的开发者模式,然后点击 Network,只显示 XHR 请求应该怎么办 分析 打开漏斗,选择All 模式

怎么给视频加背景音乐?学会这三种方法轻松配乐

给视频添加配乐可以带来多重好处。首先,配乐可以增强视频的氛围和情感,帮助观众更好地投入其中,感受视频所要表达的情感。不同的音乐可以传达不同的情感,例如悲伤、欢乐、紧张等等,可以让观众更深入地体验视频内容。教…

2023年9月NPDP产品经理国际认证报名,找弘博创新

产品经理国际资格认证NPDP是新产品开发方面的认证,集理论、方法与实践为一体的全方位的知识体系,为公司组织层级进行规划、决策、执行提供良好的方法体系支撑。 【认证机构】 产品开发与管理协会(PDMA)成立于1979年,是…

vue响应式原理

vue响应式原理 vue响应式原理vue2响应式原理目标对象为数组时 vue3响应式原理Vue3和Vue2在响应式系统方面的对比数据劫持的方式支持数据劫持的数据类型Vue3响应式系统显著优点是: vue响应式原理 无论vue2和vue3响应式都是通过观察者模式(发布订阅模式&a…

技术分享 | 强化学习,让机器像人类一样自我学习

如果说近年来有什么是各行各业共通的话题,那就一定是强化学习,这是一个让机器能够像人类一样通过与环境互动来学习和改进自己决策的领域。它不仅令人兴奋,而且具有革命性的潜力,可以改变我们生活和工作的方式。 随着计算能力的不断…

perf与simpleperf

对事件进行采样,然后根据采样频率,评估各个函数的调用频率。可以用来分析CPU cache,CPU迁移,指令周期等各种硬件事件,他也可以对感兴趣的事件进行动态追踪。 效果: cat available_events | grep receive p…

YashanDB:潜心实干,数据库核心技术突破没有捷径可走

都说数据库是三大基础软件中的一块硬骨头,技术门槛高、研发周期长、工程要求高,市场长期被几大巨头所把持。 因此,实现突破一直是中国数据库产业的夙愿。自上个世纪80年代起,中国数据库产业走过艰辛坎坷的四十余载,终…

CocosCreator3.8研究笔记(九)CocosCreator 场景资源的理解

相信很多朋友都想知道, Cocos Creator 资源的定义? Cocos Creator 常见的资源包含哪些?Cocos Creator 资源的管理机制是什么样的? Cocos Creator 中所有继承自 Asset 的类型都统称资源 ,例如:Texture2D、Sp…

springboot项目实现helloworld

使用Spring官方源创建项目(推荐) 缺陷:镜像在国外下载速度有点慢 选择配置 选择版本 实现HelloWorld 删除部分不重要的文件 idea隐藏文件 使用云原生的方式创建项目(spring官方源) 访问地址:Spring Init…

基于Java+SpringBoot+Vue前后端分离科研项目验收管理系统设计和实现

博主介绍:✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专…

服务端 TCP 连接的 TIME_WAIT 过多问题的分析与解决

https://blog.csdn.net/zxlyx/article/details/120397006 本文给出一个 TIME_WAIT 状态的 TCP 连接过多的问题的解决思路,非常典型,大家可以好好看看,以后遇到这个问题就不会束手无策了。 问题描述 模拟高并发的场景,会出现批量…

CS架构和BS架构的联系与区别(零基础理解)

文章目录 网络编程CS架构BS架构CS和BS的区别C/S架构优缺点B/S架构优缺点 网络编程 首先要了解CS架构和BS架构就需要了解一下什么是网络编程? 大家刚接触编程时,往往是在自己的电脑的编辑器上进行代码的编写,说简单的就是以前我们书写的代码就像单机版游戏一样,只能自己玩,不能…

VSRS4.0 安装与配置

0 引言 介绍:VSRS的定义参阅官方论文,项目引入VSRS来解决目前亟需解决的问题(基于两视点的虚拟视点合成)。 1 下载VSRS 1.1 通过TortoiseSVN下载最新版VSRS VSRS can be accessed from SVN server server: https://svn.multimedia.edu.pl/vsrs user:…

PyTorch基础知识(1)— PyTorch框架介绍和安装步骤

前言:Hello大家好,我是小哥谈。PyTorch是一个开源的深度学习框架,它基于Python语言,并提供了高级的神经网络接口,可以用于构建和训练各种深度学习模型。它的设计理念是灵活性和易用性,并且提供了动态图的特…

C++入门介绍之“栈”

1.1栈的定义 栈(stack)是一种只能在一端进行插入或删除的线性表 下面是一些基础概念 栈顶(top) : 表中允许进行插入、删除操作的线性表栈底(bottom):表的另一端空栈 :栈中没有数据元素进栈/入栈&#xf…

如何统计网站的访问量

本文介绍的是使用redis的HyperLoglog实现uv的统计功能。 背景 首先我们先明确一下uv这个名词代表的实际意义。uv代表的是通过网页访问浏览的人数,和文章的阅读量差不多,但是需要注意的是,一个人即使是多次访问,也只算一次。 所…