【目标检测实验系列】YOLOv5/YOLOv8改进:CARAFE轻量级上采样算子,聚合上下文信息,助力模型涨点(文内附源码)

news2024/9/20 0:51:15

1. 文章主要内容

       本篇博客主要涉及轻量级上采样算子CARAFE,将YOLOv5/YOLOv8模型中最近邻上采样算子改为CARAFE算子,使模型聚合上下文信息,助力模型涨点。

2. 简要概括

       论文地址:CARAFE论文地址
       论文Github代码:Github代码

       CARAFE具有以下特点:
       1.感受野大。不同于以往只利用亚像素邻域的工作(如双线性插值),CARAFE可以在一个大的接收域中聚合上下文信息。
       2.内容感知。CARAFE不是为所有的样本使用一个固定的内核(例如反卷积,也就是transposed conv),而是支持特定于实例的内容感知处理,它可以动态地生成自适应的内核。
       3.轻量级、计算速度快。CARAFE引入了很少的计算开销,可以很容易地集成到现有的网络架构中。其结构如下图所示:
在这里插入图片描述
       分析:CARAFE作为上采样算子,其轻量级、计算资源小并且高效的特点能够助力YOLOv5/YOLOv8模型涨点。YOLOv5/YOLOv8模型的金字塔结构,也就是Neck结构部分,是需要通过上采样的算子将图片的分辨率进行扩大,以便层级的进行融合,从而融合层级的特征。然而,在上采样的过程中,原模型使用的是最近邻的静态上采样方法,容易丢失信息,尤其是小目标的信息特征(因为小目标本来在图像中的占比像素就比较少),这样就会导致模型识别目标的精度会降低,出现漏检、误检等问题。而CARAFE算子动态的进行上采样,而且有一个大的感受野聚合上下文信息,从而能够抑制小目标信息过多丢失,助力模型涨点!

3. 详细代码改进流程

       接下来手把手将CARAFE算子添加到YOLOv5/YOLOv8模型中某一个地方的全实验过程。首先给出CARAFE算子的源码,新建一个CARAFE.py存放源代码,如下所示:

import torch
from torch import nn

from models.common import Conv


class CARAFE(nn.Module):
    def __init__(self, c, k_enc=3, k_up=5, c_mid=64, scale=2):
        """ The unofficial implementation of the CARAFE module.
        The details are in "https://arxiv.org/abs/1905.02188".
        Args:
            c: The channel number of the input and the output.
            c_mid: The channel number after compression.
            scale: The expected upsample scale.
            k_up: The size of the reassembly kernel.
            k_enc: The kernel size of the encoder.
        Returns:
            X: The upsampled feature map.
        """
        super(CARAFE, self).__init__()
        self.scale = scale

        self.comp = Conv(c, c_mid)
        self.enc = Conv(c_mid, (scale * k_up) ** 2, k=k_enc, act=False)
        self.pix_shf = nn.PixelShuffle(scale)

        self.upsmp = nn.Upsample(scale_factor=scale, mode='nearest')
        self.unfold = nn.Unfold(kernel_size=k_up, dilation=scale,
                                padding=k_up // 2 * scale)

    def forward(self, X):
        b, c, h, w = X.size()
        h_, w_ = h * self.scale, w * self.scale

        W = self.comp(X)  # b * m * h * w
        W = self.enc(W)  # b * 100 * h * w
        W = self.pix_shf(W)  # b * 25 * h_ * w_
        W = torch.softmax(W, dim=1)  # b * 25 * h_ * w_

        X = self.upsmp(X)  # b * c * h_ * w_
        X = self.unfold(X)  # b * 25c * h_ * w_
        X = X.view(b, c, -1, h_, w_)  # b * 25 * c * h_ * w_

        X = torch.einsum('bkhw,bckhw->bchw', [W, X])  # b * c * h_ * w_
        return X


       然后分别来讲融合到YOLOv5和YOLOv8模型的教程,这里先从YOLOv5模型开始。

3.1 YOLOv5融合CARAFE算子

3.1.1 新建一个yolov5-CARAFE.yaml文件

       然后,新建一个yolov5-CARAFE.yaml文件,同时 注意nc改为自己数据集的类别数这里是将CARAFE算子替换了Neck部分的两个上采样部分。yaml文件源代码如下所示:

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 4  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8  小目标
  - [30,61, 62,45, 59,119]  # P4/16 中目标
  - [116,90, 156,198, 373,326]  # P5/32  大目标

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2  output_channel, kernel_size, stride, padding
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, CARAFE, [3, 5]],
   #[-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, CARAFE, [3, 5]],
   #[-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
  
   [[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

3.1.2 将CARAFE引入到yolo.py文件中

       在下图的红色圈内位置处,引入CARAFE,并手动导入相应的包即可。代码和示意图如下:

 elif m is CARAFE:
            c2 = ch[f]
            args = [c2, *args]

在这里插入图片描述

3.1.3 修改train.py启动文件

       修改配置文件为yolov5-CARAFE.yaml即可,如下图所示:
在这里插入图片描述

3.2 YOLOv8融合CARAFE算子

3.2.1 新建一个yolov8-CARAFE.yaml文件

       然后,新建一个yolov8-CARAFE.yaml文件,同时 注意nc改为自己数据集的类别数这里是将CARAFE算子替换了Neck部分的两个上采样部分。yaml文件源代码如下所示:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 10 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9

# YOLOv8.0n head
head:
  - [-1, 1, CARAFE, []]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 12

  - [-1, 1, CARAFE, []]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2f, [1024]] # 21 (P5/32-large)

  - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)

