YOLO改进涨点,RCS-YOLO:检测头和网络结构的改进

news2024/11/20 1:32:21

目录

摘要

原理

总体结构图

 RCS模块原理

代码实现

RCS-Based One-Shot Aggregation

代码实现 

检测头改进

手动计算anchor代码

yaml文件

已详细修改的代码

程序启动命令

可论文指导  +V ------------> jiabei-545

往期推荐 


摘要

凭借速度和准确性之间的出色平衡,尖端的 YOLO 框架已成为最有效的目标检测算法之一。然而, YOLO 网络的性能还有提升的空间。我们提出了 RCS 和 RCS 的一次聚合(RCS-OSA),它将特征级联和计算效率联系起来,以提取更丰富的信息并减少时间消耗。所提出的模型在速度和准确性上超越了YOLOv6、YOLOv7和YOLOv8。值得注意的是,与YOLOv7相比,RCS-YOLO的精度提高了1%,推理速度提高了60%,每秒检测到114.8张图像(FPS)。

原理

1、首先通过将 RepVGG/RepConv 与 ShuffleNet 相结合来开发 RepVGG/RepConv ShuffleNet (RCS),它受益于重新参数化,可以在训练阶段提供更多特征信息并减少推理时间。然后,我们构建了一个基于 RCS 的一次性聚合(RCSOSA)模块,该模块不仅允许低成本的内存消耗,而且还允许语义信息提取。 

2、通过将开发的 RCS-OSA 和 RepVGG/RepConv 与路径聚合相结合,设计了 YOLO 架构的新骨干和颈部网络,以缩短特征预测层之间的信息路径。这促进了准确的定位信息快速传播到骨干网络和颈部网络中的特征层次结构。

总体结构图
 RCS -YOLO的架构主要由RCS -OSA(蓝色)和RepVGG(橙色)模块组成。 n表示堆叠的RCS模块的数量。 ncls 表示检测到的对象中的类数。
 RCS模块原理

受ShuffleNet的启发,设计了一种基于通道shuffle的结构重参数化卷积。下图为RCS的结构示意图。

RCS的结构:左边是训练阶段的 RepVGG。 右边是模型推理(或部署)期间的 RepConv。

假设输入张量的特征维度为 C×H×W ,在通道分割算子之后,它被分为两个不同的通道张量,其维度相同C×H×W 。对于其中一个张量,我们使用恒等分支、1×1 卷积和 3×3 卷积来构造训练时 RCS。在推理阶段,使用结构重参数化将恒等分支、1×1 卷积和3×3 卷积转换为3×3 RepConv。多分支拓扑架构可以在训练时学习丰富的特征信息,简化的单分支架构可以节省推理时的内存消耗,实现快速推理。对其中一个张量进行多分支训练后,它以通道方式连接到另一个张量。通道混洗算子还用于增强两个张量之间的信息融合,从而可以以较低的计算复杂度实现输入的不同通道特征之间的深度测量。

代码实现
class RepVGG(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size=3,
                 stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):
        super(RepVGG, self).__init__()
        self.deploy = deploy
        self.groups = groups
        self.in_channels = in_channels

        padding_11 = padding - kernel_size // 2

        self.nonlinearity = nn.SiLU()
        # self.nonlinearity = nn.ReLU()

        if use_se:
            self.se = SEBlock(out_channels, internal_neurons=out_channels // 16)
        else:
            self.se = nn.Identity()

        if deploy:
            self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                         stride=stride,
                                         padding=padding, dilation=dilation, groups=groups, bias=True,
                                         padding_mode=padding_mode)

        else:
            self.rbr_identity = nn.BatchNorm2d(
                num_features=in_channels) if out_channels == in_channels and stride == 1 else None
            self.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                                     stride=stride, padding=padding, groups=groups)
            self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride,
                                   padding=padding_11, groups=groups)
            # print('RepVGG Block, identity = ', self.rbr_identity)

    def get_equivalent_kernel_bias(self):
        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
        if kernel1x1 is None:
            return 0
        else:
            return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])

    def _fuse_bn_tensor(self, branch):
        if branch is None:
            return 0, 0
        if isinstance(branch, nn.Sequential):
            kernel = branch.conv.weight
            running_mean = branch.bn.running_mean
            running_var = branch.bn.running_var
            gamma = branch.bn.weight
            beta = branch.bn.bias
            eps = branch.bn.eps
        else:
            assert isinstance(branch, nn.BatchNorm2d)
            if not hasattr(self, 'id_tensor'):
                input_dim = self.in_channels // self.groups
                kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
                for i in range(self.in_channels):
                    kernel_value[i, i % input_dim, 1, 1] = 1
                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
            kernel = self.id_tensor
            running_mean = branch.running_mean
            running_var = branch.running_var
            gamma = branch.weight
            beta = branch.bias
            eps = branch.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta - running_mean * gamma / std

    def forward(self, inputs):
        if hasattr(self, 'rbr_reparam'):
            return self.nonlinearity(self.se(self.rbr_reparam(inputs)))

        if self.rbr_identity is None:
            id_out = 0
        else:
            id_out = self.rbr_identity(inputs)

        return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))

    def fusevggforward(self, x):
        return self.nonlinearity(self.rbr_dense(x))
