【Pytorch】Visualization of Feature Maps(4)——Saliency Maps

news2025/1/23 23:30:07

在这里插入图片描述

学习参考来自

  • Saliency Maps的原理与简单实现(使用Pytorch实现)
  • https://github.com/wmn7/ML_Practice/tree/master/2019_07_08/Saliency%20Maps

Saliency Maps 原理

《Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps》(arXiv-2013)

在这里插入图片描述

A saliency map tells us the degree to which each pixel in the image affects the classification score for that image.
To compute it, we compute the gradient of the unnormalized score corresponding to the correct class (which is a scalar)
with respect to the pixels of the image. If the image has shape (3, H, W) then this gradient will also have shape (3, H, W);
for each pixel in the image, this gradient tells us the amount by which the classification score will change if the pixel
changes by a small amount. To compute the saliency map, we take the absolute value of this gradient, then take the maximum value over the 3 input channels; the final saliency map thus has shape (H, W) and all entries are non-negative.

Saliency Maps相当于是计算图像的每一个pixel是如何影响一个分类器的, 或者说分类器对图像中每一个pixel哪些认为是重要的.

会计算图像每一个像素点的梯度。如果图像的形状是(3, H, W),这个梯度的形状也是(3, H, W);对于图像中的每个像素点,
这个梯度告诉我们当像素点发生轻微改变时,正确分类分数变化的幅度。

计算 saliency map 的时候,需要计算出梯度的绝对值,然后再取三个颜色通道的最大值;

因此最后的 saliency map的形状是(H, W)为一个通道的灰度图。


直接来代码,先载入些数据,用的是 cs231n 作业里面的 imagenet_val_25.npz,含有 imagenet 数据中验证集的 25 张图片

import torch
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image

SQUEEZENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
SQUEEZENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)

def load_imagenet_val(num=None):
    """Load a handful of validation images from ImageNet.
    Inputs:
    - num: Number of images to load (max of 25)
    Returns:
    - X: numpy array with shape [num, 224, 224, 3]
    - y: numpy array of integer image labels, shape [num]
    - class_names: dict mapping integer label to class name
    """
    imagenet_fn = 'imagenet_val_25.npz'
    if not os.path.isfile(imagenet_fn):
      print('file %s not found' % imagenet_fn)
      print('Run the following:')
      print('cd cs231n/datasets')
      print('bash get_imagenet_val.sh')
      assert False, 'Need to download imagenet_val_25.npz'
    f = np.load(imagenet_fn, allow_pickle=True)
    X = f['X']  # (25, 224, 224, 3)
    y = f['y']  # (25, )
    class_names = f['label_map'].item()  # 999
    if num is not None:
        X = X[:num]
        y = y[:num]
    return X, y, class_names

图像的前处理,resize,变成向量,减均值除以方差

# 辅助函数
def preprocess(img, size=224):
    transform = T.Compose([
        T.Resize(size),
        T.ToTensor(),
        T.Normalize(mean=SQUEEZENET_MEAN.tolist(),
                    std=SQUEEZENET_STD.tolist()),
        T.Lambda(lambda x: x[None]),
    ])
    return transform(img)

在这里插入图片描述

数据集和实验的模型

链接:https://pan.baidu.com/s/1vb2Y0IiHdH_Fb9wibTta4Q?pwd=zuvw
提取码:zuvw


核心代码,计算 saliency maps

def compute_saliency_maps(X, y, model):
    """
    X表示图片, y表示分类结果, model表示使用的分类模型
    
    Input : 
    - X : Input images : Tensor of shape (N, 3, H, W)
    - y : Label for X : LongTensor of shape (N,)
    - model : A pretrained CNN that will be used to computer the saliency map
    
    Return :
    - saliency : A Tensor of shape (N, H, W) giving the saliency maps for the input images
    """
    # 确保model是test模式
    model.eval()
    
    # 确保X是需要gradient
    X.requires_grad_() # 仅开启了输入图片的梯度
    
    saliency = None
    
    logits = model.forward(X)  # torch.Size([5, 1000]), 前向获取 logits
    logits = logits.gather(1, y.view(-1, 1)).squeeze()  # torch.Size([5]) 得到正确分类 logits (5张图片标签相应类别的 logits)
    logits.backward(torch.FloatTensor([1., 1., 1., 1., 1.]))  # 只计算正确分类部分的loss(正确类别梯度为 1 回传)
    
    saliency = abs(X.grad.data)  # 返回X的梯度绝对值大小, torch.Size([5, 3, 224, 224])
    saliency, _ = torch.max(saliency, dim=1)  # torch.Size([5, 224, 224]),取 rgb 3通道的最大值
    
    return saliency.squeeze()

显示 saliency maps

