使用自己的数据集,测试mmrotate新网络rotated_rtmdet,旋转目标检测

news2024/12/23 17:59:54

1.安装

!!!!一定不要安装mmrotate
1.版本需求
mmcv 2.0.0rc2
mmdet 3.0.0rc5
mmengine 0.5.0
不用安装mmcv-full

  1. 下载mmrotate 1.x 源码 (不要下载默认的master,因为新的网络只在1.x版本中)

2.制作数据集

因为需要的是dota格式的数据集
DOTA 格式的注解 txt 文件:

184 2875 193 2923 146 2932 137 2885 plane 0
66 2095 75 2142 21 2154 11 2107 plane 0
...

每行代表一个对象,并将其记录为一个 10 维数组 A

  • A[0:8]: 多边形的格式 (x1, y1, x2, y2, x3, y3, x4, y4)
  • A[8]: 类别
  • A[9]: 困难

我的数据是用labelme标注的json文件,所以先把json转txt

import json
import os

file_path='hailuo'
json_names=os.listdir(file_path)

for name in json_names:
    if name.endswith('json'):
        json_path=os.path.join(file_path,name)

        path2 = name.split('.')[0]+'.txt'
        # path2 = '1.txt'
        file2 = open(path2, 'w+')
        with open(json_path) as f:
            data = json.load(f)
            for box in data["shapes"]:
                _one=box['points']
                res = [str(x) for j in _one for x in j]
                if len(res)==10:
                    continue
                rest =' '.join(res)
                rest=rest+' box 0\n'
                file2.write(rest)

把所有的txt文件放在annfiles文件夹,所有的图片放在images文件夹
!

3. 修改配置文件

  1. G:\research\mmrotate-1.x\configs\rotated_rtmdet_base_\dota_rr.py
# dataset settings
dataset_type = 'DOTADataset'
data_root = '../data/'    # 路径
file_client_args = dict(backend='disk')

train_pipeline = [
    dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args),
    dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
    dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
    dict(
        type='mmdet.RandomFlip',
        prob=0.75,
        direction=['horizontal', 'vertical', 'diagonal']),
    dict(
        type='RandomRotate',
        prob=0.5,
        angle_range=180,
        rect_obj_labels=[9, 11]),
    dict(type='mmdet.PackDetInputs')
]
val_pipeline = [
    dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args),
    # avoid bboxes being resized
    dict(type='mmdet.LoadAnnotations', with_bbox=True, box_type='qbox'),
    dict(type='ConvertBoxType', box_type_mapping=dict(gt_bboxes='rbox')),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape'
                   ))  #  删除了'scale_factor'
] 
test_pipeline = [
    dict(type='mmdet.LoadImageFromFile', file_client_args=file_client_args),
    dict(
        type='mmdet.PackDetInputs',
        meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape'))  #  删除了'scale_factor'
]
train_dataloader = dict(
    batch_size=2,
    num_workers=8,
    persistent_workers=True,
    sampler=dict(type='DefaultSampler', shuffle=True),
    batch_sampler=None,
    pin_memory=False,
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='trainval/annfiles/',  # txt路径
        data_prefix=dict(img_path='trainval/images/'),  # 图片路径
        img_shape=(512, 640),
        filter_cfg=dict(filter_empty_gt=True),
        pipeline=train_pipeline))
val_dataloader = dict(
    batch_size=1,
    num_workers=2,
    persistent_workers=True,
    drop_last=False,
    sampler=dict(type='DefaultSampler', shuffle=False),
    dataset=dict(
        type=dataset_type,
        data_root=data_root,
        ann_file='trainval/annfiles/',   # txt路径
        data_prefix=dict(img_path='trainval/images/'),   # 图片路径
        img_shape=(512, 640),
        test_mode=True,
        pipeline=val_pipeline))
test_dataloader = val_dataloader

val_evaluator = dict(type='DOTAMetric', metric='mAP')
test_evaluator = val_evaluator
  1. G:\research\mmrotate-1.x\mmrotate\datasets\dota.py
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os.path as osp
from typing import List, Tuple

from mmengine.dataset import BaseDataset

from mmrotate.registry import DATASETS


