如何用DETR(detection transformer)训练自己的数据集,并推理图片

news2025/1/12 2:43:56

前言

这里我分享一下官方写的DETR的预测代码,其实就是github上DETR官方写的一个Jupyter notebook,可能需要梯子才能访问,这里我贴出来。因为DETR比较难训练,我觉得可以用官方的预训练模型来看看效果。

如果要用自己训练的模型来推理,可以参考我的这篇文章:DETR训练自己的数据集

顺便说一下,这个模型是简易版,其实就是DETR论文最后一页贴出来的代码,mAP大概要比源码的低2-3个点,但已经很强了。只用了50行代码就把DETR模型架构展示出来了,真的是非常简洁的一个模型!

一、jupyter notebook运行代码

如果用的是jupyter notebook,用这个代码,修改代码末尾的图片路径即可

from PIL import Image
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'svg'
import ipywidgets as widgets
from IPython.display import display, clear_output
import cv2
import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False)
from torchvision import models


# COCO classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

def plot_results(pil_img, prob, boxes):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        cl = p.argmax()
        text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()
    
    
def detect(im, model, transform):
    # mean-std normalize the input image (batch-size: 1)
    img = transform(im).unsqueeze(0)

    # demo model only support by default images with aspect ratio between 0.5 and 2
    # if you want to use images with an aspect ratio outside this range
    # rescale your image so that the maximum size is at most 1333 for best results
    assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'

    # propagate through the model
    outputs = model(img)

    # keep only predictions with 0.7+ confidence
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > 0.7

    # convert boxes from [0; 1] to image scales
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
    
    return probas[keep], bboxes_scaled
    
    
    
class DETRdemo(nn.Module):
    """
    Demo DETR implementation.

    Demo implementation of DETR in minimal number of lines, with the
    following differences wrt DETR in the paper:
    * learned positional encoding (instead of sine)
    * positional encoding is passed at input (instead of attention)
    * fc bbox predictor (instead of MLP)
    The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
    Only batch size 1 supported.
    """
    def __init__(self, num_classes, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()

        # create ResNet-50 backbone
        self.backbone = resnet50()
        del self.backbone.fc

        # create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # prediction heads, one extra class for predicting non-empty slots
        # note that in baseline DETR linear_bbox layer is 3-layer MLP
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)

        # output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        # propagate inputs through ResNet-50 up to avg-pool layer
        x = self.backbone.conv1(inputs)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)

        # construct positional encodings
        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)

        # propagate through the transformer
        h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
                             self.query_pos.unsqueeze(1)).transpose(0, 1)
        
        # finally project transformer outputs to class labels and bounding boxes
        return {'pred_logits': self.linear_class(h), 
                'pred_boxes': self.linear_bbox(h).sigmoid()}

if __name__=="__main__":
    detr = DETRdemo(num_classes=91)
    state_dict = torch.hub.load_state_dict_from_url(url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',
                                                    map_location='cpu', check_hash=True)
    detr.load_state_dict(state_dict)
    detr.eval()
    
    # 这里修改自己的图片路径即可,注意路径不要有中文,否则要加一行解码的代码
    im = cv2.imread("test4.JPG")
    im= cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im=Image.fromarray(im)
    scores, boxes= detect(im, detr, transform)
    plot_results(im, scores, boxes)

这是用我自己拍的图片跑的效果:

在这里插入图片描述

二、Pycharm(非jupyter notebook的IDE)

from PIL import Image
import matplotlib.pyplot as plt
import cv2
import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T

torch.set_grad_enabled(False)

# COCO classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b


def plot_results(pil_img, prob, boxes):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        cl = p.argmax()
        text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()


def detect(im, model, transform):
    # mean-std normalize the input image (batch-size: 1)
    img = transform(im).unsqueeze(0)

    # demo model only support by default images with aspect ratio between 0.5 and 2
    # if you want to use images with an aspect ratio outside this range
    # rescale your image so that the maximum size is at most 1333 for best results
    assert img.shape[-2] <= 1600 and img.shape[
        -1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'

    # propagate through the model
    outputs = model(img)

    # keep only predictions with 0.7+ confidence
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > 0.7

    # convert boxes from [0; 1] to image scales
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)

    return probas[keep], bboxes_scaled


