YOLOv5、YOLOv8改进:ConvNeXt(backbone改为ConvNextBlock)

news2024/11/23 22:32:31

 

目录

 

1.介绍

2. YOLOv5修改backbone为ConvNeXt

2.1修改common.py

2.2 修改yolo.py

2.3修改yolov5.yaml配置


1.介绍

  • 论文地址:https://arxiv.org/abs/2201.03545
  • 官方源代码地址:https://github.com/facebookresearch/ConvNeXt.git

自从ViT(Vision Transformer)在CV领域大放异彩,越来越多的研究人员开始拥入Transformer的怀抱。回顾近一年,在CV领域发的文章绝大多数都是基于Transformer的,比如2021年ICCV 的best paper Swin Transformer,而卷积神经网络已经开始慢慢淡出舞台中央。卷积神经网络要被Transformer取代了吗?也许会在不久的将来。今年(2022)一月份,Facebook AI Research和UC Berkeley一起发表了一篇文章A ConvNet for the 2020s,在文章中提出了ConvNeXt纯卷积神经网络,它对标的是2021年非常火的Swin Transformer,通过一系列实验比对,在相同的FLOPs下,ConvNeXt相比Swin Transformer拥有更快的推理速度以及更高的准确率,在ImageNet 22K上ConvNeXt-XL达到了87.8%的准确率,参看下图(原文表12)。看来ConvNeXt的提出强行给卷积神经网络续了口命。

ConvNeXt是一种由Facebook AI Research和UC Berkeley共同提出的卷积神经网络模型。它是一种纯卷积神经网络,由标准卷积神经网络模块构成,具有精度高、效率高、可扩展性强和设计非常简单的特点。ConvNeXt在2022年的CVPR上发表了一篇论文,题为“面向2020年代的卷积神经网络”。ConvNeXt已在ImageNet-1K和ImageNet-22K数据集上进行了训练,并在多个任务上取得了优异的表现。ConvNeXt的训练代码和预训练模型均已在GitHub上公开。
ConvNeXt是基于ResNet50进行改进的,其与Swin Transformer一样,具有4个Stage;不同的是ConvNeXt将各Stage中Block的数量比例从3:4:6:3改为了与Swin Transformer一样的1:1:3:1。 此外,在进行特征图降采样方面,ConvNeXt采用了与Swin Transformer一致的步长为4,尺寸为4×4的卷积核。
ConvNeXt的优点包括:
ConvNeXt是一种纯卷积神经网络,由标准卷积神经网络模块构成,具有精度高、效率高、可扩展性强和设计非常简单的特点。
ConvNeXt在ImageNet-1K和ImageNet-22K数据集上进行了训练,并在多个任务上取得了优异的表现。
ConvNeXt采用了Transformer网络的一些先进思想对现有的经典ResNet50/200网络做一些调整改进,将Transformer网络的最新的部分思想和技术引入到CNN网络现有的模块中从而结合这两种网络的优势,提高CNN网络的性能表现.
ConvNeXt的缺点包括:
ConvNeXt并没有在整体的网络框架和搭建思路上做重大的创新,它仅仅是依照Transformer网络的一些先进思想对现有的经典ResNet50/200网络做一些调整改进.
ConvNeXt相对于其他CNN模型而言,在某些情况下需要更多计算资源.

2. YOLOv5修改backbone为ConvNeXt

2.1修改common.py

将以下代码,添加进common.py。

############## ConvNext ##############
import torch.nn.functional as F
class LayerNorm_s(nn.Module):

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


