深度学习绘制热力图heatmap、使模型具有可解释性

news2024/11/17 12:58:52

思路

获取想要解释的那一层的特征图,然后根据特征图梯度计算出权重值,加在原图上面。

Demo

在这里插入图片描述
加上类激活(cam)
在这里插入图片描述
可以看到,cam将模型认为有利于分类的特征标注了出来。
下面以ResNet50为例:
Trick:
使用

for i in model._modules.items():

可以获得模型名称和对应层。

# coding: utf-8
import os
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
 
import torch
import torch.autograd as autograd
import torchvision.transforms as transforms

import torchvision.models as models
 
 
# 训练过的模型路径
#resume_path = r"D:\TJU\GBDB\set113\cross_validation\test1\epoch_0257_checkpoint.pth.tar"
# 输入图像路径
single_img_path = r'bicycle.jpg'
# 绘制的热力图存储路径
save_path = r'heatmap/bicycle_layer4.jpg'
 
# 网络层的层名列表, 需要根据实际使用网络进行修改
layers_names = ['conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'avgpool']
# 指定层名
out_layer_name = "layer4"
 
features_grad = 0
 
 
# 为了读取模型中间参数变量的梯度而定义的辅助函数
def extract(g):
    global features_grad
    features_grad = g
 
 
def draw_CAM(model, img_path, save_path, transform=None, visual_heatmap=False, out_layer=None):
    """
    绘制 Class Activation Map
    :param model: 加载好权重的Pytorch model
    :param img_path: 测试图片路径
    :param save_path: CAM结果保存路径
    :param transform: 输入图像预处理方法
    :param visual_heatmap: 是否可视化原始heatmap(调用matplotlib)
    :return:
    """
    # 读取图像并预处理
    global layer2
    img = Image.open(img_path).convert('RGB')
    if transform:
        img = transform(img)
    img = img.unsqueeze(0)  # (1, 3, 448, 448)
 
    # model转为eval模式
    model.eval()
 
    # 获取模型层的字典
    layers_dict = {layers_names[i]: None for i in range(len(layers_names))}
    for name,module in model._modules.items():
        #print(i, (name, module))
        layers_dict[name] = module
 
    # 遍历模型的每一层, 获得指定层的输出特征图
    # features: 指定层输出的特征图, features_flatten: 为继续完成前端传播而设置的变量
    features = img
    start_flatten = False
    features_flatten = None

    for name, layer in layers_dict.items():
        if name != out_layer and start_flatten is False:    # 指定层之前
            features = layer(features)
        elif name == out_layer and start_flatten is False:  # 指定层
            features = layer(features)
            start_flatten = True
        else:   # 指定层之后
            if name == "fc":
                break
            if features_flatten is None:
                features_flatten = layer(features)
            else:
                features_flatten = layer(features_flatten)
    #print(features_flatten.shape)
    features_flatten = torch.flatten(features_flatten, 1)
    #print(features_flatten.shape)
    output = model.fc(features_flatten)
    # 预测得分最高的那一类对应的输出score
    pred = torch.argmax(output, 1).item()
    pred_class = output[:, pred]
 
    # 求中间变量features的梯度
    # 方法1
    # features.register_hook(extract)
    # pred_class.backward()
    # 方法2
    features_grad = autograd.grad(pred_class, features, allow_unused=True)[0]
 
    grads = features_grad  # 获取梯度
    pooled_grads = torch.nn.functional.adaptive_avg_pool2d(grads, (1, 1))
    # 此处batch size默认为1,所以去掉了第0维(batch size维)
    pooled_grads = pooled_grads[0]
    features = features[0]
    print("pooled_grads:", pooled_grads.shape)
    print("features:", features.shape)
    # features.shape[0]是指定层feature的通道数
    for i in range(features.shape[0]):
        features[i, ...] *= pooled_grads[i, ...]
 
    # 计算heatmap
    heatmap = features.detach().cpu().numpy()
    heatmap = np.mean(heatmap, axis=0)
    heatmap = np.maximum(heatmap, 0)
    heatmap /= np.max(heatmap)
 
    # 可视化原始热力图
    if visual_heatmap:
        plt.matshow(heatmap)
        plt.show()
 
    img = cv2.imread(img_path)  # 用cv2加载原始图像
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))  # 将热力图的大小调整为与原始图像相同
    heatmap = np.uint8(255 * heatmap)  # 将热力图转换为RGB格式
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # 将热力图应用于原始图像
    superimposed_img = heatmap * 0.7 + img  # 这里的0.4是热力图强度因子
    cv2.imwrite(save_path, superimposed_img)  # 将图像保存到硬盘
 
 
