RT-DETR中的CCFF结构代码详解(Pytorch)

news2024/9/20 10:30:28

代码链接

lyuwenyu/RT-DETR: [CVPR 2024] Official RT-DETR (RTDETR paddle pytorch), Real-Time DEtection TRansformer, DETRs Beat YOLOs on Real-time Object Detection. 🔥 🔥 🔥 (github.com)icon-default.png?t=N7T8https://github.com/lyuwenyu/RT-DETR

模型图

CCFF是作者提出的一种类似于特征金字塔的特征融合模块,S3,S4,S5是backbone的后三层,作者在论文中证明了只对S5进行尺度内交互,而不对更低级别的特征进行尺度内交互,并对次做法的合理性进行了证明,再次不多赘述

“基于上述分析,我们重新思考编码器的结构,提出了一种有效的混合编码器,由基于注意力的尺度内特征交互(AIFI)和基于 CNN 的跨尺度特征融合(CCFF)两个模块组成。具体来说,AIFI 通过使用单尺度 Transformer 编码器仅在 S5 上执行尺度内交互,进一步降低了基于变体 D 的计算成本。原因是将自注意力操作应用于具有更丰富语义概念的高级特征可以捕获概念实体之间的连接,这有助于后续模块对对象的定位和识别。然而,由于缺乏语义概念以及重复和与高级特征交互混淆的风险,低级特征的尺度内交互是不必要的。为了验证这一观点,我们仅在变体 D 中的 S5 上执行尺度内交互,实验结果如表 3 所示(参见第 DS5 行)。与D相比,DS5 不仅显着减少了延迟(快 35%),而且提高了准确性(AP 高 0.4%)。CCFF是基于跨尺度融合模块优化的,该模块将多个由卷积层组成的融合块插入到融合路径中。融合块的作用是将两个相邻的尺度特征融合到一个新的特征中,其结构如图5所示。融合块包含两个1 × 1卷积来调整通道数,还使用了N个RepBlock(由RepConv[8]组成的)进行特征融合,两条路径输出通过元素相加被融合。我们将混合编码器的计算表述为:”

本文要探讨的是这个CCFF结构,图像他的线画的有点乱,要是不看代码的话,完全理解不了,或者说,不敢吓理解,但是看过代码之后,你会发现他画的很正确:下面直入正题:

代码

首先直接找到Hybird Encoder类,CCFF结构被包含在内

@register
class HybridEncoder(nn.Module):
    def __init__(self,
                 in_channels=[512, 1024, 2048],#S3,S4,S5分别对应的通道数
                 feat_strides=[8, 16, 32],#如果对S3,S4,S5进行位置编码所需要的步长数字
                 hidden_dim=256,
                 nhead=8,
                 dim_feedforward = 1024,
                 dropout=0.0,
                 enc_act='gelu',
                 use_encoder_idx=[2],#一共传进来三层,编号是0,1,2,它只在最后一层进行AIFI(作者的Attention操作)所以这里就是个2
                 num_encoder_layers=1,
                 pe_temperature=10000,
                 expansion=1.0,
                 depth_mult=1.0,
                 act='silu',
                 eval_spatial_size=None):
        super().__init__()
        self.in_channels = in_channels
        self.feat_strides = feat_strides
        self.hidden_dim = hidden_dim
        self.use_encoder_idx = use_encoder_idx
        self.num_encoder_layers = num_encoder_layers
        self.pe_temperature = pe_temperature
        self.eval_spatial_size = eval_spatial_size

        self.out_channels = [hidden_dim for _ in range(len(in_channels))]
        self.out_strides = feat_strides

一个对三层特征的初始化映射层,以1*1卷积核将三层的通道数都映射为hidden_dim

        # channel projection
        self.input_proj = nn.ModuleList()
        for in_channel in in_channels:
            self.input_proj.append(
                nn.Sequential(
                    nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False),
                    nn.BatchNorm2d(hidden_dim)
                )
            )

下面是AIFI,其实就是一个单层的TransformerEncoderLayer,不多说

        # encoder transformer
        encoder_layer = TransformerEncoderLayer(
            hidden_dim, 
            nhead=nhead,
            dim_feedforward=dim_feedforward, 
            dropout=dropout,
            activation=enc_act)

        self.encoder = nn.ModuleList([
            TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers) for _ in range(len(use_encoder_idx))
        ])

