深度学习系列53:mmdetection上手

news2025/1/13 10:19:12

1. 安装

使用openmim安装:

pip install -U openmim
mim install "mmengine>=0.7.0"
mim install "mmcv>=2.0.0rc4"

2. 测试案例

下载代码和模型:

git clone https://github.com/open-mmlab/mmdetection.git
mkdir ./checkpoints
mim download mmdet --config rtmdet_tiny_8xb32-300e_coco --dest ./checkpoints

运行代码,核心是定义inferencer和使用inferencer进行推理两行:

from mmdet.apis import DetInferencer

# Choose to use a config
model_name = 'rtmdet_tiny_8xb32-300e_coco'
# Setup a checkpoint file to load
checkpoint = './checkpoints/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'

# Set the device to be used for evaluation
device = 'cpu'

# Initialize the DetInferencer
inferencer = DetInferencer(model_name, checkpoint, device)

# Use the detector to do inference
img = 'demo.jpg'
result = inferencer(img, out_dir='./output')

# Show the structure of result dict
from rich.pretty import pprint
pprint(result, max_length=4)

# Show the output image
from PIL import Image
Image.open('./output/vis/demo.jpg')

3. 自定义数据进行训练

3.1 准备数据

建议使用coco格式,参见https://cocodataset.org/#format-data。文件从头至尾按照顺序分为以下段落:

{
“info”: info,
“licenses”: [license],
“images”: [image],
“annotations”: [annotation],
“categories”: [category]
}
下面是从instances_val2017.json文件中摘出的一个annotation的实例,这里的segmentation就是polygon格式:

{
“segmentation”: [[510.66,423.01,511.72,420.03,510.45…]],
“area”: 702.1057499999998,
“iscrowd”: 0,
“image_id”: 289343,
“bbox”: [473.07,395.93,38.65,28.67],
“category_id”: 18,
“id”: 1768
},
从instances_val2017.json文件中摘出的2个category实例如下所示:

{
“supercategory”: “person”,
“id”: 1,
“name”: “person”
},
{
“supercategory”: “vehicle”,
“id”: 2,
“name”: “bicycle”
},

我们来看测试案例的例子,包含三个大字段,其中categories非常简单,只有一个balloon(我们需要训练的目标)
在这里插入图片描述
images则是如下的清单:
在这里插入图片描述
annotations如下:
在这里插入图片描述

3.2 配置config文件

config文件中需要定义数据,模型,训练参数,优化器等各种参数。测试案例如下:

config_balloon = """
# Inherit and overwrite part of the config based on this config
_base_ = './rtmdet_tiny_8xb32-300e_coco.py'

data_root = 'data/balloon/' # dataset root

train_batch_size_per_gpu = 4
train_num_workers = 2

max_epochs = 20
stage2_num_epochs = 1
base_lr = 0.00008


metainfo = {
    'classes': ('balloon', ),
    'palette': [
        (220, 20, 60),
    ]
}

train_dataloader = dict(
    batch_size=train_batch_size_per_gpu,
    num_workers=train_num_workers,
    dataset=dict(
        data_root=data_root,
        metainfo=metainfo,
        data_prefix=dict(img='train/'),
        ann_file='train.json'))

val_dataloader = dict(
    dataset=dict(
        data_root=data_root,
        metainfo=metainfo,
        data_prefix=dict(img='val/'),
        ann_file='val.json'))

test_dataloader = val_dataloader

val_evaluator = dict(ann_file=data_root + 'val.json')

test_evaluator = val_evaluator

model = dict(bbox_head=dict(num_classes=1))

# learning rate
param_scheduler = [
    dict(
        type='LinearLR',
        start_factor=1.0e-5,
        by_epoch=False,
        begin=0,
        end=10),
    dict(
        # use cosine lr from 10 to 20 epoch
        type='CosineAnnealingLR',
        eta_min=base_lr * 0.05,
        begin=max_epochs // 2,
        end=max_epochs,
        T_max=max_epochs // 2,
        by_epoch=True,
        convert_to_iter_based=True),
]

train_pipeline_stage2 = [
    dict(type='LoadImageFromFile', backend_args=None),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='RandomResize',
        scale=(640, 640),
        ratio_range=(0.1, 2.0),
        keep_ratio=True),
    dict(type='RandomCrop', crop_size=(640, 640)),
    dict(type='YOLOXHSVRandomAug'),
    dict(type='RandomFlip', prob=0.5),
    dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),
    dict(type='PackDetInputs')
]

# optimizer
optim_wrapper = dict(
    _delete_=True,
    type='OptimWrapper',
    optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
    paramwise_cfg=dict(
        norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))

default_hooks = dict(
    checkpoint=dict(
        interval=5,
        max_keep_ckpts=2,  # only keep latest 2 checkpoints
        save_best='auto'
    ),
    logger=dict(type='LoggerHook', interval=5))

custom_hooks = [
    dict(
        type='PipelineSwitchHook',
        switch_epoch=max_epochs - stage2_num_epochs,
        switch_pipeline=train_pipeline_stage2)
]

# load COCO pre-trained weight
load_from = './checkpoints/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
visualizer = dict(vis_backends=[dict(type='LocalVisBackend'),dict(type='TensorboardVisBackend')])
"""

