【深度学习可视化系列】—— 特征图可视化(支持Vit系列模型的特征图可视化,包含使用Tensorboard对可视化结果进行保存)

news2024/11/26 18:33:42

【深度学习可视化系列】—— 特征图可视化(支持Vit系列模型的特征图可视化,包含使用Tensorboard对可视化结果进行保存)

import sys
import os
import torch
import cv2
import timm
import numpy as np 
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
from model.MitUnet import  MitUnet
from collections import OrderedDict
from typing import Dict, Iterable, Callable
from torch import nn, Tensor
from PIL import Image
from pprint import pprint


# --------------------------------------------------------------------------------------------------------------------------
# 构建模型特征图提取模型,输入参数为模型、以及需提取特征图层的key名称,该名称可通过model.named_modules()或model.named_children()获取
# --------------------------------------------------------------------------------------------------------------------------
class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        # assert layers is not None
        self.model = model
        self.layers = layers
        self._features = OrderedDict({layer: torch.empty(0) for layer in layers})
        self.hook = []

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            self.hook = layer.register_forward_hook(self.hook_func(layer_id))
            # self.hook.append(self.layer_id)

    def hook_func(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            # print("_____{}".format(output.dim()))   
            if output.dim() == 3:
                output = self.reshape_transform(in_tensor=output) 
            self._features[layer_id] = output
        return fn

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x)
        self.remove()
        return self._features
    
    def remove(self):
        # for hook in self.hook:
        self.hook.remove()

    def reshape_transform(self, in_tensor):
        result = in_tensor.reshape(in_tensor.size(0),
            int(np.sqrt(in_tensor.size(1))), int(np.sqrt(in_tensor.size(1))), in_tensor.size(2))

        result = result.transpose(2, 3).transpose(1, 2)
        return result
    
    
# --------------------------------------------------------------------------------------------------------------------------
# 构建模型,并进行特征提取
# --------------------------------------------------------------------------------------------------------------------------
img_mask_size = 256
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
model = UNet(....)
# map_location={'cuda:0': 'cpu'}
state_dict = torch.load('./state_dict/model.pth')
model.load_state_dict(state_dict['model'])
print('网络设置完毕 :成功载入了训练完毕的权重。')
model.to(device=device)
transformer = A.Compose([
    A.Resize(img_mask_size, img_mask_size),
    A.Normalize(
        mean=(0.5835, 0.5820, 0.5841),
        std=(0.1149, 0.1111, 0.1064),
        max_pixel_value=255.0
    ),
    ToTensorV2()
])
return_layers = ["encoder.norm1"]
e_model = FeatureExtractor(model=model, layers=return_layers)
image_file = ".\images"
image_file_path = os.path.join(image_file, str("15") + (".jpg"))
img = Image.open(image_file_path)
img_width, img_height = img.size
image_np = np.array(img)
augmented = transformer(image=image_np)
augmented_img = augmented['image'].to(device)  
# 由于模型中存在BN层,其不允许推理的batchsize小于2,所以生成一个和原始影像相同大小尺度的虚拟图像使得batchsize=2。
virual_image = torch.randn(size=(3, img_mask_size, img_mask_size), dtype=torch.float32).to(device=device)
augmented_img = torch.stack([augmented_img, virual_image], dim=0)
print(augmented_img.shape)
output = e_model(augmented_img)
for keys, values in output.items():
    output[keys] = values[0].unsqueeze(0) 
pprint({keys : torch.sigmoid(values[0]).detach().shape for keys, values in output.items()})


# --------------------------------------------------------------------------------------------------------------------------
# 使用tensorboard保存特征图可视化结果
# --------------------------------------------------------------------------------------------------------------------------
from torchvision.utils import make_grid
from torch.utils.tensorboard.writer import SummaryWriter

