python基于DETR(DEtection TRansformer)开发构建钢铁产业产品智能自动化检测识别系统

news2024/11/18 10:41:24

在前文中我们基于经典的YOLOv5开发构建了钢铁产业产品智能自动化检测识别系统,这里本文的主要目的是想要实践应用DETR这一端到端的检测模型来开发构建钢铁产业产品智能自动化检测识别系统。

DETR (DEtection TRansformer) 是一种基于Transformer架构的端到端目标检测模型。与传统的基于区域提议的目标检测方法(如Faster R-CNN)不同,DETR采用了全新的思路,将目标检测问题转化为一个序列到序列的问题,通过Transformer模型实现目标检测和目标分类的联合训练。

DETR的工作流程如下:

输入图像通过卷积神经网络(CNN)提取特征图。
特征图作为编码器输入,经过一系列的编码器层得到图像特征的表示。
目标检测问题被建模为一个序列到序列的转换任务,其中编码器的输出作为解码器的输入。
解码器使用自注意力机制(self-attention)对编码器的输出进行处理,以获取目标的位置和类别信息。
最终,DETR通过一个线性层和softmax函数对解码器的输出进行分类,并通过一个线性层预测目标框的坐标。
DETR的优点包括:

端到端训练:DETR模型能够直接从原始图像到目标检测结果进行端到端训练,避免了传统目标检测方法中复杂的区域提议生成和特征对齐的过程,简化了模型的设计和训练流程。
不受固定数量的目标限制:DETR可以处理变长的输入序列,因此不受固定数量目标的限制。这使得DETR能够同时检测图像中的多个目标,并且不需要设置预先确定的目标数量。
全局上下文信息:DETR通过Transformer的自注意力机制,能够捕捉到图像中不同位置的目标之间的关系,提供了更大范围的上下文信息。这有助于提高目标检测的准确性和鲁棒性。
然而,DETR也存在一些缺点:

计算复杂度高:由于DETR采用了Transformer模型,它在处理大尺寸图像时需要大量的计算资源,导致其训练和推理速度相对较慢。
对小目标的检测性能较差:DETR模型在处理小目标时容易出现性能下降的情况。这是因为Transformer模型在处理小尺寸目标时可能会丢失细节信息,导致难以准确地定位和分类小目标。

首先看下实例效果:
 

简单看下数据集:

PyTorch训练代码和DETR(DEDetection-TRansformer)的预训练模型。我们用Transformer替换了完全复杂的手工制作的对象检测管道,并将Faster R-CNN与ResNet-50匹配,使用一半的计算能力(FLOP)和相同数量的参数在COCO上获得42个AP。

官方项目地址在这里,如下所示:

可以看到目前已经收获了超过1.2w的star量,还是很不错的了。

DETR整体数据流程示意图如下所示:

官方也提供了对应的预训练模型,可以自行使用:

本文选择的预训练官方权重是detr-r50-e632da11.pth,首先需要基于官方的预训练权重开发能够用于自己的 个性化数据集的权重,如下所示:

pretrained_weights = torch.load("./weights/detr-r50-e632da11.pth")
num_class = 10 + 1
pretrained_weights["model"]["class_embed.weight"].resize_(num_class+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_class+1)
torch.save(pretrained_weights,'./weights/detr_r50_%d.pth'%num_class)

因为这里我的类别数量为10,所以num_class修改为:10+1,根据自己的实际情况修改即可。生成后如下所示:

之后按照官方说明准备好数据集即可,启动训练模型命令如下所示:

python main.py --dataset_file "coco" --coco_path "/0000" --epoch 100 --lr=1e-4 --batch_size=32 --num_workers=0 --output_dir="outputs" --resume="weights/detr_r50_11.pth"

借助于plot_util.py模块可以实现对模型的评估和可视化,如下:

def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
    '''
    Function to plot specific fields from training log(s). Plots both training and test results.

    :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
              - fields = which results to plot from each log file - plots both training and test for each field.
              - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
              - log_name = optional, name of log file if different than default 'log.txt'.

    :: Outputs - matplotlib plots of results in fields, color coded for each log file.
               - solid lines are training results, dashed lines are test results.

    '''
    func_name = "plot_utils.py::plot_logs"

    # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
    # convert single Path to list to avoid 'not iterable' error

    if not isinstance(logs, list):
        if isinstance(logs, PurePath):
            logs = [logs]
            print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
        else:
            raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
            Expect list[Path] or single Path obj, received {type(logs)}")

    # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir
    for i, dir in enumerate(logs):
        if not isinstance(dir, PurePath):
            raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
        if not dir.exists():
            raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
        # verify log_name exists
        fn = Path(dir / log_name)
        if not fn.exists():
            print(f"-> missing {log_name}.  Have you gotten to Epoch 1 in training?")
            print(f"--> full path of missing log file: {fn}")
            return

    # load log file(s) and plot
    dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]

    fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))

    for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
        for j, field in enumerate(fields):
            if field == 'mAP':
                coco_eval = pd.DataFrame(
                    np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1]
                ).ewm(com=ewm_col).mean()
                axs[j].plot(coco_eval, c=color)
            else:
                df.interpolate().ewm(com=ewm_col).mean().plot(
                    y=[f'train_{field}', f'test_{field}'],
                    ax=axs[j],
                    color=[color] * 2,
                    style=['-', '--']
                )
    for ax, field in zip(axs, fields):
        ax.legend([Path(p).name for p in logs])
        ax.set_title(field)


def plot_precision_recall(files, naming_scheme='iter'):
    if naming_scheme == 'exp_id':
        # name becomes exp_id
        names = [f.parts[-3] for f in files]
    elif naming_scheme == 'iter':
        names = [f.stem for f in files]
    else:
        raise ValueError(f'not supported {naming_scheme}')
    fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
    for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
        data = torch.load(f)
        # precision is n_iou, n_points, n_cat, n_area, max_det
        precision = data['precision']
        recall = data['params'].recThrs
        scores = data['scores']
        # take precision for all classes, all areas and 100 detections
        precision = precision[0, :, :, 0, -1].mean(1)
        scores = scores[0, :, :, 0, -1].mean(1)
        prec = precision.mean()
        rec = data['recall'][0, :, 0, -1].mean()
        print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
              f'score={scores.mean():0.3f}, ' +
              f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
              )
        axs[0].plot(recall, precision, c=color)
        axs[1].plot(recall, scores, c=color)

    axs[0].set_title('Precision / Recall')
    axs[0].legend(names)
    axs[1].set_title('Scores / Recall')
    axs[1].legend(names)
    return fig, axs

结果如下所示:

iter 000: mAP@50= 24.0, score=0.317, f1=0.341
iter 050: mAP@50= 27.7, score=0.339, f1=0.400
iter latest: mAP@50= 26.4, score=0.348, f1=0.393
iter 000: mAP@50= 24.0, score=0.317, f1=0.341
iter 050: mAP@50= 27.7, score=0.339, f1=0.400
iter latest: mAP@50= 26.4, score=0.348, f1=0.393

可视化如下所示:

【Precision曲线】
精确率曲线(Precision-Recall Curve)是一种用于评估二分类模型在不同阈值下的精确率性能的可视化工具。它通过绘制不同阈值下的精确率和召回率之间的关系图来帮助我们了解模型在不同阈值下的表现。
精确率(Precision)是指被正确预测为正例的样本数占所有预测为正例的样本数的比例。召回率(Recall)是指被正确预测为正例的样本数占所有实际为正例的样本数的比例。
绘制精确率曲线的步骤如下:
使用不同的阈值将预测概率转换为二进制类别标签。通常,当预测概率大于阈值时,样本被分类为正例,否则分类为负例。
对于每个阈值,计算相应的精确率和召回率。
将每个阈值下的精确率和召回率绘制在同一个图表上,形成精确率曲线。
根据精确率曲线的形状和变化趋势,可以选择适当的阈值以达到所需的性能要求。
通过观察精确率曲线,我们可以根据需求确定最佳的阈值,以平衡精确率和召回率。较高的精确率意味着较少的误报,而较高的召回率则表示较少的漏报。根据具体的业务需求和成本权衡,可以在曲线上选择合适的操作点或阈值。
精确率曲线通常与召回率曲线(Recall Curve)一起使用,以提供更全面的分类器性能分析,并帮助评估和比较不同模型的性能。
【Recall曲线】
召回率曲线(Recall Curve)是一种用于评估二分类模型在不同阈值下的召回率性能的可视化工具。它通过绘制不同阈值下的召回率和对应的精确率之间的关系图来帮助我们了解模型在不同阈值下的表现。
召回率(Recall)是指被正确预测为正例的样本数占所有实际为正例的样本数的比例。召回率也被称为灵敏度(Sensitivity)或真正例率(True Positive Rate)。
绘制召回率曲线的步骤如下:
使用不同的阈值将预测概率转换为二进制类别标签。通常,当预测概率大于阈值时,样本被分类为正例,否则分类为负例。
对于每个阈值,计算相应的召回率和对应的精确率。
将每个阈值下的召回率和精确率绘制在同一个图表上,形成召回率曲线。
根据召回率曲线的形状和变化趋势,可以选择适当的阈值以达到所需的性能要求。
通过观察召回率曲线,我们可以根据需求确定最佳的阈值,以平衡召回率和精确率。较高的召回率表示较少的漏报,而较高的精确率意味着较少的误报。根据具体的业务需求和成本权衡,可以在曲线上选择合适的操作点或阈值。
召回率曲线通常与精确率曲线(Precision Curve)一起使用,以提供更全面的分类器性能分析,并帮助评估和比较不同模型的性能。

