yolo改进替换VanillaNet backbone

news2024/11/26 4:33:18

论文地址:https://arxiv.org/pdf/2305.12972.pdf

代码地址:GitHub - huawei-noah/VanillaNet

VanillaNet简介

基础模型的核心是“更多不同”的哲学,计算机视觉和自然语言处理的惊人成功就是例证。 然而,优化的挑战和Transformer模型固有的复杂性要求范式向简单转变。 在本研究中,我们介绍了VanillaNet,一个在设计中包含优雅的神经网络架构。 通过避免高深度、快捷方式和复杂的操作,如自注意力,VanillaNet是令人耳目一新的简洁,但非常强大。 每一层都被精心制作成紧凑和简单的结构,非线性激活函数在训练后被剪枝,以恢复原始的架构。 VanillaNet克服了固有复杂性的挑战,使其成为资源紧张环境的理想选择。 其易于理解和高度简化的体系结构为高效部署打开了新的可能性。 大量实验表明,VanillaNet提供了与著名的深度神经网络和视觉Transformer相当的性能,展示了极简主义在深度学习中的力量。 VanillaNet的这一远见之旅具有重大的潜力,可以重新定义并挑战基础模型的现状,为优雅有效的模型设计开辟一条新的道路。

        在过去的几十年里,研究人员在神经网络的基本设计上达成了一些共识。大多数最先进的图像分类网络架构应该由三部分组成:

  1. 主干块,用于将输入图像从3个通道转换为多个通道,并进行下采样,一个学习有用的信息主题
  2. 主体,通常有四个阶段,每个阶段都是通过堆叠相同的块来派生的。在每个阶段之后,特征的通道将扩展,而高度和宽度将减小。不同的网络利用和堆叠不同种类的块来构建深度模型。
  3. 全连接层分类输出。

        尽管现有的深度网络取得了成功,但它们利用大量复杂层来为以下任务提取高级特征。例如,著名的ResNet需要34或50个带shortcat的层才能在ImageNet上实现超过70%的top-1精度。Vit的基础版本由62层组成,因为自注意力中的K、Q、V需要多层来计算。随着AI芯片雨来越大,神经网络推理速度的瓶颈不再是FLOPs或参数,因为现代GPU可以很容易地进行并行计算。相比之下,它们复杂的设计和较大的深度阻碍了它们的速度。为此我们提出了Vanilla网络,即VanillaNet,其框架图如图1所示。我们遵循流行的神经网络设计,包括主干、主体和全连接层。与现有的深度网络不同,我们在每个阶段只使用一层,以建立一个尽可能少的层的极其简单的网络。该网络的特点是不采用shortcut(shortcut会增加访存时间),同时没有复杂的模块如自注意力等。

在深度学习中,通过在训练阶段引入更强的容量来增强模型的性能是很常见的。为此,我们建议利用深度训练技术来提高所提出的VanillaNet在训练期间的能力。

优化策略1: 深度训练,浅层推理         为了提升VanillaNet这个架构的非线性,我们提出首先提出了深度训练(Deep training)策略,在训练过程中把一个卷积层拆成两个卷积层,并在中间插入如下的非线性操作:

        其中, A 是传统的非线性激活函数,最简单的还是 ReLU, λ 会随着模型的优化逐渐变为1,两个卷积层就可以合并成为一层,不改变VanillaNet的结构。

优化策略2:换激活函数         既然我们想提升VanillaNet的非线性,一个更直接的方案是有没有非线性更强的激活函数,并且这个激活函数好并行速度快?为了实现这个既要又要的宪法,我们提出一种基于级数启发的激活函数,把多个ReLU加权加偏置堆叠起来:

        然后再进行微调,提升这个激活函数对信息的感知能力。

YOLO中改进

以yolov7-tiny的backbone为例,想用v5v8的可以把backbone复制到v5v8的配置中就行

首先配置文件yolov7-tiny-vanilla.yaml:

# parameters
nc: 80  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple

