60分钟吃掉detectron2

news2024/12/23 18:04:26

本范例演示使用非常有名的目标检测框架detectron2 🤗🤗

在自己的数据集(balloon数据)上训练实例分割模型MaskRCNN的方法。

detectron2框架的设计有以下一些优点:

  • 1,强大:提供了包括目标检测、实例分割、全景分割等非常广泛的视觉任务模型库。

  • 2,灵活:可以通过注册机制自定义模块或模型结构,从而进行扩展和改进。

  • 3,易用:通过list of dict格式定义自己的数据集, 简单好用。

公众号算法美食屋后台回复关键词: torchkeras,获取本文源代码和balloon数据集下载链接。

我们首先需要安装并导入detectron库~

!pip install 'git+https://github.com/facebookresearch/detectron2.git'
!pip install torchkeras
import numpy as np
import os, json, cv2, random
from PIL import Image 

import torch 

import detectron2
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor

#from detectron2.utils.logger import setup_logger
#setup_logger()

def cv2_show(arr):
    img = Image.fromarray(cv2.cvtColor(arr, cv2.COLOR_BGR2RGB))
    return img

0,预训练模型

from torchkeras import data 
#下载测试图片
img = data.get_example_image('park.jpg')
img.save('park.jpg')
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5  # set threshold for this model
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
im = cv2.imread("park.jpg")
outputs = predictor(im)
print(outputs["instances"].pred_classes)
print(outputs["instances"].pred_boxes)
v = Visualizer(im[:, :, ::-1], 
            MetadataCatalog.get(cfg.DATASETS.TRAIN[0]), scale=1.2)
out = v.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2_show(out.get_image()[:, :, ::-1])

98799cea10afaddc3fb0ebc83032f022.png

一,准备数据

detectron准备数据集,需要先注册。

如果是coco格式的数据,可以用以下方法快速注册:

from detectron2.data.datasets import register_coco_instances
register_coco_instances("my_dataset_train",{},"json_annotation_train.json","path/to/image/dir")
register_coco_instances("my_dataset_val", {}, "json_annotation_val.json","path/to/image/dir")

非coco格式的数据,可以用以下步骤进行注册:

  • 1,先将数据集整理成字典组成的列表形式

  • 2,使用DatasetCatalog注册数据集

from detectron2.structures import BoxMode

def get_balloon_dicts(img_dir):
    json_file = os.path.join(img_dir, "via_region_data.json")
    with open(json_file) as f:
        imgs_anns = json.load(f)

    dataset_dicts = []
    for idx, v in enumerate(imgs_anns.values()):
        record = {}
        
        filename = os.path.join(img_dir, v["filename"])
        height, width = cv2.imread(filename).shape[:2]
        
        record["file_name"] = filename
        record["image_id"] = idx
        record["height"] = height
        record["width"] = width
      
        annos = v["regions"]
        objs = []
        for _, anno in annos.items():
            assert not anno["region_attributes"]
            anno = anno["shape_attributes"]
            px = anno["all_points_x"]
            py = anno["all_points_y"]
            poly = [(x + 0.5, y + 0.5) for x, y in zip(px, py)]
            poly = [p for x in poly for p in x]

            obj = {
                "bbox": [np.min(px), np.min(py), np.max(px), np.max(py)],
                "bbox_mode": BoxMode.XYXY_ABS,
                "segmentation": [poly],
                "category_id": 0,
            }
            objs.append(obj)
        record["annotations"] = objs
        dataset_dicts.append(record)
    return dataset_dicts



try:
    #DatasetCatalog.remove('balloon_train')
    #DatasetCatalog.remove('balloon_val')
    
    DatasetCatalog.register("balloon_train", lambda : get_balloon_dicts("./data/balloon/train"))
    MetadataCatalog.get("balloon_train" ).set(thing_classes=["balloon"])
    
    DatasetCatalog.register("balloon_val", lambda : get_balloon_dicts("./data/balloon/val"))
    MetadataCatalog.get("balloon_val" ).set(thing_classes=["balloon"])
    
except Exception as err:
    pass 
    
balloon_metadata = MetadataCatalog.get("balloon_train")

我们来可视化一下数据,看看是否正确。

dicts_train = DatasetCatalog.get('balloon_train') #get_balloon_dicts("./data/balloon/train")  
dicts_val = DatasetCatalog.get('balloon_val') #get_balloon_dicts("./data/balloon/val")
dic = dicts_train[3]
img = cv2.imread(dic["file_name"])
visualizer = Visualizer(img[:, :, ::-1], metadata=balloon_metadata, scale=0.5)
out = visualizer.draw_dataset_dict(dic)
cv2_show(out.get_image()[:, :, ::-1])

84b94e0020b5cf04514e999919639e8d.png

二,定义模型

