DERT目标检测源码流程图main.py的执行

news2024/11/16 11:32:46

DERT目标检测源码流程图main.py的执行

官网预测脚本

补充官网提供的预测部分的代码信息。

from PIL import Image
import requests
import matplotlib.pyplot as plt

import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T
torch.set_grad_enabled(False)


class DETRdemo(nn.Module):
    """
    Demo DETR implementation.

    Demo implementation of DETR in minimal number of lines, with the
    following differences wrt DETR in the paper:
    * learned positional encoding (instead of sine)
    * positional encoding is passed at input (instead of attention)
    * fc bbox predictor (instead of MLP)
    The model achieves ~40 AP on COCO val5k and runs at ~28 FPS on Tesla V100.
    Only batch size 1 supported.
    """

    def __init__(self, num_classes, hidden_dim=256, nheads=8,
                 num_encoder_layers=6, num_decoder_layers=6):
        super().__init__()

        # create ResNet-50 backbone
        self.backbone = resnet50()
        del self.backbone.fc

        # create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create a default PyTorch transformer
        self.transformer = nn.Transformer(
            hidden_dim, nheads, num_encoder_layers, num_decoder_layers)

        # prediction heads, one extra class for predicting non-empty slots
        # note that in baseline DETR linear_bbox layer is 3-layer MLP
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)

        # output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        # propagate inputs through ResNet-50 up to avg-pool layer
        x = self.backbone.conv1(inputs)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)

        # construct positional encodings
        H, W = h.shape[-2:]
        pos = torch.cat([ # 张量的顺序进行转换
            self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
            self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
        ], dim=-1).flatten(0, 1).unsqueeze(1)

        # propagate through the transformer
        h = self.transformer(pos + 0.1 * h.flatten(2).permute(2, 0, 1),
                             self.query_pos.unsqueeze(1)).transpose(0, 1)

        # finally project transformer outputs to class labels and bounding boxes
        return {'pred_logits': self.linear_class(h),
                'pred_boxes': self.linear_bbox(h).sigmoid()}

detr = DETRdemo(num_classes=91)
state_dict = torch.hub.load_state_dict_from_url(
    url='https://dl.fbaipublicfiles.com/detr/detr_demo-da2a99e9.pth',
    map_location='cpu', check_hash=True)
detr.load_state_dict(state_dict)
detr.eval();


# COCO classes
CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
    'toothbrush'
]

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

# standard PyTorch mean-std input image normalization
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b


def detect(im, model, transform):
    # mean-std normalize the input image (batch-size: 1)
    img = transform(im).unsqueeze(0)

    # demo model only support by default images with aspect ratio between 0.5 and 2
    # if you want to use images with an aspect ratio outside this range
    # rescale your image so that the maximum size is at most 1333 for best results
    assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'

    # propagate through the model
    outputs = model(img)

    # keep only predictions with 0.7+ confidence
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > 0.7

    # convert boxes from [0; 1] to image scales
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
    return probas[keep], bboxes_scaled


url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
# url = "./test.jpg"
# im = Image.open(url)
im = Image.open(requests.get(url, stream=True).raw)

scores, boxes = detect(im, detr, transform)


def plot_results(pil_img, prob, boxes):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        cl = p.argmax()
        text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()


plot_results(im, scores, boxes)

在这里插入图片描述

核心流程图

  1. 整体执行流程概述
  2. 模型构建过程
  3. 前向传播与损失函数

需要代码注释部分可联系,简单原因不在提供代码注释。只关注断点调试得到的流程图信息。

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2172055.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于LangChain实现数据库操作的智能体

在 Retrieval 或者 ReACT 的一些场景中&#xff0c;常常需要数据库与人工智能结合。而 LangChain 本身就封装了许多相关的内容&#xff0c;在其官方文档-SQL 能力中&#xff0c;也有非常好的示例。 而其实现原理主要是通过 LLM 将自然语言转换为 SQL 语句&#xff0c;然后再通…

不懂性能测试,被面试官挂了...

