Pytorch可视化特征图(代码 亲测可用)

news2024/10/6 12:21:43

2013年Zeiler和Fergus发表的《Visualizing and Understanding Convolutional Networks》

 

 

早期LeCun 1998年的文章《Gradient-Based Learning Applied to Document Recognition》中的一张图也非常精彩,个人觉得比Zeiler 2013年的文章更能给人以启发。从下图的F6特征,我们可以清楚地看到原始书写差异非常大的图片如何在深层特征中体现出其不变性。
 

pytorch 有专门的接口提取特征图

拿到特征图的方法有多种,有人可以从输入开始,一个一个算子地让网络做前向运算,直到想要的特征图处将其返回,这种方法尽管也可行,但略有些麻烦。实际上PyTorch给了一个专用接口可以在前向过程中获取到特征图,这个接口是torch.nn.Module.register_forward_hook。当我们拿到特征图后,PyTorch又有专门的画图和保存图片的接口:torchvision.utils.make_grid和torchvision.utils.save_image,非常方便。
 

展示代码,复制粘贴就可以使用

# -*- coding: utf-8 -*-
import os
import shutil
import time
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import datasets
import torchvision.utils as vutil
from torch.utils.data import DataLoader
import torchsummary

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCH = 1
LR = 0.001
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 32
BASE_CHANNEL = 32
INPUT_CHANNEL = 1
INPUT_SIZE = 28
MODEL_FOLDER = './save_model'
IMAGE_FOLDER = './save_image'
INSTANCE_FOLDER = None


class Model(nn.Module):
    def __init__(self, input_ch, num_classes, base_ch):
        super(Model, self).__init__()

        self.num_classes = num_classes
        self.base_ch = base_ch
        self.feature_length = base_ch * 4

        self.net = nn.Sequential(
            nn.Conv2d(input_ch, base_ch, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(base_ch, base_ch * 2, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(base_ch * 2, self.feature_length, kernel_size=3,
                      padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(output_size=(1, 1))
        )
        self.fc = nn.Linear(in_features=self.feature_length,
                            out_features=num_classes)

    def forward(self, input):
        output = self.net(input)
        output = output.view(-1, self.feature_length)
        output = self.fc(output)
        return output


def load_dataset():
    train_dataset = datasets.MNIST(root='./data',
                                   train=True,
                                   transform=transforms.ToTensor(),
                                   download=True)
    test_dataset = datasets.MNIST(root='./data',
                                  train=False,
                                  transform=transforms.ToTensor(),
                                  download=True)
    return train_dataset, test_dataset


def hook_func(module, input, output):
    """
    Hook function of register_forward_hook

    Parameters:
    -----------
    module: module of neural network
    input: input of module
    output: output of module
    """
    image_name = get_image_name_for_hook(module)
    data = output.clone().detach()
    data = data.permute(1, 0, 2, 3)
    vutil.save_image(data, image_name, pad_value=0.5)


def get_image_name_for_hook(module):
    """
    Generate image filename for hook function

    Parameters:
    -----------
    module: module of neural network
    """
    os.makedirs(INSTANCE_FOLDER, exist_ok=True)
    base_name = str(module).split('(')[0]
    index = 0
    image_name = '.'  # '.' is surely exist, to make first loop condition True
    while os.path.exists(image_name):
        index += 1
        image_name = os.path.join(
            INSTANCE_FOLDER, '%s_%d.png' % (base_name, index))
    return image_name


if __name__ == '__main__':
    time_beg = time.time()

    train_dataset, test_dataset = load_dataset()
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=TRAIN_BATCH_SIZE,
                              shuffle=True)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=TEST_BATCH_SIZE,
                             shuffle=False)

    model = Model(input_ch=1, num_classes=10, base_ch=BASE_CHANNEL).cuda()
    torchsummary.summary(
        model, input_size=(INPUT_CHANNEL, INPUT_SIZE, INPUT_SIZE))
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    train_loss = []
    for ep in range(EPOCH):
        # ----------------- train -----------------
        model.train()
        time_beg_epoch = time.time()
        loss_recorder = []
        for data, classes in train_loader:
            data, classes = data.cuda(), classes.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, classes)
            loss.backward()
            optimizer.step()

            loss_recorder.append(loss.item())
            time_cost = time.time() - time_beg_epoch
            print('\rEpoch: %d, Loss: %0.4f, Time cost (s): %0.2f' % (
                ep, loss_recorder[-1], time_cost), end='')

        # print train info after one epoch
        train_loss.append(loss_recorder)
        mean_loss_epoch = torch.mean(torch.Tensor(loss_recorder))
        time_cost_epoch = time.time() - time_beg_epoch
        print('\rEpoch: %d, Mean loss: %0.4f, Epoch time cost (s): %0.2f' % (
            ep, mean_loss_epoch.item(), time_cost_epoch), end='')

        # save model
        os.makedirs(MODEL_FOLDER, exist_ok=True)
        model_filename = os.path.join(MODEL_FOLDER, 'epoch_%d.pth' % ep)
        torch.save(model.state_dict(), model_filename)

        # ----------------- test -----------------
        model.eval()
        correct = 0
        total = 0
        for data, classes in test_loader:
            data, classes = data.cuda(), classes.cuda()
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += classes.size(0)
            correct += (predicted == classes).sum().item()
        print(', Test accuracy: %0.4f' % (correct / total))

    print('Total time cost: ', time.time() - time_beg)

    # ----------------- visualization -----------------
    # clear output folder
    if os.path.exists(IMAGE_FOLDER):
        shutil.rmtree(IMAGE_FOLDER)

    model.eval()
    modules_for_plot = (torch.nn.ReLU, torch.nn.Conv2d,
                        torch.nn.MaxPool2d, torch.nn.AdaptiveAvgPool2d)
    for name, module in model.named_modules():
        if isinstance(module, modules_for_plot):
            module.register_forward_hook(hook_func)

    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=1,
                             shuffle=False)
    index = 1
    for data, classes in test_loader:
        INSTANCE_FOLDER = os.path.join(
            IMAGE_FOLDER, '%d-%d' % (index, classes.item()))
        data, classes = data.cuda(), classes.cuda()
        outputs = model(data)

        index += 1
        if index > 20:
            break

