pytorch生成CAM热力图-单张图像

news2024/9/23 0:40:20

利用ImageNet预训练模型生成CAM热力图-单张图像

  • 一、环境搭建
  • 二、主要代码
  • 三、结果展示

代码和图片等资源均来源于哔哩哔哩up主:同济子豪兄
讲解视频:CAM可解释性分析-算法讲解

一、环境搭建

1,安装所需的包

pip install numpy pandas matplotlib requests tqdm opencv-python pillow -i https://pypi.tuna.tsinghua.edu.cn/simple

2,安装 Pytorch

pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

3,安装 mmcv-full

pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.10.0/index.html

4,下载中文字体文件(用于显示和打印汉字文字)

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/SimHei.ttf

5,下载 ImageNet 1000类别信息

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/dataset/meta_data/imagenet_class_index.csv

6,创建 test_img 文件夹,并下载测试图像到该文件夹

import os
os.mkdir('test_img')

wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/border-collie.jpg -P test_img
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/cat_dog.jpg -P test_img
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220716-mmclassification/test/0818/room_video.mp4 -P test_img

7,下载安装 torchcam

git clone https://github.com/frgfm/torch-cam.git
pip install -e torch-cam/.

二、主要代码

from PIL import Image

import torch
# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)

# 导入ImageNet预训练模型
from torchvision.models import resnet18
model = resnet18(pretrained=True).eval().to(device)

# 导入自己训练的模型
# model = torch.load('自己训练的模型.pth')
# model = model.eval().to(device)

# 可解释性分析方法有:CAM GradCAM GradCAMpp ISCAM LayerCAM SSCAM ScoreCAM SmoothGradCAMpp XGradCAM

# 方法一:导入可解释性分析方法SmoothGradCAMpp
# from torchcam.methods import SmoothGradCAMpp 
# cam_extractor = SmoothGradCAMpp(model)

# 方法二:导入可解释性分析方法GradCAM
from torchcam.methods import GradCAM
target_layer = model.layer4[-1]    # 选择目标层
cam_extractor = GradCAM(model, target_layer)

# 图片预处理
from torchvision import transforms
# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
                                    ])

# 图片分类预测
img_path = 'test_img/border-collie.jpg'
img_pil = Image.open(img_path)
input_tensor = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
pred_logits = model(input_tensor)
# topk()方法用于返回输入数据中特定维度上的前k个最大的元素
pred_top1 = torch.topk(pred_logits, 1)
# pred_id 为图片所属分类对应的索引号,分类和索引号存储在imagenet_class_index.csv
pred_id = pred_top1[1].detach().cpu().numpy().squeeze().item()

# 生成可解释性分析热力图
activation_map = cam_extractor(pred_id, pred_logits)
activation_map = activation_map[0][0].detach().cpu().numpy()

# 可视化
from torchcam.utils import overlay_mask

# overlay_mask 用于构建透明的叠加层
# fromarray 实现array到image的转换
result = overlay_mask(img_pil, Image.fromarray(activation_map), alpha=0.7)

# 为图片添加中文类别显示

# 载入ImageNet 1000 类别中文释义
import pandas as pd
df = pd.read_csv('imagenet_class_index.csv')
idx_to_labels = {}
idx_to_labels_cn = {}
for idx, row in df.iterrows():
    idx_to_labels[row['ID']] = row['class']
    idx_to_labels_cn[row['ID']] = row['Chinese']

# 显示所有中文类别
# idx_to_labels_cn

# 可视化热力图的类别ID,如果为 None,则为置信度最高的预测类别ID
# show_class_id = 231		# 例如 牧羊犬:231 虎猫:281
show_class_id = None

# 可视化热力图的类别ID,如果不指定,则为置信度最高的预测类别ID
if show_class_id:
    show_id = show_class_id
else:
    show_id = pred_id
    show_class_id = pred_id

# 是否显示中文类别
Chinese = True
# Chinese = False

from PIL import ImageDraw
# 在图像上写字
draw = ImageDraw.Draw(result)