性能测试旨在检查应用程序或软件在特定负载下工作时的响应性和稳定性&#xff0c;从而检测应用程序/软件在响应速度、可扩展性和稳定性方面是否达到预期的要求。 简而言之&#xff0c;性能测试目标就是为了识别并消除应用程序中的性能瓶颈。 本文将为大家详细介绍性能测试主要…

快手:从 Clickhouse 到 Apache Doris,实现湖仓分离向湖仓一体架构升级

导读&#xff1a;快手 OLAP 系统为内外多个场景提供数据服务&#xff0c;每天承载近 10 亿的查询请求。原有湖仓分离架构&#xff0c;由离线数据湖和实时数仓组成&#xff0c;面临存储冗余、资源抢占、治理复杂、查询调优难等问题。通过引入 Apache Doris 湖仓一体能力&#xf…

DAF-Net:一种基于域自适应的双分支特征分解融合网络用于红外和可见光图像融合

论文 DAF-Net: A Dual-Branch Feature Decomposition Fusion Network with Domain Adaptive for Infrared and Visible Image Fusion 提出了一种新的红外和可见光图像融合方法。该方法旨在结合红外图像和可见光图像的互补信息&#xff0c;以提供更全面的场景理解。红外图像在低…

另外知识与网络总结

一、重谈NAT&#xff08;工作在网络层&#xff09; 为什么会有NAT 为了解决ipv4地址太少问题&#xff0c;到了公网的末端就会有运营商路由器来构建私网&#xff0c;在不同私网中私有IP可以重复&#xff0c;这就可以缓解IP地址太少问题&#xff0c;但是这就导致私有IP是重复的…

【锁住精华】MySQL锁机制全攻略:从行锁到表锁,共享锁到排他锁,悲观锁到乐观锁

MySQL有哪些锁 1、按照锁的粒度划分 行锁 是最低粒度的的锁&#xff0c;锁住指定行的数据&#xff0c;加锁的开销较大&#xff0c;加锁较慢&#xff0c;可能会出现死锁的情况&#xff0c;锁的竞争度会较低&#xff0c;并发度相对较高。但是如果where条件里的字段没有加索引&…

【中级通信工程师】终端与业务(十):通信市场营销组合策略

【零基础3天通关中级通信工程师】 终端与业务(十)&#xff1a;通信市场营销组合策略 本文是中级通信工程师考试《终端与业务》科目第十章《通信市场营销组合策略》的复习资料和真题汇总。本章的核心内容涵盖了市场营销组合策略的特点、产品策略、价格策略、渠道策略和促销策略…

2206. 将数组划分成相等数对(排序/哈希)

目录 一&#xff1a;题目&#xff1a; 二&#xff1a;代码&#xff1a; 三&#xff1a;结果&#xff1a; 一&#xff1a;题目&#xff1a; 给你一个整数数组 nums &#xff0c;它包含 2 * n 个整数。 你需要将 nums 划分成 n 个数对&#xff0c;满足&#xff1a; 每个元素…

【开源项目】数字孪生智慧停车场——开源工程及源码

飞渡科技数字孪生停车场管理平台&#xff0c;基于国产数字孪生3D渲染引擎&#xff0c;结合数字孪生、物联网IOT&#xff0c;以及车牌自动识别、视频停车诱导等技术&#xff0c;实现停车场的自动化、可视化和无人化值守管理。 以3D可视化技术为基础&#xff0c;通过三维场景完整…

【原创】java+swing+mysql企业招聘管理系统设计与实现

个人主页&#xff1a;程序员杨工 个人简介&#xff1a;从事软件开发多年&#xff0c;前后端均有涉猎&#xff0c;具有丰富的开发经验 博客内容&#xff1a;全栈开发&#xff0c;分享Java、Python、Php、小程序、前后端、数据库经验和实战 文末有本人名片&#xff0c;希望和大家…

2024年9月第4周AI资讯