3.2.2 将CARAFE源码拷贝到conv.py中并在__init__.py全局注册

       conv.py的路径在这里:ultralytics/nn/modules/conv.py
       init.py的路径在这里:ultralytics/nn/modules/init.py
       第一步:将源码拷贝到conv.py的随便一处,然后在conv.py的最上面有个__all__的地方注册CARAFE引用,如下图所示:
在这里插入图片描述
       第二步:在__init__.py中的.conv中import CARAFE算子,然后同时在此文件的最下放的__all__处也注册CARAFE算子,分别如下图所示:
在这里插入图片描述
在这里插入图片描述

3.2.3 将CARAFE引入到task.py中,并添加相应的代码

       将CARAFE的逻辑代码引入到task.py中,并导入相应的包,这一步其实和YOLOv5中的yolo.py有异曲同工之妙。另外task.py的路径在这里:ultralytics/nn/tasks.py,代码和操作如下所示:

        elif m is CARAFE:
            c1 = ch[f]
            args = [c1]

在这里插入图片描述
在这里插入图片描述

3.2.4 启动train.py即可

       将train.py中训练YOLOv8的模型文件改为YOLOv8n-CARAFE.yaml即可。这里需要注意到,我在YOLOv8后面加了一个n,程序会自动识别调用n的模型大小,我们的yaml文件不需要重命名后面多加一个n。

4. 总结

       本篇博客主要介绍了CARAFE轻量级上采样算子,聚合上下文信息,助力YOLOv5/YOLOv8模型涨点。另外,在修改过程中,要是有任何问题,评论区交流;如果博客对您有帮助,请帮忙点个赞,收藏一下;后续会持续更新本人实验当中觉得有用的点子,如果很感兴趣的话,可以关注一下,谢谢大家啦!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2035254.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Go语言 Defer(延迟)

本文主要内容为Go语言中defer(延迟)介绍及应用文件读取使用defer的示例。 目录 定义 应用场景 代码示例 改为匿名函数 总结 定义 延迟:关键字,可以用于修饰语句、函数, 确保这条语句可以在当前栈退出的时候执行。 应用场景 1.一般用于…

【leetcode】特殊数组I【(炒鸡)简单】

好像这题没啥子好说的欸&#xff0c;那就祝点进来的友友今天有好事发生叭~ AC代码见下&#xff1a; class Solution { public:bool isArraySpecial(vector<int>& nums) {for(int i1; i<nums.size(); i)if(nums[i]%2 nums[i-1]%2) return false;return true;} }…

如何妙用哈希表来优化遍历查找过程?刷题感悟总结,c++实现

先上题目 题目链接&#xff1a;题目链接 这题我最先想到的就是前缀和a&#xff0c;构造好了以后就遍历每一个[l,r]数组&#xff08;满足题目要求的连续区间数组&#xff09;&#xff0c;奈何倒数第二个样例时间超限 先给出原思路代码 class Solution { public:int subarray…

网络如何发送一个数据包

网络如何发送一个数据包 网络消息发送就是点一点屏幕。 骚瑞&#xff0c;这一点都不好笑。&#xff08;小品就是我的本质惹&#xff09; 之前我就是会被这个问题搞的不安宁。是怎么知道对方的IP地址的呢&#xff1f;怎么知道对方的MAC呢&#xff1f;世界上计算机有那么多&…

top250的电影

本次的电影排行来源于豆瓣。材料仅用于自身学习和记录自己学习过程 使用python中的requests、BeautifulSoup、xlwt&#xff0c;三者需要提前下载好。。 预处理&#xff1a; url&#xff1a;反应网页变化 其中start后面的数字变化每次加25&#xff0c;对应一页&#xff0c;故…

用exceljs和file-saver插件实现纯前端表格导出Excel(支持样式配置,多级表头)

exceljs在Jquery&#xff08;HTML&#xff09;和vue项目中实现导出功能 前言Jquery&#xff08;HTML&#xff09;中实现导出第一步&#xff0c;先在项目本地中导入exceljs和file-saver包第二步&#xff0c;封装导出Excel方法&#xff08;可直接复制粘贴使用&#xff09;第三步&…

JJ音乐,听歌自由!

林俊杰&#xff0c;这位才华横溢的音乐才子&#xff0c;用他的音符编织了一个又一个令人陶醉的梦幻世界。作为他的音乐爱好者&#xff0c;每一次倾听都是一次心灵的旅程。 他的歌声仿佛有一种魔力&#xff0c;能够穿透灵魂。从《江南》的诗意浪漫&#xff0c;到《不为谁而作的歌…