class ConvNextBlock(nn.Module):

    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
        self.norm = LayerNorm_s(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks).
    """

    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path_f(x, self.drop_prob, self.training)


def drop_path_f(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_()  # binarize
    output = x.div(keep_prob) * random_tensor
    return output


class CNeB(nn.Module):
    # CSP ConvNextBlock with 3 convolutions by iscyy/yoloair
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)
        self.m = nn.Sequential(*(ConvNextBlock(c_) for _ in range(n)))

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
############## ConvNext ##############

2.2 修改yolo.py

        if m in [Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,
                 BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, CNeB]:

 

2.3修改yolov5.yaml配置

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, CNeB, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, CNeB, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, CNeB, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, CNeB, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, CNeB, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, CNeB, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, CNeB, [1024, False]],  # 23 (P5/32-large)

   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

后续会更新v7、v8版本

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1044407.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux 权限相关例题练习

目录 一、前期准备工作: 1)新建redhat用户 2)新建testdir目录及其file1 二、例题详解 1、当用户redhat对/testdir目录无写权限时,该目录下的只读文件file1是否可修改和删除? 2、复制/etc/fstab文件到/var/tmp下&…

C++11之新的类功能

这里写目录标题 新的类功能默认成员函数类成员变量初始化强制生成默认函数的关键字default禁止生成默认函数的关键字deletefinal与override关键字override 新的类功能 默认成员函数 原来C类中,有6个默认成员函数: 构造函数析构函数拷贝构造函数拷贝赋…

Spring学习笔记13 Spring对事务的支持

Spring学习笔记12 面向切面编程AOP-CSDN博客 什么是事务:在一个业务流程当中,通常需要多条DML(insert delete update)语句共同联合才能完成,这多条DML语句必须同时成功,或者同时失败,这样才能保证数据的安全. 多条DML要么同时成功,要么同时失败,叫做事务(Transaction) 事务四…

安防视频平台EasyCVR视频调阅全屏播放显示异常是什么原因?

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台可拓展性强、视频能力灵活、部署轻快,可支持的主流标准协议有国标GB28181、RTSP/Onvif、RTMP等,以及支持厂家私有协议与SDK接入,包括海康Ehome、海大宇等设备的SDK等。平台既具备传统安…

Apache DolphinScheduler在中国信通院“2023 OSCAR开源尖峰案例”评选中荣获「尖峰开源项目奖」!

在近日由中国信息通信研究院(以下简称“中国信通院”)和中国通信标准化协会联合主办的“2023 OSCAR 开源产业大会”上,主办方公布了 2023 年“OSCAR 开源尖峰案例”评选结果,包括“开源人物”“开源项目”“开源社区”“开源企业”…

python+vue实验室课程预约管理系统

实验室课程管理系统运用计算机完成数据收集、查询、修改和删除以及统计等工作,提高了管理者工作效率,避免了因信息量巨大,造成的人为错误.通过前面的功能分析可以将实验室课程管理系统的功能分为管理员、学生和教师三个部分&#…

Linux高性能服务器编程 学习笔记 第八章 高性能服务器程序框架

TCP/IP协议在设计和实现上没有客户端和服务器的概念,在通信过程中所有机器都是对等的。但由于资源(视频、新闻、软件等)被数据提供者所垄断,所以几乎所有网络应用程序都采用了下图所示的C/S(客户端/服务器)…

LeetCode_BFS_中等_1926.迷宫中离入口最近的出口

目录 1.题目2.思路3.代码实现(Java) 1.题目 给你一个 m x n 的迷宫矩阵 maze (下标从 0 开始),矩阵中有空格子(用 ‘.’ 表示)和墙(用 ‘’ 表示)。同时给你迷宫的入口 …

电脑提示vcruntime140.dll缺失重新安装的修复方法

电脑出现 vcruntime140.dll 丢失的情况,通常是由于系统缺失了 Microsoft Visual C Redistributable 的运行库文件。这个文件是许多应用程序在运行时所需的依赖库,如果丢失了该文件,可能会导致某些软件无法正常运行。 下面是关于 vcruntime140…

免费录音软件推荐,告别杂音,音质更清晰!

“求推荐一款免费的录音软件!最近下载了好多的录音软件,不是音质太差,就是需要收费解锁新的功能,根本不好用,有没有人知道一款免费优秀的录音软件呀,告诉我一下。” 录音已成为现代人们学习和工作中的一项…

DataExcel控件读取和保存excel xlsx 格式文件

需要引用NPOI库 https://github.com/dotnetcore/NPOI 调用Read 函数将excel读取到dataexcel控件 调用Save 函数将dataexcel控件文件保存为excel文件 using NPOI.HSSF.UserModel; using NPOI.HSSF.Util; using NPOI.SS.UserModel; using NPOI.SS.Util; using System; using …

pytho实例--pandas读取表格内容

前言:由于运维反馈帮忙计算云主机的费用,特编写此脚本进行运算 如图,有如下excel数据 计算过程中需用到数据库中的数据,故封装了一个读取数据库的类 import MySQLdb from sshtunnel import SSHTunnelForwarderclass SSHMySQL(ob…

Java BigDecimal 详解

目录 一、BigDecimal 1、简介 2、构造器描述 3、方法描述 4、使用 一、BigDecimal float和double类型的主要设计目标是为了科学计算和工程计算。他们执行二进制浮点运算,这是为了在广域数值范围上提供较为精确的快速近似计算而精心设计的。然而,它…

25841-2017 1000kV电力系统继电保护技术导则

声明 本文是学习GB-T 25841-2017 1000kV电力系统继电保护技术导则. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了交流1000 kV 系统及1000 kV 变电站相关电压等级具有特别要求的继电保护装置 的基本准则。 本标准适用于1000 k…

大厂秋招真题【BFS+DP】华为20230921秋招T3-PCB印刷电路板布线【欧弟算法】全网最全大厂秋招题解

题目描述与示例 题目描述 在PCB印刷电路板设计中,器件之间的连线,要避免线路的阻抗值增大,而且器件之间还有别的器任和别的干扰源,在布线时我们希望受到的干扰尽量小。 现将电路板简化成一个M N的矩阵,每个位置&am…

如何快速学习AdsPower RPA(1)——简单、进阶部分

你是否刚开始学习使用AdsPower的RPA功能? 你是否对着这些操作选项头皮发麻,不知所措? 你是否想快速学会RPA? 你是否想编写出满足各种业务场景的RPA流程? 以上这些,Tool哥统统都帮你搞定! Too…

科技成果鉴定测试有多重要?可出具专业测试报告的软件测评机构推荐

科技成果鉴定测试在现代社会中具有重要意义,它不仅可以评估科技成果的价值和可行性,还可以为科技创新提供决策依据,推动科技进步和社会发展,那么科技成果鉴定测试究竟重要在哪呢? 1、对于科技项目的投资决策至关重要。鉴定测试可…

YOLOV8-DET转ONNX和RKNN

目录 1. 前言 2.环境配置 (1) RK3588开发板Python环境 (2) PC转onnx和rknn的环境 3.PT模型转onnx 4. ONNX模型转RKNN 6.测试结果 1. 前言 yolov8就不介绍了,详细的请见YOLOV8详细对比,本文章注重实际的使用,从拿到yolov8的pt检测模型&…

玩转gpgpu-sim 04记—— __cudaRegisterBinary() of gpgpu-sim 到底做了什么

官方文档&#xff1a; GPGPU-Sim 3.x Manual __cudaRegisterBinary(void*) 被执行到的代码逻辑如下&#xff1a; void** CUDARTAPI __cudaRegisterFatBinary( void *fatCubin ) { #if (CUDART_VERSION < 2010)printf("GPGPU-Sim PTX: ERROR ** this version of GPGPU…

查看Linux系统信息的常用命令

文章目录 1. 机器配置查看2. 常用分析工具3. 常用指令解读3.1 lscpu 4. 定位僵尸进程5. 参考 1. 机器配置查看 # 总核数物理CPU个数x每颗物理CPU的核数 # 总逻辑CPU数物理CPU个数x每颗物理CPU的核数x超线程数 cat /proc/cpuinfo| grep "physical id"| sort| uniq| w…