activation: nn.ReLU()
# anchors
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# yolov7-tiny backbone
backbone:
  # [from, number, module, args] c2, k=1, s=1, p=None, g=1, act=True
  [[-1, 1, VanillaStem, [64, 4, 4, None, 1]],  # 0-P1/4
   [-1, 1, VanillaBlock, [256, 1, 2, None, 1]],  # 1-P2/8
   [-1, 1, VanillaBlock, [512, 1, 2, None, 1]],  # 2-P3/16
   [-1, 1, VanillaBlock, [1024, 1, 2, None, 1]],  # 3-P4/32
  ]

# yolov7-tiny head
head:
  [[1, 1, Conv, [128, 1, 1, None, 1]],  # 4
   [2, 1, Conv, [256, 1, 1, None, 1]],  # 5
   [3, 1, Conv, [512, 1, 1, None, 1]],  # 6

   [-1, 1, SPPCSPCSIM, [256]], # 7

   [-1, 1, Conv, [128, 1, 1, None, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [5, 1, Conv, [128, 1, 1, None, 1]], # route backbone P3
   [[-1, -2], 1, Concat, [1]], # 11

   [-1, 1, ELAN, [128, 1, 1, None, 1]],  # 12

   [-1, 1, Conv, [64, 1, 1, None, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [4, 1, Conv, [64, 1, 1, None, 1]], # route backbone P2
   [[-1, -2], 1, Concat, [1]],  # 16

   [-1, 1, ELAN, [64, 1, 1, None, 1]],  # 17

   [-1, 1, Conv, [128, 3, 2, None, 1]],
   [[-1, 12], 1, Concat, [1]],

   [-1, 1, ELAN, [128, 1, 1, None, 1]],  # 20

   [-1, 1, Conv, [256, 3, 2, None, 1]],
   [[-1, 7], 1, Concat, [1]],

   [-1, 1, ELAN, [256, 1, 1, None, 1]],  # 23

   [17, 1, Conv, [128, 3, 1, None, 1]],
   [20, 1, Conv, [256, 3, 1, None, 1]],
   [23, 1, Conv, [512, 3, 1, None, 1]],

   [[24,25,26], 1, Detect, [nc, anchors]],   # Detect(P3, P4, P5)
  ]

其中ELAN模块为对原配置文件中相应的模块进行过简化可以

简化方式:yolov7简化网络yaml配置文件_athrunsunny的博客-CSDN博客

为了保证模型整体参数量和v5s以及v7tiny相差太多, VanillaBlock的channel数设置的比原文中要小

在common.py中添加

class activation(nn.ReLU):
    def __init__(self, dim, act_num=3, deploy=False):
        super(activation, self).__init__()
        self.act_num = act_num
        self.deploy = deploy
        self.dim = dim
        self.weight = torch.nn.Parameter(torch.randn(dim, 1, act_num * 2 + 1, act_num * 2 + 1))
        if deploy:
            self.bias = torch.nn.Parameter(torch.zeros(dim))
        else:
            self.bias = None
            self.bn = nn.BatchNorm2d(dim, eps=1e-6)
        nn.init.trunc_normal_(self.weight, std=.02)

    def forward(self, x):
        if self.deploy:
            return torch.nn.functional.conv2d(
                super(activation, self).forward(x),
                self.weight, self.bias, padding=self.act_num, groups=self.dim)
        else:
            return self.bn(torch.nn.functional.conv2d(
                super(activation, self).forward(x),
                self.weight, padding=self.act_num, groups=self.dim))

    def _fuse_bn_tensor(self, weight, bn):
        kernel = weight
        running_mean = bn.running_mean
        running_var = bn.running_var
        gamma = bn.weight
        beta = bn.bias
        eps = bn.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta + (0 - running_mean) * gamma / std

    def switch_to_deploy(self):
        kernel, bias = self._fuse_bn_tensor(self.weight, self.bn)
        self.weight.data = kernel
        self.bias = torch.nn.Parameter(torch.zeros(self.dim))
        self.bias.data = bias
        self.__delattr__('bn')
        self.deploy = True


class VanillaStem(nn.Module):
    def __init__(self, in_chans=3, dims=96,
                 k=0, s=0, p=None,g=0, act_num=3, deploy=False, ada_pool=None, **kwargs):
        super().__init__()
        self.deploy = deploy
        stride, padding = (4, 0) if not ada_pool else (3, 1)
        if self.deploy:
            self.stem = nn.Sequential(
                nn.Conv2d(in_chans, dims, kernel_size=k, stride=stride, padding=padding),
                activation(dims, act_num, deploy=self.deploy)
            )
        else:
            self.stem1 = nn.Sequential(
                nn.Conv2d(in_chans, dims, kernel_size=k, stride=stride, padding=padding),
                nn.BatchNorm2d(dims, eps=1e-6),
            )
            self.stem2 = nn.Sequential(
                nn.Conv2d(dims, dims, kernel_size=1, stride=1),
                nn.BatchNorm2d(dims, eps=1e-6),
                activation(dims, act_num)
            )
        self.act_learn = 1
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def forward(self, x):
        if self.deploy:
            x = self.stem(x)
        else:
            x = self.stem1(x)
            x = torch.nn.functional.leaky_relu(x, self.act_learn)
            x = self.stem2(x)

        return x

    def _fuse_bn_tensor(self, conv, bn):
        kernel = conv.weight
        bias = conv.bias
        running_mean = bn.running_mean
        running_var = bn.running_var
        gamma = bn.weight
        beta = bn.bias
        eps = bn.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta + (bias - running_mean) * gamma / std

    def switch_to_deploy(self):
        self.stem2[2].switch_to_deploy()
        kernel, bias = self._fuse_bn_tensor(self.stem1[0], self.stem1[1])
        self.stem1[0].weight.data = kernel
        self.stem1[0].bias.data = bias
        kernel, bias = self._fuse_bn_tensor(self.stem2[0], self.stem2[1])
        self.stem1[0].weight.data = torch.einsum('oi,icjk->ocjk', kernel.squeeze(3).squeeze(2),
                                                 self.stem1[0].weight.data)
        self.stem1[0].bias.data = bias + (self.stem1[0].bias.data.view(1, -1, 1, 1) * kernel).sum(3).sum(2).sum(1)
        self.stem = torch.nn.Sequential(*[self.stem1[0], self.stem2[2]])
        self.__delattr__('stem1')
        self.__delattr__('stem2')
        self.deploy = True


class VanillaBlock(nn.Module):
    def __init__(self, dim, dim_out,k=0 , stride=2,p=None,g=0, ada_pool=None,act_num=3 ,deploy=False):
        super().__init__()
        self.act_learn = 1
        self.deploy = deploy
        if self.deploy:
            self.conv = nn.Conv2d(dim, dim_out, kernel_size=1)
        else:
            self.conv1 = nn.Sequential(
                nn.Conv2d(dim, dim, kernel_size=1),
                nn.BatchNorm2d(dim, eps=1e-6),
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(dim, dim_out, kernel_size=1),
                nn.BatchNorm2d(dim_out, eps=1e-6)
            )

        if not ada_pool:
            self.pool = nn.Identity() if stride == 1 else nn.MaxPool2d(stride)
        else:
            self.pool = nn.Identity() if stride == 1 else nn.AdaptiveMaxPool2d((ada_pool, ada_pool))

        self.act = activation(dim_out, act_num, deploy=self.deploy)

    def forward(self, x):
        if self.deploy:
            x = self.conv(x)
        else:
            x = self.conv1(x)
            # We use leakyrelu to implement the deep training technique.
            x = torch.nn.functional.leaky_relu(x, self.act_learn)
            x = self.conv2(x)

        x = self.pool(x)
        x = self.act(x)
        return x

    def _fuse_bn_tensor(self, conv, bn):
        kernel = conv.weight
        bias = conv.bias
        running_mean = bn.running_mean
        running_var = bn.running_var
        gamma = bn.weight
        beta = bn.bias
        eps = bn.eps
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta + (bias - running_mean) * gamma / std

    def switch_to_deploy(self):
        kernel, bias = self._fuse_bn_tensor(self.conv1[0], self.conv1[1])
        self.conv1[0].weight.data = kernel
        self.conv1[0].bias.data = bias
        # kernel, bias = self.conv2[0].weight.data, self.conv2[0].bias.data
        kernel, bias = self._fuse_bn_tensor(self.conv2[0], self.conv2[1])
        self.conv = self.conv2[0]
        self.conv.weight.data = torch.matmul(kernel.transpose(1, 3),
                                             self.conv1[0].weight.data.squeeze(3).squeeze(2)).transpose(1, 3)
        self.conv.bias.data = bias + (self.conv1[0].bias.data.view(1, -1, 1, 1) * kernel).sum(3).sum(2).sum(1)
        self.__delattr__('conv1')
        self.__delattr__('conv2')
        self.act.switch_to_deploy()
        self.deploy = True

同时在yolo.py中做如下修改

1、parse_model函数中

if m in (Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                 BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, SPPCSPC, RepConv,
                 RFEM, ELAN, SPPCSPCSIM,VanillaBlock,VanillaStem):
            c1, c2 = ch[f], args[0]
            if c2 != no:  # if not output
                c2 = make_divisible(c2 * gw, 8)

            args = [c1, c2, *args[1:]]
            if m in [BottleneckCSP, C3, C3TR, C3Ghost, C3x]:
                args.insert(2, n)  # number of repeats
                n = 1

2、BaseModel的fuse函数

if isinstance(m, (VanillaStem, VanillaBlock)):
                # print(m)
                m.deploy = True
                m.switch_to_deploy()

在yolo.py中测试yolov7-tiny-vanilla.yaml

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1217141.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

矿区安全检查VR模拟仿真培训系统更全面、生动有效

矿山企业岗位基数大,生产过程中会持续有新入矿的施工人员及不定期接待的参观人员,下井安全须知培训需求量大。传统实景拍摄的视频剪辑表达方式有限,拍摄机位受限,难以生动表达安全须知的内容,且井下现场拍摄光线不理想…

Spring Boot 项目部署方案!打包 + Shell 脚本部署详解

文章目录 概要一 、profiles指定不同环境的配置二、maven-assembly-plugin打发布压缩包三、 分享shenniu_publish.sh程序启动工具四、linux上使用shenniu_publish.sh启动程序 概要 本篇和大家分享的是springboot打包并结合shell脚本命令部署,重点在分享一个shell程…

qnx 工程目录创建工具 addvariant

文章目录 前言一、addvariant 是什么二、addvariant 使用实例1. variant names 参数说明2. 创建一个可执行文件工程3. 创建一个动态库工程 总结参考资料 前言 本文主要介绍如何在qnx 开发环境中创建工程目录及其相关的配置文件(common.mk, Makefile 文件等) 软件版本&#xff…

第四代管网水位监测仪:管网水位监测仪效果怎么样?

随着我国城市化进程的推进,随之而来的是城市规模不断扩大,各类排水管网设施需要检查的范围也日益扩大。尽管市政管理部门已安排人员定期巡查,但仍无法对井下水位进行24小时实时监控。如遇管网阻塞、窨井满溢等情况,若不及时处理将…

(七)五种元启发算法(DBO、LO、SWO、COA、LSO、KOA、GRO)求解无人机路径规划MATLAB

一、五种算法(DBO、LO、SWO、COA、GRO)简介 1、蜣螂优化算法DBO 蜣螂优化算法(Dung beetle optimizer,DBO)由Jiankai Xue和Bo Shen于2022年提出,该算法主要受蜣螂的滚球、跳舞、觅食、偷窃和繁殖行为的启…

python 最快多长时间学完?

以下是一个为零基础学员制作Python速成学习计划。这个计划包括了一些基本的Python概念和技能,以及一些实用的学习技巧。 第1周:基础入门 Python简介:了解Python的历史、特点、应用领域。 安装Python:在你的电脑上安装Python&am…

【L2GD】: 无环局部梯度下降

文章链接:Federated Learning of a Mixture of Global and Local Models 发表期刊(会议): ICLR 2021 Conference(机器学习顶会) 往期博客:FLMix: 联邦学习新范式——局部和全局的结合 目录 1.背景介绍2. …

赶快来!程序员接单必须知道的六大注意事项!!!

花花世界迷人眼,增加实力多搞钱!对于咱程序员来说,搞钱的最好办法就是网上接单了,相信也有不少小伙伴已经在尝试了吧!但是如何正确的搞钱呢?其中的注意事项你真的了解吗? 本期就和小编一起来看…

最佳实践-使用Github Actions来构建跨平台容器镜像

概述 GitHub Actions 是一种持续集成和持续交付 (CI/CD) 平台,可用于自动执行生成、测试和部署管道。 您可以创建工作流程来构建和测试存储库的每个拉取请求,或将合并的拉取请求部署到生产环境。 GitHub Actions 不仅仅是 DevOps,还允许您在存…

【云原生-Kurbernets篇】Kurbernets集群的调度策略

调度 一、Kurbernetes的list-watch机制1.1 list-watch机制简介1.2 创建pod的流程(结合list-watch机制) 二、Scheduler的调度策略2.1 简介2.2 预选策略(predicate)2.3 优选策略(priorities) 三、标签管理3.1…

C51--WiFi模块ESP8266--AT指令

ESP8266 面向物联网应用的,高性价比、高度集成的WiFi MCU 简介: 高度集成: ESP8266EX集成了32位Tensilica 处理器、标准数字外设接口、天线开关、射频balun、功率放大器、底噪放大器、过滤器和电源管理模块,可将所占的PCB空间降…

SDL2 播放音频(MP4)

1.简介 这里引入FFmpeg库,获取音频流数据,然后通过FFmpeg将视频流解码成pcm原始数据,再将pcm数据送入到SDL库中实现音频播放。 2.FFmpeg的操作流程 注册API:av_register_all()构建输入AVFormatContext上下文:avform…

超细致Python自动化测试实现的思路

Python自动化测试常用于Web应用、移动应用、桌面应用等的测试 同时,我也为大家准备了一份软件测试视频教程(含面试、接口、自动化、性能测试等),就在下方,需要的可以直接去观看,也可以直接点击文末小卡片免…

Python自动化测试之request库详解(三)

做过接口测试的都会发现,现在的接口都是HTTPS协议了,今天就写一篇如何通过request发送https请求。 什么是HTTPS HTTPS 的全称是Hyper Text Transfer Protocol over Secure Socket Layer ,是以安全为目标的HTTP通道,简单的讲是HTT…

LeetCode【41】缺失的第一个正数

题目: 分析: 第i个位置的数,如果再数组 0到length-1范围内,则将其放到对应的位置; 再遍历一遍数组,找到第一个不在位置i的正数数字,即为所求 思路:https://blog.csdn.net/weixin_45…

基于JavaWeb+SpringBoot+Vue医疗器械商城微信小程序系统的设计和实现

基于JavaWebSpringBootVue医疗器械商城微信小程序系统的设计和实现 源码获取入口前言主要技术系统设计功能截图Lun文目录订阅经典源码专栏Java项目精品实战案例《500套》 源码获取 源码获取入口 前言 摘 要 目前医疗器械行业作为医药行业的一个分支,发展十分迅速。…

【Java 进阶篇】JQuery 遍历 —— `each()` 方法的奇妙之旅

在前端的世界里,操作元素是我们开发者最为频繁的任务之一。为了更好地操控页面上的元素,JQuery 提供了许多强大的工具,其中 each() 方法是一颗璀璨的明星。本文将深入探讨 each() 方法的原理和用法,带你踏上一场遍历之旅。 起步&…

第四代智能井盖传感器:万宾科技智能井盖位移监测方式一览

现在城市化水平不断提高,每个城市的井盖遍布在城市的街道上,是否能够实现常态化和系统化的管理,反映了一个城市治理现代化水平。而且近些年来住建部曾多次要求全国各个城市加强相关的井盖管理工作,作为基础设施重要的一个组成部分…

JAVA安全之Shrio550-721漏洞原理及复现

前言 关于shrio漏洞,网上有很多博文讲解,这些博文对漏洞的解释似乎有一套约定俗成的说辞,让人云里来云里去,都没有对漏洞产生的原因深入地去探究..... 本文从现象到本质,旨在解释清楚Shrio漏洞是怎么回事&#xff01…

Java学习之路 —— IO、特殊文件

文章目录 1. I/O1.1 常用API1.2 I/O流1.2.1 字节流1.2.2 try-catch-finally和try-with-resource1.2.3 字符流1.2.4 其他的一些流 2. I/O框架3. 特殊文件3.1. Properties3.2 XML 1. I/O 1.1 常用API // 1. 创建文件对象File file new File("E:\\ComputerScience\\java\\…