有效感受野可视化学习

news2024/11/19 10:21:32

有效感受野可视化

  • 过程记录
    • 创建环境
    • 准备数据、脚本
    • 脚本测试
  • 其他参考
    • 尝试运行

过程记录

创建环境

conda create -n ERF python=3.8 -y
conda activate ERF
pip3 install empy rospkg pyyaml catkin_pkg
conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=11.8 -c pytorch -c nvidia
pip3 install numpy opencv-contrib-python

准备数据、脚本

创建ERF_VIS目录,管理项目,根目录下创建test_imgs目录管理测试图片。
根目录下创建脚本ERF_visualizetion.py(参考代码):

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')
from torchvision.models.segmentation import deeplabv3_resnet50
import torch
from torchvision import transforms
import numpy as np
import torchvision
from PIL import Image
import cv2
import os
#这里随机拿100张测试集图像,放到一个文件夹中,img_dir是文件夹路径
img_dir = "./test_imgs"
images=os.listdir(img_dir)
model = deeplabv3_resnet50(pretrained=True, progress=False)
model = model.eval()
#定义输入图像的长宽,这里需要保证每张图像都要相同
input_H, input_W = 512, 512
#生成一个和输入图像大小相同的0矩阵,用于更新梯度
heatmap = np.zeros([input_H, input_W])
#打印一下模型,选择其中的一个层
print(model)

#这里选择骨干网络的最后一个模块
layer = model.backbone.layer4[-1]
print(layer)


def farward_hook(module, data_input, data_output):
    fmap_block.append(data_output)
    input_block.append(data_input)
    