运行结果;

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 28, 28]             320
              ReLU-2           [-1, 32, 28, 28]               0
         MaxPool2d-3           [-1, 32, 14, 14]               0
            Conv2d-4           [-1, 64, 14, 14]          18,496
              ReLU-5           [-1, 64, 14, 14]               0
         MaxPool2d-6             [-1, 64, 7, 7]               0
            Conv2d-7            [-1, 128, 7, 7]          73,856
              ReLU-8            [-1, 128, 7, 7]               0
 AdaptiveAvgPool2d-9            [-1, 128, 1, 1]               0
           Linear-10                   [-1, 10]           1,290
================================================================
Total params: 93,962
Trainable params: 93,962
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.74
Params size (MB): 0.36
Estimated Total Size (MB): 1.10
----------------------------------------------------------------
Epoch: 0, Mean loss: 0.6502, Epoch time cost (s): 5.53, Test accuracy: 0.9360
Total time cost:  13.36158013343811

Process finished with exit code 0

特征图被保存在当前文件的文件夹下:

MODEL_FOLDER = './save_model'
IMAGE_FOLDER = './save_image'

这两个文件夹要提前创建好

展示一下  在我电脑上运行的结果

打开一个文件夹  看一下

 

展示上面的一张图

 

 

代码中的重要的内容在visualization部分:

  1. test_loader的batch_size设置为1。
  2. 下面一段代码用来设置register_forward_hook,只要我们在跑前向之前将register_forward_hook设置好,那么在它就会在前向的时候被调用。

 

modules_for_plot = (torch.nn.ReLU, torch.nn.Conv2d,
                    torch.nn.MaxPool2d, torch.nn.AdaptiveAvgPool2d)
