MMROTATE 1.X特征图可视化(绘制Heat Map)

news2024/9/23 6:01:02

本文参考MMYOLO官方的特征图可视化教程,对MMROTATE相关算法进行特征图可视化

1. 新建featmap_vis_demo.py文件

在mmrotate项目文件夹下新建 featmap_vis_demo.py

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
from typing import Sequence

import mmcv
from mmdet.apis import inference_detector, init_detector
from mmengine import Config, DictAction
from mmengine.registry import init_default_scope
from mmengine.utils import ProgressBar

from mmrotate.registry import VISUALIZERS
from mmrotate.utils.misc import auto_arrange_images, get_file_list


def parse_args():
    parser = argparse.ArgumentParser(description='Visualize feature map')
    parser.add_argument(
        'img', help='Image path, include image file, dir and URL.')
    parser.add_argument('config', help='Config file')
    parser.add_argument('checkpoint', help='Checkpoint file')
    parser.add_argument(
        '--out-dir', default='./output', help='Path to output file')
    parser.add_argument(
        '--target-layers',
        default=['backbone'],
        nargs='+',
        type=str,
        help='The target layers to get feature map, if not set, the tool will '
        'specify the backbone')
    parser.add_argument(
        '--preview-model',
        default=False,
        action='store_true',
        help='To preview all the model layers')
    parser.add_argument(
        '--device', default='cuda:0', help='Device used for inference')
    parser.add_argument(
        '--score-thr', type=float, default=0.3, help='Bbox score threshold')
    parser.add_argument(
        '--show', action='store_true', help='Show the featmap results')
    parser.add_argument(
        '--channel-reduction',
        default='select_max',
        help='Reduce multiple channels to a single channel')
    parser.add_argument(
        '--topk',
        type=int,
        default=4,
        help='Select topk channel to show by the sum of each channel')
    parser.add_argument(
        '--arrangement',
        nargs='+',
        type=int,
        default=[2, 2],
        help='The arrangement of featmap when channel_reduction is '
        'not None and topk > 0')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    args = parser.parse_args()
    return args


class ActivationsWrapper:

    def __init__(self, model, target_layers):
        self.model = model
        self.activations = []
        self.handles = []
        self.image = None
        for target_layer in target_layers:
            self.handles.append(
                target_layer.register_forward_hook(self.save_activation))

    def save_activation(self, module, input, output):
        self.activations.append(output)

    def __call__(self, img_path):
        self.activations = []
        results = inference_detector(self.model, img_path)
        return results, self.activations

    def release(self):
        for handle in self.handles:
            handle.remove()


def main():
    args = parse_args()

    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    init_default_scope(cfg.get('default_scope', 'mmyolo'))

    channel_reduction = args.channel_reduction
    if channel_reduction == 'None':
        channel_reduction = None
    assert len(args.arrangement) == 2

    model = init_detector(args.config, args.checkpoint, device=args.device)

    if not os.path.exists(args.out_dir) and not args.show:
        os.mkdir(args.out_dir)

    if args.preview_model:
        print(model)
        print('\n This flag is only show model, if you want to continue, '
              'please remove `--preview-model` to get the feature map.')
        return

    target_layers = []
    for target_layer in args.target_layers:
        try:
            target_layers.append(eval(f'model.{target_layer}'))
        except Exception as e:
            print(model)
            raise RuntimeError('layer does not exist', e)

    activations_wrapper = ActivationsWrapper(model, target_layers)

    # init visualizer
    visualizer = VISUALIZERS.build(model.cfg.visualizer)
    visualizer.dataset_meta = model.dataset_meta

    # get file list
    image_list, source_type = get_file_list(args.img)

    progress_bar = ProgressBar(len(image_list))
    for image_path in image_list:
        result, featmaps = activations_wrapper(image_path)
        if not isinstance(featmaps, Sequence):
            featmaps = [featmaps]

        flatten_featmaps = []
        for featmap in featmaps:
            if isinstance(featmap, Sequence):
                flatten_featmaps.extend(featmap)
            else:
                flatten_featmaps.append(featmap)

        img = mmcv.imread(image_path)
        img = mmcv.imconvert(img, 'bgr', 'rgb')

        if source_type['is_dir']:
            filename = os.path.relpath(image_path, args.img).replace('/', '_')
        else:
            filename = os.path.basename(image_path)
        out_file = None if args.show else os.path.join(args.out_dir, filename)

        # show the results
        shown_imgs = []
        visualizer.add_datasample(
            'result',
            img,
            data_sample=result,
            draw_gt=False,
            show=False,
            wait_time=0,
            out_file=None,
            pred_score_thr=args.score_thr)
        drawn_img = visualizer.get_image()

        for featmap in flatten_featmaps:
            shown_img = visualizer.draw_featmap(
                featmap[0],
                drawn_img,
                channel_reduction=channel_reduction,
                topk=args.topk,
                arrangement=args.arrangement)
            shown_imgs.append(shown_img)

        shown_imgs = auto_arrange_images(shown_imgs)

        progress_bar.update()
        if out_file:
            mmcv.imwrite(shown_imgs[..., ::-1], out_file)
            print(f'{out_file} has been saved')

        if args.show:
            visualizer.show(shown_imgs)

    if not args.show:
        print(f'All done!'
              f'\nResults have been saved at {os.path.abspath(args.out_dir)}')