探索树莓派Pico 2:新一代RP2350芯片引领的微型开发革命

Raspberry Pi Pico 2 是由树莓派基金会推出的微处理器开发板&#xff0c;作为Pico系列的最新成员&#xff0c;它在原有的基础上进行了多项改进和扩展。这款开发板搭载了全新的RP2350芯片&#xff0c;具有更强大的处理能力和更多的功能特性。 1. Raspberry Pi Pico 2的特性和规格…

使用CUbeMX配置STM32F103C8T6 CRC校验

一、CubeMX配置 1.配置RCC 2.配置SYS 3.启用CRC校验 二、Keil添加程序 1.main.c /* USER CODE BEGIN Header */ /********************************************************************************* file : main.c* brief : Main program body*******…

LVGL——(4)标签控件

文章目录 一、介绍二、用法1、创建2、显示文本2.1 直接设置要显示的文本2.2 格式化给定要显示的文本2.3 在 label 中进行换行 3、改变字体大小4、长模式5、文本选择6、文本对齐方式7、非常长的文本8、显示内置图标字体9、事件处理 三、拓展1、修改文本颜色1.1 Palette&#xff…

研0 冲刺算法竞赛 day30 P1102 A-B 数对

P1102 A-B 数对 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 思路&#xff1a; ①map&#xff0c;键值对计数&#xff0c;将A-B->A-C ②先排序&#xff0c;找对应差值为C的第一个和最后一个计数 代码&#xff1a; #include<iostream> #include <map> #i…

Typora绿色版

1、下载安装 Typora 官网地址&#xff1a;https://typora.io/ 中文站地址&#xff1a;Typora 2、击活 Typora 鼠标右击文件所在位置查询 resources\page-dist\static\js\LicenseIndex.180dd4c7.4da8909c.chunk.chunk.js e.hasActivated"true"e.hasActivated, 替…

使用nvm切换Node.js版本

一、安装nvm nvm&#xff08;Node Version Manager&#xff09;是一个用于管理Node.js版本的工具&#xff0c;它允许你在同一台机器上安装和切换多个Node.js版本。 1.安装nvm https://github.com/coreybutler/nvm-windows 访问以上链接到github去下载 点击releases 下载下图…

优化if-else的几种方式

优化if-else的几种方式 策略模式1、创建支付策略接口2、书写不同的支付方式逻辑代码微信支付QQ支付 3、service层的实现类使用4、controller层的调用说明 枚举与策略模式结合1、创建枚举2、service层书写处理方法3、controller层调用4、说明 Lambda表达式与函数接口说明 策略模…

用于理解视频的基础视觉编码器VideoPrism

人工智能咨询培训老师叶梓 转载标明出处 如何让机器有效地理解和处理视频内容&#xff0c;一直是计算机视觉领域的一个挑战。最近&#xff0c;Google Research的研究人员提出了一种名为VideoPrism的新型视频编码器&#xff0c;旨在通过单一的冻结模型处理多样化的视频理解任务。…

风云崛起之拉氏变换和拉式逆变换

图像的分割写出来了&#xff0c;但是写的不好&#xff0c;暂时先不发了。这两天小y想在把拉式变换的内容写出来&#xff0c;小y最近再看信号和电路&#xff0c;需要复习数学&#xff0c;所以把这点写出来。 首先要推出分布积分的公式&#xff0c;我们知道积分和微分为逆运算&am…

纯css实现多行文本右下角最后一行展示全部按钮

未展开全部&#xff1a; 展开全部&#xff1a; 综上演示按钮始终保持在最下方 css代码如下&#xff1a; <div class"info-content"><div class"info-text" :class"!showAll ? mle-hidden : "><span class"show-all"…

STM32-定时器-定时器中断-PWM调光

1、TIM 定时器 定时器是一种电子设备或软件组件&#xff0c;用于在预定时间后触发一个事件或操作。它可以基于时钟信号或其他周期性信号来工作&#xff0c;并且可以用来测量时间间隔、生成延时、触发中断等。 时钟信号 时钟信号是一种周期性的电信号&#xff0c;用于同步电路中…

如何检查端口占用:netstat和lsof指令

在网络故障排查和系统管理中&#xff0c;检查端口占用情况是一项常见且重要的任务。本文将详细介绍如何使用 netstat 和 lsof 这两个强大的工具来检查端口占用和相关服务。 1. 使用 netstat 查看端口占用 netstat (network statistics) 是一个用于显示网络连接、路由表、接口…

Flutter 学习 一部分注意点记录

使用AndroidStudio进行开发 假设你已经配置好Flutter和dart的SDK. 创建一个可执行dart文件 如果需要直接新建一个dart文件来运行&#xff0c;可以点击 File->New->New Flutter Project &#xff0c;下面是接下来弹出的新建项目弹窗&#xff0c;选中左边的Dart&#xff…