YOLOv5、YOLOv8改进:gnconv 门控递归卷积

news2024/10/5 22:23:29

1.简介


论文地址:https://arxiv.org/abs/2207.14284
代码地址:https://github.com/raoyongming/HorNet

视觉Transformer的最新进展表明,在基于点积自注意力的新空间建模机制驱动的各种任务中取得了巨大成功。在本文中,作者证明了视觉Transformer背后的关键成分,即输入自适应、长程和高阶空间交互,也可以通过基于卷积的框架有效实现。作者提出了递归门卷积(g n Conv),它用门卷积和递归设计进行高阶空间交互。新操作具有高度灵活性和可定制性,与卷积的各种变体兼容,并将自注意力中的二阶交互扩展到任意阶,而不引入显著的额外计算。g nConv可以作为一个即插即用模块来改进各种视觉Transformer和基于卷积的模型。基于该操作,作者构建了一个新的通用视觉主干族,名为HorNet。在ImageNet分类、COCO对象检测和ADE20K语义分割方面的大量实验表明,HorNet在总体架构和训练配置相似的情况下,优于Swin Transformers和ConvNeXt。HorNet还显示出良好的可扩展性,以获得更多的训练数据和更大的模型尺寸。除了在视觉编码器中的有效性外,作者还表明g n Conv可以应用于任务特定的解码器,并以较少的计算量持续提高密集预测性能。本文的结果表明,g n Conv可以作为一个新的视觉建模基本模块,有效地结合了视觉Transformer和CNN的优点。

(a)标准卷积运算没有明确考虑空间交互。


(b) 动态卷积 [27, 4] 和 SE [25] 引入了动态权重,以通过额外的空间交互来提高卷积的建模能力。

(c) 自注意力操作 通过两个连续的矩阵乘法执行二阶空间交互。

 (d) gnConv 使用具有门控卷积和递归设计的高效实现来实现任意阶空间交互。在本文中,作者总结了视觉Transformers成功背后的关键因素是通过自注意力操作实现输入自适应、远程和高阶空间交互的空间建模新方法。虽然之前的工作已经成功地将元架构、输入自适应权重生成策略和视觉Transformers的大范围建模能力迁移到CNN模型,但尚未研究高阶空间交互机制。作者表明,使用基于卷积的框架可以有效地实现所有三个关键要素。作者提出了递归门卷积(g nConv),它与门卷积和递归设计进行高阶空间交互。与简单地模仿自注意力中的成功设计不同,g n Conv有几个额外的优点:1)**效率。**基于卷积的实现避免了自注意力的二次复杂度。在执行空间交互期间逐步增加通道宽度的设计也使能够实现具有有限复杂性的高阶交互;2) 可扩展。将自注意力中的二阶交互扩展到任意阶,以进一步提高建模能力。由于没有对空间卷积的类型进行假设,g n Conv与各种核大小和空间混合策略兼容;3) 平移等变性。g n Conv完全继承了标准卷积的平移等变性,这为主要视觉引入了有益的归纳偏置。
                                          

基于g n Conv,作者构建了一个新的通用视觉主干族,名为HorNet。作者在ImageNet分类、COCO对象检测和ADE20K语义分割上进行了大量实验,以验证本文模型的有效性。凭借相同的7×7卷积核/窗口和类似的整体架构和训练配置,HorNet优于Swin和ConvNeXt在不同复杂度的所有任务上都有很大的优势。通过使用全局卷积核大小,可以进一步扩大差距。HorNet还显示出良好的可扩展性,可以扩展到更多的训练数据和更大的模型尺寸,在ImageNet上达到87.7%的top-1精度,在ADE20K val上达到54.6%的mIoU,在COCO val上通过ImageNet-22K预训练达到55.8%的边界框AP。除了在视觉编码器中应用g n Conv外,作者还进一步测试了在任务特定解码器上设计的通用性。通过将g n Conv添加到广泛使用的特征融合模型FPN,作者开发了HorFPN来建模不同层次特征的高阶空间关系。作者观察到,HorFPN还可以以较低的计算成本持续改进各种密集预测模型。结果表明,g n Conv是一种很有前景的视觉建模方法,可以有效地结合视觉Transofrmer和CNN的优点。