@DATASETS.register_module()
class DOTADataset(BaseDataset):
    """DOTA-v1.0 dataset for detection.

    Note: ``ann_file`` in DOTADataset is different from the BaseDataset.
    In BaseDataset, it is the path of an annotation file. In DOTADataset,
    it is the path of a folder containing XML files.

    Args:
        img_shape (tuple[int]): The shape of images. Due to the huge size
            of the remote sensing image, we will cut it into slices with
            the same shape. Defaults to (1024, 1024).
        diff_thr (int): The difficulty threshold of ground truth. Bboxes
            with difficulty higher than it will be ignored. The range of this
            value should be non-negative integer. Defaults to 100.
    """

    METAINFO = {
        'classes':
        ('box',),
        # palette is a list of color tuples, which is used for visualization.
        'palette': [(165, 42, 42),]
    }  #  class 修改

    def __init__(self,
                 img_shape: Tuple[int, int] = (1024, 1024),
                 diff_thr: int = 100,
                 **kwargs) -> None:
        self.img_shape = img_shape
        self.diff_thr = diff_thr
        super().__init__(**kwargs)

    def load_data_list(self) -> List[dict]:
        """Load annotations from an annotation file named as ``self.ann_file``
        Returns:
            List[dict]: A list of annotation.
        """  # noqa: E501
        cls_map = {c: i
                   for i, c in enumerate(self.metainfo['classes'])
                   }  # in mmdet v2.0 label is 0-based
        data_list = []
        if self.ann_file == '':
            img_files = glob.glob(
                osp.join(self.data_prefix['img_path'], '*.png'))
            for img_path in img_files:
                data_info = {}
                data_info['img_path'] = img_path
                img_name = osp.split(img_path)[1]
                data_info['file_name'] = img_name
                img_id = img_name[:-4]
                data_info['img_id'] = img_id
                data_info['height'] = self.img_shape[0]
                data_info['width'] = self.img_shape[1]
                instance = dict(bbox=[], bbox_label=[], ignore_flag=0)
                data_info['instances'] = [instance]
                data_list.append(data_info)

            return data_list
        else:
            txt_files = glob.glob(osp.join(self.ann_file, '*.txt'))
            if len(txt_files) == 0:
                raise ValueError('There is no txt file in '
                                 f'{self.ann_file}')
            for txt_file in txt_files:
                data_info = {}
                img_id = osp.split(txt_file)[1][:-4]
                data_info['img_id'] = img_id
                img_name = img_id + '.png'
                data_info['file_name'] = img_name
                data_info['img_path'] = osp.join(self.data_prefix['img_path'],
                                                 img_name)
                data_info['height'] = self.img_shape[0]
                data_info['width'] = self.img_shape[1]

                instances = []
                with open(txt_file) as f:
                    s = f.readlines()
                    for si in s:
                        instance = {}
                        bbox_info = si.split()
                        instance['bbox'] = [float(i) for i in bbox_info[:8]]
                        cls_name = bbox_info[8]
                        if len(cls_name)>3:
                            print(11)
                            continue
                        instance['bbox_label'] = cls_map[cls_name]
                        difficulty = int(bbox_info[9])
                        if difficulty > self.diff_thr:
                            instance['ignore_flag'] = 1
                        else:
                            instance['ignore_flag'] = 0
                        instances.append(instance)
                data_info['instances'] = instances
                data_list.append(data_info)

            return data_list

    def filter_data(self) -> List[dict]:
        """Filter annotations according to filter_cfg.

        Returns:
            List[dict]: Filtered results.
        """
        if self.test_mode:
            return self.data_list

        filter_empty_gt = self.filter_cfg.get('filter_empty_gt', False) \
            if self.filter_cfg is not None else False

        valid_data_infos = []
        for i, data_info in enumerate(self.data_list):
            if filter_empty_gt and len(data_info['instances']) == 0:
                continue
            valid_data_infos.append(data_info)

        return valid_data_infos

    def get_cat_ids(self, idx: int) -> List[int]:
        """Get DOTA category ids by index.

        Args:
            idx (int): Index of data.
        Returns:
            List[int]: All categories in the image of specified index.
        """

        instances = self.get_data_info(idx)['instances']
        return [instance['bbox_label'] for instance in instances]