# Please refer to the usage tutorial:
# https://github.com/open-mmlab/mmyolo/blob/main/docs/zh_cn/user_guides/visualization.md # noqa
if __name__ == '__main__':
    main()

2. 修改或替换mmrotate的misc.py文件

mmrotate-1.x/mmrotate/utils/misc.py , 删除里面的内容,填入以下内容:
在这里插入图片描述

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union

from mmengine.config import Config, ConfigDict

# Copyright (c) OpenMMLab. All rights reserved.
import os
import urllib

import numpy as np
import torch
from mmengine.utils import scandir

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif',
                  '.tiff', '.webp')


def get_test_pipeline_cfg(cfg: Union[str, ConfigDict]) -> ConfigDict:
    """Get the test dataset pipeline from entire config.

    Args:
        cfg (str or :obj:`ConfigDict`): the entire config. Can be a config
            file or a ``ConfigDict``.

    Returns:
        :obj:`ConfigDict`: the config of test dataset.
    """
    if isinstance(cfg, str):
        cfg = Config.fromfile(cfg)

    def _get_test_pipeline_cfg(dataset_cfg):
        if 'pipeline' in dataset_cfg:
            return dataset_cfg.pipeline
        # handle dataset wrapper
        elif 'dataset' in dataset_cfg:
            return _get_test_pipeline_cfg(dataset_cfg.dataset)
        # handle dataset wrappers like ConcatDataset
        elif 'datasets' in dataset_cfg:
            return _get_test_pipeline_cfg(dataset_cfg.datasets[0])

        raise RuntimeError('Cannot find `pipeline` in `test_dataloader`')

    return _get_test_pipeline_cfg(cfg.test_dataloader.dataset)



def auto_arrange_images(image_list: list, image_column: int = 2) -> np.ndarray:
    """Auto arrange image to image_column x N row.

    Args:
        image_list (list): cv2 image list.
        image_column (int): Arrange to N column. Default: 2.
    Return:
        (np.ndarray): image_column x N row merge image
    """
    img_count = len(image_list)
    if img_count <= image_column:
        # no need to arrange
        image_show = np.concatenate(image_list, axis=1)
    else:
        # arrange image according to image_column
        image_row = round(img_count / image_column)
        fill_img_list = [np.ones(image_list[0].shape, dtype=np.uint8) * 255
                         ] * (
                             image_row * image_column - img_count)
        image_list.extend(fill_img_list)
        merge_imgs_col = []
        for i in range(image_row):
            start_col = image_column * i
            end_col = image_column * (i + 1)
            merge_col = np.hstack(image_list[start_col:end_col])
            merge_imgs_col.append(merge_col)

        # merge to one image
        image_show = np.vstack(merge_imgs_col)

    return image_show


def get_file_list(source_root: str) -> [list, dict]:
    """Get file list.

    Args:
        source_root (str): image or video source path

    Return:
        source_file_path_list (list): A list for all source file.
        source_type (dict): Source type: file or url or dir.
    """
    is_dir = os.path.isdir(source_root)
    is_url = source_root.startswith(('http:/', 'https:/'))
    is_file = os.path.splitext(source_root)[-1].lower() in IMG_EXTENSIONS

    source_file_path_list = []
    if is_dir:
        # when input source is dir
        for file in scandir(
                source_root, IMG_EXTENSIONS, recursive=True,
                case_sensitive=False):
            source_file_path_list.append(os.path.join(source_root, file))
    elif is_url:
        # when input source is url
        filename = os.path.basename(
            urllib.parse.unquote(source_root).split('?')[0])
        file_save_path = os.path.join(os.getcwd(), filename)
        print(f'Downloading source file to {file_save_path}')
        torch.hub.download_url_to_file(source_root, file_save_path)
        source_file_path_list = [file_save_path]
    elif is_file:
        # when input source is single image
        source_file_path_list = [source_root]
    else:
        print('Cannot find image file.')

    source_type = dict(is_dir=is_dir, is_url=is_url, is_file=is_file)

    return source_file_path_list, source_type

