RT-DETR融合[ECCV2024]自调制特征聚合SMFA模块及相关改进思路

news2024/11/15 16:16:51


RT-DETR使用教程: RT-DETR使用教程

RT-DETR改进汇总贴:RT-DETR更新汇总贴


《SMFANet: A Lightweight Self-Modulation Feature Aggregation Network for Efficient Image Super-Resolution》

一、 模块介绍

        论文链接:https://link.springer.com/chapter/10.1007/978-3-031-72973-7_21

        代码链接:https://github.com/Zheng-MJ/SMFANet?tab=readme-ov-file

论文速览:

        基于 Transformer 的修复方法取得了显着的性能,因为 Transformer 的自注意力 (SA) 可以探索非局部信息以获得更好的高分辨率图像重建。然而,关键的点积 SA 需要大量的计算资源。此外,SA 机制的低通特性限制了其捕获局部细节的能力,从而导致平滑的重建结果。为了解决这些问题,作者提出了一个自调制特征聚合 (SMFA) 模块,以协同利用局部和非局部特征交互来实现更准确的重建。具体来说,SMFA 模块采用高效的自我注意近似 (EASA) 分支来对非局部信息进行建模,并使用局部细节估计 (LDE) 分支来捕获局部细节。此外,作者进一步引入了基于部分卷积的前馈网络 (PCFN) 来改进从 SMFA 派生的代表性特征。大量实验表明,所提出的 SMFANet 系列在公共基准数据集上实现了更好的重建性能和计算效率之间的权衡。特别是,与×4 SwinIR-light,SMFANet+ 在五个公共测试集中平均实现了 0.14 dB 的性能提升,并且×运行速度提高 10 倍,模型复杂度仅为 43% 左右(例如 FLOPs)。

总结:一种基于自调制特征聚合模块(SMFA)的高分辨率图像重建方法,实测与其他模块融合有提升。


二、 加入到RT-DETR中

2.1 创建脚本文件

        首先在ultralytics->nn路径下创建blocks.py脚本,用于存放模块代码。

2.2 复制代码        

        复制代码粘到刚刚创建的blocks.py脚本中,如下图所示:

import torch
import torch.nn as nn
import torch.nn.functional as F
 
 
class DMlp(nn.Module):
    def __init__(self, dim, growth_rate=2.0):
        super().__init__()
        hidden_dim = int(dim * growth_rate)
        self.conv_0 = nn.Sequential(
            nn.Conv2d(dim,hidden_dim,3,1,1,groups=dim),
            nn.Conv2d(hidden_dim,hidden_dim,1,1,0)
        )
        self.act =nn.GELU()
        self.conv_1 = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
 
    def forward(self, x):
        x = self.conv_0(x)
        x = self.act(x)
        x = self.conv_1(x)
        return x
 
 
