文章首发及后续更新:https://mwhls.top/4475.html,无图/无目录/格式错误/更多相关请至首发页查看。
新的更新内容请到mwhls.top查看。
欢迎提出任何疑问及批评,非常感谢!
摘要:绘制模型指定层的热力图
可视化环境安装
- 可用的环境版本:
- mmseg 1.0.0rc5
- mmdet 3.0.0rc6
- mmcv 2.0.0rc4
- mmengine 0.6.0
- 注:不要用在其它版本跑的文件覆盖它,我最开始一直没成功就是因为我想偷懒直接复制我的模型过去,但是模型调用了在原版本存在,但新版本不存在的方法,导致一直报错。
- 安装以上环境,参考该 issue 代码可正常推理,代码如下
- 还有其它 issue 也提到了 featmap,可以在 mmseg 的 GitHub 搜 cam 关键词,或者点这里。
import torch
import cv2
import numpy as np
from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm
config_path = '../mmsegv2/configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py'
checkpoint_path = '../mmsegv2/checkpoints/pspnet_r50-d8_512x1024_80k_cityscapes_20200606_112131-2376f12b.pth'
img_path = '../mmsegv2/demo/demo.png'
register_all_modules()
model = init_model(config_path, checkpoint_path, device='cpu')
model = revert_sync_batchnorm(model)
vis = SegLocalVisualizer()
ori_img = cv2.imread(img_path)
img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)
logits = model(img)
out = vis.draw_featmap(logits[0], ori_img)
cv2.imshow('cam', out)
cv2.waitKey(0)
指定位置可视化
- 修改后的可视化代码 Startup.py
# Thank xiexinch: https://github.com/open-mmlab/mmsegmentation/issues/2434#issuecomment-1441392574
import torch
import cv2
import numpy as np
from mmseg.visualization import SegLocalVisualizer
from mmseg.apis import init_model
from mmseg.utils import register_all_modules
from mmengine.model import revert_sync_batchnorm
# prefix = "mmsegmentation-1.0.0rc5/"
prefix = ""
config = prefix + r"log\7_ttpla_p2t_t_20k\ttpla_p2t_t_20k.py"
checkpoint = prefix + r"log\7_ttpla_p2t_t_20k\iter_8000.pth"
config = prefix + r"log\9_ttpla_r50_20k\ttpla_r50_20k.py"
checkpoint = prefix + r"log\9_ttpla_r50_20k\iter_8000.pth"
img_path = prefix + r"img.png"
def draw_heatmap(featmap):
vis = SegLocalVisualizer()
ori_img = cv2.imread(img_path)
out = vis.draw_featmap(featmap, ori_img)
cv2.imshow('cam', out)
cv2.waitKey(0)
def generate_featmap(config, checkpoint, img_path):
register_all_modules()
model = init_model(config, checkpoint, device='cpu')
model = revert_sync_batchnorm(model)
vis = SegLocalVisualizer()
ori_img = cv2.imread(img_path)
img = torch.from_numpy(ori_img.astype(np.single)).permute(2, 0, 1).unsqueeze(0)
logits = model(img)
out = vis.draw_featmap(logits[0], ori_img)
cv2.imshow('cam', out)
cv2.waitKey(0)
if __name__ == "__main__":
generate_featmap(config, checkpoint, img_path)
- 如下,在模型内调用
draw_heatmap()
from Startup import draw_heatmap
draw_heatmap(x[0])
def forward(self, x):
"""Forward function."""
from Startup import draw_heatmap
draw_heatmap(x[0])
if self.deep_stem:
x = self.stem(x)
else:
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.res_layers):
res_layer = getattr(self, layer_name)
x = res_layer(x)
if i in self.out_indices:
outs.append(x)
from Startup import draw_heatmap
draw_heatmap(x[0])
return tuple(outs)