python基于DETR(DEtection TRansformer)开发构建人员手持物品检测识别分析系统

news2025/1/11 19:44:20

PyTorch训练代码和DETR(DEDetection-TRansformer)的预训练模型。我们用Transformer替换了完全复杂的手工制作的对象检测管道,并将Faster R-CNN与ResNet-50匹配,使用一半的计算能力(FLOP)和相同数量的参数在COCO上获得42个AP。

官方项目地址在这里,如下所示:
 

可以看到目前已经收获了超过1.2w的star量,还是很不错的了。

DETR整体数据流程示意图如下所示:

官方也提供了对应的预训练模型,可以自行使用:

首先按照README基础操作按照配置环境,如下所示:

相关的预处理内容在我前面的博文中都有涉及,这里就不再展开介绍了。

DETR (DEtection TRansformer) 是一种基于Transformer架构的端到端目标检测模型。与传统的基于区域提议的目标检测方法(如Faster R-CNN)不同,DETR采用了全新的思路,将目标检测问题转化为一个序列到序列的问题,通过Transformer模型实现目标检测和目标分类的联合训练。

DETR的工作流程如下:

  1. 输入图像通过卷积神经网络(CNN)提取特征图。
  2. 特征图作为编码器输入,经过一系列的编码器层得到图像特征的表示。
  3. 目标检测问题被建模为一个序列到序列的转换任务,其中编码器的输出作为解码器的输入。
  4. 解码器使用自注意力机制(self-attention)对编码器的输出进行处理,以获取目标的位置和类别信息。
  5. 最终,DETR通过一个线性层和softmax函数对解码器的输出进行分类,并通过一个线性层预测目标框的坐标。

DETR的优点包括:

  1. 端到端训练:DETR模型能够直接从原始图像到目标检测结果进行端到端训练,避免了传统目标检测方法中复杂的区域提议生成和特征对齐的过程,简化了模型的设计和训练流程。
  2. 不受固定数量的目标限制:DETR可以处理变长的输入序列,因此不受固定数量目标的限制。这使得DETR能够同时检测图像中的多个目标,并且不需要设置预先确定的目标数量。
  3. 全局上下文信息:DETR通过Transformer的自注意力机制,能够捕捉到图像中不同位置的目标之间的关系,提供了更大范围的上下文信息。这有助于提高目标检测的准确性和鲁棒性。

然而,DETR也存在一些缺点:

  1. 计算复杂度高:由于DETR采用了Transformer模型,它在处理大尺寸图像时需要大量的计算资源,导致其训练和推理速度相对较慢。
  2. 对小目标的检测性能较差:DETR模型在处理小目标时容易出现性能下降的情况。这是因为Transformer模型在处理小尺寸目标时可能会丢失细节信息,导致难以准确地定位和分类小目标。

现在我们来对比一下DETR和YOLO系列以及SSD等知名目标检测模型的优劣:

YOLO系列(包括YOLOv1、YOLOv2、YOLOv3和YOLOv4)和SSD是基于锚框的目标检测方法。它们的优点包括:

  1. 实时性能较好:YOLO系列和SSD通过使用锚框和特征金字塔网络,能够在保持较高检测准确性的同时,实现实时目标检测。
  2. 对小目标的检测效果较好:锚框的使用使得YOLO系列和SSD对小目标的检测能力相对较强。
  3. 计算效率高:相对于DETR的Transformer模型,YOLO系列和SSD的计算复杂度较低,因此训练和推理速度更快。

然而,YOLO系列和SSD也存在一些缺点:

  1. 定位精度相对较低:由于采用了固定数量的锚框,YOLO系列和SSD在目标定位方面的精度相对较低。特别是对于小尺寸目标,容易出现边界框偏移或不完整的情况。
  2. 对密集目标的处理困难:由于锚框的固定尺寸和位置,YOLO系列和SSD在处理密集目标(多个目标在空间上重叠)时可能存在困难,容易发生目标漏检或重叠框的问题。

综上所述,DETR相对于YOLO系列和SSD等基于锚框的目标检测模型具有端到端训练、不受固定目标数量限制和全局上下文信息等优点。然而,DETR在计算复杂度和对小目标的检测性能方面存在一些限制。对于实时性能要求高且注重目标定位精度的场景,YOLO系列和SSD可能是更好的选择。而对于需要全局上下文信息和不受固定目标数量限制的场景,DETR可能更适合。选择适合的目标检测模型应根据具体应用场景和需求进行评估。

首先看下整体效果:

接下来看下数据集:

需要借助于脚本转化处理为coco格式的:

这块网上现成的教程自行百度即可。

默认是100次epoch的迭代计算,看下结果详情,如下所示:

训练完成截图如下所示:

整体训练过程可视化核心实现如下所示:

def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
    '''
    Function to plot specific fields from training log(s). Plots both training and test results.

    :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
              - fields = which results to plot from each log file - plots both training and test for each field.
              - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
              - log_name = optional, name of log file if different than default 'log.txt'.

    :: Outputs - matplotlib plots of results in fields, color coded for each log file.
               - solid lines are training results, dashed lines are test results.

    '''
    func_name = "plot_utils.py::plot_logs"

    # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
    # convert single Path to list to avoid 'not iterable' error

    if not isinstance(logs, list):
        if isinstance(logs, PurePath):
            logs = [logs]
            print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
        else:
            raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
            Expect list[Path] or single Path obj, received {type(logs)}")

    # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
    for i, dir in enumerate(logs):
        if not isinstance(dir, PurePath):
            raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
        if not dir.exists():
            raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
        # verify log_name exists
        fn = Path(dir / log_name)
        if not fn.exists():
            print(f"-> missing {log_name}.  Have you gotten to Epoch 1 in training?")
            print(f"--> full path of missing log file: {fn}")
            return

    # load log file(s) and plot
    dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]

    fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))

    for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
        for j, field in enumerate(fields):
            if field == 'mAP':
                coco_eval = pd.DataFrame(
                    np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
                ).ewm(com=ewm_col).mean()
                axs[j].plot(coco_eval, c=color)
            else:
                df.interpolate().ewm(com=ewm_col).mean().plot(
                    y=[f'train_{field}', f'test_{field}'],
                    ax=axs[j],
                    color=[color] * 2,
                    style=['-', '--']
                )
    for ax, field in zip(axs, fields):
        ax.legend([Path(p).name for p in logs])
        ax.set_title(field)


def plot_precision_recall(files, naming_scheme='iter'):
    if naming_scheme == 'exp_id':
        # name becomes exp_id
        names = [f.parts[-3] for f in files]
    elif naming_scheme == 'iter':
        names = [f.stem for f in files]
    else:
        raise ValueError(f'not supported {naming_scheme}')
    fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
    for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
        data = torch.load(f)
        # precision is n_iou, n_points, n_cat, n_area, max_det
        precision = data['precision']
        recall = data['params'].recThrs
        scores = data['scores']
        # take precision for all classes, all areas and 100 detections
        precision = precision[0, :, :, 0, -1].mean(1)
        scores = scores[0, :, :, 0, -1].mean(1)
        prec = precision.mean()
        rec = data['recall'][0, :, 0, -1].mean()
        print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
              f'score={scores.mean():0.3f}, ' +
              f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
              )
        axs[0].plot(recall, precision, c=color)
        axs[1].plot(recall, scores, c=color)

    axs[0].set_title('Precision / Recall')
    axs[0].legend(names)
    axs[1].set_title('Scores / Recall')
    axs[1].legend(names)
    return fig, axs

结果如下所示:

对应模型评估指标如下所示:

感兴趣的话可以自行动手实践尝试下!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1219254.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

React升级到18版本

前言 升级前react版本是16.9.0,react-dom版本为16.9.0,react-router-dom为5.1.2版本。 安装 // npm npm install react react-dom// yarn yarn add react react-dom// pnpm pnpm install react react-dom启动项目 此时,项目可以正常运行&…

【python】——控制语句和组合数据类型(其二)

🎃个人专栏: 🐬 算法设计与分析:算法设计与分析_IT闫的博客-CSDN博客 🐳Java基础:Java基础_IT闫的博客-CSDN博客 🐋c语言:c语言_IT闫的博客-CSDN博客 🐟MySQL&#xff1a…

【开源】基于JAVA的服装店库存管理系统

项目编号: S 052 ,文末获取源码。 \color{red}{项目编号:S052,文末获取源码。} 项目编号:S052,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 角色管理模块2.3 服…

OpenCV技术应用(3)— 把.png图像保存为.jpg图像

前言:Hello大家好,我是小哥谈。本节课就手把手教你如何把.png图像保存为.jpg图像,希望大家学习之后能够有所收获~!🌈 目录 🚀1.技术介绍 🚀2.实现代码 🚀1.技术介绍 如果在电脑某…

image is being used by stopped container 7d2ff8620f3b 删除镜像失败怎么办

这个错误信息表明,镜像 55860ee0cd73 正被一个已停止的容器 7d2ff8620f3b 使用,因此无法正常删除。要解决这个问题,你有两个选择: 删除使用该镜像的容器:首先删除引用该镜像的容器,然后再删除镜像。这可以通…

到站上海!见证这座零碳园区的绿色低碳新选择

不知不觉中,科士达新能源的零碳足迹已遍布五洲四海,为全球各地,千行百业、千家万户,带去了源源不断的绿色能源和低碳新选择。再次启航,这一站,抵达上海世博园。 小机身,大配置,灵活适…

扬帆未来,成就架构之路:十本书籍助力你的架构师梦想 | 文末送书

相信大家都对未来的职业发展有着憧憬和规划,要做架构师、要做技术总监、要做CTO。对于如何实现自己的职业规划也都信心满满,努力工作、好好学习、不断提升自己。 规划职业发展 当涉及未来职业发展时,我们都怀揣着远大的目标和野心。对许多人…

移动硬盘打不开?正确操作方法分享!