class SMFA(nn.Module):
    def __init__(self, dim=36):
        super(SMFA, self).__init__()
        self.linear_0 = nn.Conv2d(dim,dim*2,1,1,0)
        self.linear_1 = nn.Conv2d(dim,dim,1,1,0)
        self.linear_2 = nn.Conv2d(dim,dim,1,1,0)
 
        self.lde = DMlp(dim,2)
 
        self.dw_conv = nn.Conv2d(dim,dim,3,1,1,groups=dim)
 
        self.gelu = nn.GELU()
        self.down_scale = 8
 
        self.alpha = nn.Parameter(torch.ones((1,dim,1,1)))
        self.belt = nn.Parameter(torch.zeros((1,dim,1,1)))
 
    def forward(self, f):
        _,_,h,w = f.shape
        y, x = self.linear_0(f).chunk(2, dim=1)
        x_s = self.dw_conv(F.adaptive_max_pool2d(x, (h // self.down_scale, w // self.down_scale)))
        x_v = torch.var(x, dim=(-2,-1), keepdim=True)
        x_l = x * F.interpolate(self.gelu(self.linear_1(x_s * self.alpha + x_v * self.belt)), size=(h,w), mode='nearest')
        y_d = self.lde(y)
        return self.linear_2(x_l + y_d)

2.3 更改task.py文件 

       打开ultralytics->nn->modules->task.py,在脚本空白处导入函数。

from ultralytics.nn.blocks import *

        之后找到模型解析函数parse_model(约在tasks.py脚本中940行左右位置,可能因代码版本不同变动),在该函数的最后一个else分支上面增加相关解析代码。

        elif m is SMFA:
            c2 = ch[f]
            args = [ch[f]]

2.4 更改yaml文件 

yam文件解读:YOLO系列 “.yaml“文件解读_yolo yaml文件-CSDN博客

       打开更改ultralytics/cfg/models/rt-detr路径下的rtdetr-l.yaml文件,替换原有模块。(放在该位置仅能插入该模块,具体效果未知。博主精力有限,仅完成与其他模块二次创新融合的测试,结构图见文末,代码见群文件更新。)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# RT-DETR-l object detection model with P3-P5 outputs. For details see https://docs.ultralytics.com/models/rtdetr

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-cls.yaml' will call yolov8-cls.yaml with scale 'n'
  # [depth, width, max_channels]
  l: [1.00, 1.00, 1024]

backbone:
  # [from, repeats, module, args]
  - [-1, 1, HGStem, [32, 48]] # 0-P2/4
  - [-1, 6, HGBlock, [48, 128, 3]] # stage 1

  - [-1, 1, DWConv, [128, 3, 2, 1, False]] # 2-P3/8
  - [-1, 6, HGBlock, [96, 512, 3]] # stage 2

  - [-1, 1, DWConv, [512, 3, 2, 1, False]] # 4-P3/16
  - [-1, 2, SMFA, []] # cm, c2, k, light, shortcut
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]]
  - [-1, 6, HGBlock, [192, 1024, 5, True, True]] # stage 3

  - [-1, 1, DWConv, [1024, 3, 2, 1, False]] # 8-P4/32
  - [-1, 6, HGBlock, [384, 2048, 5, True, False]] # stage 4

head:
  - [-1, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 10 input_proj.2
  - [-1, 1, AIFI, [1024, 8]]
  - [-1, 1, Conv, [256, 1, 1]] # 12, Y5, lateral_convs.0

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [7, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 14 input_proj.1
  - [[-2, -1], 1, Concat, [1]]
  - [-1, 3, RepC3, [256]] # 16, fpn_blocks.0
  - [-1, 1, Conv, [256, 1, 1]] # 17, Y4, lateral_convs.1

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [3, 1, Conv, [256, 1, 1, None, 1, 1, False]] # 19 input_proj.0
  - [[-2, -1], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, RepC3, [256]] # X3 (21), fpn_blocks.1

  - [-1, 1, Conv, [256, 3, 2]] # 22, downsample_convs.0
  - [[-1, 17], 1, Concat, [1]] # cat Y4
  - [-1, 3, RepC3, [256]] # F4 (24), pan_blocks.0

  - [-1, 1, Conv, [256, 3, 2]] # 25, downsample_convs.1
  - [[-1, 12], 1, Concat, [1]] # cat Y5
  - [-1, 3, RepC3, [256]] # F5 (27), pan_blocks.1

  - [[21, 24, 27], 1, RTDETRDecoder, [nc]] # Detect(P3, P4, P5)


 2.5 修改train.py文件

       创建Train_RT脚本用于训练。

from ultralytics.models import RTDETR
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

if __name__ == '__main__':
    model = RTDETR(model='ultralytics/cfg/models/rt-detr/rtdetr-l.yaml')
    # model.load('yolov8n.pt')
    model.train(data='./data.yaml', epochs=2, batch=1, device='0', imgsz=640, workers=2, cache=False,
                amp=True, mosaic=False, project='runs/train', name='exp')

         在train.py脚本中填入修改好的yaml路径,运行即可训。

三、相关改进思路(2024/11/16日群文件)

        根据SMFA模块特性,可如图加入到HGBlock、RepNCSPELAN4、RepC3等模块中,代码见群文件,结构如图。自研模块与该模块融合代码及yaml文件见群文件。

 ⭐另外,融合上百种深度学习改进模块的YOLO项目仅79.9(含百种改进的v9),RTDETR79.9,含高性能自研模型,更易发论文,代码每周更新,欢迎点击下方小卡片加我了解。⭐

⭐⭐平均每个文章对应4-6个二创及自研融合模块⭐⭐


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2240940.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

WebAPI性能监控-MiniProfiler与Swagger集成

Net8_WebAPI性能监控-MiniProfiler与Swagger集成 要在.NET Core项目中集成MiniProfiler和Swagger,可以按照以下步骤操作: 安装NuGet包: 安装MiniProfiler.AspNetCore.Mvc包以集成MiniProfiler。安装MiniProfiler.EntityFrameworkCore包以监…

第十五章 Spring之假如让你来写AOP——Joinpoint(连接点)篇

Spring源码阅读目录 第一部分——IOC篇 第一章 Spring之最熟悉的陌生人——IOC 第二章 Spring之假如让你来写IOC容器——加载资源篇 第三章 Spring之假如让你来写IOC容器——解析配置文件篇 第四章 Spring之假如让你来写IOC容器——XML配置文件篇 第五章 Spring之假如让你来写…

喜讯 | 科东软件荣获广东省工业软件科学技术进步奖一等奖

工业软件是制造业数字化、智能化转型升级的核心支撑,贯穿于工业生产的全过程,包括研发设计、测试,智能装备与操作系统嵌入式,系统与平台,算法、模型与工具等类型。通过开展工业软件科学技术奖评选活动,激励…

SystemVerilog学习笔记(二):数组

数组是元素的集合,所有元素都具有相同的类型,并使用其名称和一个或多个索引进行访问。 Verilog 2001 要求数组的下限和上限必须是数组声明的一部分。 System Verilog 引入了紧凑数组声明样式,只需给出数组大小以及数组名称声明就足够了。 下…

批量从Excel某一列中找到符合要求的值并提取其对应数据

本文介绍在Excel中,从某一列数据中找到与已知数据对应的字段,并提取这个字段对应数值的方法。 首先,来明确一下我们的需求。现在已知一个Excel数据,假设其中W列包含了上海市全部社区的名称,而其后的Y列则是这些社区对应…

握手协议是如何在SSL VPN中发挥作用的?

SSL握手协议:客户端和服务器通过握手协议建立一个会话。会话包含一组参数,主要有会话ID、对方的证书、加密算法列表(包括密钥交换算法、数据加密算法和MAC算法)、压缩算法以及主密钥。SSL会话可以被多个连接共享,以减少…

数字化转型:基于价值流的业务架构战略解析

在当前数字化浪潮下,企业纷纷转向数字化转型,以适应市场需求的快速变化和技术革新。数字化转型不仅仅是技术层面的变革,更是对企业业务模式、文化以及价值创造方式的全面重新思考和重塑。《价值流(Value Streams)》为企…

Mac终端字体高亮、提示插件

一、安装配置“oh my zsh” 1.1 安装brew /bin/zsh -c "$(curl -fsSL https://gitee.com/cunkai/HomebrewCN/raw/master/Homebrew.sh)" 按照步骤安装即可,安装完成查看版本 brew -v 1.2 安装zsh brew install zsh 安装完成后查看版本 zsh --version 1.3 …

什么是CRM系统?

越来越多的企业意识到:如何有效管理与客户的关系、提升客户满意度,并通过这些提升推动销售增长,已经成为许多公司亟待解决的问题。为此,客户关系管理(Customer Relationship Management,简称CRM&#xff09…

Ilya Sutskever AI行业将进入一个新的“探索时代”

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

MySQL(5)【数据类型 —— 字符串类型】

阅读导航 引言一、char🎯基本语法🎯使用示例 二、varchar🎯基本语法🎯使用示例 三、char 和 varchar 比较四、日期和时间类型1. 基本概念2. 使用示例 五、enum 和 set🎯基本语法 引言 之前我们聊过MySQL中的数值类型&…

湾区聚力 开源启智 | 2024 CCF中国开源大会暨第五届OpenI/O启智开发者大会闪耀深圳

当下,全球数字化浪潮席卷而来,开源技术已成为科技创新和产业升级的关键驱动力。11月9-10日,以“湾区聚力 开源启智”为主题的2024 CCF中国开源大会在深圳隆重举行。本届大会由中国计算机学会主办,CCF开源发展委员会、鹏城实验室、…

Linux基本指令(中)(2)

文章目录 前言一、echo二、cat三、more四、less五、head六、tail七、date八、cal九、find十、whoami十一、clear总结 前言 承上启下,我们再来看看另外一些常用的基础指令吧! 一、echo 语法:echo [选项] [字符串] 功能:在终端设备上…

MYSQL中JDBC的使用

一、JDBC基础概念 JDBC 是Java 中的一组API,用于执行SQL 操作(例如CRUD 操作:增、删、改、关系),同时可以和各种类型的数据库类型进行连接(MySQL、Oracle、SQL Server 等)。 JDBC是Java标准库的…

UnixBench和Geekbench进行服务器跑分

1 概述 服务器的基准测试,常见的测试工具有UnixBench、Geekbench、sysbench等。本文主要介绍UnixBench和Geekbench。 1.1 UnixBench UnixBench是一款开源的测试UNIX系统基本性能的工具(https://github.com/kdlucas/byte-unixbench)&#x…

基于Java Springboot人力资源管理系统

一、作品包含 源码数据库设计文档万字全套环境和工具资源部署教程 二、项目技术 前端技术:Html、Css、Js、Vue 数据库:MySQL 后端技术:Java、Spring Boot、MyBatis 三、运行环境 开发工具:IDEA 数据库:MySQL8.0…

使用OpenCV(C++)通过鼠标点击操作获取图像的像素坐标和像素值

使用OpenCV(C)通过鼠标点击操作获取图像的像素坐标和像素值 在这篇博客中,我们将介绍如何使用OpenCV库在C中实现鼠标点击操作,以获取图像的像素坐标和像素值。代码分为两个部分:一个是鼠标事件处理的回调函数&#xff…

Windows VSCode .NET CORE WebAPI Debug配置

1.安装C#插件 全名C# for Visual Studio Code,选择微软的 2. 安装C# Dev Kit插件 全名C# Dev Kit for Visual Studio Code,同样是选择微软的 3.安装Debugger for Unity 4.配置launch.json 文件 {"version": "0.2.0","config…

AI斩获6枚金牌!华为Kaggle大师级智能体诞生,自主解决数据科学难题

继 OpenAI o1 成为首个达到 Kaggle 特级大师的人工智能(AI)模型后,另一个 Kaggle 大师级 AI 也诞生了。 根据 Kaggle 的晋级系统,由华为诺亚方舟实验室和伦敦大学学院团队联合推出的端到端自主数据科学智能体(agent&a…

[Mysql基础] 表的操作

一、创建表 1.1 语法 CREATE TABLE table_name ( field1 datatype, field2 datatype, field3 datatype ) character set 字符集 collate 校验规则 engine 存储引擎; 说明: field 表示列名 datatype 表示列的类型 character set 字符集,如果没有指定字符集…