if Chinese:
    # 在图像上写中文
    text_pred = 'Pred Class: {}'.format(idx_to_labels_cn[pred_id])
    text_show = 'Show Class: {}'.format(idx_to_labels_cn[show_class_id])
else:
    # 在图像上写英文
    text_pred = 'Pred Class: {}'.format(idx_to_labels[pred_id])
    text_show = 'Show Class: {}'.format(idx_to_labels[show_class_id])

from PIL import ImageFont, ImageDraw
# 导入中文字体,指定字体大小
font = ImageFont.truetype('SimHei.ttf', 30)

# 文字坐标,中文字符串,字体,rgba颜色
draw.text((10, 10), text_pred, font=font, fill=(255, 0, 0, 1))
draw.text((10, 50), text_show, font=font, fill=(255, 0, 0, 1))

#输出结果图
result

注意:

  1. 可解释性方法的选择有多种,代码中提供了 SmoothGradCAMpp 和 GradCAM 两种方法;
  2. 模型选择也有pytorch预训练模型和自己训练的模型两种,代码中演示了 ImageNet图像分类 模型,图片类别文件为 imagenet_class_index.csv;若为自己的模型则还需要修改 “为图片载入类别的部分代码”

三、结果展示

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1013281.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

EATON XV-440-10TVB-1-13-1工业显示屏模块

EATON XV-440-10TVB-1-13-1工业显示屏模块是一款功能强大的工业显示屏,具有以下特点和功能: 1. 显示屏尺寸为10.4英寸,分辨率为800600。 2. 采用TFT液晶显示技术,显示效果清晰,色彩鲜艳。 3. 支持多种显示模式&…

ESP32-BOX的组件配置添加核心部分详细介绍

前言 (1)为了方便开发,ESP32提供了组件库方便用户进行二次开发。 github仓库;gitee仓库 (2)在学习本章之前最好有CMake或者Makefile的基础,如果没有也不要慌,有的话最好。 &#xff…

1-FPGA硬件加速-YUV_YCbCr

这是对《基于Matlab与FPGA的图像处理教程》的学习笔记,代码和内容摘取自书中。 心得: 使用FPGA进行硬件加速的重点是消除或者减少浮点数运算,转换为定点运算,然后通过pipeline流水设计转为并行实现加速。 原理和方法 RGB与&…

微信小程序——常用组件的属性介绍

常用的组件内容标签 text 文本组件类似于HTML中的span标签,是一个行内元素rich-text 富文本标签支持把HTML字符串渲染为WXML结构 text标签的基本使用 通过text组件的selectable属性,实现长按选中文本内容的效果。只有text标签支持长按选中效果&#x…

爬虫代理在数据采集中的应用详解

随着互联网技术的不断发展,数据采集已经成为了各个行业中必不可少的一项工作。在数据采集的过程中,爬虫代理的应用越来越受到了重视。本文将详细介绍爬虫代理在数据采集中的应用。 什么是爬虫代理? 爬虫代理是指利用代理服务器来隐藏真实的IP…

string的使用和模拟实现

💓博主个人主页:不是笨小孩👀 ⏩专栏分类:数据结构与算法👀 C👀 刷题专栏👀 C语言👀 🚚代码仓库:笨小孩的代码库👀 ⏩社区:不是笨小孩👀 🌹欢迎大…

Pytest系列-使用自定义标记mark(6)

简介 pytest 可以支持自定义标记,自定义标记可以把一个 web 项目划分为多个模块,然后指定模块名称执行 Pytest 里面自定义标记 用法:将pytest.mark.标记名称 放到测试函数或者类上面 使用: 执行时加上 -m 标记名 进行用例筛选…

[交互]交互的实战问题1

[交互]交互的实战问题1 状态码 431 Request Header Fields Too LargeReferrer Policy: no-referrer-when-downgrade路径参数高并发问题使用场景使用的方法异常情况 状态码 431 Request Header Fields Too Large 最近做项目,遇到一个问题,后台导出表格时…

牛客: BM4 合并两个排序的链表

牛客: BM4 合并两个排序的链表 文章目录 牛客: BM4 合并两个排序的链表题目描述题解思路题解代码 题目描述 题解思路 以链表一为主链表,遍历两条链表 若当前链表二的节点val小于当前链表一的下一个节点val,则将链表链表二的该节点连到链表一的节点的下一个,链表一的当前节点往…

