YOLO11改进|上采样篇|引入DySample轻量级动态上采样器

news2024/12/24 20:53:24

在这里插入图片描述

目录

    • 一、DySample轻量级动态上采样器
      • 1.1DySample上采样模块介绍
      • 1.2DySample核心代码
    • 五、添加DySample上采样器
      • 5.1STEP1
      • 5.2STEP2
      • 5.3STEP3
      • 5.4STEP4
    • 六、yaml文件与运行
      • 6.1yaml文件
      • 6.2运行成功截图

一、DySample轻量级动态上采样器

1.1DySample上采样模块介绍

在这里插入图片描述

DySample是一种基于采样的动态上采样(Sampling based dynamic upsampling)机制,用于提升图像或特征图的分辨率。主要分为三部分:采样点生成(sampling point generator)、静态和动态作用域因子(Static Scope Factor 和 Dynamic Scope Factor)。下面我将对每个部分的工作流程和其优势做一个简要介绍。

  1. 采样点生成和网格采样
    采样点生成器(Sampling point generator) 会生成一个采样点集,用于决定在哪些点上进行上采样。接着,使用 网格采样(Grid Sample) 方法对原始特征图进行采样,生成一个新的高分辨率特征图 𝑋′。这种基于采样点的上采样方法具有动态性,意味着它可以根据输入特征图的不同自动调整采样点,避免了传统插值方法的局限性,使得上采样更加灵活和精确。

  2. 静态作用域因子(Static Scope Factor)
    在这个模块中,输入特征图 𝑋,先通过一个线性变换生成低维特征。然后乘以一个固定因子(例如0.25),再进行 Pixel Shuffle 操作。最终,将生成的高分辨率特征 𝐺 和一个由固定范围因子 𝑂计算出来的偏移特征图相加,得到输出 𝑆。
    优势:静态因子使得上采样具有稳定性和可控性,适合处理那些具有一致性尺度变化的场景。

  3. 动态作用域因子(Dynamic Scope Factor)
    与静态作用域因子类似,但动态作用域因子会根据输入的特征图动态调整缩放因子(如图中的 0.5a),然后再通过 Pixel Shuffle 操作放大特征图。同样,生成的特征图 𝐺 和动态偏移特征图 𝑂相加得到最终输出 𝑆。
    优势:动态因子能够自适应地调整不同输入的缩放尺度,特别适合处理多尺度特征或分辨率变化较大的图像。这种机制可以更好地捕捉不同尺度下的细节信息,提高模型的灵活性和适应性。
    DySample动态点采样器的结构如下:
    在这里插入图片描述

1.2DySample核心代码

import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['Dy_Sample']


def normal_init(module, mean=0, std=1, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.normal_(module.weight, mean, std)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)


def constant_init(module, val, bias=0):
    if hasattr(module, 'weight') and module.weight is not None:
        nn.init.constant_(module.weight, val)
    if hasattr(module, 'bias') and module.bias is not None:
        nn.init.constant_(module.bias, bias)


