YOLOv11改进 | 上采样篇 | YOLOv11引入DySample轻量级动态上采样器

news2024/10/3 15:31:48

1. DySample介绍

1.1  摘要:我们提出了DySample,一个超轻量和有效的动态上采样器。虽然最近的基于内核的动态上采样器(如CARAFE、FADE和SAPA)的性能提升令人印象深刻,但它们引入了大量工作负载,主要是由于耗时的动态卷积和用于生成动态内核的额外子网。此外,FADE和SAPA对高分辨率功能指导的需求在某种程度上限制了它们的应用场景。为了解决这些问题,我们绕过了动态卷积,从点采样的角度制定了上采样,这更节省资源,而且可以很容易地用PyTorch中的标准内置函数实现。我们首先展示一个简单的设计,然后演示如何逐步增强其上采样行为,以实现我们的新上采样器DySample。与以前的基于内核的动态上采样器相比,DySample不需要定制的CUDA包,具有更少的参数、FLOP、GPU内存和延迟。除了轻量级的特性外,DySample在五个密集预测任务中的表现优于其他上采样器,包括语义分割、对象检测、实例分割、全景分割和单目深度估计。

官方论文地址:https://arxiv.org/pdf/2308.15085

1.2  简单介绍:  

           DySample模块是一种超轻量级且高效的动态上采样器,通过从点采样的角度重新定义了上采样过程。DySample模块的设计旨在解决现有动态上采样器的高计算开销和复杂性问题,同时提供更好的性能和效率。以下是对DySample模块的描述:

(1). 基本设计

         点采样的基础概念:DySample模块的核心思想是将上采样过程视为点采样。不同于传统的核方法,DySample通过生成偏移量来重新采样输入特征图中的连续区域。这种方法被称为“动态采样”,其目的是在保持高效性的同时实现灵活的上采样操作。

         初步实现:在初步实现中,首先使用线性层生成偏移量,然后通过像素混洗(Pixel Shuffle)函数将这个偏移量与原始网格位置相加,从而生成新的采样点集合。这个过程被称为“静态偏移范围因子”。

(2). 改进设计

         初始采样位置调整:研究发现,共享相同的初始偏移位置会导致采样点之间缺乏位置关系,这会影响模型的性能。因此,提出了“双线性初始化”策略,即将初始采样位置分布得更加均匀,从而提高模型的准确性。

         偏移范围控制:由于标准化层的存在,输出特征值通常在[−1, 1]范围内。为了减少偏移范围的重叠,采用了“静态偏移范围因子”,将偏移范围限制在[−0.25, 0.25]内。此外,还引入了动态偏移范围因子,通过线性投影和Sigmoid函数进一步调整偏移范围。

          分组上采样:为了进一步提高灵活性和性能,DySample采用了多组上采样策略。具体来说,将特征图沿着通道维度划分为多个组,每个组内共享相同的采样集。实验结果表明,分组上采样能够显著提升模型性能。

(3). 动态上采样过程

         偏移生成方式:DySample模块有两种偏移生成方式:“线性+像素混洗”(LP)和“像素混洗+线性”(PL)。LP方式需要更多的参数和内存开销,但更灵活;PL方式则相反,具有更少的参数和更快的推理速度。

          上采样过程可视化:通过对局部区域的放大显示,可以看到DySample如何将一个边缘点分成四个点进行上采样,以增强边界的清晰度。这种设计确保了上采样后的图像细节丰富且边界清晰。

综上DySample模块通过一系列改进设计和优化措施,实现了高效的动态上采样功能。其独特的点采样方法和灵活的偏移控制机制使其在多个任务中表现出色,同时具备较低的计算资源消耗和较高的推理速度。这些特点使得DySample成为现有动态上采样器中极具竞争力的一种选择。

1.3  模块结构图

2. 核心代码

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['Dy_Sample']


def normal_init(module, mean=0, std=1, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.normal_(module.weight, mean, std)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)