@DATASETS.register_module()
class DOTAv15Dataset(DOTADataset):
    """DOTA-v1.5 dataset for detection.

    Note: ``ann_file`` in DOTAv15Dataset is different from the BaseDataset.
    In BaseDataset, it is the path of an annotation file. In DOTAv15Dataset,
    it is the path of a folder containing XML files.
    """

    METAINFO = {
        'classes':
        ('plane', 'baseball-diamond', 'bridge', 'ground-track-field',
         'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
         'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
         'harbor', 'swimming-pool', 'helicopter', 'container-crane'),
        # palette is a list of color tuples, which is used for visualization.
        'palette': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0),
                    (138, 43, 226), (255, 128, 0), (255, 0, 255),
                    (0, 255, 255), (255, 193, 193), (0, 51, 153),
                    (255, 250, 205), (0, 139, 139), (255, 255, 0),
                    (147, 116, 116), (0, 0, 255), (220, 20, 60)]
    }


@DATASETS.register_module()
class DOTAv2Dataset(DOTADataset):
    """DOTA-v2.0 dataset for detection.

    Note: ``ann_file`` in DOTAv2Dataset is different from the BaseDataset.
    In BaseDataset, it is the path of an annotation file. In DOTAv2Dataset,
    it is the path of a folder containing XML files.
    """

    METAINFO = {
        'classes':
        ('plane', 'baseball-diamond', 'bridge', 'ground-track-field',
         'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
         'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
         'harbor', 'swimming-pool', 'helicopter', 'container-crane', 'airport',
         'helipad'),
        # palette is a list of color tuples, which is used for visualization.
        'palette': [(165, 42, 42), (189, 183, 107), (0, 255, 0), (255, 0, 0),
                    (138, 43, 226), (255, 128, 0), (255, 0, 255),
                    (0, 255, 255), (255, 193, 193), (0, 51, 153),
                    (255, 250, 205), (0, 139, 139), (255, 255, 0),
                    (147, 116, 116), (0, 0, 255), (220, 20, 60), (119, 11, 32),
                    (0, 0, 142)]
    }

  1. G:\research\mmrotate-1.x\configs\rotated_rtmdet\rotated_rtmdet_l-3x-dota_ms.py
    num_classes=1 修改
_base_ = [
    './_base_/default_runtime.py', './_base_/schedule_3x.py',
    './_base_/dota_rr_ms.py'
]
checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-l_8xb256-rsb-a1-600e_in1k-6a760974.pth'  # noqa

angle_version = 'le90'
model = dict(
    type='mmdet.RTMDet',
    data_preprocessor=dict(
        type='mmdet.DetDataPreprocessor',
        mean=[103.53, 116.28, 123.675],
        std=[57.375, 57.12, 58.395],
        bgr_to_rgb=False,
        boxtype2tensor=False,
        batch_augments=None),
    backbone=dict(
        type='mmdet.CSPNeXt',
        arch='P5',
        expand_ratio=0.5,
        deepen_factor=1,
        widen_factor=1,
        channel_attention=True,
        norm_cfg=dict(type='SyncBN'),
        act_cfg=dict(type='SiLU'),
        init_cfg=dict(
            type='Pretrained', prefix='backbone.', checkpoint=checkpoint)),
    neck=dict(
        type='mmdet.CSPNeXtPAFPN',
        in_channels=[256, 512, 1024],
        out_channels=256,
        num_csp_blocks=3,
        expand_ratio=0.5,
        norm_cfg=dict(type='SyncBN'),
        act_cfg=dict(type='SiLU')),
    bbox_head=dict(
        type='RotatedRTMDetSepBNHead',
        num_classes=1,
        in_channels=256,
        stacked_convs=2,
        feat_channels=256,
        angle_version=angle_version,
        anchor_generator=dict(
            type='mmdet.MlvlPointGenerator', offset=0, strides=[8, 16, 32]),
        bbox_coder=dict(
            type='DistanceAnglePointCoder', angle_version=angle_version),
        loss_cls=dict(
            type='mmdet.QualityFocalLoss',
            use_sigmoid=True,
            beta=2.0,
            loss_weight=1.0),
        loss_bbox=dict(type='RotatedIoULoss', mode='linear', loss_weight=2.0),
        with_objectness=False,
        exp_on_reg=True,
        share_conv=True,
        pred_kernel_size=1,
        use_hbbox_loss=False,
        scale_angle=False,
        loss_angle=None,
        norm_cfg=dict(type='SyncBN'),
        act_cfg=dict(type='SiLU')),
    train_cfg=dict(
        assigner=dict(
            type='mmdet.DynamicSoftLabelAssigner',
            iou_calculator=dict(type='RBboxOverlaps2D'),
            topk=13),
        allowed_border=-1,
        pos_weight=-1,
        debug=False),
    test_cfg=dict(
        nms_pre=2000,
        min_bbox_size=0,
        score_thr=0.05,
        nms=dict(type='nms_rotated', iou_threshold=0.1),
        max_per_img=2000),
)

