【Pytorch】六行代码实现:特征图提取与特征图可视化

news2024/11/24 7:53:29

前言

之前记录过特征图的可视化:Pytorch实现特征图可视化,当时是利用IntermediateLayerGetter 实现的,但是有很大缺陷,只能获取到一级的子模块的特征图输出,无法获取内部二级子模块的输出。今天补充另一种Pytorch官方实现好的特征提取方式,非常好用!


特征图提取

  • 前言
  • 一、Torch FX
  • 二、特征提取
    • 1.使用get_graph_node_names提取各个节点
    • 2.使用create_feature_extractor提取输出
    • 3.六行代码可视化特征图
  • 三、Reference


一、Torch FX

首先是Torch FX的介绍:FX Blog(具体可参考Reference)

FX based feature extraction is a new TorchVision utility that lets us access intermediate transformations of an input during the forward pass of a PyTorch Module. It does so by symbolically tracing the forward method to produce a graph where each node represents a single operation. Nodes are named in a human-readable manner such that one may easily specify which nodes they want to access.
Did that all sound a little complicated? Not to worry as there’s a little in this article for everyone. Whether you’re a beginner or an advanced deep-vision practitioner, chances are you will want to know about FX feature extraction. If you still want more background on feature extraction in general, read on. If you’re already comfortable with that and want to know how to do it in PyTorch, skim ahead to Existing Methods in PyTorch: Pros and Cons. And if you already know about the challenges of doing feature extraction in PyTorch, feel free to skim forward to FX to The Rescue.


也就是我们后面调用的特征提取函数是基于Torch FX实现的。总之一句话:基于FX的特征提取是一种新的TorchVision实用程序,它允许我们在PyTorch模块的前向传递过程中访问输入的中间值。


二、特征提取

1.使用get_graph_node_names提取各个节点

首先依然是查看各个网络的子层

#首先定义一个模型,这里直接加载models里的预训练模型
model = torchvision.models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
#查看模型的各个层,
for name in model.named_children():
    print(name[0])
#输出,相当于把ResNet的分成了10个层
"""
conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc"""

在这里插入图片描述


之前是利用IntermediateLayerGetter 实现的,但是有很大缺陷,只能获取到一级的子模块的特征图输出,无法获取内部二级子模块的输出。比如不能获取layer2内部第一个BasicBlock的特征图输出。现在可以利用 get_graph_node_names获取任意前向传播的子节点。

import torchvision
import torch
from torchvision.models.feature_extraction import get_graph_node_names

model = torchvision.models.resnet18(
    weights=torchvision.models.ResNet18_Weights.DEFAULT)
nodes, _ = get_graph_node_names(model)
nodes
# 输出如下
"""
['x',
 'conv1',
 'bn1',
 'relu',
 'maxpool',
 'layer1.0.conv1',
 'layer1.0.bn1',
 'layer1.0.relu',
 'layer1.0.conv2',
 'layer1.0.bn2',
 'layer1.0.add',
 'layer1.0.relu_1',
 'layer1.1.conv1',
 'layer1.1.bn1',
 'layer1.1.relu',
 'layer1.1.conv2',
 'layer1.1.bn2',
 'layer1.1.add',
 'layer1.1.relu_1',
 'layer2.0.conv1',
 'layer2.0.bn1',
 'layer2.0.relu',
 'layer2.0.conv2',
 'layer2.0.bn2',
 'layer2.0.downsample.0',
 'layer2.0.downsample.1',
 'layer2.0.add',
 'layer2.0.relu_1',
 'layer2.1.conv1',
 'layer2.1.bn1',
 'layer2.1.relu',
 'layer2.1.conv2',
 'layer2.1.bn2',
 'layer2.1.add',
 'layer2.1.relu_1',
 'layer3.0.conv1',
 'layer3.0.bn1',
 'layer3.0.relu',
 'layer3.0.conv2',
 'layer3.0.bn2',
 'layer3.0.downsample.0',
 'layer3.0.downsample.1',
 'layer3.0.add',
 'layer3.0.relu_1',
 'layer3.1.conv1',
 'layer3.1.bn1',
 'layer3.1.relu',
 'layer3.1.conv2',
 'layer3.1.bn2',
 'layer3.1.add',
 'layer3.1.relu_1',
 'layer4.0.conv1',
 'layer4.0.bn1',
 'layer4.0.relu',
 'layer4.0.conv2',
 'layer4.0.bn2',
 'layer4.0.downsample.0',
 'layer4.0.downsample.1',
 'layer4.0.add',
 'layer4.0.relu_1',
 'layer4.1.conv1',
 'layer4.1.bn1',
 'layer4.1.relu',
 'layer4.1.conv2',
 'layer4.1.bn2',
 'layer4.1.add',
 'layer4.1.relu_1',
 'avgpool',
 'flatten',
 'fc']
"""