def constant_init(module, val, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.constant_(module.weight, val)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)


class Dy_Sample(nn.Module):
    def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False):
        super().__init__()
        self.scale = scale
        self.style = style
        self.groups = groups
        assert style in ['lp', 'pl']
        if style == 'pl':
            assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
        assert in_channels >= groups and in_channels % groups == 0

        if style == 'pl':
            in_channels = in_channels // scale ** 2
            out_channels = 2 * groups
        else:
            out_channels = 2 * groups * scale ** 2

        self.offset = nn.Conv2d(in_channels, out_channels, 1)
        normal_init(self.offset, std=0.001)
        if dyscope:
            self.scope = nn.Conv2d(in_channels, out_channels, 1)
            constant_init(self.scope, val=0.)

        self.register_buffer('init_pos', self._init_pos())

    def _init_pos(self):
        h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
        return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)

    def sample(self, x, offset):
        B, _, H, W = offset.shape
        offset = offset.view(B, 2, -1, H, W)
        coords_h = torch.arange(H) + 0.5
        coords_w = torch.arange(W) + 0.5
        coords = torch.stack(torch.meshgrid([coords_w, coords_h])
                             ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
        normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
        coords = 2 * (coords + offset) / normalizer - 1
        coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view(
            B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
        return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear',
                             align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W)

    def forward_lp(self, x):
        if hasattr(self, 'scope'):
            offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos
        else:
            offset = self.offset(x) * 0.25 + self.init_pos
        return self.sample(x, offset)

    def forward_pl(self, x):
        x_ = F.pixel_shuffle(x, self.scale)
        if hasattr(self, 'scope'):
            offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos
        else:
            offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos
        return self.sample(x, offset)

    def forward(self, x):
        if self.style == 'pl':
            return self.forward_pl(x)
        return self.forward_lp(x)


if __name__ == '__main__':
    x = torch.rand(2, 64, 4, 7)
    dys = Dy_Sample(64)
    print(dys(x).shape)

3. YOLOv11中添加DySample

3.1 在ultralytics/nn下新建Extramodule

 3.2 在Extramodule里创建DySample

在DySample.py文件里添加给出的DySample代码

添加完DySample代码后,在ultralytics/nn/Extramodule/__init__.py文件中引用

3.3 在tasks.py里引用

在ultralytics/nn/tasks.py文件里引用Extramodule

在tasks.py找到parse_model(ctrl+f可以直接搜索parse_model位置

添加如下代码:

        elif m in {Dy_Sample}:
            c2 = ch[f]
            args = [c2, *args]

4. 新建一个yolo11DySample.yaml文件

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 1 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, Dy_Sample, []]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, Dy_Sample, []]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)

大家根据自己的数据集实际情况,修改nc大小。

5. 模型训练

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO(r'D:\yolo\yolov11\ultralytics-main\datasets\yolo11DySample.yaml')
    model.train(data=r'D:\yolo\yolov11\ultralytics-main\datasets\data.yaml',
                cache=False,
                imgsz=640,
                epochs=100,
                single_cls=False,  # 是否是单类别检测
                batch=4,
                close_mosaic=10,
                workers=0,
                device='0',
                optimizer='SGD',
                amp=True,
                project='runs/train',
                name='exp',
                )

模型结构打印,成功运行:

6.本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv11改进有效涨点专栏,本专栏目前为新开的,后期我会根据各种前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

YOLOv11有效涨点专栏

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2186462.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Boot 中的拦截器 Interceptors

​ 博客主页: 南来_北往 系列专栏:Spring Boot实战 前言 Spring Boot中的拦截器(Interceptor)是一种用于拦截和处理HTTP请求的机制,它基于Spring MVC框架中的HandlerInterceptor接口实现。拦截器允许在请求到达控制器&#…

C++函数模板、选择排序实现(从大到小)