RCS-Based One-Shot Aggregation

此外,还设计了单次聚合(OSA)模块来克服 DenseNet 中密集连接的低效率,通过用多感受野表示多样化特征并在最后的特征图中仅聚合所有特征一次。 VoVNet V1 [14 ] 和 V2 [15 ] 在其架构中使用 OSA 模块来构建轻量级和大规模目标检测器,其性能优于广泛使用的 ResNet 主干网络,具有更快的速度和更好的能源效率。

RCS-OSA的结构。 n表示堆叠的RCS模块的数量。
代码实现 
class RCSOSA(nn.Module):
    # VoVNet with Res Shuffle RepVGG
    def __init__(self, c1, c2, n=1, se=False, e=0.5, stackrep=True):
        super().__init__()
        n_ = n // 2
        c_ = make_divisible(int(c1 * e), 8)
        # self.conv1 = Conv(c1, c_)
        self.conv1 = RepVGG(c1, c_)
        self.conv3 = RepVGG(int(c_ * 3), c2)
        self.sr1 = nn.Sequential(*[SR(c_, c_) for _ in range(n_)])
        self.sr2 = nn.Sequential(*[SR(c_, c_) for _ in range(n_)])

        self.se = None
        if se:
            self.se = SEBlock(c2)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.sr1(x1)
        x3 = self.sr2(x2)
        x = torch.cat((x1, x2, x3), 1)
        return self.conv3(x) if self.se is None else self.se(self.conv3(x))

检测头改进

为了进一步减少推理时间,我们将 RepVGG 和 Detect 组成的检测头数量从 3 个减少到 2 个。YOLOv5、YOLOv6、YOLOv7 和 YOLOv8 有 3 个检测头。然而,我们只使用两个特征层进行预测,将原来九个不同尺度的anchor数量减少到四个,并使用K-means无监督聚类方法重新生成不同尺度的anchor。这不仅减少了RCS-YOLO的卷积层数和计算复杂度,而且减少了推理阶段网络的整体计算要求和后处理非极大值抑制的计算时间。

Yolov5的检测目标分别从检测输出17,20,23

YOLOv5结构图
手动计算anchor代码

YOLO 手动计算anchor的值-CSDN博客不需要运行 kmeans.py,运行 clauculate_anchors.py 即可。创建程序两个程序 kmeans.py 以及 clauculate_anchors.py。会调用 kmeans.py 聚类生成新anchors的文件。kmeans.py 程序如下:这不需要运行,也不需要更改。如果报错,可以查看第 13 行内容。会生成anchors文件。https://blog.csdn.net/m0_67647321/article/details/136315355?spm=1001.2014.3001.5501

yaml文件
# RCS-YOLO v1.0  (Two heads)

# Parameters
nc: 1  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 0.75  # layer channel multiple
anchors:
  - [87,90,  127,139]  # P4/16
  - [154,171,  191,240]  # P5/32