with open('../configs/rtmdet/rtmdet_tiny_1xb4-20e_balloon.py', 'w') as f:
    f.write(config_balloon)

3.3 开始训练

使用Mac M2芯片需要修改3个地方。首先是需要设置

export PYTORCH_ENABLE_MPS_FALLBACK=1

其次是mmcv中的nms需要转到cpu上计算,打开mmcv/ops/nms.py,将class NMSop(torch.autograd.Function)中的inds = ext_module.nms(bboxes, scores…)改为inds = ext_module.nms(bboxes.cpu(), scores.cpu()…)
运行后会出现一个assert报错,找到源代码,把那一行assert删掉即可。
运行完成后,可以查看tensorboard:

%load_ext tensorboard

# see curves in tensorboard
%tensorboard --logdir ./work_dirs

然后查看测试结果

from mmdet.apis import DetInferencer
import glob

# Choose to use a config
config = '../configs/rtmdet/rtmdet_tiny_1xb4-20e_balloon.py'
# Setup a checkpoint file to load
checkpoint = glob.glob('./work_dirs/rtmdet_tiny_1xb4-20e_balloon/best_coco*.pth')[0]

# Set the device to be used for evaluation
device = 'cpu'

# Initialize the DetInferencer
inferencer = DetInferencer(config, checkpoint, device)

# Use the detector to do inference
img = './data/balloon/val/4838031651_3e7b5ea5c7_b.jpg'
result = inferencer(img, out_dir='./output')
# Show the output image
Image.open('./output/vis/4838031651_3e7b5ea5c7_b.jpg')

在这里插入图片描述

4. 其他

MMYOLO:传统的目标检测库
MMRotate:旋转检测库
MMDetection3D:三维检测库
下面几期一一介绍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1225947.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python 计算最大回撤

1. 什么是最大回撤 最大回撤是评估金融产品收益的一个非常重要的风险指标,它指的是在选定历史周期内任一历史时点往后推,产品净值走到最低点时的收益率回撤幅度的最大值。 以上图为例, 最大回撤 ( V a l u e A − V a l u e B ) V a l u e …

取数游戏2(动态规划java)

取数游戏2 题目描述 给定两个长度为n的整数列A和B,每次你可以从A数列的左端或右端取走一个数。假设第i次取走的数为ax,则第i次取走的数的价值vibi⋅ax,现在希望你求出∑vi的最大值。 输入格式 第一行一个数T ,表示有T 组数据。…

如何使用websocket+node.js实现pc后台与小程序端实时通信

如何使用websocketnode.js实现pc后台与小程序端实时通信 一、使用node.js创建一个服务器二、pc后台连接ws三、小程序端连接ws四、实现效果 实现功能:实现pc后台与小程序端互发通信能够实时检测到 一、使用node.js创建一个服务器 1.安装ws依赖 npm i ws2.创建index.js const…

(免费领源码)python+django+mysql线上兼职平台系统83320-计算机毕业设计项目选题推荐

摘 要 信息化社会内需要与之针对性的信息获取途径,但是途径的扩展基本上为人们所努力的方向,由于站在的角度存在偏差,人们经常能够获得不同类型信息,这也是技术最为难以攻克的课题。针对线上兼职等问题,对线上兼职进行…

【狂神说Java】Docker概述 | Docker安装 | Docker的常用命令

✅作者简介:CSDN内容合伙人、信息安全专业在校大学生🏆 🔥系列专栏 :【狂神说Java】 📃新人博主 :欢迎点赞收藏关注,会回访! 💬舞台再大,你不上台&#xff0c…

STL的介绍

STL 是 C 标准模板库(Standard Template Library)的缩写,是 C 标准库中的一个重要组成部分。STL 提供了一组通用的模板类和函数,用于实现常用的数据结构和算法,如向量(vector)、链表&#xff08…