template <class T> void mysw (T &a , T &b) {T temp b;b a;a temp; }template <class T> void muSort( T &arr ,int len) {//该实现为选择排序(高到低)for (int i 0; i < len; i) {int max i ; //首先默认本次循环首位元素为最大for (int j …

scrapy爬取汽车、车评数据【中】

这个爬虫我想分三期来写&#xff1a; ✅ 第一期写如何爬取汽车的车型信息&#xff1b; ✅ 第二期写如何爬取汽车的车评&#xff1b; ✅ 第三期写如何对车评嵌入情感分析结果&#xff0c;以及用简单的方法把数据插入mysql中&#xff1b; 技术基于scrapy框架、BERT语言模型、mysq…

SQL Server中关于个性化需求批量删除表的做法

在实际开发中&#xff0c;我们常常会遇到需要批量删除表&#xff0c;且具有共同特征的情况&#xff0c;例如&#xff1a;找出表名中数字结尾的表之类的&#xff0c;本文我将以3中类似情况为例&#xff0c;来示范并解说此类需求如何完成&#xff1a; 第一种&#xff0c;批量删除…

【Godot4.3】图形碰撞相关函数库ShapeTests

概述 最近积累了一些图形重叠检测&#xff0c;以及求图形的轴对齐包围盒Rect2&#xff0c;还有求Rect2的外接圆等函数。感觉可以作为一个单独的函数库&#xff0c;提供日常的使用&#xff0c;所以汇总成了ShapeTests。 注意&#xff1a;函数名和写法可能会不断改进。 代码 …

基于SSM的北京冬奥会志愿者服务系统

文未可获取一份本项目的java源码和数据库参考。 本课题国内外研究现状 当前&#xff0c;国外志愿者服务活动开展的十分活跃。志愿服务正以其突出的社会效益受到越来越多国家政府的重视。许多国家的志愿服务活动起步早、规模大&#xff0c;社会效益好。他们在国内有广泛的群众…

第四届生物医学与智能系统国际学术会议(IC-BIS 2025)

在线投稿&#xff1a;学术会议-学术交流征稿-学术会议在线-艾思科蓝 2025年第四届生物医学与智能系统国际学术会议&#xff08;IC-BIS 2025&#xff09; 将于2025年4月11-13日在意大利隆重举行。 该会议旨在汇集全球学术界和工业界的研究人员、专家和从业人员&#xff0c;共…

C(十)for循环 --- 黑神话情景

前言&#xff1a; "踏过三界宝刹&#xff0c;阅过四洲繁华。笑过五蕴痴缠&#xff0c;舍过六根牵挂。怕什么欲念不休&#xff0c;怕什么浪迹天涯。步履不停&#xff0c;便是得救之法。" 国际惯例&#xff0c;开篇先喝碗鸡汤。 今天&#xff0c;杰哥写的 for 循环相…

android Activity生命周期

android 中一个 activity 在其生命周期中会经历多种状态。 您可以使用一系列回调来处理状态之间的转换。下面我们来介绍这些回调。 onCreate&#xff08;创建阶段&#xff09; 初始化组件&#xff1a;在这个阶段&#xff0c;Activity的主要工作是进行初始化操作。这包括为Ac…

【Bug】STM32F1的PB3和PB4无法正常输出

Bug 使用标准库配置STM32F103C8T6的PB3和PB4引脚输出控制LED灯时&#xff0c;发现引脚电平没有变化无法正常输出高低电平&#xff0c;配置代码如下&#xff1a; GPIO_InitTypeDef GPIO_InitStructure;RCC_APB2PeriphClockCmd( RCC_APB2Periph_GPIOB, ENABLE ); GPIO_InitStruc…

Maven项目管理入门:POM文件详解与依赖管理

目录 1、pom配置详解 2、依赖管理 2.1 Maven坐标 2.2 依赖导入 1、使用IDEA工具导入&#xff1a; 2、从远程仓库中获取坐标 3、maven插件 maven是一个项目管理工具&#xff0c;其中依靠于一个很重要的xml文件&#xff1a;pom.xml。我们接下来学习下pom.xml文件的配置。 …

论文阅读:PET/CT Cross-modal medical image fusion of lung tumors based on DCIF-GAN

摘要 背景&#xff1a; 基于GAN的融合方法存在训练不稳定&#xff0c;提取图像的局部和全局上下文语义信息能力不足&#xff0c;交互融合程度不够等问题 贡献&#xff1a; 提出双耦合交互式融合GAN&#xff08;Dual-Coupled Interactive Fusion GAN&#xff0c;DCIF-GAN&…

第十三章 集合

一、集合的概念 集合&#xff1a;将若干用途、性质相同或相近的“数据”组合而成的一个整体 Java集合中只能保存引用类型的数据&#xff0c;不能保存基本类型数据 数组的缺点&#xff1a;长度不可变 Java中常用集合&#xff1a; 1.Set(集):集合中的对象不按特定方式排序&a…

FastGPT的使用

fastGPT的介绍&#xff1a; fastGPT其实和chatGPT差不多 但是好处是可以自行搭建&#xff0c;而且很方便 链接&#xff1a;https://cloud.fastgpt.cn/app/list 首先我们可以根据红框点击&#xff0c;创建一个简易的对话引导 这个机器人就非常的简易&#xff0c;只能完成一些翻…

自动驾驶系列—LDW(车道偏离预警):智能驾驶的安全守护者

&#x1f31f;&#x1f31f; 欢迎来到我的技术小筑&#xff0c;一个专为技术探索者打造的交流空间。在这里&#xff0c;我们不仅分享代码的智慧&#xff0c;还探讨技术的深度与广度。无论您是资深开发者还是技术新手&#xff0c;这里都有一片属于您的天空。让我们在知识的海洋中…

【Linux系统编程】权限

目录 1、shell命令以及运行原理 2、Linux权限的概念 3、Linux权限管理 3.1 文件访问者的分类(人) 3.2 文件类型和访问权限(事物属性) 4、目录的权限 5、粘滞位 1、shell命令以及运行原理 首先&#xff0c;我们来了解一条指令是如何跑起来的 一般操作系统是不会让用户直…

大数据开发--1.2 Linux介绍及虚拟机网络配置

目录 一. 计算机入门知识介绍 软件和硬件的概述 硬件 软件 操作系统概述 简单介绍 常见的系统操作 学习Linux系统 二. Linux系统介绍 简单介绍 发行版介绍 常用的发行版 三. Linux系统的安装和体验 Linux系统的安装 介绍 虚拟机原理 常见的虚拟机软件 体验Li…

招联金融秋招内推2025

【投递方式】 直接扫下方二维码&#xff0c;或点击内推官网https://wecruit.hotjob.cn/SU61025e262f9d247b98e0a2c2/mc/position/campus&#xff0c;使用内推码 igcefb 投递&#xff09; 【招聘岗位】 后台开发 前端开发 数据开发 数据运营 算法开发 技术运维 软件测试 产品策…

四、网络层(下)

4.9 CIDR CIDR&#xff08;Classless Inter-Domain Routing&#xff09;&#xff0c;是IPv4地址分配和路由表选择的一种灵活且高效的方法。 1992年&#xff0c;由于分类地址中的B类地址很快就被分配完了&#xff0c;且路由表中的表项也急剧增加&#xff0c;分类的IP地址并不能…

高校实训产品:教育AI人工智能实训与科研解决方案

保持前沿、提升就业、低成本的教育AI实训全场景方案 产品概述 AIGC实训云图站解决方案为高校提供了灵活、高效的人工智能实训平台。通过弹性裸金属调度技术和GPU虚拟化&#xff0c;实现高性能与低成本的兼顾&#xff0c;为学生和教师提供不受时间和空间限制的实操机会。平台涵…