一下是CCFF模块的定义,从上到下路径以及从下到上的路径

        # top-down fpn
        self.lateral_convs = nn.ModuleList()
        self.fpn_blocks = nn.ModuleList()
        for _ in range(len(in_channels) - 1, 0, -1):#2,1
            self.lateral_convs.append(ConvNormLayer(hidden_dim, hidden_dim, 1, 1, act=act))
            self.fpn_blocks.append(
                CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
            )

        # bottom-up pan
        self.downsample_convs = nn.ModuleList()
        self.pan_blocks = nn.ModuleList()
        for _ in range(len(in_channels) - 1):
            self.downsample_convs.append(
                ConvNormLayer(hidden_dim, hidden_dim, 3, 2, act=act)
            )
            self.pan_blocks.append(
                CSPRepLayer(hidden_dim * 2, hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
            )

省略一部分位置编码的代码

快进到forward函数

forword函数首先验证输入特征的shape应该是[3,B,C,H,W],在将特征进行初始映射,再对S5进行尺度内交互,也就是将S5层过单层的TransformerEncoderLayer

代码如下:

    def forward(self, feats):
        assert len(feats) == len(self.in_channels)
        proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)]
        
        # encoder
        if self.num_encoder_layers > 0:#1
            for i, enc_ind in enumerate(self.use_encoder_idx):#实际上这里只有一次取值,就是i=0,enc_ind=2
                h, w = proj_feats[enc_ind].shape[2:]#获得backbone最后一层特征的高和宽分别存储在变量h和w里
                # flatten [B, C, H, W] to [B, HxW, C]
                src_flatten = proj_feats[enc_ind].flatten(2).permute(0, 2, 1)
                if self.training or self.eval_spatial_size is None:
                    pos_embed = self.build_2d_sincos_position_embedding(
                        w, h, self.hidden_dim, self.pe_temperature).to(src_flatten.device)
                else:
                    pos_embed = getattr(self, f'pos_embed{enc_ind}', None).to(src_flatten.device)

                memory = self.encoder[i](src_flatten, pos_embed=pos_embed)#经过Transformer的编码器,事实上是只有一层encoderlayer的encoder
                proj_feats[enc_ind] = memory.permute(0, 2, 1).reshape(-1, self.hidden_dim, h, w).contiguous()#经过编码之后再把它的形状恢复到之前的形状即[B,C,H,W],contiguous使连续存储
                # print([x.is_contiguous() for x in proj_feats ])

下面是最重点的CCFF流程,仔细看代码,注释很清楚,因此不再前解释了