if __name__ == '__main__':
    model = models.resnet50(pretrained=True)
    #model.eval()
    transform = transforms.Compose([
        transforms.Resize(448),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    # 构建模型并加载预训练参数
    #seresnet50 = FineTuneSEResnet50(num_class=113).cuda()
    #checkpoint = torch.load(resume_path)
    #seresnet50.load_state_dict(checkpoint['state_dict'])
    draw_CAM(model, single_img_path, save_path, transform=transform, visual_heatmap=True, out_layer=out_layer_name)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1540831.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

springboot做自定义校验注解

目录 自定义校验注解的实现 注意: 首先,我们需要自定义一个校验注解: 注解含义: Target({ElementType.FIELD}) Retention(RetentionPolicy.RUNTIME) Constraint(validatedBy PhoneValidator.class) 校验注解逻辑实现类&a…

数据结构:图的最短路径

目录 一、最短路径的基本概念 二、无权图单源最短路径 三、Dijkstra算法(正权图单源) 3.1、算法的基本步骤 3.2、算法的实现 3.3、习题思考 3.3.1、网络延迟时间 四、A*算法(正权图单源单目标点) 4.1、算法的基本概念 4…

阿里必问:Spring源码背后的10大设计奥秘!

如有疑问或者更多的技术分享,欢迎关注我的微信公众号“知其然亦知其所以然”! 各位小米粉丝们,大家好!今天小米要和大家分享的是一个备受关注的话题——“阿里巴巴面试题:Spring源码中的设计模式?”设计模式是软件工程领域中的经典话题,也是技术面试中的常见考点之一。而…

UE5学习日记——Rope Swing 人物与绳索摆动知识准备

rope swing荡绳 比我想的要复杂,目前还没查到简单的做法。本文为查资料的记录,积累后再做一个自己满意的荡绳蓝图。 一、某国外网友的解释 原文 https://forums.unrealengine.com/t/implementing-rope-swing/83098/15 Project Flake - Physics Rope De…

1+x中级题目练习复盘(八)

SQL 语句中进行 group by 分组时,可以不写 where 子句 在使用 select 语句进行查询分组时,如果希望去掉不满足条件的分组,使用 having 子句File 类的 isDirectory() 方法可以判断文件是否为目录 在使用 select 语句进行查询分组时&#xff0…

StarRocks学习笔记

介绍场景建表明细模型聚合模型更新模型主键模型 介绍 StarRocks是一款经过业界检验、现代化,面向多种数据分析场景的、兼容MySQL协议的、高性能分布式关系型分析数据库。 StarRocks充分吸收关系型 OLAP 数据库和分布式存储系统在大数据时代的优秀研究成果&#xff…

【数据结构】链表习题之链表的中间节点和合并两个有序链表

👑个人主页:啊Q闻 🎇收录专栏:《数据结构》 🎉道阻且长,行则将至 前言 嗨嗨,今天的博客是关于链表的题目,力扣题目之链表的中间节点和合并两个有序链表 一.链表的…

【MySql】1.mysql数据库

一、数据库的基本概念 1.数据 记录事物的信息;按统一的格式进行存储 2.表 数据的集合,行和列的组合;将多条数据组织在一起 3.数据库 表的集合,是存储 相互有关 数据的仓库 二、数据库管理系统 DBMS的主要功能: …

【Unity】UI九宫格

什么是九宫格? 顾名思义,九宫格就是指UI切成9个格子,9个格子可以任意拉伸。 1、3、7、9不拉伸。 2、8水平拉伸。 4、6垂直拉伸。 5既可以水平也可以垂直拉伸。 怎么切九宫格? 选中图片,改成Sprite模式,点…

使用React搭建single-spa

自己搭建的Demo GitHub - ftao123/single-spa-react-demo: single-spa-react-demo 修改子应用的webpack配置 library: "app2"和libraryTarget: "umd"配置必须添加。 可以看到filename在开发环境下的地址是static/js/bundle.js,所以我们主应用…

axure和蓝湖上查看页面的说明和上传文件

蓝湖上传文件 入口 可添加链接和文件 文件可添加 PDF,word,Excel等,不能添加压缩包,可在线预览文件内容 axure元件说明 在原型上添加说明 axure发布页 axure预览页或发布到axure的服务器上,查看页面说明的方法 点…

jmeter之并发和顺序执行与特殊线程组-第四天

1.jmeter的并发执行 并发执行:多个线程同时执行,不能确定谁先结束 以上案例中http请求里面没有写任何内容,只是为了看这个并发执行的效果 2.jmeter的顺序执行 顺序执行:多个线程顺序执行 再测试计划中勾选“独立运行每个线程组…

大舍传媒:纽约纳斯达克大屏引领企业多维曝光,挑战华尔街巨头,获得30%销售增长!

作为一名深耕华尔街的金融巨擘,我深知企业在如今竞争激烈的商业环境中,亟需寻找新的推广方式来获得曝光,并取得销售增长。而纳斯达克大屏,作为世界著名的电子交易市场,无疑是一个引领企业多维曝光的理想平台。在本篇干…

基于Matlab的血管图像增强算法,Matlab实现

博主简介: 专注、专一于Matlab图像处理学习、交流,matlab图像代码代做/项目合作可以联系(QQ:3249726188) 个人主页:Matlab_ImagePro-CSDN博客 原则:代码均由本人编写完成,非中介,提供…

C++学习之旅(二)运行四个小项目 (Ubuntu使用Vscode)

如果是c语言学的比较好的同学 可以直接跟着代码敲一遍&#xff0c;代码附有详细语法介绍&#xff0c;不可错过 一&#xff0c;猜数字游戏 #include <iostream> #include <cstdlib> #include <ctime>int main() {srand(static_cast<unsigned int>(tim…

常用类一(包装类)

目录 基本数据类型的包装类 包装类基本知识 包装类的用途 自动装箱和拆箱 自动装箱&#xff1a; 自动拆箱&#xff1a; 包装类的缓存问题 基本数据类型的包装类 八种基本数据类型并不是对象&#xff0c;为了将基本类型数据和对象之间实现互 相转化&#xff0c;JDK 为每一…

java 高级面试题(借鉴)(下)

雪花算法原理 第1位符号位固定为0&#xff0c;41位时间戳&#xff0c;10位workId&#xff0c;12位序列号&#xff0c;位数可以有不同实现。 优点&#xff1a;每个毫秒值包含的ID值很多&#xff0c;不够可以变动位数来增加&#xff0c;性能佳&#xff08;依赖workId的实现…

《自动机理论、语言和计算导论》阅读笔记:p1-p4

《自动机理论、语言和计算导论》学习第1天&#xff0c;p1-p4&#xff0c;总计4页。这只是个人的学习记录&#xff0c;因为很多东西不懂&#xff0c;难免存在理解错误的地方。 一、技术总结 1.有限自动机(finite automata)示例 1.software for checking digital circuits。 …

数据结构基础:一篇文章教你单链表(头插,尾插,查找,头删等的解析和代码)

和我一起学编程呀&#xff0c;大家一起努力&#xff01; 这篇文章耗时比较久&#xff0c;所以大家多多支持啦 链表的结构及结构 概念&#xff1a;链表是⼀种物理存储结构上非连续、非顺序的存储结构&#xff0c;数据元素的逻辑顺序是通过链表 中的指针链接次序实现的。 理解&a…

【HarmonyOS】ArkUI - 状态管理

在声明式 UI 中&#xff0c;是以状态驱动视图更新&#xff0c;如图1所示&#xff1a; 图1 其中核心的概念就是状态&#xff08;State&#xff09;和视图&#xff08;View&#xff09;&#xff1a; 状态&#xff08;State&#xff09;&#xff1a;指驱动视图更新的数据&#xf…