class Dy_Sample(nn.Module):
    def __init__(self, in_channels, scale=2, style='lp', groups=4, dyscope=False):
        super().__init__()
        self.scale = scale
        self.style = style
        self.groups = groups
        assert style in ['lp', 'pl']
        if style == 'pl':
            assert in_channels >= scale ** 2 and in_channels % scale ** 2 == 0
        assert in_channels >= groups and in_channels % groups == 0

        if style == 'pl':
            in_channels = in_channels // scale ** 2
            out_channels = 2 * groups
        else:
            out_channels = 2 * groups * scale ** 2

        self.offset = nn.Conv2d(in_channels, out_channels, 1)
        normal_init(self.offset, std=0.001)
        if dyscope:
            self.scope = nn.Conv2d(in_channels, out_channels, 1)
            constant_init(self.scope, val=0.)

        self.register_buffer('init_pos', self._init_pos())

    def _init_pos(self):
        h = torch.arange((-self.scale + 1) / 2, (self.scale - 1) / 2 + 1) / self.scale
        return torch.stack(torch.meshgrid([h, h])).transpose(1, 2).repeat(1, self.groups, 1).reshape(1, -1, 1, 1)

    def sample(self, x, offset):
        B, _, H, W = offset.shape
        offset = offset.view(B, 2, -1, H, W)
        coords_h = torch.arange(H) + 0.5
        coords_w = torch.arange(W) + 0.5
        coords = torch.stack(torch.meshgrid([coords_w, coords_h])
                             ).transpose(1, 2).unsqueeze(1).unsqueeze(0).type(x.dtype).to(x.device)
        normalizer = torch.tensor([W, H], dtype=x.dtype, device=x.device).view(1, 2, 1, 1, 1)
        coords = 2 * (coords + offset) / normalizer - 1
        coords = F.pixel_shuffle(coords.view(B, -1, H, W), self.scale).view(
            B, 2, -1, self.scale * H, self.scale * W).permute(0, 2, 3, 4, 1).contiguous().flatten(0, 1)
        return F.grid_sample(x.reshape(B * self.groups, -1, H, W), coords, mode='bilinear',
                             align_corners=False, padding_mode="border").view(B, -1, self.scale * H, self.scale * W)

    def forward_lp(self, x):
        if hasattr(self, 'scope'):
            offset = self.offset(x) * self.scope(x).sigmoid() * 0.5 + self.init_pos
        else:
            offset = self.offset(x) * 0.25 + self.init_pos
        return self.sample(x, offset)

    def forward_pl(self, x):
        x_ = F.pixel_shuffle(x, self.scale)
        if hasattr(self, 'scope'):
            offset = F.pixel_unshuffle(self.offset(x_) * self.scope(x_).sigmoid(), self.scale) * 0.5 + self.init_pos
        else:
            offset = F.pixel_unshuffle(self.offset(x_), self.scale) * 0.25 + self.init_pos
        return self.sample(x, offset)

    def forward(self, x):
        if self.style == 'pl':
            return self.forward_pl(x)
        return self.forward_lp(x)


if __name__ == '__main__':
    x = torch.rand(2, 64, 4, 7)
    dys = Dy_Sample(64)
    print(dys(x).shape)

五、添加DySample上采样器

5.1STEP1

首先找到ultralytics/nn文件路径下新建一个Add-module的python文件包【这里注意一定是python文件包,新建后会自动生成_init_.py】,如果已经跟着我的教程建立过一次了可以省略此步骤,随后新建一个DySample.py文件并将上文中提到的注意力机制的代码全部粘贴到此文件中,如下图所示在这里插入图片描述

5.2STEP2

在STEP1中新建的_init_.py文件中导入增加改进模块的代码包如下图所示在这里插入图片描述

5.3STEP3

找到ultralytics/nn文件夹中的task.py文件,在其中按照下图添加在这里插入图片描述

5.4STEP4

定位到ultralytics/nn文件夹中的task.py文件中的def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)函数添加如图代码,【如果不好定位可以直接ctrl+f搜索定位】

在这里插入图片描述

   elif m in {Dy_Sample}:
            c2=ch[f]
            args=[c2,*args]

六、yaml文件与运行

6.1yaml文件

以下是添加DySample动态上采样的yaml文件,大家可以注释自行调节,效果以自己的数据集结果为准

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLO11 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolo11n.yaml' will call yolo11.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.50, 0.25, 1024] # summary: 319 layers, 2624080 parameters, 2624064 gradients, 6.6 GFLOPs
  s: [0.50, 0.50, 1024] # summary: 319 layers, 9458752 parameters, 9458736 gradients, 21.7 GFLOPs
  m: [0.50, 1.00, 512] # summary: 409 layers, 20114688 parameters, 20114672 gradients, 68.5 GFLOPs
  l: [1.00, 1.00, 512] # summary: 631 layers, 25372160 parameters, 25372144 gradients, 87.6 GFLOPs
  x: [1.00, 1.50, 512] # summary: 631 layers, 56966176 parameters, 56966160 gradients, 196.0 GFLOPs

# YOLO11n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 2, C3k2, [256, False, 0.25]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 2, C3k2, [512, False, 0.25]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 2, C3k2, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 2, C3k2, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 2, C2PSA, [1024]] # 10