2.YOLOv5代码修改

2.1 修改yaml文件  

我这边只是提供参考,你可以修改任意位置

# YOLOAir 🚀 by iscyy, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 3, gnconv, [256]], # 修改示例
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)

   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

2.2 common添加代码

class HorLayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).# https://ar5iv.labs.arxiv.org/html/2207.14284
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError # by iscyy/air
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x
 
class GlobalLocalFilter(nn.Module):
    def __init__(self, dim, h=14, w=8):
        super().__init__()
        self.dw = nn.Conv2d(dim // 2, dim // 2, kernel_size=3, padding=1, bias=False, groups=dim // 2)
        self.complex_weight = nn.Parameter(torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)
        trunc_normal_(self.complex_weight, std=.02)
        self.pre_norm = HorLayerNorm(dim, eps=1e-6, data_format='channels_first')
        self.post_norm = HorLayerNorm(dim, eps=1e-6, data_format='channels_first')
 
    def forward(self, x):
        x = self.pre_norm(x)
        x1, x2 = torch.chunk(x, 2, dim=1)
        x1 = self.dw(x1)
 
        x2 = x2.to(torch.float32)
        B, C, a, b = x2.shape
        x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho')
 
        weight = self.complex_weight
        if not weight.shape[1:3] == x2.shape[2:4]:
            weight = F.interpolate(weight.permute(3,0,1,2), size=x2.shape[2:4], mode='bilinear', align_corners=True).permute(1,2,3,0)
 
        weight = torch.view_as_complex(weight.contiguous())
 
        x2 = x2 * weight
        x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho')
 
        x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim=2).reshape(B, 2 * C, a, b)
        x = self.post_norm(x)
        return x
 
class gnconv(nn.Module):
    def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0):
        super().__init__()
        self.order = order
        self.dims = [dim // 2 ** i for i in range(order)]
        self.dims.reverse()
        self.proj_in = nn.Conv2d(dim, 2*dim, 1)
 
        if gflayer is None:
            self.dwconv = get_dwconv(sum(self.dims), 7, True)
        else:
            self.dwconv = gflayer(sum(self.dims), h=h, w=w)
        
        self.proj_out = nn.Conv2d(dim, dim, 1)
 
        self.pws = nn.ModuleList(
            [nn.Conv2d(self.dims[i], self.dims[i+1], 1) for i in range(order-1)]
        )
        self.scale = s
 
    def forward(self, x, mask=None, dummy=False):
        # B, C, H, W = x.shape gnconv [512]by iscyy/air
        fused_x = self.proj_in(x)
        pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)
        dw_abc = self.dwconv(abc) * self.scale
        dw_list = torch.split(dw_abc, self.dims, dim=1)
        x = pwa * dw_list[0]
        for i in range(self.order -1):
            x = self.pws[i](x) * dw_list[i+1]
        x = self.proj_out(x)
 
        return x
 
def get_dwconv(dim, kernel, bias):
    return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2 ,bias=bias, groups=dim)
 
 