class DETRdemo(nn.Module):
    """
    Demo DETR implementation.

    Demo implementation of DETR in minimal number of lines, with the
    following differences wrt DETR in the paper:
    * learned positional encoding (instead of sine)
    * positional encoding is passed at input (instead of attention)
    * fc bbox predictor (instead of MLP)
    The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
    Only batch size 1 supported.
    """

    def __init__(self, num_classes, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()

        # create ResNet-50 backbone
        self.backbone = resnet50()
        del self.backbone.fc

        # create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # prediction heads, one extra class for predicting non-empty slots
        # note that in baseline DETR linear_bbox layer is 3-layer MLP
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)

        # output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        # propagate inputs through ResNet-50 up to avg-pool layer
        x = self.backbone.conv1(inputs)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)

        # construct positional encodings
        H, W = h.shape[-2:]
        pos = torch.cat([
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)

        # propagate through the transformer
        h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
                             self.query_pos.unsqueeze(1)).transpose(0, 1)

        # finally project transformer outputs to class labels and bounding boxes
        return {'pred_logits': self.linear_class(h),
                'pred_boxes': self.linear_bbox(h).sigmoid()}


if __name__ == "__main__":
    detr = DETRdemo(num_classes=91)
    state_dict = torch.hub.load_state_dict_from_url(
        url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',
        map_location='cpu', check_hash=True)
    detr.load_state_dict(state_dict)
    detr.eval()
    im = cv2.imread("test4.JPG")
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    im = Image.fromarray(im)
    scores, boxes = detect(im, detr, transform)
    plot_results(im, scores, boxes)

官方github:DETR

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1313989.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MYSQl基础操作命令合集与详解

MySQL入门 先来个总结 SQL语言分类 DDL&#xff08;Data Definition Language&#xff09; - 数据定义语言: 用于定义和管理数据库结构&#xff0c;包括创建、修改和删除数据库对象。 示例&#xff1a;CREATE, ALTER, DROP等语句。 DML&#xff08;Data Manipulation Lan…

FC-13A(用于汽车应用的kHz范围晶体单元,低轮廓贴片)

FC-13A晶体非常适合用在汽车导航系统设计中的应用&#xff0c;是一种具有优异的频率性能和AEC-Q200标准认证的汽车工业级高精度晶体,FC-13A是一款尺寸为3.2 1.5 0.9mm&#xff0c;频率范围32.768KHz耐高温晶振&#xff0c;频率温度系数仅为-0.04ppm/℃&#xff0c;并且其老化…

Java基础面试题小结

基础面试题 Java语言简介 Java是1995年由sun公司推出的一门高级语言&#xff0c;该语言具备如下特点: 简单易学&#xff0c;相较于C语言和C&#xff0c;没有指针的概念&#xff0c;所以操作和使用是会相对容易一些。平台无关性&#xff0c;即Java程序可以通过Java虚拟机在不…

牛客网SQL训练2—SQL基础进阶

文章目录 一、基本查询二、数据过滤三&#xff1a;函数四&#xff1a;分组聚合五&#xff1a;子查询六&#xff1a;多表连接七&#xff1a;组合查询八&#xff1a;技能专项-case when使用九&#xff1a;多表连接-窗口函数十&#xff1a;技能专项-having子句十一&#xff1a;技能…

vue 使用 Echarts做地图及飞线效果

前言&#xff1a; 效果图 一. 项目中安装以及引入Echarts 1.1 npm 命令安装echarts库 npm install echarts --save 1.2 yarn命令安装echarts库 yarn add echarts 1.3 引用 a. 在使用页面上引入 在Vue组件的script标签中引入echarts库 使用 echarts import * as echarts f…

flowable之三 启动一个流程并跟踪

1. 背景介绍 当我们部署一个流程并启动后&#xff0c;Flowable会按照既定流程定义及进行节点处理以及自动流转&#xff0c;从一个节点执行到下一个节点&#xff0c;直至结束。在此过程中&#xff0c;系统如何处理BPMN XML文件&#xff1f;节点如何进行流转&#xff1f;本文对f…

衡兰芷若成绝响,人间不见周海媚(4k修复基于PaddleGan)

一代人有一代人的经典回忆&#xff0c;1994年由周海媚、马景涛、叶童主演的《神雕侠侣》曾经风靡一时&#xff0c;周海媚所诠释的周芷若凝聚了汉水之钟灵&#xff0c;峨嵋之毓秀&#xff0c;遇雪尤清&#xff0c;经霜更艳&#xff0c;俘获万千观众&#xff0c;成为了一代人的共…

柔性数组(结构体成员)

目录 前言&#xff1a; 柔性数组&#xff1a; 给柔性数组分配空间&#xff1a; 调整柔性数组大小&#xff1a; 柔性数组的好处&#xff1a; 前言&#xff1a; 柔性数组&#xff1f;可能你从未听说&#xff0c;但是确实有这个概念。听名字&#xff0c;好像就是柔软的数…

鸿蒙小车之多任务调度实验