# YOLO11n head
head:
  - [-1, 1, Dy_Sample, []]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 2, C3k2, [512, False]] # 13

  - [-1, 1, Dy_Sample, []]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 2, C3k2, [256, False]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 2, C3k2, [512, False]] # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 2, C3k2, [1024, True]] # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]] # Detect(P3, P4, P5)

以上添加位置仅供参考,具体添加位置以及模块效果以自己的数据集结果为准

6.2运行成功截图

在这里插入图片描述

OK 以上就是添加DySample上采样的全部过程了,后续将持续更新尽情期待

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2186129.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Koa2+Vue2的简书后台管理系统

文章目录 项目实战:前(vue)后(koa)端分离1、创建简书项目2、创建数据库2.1 创建数据库2.2 连接数据库3、模型对象3.1 设计用户模块的Schema3.2 实现用户增删改查3.2.1 增加用户3.2.2 修改用户3.2.3 删除用户3.2.4 查询用户4、封装业务逻辑层5、封装CRUD6、创建Vue项目7、配…

“衣依”服装销售平台:Spring Boot技术架构剖析

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统,它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等,非常适…

TCP四次挥手过程详解

TCP四次挥手全过程 有几点需要澄清: 1.首先,tcp四次挥手只有主动和被动方之分,没有客户端和服务端的概念 2.其次,发送报文段是tcp协议栈的行为,用户态调用close会陷入到内核态 3.再者,图中的情况前提是双…

【CKA】十、统计node节点ready状态的数量

10、统计node节点ready状态的数量 1. 考题内容: 2. 答题思路: 1、检查有多个node状态ready 2、去除有Taint和NoSchedule的节点数量 3、将结果写入到指定文件中 3. 官网地址: https://kubernetes.io/zh-cn/docs/reference/node/node-statu…

LeetCode[中等] 45. 跳跃游戏 II

给定一个长度为 n 的 0 索引整数数组 nums。初始位置为 nums[0]。 每个元素 nums[i] 表示从索引 i 向前跳转的最大长度。换句话说&#xff0c;如果你在 nums[i] 处&#xff0c;你可以跳转到任意 nums[i j] 处: 0 < j < nums[i] i j < n 返回到达 nums[n - 1] 的最…

在Docker中运行微服务注册中心Eureka

1、Docker简介&#xff1a; 作为开发者&#xff0c;经常遇到一个头大的问题&#xff1a;“在我机器上能运行”。而将SpringCloud微服务运行在Docker容器中&#xff0c;避免了因环境差异带来的兼容性问题&#xff0c;能够有效的解决此类问题。 通过Docker&#xff0c;开发者可…

五子棋双人对战项目(4)——匹配模块(解读代码)

目录 一、约定前后端交互接口的参数 1、websocket连接路径 2、构造请求、响应对象 二、用户在线状态管理 三、房间管理 1、房间类&#xff1a; 2、房间管理器&#xff1a; 四、匹配器(Matcher) 1、玩家实力划分 2、加入匹配队列&#xff08;add&#xff09; 3、移除…

golang grpc初体验

grpc 是一个高性能、开源和通用的 RPC 框架&#xff0c;面向服务端和移动端&#xff0c;基于 HTTP/2 设计。目前支持c、java和go&#xff0c;分别是grpc、grpc-java、grpc-go&#xff0c;目前c版本支持c、c、node.js、ruby、python、objective-c、php和c#。grpc官网 grpc-go P…

Linux相关概念和重要知识点(11)(进程调度、Linux内核链表)

1.Linux调度算法 上篇文章我粗略讲过queue[140]的结构&#xff0c;根据哈希表&#xff0c;我们可以将40个不同优先级的进程借助哈希桶链入queue[140]中。调度器会根据queue的下标来进行调度。但这个具体的调度过程是怎样的呢&#xff1f;以及runqueue和queue[140]的关系是什么…

谷歌给到的185个使用生成式AI的案例

很多公司从利用AI回答问题&#xff0c;进而使用AI进行预测&#xff0c;向使用生成式AI Agent转变。AI Agent的独特之处在于它们可以采取行动以实现特定目标&#xff0c;比如引导购物者找到合适的鞋子&#xff0c;帮助员工寻找合适的健康福利&#xff0c;或在护理人员交接班期间…

