基于transformer的解码decode目标检测框架(修改DETR源码)

news2024/11/7 21:01:53

提示:transformer结构的目标检测解码器,包含loss计算,附有源码

文章目录

  • 前言
  • 一、main函数代码解读
    • 1、整体结构认识
    • 2、main函数代码解读
    • 3、源码链接
  • 二、decode模块代码解读
    • 1、decoded的TransformerDec模块代码解读
    • 2、decoded的TransformerDecoder模块代码解读
    • 3、decoded的DecoderLayer模块代码解读
  • 三、decode模块训练demo代码解读
    • 1、解码数据输入格式
    • 2、解码训练demo代码解读
  • 四、decode模块预测demo代码解读
    • 1、预测数据输入格式
    • 2、解码预测demo代码解读
  • 五、losses模块代码解读
    • 1、matcher初始化
    • 2、二分匹配matcher代码解读
    • 3、num_classes参数解读
    • 4、losses的demo代码解读


前言

最近重温DETR模型,越发感觉detr模型结构精妙之处,不同于anchor base 与anchor free设计,直接利用100框给出预测结果,使用可学习learn query深度查找,使用二分匹配方式训练模型。为此,我基于detr源码提取解码decode、loss计算等系列模块,并重构、修改、整合一套解码与loss实现的框架,该框架可适用任何backbone特征提取接我框架,实现完整训练与预测,我也有相应demo指导使用我的框架。那么,接下来,我将完整介绍该框架源码。同时,我将此源码进行开源,并上传github中,供读者参考。


一、main函数代码解读

1、整体结构认识

在介绍main函数代码前,我先说下整体框架结构,该框架包含2个文件夹,一个losses文件夹,用于处理loss计算,一个是obj_det文件,用于transformer解码模块,该模块源码修改于detr模型,也包含main.py,该文件是整体解码与loss计算demo示意代码,如下图。

在这里插入图片描述

2、main函数代码解读

该代码实际是我随机创造了标签target数据与backbone特征提取数据及位置编码数据,使其能正常运行的demo,其代码如下:

import torch
from obj_det.transformer_obj import TransformerDec
from losses.matcher import HungarianMatcher
from losses.loss import SetCriterion

if __name__ == '__main__':


    Model = TransformerDec(d_model=256, output_intermediate_dec=True, num_classes=4)

    num_classes = 4   #  类别+1
    matcher = HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2)  # 二分匹配不同任务分配的权重
    losses = ['labels', 'boxes', 'cardinality']  # 计算loss的任务
    weight_dict = {
   'loss_ce': 1, 'loss_bbox': 5, 'loss_giou': 2}  # 为dert最后一个设置权重
    criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=0.1, losses=losses)

    # 下面使用iter,我构造了虚拟模型编码数据与数据加载标签数据
    src = torch.rand((391, 2, 256))
    pos_embed = torch.ones((391, 1, 256))

    # 创造真实target数据
    target1 = {
   'boxes':torch.rand((5,4)),'labels':torch.tensor([1,3,2,1,2])}
    target2 = {
   'boxes': torch.rand((3, 4)), 'labels': torch.tensor([1, 1, 2])}
    target = [target1, target2]

    res = Model(src, pos_embed)
    losses = criterion(res, target)
    print(losses)

如下图:

在这里插入图片描述

3、源码链接

源码链接:点击这里

二、decode模块代码解读

该模块主要是使用transform方式对backbone提取特征的解码,主要使用learn query等相关trike与transform解码方式内容。
我主要介绍TransformerDec、TransformerDecoder、DecoderLayer模块,为依次被包含关系,或说成后者是前者组成部分。

1、decoded的TransformerDec模块代码解读

该类大意是包含了learn query嵌入、解码transform模块调用、head头预测logit与boxes等内容,是实现解码与预测内容,该模块参数或解释已有注释,读者可自行查看,其代码如下:

class TransformerDec(nn.Module):
    '''
    d_model=512, 使用多少维度表示,实际为编码输出表达维度
    nhead=8, 有多少个头
    num_queries=100, 目标查询数量,可学习query
    num_decoder_layers=6, 解码循环层数
    dim_feedforward=2048, 类似FFN的2个nn.Linear变化
    dropout=0.1,
    activation="relu",
    normalize_before=False,解码结构使用2种方式,默认False使用post解码结构
    output_intermediate_dec=False, 若为True保存中间层解码结果(即:每个解码层结果保存),若False只保存最后一次结果,训练为True,推理为False
    num_classes: num_classes数量与数据格式有关,若类别id=1表示第一类,则num_classes=实际类别数+1,若id=0表示第一个,则num_classes=实际类别数

    额外说明,coco类别id是1开始的,假如有三个类,名称为[dog,cat,pig],batch=2,那么参数num_classes=4,表示3个类+1个背景,
    模型输出src_logits=[2,100,5]会多出一个预测,target_classes设置为[2,100],其值为4(该值就是背景,而有类别值为123),
    那么target_classes中没有值为0,我理解模型不对0类做任何操作,是个无效值,模型只对1234进行loss计算,然4为背景会比较多,
    作者使用权重0.1避免其背景过度影响。

    forward return: 返回字典,包含{
   
    'pred_logits':[],  # 为列表,格式为[b,100,num_classes+2]
    'pred_boxes':[],  # 为列表,格式为[b,100,4]
    'aux_outputs'[{
   },...] # 为列表,元素为字典,每个字典为{
   'pred_logits':[],'pred_boxes':[]},格式与上相同

    }

    '''

    def __init__(self, d_model=512, nhead=8, num_queries=100, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False, output_intermediate_dec=False, num_classes=1):
        super().__init__()

        self.num_queries = num_queries
        self.query_embed = nn.Embedding(num_queries, d_model)  # 与编码输出表达维度一致
        self.output_intermediate_dec = output_intermediate_dec

        decoder_layer = DecoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        decoder_norm = nn.LayerNorm(d_model)
        self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1164961.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件测试面试最经典的5个问题

软件测试面试灵魂五问! 请做一下自我介绍?你为什么从上家公司离职?为什么转行做测试? 你对测试行业的认识?你的期望薪资是多少?最后,你要问我什么? 一、请做一下自我介绍 简历上有的可以一两…

VLAN与配置

VLAN与配置 什么是VLAN 以最简单的形式为例。如下图,此时有4台主机处于同一局域网中,很明显这4台主机是能够直接通讯。但此时我需要让处于同一局域网中的PC3和PC4能通讯,PC5和PC6能通讯,并且PC3和PC4不能与PC5和PC6通讯。 为了实…

【工具】【IDE】Qt Creator社区版

Qt Creator社区版下载地址:https://download.qt.io/archive/qt/ 参考:https://cloud.tencent.com/developer/article/2084698?areaSource102001.8&traceIduMchNghqp8gWPdFHvSOGg MAC安装并配置Qt(超级简单版) 1.安装brew&…

单链表的详解实现

单链表 结构 单链表结构中有两个数据,一个是存储数据的,还有一个指针指向下一个节点。 该图就是一个简单单链表的结构图。 接口实现 SLNode* CreateNode(SLNDataType x);//申请节点 void SLTprint(SLNode* head);//打印链表 void SLTPushBack(SLNode*…

【Echarts】玫瑰饼图数据交互

在学习echarts玫瑰饼图的过程中,了解到三种数据交互的方法,如果对您也有帮助,不胜欣喜。 一、官网教程 https://echarts.apache.org/examples/zh/editor.html?cpie-roseType-simple (该教程数据在代码中) import *…

springboot-2.7.3+ES-7.10.0

跟着官网走,能干99。一年几次变,次次不一样。刚部署好ES-6.8,又买阿里云Es-7.10.0根本忙不完。 做为JDK1.8最后一个版本。今天就拿新技术部署一套。致辞:大家以后就用这套好了。别轻易触发springboot3.0了 学习无止境&#xff1…

【使用Python编写游戏辅助工具】第三篇:鼠标连击器的实现

前言 这里是【使用Python编写游戏辅助工具】的第三篇:鼠标连击器的实现。本文主要介绍使用Python来实现鼠标连击功能。 鼠标连击是指在很短的时间内多次点击鼠标按钮,通常是鼠标左键。当触发鼠标连击时,鼠标按钮会迅速按下和释放多次&#xf…

言情小说怎么推广?如何推广网络小说?

