Pytorch ——特征图的可视化

news2025/2/26 23:17:07

文章目录

  • 前言
  • 一、torchvision.models._utils.IntermediateLayerGetter
    • *注意:torcvision的最新版本0.13,已经取消了pretrained=True这个参数,并且打算在0.15版正式移除,如果用pretrained这个参数会出现warring警告。现在加载与训练权重的参数改成了**weights**,这样可以加载不同版本的预训练权重,比如models.ResNet18_Weights.DEFAULT,就加载默认最新的ResNet18权重文件,还有其他参数形式,具体参考官网*
  • 二、示例
    • 1.ResNet50特征图可视化
    • 原图
    • 特征图
    • 2.AlexNet可视化
  • 总结


前言

Pytroch中间层的特征图可视化,网上已经有很多教程,比如用hook钩子函数,但是代码都写得不是很清楚,所以还是自己去摸索一下。


一、torchvision.models._utils.IntermediateLayerGetter

IntermediateLayerGetter这个函数是在看DETR源码时发现的,它的作用很简单,记录我们想要的中间层的输出。看个官方给出的例子:

import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from matplotlib import pyplot as plt

model = torchvision.models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
new_model = torchvision.models._utils.IntermediateLayerGetter(model, {'layer1': 'feat1', 'layer3': 'feat2'})
out = new_model(torch.rand(1, 3, 224, 224))

print([(k, v.shape) for k, v in out.items()])  # 其中v是对应层的输出,也就是我们要得到的特征图Tensor

#输出
"[('feat1', torch.Size([1, 64, 56, 56])), ('feat2', torch.Size([1, 256, 14, 14]))]"

注意:torcvision的最新版本0.13,已经取消了pretrained=True这个参数,并且打算在0.15版正式移除,如果用pretrained这个参数会出现warring警告。现在加载与训练权重的参数改成了weights,这样可以加载不同版本的预训练权重,比如models.ResNet18_Weights.DEFAULT,就加载默认最新的ResNet18权重文件,还有其他参数形式,具体参考官网

这里详细说一下

#首先定义一个模型,这里直接加载models里的预训练模型
model = torchvision.models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
#查看模型的各个层,
for name in model.named_children():
    print(name[0])
#输出,相当于把ResNet的分成了10个层
"""
conv1
bn1
relu
maxpool
layer1
layer2
layer3
layer4
avgpool
fc"""

可以看到ResNet18的结构被分为了10个部分,和下图的网络结构是一一对应的,conv1、bn1、relu、maxpool这四个对应第一层的卷积conv1,layer1对应图中的conv2_x,也就是一个残差结构,同理layer2对应conv3_x,以此类推。

在这里插入图片描述
比如,我想要layer1(conv2_x)和layer2(conv3_x)的输出,那么只需要构建一个字典,{‘layer1’: ‘feat1’, ‘layer2’: ‘feat2’},feat1、feat2是我们的重命名,可以随意输入自己想要的名字。

#现在我们把model传进IntermediateLayerGetter
new_model = torchvision.models._utils.IntermediateLayerGetter(model, {'layer1': 'feat1', 'layer2': 'feat2'})
out = new_model(torch.rand(1, 3, 224, 224))
print([(k,v.shape) for  k,v in out.items()])

#输出
"""
[('feat1', torch.Size([1, 64, 56, 56])), ('feat2', torch.Size([1, 128, 28, 28]))]
"""

二、示例

1.ResNet50特征图可视化

代码如下:

# 返回输出结果
import cv2
import torchvision
import torch
from matplotlib import pyplot as plt
import numpy as np


#定义函数,随机从0-end的一个序列中抽取size个不同的数
def random_num(size,end):
    range_ls=[i for i in range(end)]
    num_ls=[]
    for i in range(size):
        num=random.choice(range_ls)
        range_ls.remove(num)
        num_ls.append(num)
    return num_ls
    