3. 输出热力图

首先要配置好并且切换到mmrotate虚拟环境,然后运行以下命令:

这里是引用

python featmap_vis_demo.py <path to your photo> \
                    <path to your config file> \
                    <path to your weight file> \
                    --target-layers <想要输出特征图的位置 backbone or neck ...> \
                    --channel-reduction select_max \
                    --out-dir '<path to your output dir>'

具体的例子为:

python featmap_vis_demo.py demo/heatMap.png \
                    configs/rotated_rtmdet_tiny-9x-hrsc.py \
                    weights/155647.pth \
                    --target-layers neck \
                    --channel-reduction select_max \
                    --out-dir 'output'

然后就可以在output文件夹中看到输出的热力图:
在这里插入图片描述

参考MMYOLO教程:

玩转 MMYOLO 之工具篇(一):特征图可视化

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2156787.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

java intellij idea开发步骤,使用指南,工程创建与背景色字体配置,快捷键

intellij idea2021 配置背景色&#xff0c;字体大小&#xff0c;主题 快捷键

STM32系统时钟

时钟为单片机提供了稳定的机器周期&#xff0c;从而使我们的系统能够正常的运行 时钟就像我们人的心脏&#xff0c;一旦有问题就整个都会崩溃 stm32有很多外设&#xff0c;但不是所有的外设都使用同一种时钟频率工作&#xff0c;比如我们的内部看门狗和RTC 只要30几k的频率就…

计算机毕业设计推荐-基于python的白酒销售数据可视化分析

精彩专栏推荐订阅&#xff1a;在下方主页&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; &#x1f496;&#x1f525;作者主页&#xff1a;计算机毕设木哥&#x1f525; &#x1f496; 文章目录 一、白酒销售数据…

react:React Hook函数

使用规则 只能在组件中或者其他自定义的Hook函数中调用 只能在组件的顶层调用&#xff0c;不能嵌套在if、for、 其他函数中 基础Hook 函数 useState useState是一个hook函数&#xff0c;它允许我们向组件中添加一个状态变量&#xff0c;从而控制影响组件的渲染结果 示例1…

[Excel VBA]如何使用VBA自动生成图表

在Excel中&#xff0c;图表是可视化数据的重要工具。以下是一个VBA代码示例&#xff0c;帮助你自动生成图表。 1. 代码说明 该代码会根据指定数据范围创建一个柱状图&#xff0c;并设置图表的基本属性。 2. VBA代码 Sub CreateChart()Dim ws As WorksheetDim chartObj As Ch…

百度营销转化追踪(网页JS布码)

引言&#xff1a;使用百度营销api配置网站上各个模块组件的转化追踪&#xff0c;统计网站上的各组件模块点击等信息。 一、选择接入方式&#xff08;本文选择的是网页JS布码&#xff09; 参考文档&#xff1a;百度营销-商业开发者中心百度开发者中心是一个面向开发者的知识分享…

Java启动Tomcat: Can‘t load IA 32-bit .dll on a AMD 64-bit platform报错问题解决

&#x1f3ac; 鸽芷咕&#xff1a;个人主页 &#x1f525; 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想&#xff0c;就是为了理想的生活! 专栏介绍 在软件开发和日常使用中&#xff0c;BUG是不可避免的。本专栏致力于为广大开发者和技术爱好者提供一个关于BUG解决的经…

Java-数据结构-排序-(二) (๑¯∀¯๑)

文本目录&#xff1a; ❄️一、交换排序&#xff1a; ➷ 1、 冒泡排序&#xff1a; ▶ 代码&#xff1a; ➷ 2、 快速排序&#xff1a; ☞ 基本思想&#xff1a; ☞ 方法一&#xff1a;Hoare法 ▶ 代码&#xff1a; ☞ 方法二&#xff1a;挖坑法 ▶ 代码&#xff1a; ☞ 方法三…

GNU编译器(GCC):编译的4个过程及.elf、.list、.map文件功能说明

0 参考资料 GNU-LD-v2.30-中文手册.pdf GNU linker.pdf1 前言 一个完整的编译工具链应该包含以下4个部分&#xff1a; &#xff08;1&#xff09;编译器 &#xff08;2&#xff09;汇编器 &#xff08;3&#xff09;链接器 &#xff08;4&#xff09;lib库 在GNU工具链中&…

Linux-文件的压缩、解压