# broadcasting and fusion
        #先从上到下,
        inner_outs = [proj_feats[-1]]#这个inner_outs存储的是最后一层经过Transformer编码器处理后的特征,与此同时从backbone提取的前两层都没有被处理过
        for idx in range(len(self.in_channels) - 1, 0, -1):#idx=2,1(不包括0),初始为2
            feat_high = inner_outs[0]#在以上这种情况下,初始inner_outs中只有一个元素,所以又给最后一层经过TransformerEncoder过后的特征提出来了,即处理过后的proj_feats[-1]
            feat_low = proj_feats[idx - 1]#idx-1=1,0,初始为1,对应着S4层特征
            feat_high = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_high)#len(self.in_channels)-1-idx=0,1初始为0,过一个1*1卷积(laterral_convs中全是1*1的带bn的卷积块)
            inner_outs[0] = feat_high#把feat_high存回去
            upsample_feat = F.interpolate(feat_high, scale_factor=2., mode='nearest')#feat_high[B,C,H,W]——>feat_high[B,C,2H,2W]
            inner_out = self.fpn_blocks[len(self.in_channels)-1-idx](torch.concat([upsample_feat, feat_low], dim=1))#len(self.in_channels)-1-idx:0,1,初始为0,初始将上采样的S5和S4特征在特征维度拼接起来送入CSP进行融合
            inner_outs.insert(0, inner_out)#将融合后的特征inner_out插入到列表的开头(索引 0 的位置)
            #该循环循环两遍
            #在第二次进行该循环的时候 feat_high是上一次循环融合好的特征 feat_low是S3的特征,将feat_high进行卷积上采样与feats_low进行融合,将结果再次插入到inner_outs中
            #经过两次循环之后 inner_outs中将包含三个元素 [0]是三层融合后的结构 [1]是S4和S5融合后的结果 [2]是S5自我卷积后的结果
        #再从下到上
        outs = [inner_outs[0]]#取的是三层从上到下融合后的结果再加一个维度(多套了一层[])
        for idx in range(len(self.in_channels) - 1):#len(self.in_channels)-1=2,idx=0,1
            feat_low = outs[-1]#取出三层从上到下融合后的特征
            feat_high = inner_outs[idx + 1]#idx+1=1,2;第一次循环取出S4和S5的融合结果
            downsample_feat = self.downsample_convs[idx](feat_low)#idx=0,1 采用卷积对低级特征进行下采样
            out = self.pan_blocks[idx](torch.concat([downsample_feat, feat_high], dim=1))#将下采样后的低级特征与高级特征在第通道维度上进行拼接,采用CSP进行融合
            outs.append(out)#将融合后的特征out拼接到outs的尾部
            #该循环会循环两边
            #在第二次进行该循环的时候 feat_low是S3和S4(经过从上到下处理)融合后结果,feats_high是S5的特征,将feat_low进行下采样与feat_high融合,将结果再次拼接到outs的末尾
            #经过两次循环时候 outs中包含三个元素 [0]是三个层自上而下融合后的结果 [1]是三个层自上而下融合后的结果再与上两层自上而下融合后的结果相容和的结果
            # [2]是五个层融合后的结果在和S5的初代卷积融合后的结果,称为六个层融合后的结果,也是S3,S4,S5自上而下再自下而上融合后的结果,也是S3,S4,S5各自被融合两次的结果

        return outs#最终返回outs数组

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2033503.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

计算机网络408考研 2015

计算机网络408考研2015年真题解析_哔哩哔哩_bilibili 1 1线路编码(NRZ,NRZI,8B/10B,Manchester)与加扰_nrz编码-CSDN博客 1 1 11

sunspec协议储能电能计量装置

电网公司通常要求光伏并网系统为不可逆流发电系统,即光伏并网系统所发的电由本地负荷消耗,多余的电不允许通过低压配电变压器向上级电网逆向送电。在并网发电系统中,由于外部环境是不断变化的,为了防止光伏并网系统逆向发电&#…

DLL修复工具免费版本推荐:有效修复DLL文件问题

在Windows系统中,DLL(动态链接库)文件扮演着至关重要的角色。它们为多个程序共享代码和资源,节省内存并促进程序之间的高效运行。然而,DLL文件的损坏或丢失可能导致各种问题,如程序崩溃、系统不稳定甚至蓝屏…

大数据技术——实战项目:广告数仓(第五部分)

目录 第9章 广告数仓DIM层 9.1 广告信息维度表 9.2 平台信息维度表 9.3 数据装载脚本 第10章 广告数仓DWD层 10.1 广告事件事实表 10.1.1 建表语句 10.1.2 数据装载 10.1.2.1 初步解析日志 10.1.2.2 解析IP和UA 10.1.2.3 标注无效流量 10.2 数据装载脚本 第9章 广…

Ubuntu中设置环境变量 PATH 的命令,不生效的问题“PATH=~/bin:$PATH”

