YoloV8改进策略:卷积改进|DOConv轻量卷积,即插即用|适用各种场景

news2024/11/15 23:21:47

摘要

本文使用DOConv卷积,替换YoloV8的常规卷积,轻量高效,即插即用!改进方法非常简单。
在这里插入图片描述

DO-Conv(Depthwise Over-parameterized Convolutional Layer)是一种深度过参数化的卷积层,用于提高卷积神经网络(CNN)的性能。它的核心思想是在训练阶段使用额外的深度卷积来增强卷积层,其中每个输入通道与不同的二维核进行卷积。这两个卷积的组合构成了一个过度参数化,因为它增加了可学习的参数,而结果的线性操作可以用单个卷积层来表示。在推理阶段,DO-Conv可以融合到常规卷积层中,使得计算量与常规卷积层的计算量完全相同。

DO-Conv可以作为一种即插即用的模块,用于替代CNN中的常规卷积层,以提高在各种计算机视觉任务(如图像分类、语义分割和对象检测)上的性能。通过实验证明,使用DO-Conv不仅可以加速网络的训练过程,还能在多种计算机视觉任务中取得比使用传统卷积层更好的结果。

论文链接:https://arxiv.org/pdf/2006.12030.pdf

代码

# coding=utf-8
import math
import torch
import numpy as np
from torch.nn import init
from itertools import repeat
from torch.nn import functional as F
from torch._jit_internal import Optional
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import collections