#为了简单,这里直接一张一张图来算,遍历文件夹中所有图像  
for img in images:
    read_img = os.path.join(img_dir,img)
    image = Image.open(read_img)
    
    #图像预处理,将图像缩放到固定分辨率,并进行标准化
    image = image.resize((input_H, input_W))
    image = np.float32(image) / 255
    input_tensor = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225))])(image)
    
    #添加batch维度
    input_tensor = input_tensor.unsqueeze(0)
    
    if torch.cuda.is_available():
        model = model.cuda()
        input_tensor = input_tensor.cuda()
        
    #输入张量需要计算梯度
    input_tensor.requires_grad = True
    fmap_block = list()
    input_block = list()
    
    #对指定层获取特征图
    layer.register_forward_hook(farward_hook)
    
    #进行一次正向传播
    output = model(input_tensor)
    
    #特征图的channel维度算均值且去掉batch维度,得到二维张量
    feature_map = fmap_block[0].mean(dim=1,keepdim=False).squeeze()
    
    #对二维张量中心点(标量)进行backward
    feature_map[(feature_map.shape[0]//2-1)][(feature_map.shape[1]//2-1)].backward(retain_graph=True)

    #对输入层的梯度求绝对值
    grad = torch.abs(input_tensor.grad)
    
    #梯度的channel维度算均值且去掉batch维度,得到二维张量,张量大小为输入图像大小
    grad = grad.mean(dim=1,keepdim=False).squeeze()
    
    #累加所有图像的梯度,由于后面要进行归一化,这里可以不算均值
    heatmap = heatmap + grad.cpu().numpy()
    
    
cam = heatmap

#对累加的梯度进行归一化
cam = cam / cam.max()

#可视化,蓝色值小,红色值大
cam = cv2.applyColorMap(np.uint8(cam*255), cv2.COLORMAP_JET)
cam = cv2.cvtColor(cam, cv2.COLOR_BGR2RGB)
map_image = Image.fromarray(cam)
map_image.show()

根目录下创建脚本Img_copy.py(从之前下载的ADE20K数据集的验证集中拷贝图片):

import os
import shutil

# 源文件夹路径
source_folder = '/media/lcy-magic/Dataset/Segment_Dataset/ade/ADEChallengeData2016/images/validation'
# 目标文件夹路径
destination_folder = '/home/lcy-magic/Segment_TEST/ERF_VIS/test_imgs'

# 获取源文件夹中所有的.png图片文件
png_files = [f for f in os.listdir(source_folder) if f.endswith('.jpg')]

# 确保目标文件夹存在
if not os.path.exists(destination_folder):
    os.makedirs(destination_folder)

# 复制前100张图片到目标文件夹
for i, png_file in enumerate(png_files):
    if i < 100:
        source_file = os.path.join(source_folder, png_file)
        destination_file = os.path.join(destination_folder, png_file)
        shutil.copyfile(source_file, destination_file)
        print(f'Copied {png_file} to {destination_folder}')

print('Copying completed.')

脚本测试

运行:

python ERF_visualizetion.py

运行结果为:
在这里插入图片描述

其他参考

尝试运行

发现一个这个项目做了可视化感受野参考博客,看看他是怎么怎么做的。
我先从github上下载了他的ERF代码:
在这里插入图片描述
先尝试运行visualize_erf.py。
看样子需要这几个参数:

  • –model:先不传,用他默认的resnet101
  • –weights:看后面代码:model = resnet101(pretrained=args.weights is None),对于resnet101并不需要提供权重
  • –data_path:看后面代码:root = os.path.join(args.data_path, 'val'),会对其中val文件夹内所有图片进行操作;我自己做个val文件夹,内含10张图片
  • –save_path:用npy文件保存ERF矩阵,指定地点;我先保存到根目录下新建result目录
  • –num_images:处理多少张图片,默认是50,我改成10

运行:

python other/visualize_erf.py --data_path test_imgs  --save_path result/ --num_images 10

报错:

ModuleNotFoundError: No module named 'timm'

安装就好:

pip3 install timm

报错:

ModuleNotFoundError: No module named 'erf'

把脚本的目录改名为erf,再添加环境变量,就能本地import了:
在这里插入图片描述

export PYTHONPATH=$PYTHONPATH:~/Segment_TEST/ERF_VIS

在这里插入图片描述
报错:

Traceback (most recent call last):
  File "erf/visualize_erf.py", line 15, in <module>
    from erf.resnet_for_erf import resnet101, resnet152
  File "/home/lcy-magic/Segment_TEST/ERF_VIS/erf/resnet_for_erf.py", line 7, in <module>
    from torchvision.models.resnet import ResNet, Bottleneck, BasicBlock, load_state_dict_from_url, model_urls
ImportError: cannot import name 'load_state_dict_from_url' from 'torchvision.models.resnet' (/home/lcy-magic/anaconda3/envs/ERF/lib/python3.8/site-packages/torchvision/models/resnet.py)

参考参考博客改为:

from torchvision.models.resnet import ResNet, Bottleneck, BasicBlock, model_urls
try:
    from torch.hub import load_state_dict_from_url
except ImportError:
    from torch.utils.model_zoo import load_url as load_state_dict_from_url

解决。但报错:

ImportError: cannot import name 'model_urls' from 'torchvision.models.resnet' (/home/lcy-magic/anaconda3/envs/ERF/lib/python3.8/site-packages/torchvision/models/resnet.py)

无法这样解决。应该是torchvision版本问题。查看我的版本是0.16.0:

pip show torchvision
Name: torchvision
Version: 0.16.0
Summary: image and video datasets and models for torch deep learning
Home-page: https://github.com/pytorch/vision
Author: PyTorch Core Team
Author-email: soumith@pytorch.org
License: BSD
Location: /home/lcy-magic/anaconda3/envs/ERF/lib/python3.8/site-packages
Requires: numpy, pillow, requests, torch
Required-by: timm

我不想牵就作者的版本。于是查看pytorch关于torchvision的官方文档:官方文档 。然后搜索load_state_dict_from_url关键词,点击链接进入相关文档:
在这里插入图片描述
看起来和上一篇参考博客说的吻合,也就是从torch.hub下载。可能model_urls在torch.hub里也没有。搜索下这个关键词,结果并没有搜到。查看代码中使用model_urls的部分:

def _resnet(arch, block, layers, pretrained, progress, **kwargs):
    model = ResNetForERF(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict, strict=False)
    return model

发现model_urls只是根据输入参数arch(网络架构),给load_state_dict_from_url提供输入参数。参考load_state_dict_from_url的官方说明:
在这里插入图片描述
这个参数就是URL.我看到这篇博客参考博客就没有import,直接给出的。那我也不import。只添加:

model_urls = {
    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}

果然成功了。
报错:

ModuleNotFoundError: No module named 'replknet'

发现这个replknet是那个论文自己的模型,我对这个没有兴趣。把所有用到他的地方都注释掉。
报错:

Traceback (most recent call last):
  File "erf/visualize_erf.py", line 116, in <module>
    main(args)
  File "erf/visualize_erf.py", line 54, in main
    dataset = datasets.ImageFolder(root, transform=transform)
  File "/home/lcy-magic/anaconda3/envs/ERF/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 309, in __init__
    super().__init__(
  File "/home/lcy-magic/anaconda3/envs/ERF/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 144, in __init__
    classes, class_to_idx = self.find_classes(self.root)
  File "/home/lcy-magic/anaconda3/envs/ERF/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 218, in find_classes
    return find_classes(directory)
  File "/home/lcy-magic/anaconda3/envs/ERF/lib/python3.8/site-packages/torchvision/datasets/folder.py", line 42, in find_classes
    raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
FileNotFoundError: Couldn't find any class folder in test_imgs/val.

原来他对数据集还有要求,得按照imagenet的格式来。我尝试改代码,但太麻烦了。又不想再下个数据集。就随便给我的图片归成dog和cat两类,但实际一点关系都没。
运行成功:
在这里插入图片描述
但发现result文件夹内没有结果,打印下保存文件的代码信息:
在这里插入图片描述
发现就是计数不对。改为:

if meter.count == args.num_images - 1:
            np.save(args.save_path, meter.avg)
            exit()

改了之后还是不对,发现这样不对,不应该-1。同时指定的保存地点应该是具体的npy文件。于是我给dog里加了一张图片。然后命令改为:

python erf/visualize_erf.py --data_path test_imgs  --save_path result/ERF.npy --num_images 10

成功。

接下来可视化,应该是analyze_erf.py脚本,只有两个参数:

  • –source:也就是刚生成的result/ERF.npy
  • –heatmap_save:热力图地址
    执行:
python erf/analyze_erf.py --source result/ERF.npy --heatmap_save result/heatmap.png

先后报错:

ModuleNotFoundError: No module named 'matplotlib'
pykitti 0.3.1 requires pandas, which is not installed.
ModuleNotFoundError: No module named 'seaborn'

安装就好:

pip3 install matplotlib
pip3 install pandas
pip3 install seaborn

报错:

ModuleNotFoundError: No module named 'mpl_toolkits.axes_grid1.colorbar'

搜了下,好像是把matplotlib升级下就行。我尝试了upgrade和用conda install,都不行。发现现在的matplotlib已经没这个东西了。我也没找到现在版本对应的实现是什么。算了退一下版本吧。发现seabon又必须让matplotlib在3.4以上。然后发现,seabon本来就可以绘制colorbar,代码注释里也写了,于是把color bar那个都注释掉,改用seabon:

def heatmap(data, camp='RdYlGn', figsize=(10, 10.75), ax=None, save_path=None):
    plt.figure(figsize=figsize, dpi=40)

    ax = sns.heatmap(data,
                xticklabels=False,
                yticklabels=False, cmap=camp,
                center=0, annot=False, ax=ax, cbar=True, annot_kws={"size": 24}, fmt='.2f')
    #   =========================== Add a **nicer** colorbar on top of the figure. Works for matplotlib 3.3. For later versions, use matplotlib.colorbar
    #   =========================== or you may simply ignore these and set cbar=True in the heatmap function above.
    from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
    # from matplotlib_colorbar import colorbar
    # from mpl_toolkits.axes_grid1.colorbar import colorbar
    # ax_divider = make_axes_locatable(ax)
    # cax = ax_divider.append_axes('top', size='5%', pad='2%')
    # colorbar(ax.get_children()[0], cax=cax, orientation='horizontal')
    # cax.xaxis.set_ticks_position('top')
    #   ================================================================
    #   ================================================================
    plt.savefig(save_path)

运行成功:
在这里插入图片描述
这个代码计算梯度的方法和之前的一致,没啥参考价值。主要是他中间有个保存梯度图为npy文件的过程和用seabon画更好看的图的想法不错,以及通过arg传参。借用下。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1570665.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++set和map详细介绍

文章目录 前言一、关联式容器和序列式容器二、set1.set文档介绍2.set成员函数1.构造函数2.迭代器3.容量4.修改5.其他 三.multiset四.map1.map文档介绍2.map成员函数1.构造2.insert插入3.count4.迭代器5.【】和at 五.multimap总结 前言 在本篇文章中&#xff0c;我们将会学到关…

大模型生成RAG评估数据集并计算hit_rate 和 mrr

文章目录 背景简介代码实现公开参考资料 背景 最近在做RAG评估的实验&#xff0c;需要一个RAG问答对的评估数据集。在网上没有找到好用的&#xff0c;于是便打算自己构建一个数据集。 简介 本文使用大模型自动生成RAG 问答数据集。使用BM25关键词作为检索器&#xff0c;然后…

网络编程核心概念解析:IP地址、端口号与网络字节序深度探讨

⭐小白苦学IT的博客主页 ⭐初学者必看&#xff1a;Linux操作系统入门 ⭐代码仓库&#xff1a;Linux代码仓库 ❤关注我一起讨论和学习Linux系统 本节重点 认识IP地址, 端口号, 网络字节序等网络编程中的基本概念; 1.前言 网络编程&#xff0c;作为现代信息社会中的一项核心技术&…

基于jsp+Spring boot+mybatis的图书管理系统设计和实现

基于jspSpring bootmybatis的图书管理系统设计和实现 博主介绍&#xff1a;多年java开发经验&#xff0c;专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》 欢迎点赞 收藏 ⭐留言 文末获…

kubesphere开启java服务

使用java:8作为基础镜像 1、创建持久化存储空间&#xff1a; 2、创建工作负载 &#xff08;1&#xff09;选择java镜像 &#xff08;2&#xff09;设置开启端口和启动命令&#xff08;--spring.config.location为读取jar包外部的配置文件&#xff09; &#xff08;3&#xff…

Linux--进程(2)

目录 前言 1. 进程的状态 1.1 进程排队 1.2 运行&#xff0c;阻塞&#xff0c;挂起 2.Linux下具体的进程状态 2.1僵尸和孤儿 3.进程的优先级 4.Linux的调度与切换 前言 这篇继续来学习进程的其它知识 上篇文章&#xff1a;Linux--进程&#xff08;1&#xff09;-CS…

理解Three.js的相机

大家都知道我们生活中的相机&#xff0c;可以留下美好瞬间。那Three.js的相机是什么呢&#xff1f;Three.js创建的场景是三维的&#xff0c;而我们使用的显示器显然是二维的&#xff0c;相机就是抽象的定义了三维空间到二维显示器的投影方式。Three.js常见的相机有两类&#xf…

资源分享 | 解决你的算力烦恼,平台注册送算力

前言 最近趋动云在做活动&#xff0c;新用户注册即可送价值70元的算力金&#xff0c;做满新手任务最高可领300元的算力红包&#xff0c;趋动云中租卡的费用如下&#xff1a; 1张24G的显存的卡大概是2块钱一个小时&#xff0c;48G的是4块钱一个小时&#xff0c;300算力红包能用…

【Redis】详解 Redis

Redis是一种高性能的开源键值存储数据库&#xff0c;它支持各种数据结构&#xff0c;包括字符串&#xff08;strings&#xff09;、哈希&#xff08;hashes&#xff09;、列表&#xff08;lists&#xff09;、集合&#xff08;sets&#xff09;、有序集合&#xff08;sorted se…

如何借助Idea创建多模块的SpringBoot项目

目录 1.1、前言1.2、开发环境1.3、项目多模块结构1.4、新建父工程1.5、创建子模块1.6、编辑父工程的pom.xml文件 1.1、前言 springmvc项目&#xff0c;一般会把项目分成多个包:controler、service、dao、utl等&#xff0c;但是随着项目的复杂性提高&#xff0c;想复用其他一个模…

蓝桥集训之斐波那契前n项和

蓝桥集训之斐波那契前n项和 核心思想&#xff1a;矩阵乘法 左边求和 右边求和 得到Sn fn2 – 1 因此只要求出fn2 即可 #include <iostream>#include <cstring>#include <algorithm>using namespace std;typedef long long LL;int n,m;int A[2][2] { …

论大数据服务化发展史

引言 一直想写一篇服务化相关的文章&#xff0c;那就别犹豫了现在就开始吧 正文 作为大数据基础架构工程师&#xff0c;业界也笑称“运维Boy”&#xff0c;日常工作就是在各个机器上部署以及维护服务&#xff0c;例如部署Hadoop、Kafka、Pulsar这些等等&#xff0c;用于给公…

使用python将作图并将局部放大

此程序主要特点&#xff1a; 1、使用python画实验结果图 2、想要对大图的局部进行放大 3、有两个子图 4、子图和原图的横坐标都使用标签而不是原始的数据 代码和注释如下&#xff1a; import pandas as pd import numpy as np import matplotlib.pyplot as plt import ope…

BCLinux-for-Euler配置本地yum源

稍微吐槽一句…… 在这片土地上&#xff0c;国产化软件的大潮正在滚滚而来&#xff0c;虽然都不是真正意义上的国产化&#xff0c;但是至少壳是国产的~~~ 之前使用的Centos7的系统&#xff0c;现在都要求统一换成BCLinux-for-Euler。说实话换了之后不太适应&#xff0c;好多用习…

COCO格式转YOLO格式训练

之前就转换过好几次&#xff0c;每次换设备训练&#xff0c;由于压缩包太大&#xff0c;u盘不够用。每次都要找教程从网上再下载一遍。因此这里记录一下&#xff0c;以免下次重新找教程。 在coco数据集中&#xff0c;coco2017train或coco2017val数据集中标注的目标(类别)位置在…

Spring 详细总结

文章目录 第一章 IOC容器第一节 Spring简介1、一家公司2、Spring旗下的众多项目3、Spring Framework①Spring Framework优良特性②Spring Framework五大功能模块 第二节 IOC容器概念1、普通容器①生活中的普通容器②程序中的普通容器 2、复杂容器①生活中的复杂容器②程序中的复…

传输层 --- UDP

目录 1. 传输层是什么呢&#xff1f; 2. 再谈端口号 2.1. 端口号是什么 2.2. 协议号是什么 2.3. 认识知名端口号 2.4. 端口号的相关问题 2.4.1. 一个进程可以绑定多个端口号吗&#xff1f; 2.4.2. 一个端口号可以被多个进程绑定吗&#xff1f; 2.4.3. 为什么不使用P…

数据结构进阶篇 之 【并归排序】(递归与非递归实现)详细讲解

都说贪小便宜吃大亏&#xff0c;但吃亏是福&#xff0c;那不就是贪小便宜吃大福了吗 一、并归排序 MergeSort 1.基本思想 2.实现原理 3.代码实现 4.归并排序的特性总结 二、非递归并归排序实现 三、完结撒❀ –❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀…

如何使用单片机 pwm 控制 mos 管?

目录 选择适合的硬件 连接电路 编写代码 参考示例 程序一 程序二 测试与调试 注意事项 使用单片机&#xff08;如常见的Arduino、STM32等&#xff09;通过PWM&#xff08;脉冲宽度调制&#xff09;控制MOS管&#xff08;金属氧化物半导体场效应管&#xff09;是一种常见…

Java中的集合(二)

一、回顾上期 上一篇讲到在Java中&#xff0c;集合和容器是非常重要的概念&#xff0c;用于存储和操作数据。在集合中&#xff0c;有单列集合和双列集合两种类型。我们在上一篇将单列集合中的list类讲完了&#xff0c;这一篇将会将集合中剩余部分介绍完&#xff0c;话不多说&am…