【算法挨揍日记】day29——139. 单词拆分、467. 环绕字符串中唯一的子字符串

139. 单词拆分 139. 单词拆分 题目描述: 给你一个字符串 s 和一个字符串列表 wordDict 作为字典。请你判断是否可以利用字典中出现的单词拼接出 s 。 注意:不要求字典中出现的单词全部都使用,并且字典中的单词可以重复使用。 解题思路&am…

【算法挨揍日记】day27——152. 乘积最大子数组、1567. 乘积为正数的最长子数组长度

152. 乘积最大子数组 152. 乘积最大子数组 题目描述: 给你一个整数数组 nums ,请你找出数组中乘积最大的非空连续子数组(该子数组中至少包含一个数字),并返回该子数组所对应的乘积。 测试用例的答案是一个 32-位 整…

【LLM】基于LLM的agent应用(上)

note 在未来,Agent 还会具备更多的可扩展的空间。 就 Observation 而言,Agent 可以从通过文本输入来观察来理解世界到听觉和视觉的集成;就 Action 而言,Agent 在具身智能的应用场景下,对各种器械进行驱动和操作。 Age…

BGP联盟和团体属性实验

目录 一、实验拓扑 二、实验要求 三、实验步骤 1、IP地址配置 2、ospf配置 3、BGP建邻 4、宣告网段 5、配置团体属性 一、实验拓扑 二、实验要求 1、按照图示配 IP 地址,R2,R3,R4,R5分别配 Loopbacke 口地址作为OSPF的Ro…

Synchronized 关键字的底层原理

目录 synchronized 同步语句块的情况 synchronized 修饰方法的的情况 synchronized 关键字底层原理属于JVM 层面 synchronized 同步语句块的情况 public class SynchronizedDemo {public void method() {synchronized (this) {System.out.println("synchronized 代码块…

hive sql 行列转换 开窗函数 炸裂函数

hive sql 行列转换 开窗函数 炸裂函数 准备原始数据集 学生表 student.csv 讲师表 teacher.csv 课程表 course.csv 分数表 score.csv 员工表 emp.csv 雇员表 employee.csv 电影表 movie.txt 学生表 student.csv 001,彭于晏,1995-05-16,男 002,胡歌,1994-03-20,男 003,周杰伦,…

037、目标检测-算法速览

之——常用算法速览 目录 之——常用算法速览 杂谈 正文 1.区域卷积神经网络 - R-CNN 2.单发多框检测SSD,single shot detection 3.yolo 杂谈 快速过一下目标检测的各类算法。 正文 1.区域卷积神经网络 - R-CNN region_based CNN,奠基性的工作。…

【AI】行业消息精选和分析(23-11-19)

行业动态 1、对标GPTs,微软连夜发布100多项更新!微软CEO:Copilot时代来了 2、英伟达联手微软推出AI代工服务 3、全新雅虎搜索将于 2024 年上线,未来还会推出更多 AI 和高级功能 4、Instagram 推出定制 AI 贴纸和滤镜功能&#xff…

【教3妹学编程-算法题】三个无重叠子数组的最大和

2哥 : 3妹,咋啦?一副苦大仇深的样子? 3妹:不开心呀不开心,羽生结弦宣布离婚。 2哥 : 羽生什么? 3妹:羽生结弦! 2哥 : 什么结弦? 3妹:羽生结弦!&am…

战神传奇【我本沉默精修版】win服务端+双端+充值后台+架设教程

搭建资源下载:战神传奇【我本沉默精修版】win服务端双端充值后台架设教程-海盗空间

安卓手机投屏到电视,跨品牌、跨地域同样可以实现!

在手机网页上看到的视频,也可以投屏到电视上看! 长时间使用手机,难免脖子会酸。这时候,如果你将手机屏幕投屏到大电视屏幕,可以减缓脖子的压力,而且大屏的视觉体验更爽。 假设你有一台安卓手机,…

TG Pro v2.87(mac温度风扇速度控制工具)

TG Pro 是适用于 macOS 的温度和风扇速度控制工具,可让您监控 Mac 组件(例如 CPU 和 GPU)的温度和风扇速度。如果您担心 Mac 过热或想要手动调整风扇速度以降低噪音水平,这将特别有用。 除了温度和风扇监控,TG Pro 还…

解锁数据安全之门:探秘迅软DSE的文件权限控制功能

企业管理者在进行数据安全管控时通常只关注到文件的加密方式,却忽略了以下问题:对于企业内部文档,根据其所承载的涉密程度不同,重要程度也不相同,需要由不同涉密等级的的人员进行处理,这就需要对涉密文档和…