get_graph_node_names把前向传播的各个节点都列出来了形成了一个列表。比如列表中的x表示我们的输入;layer1.0.conv2表示layer1的第1个BasicBlock的conv2节点;layer3.1.conv2表示layer3的第2个BasicBlock的conv2节点;这些节点和我们上图方框中圈出来的是一一对应的,可以结合自己的网络结构具体分析。

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(3, 96, 11, 4, 2),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2),
                                   )

        self.conv2 = nn.Sequential(nn.Conv2d(96, 256, 5, 1, 2),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2),
                                   )

        self.conv3 = nn.Sequential(nn.Conv2d(256, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 256, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2))


        self.fc=nn.Sequential(nn.Linear(256*6*6, 4096),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(4096, 4096),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(4096, 100),
                                )

    def forward(self, x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        output=self.fc(x.view(-1, 256*6*6))
        return output
    
model=AlexNet()
nodes, _ = get_graph_node_names(model)
nodes
# 输出如下
['x',
 'conv1.0',
 'conv1.1',
 'conv1.2',
 'conv2.0',
 'conv2.1',
 'conv2.2',
 'conv3.0',
 'conv3.1',
 'conv3.2',
 'conv3.3',
 'conv3.4',
 'conv3.5',
 'conv3.6',
 'view',
 'fc.0',
 'fc.1',
 'fc.2',
 'fc.3',
 'fc.4',
 'fc.5',
 'fc.6']

如果是自定义网络结构,在__init__中初始化了self.conv1self.conv2self.conv3self.fc与输出列表相对应。
conv3为例:

 self.conv3 = nn.Sequential(nn.Conv2d(256, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 256, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2))

总共定义了7层,3个卷积层、3个激活层、1个池化层。 输出节点列表中的conv3.0就表示conv3的第一个节点即第一个卷积层nn.Conv2d(256, 384, 3, 1, 1),同理, conv3.1表示conv3的第二个节点即nn.ReLU()

2.使用create_feature_extractor提取输出

在获取节点信息之后,我么可以利用create_feature_extractor来获取对应节点层的输出。所以get_graph_node_names只是帮助我们获取节点层的信息。

比如,我只想获取layer3layer4内部的第一个卷积层的输出即layer3.0.conv1, layer4.0.conv1

import torch
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor

# 根据get_graph_node_names得到的节点层信息
# 定义想要得到的输出层
features = ['layer3.0.conv1', "layer4.0.conv1"]

model = torchvision.models.resnet18(
					weights=torchvision.models.ResNet18_Weights.DEFAULT)
					
# return_nodes参数就是返回对应的输出
feature_extractor = create_feature_extractor(model, return_nodes=features)
# 定义输入
x=torch.ones(1, 3, 224, 224)
# 得到一个我们想要的输出层的字典
out = feature_extractor(x)
out

# tensor即对应的输出
"""
{'layer3.0.conv1': tensor(...),
 'layer4.0.conv1': tensor(...) }
"""

当然,并不是一定要完全按照get_graph_node_names得到的节点层信息来定义输出层。比如,我只想获取layer3整个层的输出特征图,我并不关心layer3内部子层的输出:

import torch
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor

# 定义layer3即可
# 其他层同理
features = ['layer3']
model = torchvision.models.resnet18(
    weights=torchvision.models.ResNet18_Weights.DEFAULT)
feature_extractor = create_feature_extractor(model, return_nodes=features)
# 定义输入
x=torch.ones(1, 3, 224, 224)
# 得到一个我们想要的输出层的字典
out = feature_extractor(x)
out
"""
{'layer3': tensor(...)}
"""


return_nodes参数也可以传入一个字典,字典的键是节点层,值是自定义别名。比如{"layer3":"output1","layer4":"output2"}

features = {"layer3":"output1","layer4":"output2"}
model = torchvision.models.resnet18(
    weights=torchvision.models.ResNet18_Weights.DEFAULT)
feature_extractor = create_feature_extractor(model, return_nodes=features)
x=torch.ones(1, 3, 224, 224)
out = feature_extractor(x)
out
# 输出如下
"""
{'output1': tensor(...),
 'output2': tensor(...)}

"""

3.六行代码可视化特征图

import torch
import torchvision
from PIL import Image
import torchvision.transforms as transforms
from matplotlib import pyplot as plt
from torchvision.models.feature_extraction import create_feature_extractor


transform = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 

model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)