class DOConv2d(Module):
    """
       DOConv2d can be used as an alternative for torch.nn.Conv2d.
       The interface is similar to that of Conv2d, with one exception:
            1. D_mul: the depth multiplier for the over-parameterization.
       Note that the groups parameter switchs between DO-Conv (groups=1),
       DO-DConv (groups=in_channels), DO-GConv (otherwise).
    """
    __constants__ = ['stride', 'padding', 'dilation', 'groups',
                     'padding_mode', 'output_padding', 'in_channels',
                     'out_channels', 'kernel_size', 'D_mul']
    __annotations__ = {'bias': Optional[torch.Tensor]}

    def __init__(self, in_channels, out_channels, kernel_size, D_mul=None, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        super(DOConv2d, self).__init__()

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)

        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
        if padding_mode not in valid_padding_modes:
            raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format(
                valid_padding_modes, padding_mode))
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.padding_mode = padding_mode
        self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2))

        #################################### Initailization of D & W ###################################
        M = self.kernel_size[0]
        N = self.kernel_size[1]
        self.D_mul = M * N if D_mul is None or M * N <= 1 else D_mul
        self.W = Parameter(torch.Tensor(out_channels, in_channels // groups, self.D_mul))
        init.kaiming_uniform_(self.W, a=math.sqrt(5))

        if M * N > 1:
            self.D = Parameter(torch.Tensor(in_channels, M * N, self.D_mul))
            init_zero = np.zeros([in_channels, M * N, self.D_mul], dtype=np.float32)
            self.D.data = torch.from_numpy(init_zero)

            eye = torch.reshape(torch.eye(M * N, dtype=torch.float32), (1, M * N, M * N))
            d_diag = eye.repeat((in_channels, 1, self.D_mul // (M * N)))
            if self.D_mul % (M * N) != 0:  # the cases when D_mul > M * N
                zeros = torch.zeros([in_channels, M * N, self.D_mul % (M * N)])
                self.d_diag = Parameter(torch.cat([d_diag, zeros], dim=2), requires_grad=False)
            else:  # the case when D_mul = M * N
                self.d_diag = Parameter(d_diag, requires_grad=False)
        ##################################################################################################

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
        else:
            self.register_parameter('bias', None)

    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        if self.padding_mode != 'zeros':
            s += ', padding_mode={padding_mode}'
        return s.format(**self.__dict__)

    def __setstate__(self, state):
        super(DOConv2d, self).__setstate__(state)
        if not hasattr(self, 'padding_mode'):
            self.padding_mode = 'zeros'

    def _conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input):
        M = self.kernel_size[0]
        N = self.kernel_size[1]
        DoW_shape = (self.out_channels, self.in_channels // self.groups, M, N)
        if M * N > 1:
            ######################### Compute DoW #################
            # (input_channels, D_mul, M * N)
            D = self.D + self.d_diag
            W = torch.reshape(self.W, (self.out_channels // self.groups, self.in_channels, self.D_mul))

            # einsum outputs (out_channels // groups, in_channels, M * N),
            # which is reshaped to
            # (out_channels, in_channels // groups, M, N)
            DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape)
        else:
            # in this case D_mul == M * N
            # reshape from
            # (out_channels, in_channels // groups, D_mul)
            # to
            # (out_channels, in_channels // groups, M, N)
            DoW = torch.reshape(self.W, DoW_shape)
        return self._conv_forward(input, DoW)


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse
_pair = _ntuple(2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1615716.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

52 文本预处理【动手学深度学习v2】

将文本作为字符串加载到内存中。 将字符串拆分为词元&#xff08;如单词和字符&#xff09;。 建立一个词表&#xff0c;将拆分的词元映射到数字索引;将文本转换为数字索引序列&#xff0c;方便模型操作。

清华新突破,360°REA重塑多智能体系统:全方位提升复杂任务表现

引言&#xff1a;多智能体系统的新篇章——360REA框架 在多智能体系统的研究领域&#xff0c;最新的进展揭示了一种全新的框架——360REA&#xff08;Reusable Experience Accumulation with 360 Assessment&#xff09;。这一框架的提出&#xff0c;不仅是对现有系统的一次重大…

模块三——二分:704.二分查找

文章目录 前言二分查找算法简介特点学习中的侧重点算法原理模板 题目描述算法原理解法一&#xff1a;暴力解法解法二&#xff1a;二分查找算法算法流程细节问题循环结束的条件为什么是正确的&#xff1f;时间复杂度 代码实现 前言 本系列博客是逐渐深入的过程&#xff0c;建议…

函数的内容

一&#xff0c;概念 封装一份可以被重复执行的代码块&#xff0c;让大量代码重复使用 二&#xff0c;函数使用 大体分两步&#xff1a;声明函数&#xff0c;调用函数 声明函数有关键字&#xff1a;function 函数名&#xff08;&#xff09;{ 函数体 } 为基本格式&#xf…

代码随想录算法训练营第四十八天| 198.打家劫舍,213.打家劫舍II,337.打家劫舍III

题目与题解 198.打家劫舍 题目链接&#xff1a;198.打家劫舍 代码随想录题解&#xff1a;​​​​​​​198.打家劫舍 视频讲解&#xff1a;动态规划&#xff0c;偷不偷这个房间呢&#xff1f;| LeetCode&#xff1a;198.打家劫舍_哔哩哔哩_bilibili 解题思路&#xff1a; 这道…

阿里巴巴Java开发规范——编程规约(3)

# 阿里巴巴Java开发规范——编程规约&#xff08;3&#xff09; 编程规约 &#xff08;四&#xff09; OOP规约 1.【强制】构造方法里面禁止加入任何业务逻辑&#xff0c;如果有初始化逻辑&#xff0c;请放在 init 方法中 这条编程规范的目的是为了保持代码的清晰性、可读性…

非计算机专业考软考高项有必要吗?

我认为这非常重要。 看了你的介绍&#xff0c;如果你已经考取了会计证书&#xff0c;而且想要考取计算机专业的证书&#xff0c;或者你的职业规划涉及到计算机岗位&#xff0c;又或者你对计算机感兴趣&#xff0c;我建议你优先考虑软考&#xff0c;因为这个证书的含金量是有保…

问题带来多少成长,看你挖得有多深多痛

原文: 一次Redis访问超时的“捉虫”之旅 力是相互的&#xff0c;成长与痛苦也是相互的。 01-引言 最近在对一个老项目使用的docker镜像版本升级过程中碰到一个奇怪的问题&#xff0c;发现项目升级到高版本镜像后&#xff0c;访问Redis会出现很多超时错误&#xff0c;而降回之…

【数学建模】虫子追击问题(仿真)

已知 有四个虫子,分别是 A , B , C , D A,B,C,D A,B,C,D A , B , C , D A,B,C,D A,B,C,D分别在 ( 0 , 0 ) , ( 0 , 1 ) , ( 1 , 1 ) , ( 1 , 0 ) (0,0),(0,1),(1,1),(1,0) (0,0),(0,1),(1,1),(1,0)四个虫子A追B&#xff0c;B追C&#xff0c;C追D&#xff0c;D追A四个速度相同 …

XTuner 微调 LLM:1.8B、多模态、Agent——笔记

XTuner 微调 LLM&#xff1a;1.8B、多模态、Agent——笔记 一、Finetune 简介1.1、两种 Finetune 范式1.2、一条数据的一生1.2.1、标准格式数据1.2.2、添加对话模板1.2.3、LoRA & QLoRA 二、XTuner2.1、XTuner 简介2.2、LLaMA-Factory vs XTuner2.3、XTuner 数据引擎2.3.1、…

【InternLM 实战营第二期笔记04】XTuner微调LLM:1.8B、多模态、Agent

一、微调的原因 大模型微调&#xff08;Fine-tuning&#xff09;的原因主要有以下几点&#xff1a; 适应特定任务&#xff1a;预训练的大模型往往是在大量通用数据上训练的&#xff0c;虽然具有强大的表示学习能力&#xff0c;但可能并不直接适用于特定的下游任务。通过微调&…

低代码新时代:6款免费开发平台助你畅行编码之路

本篇文章为您介绍的六款免费又好用的低代码开发平台有&#xff1a;Zoho creator、Baserow、OS.bee、nuBuilder、JHipster、Appian。 一、Zoho creator Zoho Creator是一款国际化的低代码开发平台&#xff0c;有超17年低代码经验。近些年&#xff0c;Zoho Creator以其成本低、国…

一键还原精灵 V12.1.405.701 装机版

网盘下载 个人版&#xff1a;不划分分区不修改分区表及MBR&#xff0c;安装非常安全&#xff0c;备份文件自动隐藏&#xff0c;不适用于WIN98系统。 装机版&#xff1a;需用PQMAGIC划分分区作隐藏的备份分区&#xff0c;安装过程中有一定的风险&#xff0c;安装后就非常安全。…

基于SpringBoot的宠物领养网站管理系统

基于SpringBootVue的宠物领养网站管理系统的设计与实现~ 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringBootMyBatis工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 主页 宠物领养 宠物救助站 宠物论坛 登录界面 管理员界面 摘要 基于Spr…

[RTOS 学习记录] 复杂工程项目的管理

[RTOS 学习记录] 复杂工程项目的管理 这篇文章是我阅读《嵌入式实时操作系统μCOS-II原理及应用》后的读书笔记&#xff0c;记录目的是为了个人后续回顾复习使用。 前置内容&#xff1a; 工程管理工具make及makefile 文章目录 1 批处理文件与makefile的综合使用1.1 批处理文件…

[第一届 帕鲁杯 CTF挑战赛 2024] Crypto/PWN/Reverse

被一个小题整坏了&#xff0c;后边就没认真打。赛后把没作的复盘一下。 比赛有52个应急响应&#xff0c;猜是取证&#xff0c;都是队友在干&#xff0c;我也不大关心。前边大多题是比赛的原题。这是后来听说的&#xff0c;可都没见过&#xff0c;看来打的比赛还是少了。 Cryp…

基于RT-Thread摄像头车牌图像采集系统

一、项目简介 使用基于RT-thread操作系统的AB32VG1开发板作为主控&#xff0c;对ov7670摄像头进行图像采集&#xff0c;并使用串口发送图片RGB565格式到PC供opencv进行图像识别。 原项目设想在开发板上进行采集的同时并通过简单的二值算法和插值算法实现车牌号识别&#xff0c…

Obsidian 快速安装

看网上Obsidian 很好用&#xff0c;但自己下载总是中断&#xff0c;烦的要死&#xff0c;一度以为要开魔法…… 直到我找到了这个网站Thoughts (teambition.com) yeah~ 亲测有效&#xff0c;大概不到2min吧. 快速开始~&#xff0c;成功水了一片

(二)Servlet教程——我的第一个Java程序

首先打开记事本&#xff0c;输入如下的代码&#xff0c;请注意字母的大小写 public class MyFirst{ public static void main(String[] args){ System.out.println("This is My first Java..."); } } 将该txt文件命名为MyFirst.java 打开cmd命令行窗口&#xff0…

【STM32】嵌入式实验二 GPIO 实验:数码管

实验内容&#xff1a; 编写程序&#xff0c;在数码管上显示自己的学号。 数码管相关电路&#xff1a; PA7对应的应该是段码&#xff0c;上面的图写错了。 注意&#xff1a;选中数码管是低电平选中&#xff1b;并且用74HC595模块驱动输出的段码&#xff0c; 这个模块的学习可以…