class HorBlock(nn.Module):
    r""" HorNet block
    """
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, gnconv=gnconv):
        super().__init__()
 
        self.norm1 = HorLayerNorm(dim, eps=1e-6, data_format='channels_first')
        self.gnconv = gnconv(dim)
        self.norm2 = HorLayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
 
        self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
 
        self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
 
    def forward(self, x):
        B, C, H, W  = x.shape # [512]
        if self.gamma1 is not None:
            gamma1 = self.gamma1.view(C, 1, 1)
        else:
            gamma1 = 1
        x = x + self.drop_path(gamma1 * self.gnconv(self.norm1(x)))
 
        input = x
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm2(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma2 is not None:
            x = self.gamma2 * x
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
 
        x = input + self.drop_path(x)
        return x

 

2.3 YOLO注册

找到 parse_model并注册

 if m in [Conv, ,gnconv,GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                 BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, C3HB, C3RFEM, MultiSEAM, SEAM, C3STR, MobileOneBlock]:
            c1, c2 = ch[f], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/957184.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【进程间通信】IPC对象(进程间通信的精髓)

(꒪ꇴ꒪ ),Hello我是祐言QAQ我的博客主页:C/C语言,数据结构,Linux基础,ARM开发板,网络编程等领域UP🌍快上🚘,一起学习,让我们成为一个强大的攻城狮&#xff0…

2023年8月随笔之有顾忌了

1. 回头看 日更坚持了243天。 读《发布!设计与部署稳定的分布式系统》终于更新完成 选读《SQL经典实例》也更新完成 读《高性能MySQL(第4版)》开更,但目前暂缓 读《SQL学习指南(第3版)》开更并持续更新…

3、QT 的基础控件的使用

一、qFileDialog 文件窗体 Header: #include <QFileDialog> qmake: QT widgets Inherits: QDialog静态函数接口&#xff1a; void Widget::on_pushButton_clicked() {//获取单个文件的路径名QString filename QFileDialog :: getOpenFileName(this, tr("Open Fi…

【taro react】(游戏) ---- 贪吃蛇

1. 预览 2. 实现思路 实现食物类&#xff0c;食物坐标和刷新食物的位置&#xff0c;以及获取食物的坐标点&#xff1b;实现计分面板类&#xff0c;实现吃食物每次的计分以及积累一定程度的等级&#xff0c;实现等级和分数的增加&#xff1b;实现蛇类&#xff0c;蛇类分为蛇头和…

16 Linux之JavaEE定制篇-搭建JavaEE环境

16 Linux之JavaEE定制篇-搭建JavaEE环境 文章目录 16 Linux之JavaEE定制篇-搭建JavaEE环境16.1 概述16.2 安装JDK16.3 安装tomcat16.4 安装idea2020*16.5 安装mysql5.7 学习视频来自于B站【小白入门 通俗易懂】2021韩顺平 一周学会Linux。可能会用到的资料有如下所示&#xff0…

Matlab图像处理-灰度插值法

最近邻法 最近邻法是一种最简单的插值算法&#xff0c;输出像素的值为输入图像中与其最邻近的采样点的像素值。是将(u0,v0)(u_0,v_0)点最近的整数坐标u,v(u,v)点的灰度值取为(u0,v0)(u_0,v_0)点的灰度值。 在(u0,v0)(u_0,v_0)点各相邻像素间灰度变化较小时&#xff0c;这种方…

使用ELK收集解析nginx日志和kibana可视化仪表盘

文章目录 ELK生产环境配置filebeat 配置logstash 配置 kibana仪表盘配置配置nginx转发ES和kibanaELK设置账号和密码 ELK生产环境配置 ELK收集nginx日志有多种方案&#xff0c;一般比较常见的做法是在生产环境服务器搭建filebeat 收集nginx的文件日志并写入到队列&#xff08;k…

图解 STP

网络环路 现在我们的生活已经离不开网络&#xff0c;如果我家断网&#xff0c;我会抱怨这什么破网络&#xff0c;影响到我刷抖音、打游戏&#xff1b;如果公司断网&#xff0c;那老板估计会骂娘&#xff0c;因为会影响到公司正常运转&#xff0c;直接造成经济损失。网络通信中&…

传统分拣弊端明显,AI机器视觉赋能物流行业包裹分类产线数智化升级

随着电子商务的快速发展&#xff0c;物流行业的包裹数量持续增长&#xff0c;给物流企业带来了巨大的运营压力。目前&#xff0c;国内大型物流运转中心已开始采用机器视觉自动化设备&#xff0c;但多数快递公司处于半自动化状态&#xff0c;中小型物流分拣中心目前仍靠人工录入…

iOS系统修复软件 Fix My iPhone for Mac

Fix My iPhone for Mac是一款iOS系统恢复工具。修复您的iPhone卡在Apple徽标&#xff0c;黑屏&#xff0c;冻结屏幕&#xff0c;iTunes更新/还原错误和超过20个iOS 12升级失败。这个macOS桌面应用程序提供快速&#xff0c;即时的解决方案来修复您的iOS系统问题&#xff0c;而不…

基于单片机的遥控器设计

一、项目介绍 随着科技的不断发展&#xff0c;红外遥控器已经成为我们日常生活中普遍使用的一种电子设备。它能够给我们带来便捷和舒适&#xff0c;减少人工操作的繁琐性。然而&#xff0c;在实际应用中&#xff0c;有时候我们可能需要制作一个自己的红外遥控器&#xff0c;以…

【云原生进阶之PaaS中间件】第一章Redis-2.4缓存更新机制

1 缓存和数据库的数据一致性分析 1.1 Redis 中如何保证缓存和数据库双写时的数据一致性&#xff1f; 无论先操作db还是cache&#xff0c;都会有各自的问题&#xff0c;根本原因是cache和db的更新不是一个原子操作&#xff0c;因此总会有不一致的问题。想要彻底解决这种问题必须…

目标检测YOLO算法,先从yolov1开始

学习资源 有一套配套的学习资料&#xff0c;才能让我们的学习事半功倍。 yolov1论文原址&#xff1a;You Only Look Once: Unified, Real-Time Object Detection 代码地址&#xff1a;darknet: Convolutional Neural Networks (github.com) 深度学习经典检测方法 one-stag…

React 18 对 state 进行保留和重置

参考文章 对 state 进行保留和重置 各个组件的 state 是各自独立的。根据组件在 UI 树中的位置&#xff0c;React 可以跟踪哪些 state 属于哪个组件。可以控制在重新渲染过程中何时对 state 进行保留和重置。 UI 树 浏览器使用许多树形结构来为 UI 建立模型。DOM 用于表示 …

3DCAT携手华为,打造XR虚拟仿真实训实时云渲染解决方案

2023年5月8日-9日&#xff0c;以 因聚而生 众志有为 为主题的 华为中国合作伙伴大会2023 在深圳国际会展中心隆重举行。本次大会汇聚了ICT产业界的广大新老伙伴朋友&#xff0c;共同探讨数字化转型的新机遇&#xff0c;共享数字化未来的新成果。 华为中国合作伙伴大会2023现场&…

Python小知识 - 使用Python进行数据分析

使用Python进行数据分析 数据分析简介 数据分析&#xff0c;又称为信息分析&#xff0c;是指对数据进行综合处理、归纳提炼、概括总结的过程&#xff0c;是数据处理的第一步。 数据分析的目的是了解数据的内在规律&#xff0c;为数据挖掘&#xff0c;并应用于商业决策、科学研究…

04ShardingSphere-JDBC垂直分片

1、准备服务器 比如商城项目中&#xff0c;有用户、订单等系统&#xff0c;数据库在设计时用户信息与订单信息在同一表中。这里创建用户服务、订单服务实现数据库的用户信息和订单信息垂直分片 服务器规划&#xff1a;使用docker方式创建如下容器 服务器&#xff1a;容器名se…

P13 VUE 二级menu实现

主要修改以下几个点&#xff1a; CommonAside.vue中 外层便利有孩子节点&#xff0c;关键词key是对应的标签&#xff0c;class动态图表渲染 内层遍历不能再用item&#xff0c;用subitem&#xff0c;遍历该item.childeren&#xff0c;关键词是path&#xff0c; <templat…

java 批量下载将多个文件(minio中存储)压缩成一个zip包

我的需求是将minio中存储的文件按照查询条件查询出来统一压成一个zip包然后下载下来。 思路&#xff1a;针对这个需求&#xff0c;其实可以有多个思路&#xff0c;不过也大同小异&#xff0c;一般都是后端返回流文件前端再处理下载&#xff0c;也有少数是压缩成zip包之后直接给…

【C++设计模式】详解装饰模式

2023年8月31日&#xff0c;周四上午 这是我目前碰到的最难的设计模式..... 非常难以理解而且比较灵活多半&#xff0c;学得贼难受&#xff0c;写得贼费劲..... 2023年8月31日&#xff0c;周四晚上19:48 终于写完了&#xff0c;花了一天的时间来学习装饰模式和写这篇博客。 …