# backbone
backbone: # 12462
  # [from, number, module, args]
  [[-1, 1, RepVGG, [64, 3, 2]],  # 0-P1/2
   [-1, 1, RepVGG, [128, 3, 2]],  # 1-P2/4
   [-1, 2, RCSOSA, [128]],
   [-1, 1, RepVGG, [256, 3, 2]],  # 3-P3/8
   [-1, 2, RCSOSA, [256]],
   [-1, 1, RepVGG, [512, 3, 2]],  # 5-P4/16
   [-1, 4, RCSOSA, [512, True]],
   [-1, 1, RepVGG, [1024, 3, 2]],  # 7-P5/32
   [-1, 2, RCSOSA, [1024, True]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# head
head:
  [[-1, 1, Conv, [512, 1, 1]], # 10
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [-1, 2, RCSOSA, [512]],  # 12

   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 2, RCSOSA, [512]],  # 14

   [-1, 1, RepVGG, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 2, RCSOSA, [768]],  # 17

   [14, 1, RepVGG, [512, 3, 1]],
   [17, 1, RepVGG, [768, 3, 1]],

   [[18, 19], 1, IDetect, [nc, anchors]],  # Detect(P4, P5)
  ]
已详细修改的代码

链接: https://pan.baidu.com/s/1EOeZP5kJ92tVjZgFtM0Wdg?pwd=zk88 提取码: zk88 

程序启动命令

单GPU

python train.py --workers 8 --device 0 --batch-size 32 --data data/br35h.yaml --img 640 640 --cfg cfg/training/rcs-yolo.yaml --weights '' --name rcs-yolo --hyp data/hyp_training.yaml

多GPU

python -m torch.distributed.launch --nproc_per_node 4 --master_port 9527 train.py --workers 8 --device 0,1,2,3 --sync-bn --batch-size 128 --data data/br35h.yaml --img 640 640 --cfg cfg/training/rcs-yolo.yaml --weights '' --name rcs-yolo --hyp data/hyp_training.yaml
可论文指导  +V ------------> jiabei-545
往期推荐 

Yolov8有效涨点:YOLOv8-AM,添加多种注意力模块提高检测精度,含代码,超详细-CSDN博客

有效涨点,增强型 YOLOV8 与多尺度注意力特征融合,附代码,详细步骤-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1475513.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

串联所有单词的子串

题目链接 串联所有单词的子串 题目描述 注意点 words[i] 和 s 由小写英文字母组成1 < words.length < 5000可以以 任意顺序 返回答案words中所有字符串长度相同 解答思路 根据滑动窗口哈希表解决本题&#xff0c;哈希表存储words中所有的单词及单词的出现次数&#…

【mysql 数据库事务】开启事务操作数据库,写入失败后,不回滚,会有问题么? 这里隐藏着大坑,复试,面试时可以镇住面试老师!!!!

建表字段: CREATE TABLE user (id INT(11) NOT NULL AUTO_INCREMENT,nickname VARCHAR(32) NOT NULL COLLATE utf8mb4_general_ci,email VARCHAR(32) NOT NULL COLLATE utf8mb4_general_ci,status SMALLINT(6) UNSIGNED NULL DEFAULT NULL,password VARCHAR(256) NULL DEFAULT…

【Vue3】3-5 :组件内容的组合与分发

文章目录 前言问题本节内容 插槽 slot>> 使用>> 效果 (前后相同) 插槽的特性实战> 实例 1&#xff1a;【作用域】根组件中渲染> 效果> 原因>> 实例 2&#xff1a;【具名插槽】即多个插槽> 效果>> 实例 3&#xff1a;【作用域插槽】插槽能访…

kubectl 声明式资源管理方式

目录 介绍 YAML 语法格式 命令 应用yaml文件指定的资源 删除yaml文件指定的资源 查看资源的yaml格式信息 查看yaml文件字段说明 查看 api 资源版本标签 修改yaml文件指定的资源 离线修改 在线修改 编写yaml文件 创建资源对象 查看创建的pod资源 创建service服务对…

WordPress 插件存在漏洞,500 万网站面临严重安全风险

网络安全研究人员近期发现 WordPress LiteSpeed Cache 插件中存在一个安全漏洞&#xff0c;该漏洞被追踪为 CVE-2023-40000&#xff0c;未经身份验证的威胁攻击者可利用该漏洞获取超额权限。 LiteSpeed Cache 主要用于提高网站性能&#xff0c;据不完全统计已经有 500 多万安装…

K8s Pod资源管理组件

目录 Pod基础概念 在Kubrenetes集群中Pod有如下两种使用方式 pause容器使得Pod中的所有容器可以共享两种资源 网络 存储 总结 kubernetes中的pause容器主要为每个容器提供功能 Kubernetes设计这样的Pod概念和特殊组成结构的用意 通常把Pod分为以下几类 自主式Pod 控…

幻兽帕鲁(Palworld 1.4.11.5.0)私有服务器搭建(docker版)

文章目录 说明客户端安装服务器部署1Panel安装和配置docker服务初始化设置设置开机自启动设置镜像加速 游戏服务端部署游戏服务端参数可视化配置 Palworld连接服务器问题总结 服务端升级&#xff08;1.5.0&#xff09; 说明 服务器硬件要求&#xff1a;Linux系统/Window系统&a…

振动解调用的包络谱计算

1缘起 在振动分析中&#xff0c;对于一些高频频点的分析计算&#xff0c;使用包络谱技术&#xff0c;进而得到特化谱是最适宜的。 1.1 包络谱是什么样子的&#xff1f; 我们看matlab信号分析中提供的一个实例&#xff1a; https://www.mathworks.com/help/signal/ug/comput…

前端JS 时间复杂度和空间复杂度

时间复杂度 BigO 算法的时间复杂度通常用大 O 符号表述&#xff0c;定义为 T(n) O(f(n)) 实际就是计算当一个一个问题量级&#xff08;n&#xff09;增加的时候&#xff0c;时间T增加的一个趋势 T(n)&#xff1a;时间的复杂度&#xff0c;也就相当于所消耗的时长 O&#xff1…

纯国产轻量化数字孪生:智慧城市、智慧工厂、智慧校园、智慧社区。。。

AMRT 3D数字孪生引擎介绍 AMRT3D引擎是一款融合了眸瑞科技的AMRT格式与轻量化处理技术为基础&#xff0c;以降本增效为目标&#xff0c;支持多端发布的一站式纯国产自研的CS架构项目开发引擎。 引擎包括场景搭建、UI拼搭、零代码交互事件、光影特效组件、GIS/BIM组件、实时数据…

十四、综合项目(斗地主)

综合项目&#xff08;斗地主&#xff09; 1.准备牌、洗牌、发牌、看牌2.对每人手中的牌进行排序2.1 排序方法1&#xff08;利用序号进行排序&#xff09;2.2排序方法2&#xff08;给每一张牌计算价值&#xff09; 3.两个实体类3.1 User3.2 Poker 4.登录页面4.1 验证码代码4.2 登…

【蓝桥杯】快读|min和max值的设置|小明和完美序列|​顺子日期​|星期计算|山

目录 一、输入的三种方式 1.最常见的Scanner的输入方法 2.数据多的时候常用BufferedReader快读 3.较麻烦的StreamTokenizer快读&#xff08;用的不多&#xff09; StreamTokenizer常见错误&#xff1a; 二、min和max值的设置 三、妮妮的翻转游戏 四、小明和完美序列 五…

如何删除视频中不想要的部分?分享实用工具和步骤!

在数字化时代&#xff0c;视频已成为我们生活中不可或缺的一部分。无论是观看电影、记录生活&#xff0c;还是制作专业的广告、教学材料&#xff0c;我们都需要对视频进行编辑处理。其中&#xff0c;删除视频中不想要的部分是最常见的需求之一。那么&#xff0c;如何轻松实现这…

Pytorch添加自定义算子之(5)-配置GPU形式的简单add自定义算子

参考:https://zhuanlan.zhihu.com/p/358778742 一、头文件 命名为:add2.h void launch_add2(float *c,const float *a,const float *b,int n);

Jvm之内存泄漏

1 内存溢出 1.1 概念 java.lang.OutOfMemoryError&#xff0c;是指程序在申请内存时&#xff0c;没有足够的内存空间供其使用&#xff0c;出现OutOfMemoryError。产生该错误的原因主要包括&#xff1a;JVM内存过小。程序不严密&#xff0c;产生了过多的垃圾。 程序体现: 内…

Win UI3开发笔记(四)设置主题续

上文讲到过关于界面和标题栏以及普通文本的主题设置&#xff0c;这篇说一下关于对话框的主题设置。 我最终没找到办法&#xff0c;寻求办法的朋友可以不用接着看了&#xff0c;以下只是过程。 一个对话框包括标题部分、内容部分和按钮部分&#xff0c;其中&#xff0c;在Cont…

论文笔记:A survey on zero knowledge range proofs and applications

https://link.springer.com/article/10.1007/s42452-019-0989-z 描述了构建零知识区间证明&#xff08;ZKRP&#xff09;的不同策略&#xff0c;例如2001年Boudot提出的方案&#xff1b;2008年Camenisch等人提出的方案&#xff1b;以及2017年提出的Bulletproofs。 Introducti…

Python 全栈系列227 部署chatglm3-API接口

说明 上一篇介绍了基于算力租用的方式部署chatglm3, 见文章&#xff1b;本篇接着看如何使用API方式进行使用。 内容 1 官方接口 详情可见接口调用文档 调用有两种方式&#xff0c;SDK包和Http。一般来说&#xff0c;用SDK会省事一些。 以下是Python SDK包的git项目地址 安…

ChatGPT 正测试Android屏幕小组件;联想ThinkBook 推出透明笔记本电脑

▶ ChatGPT 测试屏幕小组件 近日 ChatGPT 正在测试 Android 平台上的屏幕小组件&#xff0c;类似于手机中的悬浮窗&#xff0c;按住 Android 手机主屏幕上的空白位置就可以调出 ChatGPT 的部件菜单。 菜单中提供了许多选项&#xff0c;包括文本、语音和视频查询的快捷方式&…

vue3的echarts从后端获取数据,用于绘制图表

场景需求&#xff1a;后端采用flask通过pymysql从数据库获取数据&#xff0c;并返回给前端。前端vue3利用axios获取数据并运用到echarts绘制图表。 第一步&#xff0c;vue中引入echarts 首先vue下载echarts npm install echarts 然后在main.js文件写如下代码 import {create…