python之输入输出

1、输入 Python在控制台输入内容&#xff0c;需要使用input函数。input函数会在控制台等待用户输入&#xff0c;直到用户按下了回车键才算完成输入。 注意&#xff1a;input函数接收的内容为字符串。 str1 input("请输入内容\n") print(str1) print(type(str1))1…

Python酷库之旅-第三方库Pandas(132)

目录 一、用法精讲 591、pandas.DataFrame.plot方法 591-1、语法 591-2、参数 591-3、功能 591-4、返回值 591-5、说明 591-6、用法 591-6-1、数据准备 591-6-2、代码示例 591-6-3、结果输出 592、pandas.DataFrame.plot.area方法 592-1、语法 592-2、参数 592-…

9.28学习笔记

1.ping 网址 2.ssh nscc/l20 3.crtl,打开vscode的setting 4.win 10修改ssh配置文件及其密钥权限为600 - 晴云孤魂 - 博客园 整体来看&#xff1a; 使用transformer作为其主干网络&#xff0c;代替了原先的UNet 在latent space进行训练&#xff0c;通过transformer处理潜…

查缺补漏----该不该考虑不可屏蔽中断

可以看看这个视频&#xff1a; 讨论中断时&#xff0c;该不该考虑不可屏蔽中断&#xff1f;_哔哩哔哩_bilibili 首先要知道一个概念&#xff1a;可屏蔽中断和不可屏蔽中断 可屏蔽中断&#xff1a; 可屏蔽中断是可通过中断屏蔽字来启用或禁用的中断。对于多级中断而言&#…

①EtherCAT转ModbusTCP, EtherCAT/Ethernet/IP/Profinet/ModbusTCP协议互转工业串口网关

EtherCAT/Ethernet/IP/Profinet/ModbusTCP协议互转工业串口网关https://item.taobao.com/item.htm?ftt&id822721028899 协议转换通信网关 EtherCAT 转 ModbusTCP GW系列型号 MS-GW15 简介 MS-GW15 是 EtherCAT 和 Modbus TCP 协议转换网关&#xff0c;为用户提供一种 …

map_set的使用

map_set的使用 关联式容器树形结构的关联式容器setset的介绍set的使用 multisetmultiset的介绍multiset的使用 mapmap的介绍map的使用键值对 multimapmultimap的介绍 &#x1f30f;个人博客主页&#xff1a;个人主页 关联式容器 在初阶阶段&#xff0c;我们已经接触过STL中的部…

黑科技外绘神器:一键扩展图像边界

黑科技外绘神器&#xff1a;一键扩展图像边界 Diffusers Image Outpaint✨是一个开源工具&#xff0c;能智能扩展图像边界&#xff0c;创造完美视觉效果&#x1f3de;️。用户可自定义风格&#xff0c;生成高清图像&#x1f929;&#xff0c;应用场景广泛&#xff0c;释放你的…

大模型~合集6

我自己的原文哦~ https://blog.51cto.com/whaosoft/11566566 # 深度模型融合&#xff08;LLM/基础模型/联邦学习/微调等&#xff09; 23年9月国防科大、京东和北理工的论文“Deep Model Fusion: A Survey”。 深度模型融合/合并是一种新兴技术&#xff0c;它将多个深度学习模…

爬虫——爬取小音乐网站

爬虫有几部分功能&#xff1f;&#xff1f;&#xff1f; 1.发请求&#xff0c;获得网页源码 #1.和2是在一步的 发请求成功了之后就能直接获得网页源码 2.解析我们想要的数据 3.按照需求保存 注意&#xff1a;开始爬虫前&#xff0c;需要给其封装 headers {User-…

本地化测试对游戏漏洞修复的影响

本地化测试在游戏开发的质量保证过程中起着至关重要的作用&#xff0c;尤其是在修复bug方面。当游戏为全球市场做准备时&#xff0c;它们通常会被翻译和改编成各种语言和文化背景。这种本地化带来了新的挑战&#xff0c;例如潜在的语言错误、文化误解&#xff0c;甚至是不同地区…