path = "test.jpg"
transformss = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((224, 224)),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

#注意如果有中文路径需要先解码,最好不要用中文
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

#转换维度
img = transformss(img).unsqueeze(0)

model = torchvision.models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
new_model = torchvision.models._utils.IntermediateLayerGetter(model, {'layer1': '1', 'layer2': '2',"layer3":"3"})
out = new_model(img)

tensor_ls=[(k,v) for  k,v in out.items()]

#这里选取layer2的输出画特征图
v=tensor_ls[1][1]

#取消Tensor的梯度并转成三维tensor,否则无法绘图
v=v.data.squeeze(0)

print(v.shape)  # torch.Size([512, 28, 28])


#随机选取25个通道的特征图
channel_num = random_num(25,v.shape[0])
plt.figure(figsize=(10, 10))
for index, channel in enumerate(channel_num):
    ax = plt.subplot(5, 5, index+1,)
    plt.imshow(v[channel, :, :])
plt.savefig("feature.jpg",dpi=300)

原图

请添加图片描述

特征图

从特征图中可以看到,layer2确实已经学习到了某些特征,比如第二行第二列的特征图已经把狗的形状勾勒出来了,说明这个卷积核学习的可能是狗的颜色。

请添加图片描述
这里再展示一下ResNet第一部分(conv1)卷积层的特征图(灰度图):
在这里插入图片描述

2.AlexNet可视化

上面的ResNet用的是预训练模型,这里我们自己构建AlexNet。
代码如下:

class AlexNet(nn.Module):
    def __init__(self):
        super(AlexNet, self).__init__()

        self.conv1 = nn.Sequential(nn.Conv2d(3, 96, 11, 4, 2),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2),
                                   )

        self.conv2 = nn.Sequential(nn.Conv2d(96, 256, 5, 1, 2),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2),
                                   )

        self.conv3 = nn.Sequential(nn.Conv2d(256, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 384, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.Conv2d(384, 256, 3, 1, 1),
                                   nn.ReLU(),
                                   nn.MaxPool2d(3, 2))


        self.fc=nn.Sequential(nn.Linear(256*6*6, 4096),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(4096, 4096),
                                nn.ReLU(),
                                nn.Dropout(0.5),
                                nn.Linear(4096, 100),
                                )

    def forward(self, x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.conv3(x)
        output=self.fc(x.view(-1, 256*6*6))
        return output

model=AlexNet()
for name in model.named_children():
    print(name[0])
#同理先看网络结构
#输出
"""
conv1
conv2
conv3
fc
"""


    
path = "test.jpg"
transformss = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize((224, 224)),
     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

#注意如果有中文路径需要先解码,最好不要用中文
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

#转换维度
img = transformss(img).unsqueeze(0)

model = AlexNet()


## 修改这里传入的字典即可

new_model = torchvision.models._utils.IntermediateLayerGetter(model, {"conv1":1,"conv2":2,"conv3":3})
out = new_model(img)

tensor_ls=[(k,v) for  k,v in out.items()]

#选取conv2的输出
v=tensor_ls[1][1]

#取消Tensor的梯度并转成三维tensor,否则无法绘图
v=v.data.squeeze(0)

print(v.shape)  # torch.Size([512, 28, 28])


#随机选取25个通道的特征图
channel_num = random_num(25,v.shape[0])
plt.figure(figsize=(10, 10))
for index, channel in enumerate(channel_num):
    ax = plt.subplot(5, 5, index+1,)
    plt.imshow(v[channel, :, :])  # 灰度图参数cmap="gray"
plt.savefig("feature.jpg",dpi=300)


也就是说AlexNet这里分为了4部分,三个卷积和一个全连接(其实就是我们自己定义的foward前向传播),我们想要哪层的输出改个字典就好了,new_model = torchvision.models._utils.IntermediateLayerGetter(model, {“conv1”:1,“conv2”:2,“conv3”:3}),得到的特征图如下。

在这里插入图片描述
plt.imshow(v[channel, :, :],cmap="gray") 加上cmap参数就可以显示灰度图
在这里插入图片描述


总结

IntermediateLayerGetter有一个不足就是它不能获取二级层的输出,比如ResNet的layer2,他不能获取layer2里面的卷积的输出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/9821.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【项目实战】springboot+vue舞蹈课程在线学习系统-java舞蹈课程学习打卡系统的设计与实现

注意:该项目只展示部分功能,如需了解,评论区咨询即可。 本文目录1.开发环境2 系统设计2.1 背景意义2.2 技术路线2.3 主要研究内容3 系统页面展示3.1 学生3.2 教师页面3.3 管理员页面4 更多推荐5 部分功能代码5.1 查看学生打卡5.2 文件上传下载…

天翼云实时云渲染,助力打造世界VR产业大会云上之城

2022年11月12日,2022世界VR产业大会于江西南昌开幕。11月13日,以“共建元宇宙生态,点亮新数智未来”为主题的中国电信生态论坛召开。由天翼云携手新国脉数字文化股份有限公司(简称“国脉文化”)打造的元宇宙家园国脉大…

【力扣刷题】只出现一次的数字

🔗 题目链接 题目描述 给你一个 非空 整数数组 nums ,除了某个元素只出现一次以外,其余每个元素均出现两次。找出那个只出现了一次的元素。 你必须设计并实现线性时间复杂度的算法来解决此问题,且该算法只使用常量额外空间。 …

java类的练习 -- 声明一个接口(Calculability),接口中

java类的练习 – 声明一个接口(Calculability),接口中… 题目 编写一个应用程序,实现以下功能: ①声明一个接口(Calculability),接口中包含一个方法area()。 ②声明一个三角形类实现该接口,类名为Triangle&#xf…

React源码分析2-深入理解fiber

react16 版本之后引入了 fiber,整个架构层面的 调度、协调、diff 算法以及渲染等都与 fiber 密切相关。所以为了更好地讲解后面的内容,需要对 fiber 有个比较清晰的认知。本章将介绍以下内容: 为什么需要 fiberfiber 节点结构中的属性fiber 树…

AR眼镜新秀雷鸟创新,究竟能飞多远?

时隔近十年之后,消费级AR眼镜又重新高调回归大众视野。 自去年10月开始,以OPPO、小米为代表的国内大厂纷纷推出试验性AR眼镜,谷歌第二代AR眼镜更是作为压轴在I/O大会上重新回归,苹果多年来不断提及但始终“难产”的AR产品&#x…

平衡二叉树(AVL树)

1.简介 1.二叉排序树的问题: 如果原始是数据是排好序的(如1,2,3,4,5,6),那么最终创建的二叉排序树的结构就会变成一条斜线,类似于一条单链表,此时如果需要查找/插入某个元素就要一个一个元素的比较,这样就没有优势了.由于每次都要比较左子树,其查询速度甚至比单链表还慢; 2.对…

labview 写入文本到word报表(标签方法)

描述labview按预先定义的包含数个标签的模板,复制模板到新文件,写入文本到各标签位置。 图1 前面板 图2 程序框图 图3 Ms office report 图4 配置Ms office report的属性1 图5 配置Ms office report的属性2

在vue2项目中使用vue-quill-editor实现富文本编译器

1 安装 npm install vue-quill-editor --save 2 引入 有两种引入方式 (1)全局引入(main) import VueQuillEditor from vue-quill-editor//调用编辑器 // 样式 import quill/dist/quill.core.css import quill/dist/quill.snow.css import quill/dist…

浅谈无脚本自动化测试

在当今的企业环境中,软件测试不再被视为不必要的投资;相反,它已经上升到一种需要而不是奢侈品的水平。随着市场的不断变化和竞争的加剧,企业必须做一些让他们与竞争对手区分开来的事情。 为了使自己与众不同,公司必须提…

【1-系统架构演进过程】

特别说明:接下来我会和大家一起完成一个商城项目,这个项目涉及的内容以及技术不仅多,而且都是现在主流的开发技术,每天我会按时更新博客内容,详细记录学习的过程,感兴趣的同学可以和我一起完成,但是时间较长…

国际贸易详解:国际贸易主要有哪些分类标准和运输方式

国际贸易主要的分类标准包括按商品流向分为出口贸易,进口贸易和过境贸易,按商品形态分为有形贸易和无形贸易,按运输方式分为陆运贸易,海运贸易等。一、国际贸易主要有哪些分类标准 1、按商品流向分为出口贸易、进口贸易、过境贸易…

2022英特尔® FPGA中国技术周

本文图片均来自于2022英特尔 FPGA中国技术周线上会议 11.14 全新的中端和以边缘为中心的FPGA 英特尔 Agilex™ FPGA的下一代接口协议 11.15 Nios V: 基于FPGA的RISC-V处理器 英特尔 Quartus Prime开发软件 基于FPGA的人工智能开发套件 Case 使用oneAPI高级语言开发IP​ 将…

油罐清洗抽吸系统设计

目录 摘要 - 1 - Abstract - 2 - 1 引言 - 3 - 1.1课题的背景和意义 - 3 - 1.1.1课题的背景 - 3 - 1.1.2课题的意义 - 3 - 1.2国内外油罐清洗抽吸系统的研究情况 - 3 - 2 油罐清洗抽吸系统总体设计方案 - 6 - 2.1油罐清洗抽吸系统方案 - 6 - 2.2清洗对象及要求 - 6 - 2.3清洗工…

欧拉路径!

呃在昨晚的破防之下我并不想学这东西所以留到今晚?属于是懒爆了。那么我们来看,定义的话其实前面写过了 其实主要是分两个方面: 1.图是否联通,是什么图 2.这个图每个点的度数或者(出度入度)什么的若是符合就行。 偷偷总…

【C++】C++基础知识(四)---程序流程结构

C基础知识(四)1. 顺序结构2. 选择结构2.1 if语句2.2 switch语句2.3 选择结构案例3. 循环结构3.1 循环语句3.2 循环结构案例4. 跳转语句C中支持的三种流程结构:顺序结构、选择结构、循环结构顺序结构:程序按照顺序执行,…

easy-monitor3.0 nodejs性能监控和分析工具

#easy-monitor性能监控和分析工具 Easy-Monitor 3.0 https://blog.csdn.net/qq_36791889/article/details/115420116 #git地址:https://github.com/1981430140/easy-monitor-docker-compose.git 一、easy-monitor 服务器端安装(docker-compose&#xff…

我说MySQL里每张表不要超过100w数据,面试官让我回去等通知?

V-xin:ruyuanhadeng获得600页原创精品文章汇总PDF 目录 1、面试题2、面试官心理分析3、面试题剖析 1、面试题 事务的几个特点是什么?数据库事务有哪些隔离级别?MySQL的默认隔离级别? 2、面试官心里分析 用mysql开发的三个基本…

操作系统4小时速成:处理机调度,调度方法,调度准则,典型的调度算法,响应比

操作系统4小时速成:处理机调度,调度方法,调度准则,典型的调度算法,响应比 2022找工作是学历、能力和运气的超强结合体,遇到寒冬,大厂不招人,可能很多算法学生都得去找开发&#xff…

详探XSS PayIoad

详探XSS PayIoad1.Cookie劫持2.构造GET与POST请求3.XSS钓鱼4.识别用户浏览器1.Cookie劫持 一个最常见的XSS Payload,就是通过读取浏览器的Cookie对象,从而发起“Cookie劫持”攻击 Cookie中一般加密保存了当前用户的登录凭证。Cookie如果丢失&#xff0…