feature_extractor = create_feature_extractor(model, return_nodes={"conv1":"output"})

original_img = Image.open("dog.jpg")

img=transform(original_img).unsqueeze(0)

out = feature_extractor(img) 

# 这里没有分通道可视化
plt.imshow(out["output"][0].transpose(0,1).sum(1).detach().numpy())

在这里插入图片描述

在这里插入图片描述

三、Reference

Torch FX官方文档:Torch FX官方文档介绍
Torch FX Blog:Feature Extraction in TorchVision using Torch FX
在这里插入图片描述
官方对四种获取特征输出的方式进行了对比,这篇Blog写的比较详细,可以仔细看看。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/467521.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ChatGPT如何写作-ChatGPT写作程序

ChatGPT如何写作 ChatGPT是一款自然语言处理模型,它无法像人类一样进行“写作”。但是,您可以利用ChatGPT的生成文本功能来帮助您生成文字。以下是一些使用ChatGPT写作的建议: 确定主题和目标受众。在开始写作之前,请确保您清楚知…

凝心聚力,携“源”出海:开源社顾问委员会2023年第一季度会议圆满举办!

2023 年 3 月 25 日,开源社顾问委员会(以下简称"顾问委员会")第一季度会议在北京圆满召开。这是顾问委员会自 2018 年成立以来的第 17 次全体委员会议。 为增进顾问委员会成员交流,加强开源社社区建设,实现开…

第五章 资源包使用

游戏开发中会大量使用模型文件,图片文件,这些资源都需要事先导入到项目中去。导入的方式非常简单,将这些文件直接复制到项目中的Assets目录下即可。Unity 会在文件添加到 Assets 文件夹时自动检测到这些文件并同步显示在Project视图中。 Uni…

内分泌失调对身体有什么影响?

体内各种荷尔蒙的平衡,可以维持内分泌的稳定,当生活节奏被打乱,就会导致熬夜、入睡困难、压力过大、不按时就餐、久坐、情绪不稳定等。 对此,内分泌失调都是不小的问题,都是会影响身体的各个部位的。 内分泌对身体有什…

【U8+】用友U8+产品-操作系统、数据库、浏览器推荐支持一览表

【业务场景】 大家平时在服务、实施过程中, 经常被问到各个版本的产品支持什么版本操作系统、数据库、浏览器? 根据各个版本发版说明, 总结了操作系统、数据库、浏览器推荐使用一览表。 软件版本与电脑操作系统版本相辅相承, 一方…

Redis 缓存穿透、缓存雪崩、缓存击穿

缓存穿透 缓存穿透是指客户端请求的数据在缓存中和数据库中都不存在,这样缓存永远不会生效,这些请求都会打到数据库。 常见的解决方案有两种: 缓存空对象 优点:实现简单,维护方便 缺点: 额外的内存消耗 可…

chatGPT写文章一半不写了-如何让chatGPT写完整文章

chatGPT不生成内容的原因有哪些 当ChatGPT不生成内容时,可能有如下原因: 数据限制:ChatGPT的生成能力是建立在其训练数据的基础上的。如果输入的内容领域、主题和题材不在其数据范围内,ChatGPT将无法生成非常有意义和具体的内容。…

图像修补论文阅读:MAT算法笔记

标题:MAT: Mask-Aware Transformer for Large Hole Image Inpainting 会议:CVPR2022 论文地址:https://ieeexplore.ieee.org/document/9879508/ 官方代码:https://github.com/fenglinglwb/MAT 作者单位:香港中文大学、…

第十一讲 常用数据结构之字符串

第二次世界大战促使了现代电子计算机的诞生,世界上的第一台通用电子计算机名叫 ENIAC(电子数值积分计算机),诞生于美国的宾夕法尼亚大学,占地167平米,重量约27吨,每秒钟大约能够完成约5000次浮点…