说到鸿蒙我们都会想到华为mate60&#xff1a;遥遥领先&#xff01;我们一直领先&#xff01; 我们这个小车也是采用的是鸿蒙操作系统&#xff0c;学习鸿蒙小车&#xff0c;让你遥遥领先于你的同学。 文章目录 前言一、什么是任务&#xff1f;为什么要有任务二、任务的状态三、任…

Axure自定义元件

目录 1.processOne的使用 ​编辑2.自定义元件的使用、 2.1如何自定义一个元件 2.2使用自定义元件 导语&#xff1a; Axure是绘制原型图的软件&#xff0c;但是我们很多时候不知道&#xff0c;画哪一个板块&#xff0c;所以流程图的绘制也是非常重要的 1.processOne的使用…

一些好用的VSCode扩展

可以在扩展这里直接搜索需要的扩展&#xff0c;点击安装即可。 1.Chinese 中文扩展&#xff0c;就是说虽然咱们懂点英语&#xff0c;但还是中文看着方便 2.Auto Rename Tag 当你重命名一个HTML 标签时&#xff0c;会自动重命名与他配对的HTML 标签 当你选择h4这个标签时&…

如何解决Windows 11黑屏的问题,让电脑“重见光明”

本页介绍了经过测试并证明有效的常见Windows 11黑屏故障的所有修复程序。 本页上的提示和解决方案适用于所有Windows 11设备,从台式电脑和笔记本电脑到微软的Surface二合一设备。 是什么导致Windows 11黑屏死机 在使用Windows 11时,显示器或屏幕明显关闭,通常被称为Window…

YOLOv8 图片目标计数 | 特定目标进行计数

全类别计数特定类别计数如何使用 YOLOv8 进行对象计数 有很多同学留言说想学 YOLOv8 目标计数。那么今天这篇博客,我将教大家如何使用 YOLOv8 进行对象计数。YOLOv8 是一种非常强大的对象检测模型,它可以识别图像中的各种对象。我们将学习如何利用这个模型对特定对象进行计数…

【日常笔记】notepad++ 正则表达式基本用法

一、场景 二、正则表达式--语法 2.1、学习基本的匹配字符&#xff1a; 2.2、学习特殊字符和量词&#xff1a; 2.3、学习转义字符 2.4、学习分组和捕获 2.5、区分大小写 和 匹配整个单词 2.6、引用分组 三、实战 ▶ 希望把课程目录中 -- 前面的都去掉 一、场景 希望把…

如何使用ArcGIS Pro裁剪影像

对影像进行裁剪是一项比较常规的操作&#xff0c;因为到手的影像可能是多种范围&#xff0c;需要根据自己需求进行裁剪&#xff0c;这里为大家介绍一下ArcGIS Pro中裁剪的方法&#xff0c;希望能对你有所帮助。 数据来源 本教程所使用的数据是从水经微图中下载的影像和行政区…

Java EE 多线程之 JUC

文章目录 1. Callable 接口2. ReentrantLock3. 信号量4. CountDownLatch JUC这里就是指&#xff08;java.util.concurrent&#xff09; concurrent 就是并发的意思 这个包里的内容&#xff0c;主要就是一些多线程相关的组件 1. Callable 接口 Callable 也是一种创建线程的方式…

Go 与 Rust:现代编程语言的深度对比

在快速发展的软件开发领域中&#xff0c;选择合适的编程语言对项目的成功至关重要。Go 和 Rust 是两种现代编程语言&#xff0c;它们都各自拥有一系列独特的特性和优势。本文旨在深入比较 Go 和 Rust&#xff0c;从不同的角度分析这两种语言&#xff0c;包括性能、语言特性、生…

你好!赫夫曼树【JAVA】

目录 1.简单介绍 2.术语 3.构建思路 4.创建节点类 5.创建赫夫曼树 6.前序遍历 7.小玩一把 1.简单介绍 赫夫曼树&#xff08;Huffman Tree&#xff09;又称最优二叉树&#xff0c;是一种带权路径长度最短的二叉树。它的构建主要用于数据压缩算法中&#xff0c;根据字…

【vue】jenkins打前端包时报错:第 8 行:cd: dist: 没有那个文件或目录

问题描述 jenkins打前端包时报错&#xff1a;第 8 行&#x1f4bf; dist: 没有那个文件或目录 Jenkins中 “Execute shell” 配置的脚本&#xff1a; echo $PATH node -v npm -v npm config set registry http://ued.edtsoft.com/ npm install npm run build:prod cd dist rm…

数据结构(七):树介绍及面试常考算法

一、树介绍 1、定义 树形结构是一种层级式的数据结构&#xff0c;由顶点&#xff08;节点&#xff09;和连接它们的边组成。 树类似于图&#xff0c;但区分树和图的重要特征是树中不存在环路。树有以下特点&#xff1a; &#xff08;1&#xff09;每个节点有零个或多个子节点…