1. 知识点 PATH~/bin:$PATH PATH:这是一个环境变量,用于指定操作系统在哪些目录中查找可执行文件。 ~:这是一个特殊的符号,代表当前用户的主目录。 /bin:这通常是存放标准实用程序(如 ls, cp 等&#xff…

解决Openwrt 串口默认是没有密码的方法

将串口登录加入密码方法如下: 步骤一:配置busybox的登录,可以在.config文件中添加如下 CONFIG_BUSYBOX_CONFIG_LOGINy 添加后,需要重新编译busybox。 步骤二:修改target/linux/ramips/base-files/etc/inittab文件 将…

C++之类与对象(中)(上篇)

类与对象(中) 类的默认成员函数 默认成员函数就是⽤⼾没有显式实现,编译器会⾃动⽣成的成员函数称为默认成员函数。⼀个类,我 们不写的情况下编译器会默认⽣成以下6个默认成员函数,需要注意的是这6个中最重要的是前4…

ECCV 2024 | 南洋理工三维数字人生成新范式:结构扩散模型

该论文作者均来自于新加坡南洋理工大学 S-Lab 团队,包括博士后胡涛,博士生洪方舟,以及计算与数据学院刘子纬教授(《麻省理工科技评论》亚太地区 35 岁以下创新者)。S-Lab 近年来在顶级会议如 CVPR, ICCV, ECCV, NeurIP…

ICE.AI战略扩展亚太市场,创新交易模式及平台全面升级

2024年8月11日,纽约——全球金融科技领军企业,Intercontinental Exchange Inc.宣布,公司将加速在亚太市场的战略扩展,并通过进一步优化交易模式和平台功能,巩固其在全球市场的卓越地位,同时积极探索新的获利机会。 ICE.AI自推行以来,凭借前沿的人工智能技术和深度学习算法,为全…

shell编程:利用SSH实现分布式应用的一键安装部署②(脚本安装java环境、脚本安装配置zookeeper、scala、kafka)

上一节:函数封装 ②脚本安装java环境、脚本安装配置zookeeper、scala、kafka 1 脚本一键部署kafka分布式应用 1.1 脚本安装配置java环境 准备好java安装包,存放到/opt/tmp目录下。我这里使用的是jdk-8u212-linux-x64.tar.gz,在网上找对应…

excel向下合并空值

方方格子:合并转换——合并空值 选择向右或者向下

基于ssm+vue+uniapp的英语学习交流平台小程序

开发语言:Java框架:ssmuniappJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:M…

【网络】套接字(socket)编程——UDP版

1.socket 1.1.什么是socket Socket 的中文翻译过来就是“套接字”。 套接字是什么,我们先来看看它的英文含义:插座。 Socket 就像一个电话插座,负责连通两端的电话,进行点对点通信,让电话可以进行通信,端…

鸿蒙(API 12 Beta3版)【音视频解封装】 文件解析封装

开发者可以调用本模块的Native API接口,完成音视频解封装,即从比特流数据中取出音频、视频等媒体帧数据。 当前支持的数据输入类型有:远程连接(http协议、HLS协议)和文件描述符(fd)。 支持的解封装格式如下: 媒体格式封装格式码…

高效修复,2024年SD卡损坏数据恢复利器推荐

如果你也是爱记录生活的小伙伴外出游玩的时候肯定会带上带你的长枪短炮吧。如果预算充足可以直接考虑双盘位的设备,为你的图片上个保险。如果是单卡槽的设备回来的时候发现照片全无了咋办,这次我们就探讨下sd卡数据恢复要怎么进行吧。 1.福昕恢复数据 …

【递归】3.反转链表

leetcode题目连接:https://leetcode.cn/problems/reverse-linked-list/题解过程: 1.找到重复的子问题 要逆序第一个节点,就把后面的节点都逆序一遍 2.关注到具体的子问题的实现 第一步:将当前节点的后面所有节点逆置 第二步&…

【自动驾驶】ROS中自定义格式的服务通信,含命令行动态传参(c++)

目录 通信流程创建服务器端及客户端新建服务通讯文件修改service的xml及cmakelistCMakeLists.txt编辑 msg 相关配置编译消息相关头文件在cmakelist中包含头文件的路径在service包下编写service.cpp在client包下编写client.cpp测试运行查询服务的相关指令列出目前的所有服务&…

毛骨悚然,ChatGPT诡异尖叫、模仿用户说话,GPT-4o被曝行为失控

ChatGPT被曝存在失控行为,原本是用户和ChatGPT正常的语音对话,但ChatGPT却突然大喊了一声“no”,随即竟模仿起了用户的声音! 下面就是这段让人毛骨悚然的声音片段: ChatGPT失控行为首次公开很多网友表示,第…

【MySQL】2.MySQL实际操作

目录 一、数据分析基本流程 注:Navicat快捷键 二、获取数据后的代码操作 (1)探索数据,查看定义 (2)筛选有用的字段 (3)建新表(查询建表插值 三合一) 注意…

揭秘Java 8新宠儿:初识Optional,让你的代码告别空指针烦恼

文章目录 前言一、Optional基础二、使用步骤1.创建Optional实例1.常用方法 前言 Java 8 引入了一个非常有用的类 Optional,它旨在减少空指针异常(NullPointerException)的发生。Optional 类是一个可以包含也可以不包含非null值的容器对象。如…