# batch_size = (2 GPUs) x (4 samples per GPU) = 8
train_dataloader = dict(batch_size=4, num_workers=4)

4. 报错

  1. 找不到’scale_factor’
    除了删除配置文件的 这个配置外还要修改 其他关于scale_factor的错误,因为修改了后可以运行,所以忘记修改了哪里,但是效果炸裂

5.效果

左边的是标注的,右边的是检测结果
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
完全可以直接应用!!!!!!!!!!!!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/340052.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【前端】Vue项目:旅游App-(23)detail:房东介绍、热门评论、预定须知组件

文章目录目标过程与代码房东介绍landlord热门评论HotComment预定须知Book效果总代码修改或添加的文件detail.vuedetail-book.vuedetail-hotComment.vuedetail-landlord.vue参考本项目博客总结:【前端】Vue项目:旅游App-博客总结 目标 根据detail页面获…

我用python/C++调用ChatGPT自制了一个聊天问答机器人

目录1 ChatGPT完整版2 Python/C调用ChatGPT2.1 获取API秘钥2.2 测试API功能2.3 设计简单UI3 聊天问答1 ChatGPT完整版 2015年,OpenAI由马斯克、美国创业孵化器Y Combinator总裁阿尔特曼、全球在线支付平台PayPal联合创始人彼得蒂尔等硅谷科技大亨创立,公…

【脚本开发】运维人员必备技能图谱

脚本(Script)语言是一种动态的、解释性的语言,依据一定的格式编写的可执行文件,又称作宏或批处理文件。脚本语言具有小巧便捷、快速开发的特点;常见的脚本语言有Windows批处理脚本bat、Linux脚本语言shell以及python、…

Spring缓存Demo

Spring中的缓存用法:有说错请指正 启动类加EnableCache 注解很多,这里举例几个实用的 第一组 value和key都没什么特别的含义,随你自己取,注意key里面是包了一层的 Cacheable(value"user",key "findUsers") 第一次查询的时候,会查数据库,然后将返回结果…

【GlobalMapper精品教程】045:空间分析工具(2)——相交