writer = SummaryWriter("runs/test")
for keys, values in output.items():
    values = torch.sigmoid(values[0]).cpu().detach().numpy()
    imgs_ = np.empty(shape=(values.shape[0], 3, values.shape[1], values.shape[2])) 
    for index, batch_img in enumerate(values):
        imgs_[index] =  cv2.applyColorMap(np.uint8(batch_img * 255), cv2.COLORMAP_JET).transpose(2, 0, 1)
    imgs_grid = make_grid(torch.from_numpy(imgs_), nrow=5, padding=2, pad_value=0)
    cv2.namedWindow("imgs_grid", cv2.WINDOW_FULLSCREEN)
    cv2.imshow("imgs_grid", imgs_grid.permute(1, 2, 0).numpy())
    cv2.waitKey()
	cv2.destroyAllWindows()
    
    writer.add_images(keys + "_TEST", imgs_, 0, dataformats="NCHW")
writer.close()

可视化结果如下(以地表裂缝图像为例):
请添加图片描述
​ 地裂缝图像以及分割结果
请添加图片描述

​ 裂缝提取模型部分特征图可视化结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/847825.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ubuntu 20.04 安装 Stable Diffusionn

步骤 1:安装 wget、git、Python3 和 Python3虚拟环境(如果已安装可忽略这步骤) sudo apt install wget git python3 python3-venv步骤 2:克隆 SD 项目到本地 git clone https://github.com/AUTOMATIC1111/stable-diffusion-webu…

亚信科技AntDB数据库与库瀚存储方案完成兼容性互认证,联合方案带来约20%性能提升

近日,亚信科技AntDB数据库与苏州库瀚信息科技有限公司自主研发的RISC-V数据库存储解决方案进行了产品兼容测试。经过双方团队的严格测试,亚信科技AntDB数据库与库瀚数据库存储解决方案完全兼容、运行稳定。除高可用性测试外,双方进一步开展TP…

nacos升级开启鉴权后,微服务无法连接的解决方案

版本: 软件版本号备注spring boot2.2.5.RELEASEspring-cloudHoxton.SR3spring-cloud-alibaba2.2.1.RELEASEnacos2.0.1从1.4.2版本进行升级。同时作为注册中心和配置中心 一、升级nacos版本,开启鉴权 1.在application.properties配置文件开启鉴权&…

【资料分享】全志科技T507-H开发板规格书

1 评估板简介 创龙科技TLT507-EVM是一款基于全志科技T507-H处理器设计的4核ARM Cortex-A53国产工业评估板,主频高达1.416GHz,由核心板和评估底板组成。核心板CPU、ROM、RAM、电源、晶振等所有器件均采用国产工业级方案,国产化率100%。同时,评估底板大部分元器件亦采用国产…

sentinel---滑动窗口的实现原理

sentinel有多种规则,包括:降级、限流、热点等等规则,这些规则均会涉及到时间因素,既在单位时间内的请求量满足各种条件之后的各种动作。 这里我们一起来探针一下sentinel中滑动窗口的实现 如上是一个滑动窗口的示意图。 这里先不…

电脑IP地址错误无法上网怎么办?

电脑出现IP地址错误后就将无法连接网络,从而无法正常访问互联网。那么当电脑出现IP地址错误时该怎么办呢? 确认是否禁用本地连接 你需要先确定是否禁用了本地网络连接,如果发现禁用,则将其启用即可。 启用方法:点击桌…

设计实现数据库表扩展的7种方式

设计实现数据库表扩展的7种方式 在软件开发过程中,数据库是一项关键技术,用于存储、管理和检索数据。数据库表设计是构建健壮数据库系统的核心环节之一。然而,随着业务需求的不断演变和扩展,数据库表中的字段扩展变得至关重要。 …

【TensorFlow】P0 Windows GPU 安装 TensorFlow、CUDA Toolkit、cuDNN

Windows 安装 TensorFlow、CUDA Toolkit、cuDNN 整体流程概述TensorFlow 与 CUDA ToolkitTensorFlow 是一个基于数据流图的深度学习框架CUDA 充分利用 NIVIDIA GPU 的计算能力CUDA Toolkit cuDNN 安装详细流程整理流程一:安装 CUDA Toolkit步骤一:获取CU…

GIS和倾斜摄影的关系?