阅读时间&#xff1a;3-4min 更新时间&#xff1a;2024.9.23-2024.9.27 目录 o1 处于OpenAI的AGI5阶段的第2阶段 微软使用核燃料推动AI发展 阿里巴巴和英伟达在自动驾驶方向合作 Meta 推出 AR xAI 眼镜、新型号 o1 处于OpenAI的AGI5阶段的第2阶段 概要 OpenAI 首席执行官 …

怎么查看网站是否被谷歌收录,哪些因素影响着网站是否被谷歌收录

一、怎么查看网站是否被谷歌收录 查看网站是否被谷歌收录&#xff0c;有多种方法可供选择&#xff0c;以下是几种常用的方式&#xff1a; 1.使用“site:”指令&#xff1a; 在谷歌搜索引擎的搜索框中输入“site:你的域名网址”&#xff08;注意使用英文冒号&#x…

【GUI设计】基于Matlab的图像去噪GUI系统(8),matlab实现

博主简介&#xff1a; 如需获取设计的完整源代码或者有matlab图像代码项目需求/合作&#xff0c;可联系主页个人简介提供的联系方式或者文末的二维码。 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 本次案例是基于Matlab的图像去噪GUI系统&am…

CUDA error: no kernel image is available for execution on the device

记录一下出现上述问题的一个原因&#xff1a; pytorch所依赖的cuda版本不满足显卡(GPU)的算力要求&#xff01; 举例来说&#xff0c;显卡是3090&#xff0c;并按照以下命令安装Pytorch: conda install pytorch1.7.0 torchvision0.8.0 torchaudio0.7.0 cudatoolkit10.1 -c p…

STM32转AT32代码转换

1. 引言 在嵌入式开发中&#xff0c;我们经常会遇到更换单片机芯片的事情&#xff0c;若芯片是同一厂家的还好说&#xff0c;若是不同厂家的则需要重新写&#xff0c;重新调&#xff0c;重新去学习其底层驱动程序&#xff0c;比较费时费力。如&#xff1a;ST32转AT32、ST32转G…

数论——数数(找质因数个数),三位出题人(组合数学,快速幂)

数数&#xff08;找质因数个数&#xff09; 题目描述 登录—专业IT笔试面试备考平台_牛客网 运行代码&#xff08;通过率一半&#xff09; #include <iostream> #include <vector> using namespace std; const int N5e66; int n; vector<bool>vis; vo…

自定义认证过滤器和自定义授权过滤器

目录 通过数据库动态加载用户信息 具体实现步骤 一.创建数据库 二.编写secutity配置类 三.编写controller 四.编写服务类实现UserDetailsService接口类 五.debug springboot启动类 认证过滤器 SpringSecurity内置认证流程 自定义认证流程 第一步:自定义一个类继承Abstra…

AFSim仿真系统 --- 系统简解_01(任务模块)

任务 任务是AFSIM的基线可执行文件。通过任务&#xff0c;用户可以访问世界仿真框架&#xff08;WSF&#xff09;。该可执行文件&#xff08;mission.exe&#xff09;解释文本格式的仿真输入文件&#xff08;场景&#xff09;&#xff0c;以生成仿真&#xff0c;并可选择以多种…

c语言中的杨氏矩阵的介绍以及元素查找的方法

杨氏矩阵&#xff1a;是一个二维数组 特点&#xff1a;数组的每行从左到右都是递增的 数组的每列从上到下都是递增的 这种矩阵结构使得在查找特定元素时&#xff0c;可以利用其 递增性质来缩小范围&#xff0c;提高查找效率。 从杨氏矩阵中对元素进行查找 1&#xff0c;要求…

沟通技巧对班组长的工作效率有什么影响?

在企业管理的细微脉络中&#xff0c;班组长作为基层管理的中坚力量&#xff0c;其工作效率直接影响着整个生产线的流畅度与团队士气。而沟通技巧&#xff0c;作为人际交往的桥梁&#xff0c;不仅在日常交流中扮演着重要角色&#xff0c;更在班组管理中发挥着不可替代的作用。本…