GlobalMapper提供的空间分析(操作)的方法有:交集、并集、单并集、差异、对称差集、相交、重叠、接触、包含、等于、内部、分离等,本文主要讲述相交工具的使用。 文章目录 一、实验数据二、符号化设置三、相交运算四、结果展示五、心灵感悟一、实验数据 加载配套实验数据(…

Hadoop安装 --- 简易安装Hadoop

目录 1、使用xftp工具 在opt目录下创建install和soft文件 2、使用xftp工具 将压缩包上传到install文件 3、编写shell脚本 3.1、创建目录来放shell脚本 3.2、创建autoinsatll.sh文件并修改权限 3.3、编写autoinsatll.sh 文件 刷新资源 运行文件 格式化 启动所有进程 Ha…

ChatGPT到底有多牛?博主带你亲测

文章目录论文项目代码算法学习情感职业回答知乎ChatGpt网页版与客户端版个人评价论文 问他毕设框架: 让他帮我写一段毕设背景部分: 项目代码 我让他帮我用Django写一个demo网站: 算法 matlab写遗传算法: 问一个数据结构&…

Java是如何创建线程的(二)从glibc到kernel thread

Java是如何创建线程的(二)从glibc到kernel thread 背景 上一节我们讨论了java线程是如何创建的,看了看从java代码层面到jvm层面的源码里都干了什么。 整个流程还是比较复杂的,我将上一节总结的调用时序图贴在下面,方…

Flutter Widget - Container 容器

Container 属性 child 容器包含的子组件color 容器背景色padding 内边距margin 外边距decoration 定义容器形状、颜色alignment 子组件在容器内的对齐方式constraints 定义宽高width 宽(可选)height 高(可选)transform 变换transformAlignment 变换原点的相对位置foregroundDe…

【Unity3D 常用插件】Haste插件

一,Haste介绍 Haste插件是一款针对 Unity 3D 的 Everthing软件,可以实现基于名称快速定位对象的功能。Unity 3D 编辑器也自带了搜索功能,但是在 project视图 和 Hierarchy视图 中的对象需要分别查找,不支持模糊匹配。Haste插件就…

MySQL-InnoDB数据页结构浅析

在MySQL-InnoDB行格式浅析中,们简单提了一下 页 的概念,它是 InnoDB 管理存储空间的基本单位,一个页的大小一般是 16KB 。 InnoDB 为了不同的目的而设计了许多种不同类型的 页: 存放表空间头部信息的页存放 Insert Buffer信息的…

Maven专题总结

1. 什么是Maven Maven 是一个项目管理工具,它包含了一个项目对象模型 (POM: Project Object Model),一组标准集合,一个项目生命周期(Project Lifecycle),一个依赖管理系统(Dependency Management System),和…

3分钟,学会了一个调试CSS的小妙招

Ⅰ. 作用 用于调试CSS , 比控制台添更加方便,不需要寻找 ;边添加样式,边可以查看效果,适合初学者对CSS 的理解和学习; Ⅱ. 快速实现(两边) ① 显示这个样式眶 给 head 和 style 标签添加一个…

Power BI 筛选器函数---Window实例详解

一、Window函数 语法&#xff1a; Window ( <起始位置>,<起始位置类型>,<结束位置>,<结束位置类型>, [<关系>], [<OrderBy>],[空白],[PartitionBy] ) 含义&#xff1a; 对指定分区&#xff08;PartitioinBy)中的行&#xff08;关系表&…

getchar()的用法

getchar的功能 它的作用是从stdin流中读入一个字符&#xff0c;也就是说&#xff0c;如果stdin有数据的话不用输入它就可以直接读取了&#xff0c;第一次getchar()时&#xff0c;确实需要人工的输入&#xff0c;但是如果你输了多个字符&#xff0c;以后的getchar()再执行时就会…

python+django高校师生健康信息管理系统pycharm

管理员功能模块 4.1登录页面 管理员登录&#xff0c;通过填写注册时输入的用户名、密码、角色进行登录&#xff0c;如图所示。 4.2系统首页 管理员登录进入师生健康信息管理系统可以查看个人中心、学生管理、教师管理、数据收集管理、问卷分类管理、疫情问卷管理、问卷调查管理…

大数据框架之Hadoop:HDFS(一)HDFS概述

1.1HDFS产出背景及定义 HDFS 产生背景 随着数据量越来越大&#xff0c;在一个操作系统存不下所有的数据&#xff0c;那么就分配到更多的操作系统管理的磁盘中&#xff0c;但是不方便管理和维护&#xff0c;迫切需要一种系统来管理多台机器上的文件&#xff0c;这就是分布式文件…

查缺补漏三:事务隔离级别

什么是事务&#xff1f; 事务就是一组操作的集合&#xff0c;事务将整组操作作为一个整体&#xff0c;共同提交或者共同撤销 这些操作只能同时成功或者同时失败&#xff0c;成功即可提交事务&#xff0c;失败就执行事务回滚 MySQL的事务默认是自动提交的&#xff0c;一条语句执…

【LeetCode第 332 场周赛】

传送门 文章目录6354. 找出数组的串联值6355. 统计公平数对的数目6356. 子字符串异或查询6357. 最少得分子序列6354. 找出数组的串联值 题目 思路 前后指针 代码 class Solution { public:long long findTheArrayConcVal(vector<int>& nums) {long long res 0;i…

多线程相关面试题

讲解下你自己理解的 CAS 机制 ? 全称 Compare and swap, 即 “比较并交换”. 相当于通过一个原子的操作, 同时完成 “读取内存, 比较是否相等, 修改内存” 这三个步骤. 本质上需要 CPU 指令的支撑. ABA问题怎么解决&#xff1f; 给要修改的数据引入版本号. 在 CAS 比较数据…