Linux系统常见有两种压缩格式&#xff0c;后缀分别是&#xff1a; .tar 称之为tarball&#xff0c;简单的将文件组装到一个.tar的文件内&#xff0c;并没有太多的文件体积减少&#xff0c;仅仅是简单的封装.gz gzip格式压缩文件&#xff0c;可以极大的减少压缩后的体积 针对这…

Lua中..和...的使用区别

一. .. 的用法 二. ... 的用法 在 Lua 中&#xff0c;... 是一个特殊符号&#xff0c;它用于表示不定数量的参数。当你在函数定义或调用中使用 ... 时&#xff0c;它可以匹配任意数量的参数&#xff0c;并将它们作为列表传递。在您的代码示例中&am…

基于SSD的RAG技术方案,推动LLM规模扩展

随着大型语言模型&#xff08;LLM&#xff09;的不断发展&#xff0c;它们在虚拟助手、聊天机器人和对话系统等应用中发挥着重要作用。然而&#xff0c;LLM面临的挑战之一是它们可能会生成虚假或误导性的信息&#xff0c;即所谓的“幻觉”。为了解决这一问题&#xff0c;检索增…

Java数据库连接——JDBC

目录 1、JDBC简介 2、JDBC应用 2.1 建立数据库连接 2.1.1 DriverManager静态方法获取连接 2.1.2 DataSource对象获取 2.2 获取SQL执行对象 2.2.1 SQL注入 2.2.2 Statement(执行静态SQL) 2.2.3 PreparedStatement(预处理的SQL执行对象) 2.3 执行SQL并返回结果 2.4 关…

Error when custom data is added to Azure OpenAI Service Deployment

题意&#xff1a;在向 Azure OpenAI 服务部署添加自定义数据时出现错误。 问题背景&#xff1a; I receive the following error when adding my custom data which is a .txt file (it doesnt matter whether I add it via Azure Cognitive Search, Azure Blob Storage, or F…

证书学习(五)Java实现RSA、SM2证书颁发

目录 一、知识回顾1.1 X.509 证书1.2 X509Certificate 类二、代码实现2.1 Maven 依赖2.2 RSA 证书颁发1)PfxGenerateUtil 证书文件生成工具类2)CertDTO 证书中间类3)RSACertGenerateTest RSA证书生成测试类4)执行结果2.3 SM2 证书颁发1)SM2Utils 国密SM2算法工具类2)SM2C…

查询一条 SQL 语句的流程

查询一条sql语句的流程 连接器:建立连接&#xff0c;管理连接、校验用户身份查询缓存:查询语句如果命中查询缓存则直接返回&#xff0c;否则继续往下执行&#xff08;MSQL8.0 已删除&#xff09;解析 SQL&#xff1a;通过解析器对 SQL 查询语句进行词法分析、语法分析&#xf…

【RH124】解释Linux文件系统权限

RH124教材中控制对文件的访问一章中有一道解释Linux文件系统权限的测验题&#xff0c;可以一起来看看&#xff1a; 一、权限解释 这是通过 ls -l 命令查看的结果。它显示了文件或目录的权限、拥有者、所属组等信息。 1、长列表的第一个字符表示文件类型&#xff1a; -是常…

(done) 声音信号处理基础知识(6) (How to Extract Audio Features)

参考&#xff1a;https://www.youtube.com/watch?v8A-W1xk7qs8&t2s 先复习之前分类的声学特征 时域特征流水线 如下是 441Khz 下一个采样点播放的时间。这比人类耳朵分辨率(10ms)还低。 所以&#xff0c;把多个采样点组合成一个 frame 的原因有&#xff0c;这是一个人…

计算机的错误计算(一百零一)

摘要 展示 在0附近数的函数值的计算精度问题。 计算机的错误计算&#xff08;一百&#xff09;探讨了 在一般情形下的计算精度问题。本节讨论其在0附近的数的函数值的计算精度问题。 例1. 已知 计算 不妨在Python 3.12.5下计算&#xff0c;则有 若在线运行R代码&#x…

阿⾥编码规范⾥⾯Manager分层介绍-专⽤名词和POJO实体类约定

开发⼈员&#xff1a;张三、李四、王五 ⼀定要避免单点故障 ⼀个微服务起码两个⼈熟悉&#xff1a;⼀个是主程⼀个是技术leader 推荐是团队⾥⾯两个开发⼈员 N⽅库说明 ⼀⽅库: 本⼯程内部⼦项⽬模块依赖的库(jar 包)⼆⽅库: 公司内部发布到中央仓库&#xff0c;可供公司…