sql存储引擎

-- 查询建表语句 --可以查看引擎 show create table account; -- 可以看到默认引擎 InnoDB ENGINEInnoDB -- 查看当前数据库支持得存储引擎 show engines ; # InnoDB 默认 存储引擎 # MyISAM sql早期默认 存储引擎 # MEMORY 存储在内存中 用来做临时表和缓存 存储引擎 …

Adobe Acrobat Reader 中的漏洞

另一个流行漏洞 Adobe Acrobat 和 Acrobat Reader - 流行的便携式文档格式 (PDF) 工具 - 存在风险。该漏洞 CVE-2023-26369影响 Windows 和 macOS 安装。 攻击者创建的恶意 PDF 文档打开后,会利用与在缓冲区外写入有关的 CVE-2023-26369漏洞。因此,攻击…

数据中心液冷服务器详情说明

目录 前言 何为液冷服务器? 为什么需要液冷? 1.数据中心降低PUE的需求 2.政策导向 3.芯片热功率已经达到风冷散热极限 4.液冷比热远大于空气 液冷VS风冷,区别在哪? 1.液冷服务器跟风冷服务器的区别 2.液冷数据中心跟风冷…

linux安装常见的中间件和数据库

文章目录 一、数据库二、redis三、tomcat四、nginx五、mq六、es七、nacos八、neo4j(图数据库)九、fastdfs其他 一、数据库 linux环境上使用压缩包安装mysql【数据库】Mysql 创建用户与授权 二、redis redis是没有账号的,只能设置密码Linux…

对IP协议概念以及IP地址的概念进行简单整理

网络层重要协议 参考模型和协议栈IP协议IPv4数据报IP数据报格式IPv4地址特殊IP地址私有IP地址和公有IP地址子网划分 参考模型和协议栈 IP协议 IP协议定义了网络层数据传送的基本单元,也制定了一系列关于网络层的规则。 IPv4数据报 网络层的协议数据单元PDU 叫做分…

GeoSOS-FLUS未来土地利用变化情景模拟模型

软件简介 适用场景 GeoSOS-FLUS软件能较好的应用于土地利用变化模拟与未来土地利用情景 的预测和分析中,是进行地理空间模拟、参与空间优化、辅助决策制定的有效工 具。FLUS 模型可直接用于: 城市发展模拟及城市增长边界划定;城市内 部高分…

分布式事务解决方案之TCC

分布式事务解决方案之TCC 什么是TCC事务 TCC是Try、Confirm、Cancel三个词语的缩写,TCC要求每个分支事务实现三个操作:预处理Try、确认 Confirm、撤销Cancel。Try操作做业务检查及资源预留,Confirm做业务确认操作,Cancel实现一个…

Golang代码漏洞扫描工具介绍——govulncheck

Golang Golang作为一款近年来最火热的服务端语言之一,深受广大程序员的喜爱,笔者最近也在用,特别是高并发的场景下,golang易用性的优势十分明显,但笔者这次想要介绍的并不是golang本身,而且golang代码的漏洞…

微信小程序+echart实现点亮旅游地图

背景 最近看抖音有个很火的特效就是点亮地图,去过哪些地方,于是乎自己也想做一个,结合自己之前做的以家庭为单位的小程序,可以考虑做一个家庭一起点亮地图的功能。 效果图 过程 1,首先就是得去下微信小程序适配的ec…

react 实现拖动元素

demo使用create-react-app脚手架创建 删除一些文件,创建一些文件后 结构目录如下截图com/index import Movable from ./move import { useMove } from ./move.hook import * as Operations from ./move.opMovable.useMove useMove Movable.Operations Operationse…

ABB 1TGE120010R... Rev控制模块

ABB 1TGE120010R... Rev 控制器模块是一种高性能控制器,可用于工业自动化和过程控制应用。它具有以下主要特点: 多功能性:该控制器模块可用于多种应用,包括机器控制、过程控制和自动化系统等。 高性能:该控制器模块具…