for name, module in model.named_modules():
    if isinstance(module, modules_for_plot):
        module.register_forward_hook(hook_func)

 3.hook_func(module, input, output)需要自己实现。该函数的参数是固定的,不能自己随意加减,这就带来了一个比较麻烦的事情,我们没法把要保存的文件名等信息作为参数传给hook_func,所以这里我使用了一种比较粗暴的方式。文件名的前缀是module的类型名,后面会跟一个序号,序号是同类module在网络中出现的顺序。然而这个序号也没法传给hook_func,所以我直接从硬盘上从序号1开始,通过判断文件的存在性来决定下一个序号,由get_image_name_for_hook函数实现。比较粗暴和无奈的一种方法,如果哪位老兄有更好的方式欢迎留言指教。
4. hook_func中,在保存图片之前,需要将input或output的维度从[1, C, H, W]调整为[C, 1, H, W]
save_image()每一行的特征图数量默认是8,可以通过nrow参数进行修改。pad_value是特征图中间填充的间隙的颜色,尽管PyTorch将其定义为整数,但是由于Tensor一般是float型,所以还是输入个[0, 1]之间的浮点数才能真正生效。
 

感谢  点赞 收藏  + 关注

(25条消息) Pytorch可视化特征图_拜阳的博客-CSDN博客_pytorch特征图可视化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/124195.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

会议OA项目-首页

目录一、Flex布局简介什么是flex布局?flex属性学习地址:案例演示二、轮播图组件及mockjs三、会议OA小程序首页布局一、Flex布局简介 布局的传统解决方案,基于盒状模型,依赖 display属性 position属性 float属性 什么是flex布局…

简单有效的Mac内存清理方法,不用收藏也能记住

Mac电脑使用的时间越久,系统的运行就会变的越卡顿,这是Mac os会出现的正常现象,卡顿的原因主要是系统缓存文件占用了较多的磁盘空间,或者Mac的内存空间已满。如果你的Mac运行速度变慢,很有可能是因为磁盘内存被过度占用…

如何理解并记忆DataFrame中的Axis参数

当我们遇到有axis参数的方法时,脑子里的第一反应应该是:这个方法一定是沿着某一方向进行某种“聚合”或者“过滤”操作。在此场景下,Axis参数就是用来设定操作方向的:是垂直方向还是水平方向? axis0: 一行一行推进&…

【微服务架构实战】第1篇之API网关概述

1.网关概述 采用分布式、微服务的架构模式开发系统时,API 网关是整个系统中必不可少的一环。 1.1 没有网关会有什么问题? 在微服务架构模式下,1个系统会被拆分成多个微服务,如果每个微服务都直接暴露给调用方,会有以…

MySQL主键和唯一键的区别

主键和唯一键基本知识参考这篇文章 MySQL表的约束 ,本篇文章主要是谈一谈主键和唯一键的区别从而更好的理解唯一键和主键。 在上篇文章中已经提到 主键: primary key 用来唯一的约束该字段里面的数据,不能重复,不能为空&#x…

vue父页面调用子页面及方法及传参,鼠标光标定位

项目场景: vue父页面调用子页面及方法 问题描述 vue中父界面调用子界面及方法时界面可以调用,但是调用方法的时候第一次报错,但是关掉界面再次重新打开就没问题了 原因分析: 在我之前添加鼠标指针定位的时候,如果在…

记录scoped属性的使用和引发的问题

背景 在对表格数据进行样式处理时,通过业务逻辑判断,进行对符合要求的表格填充背景色,没有符合预期的效果。反复排查校验代码和判断逻辑,都没有什么问题,可能还是样式上出现问题。再通过F12 选取元素对表格设置背景色时…

获取树形结构中,父节点下所有子/孙节点(递归方式)

获取树形结构中,父节点下所有子/孙节点(递归方式)1 树形结构(TreeItem类)2 测试代码(main函数)3 运行效果1 树形结构(TreeItem类) 这里通用型树形结构为TreeItem类&…

初学Java web(七)RequestResponse