def show_saliency_maps(X, y):
    # Convert X and y from numpy arrays to Torch Tensors
    X_tensor = torch.cat([preprocess(Image.fromarray(x)) for x in X], dim=0) # torch.Size([5, 3, 224, 224])
    y_tensor = torch.LongTensor(y)

    # Compute saliency maps for images in X
    saliency = compute_saliency_maps(X_tensor, y_tensor, model)

    # Convert the saliency map from Torch Tensor to numpy array and show images
    # and saliency maps together.
    saliency = saliency.numpy()
    N = X.shape[0]  # 5
    for i in range(N):
        plt.subplot(2, N, i + 1)
        plt.imshow(X[i])
        plt.axis('off')
        plt.title(class_names[y[i]])
        plt.subplot(2, N, N + i + 1)
        plt.imshow(saliency[i], cmap=plt.cm.hot)
        plt.axis('off')
        plt.gcf().set_size_inches(12, 5)
    plt.show()

下面开始调用,首先载入模型,使其梯度冻结,仅打开输入图片的梯度,这样反向传播的时候会更新图片,得到我们想要的 saliency maps

# Download and load the pretrained SqueezeNet model.
model = torchvision.models.squeezenet1_1(pretrained=True)

# We don't want to train the model, so tell PyTorch not to compute gradients
# with respect to model parameters.
for param in model.parameters():
    param.requires_grad = False

加载一些图片看看,25 张中抽出来 5 张

X, y, class_names = load_imagenet_val(num=5)  # X: (5, 224, 224, 3) | y: (5,) | class_names: 999

"show images"

plt.figure(figsize=(12, 6))
for i in range(5):
    plt.subplot(1, 5, i + 1)
    plt.imshow(X[i])
    plt.title(class_names[y[i]])
    plt.axis('off')
plt.gcf().tight_layout()
plt.show()

显示图片
在这里插入图片描述
把五张图片的 saliency maps 画出来

show_saliency_maps(X, y)

我把 25 张都画出来了
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述


核心代码中涉及到了 gather 函数,下面来个简单的例子就明白了

# Example of using gather to select one entry from each row in PyTorch
# 用来返回matrix指定行某个位置的值
import torch

def gather_example():
    N, C = 4, 5
    s = torch.randn(N, C) # 随机生成 4 行 5 列的 tensor
    y = torch.LongTensor([1, 2, 1, 3])
    print(s)
    print(y)
    print(torch.LongTensor(y).view(-1, 1))
    print(s.gather(1, y.view(-1, 1)).squeeze()) # 抽取每行相应的列数位置上的数值


gather_example()

"""
tensor([[ 0.8119,  0.2664, -1.4168, -0.1490, -0.0675],
        [ 0.5335,  0.6304, -0.7200, -0.0974, -0.9934],
        [-0.8305,  0.5189,  0.7359,  1.5875,  0.0505],
        [ 0.4335, -1.1389, -0.7771,  0.5779,  0.3515]])
tensor([1, 2, 1, 3])
tensor([[1],
        [2],
        [1],
        [3]])
tensor([ 0.2664, -0.7200,  0.5189,  0.5779])
"""

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1268325.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何设置Linux终端提示信息

如何设置Linux终端提示信息 1 方法一:只能在VSCode或者Pycharm显示2 方法二:只能在MobaXterm等远程软件上显示,但全用户都会显示3 方法三:避免用户没看到上面的提示,上面两种都设置一下 在使用远程终端时,由…

基于Qt QChart和QChartView实现正弦、余弦、正切图表

# 源码地址 https://gitcode.com/m0_45463480/QChartView/tree/main# .pro QT += charts​​HEADERS += \ chart.h \ chartview.h​​SOURCES += \ main.cpp \ chart.cpp \ chartview.cpp​​target.path = $$[QT_INSTALL_EXAMPLES]/charts/zoomlinechartINSTAL…

L1-004:计算摄氏温度

题目描述 给定一个华氏温度F,本题要求编写程序,计算对应的摄氏温度C。计算公式:C5(F−32)/9。题目保证输入与输出均在整型范围内。 输入格式:输入在一行中给出一个华氏温度。 输出格式:在一行中按照格式“Celsius C”…

如何使用录屏软件在电脑录制PDF文件

我有一个PDF文件,想用录屏软件将它录制下来并添加上详细的注释,然后发给客户看,请问应该如何录制呢?有没有推荐的录屏软件呢? 不用担心,本文将会详细的为您讲解如何使用录屏软件在电脑端录制PDF文件&#…

GoLang切片

一、切片基础 1、切片的定义 切片(Slice)是一个拥有相同类型元素的可变长度的序列它是基于数组类型做的一层封装它非常灵活,支持自动扩容切片是一个引用类型,它的内部结构包含地址、长度和容量声明切片类型的基本语法如下&#…

Mac单独修改应用语言