GIS(地理信息系统)和倾斜摄影是两种在地理空间数据处理和分析中扮演重要角色的技术。但是我们总是会分不清二者,本文就带大家从不同角度了解二者之间的关系。 概念 GIS是一种用来捕获、存储、分析和展示地理空间数据的技术,它可以…

Java课题笔记~ Spring 集成 MyBatis

Spring 集成 MyBatis 将 MyBatis 与 Spring 进行整合,主要解决的问题就是将 SqlSessionFactory 对象交由 Spring 来管理。所以该整合,只需要将 SqlSessionFactory 的对象生成器SqlSessionFactoryBean 注册在 Spring 容器中,再将其注入给 Dao…

Apollo让自动驾驶如此简单

前言: 最近被新能源的电价闹的不行,买了电车的直呼上当了、不香了。但电车吸引人不只是公里油耗低,还有良好的驾车使用感。比如辅助驾驶、甚至是自动驾驶。今天来介绍一个头部自动驾驶平台Apollo,Apollo是一个开源的、自动驾驶的软…

【Hystrix技术指南】(4)故障切换的运作流程

[每日一句] 也许你度过了很糟糕的一天,但这并不代表你会因此度过糟糕的一生。 [背景介绍] 分布式系统的规模和复杂度不断增加,随着而来的是对分布式系统可用性的要求越来越高。在各种高可用设计模式中,【熔断、隔离、降级、限流】是经常被使…

iperf3-性能测试

iperf3-性能测试 安装1.apt安装2.源码安装 使用方法iperf原理测试参考文档性能测试客户端服务端 官方文档:https://iperf.fr/iperf-doc.php 安装 1.apt安装 sudo apt-get install iperf32.源码安装 # 按照官方说明安装 ./configure make sudo make install执行编…

OceanBase 4.1.0 clog 目录探究

基于OceanBase 4.x 版本如何统计租户每日 clog 日志生成量的背景下,探究以及如何查看租户 clog 的使用情况。 作者:姜宇 爱可生 DBA 团队成员,擅长数据库故障排查和处理。对技术抱有热忱,实践是检验真理的唯一标准~ 本文来源&…

对docker的简单理解

一款产品从开发到上线,从操作系统,到运行环境,再到应用配置。作为开发运维之间的协作,我们需要关心很多东西,这也是很多互联网公司都不得不面对的问题,特别是各种版本的迭代之后,不同版本环境的…

MySQL安装和卸载

1.MySQL概述 MySQL概述 MySQL是一个[关系型数据库管理系统],由瑞典MySQL AB 公司开发,2008年被sun公司收购, 2009sun又被oracle收购,所以属于 Oracle 旗下产品。MySQL 是最流行的关系型数据库管理系统之一,在 WEB 应用…

springboot(4)

AOP 1.AOP与OOP OOP(Object Oriented Programming,面向对象编程) AOP(Aspect Oriented Programming,面向切面编程) POP(Process Oriented Programming,面向过程编程) …

扩展卡尔曼滤波器代码

文章目录 前言问题状态向量和观测向量加性噪声的形式状态方程及求导观测方程及求导状态初始化过程噪声和观测噪声卡尔曼滤波过程code 前言 卡尔曼滤波器在1960年被卡尔曼发明之后,被广泛应用在动态系统预测。在自动驾驶、机器人、AR领域等应用广泛。卡尔曼滤波器使…

在R中比较两个矩阵是否相等

目录 方法一:使用all.equal()比较两个R对象是否近似相等 方法二:使用identical比较两个R对象是否精确相等。 方法一:使用all.equal()比较两个R对象是否近似相等 使用函数:all.equal(x,y) 比较两个R对象x和y是否近似相等 > M1…

抽象工厂模式(C++)

定义 提供一个接口,让该接口负责创建一系列“相关或者相互依赖的对象”,无需指定它们具体的类。 使用场景 在软件系统中,经常面临着“一系列相互依赖的对象”的创建工作;同时,由于需求的变化,往往存在更多系列对象的创建工作。如何应对这种…