【深度学习可视化系列】—— 特征图可视化(支持Vit系列模型的特征图可视化,包含使用Tensorboard对可视化结果进行保存)
import sys
import os
import torch
import cv2
import timm
import numpy as np
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model.MitUnet import MitUnet
from collections import OrderedDict
from typing import Dict, Iterable, Callable
from torch import nn, Tensor
from PIL import Image
from pprint import pprint
# --------------------------------------------------------------------------------------------------------------------------
# 构建模型特征图提取模型,输入参数为模型、以及需提取特征图层的key名称,该名称可通过model.named_modules()或model.named_children()获取
# --------------------------------------------------------------------------------------------------------------------------
class FeatureExtractor(nn.Module):
def __init__(self, model: nn.Module, layers: Iterable[str]):
super().__init__()
# assert layers is not None
self.model = model
self.layers = layers
self._features = OrderedDict({layer: torch.empty(0) for layer in layers})
self.hook = []
for layer_id in layers:
layer = dict([*self.model.named_modules()])[layer_id]
self.hook = layer.register_forward_hook(self.hook_func(layer_id))
# self.hook.append(self.layer_id)
def hook_func(self, layer_id: str) -> Callable:
def fn(_, __, output):
# print("_____{}".format(output.dim()))
if output.dim() == 3:
output = self.reshape_transform(in_tensor=output)
self._features[layer_id] = output
return fn
def forward(self, x: Tensor) -> Dict[str, Tensor]:
_ = self.model(x)
self.remove()
return self._features
def remove(self):
# for hook in self.hook:
self.hook.remove()
def reshape_transform(self, in_tensor):
result = in_tensor.reshape(in_tensor.size(0),
int(np.sqrt(in_tensor.size(1))), int(np.sqrt(in_tensor.size(1))), in_tensor.size(2))
result = result.transpose(2, 3).transpose(1, 2)
return result
# --------------------------------------------------------------------------------------------------------------------------
# 构建模型,并进行特征提取
# --------------------------------------------------------------------------------------------------------------------------
img_mask_size = 256
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model = UNet(....)
# map_location={'cuda:0': 'cpu'}
state_dict = torch.load('./state_dict/model.pth')
model.load_state_dict(state_dict['model'])
print('网络设置完毕 :成功载入了训练完毕的权重。')
model.to(device=device)
transformer = A.Compose([
A.Resize(img_mask_size, img_mask_size),
A.Normalize(
mean=(0.5835, 0.5820, 0.5841),
std=(0.1149, 0.1111, 0.1064),
max_pixel_value=255.0
),
ToTensorV2()
])
return_layers = ["encoder.norm1"]
e_model = FeatureExtractor(model=model, layers=return_layers)
image_file = ".\images"
image_file_path = os.path.join(image_file, str("15") + (".jpg"))
img = Image.open(image_file_path)
img_width, img_height = img.size
image_np = np.array(img)
augmented = transformer(image=image_np)
augmented_img = augmented['image'].to(device)
# 由于模型中存在BN层,其不允许推理的batchsize小于2,所以生成一个和原始影像相同大小尺度的虚拟图像使得batchsize=2。
virual_image = torch.randn(size=(3, img_mask_size, img_mask_size), dtype=torch.float32).to(device=device)
augmented_img = torch.stack([augmented_img, virual_image], dim=0)
print(augmented_img.shape)
output = e_model(augmented_img)
for keys, values in output.items():
output[keys] = values[0].unsqueeze(0)
pprint({keys : torch.sigmoid(values[0]).detach().shape for keys, values in output.items()})
# --------------------------------------------------------------------------------------------------------------------------
# 使用tensorboard保存特征图可视化结果
# --------------------------------------------------------------------------------------------------------------------------
from torchvision.utils import make_grid
from torch.utils.tensorboard.writer import SummaryWriter
writer = SummaryWriter("runs/test")
for keys, values in output.items():
values = torch.sigmoid(values[0]).cpu().detach().numpy()
imgs_ = np.empty(shape=(values.shape[0], 3, values.shape[1], values.shape[2]))
for index, batch_img in enumerate(values):
imgs_[index] = cv2.applyColorMap(np.uint8(batch_img * 255), cv2.COLORMAP_JET).transpose(2, 0, 1)
imgs_grid = make_grid(torch.from_numpy(imgs_), nrow=5, padding=2, pad_value=0)
cv2.namedWindow("imgs_grid", cv2.WINDOW_FULLSCREEN)
cv2.imshow("imgs_grid", imgs_grid.permute(1, 2, 0).numpy())
cv2.waitKey()
cv2.destroyAllWindows()
writer.add_images(keys + "_TEST", imgs_, 0, dataformats="NCHW")
writer.close()
可视化结果如下(以地表裂缝图像为例):
地裂缝图像以及分割结果
裂缝提取模型部分特征图可视化结果