【PR曲线】
精确率-召回率曲线(Precision-Recall Curve)是一种用于评估二分类模型性能的可视化工具。它通过绘制不同阈值下的精确率(Precision)和召回率(Recall)之间的关系图来帮助我们了解模型在不同阈值下的表现。
精确率是指被正确预测为正例的样本数占所有预测为正例的样本数的比例。召回率是指被正确预测为正例的样本数占所有实际为正例的样本数的比例。
绘制精确率-召回率曲线的步骤如下:
使用不同的阈值将预测概率转换为二进制类别标签。通常,当预测概率大于阈值时,样本被分类为正例,否则分类为负例。
对于每个阈值,计算相应的精确率和召回率。
将每个阈值下的精确率和召回率绘制在同一个图表上,形成精确率-召回率曲线。
根据曲线的形状和变化趋势,可以选择适当的阈值以达到所需的性能要求。
精确率-召回率曲线提供了更全面的模型性能分析,特别适用于处理不平衡数据集和关注正例预测的场景。曲线下面积(Area Under the Curve, AUC)可以作为评估模型性能的指标,AUC值越高表示模型的性能越好。
通过观察精确率-召回率曲线,我们可以根据需求选择合适的阈值来权衡精确率和召回率之间的平衡点。根据具体的业务需求和成本权衡,可以在曲线上选择合适的操作点或阈值。 

感兴趣的话可以自行动手实践尝试下!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1252913.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

字符串转换成十进制整数

编程要求 输入一个以#结束的字符串,本题要求滤去所有的非十六进制字符(不分大小写),组成一个新的表示十六进制数字的字符串,然后将其转换为十进制数后输出。如果在第一个十六进制字符之前存在字符“-”,则…

vue2:组件中extends的使用

上一篇文章中我对mixin的使用进行了一个使用和测试,这里对extend进行一个使用,其实extend和mixin还是有区别的。 上一篇文章:vue2:mixin混入的使用-CSDN博客 不过也是看实际的业务场景,我们也可以使用extend完成和mixin几乎一摸一样的操作。 不废话,上代码 创建extendTest.…

MySQL表连接

文章目录 MySQL内外连接1.内连接2.外连接(1)左外连接(2)右外连接 3.简单案例 MySQL内外连接 1.内连接 内连接的SQL如下: SELECT ... FROM t1 INNER JOIN t2 ON 连接条件 [INNER JOIN t3 ON 连接条件] ... AND 其他条件;说明一下…

vulnhub靶机Aragog-1.0.2

靶机下载:https://download.vulnhub.com/harrypotter/Aragog-1.0.2.ova 主机发现 目标133 端口扫描 端口服务扫描 漏洞扫描 去看web 就一个图片 直接目录扫描 有点东西一个一个看(blog里面有东西) 这里面直接能看到wordpress 直接用wpscan…

Go 语言之 Maps 详解:创建、遍历、操作和注意事项

Maps用于以键值对的形式存储数据值。Maps中的每个元素都是一个键值对。Maps是一个无序且可更改的集合,不允许重复。Maps的长度是其元素的数量。您可以使用 len() 函数来查找长度。Maps的默认值是 nil。Maps保存对底层哈希表的引用。 Go语言有多种方法来创建Maps。 …

VS中如何使用Halcon

使用Halcon的本质就是调用Halcon的库,其主要步骤有: 1、将Halcon代码导出为C的.cpp文件 2、获取.cpp文件中的action函数的函数体 3、添加Halcon的动态库和静态库 4、添加action函数需要的头文件 导出halcon中的代码 a)导出代码 b&#x…

运行软件报错找不到vcruntime140_1.dll无法继续执行代码如何解决?-常见问题

关于vcruntime140_1.dll丢失的6个解决方法。在我们使用电脑的过程中,有时候会遇到一些错误提示,其中之一就是“vcruntime140_1.dll丢失”。那么,究竟什么是vcruntime140_1.dll文件呢?又是什么原因导致了它的丢失?接下来…