“我的移动硬盘用了好几年了,但是不知道为什么,最近每次把移动硬盘插入电脑都显示无法打开。我还有一些很重要的数据在里面呢,有什么比较好的方法可以解决这个问题吗?” 作为一个便捷的存储工具,移动硬盘给我们带来了很…

【汇编】内存中字的存储、用DS和[address]实现字的传送、DS与数据段

文章目录 前言一、内存中字的存储1.1 8086cpu字的概念1.2 16位的字存储在一个16位的寄存器中,如何存储?1.3 字单元 二、用DS和[address]实现字的传送2.1 字的传送是什么意思?2.2 要求原理解决方案:DS和[address]配合8086传送16字节…

【LeetCode刷题-滑动窗口】--340.至多包含K个不同字符的最长子串

340.至多包含K个不同字符的最长子串 class Solution {public int lengthOfLongestSubstringKDistinct(String s, int k) {int len s.length();if(len < k){return len;}//滑动窗口的左右指针int left 0,right 0;//定义一个哈希映射HashMap<Character,Integer> hash…

外汇天眼:世界级的交流碰撞!Wiki Finance EXPO悉尼2023圆满落幕

11月16日对于外汇天眼来说是个特殊的日子&#xff0c;而对于整个世界金融界来说也是一个意义非凡的日子&#xff0c;由WikiGloba展会品牌WiKiEXPO举办的2023 Wiki Finance Expo在悉尼富丽敦酒店顺利举办并圆满落幕。 金融科技作为金融业创新变革的重要引擎&#xff0c;在全球数…

【Python基础篇】字符串的拼接

博主&#xff1a;&#x1f44d;不许代码码上红 欢迎&#xff1a;&#x1f40b;点赞、收藏、关注、评论。 格言&#xff1a; 大鹏一日同风起&#xff0c;扶摇直上九万里。 文章目录 一 Python中的字符串拼接二 join函数拼接三 os.path.join函数拼接四 号拼接五 &#xff0c;号…

【开源】基于JAVA的高校宿舍调配管理系统

项目编号&#xff1a; S 051 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S051&#xff0c;文末获取源码。} 项目编号&#xff1a;S051&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能需求2.1 学生端2.2 宿管2.3 老师端 三、系统…

⑩② 【MySQL索引】详解MySQL`索引`:结构、分类、性能分析、设计及使用规则。

个人简介&#xff1a;Java领域新星创作者&#xff1b;阿里云技术博主、星级博主、专家博主&#xff1b;正在Java学习的路上摸爬滚打&#xff0c;记录学习的过程~ 个人主页&#xff1a;.29.的博客 学习社区&#xff1a;进去逛一逛~ MySQL索引 ⑩② 【MySQL索引】1. 索引2. 索引的…

Spring Cloud Netflix微服务组件-Hystrix

目录 Hystrix的主要功能 传统容错手段 超时机制 应用容错三板斧 超时机制 舱壁隔离 熔断降级 侵入式Command用法 改进版一&#xff1a;ribbon与hystrix组合 改进版二&#xff1a;feign与hystrix组合 Hystrix三态转换图 源码分析 流程图 原理流程图 核心实现流程…

【开源】基于Vue.js的社区买菜系统的设计和实现

项目编号&#xff1a; S 011 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S011&#xff0c;文末获取源码。} 项目编号&#xff1a;S011&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、系统设计2.1 功能模块设计2.1.1 数据中心模块2.1…

11.13ASM图,FSM的一些verliog实现,串并转换

ASM图 除法器 FSM的verilog实现 状态机的状态就不用编码&#xff0c;而是用参数的定义进行转换 1. 在次态确定当中&#xff0c;只要w和y发生变化&#xff0c;就进行操作 在次态当中&#xff0c;只要时钟上升沿来临或者复位信号&#xff0c;就进行操作 2 Mealy型状态机 串…

CNN(卷积神经网络)、RNN(循环神经网络)、DNN(深度神经网络)的内部网络结构有什么区别?

【导师不教&#xff1f;我来教&#xff01;】同济计算机博士半小时就教会了我五大深度神经网络&#xff0c;CNN/RNN/GAN/transformer/LSTM一次学会&#xff0c;简直不要太强&#xff01;_哔哩哔哩_bilibili了解的五大神经网络&#xff0c;整理笔记如下&#xff1a; 视频是唐宇…

瑞萨RZ/G2L平台 初起动(SD卡启动)

文章目录 一 准备条件1 工具2 硬件3 镜像 二 烧录SD卡启动盘三 写Bootloader1 烧录文件2 启动烧录3 烧录 四 启动设置 一 准备条件 1 工具 ** BalenaEtcher&#xff08;俗称“ Etcher”&#xff09;&#xff0c;是一款快速将系统镜像文件&#xff08; .iso 或 .img 或 .zip或…

数据库实验报告(六)

实验报告&#xff08;六&#xff09; 1、实验目的 &#xff08;1&#xff09; 掌握关联查询的用法 &#xff08;2&#xff09; 掌握集合查询的区别和用法 &#xff08;3&#xff09; 掌握EXISTS的用法 2、实验预习与准备 &#xff08;1&#xff09; 了解ANY&…