detectron2通过配置文件定义模型。可以查看 detectron2目录下的configs路径,有各种各样功能的模型配置文件可以使用。

包括:Detection(检测), InstanceSegmentation(实例分割), Keypoints(关键点检测), Panoptic(全景分割) 等各种类型

cfg = get_cfg()

cfg.merge_from_file(model_zoo.get_config_file(
    "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url(
   "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")  


cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 128   # The "RoIHead batch size". 128 is faster, and good enough for this toy dataset (default: 512)
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1  # only has one class (ballon). 


#model = detectron2.modeling.build_model(cfg) # This way do not load pretrained weights

predictor = DefaultPredictor(cfg)
model = predictor.model

三,训练模型

以下代码使用detectron2原生的DefaultTrainer进行训练,比较简单。

但是这个DefaultTrainer灵活性一般,当你想在训练循环中加入自己想要的功能时比较麻烦,并且日志输出不够直观。

此外也没有earlystopping,不能够保存验证集上最优的权重。

from detectron2.engine import DefaultTrainer

cfg.DATASETS.TRAIN = ("balloon_train",)
cfg.DATASETS.TEST = ("balloon_val",)
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.IMS_PER_BATCH = 4

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

    
cfg.DATALOADER.NUM_WORKERS = 2
cfg.SOLVER.IMS_PER_BATCH = 2  # This is the real "batch size" commonly known to deep learning people
cfg.SOLVER.BASE_LR = 0.00025  # pick a good LR
cfg.SOLVER.MAX_ITER = 600    
cfg.SOLVER.STEPS = []        # do not decay learning rate

os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

trainer = DefaultTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

d986b45953d50c6357601eb663823b8d.png

下面使用我们的 梦中情炉 ~ torchkeras ~ 实现最优雅的训练循环~  😋😋

ds_train = detectron2.data.DatasetFromList(dicts_train)
ds_val = detectron2.data.DatasetFromList(dicts_val)

mp = detectron2.data.DatasetMapper(cfg,is_train=True)
batch_size = 16
dl_train = detectron2.data.build_detection_train_loader(ds_train,
        mapper=mp,total_batch_size=batch_size,num_workers=2)
dl_train.size = len(ds_train)//batch_size 

dl_val = detectron2.data.build_detection_train_loader(ds_val,
        mapper=mp,total_batch_size=1,num_workers=2)
dl_val.size = len(dicts_val)
for batch in dl_val:
    break
from torchkeras import KerasModel 
from tqdm import tqdm 
from detectron2.utils.events import EventStorage 

class StepRunner:
    def __init__(self, net, loss_fn, accelerator, 
                 stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator
        
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.train() 
    
    def __call__(self, batch):
        
        #loss
        with EventStorage() as event_storage:
            loss_dict = self.net(batch)
            
        loss = sum(loss_dict.values())
        
        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses
        step_losses = {self.stage+"_loss":all_loss.item()}
        
        #metrics
        step_metrics = {}
        
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
    
class EpochRunner:
    def __init__(self,steprunner,quiet=False):
        self.steprunner = steprunner
        self.stage = steprunner.stage
        self.accelerator = self.steprunner.accelerator
        self.quiet = quiet
        
    def __call__(self,dataloader):
        
        try:
            n = len(dataloader)
        except Exception as err:
            n = dataloader.size 
        loop = tqdm(enumerate(dataloader,start=1), 
                    total =n,
                    file=sys.stdout,
                    disable=not self.accelerator.is_local_main_process or self.quiet,
                    ncols = 100
                   )
        epoch_losses = {}
        for step, batch in loop: 
            step_losses,step_metrics = self.steprunner(batch)   
            step_log = dict(step_losses,**step_metrics)
            for k,v in step_losses.items():
                epoch_losses[k] = epoch_losses.get(k,0.0)+v
            if step<n:
                loop.set_postfix(**step_log)
            elif step==n:
                epoch_metrics = step_metrics
                epoch_metrics.update({self.stage+"_"+name:metric_fn.compute().item() 
                                 for name,metric_fn in self.steprunner.metrics_dict.items()})
                epoch_losses = {k:v/step for k,v in epoch_losses.items()}
                epoch_log = dict(epoch_losses,**epoch_metrics)
                loop.set_postfix(**epoch_log)
                for name,metric_fn in self.steprunner.metrics_dict.items():
                    metric_fn.reset()
            else:
                break 
        return epoch_log
    
    
KerasModel.StepRunner = StepRunner 
KerasModel.EpochRunner = EpochRunner
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(params, lr=1e-4)

keras_model = KerasModel(model,
                         loss_fn = None,
                         metrics_dict=None,
                         optimizer= optimizer
                        )
ckpt_path = 'checkpoint.pt'
keras_model.fit(train_data=dl_train,val_data=dl_val,
    epochs=30,patience=10,
    monitor='val_loss',
    mode='min',
    ckpt_path =ckpt_path,
    plot=True
)

d20e28e90b0da6c1378d625a97abab6a.png

991498058a2c892eb34aefbccd1b9fcf.png

四,评估模型

from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.data import build_detection_test_loader


evaluator = COCOEvaluator("balloon_val", output_dir="./output")
dl_val = build_detection_test_loader(cfg, "balloon_val")
print(inference_on_dataset(model, dl_val, evaluator))

五,使用模型

from detectron2.engine import DefaultPredictor

#cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth") 
cfg.MODEL.WEIGHTS = ckpt_path
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7   # set a custom testing threshold
predictor = DefaultPredictor(cfg)
from detectron2.utils.visualizer import ColorMode

im = cv2.imread(dicts_val[10]['file_name'])
outputs = predictor(im)  # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
vis = Visualizer(im[:, :, ::-1],
               metadata=balloon_metadata, 
               scale=0.5, 
               instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels. This option is only available for segmentation models
)
out = vis.draw_instance_predictions(outputs["instances"].to("cpu"))
cv2_show(out.get_image()[:, :, ::-1])

3bb7048f1c15ce6e3bdec53bf9faa9ff.png

公众号算法美食屋后台回复关键词: torchkeras,获取本文源代码和balloon数据集下载链接。

万水千山总是情,点个赞赞行不行?😋😋

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/582119.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Boot启动流程

1 Springboot 启动流程 创建一个StopWatch实例&#xff0c;用来记录SpringBoot的启动时间。 通过SpringFactoriesLoader加载listeners&#xff1a;比如EventPublishingRunListener。 发布SprintBoot开始启动事件&#xff08;EventPublishingRunListener#starting()&#xff0…

性能测试——基本性能监控系统使用

这里写目录标题 一、基本性能监控系统组成二、环境搭建1、准备数据文件 type.db collectd.conf2、启动InfluxDB3、启动grafana4、启动collectd5、Grafana中配置数据源 一、基本性能监控系统组成 Collectd InfluxdDB Grafana Collectd 是一个守护(daemon)进程&#xff0c;用来…

【数据结构】时间复杂度与空间复杂度

目录 前言一、算法效率1. 算法效率的定义 二、时间复杂度1. 时间复杂度的定义2. 时间复杂度的计算 三、空间复杂度1. 空间复杂度的定义2. 空间复杂度的计算 四、时间复杂度曲线图结尾 前言 在学习C语言的时候&#xff0c;大多数的小伙伴们并不会对算法的效率了解&#xff0c;也…

视频采集到录制 - 音频采集到降噪

继续上篇的视频采集到录制 视频采集相对来说还是算正常&#xff0c;如果资源够用&#xff0c;使用第三方库也是种解决办法 但音频采集网上资料相对也少&#xff0c;走了一遍&#xff0c;也发现存在很多坑 1. 音频采集 一般来说&#xff0c;采用MIC采集&#xff0c;采集出来的格…

内存泄露的循环引用问题

内存泄漏一直是很多大型系统故障的根源&#xff0c;也是一个面试热点。那么在编程语言层面已经提供了内存回收机制&#xff0c;为什么还会产生内存泄漏呢&#xff1f; 这是因为应用的内存管理一直处于一个和应用程序执行并发的状态&#xff0c;如果应用程序申请内存的速度&…

希尔伯特旅馆里,住着AI的某种真相

“无穷”和“无穷1”&#xff0c;哪个更大&#xff1f; 已经吸收了不知道多少数据的AI模型&#xff0c;和比他多学习一条数据的模型&#xff0c;哪个更智能&#xff1f; 想聊聊这个问题&#xff0c;出于一个偶然的机会。很早之前我在测试ChatGPT的时候&#xff0c;突然想问他个…

简单工厂、工厂方法、抽象工厂模式-这仨货的区别

要想明白这三玩意的区别就需要知道这三玩意的优缺点&#xff1b; 之所以有三种工厂模式&#xff0c;就说明它们各有所长&#xff0c;能解决不同场景的问题&#xff1b; 一、简单工厂模式 UML图 代码 public class MobileFactory {public static Mobile getMobile(String brand)…

【Linux】线程概述、创建线程、终止线程

目录 线程概述1、创建线程函数解析代码举例 2、终止线程函数解析代码举例 橙色 线程概述 与进程类似&#xff0c;线程是允许应用程序并发执行多个任务的一种机制。一个进程可以包含多个线程。 进程是 CPU 分配资源的最小单位&#xff0c;线程是操作系统调度执行的最小单位。…

回归预测 | MATLAB实现SSA-CNN-LSTM麻雀算法优化卷积长短期记忆神经网络多输入单输出回归预测

回归预测 | MATLAB实现SSA-CNN-LSTM麻雀算法优化卷积长短期记忆神经网络多输入单输出回归预测 目录 回归预测 | MATLAB实现SSA-CNN-LSTM麻雀算法优化卷积长短期记忆神经网络多输入单输出回归预测预测效果基本介绍模型描述程序设计学习总结参考资料 预测效果 基本介绍 MATLAB实现…

【笔记整理】轻量级神经网络 MobileNetV3

【笔记整理】轻量级神经网络 MobileNetV3 文章目录 【笔记整理】轻量级神经网络 MobileNetV31、深度可分离卷积2、翻转残差块和线性瓶颈结构3、h-swish 函数和 SE 模块4、网络结构搜索 近年来关于 CNN 的研究在飞速发展&#xff0c;CNN 模型在目标检测、图像分割等领域都取得了…

力扣sql中等篇练习(二十九)

力扣sql中等篇练习(二十九) 1 计算每个销售人员的影响力 1.1 题目内容 1.1.1 基本题目信息1 1.1.2 基本题目信息2 1.1.3 示例输入输出 a 示例输入 b 示例输出 1.2 示例sql语句 # Write your MySQL query statement below SELECT s1.salesperson_id,s1.name,IFNULL(t.total…

毕业季到底是去大厂还是去小公司

(点击即可收听) 毕业季到底是去大厂还是去小公司 相信很多人在选择大小公司的时候,会比较痛苦,外面的人想进去,里面的人想出来&#xff0c;至于选择大厂还是小公司 这是因人而异的,不同的阶段都可以有不同的选择 进大厂不一定就是对的,进小公司也不一定就是错的,学习东西,增长经…

股票量化分析工具QTYX使用攻略——涨停个股挖掘热门板块(更新2.6.5)

搭建自己的量化系统 如果要长期在市场中立于不败之地&#xff01;必须要形成一套自己的交易系统。 行情不等人&#xff01;边学习边实战&#xff0c;在实战中学习才是最有效地方式。于是我们分享一个即可以用于学习&#xff0c;也可以用于实战炒股分析的量化系统——QTYX。 QTY…

软考A计划-试题模拟含答案解析-卷九

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 &#x1f449;关于作者 专注于Android/Unity和各种游戏开发技巧&#xff0c;以及各种资源分享&am…

JetBrains的多数据库管理和SQL工具DataGrip 2023版本在Win10系统的下载与安装配置教程

目录 前言一、DataGrip 安装二、使用配置总结 前言 DataGrip是一款多数据库管理和SQL工具&#xff0c;适用于不同类型的数据库。它提供了丰富的功能和工具&#xff0c;可以帮助开发人员更高效地管理数据库、编写SQL查询和执行数据操作。 DataGrip的主要特点&#xff1a; ——…

这里有3个Tips,也许可以帮你躲过ChatGPT大规模封号 | AIGC实践

据说&#xff0c;从昨天开始&#xff0c;ChatGPT又双叒叕开始大规模封号&#xff0c;很多注册用户收到这样一则消息&#xff1a; 大意是说&#xff1a;OpenAI 发现了你的 ChatGPT 账号存在可疑活动&#xff0c;为了保障平台安全&#xff0c;已自动退款并取消你的 ChatGPT Plus …

驱动开发:内核解析内存四级页表

当今操作系统普遍采用64位架构&#xff0c;CPU最大寻址能力虽然达到了64位&#xff0c;但其实仅仅只是用到了48位进行寻址&#xff0c;其内存管理采用了9-9-9-9-12的分页模式&#xff0c;9-9-9-9-12分页表示物理地址拥有四级页表&#xff0c;微软将这四级依次命名为PXE、PPE、P…

七年老程序员的三四月总结:三十岁、准备婚礼、三次分享

你好&#xff0c;我是 shixin&#xff0c;一名工作七年的安卓开发。 每两个月我会做一次总结&#xff0c;记下这段时间里有意义的事和值得反复看的内容&#xff0c;为的是留一些回忆、评估自己的行为、沉淀有价值的信息。 一转眼 2023 年过去了三分之一&#xff0c;这两个月经历…

【数据湖仓架构】数据湖和仓库:Databricks 和 Snowflake

是时候将数据分析迁移到云端了。我们比较了 Databricks 和 Snowflake&#xff0c;以评估基于数据湖和基于数据仓库的解决方案之间的差异。 在这篇文章中&#xff0c;我们将介绍基于数据仓库和基于数据湖的云大数据解决方案之间的区别。我们通过比较多种云环境中可用的两种流行技…

HTML+CSS+JavaScript制作弹幕效果

全屏弹幕 <!DOCTYPE html> <html> <head><meta charset"UTF-8"><title>弹幕效果</title><style>/* 设置弹幕的样式 */.bullet {position: absolute;font-size: 20px;color: white;text-shadow: 1px 1px 1px black;white-s…