基于uniapp+vue微信小程序的健康饮食管理系统 907m6

设计这个微信小程序系统能使用户实现不需出门就可以在手机或电脑前进行网上查询美食信息、 运动视频等功能。 本系统由用户和管理员两大模块组成。用户界面显示在应用程序中,管理员界面显示在后台服务中,通过小程序端与服务端间进行数据交互与数据传输实…

微服务实战系列之Nginx(技巧篇)

前言 今天北京早晨竟然飘了一些“雪花”,定睛一看,似雪非雪,像泡沫球一样,原来那叫“霰”。 自然中,雨雪霜露雾,因为出场太频繁,认识门槛较低,自然不费吹灰之力,即可享受…

电脑技巧:电脑常见蓝屏、上不了网等故障及解决办法

目录 一、电脑蓝屏 常见原因1: 病毒木马 常见原因2: 安装了不兼容的软件 二、电脑不能上网 常见原因1: 新装系统无驱动 常见原因2: DNS服务器异常 常见原因3: 硬件问题 三、电脑没声音 常见原因1: 未安装驱动 常见原因2: 硬件故障 四、电脑屏幕不显示 常见原因1: 显…

【限流配电开关】TPS2001C

🚩 WRITE IN FRONT 🚩 🔎 介绍:"謓泽"正在路上朝着"攻城狮"方向"前进四" 🔎🏅 荣誉:2021|2022年度博客之星物联网与嵌入式开发TOP5|TOP4、2021|2222年获评百大…

物联网AI 无线连接学习之蓝牙基础篇 协议的发展

学物联网,来万物简单IoT物联网!! 蓝牙由来 “蓝牙”(Bluetooth)原是一位在10世纪统一丹麦的国王哈拉尔 (HaralBluetooth),他将当时的瑞典、芬兰与丹麦统一起来。而将“蓝牙”与后来的无线通讯技术标准关联…

SpringBoot——配置及原理

优质博文:IT-BLOG-CN 一、Spring Boot全局配置文件 application.properties与application.yml配置文件的作用:可以覆盖SpringBoot配置的默认值。 ◀ YML(is not a Markup Language:不仅仅是一个标记语言)&#xff1…

plotneuralnet和netron结合绘制模型架构图

plotneuralnet和netron结合绘制模型架构图 一、plotneuralnet 本身的操作 模型结构图的可视化,能直观展示模型的结构以及各个模块之间的关系。最近借助plotneuralnet python库(windows版)绘制了一个网络结构图,有一些经验和心得…

【SpringCloud】设计原则之单一职责与服务拆分

一、设计原则之单一职责 设计原则很重要的一点就是简单,单一职责也就是所谓的专人干专事 一个单元(一个类、函数或微服务)应该有且只有一个职责 无论如何,一个微服务不应该包含多于一个的职责 职责单一的后果之一就是职责单…

产品经理和项目经理的区别在哪里?

在当今的商业环境中,产品经理和项目经理扮演着两个非常重要的角色。虽然他们都是组织成功的重要推动者,但他们的职责和所关注的角度却有着显著的不同。 1. 职责差异 产品经理负责一个或多个产品线,从产品的规划到开发、推出和市场表现。他们…

Java实现-数据结构 2.时间和空间复杂度

.如何衡量一个算法的好坏:时间复杂度和空间复杂度 算法效率分为时间效率和空间效率,时间效率称为时间复杂度,空间效率称为空间复杂度 时间复杂度 算法的时间复杂度是一个数学函数,它描述了算法的运行时间,一个算法执…

分享一个鬼~

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。 先看效果: 上源码: import GUI from "https://cdn.jsdelivr.net/npm/lil-gui0.18.2/esm"const canv…

为社会做贡献的EasyDarwin 4.0.1发布了,支持视频点播、文件直播、摄像机直播、直播录像、直播回放、录像MP4合成下载

经过几个月的不懈努力和测试,最新的EasyDarwin 4.0版本总算是发布出来了,功能还是老几样:文件点播、视频直播(支持各种视频源)、直播录像与回放、录像合成MP4下载,稍稍看一下细节: 文件上传与点…

【UCAS自然语言处理作业二】训练FFN, RNN, Attention机制的语言模型,并计算测试集上的PPL

文章目录 前言前馈神经网络数据组织Dataset网络结构训练超参设置 RNN数据组织&Dataset网络结构训练超参设置 注意力网络数据组织&Dataset网络结构Attention部分完整模型 训练部分超参设置 结果与分析训练集Loss测试集PPL 前言 本次实验主要针对前馈神经网络&#xff0…