方法1: 方法2: defaults write com.microsoft.Excel AppleLanguages ("zh-cn") defaults write com.microsoft.Word AppleLanguages ("zh-cn")参考:https://www.zhihu.com/question/24976020

Javaweb之Vue组件库Element案例的详细解析

4.4 案例 4.4.1 案例需求 参考 资料/页面原型/tlias智能学习辅助系统/首页.html 文件,浏览器打开,点击页面中的左侧栏的员工管理,如下所示: 需求说明: 制作类似格式的页面 即上面是标题,左侧栏是导航&…

vue高频面试题(2023),有回答思路,并且让你回答清晰

一、对MVC,MVP,MVVM的理解 三者都是项目的架构模式(不是类的设计模式),即:一个项目的结构,如何分层,不同层负责不同的职责。 1、MVC: MVC的出现是用在后端(…

SpringMVC—拦截器

1 拦截器概念 1.1 简介 拦截器是一种动态拦截方法调用的机制,在 SpringMVC 中动态拦截控制器方法的执行 【注】拦截器底层实现为AOP 作用: 在指定的方法调用前后执行预先设定的代码阻止原始方法的执行 1.2 拦截器和过滤器的区别 ① 归属不同&#…

高效的将两个文件夹中多余的文件删除

高效的将两个文件夹中多余的文件删除 解决方案 之前使用的是这个方法,但是图像太多,需要删除的有70W张,得删10多天。。 将两个文件夹中重复的图象删除 解决方案 先将image图像复制一份,然后改名为txt import osdef change_file…

SpringBoot——Swagger2 接口规范

优质博文:IT-BLOG-CN 如今,REST和微服务已经有了很大的发展势头。但是,REST规范中并没有提供一种规范来编写我们的对外REST接口API文档。每个人都在用自己的方式记录api文档,因此没有一种标准规范能够让我们很容易的理解和使用该…

【JavaWeb】会话过滤器监听器

会话&过滤器&监听器 文章目录 会话&过滤器&监听器一、会话1.1 Cookie1.2 Session1.3 三大域对象 二、过滤器三、监听器3.1 application域监听器3.2 session域监听器3.3 request域监听器3.4 session域的两个特殊监听器3.4.1 session绑定监听器3.4.2 钝化活化监听…

【Vulnhub 靶场】【Coffee Addicts: 1】【简单-中等】【20210520】

1、环境介绍 靶场介绍:https://www.vulnhub.com/entry/coffee-addicts-1,699/ 靶场下载:https://download.vulnhub.com/coffeeaddicts/coffeeaddicts.ova 靶场难度:简单 - 中等 发布日期:2021年5月20日 文件大小:1.3 …

SparkSQL远程调试(IDEA)

启动Intellij IDEA,打开spark源码项目,配置远程调试 Run->Edit Configuration 启动远程spark-sql spark-sql --verbose --driver-java-options "-Xdebug -Xrunjdwp:transportdt_socket,servery,suspendy,address5005"运行远程调试&#xf…

【面试】css预处理器之sass(scss)

目录 为什么引入css预处理器 可读性 嵌套:关系明朗 选择器 属性 伪类‘’ 变量:语义明确 默认变量:美元符号 $ 变量名:值 !default 全局变量::global { $global-x: } 变量插值:#{} map键值对:$…

【Java SE】带你在String类世界中遨游!!!

🌹🌹🌹我的主页🌹🌹🌹 🌹🌹🌹【Java SE 专栏】🌹🌹🌹 🌹🌹🌹上一篇文章:带你走近Java的…

C++初阶(十三)vector

📘北尘_:个人主页 🌎个人专栏:《Linux操作系统》《经典算法试题 》《C》 《数据结构与算法》 ☀️走在路上,不忘来时的初心 文章目录 一、vector的介绍二、vector的模拟实现1、模拟实现2、测试结果 一、vector的介绍 vector的文…

分割掩模 VS 掩膜

掩膜 Mask分割掩模 Segmentation Mask总结示例 掩膜 Mask “掩膜” 是指一种用于 标识或遮蔽图像中特定区域 的 图像。 在图像处理中,掩膜通常是一个 二值图像,其中的 像素值为 0 或 1。binary Mask 叫做二元掩膜,如下图所示: 这…

九、hdfs中Namenode元数据处理

1、元数据的由来 在hdfs文件系统中,用户的每一次操作,都会对文件系统产生响应的影响,那么谁来记录这些影响呢? 在hdfs文件系统中,edits文件记录了hdfs中的每一次操作,以及本次操作影响的文件其对应的block。…

为啥网络安全缺口这么大,还是这么缺网络安全工程师?(网络安全行业前景到底如何)

为啥网安领域缺口多达300多万人,但网安工程师也就是白帽黑客却很少,难道又是砖家在忽悠人? 原因主要为这三点: 首先是学校的原因,很多学校网络安全课程用的还都是十年前的老教材,教学脱离社会需求,实操技能…