Ansible基础和命令行模块操作

目录 1.Ansible介绍 1.Ansible能做什么? 2.Ansible的特性和原理 2.Ansible部署 3.Ansible命令模块 1.command模块 2.shell模块 3.cron模块 4.user模块 5.group模块 7.file模块 8.hostname 模块 9.ping 模块 10. yum 模块 11.service/systemd 模块 1…

【私有云底层】理解OpenStack核心组件

文章目录 👹 关于作者一、Keystone 身份认证服务Keystone 架构工作流程 二、Glance 镜像服务Glance 架构磁盘与容器Glance 工作流程 三、Placement 放置服务Placement 工作流程 四、Nova 计算服务Nova 架构Nova 工作流程 五、Neutron 网络服务Neutron 架构Neutron 支…

jstat命令查看jvm内存情况及GC内存变化

命令格式 jstat [Options] pid [interval] [count] 参数说明: Options,选项,一般使用 -gc、-gccapacity查看gc情况 pid,VM的进程号,即当前运行的java进程号 interval,间隔时间(按该时间频率自动刷新当前内存…

Shell脚本之条件测试、if、case条件测试语句

目录 一、条件测试1.1test命令1.2文件测试1.2.1文件测试常见选项 1.3整数值比较1.4字符串比较1.5逻辑测试 二、if语句2.1单分支结构2.2双分支结构2.3多分支结构 三、case语句 一、条件测试 1.1test命令 测试特定的表达式是否成立,当条件成立,测试语句的…

Android Studio 2021 导出aar到Unity

1,新建一个新工程,创建一个Empty Activity 2.下面的都用默认即可 3.修改工程一些配置 修改setting.gradle maven { url https://maven.aliyun.com/repository/google } maven { url https://maven.aliyun.com/repository/public } maven { url https://maven.aliyu…

【matplotlib】可视化解决方案——如何正确设置轴长度和范围

概述 在 matplotlib 绘图时,往往需要对坐标轴进行设置,默认情况下,每一个绘图的最后都会调用 plt.autoscale() 方法,这个方法的底层是 gca().autoscale(enableenable, axisaxis, tighttight),本质是调用当前 Axes 对象…

【HDU - 6558】The Moon(概率dp)

ps:初学概率dp,所以 就算是板子也 是看了非常久,好在最后还是学会了qwq… 文章目录 题意思路代码总结 题意 思路 概率dp通常为从能够得到的状态去进行转移,在q为100%的时候,我们能够知道赢的概率为 p,那么赢的期望就是…

java的学习,刷题

先来点题目看看 1031. 两个非重叠子数组的最大和 难度中等249收藏分享切换为英文接收动态反馈 给你一个整数数组 nums 和两个整数 firstLen 和 secondLen,请你找出并返回两个非重叠 子数组 中元素的最大和,长度分别为 firstLen 和 secondLen 。 长度…

【c语言】详解const常量修饰符 | 指针变量的不同const修饰

创作不易&#xff0c;本篇文章如果帮助到了你&#xff0c;还请点赞支持一下♡>&#x16966;<)!! 主页专栏有更多知识&#xff0c;如有疑问欢迎大家指正讨论&#xff0c;共同进步&#xff01; 给大家跳段街舞感谢支持&#xff01;ጿ ኈ ቼ ዽ ጿ ኈ ቼ ዽ ጿ ኈ ቼ ዽ ጿ…

网络:DPDK复习相关知识点

1.转发模型&#xff1a; 1.1 运行至完成&#xff1a;run to complate &#xff08;RTC&#xff09; 参考笔记&#xff1a;DPDK介绍-CSDN博客 选择哪些核可以被DPDK使用&#xff0c;最后把处理对应收发队列的线程绑定到对应的核上&#xff0c;每个报文的生命周期都只能在其中一个…

电脑桌面日历怎么设置?超简单方法分享!

案例&#xff1a;电脑桌面日历怎么设置&#xff1f; 【最近因为工作的原因到了国外&#xff0c;但是电脑桌面的日历和时间一直都是错误的&#xff0c;急求一个设置电脑桌面日历的方法&#xff01;感谢大家&#xff01;】 电脑桌面日历是一种方便实用的工具&#xff0c;它可以…