Request&Response Request:获取请求数据 Response:设置响应数据 一.Request对象 1.Request继承体系 Tomcat需要解析请求数据,封装为requestx对象并且创建requestx对象传递到service方法中 使用request对象,查阅JavaEE API文档的HttpServletReque…

rocketMq架构原理精华分析(一)

rocketMq架构原理精华分析是我们这篇文章的核心,从消息中间件的对比、架构模型、消息模型、常见问题等逐一分析: 一、中间件对比: RabbitMq 集群效果不太好,底层不是java 语言,研究原理比较困难; Kafka是…

前端面试题之计算机网络篇 OSI七层网络参考模型

互联网数据传输原理 |OSI七层网络参考模型 OSI七层网络参考模型 应用层:产生网络流量的程序表示层:传输之前是否进行加密或者压缩处理会话层:查看会话,查木马 netstat-n传输层:可靠传输、流量控制、不可…

亿级流量的互联网项目如何快速构建?手把手教你构建思路

一. 大流量的互联网项目 1.项目背景 索尔老师之前负责的一个项目,业务背景是这样的。城市的基础设施建设是每个城市和地区都会涉及到的,如何在基建工地中实现人性化管理,是当前项目的主要诉求。该项目要实现如下目标: 工地工人的…

C语言实现http下载器(附代码)

C语言实现http的下载器。 例:做OTA升级功能时,我们能直接拿到的往往只是升级包的链接,需要我们自己去下载,这时候就需要用到http下载器。 这里分享一个: 功能: 1、支持chunked方式传输的下载 2、被重定…

Apollo开放平台8.0发布:多维升级“为开发者而生”

Apollo开放平台8.0重磅发布:多维升级“为开发者而生” Apollo开放平台迎来8.0版本,百度自动驾驶开放平台迈向易用性时代 百度Apollo EDU计划进展公布:已覆盖自动驾驶技术人才33.5万、700多所院校 Apollo Studio学习实践社区上线,新…

剑指offer----C语言版----第一天

目录 1. 数组中重复的数字Ⅰ 1.1 题目描述 1.2 思路一 1.3 思路二 1.4 思路三(最优解) 1. 数组中重复的数字Ⅰ 原题:剑指 Offer 03. 数组中重复的数字 - 力扣(LeetCode)https://leetcode.cn/problems/shu-zu-zhong-…

Python语言快速入门上

目录 1、前言 2、变量和常量 1)Python对象模型 2)Python变量 二、运算符和表达式 【运算符和表达式】 【位运算符】 【逻辑运算符】 【成员运算符】 【身份运算符】 【常用内置函数】 【基本输入输出】 【模块导入与使用】 【Python代码编…

【PCB专题】Allegro导出3D文件

在PCB布局时,已经决定了大部分器件要放置的位置。如接口、主要的芯片、模块等。因为放置好器件后可能与结构干涉,如果没有发现,那么不得不在Layout的后期调整器件位置,增加工作量。所以前期布局基本确定后就需要导出3D文件给结构工程师,由他查看是否有器件与结构、螺丝孔等…

全志Tina Linux Display 开发指南支持百问网T113 D1-H哪吒DongshanPI-D1s V853-Pro等开发板

1 概述 让显示应用开发人员了解显示驱动的接口及使用流程,快速上手,进行开发;让新人接手工作时能快速地了解驱动接口,进行调试排查问题。sunxi 平台DE1.0/DE2.0。与显示相关的应用开发人员,及与显示相关的其他模块的开…

操作系统期末考试必会题库1——引言+用户界面

1.请简要描述操作系统的定义及其功能。 操作系统定义: 是计算机系统中的一个系统软件,是一些程序模块的集合 ,它们管理和控制计算机系统中的软硬件资源,合理的组织计算机的工作流程,以便有效的利用这些资源为用户提供一…

Linux用户权限详解

为什么有人冲了钱就能享受至尊VIP待遇?为什么冲了黄钻、绿钻、紫钻就会享受一些特殊活动呢?我们起初都是一群普通用户,为什么有些人就能通过某些手段得到一些异于常人的服务呢?这其中的奥秘是什么呢?接下来带大家了解这…