网络小说是一种文学形式,它的受众群体相当广泛,其实也面临着很强的竞争,因此,网络推广是小说宣传的一项重要工作,这里小马识途营销顾问就分享一下小说推广的渠道和方法。 1、软文推广 在推广小说的过程中,…

面试10000次依然会问的【synchronized】,你还不会?

引言 synchronized 关键字是实现线程同步的核心工具,它能够确保在任一时刻,只有一个线程能够访问被同步的方法或代码块。 这不仅保证了操作的原子性,即这些操作要么完全执行,要么完全不执行;同时也确保了操作的可见性…

高效操作,轻松打造企业百度百科,展现实力形象

百度百科已经成为企业提升形象的重要渠道,拥有自己的百科词条意味着企业在互联网上拥有更高的知名度和可信度。接下来,将为大家介绍企业百度百科的创建过程和一些技巧,帮助企业更好地在百度百科上展现自身实力。 首先,创建企业百度…

基于Tensorflow卷积神经网络玉米病害识别系统(UI界面)

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 Tensorflow是一个流行的机器学习框架,可用于训练和部署各种人工智能模型。玉米病害识别系统基于Tensorf…

明明用的不是自己机器视觉软件,甚至是盗版,机器视觉公司为什么还要申请那么多专利?

我首先看下专利是什么? 专利分为发明、实用新型、外观设计三种类型。 发明是指对产品、方法或者其改进所提出的新的技术方案。 实用新型是指对产品的形状构造或者其结合所提出的适于实用的新的技术方案。一般对日用品、机械、电器等产品的简单改进比较适用于申请…

Mysql数据目录结构以及文件类型解析

目录 1. 数据目录 2. Data目录 3. 数据库目录 1)db.opt 2).frm 3).MYD和.MYI 4).ibd 5).ibd和.ibdata 在 MySQL 中,物理文件存放在数据目录中。数据目录与安装目录不同,安装目录用来存储…

NLP之Bert介绍和简单示例

文章目录 1. Bert 介绍2. 代码示例 1. Bert 介绍 2. 代码示例

Express框架开发接口之轮播图API

1.获取所有轮播图、 // 处理轮播图 const handleDB require(../handleDB/index) // 获取所有轮播图 exports.allCarousel (req, res) > {(async function () {let results await handleDB(res, "book_carousel", "find", "查询数据出错&#xf…

Python 生成Android不同尺寸的图标

源代码 # -*- coding: utf-8 -*- import sys import os import shutil from PIL import Imagedef generateAndroidIcons():imageSource icon.pngicon Image.open(imageSource)sizes [(android/drawable,512),(android/drawable-hdpi,72),(android/drawable-ldpi,36),(andro…

C# 发送邮件

1.安装 NuGet 包 2.代码如下 SendMailUtil using MimeKit; using Srm.CMER.Application.Contracts.CmerInfo; namespace Srm.Mail { public class SendMailUtil { public async static Task<string> SendEmail(SendEmialDto sendEmialDto,List<strin…

11月2日星期四今日早报简报微语报早读

11月2日星期四&#xff0c;农历九月十九&#xff0c;早报微语早读分享。 1、茅台深夜提价&#xff1a;11月1日起飞天、五星出厂价格平均上调约20&#xff05;&#xff0c;贵州茅台&#xff1a;市场指导价不变&#xff1b; 2、杭州拟发文规范直播电商业&#xff1a;不得要求商…

2015年亚太杯APMCM数学建模大赛C题识别网络中的错误连接求解全过程文档及程序

2015年亚太杯APMCM数学建模大赛 C题 识别网络中的错误连接 原题再现 网络是描述真实系统结构的强大工具——社交网络描述人与人之间的关系&#xff0c;万维网描述网页之间的超链接关系。随着现代技术的发展&#xff0c;我们积累了越来越多的网络数据&#xff0c;但这些数据部…

vs2013/2015/2019扩展-联机提示“未能建立到服务器的连接“/“基础连接已经关闭: 发送时发生错误“/“远程主机强迫关闭了一个现有的连接“

VS2013\VS2015 输入命令 [Net.ServicePointManager]::SecurityProtocol[Net.ServicePointManager]::SecurityProtocol-bOR [Net.SecurityProtocolType]::Tls12 采用上述方法偶尔可以有效&#xff0c;重新启动VS就没用了 VS2019 怎么样都不行 最终解决办法&#xff1a;换一…