YOLOv5改进 | Head | 将yolov5的检测头替换为ASFF_Detect

news2025/2/26 5:56:59

💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡

在目标检测中,为了解决尺度变化的问题,通常采用金字塔特征表示。然而,对于基于特征金字塔的单次检测器来说,不同特征尺度之间的不一致性是一个主要限制。为此,研究人员提出了一种新颖的、基于数据的策略,用于金字塔特征融合,称为自适应空间特征融合(ASFF)。它学习了一种方法,用以在空间上过滤冲突信息,从而抑制不一致性,提高了特征的尺度不变性,并且几乎不引入额外的推理开销。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改并将修改后的完整代码放在文章的最后,方便大家一键运行,小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址 YOLOv5改进+入门——持续更新各种有效涨点方法 点击即可跳转

目录

1.原理

2. 将ASFF_DETECT代码实现

2.1 ASFF_DETECT添加到YOLOv5中

2.2 新增yaml文件

2.3 注册模块

2.4 执行程序

3. 完整代码分享

4. GFLOPs

5. 进阶

6. 总结


1.原理

论文地址:Learning Spatial Fusion for Single-Shot Object Detection——点击即可跳转

官方代码:官方代码仓库——点击即可跳转

自适应空间特征融合(ASFF)的主要原理旨在解决单次检测器中不同尺度特征的不一致性问题。具体来说,ASFF通过动态调整来自不同尺度特征金字塔层的特征贡献,确保每个检测对象的特征表示是一致且最优的。以下是ASFF的主要原理:

原理概述

  1. 多尺度特征的融合

    • 传统的特征金字塔网络(FPN)在不同尺度上提取特征,但这些特征在空间位置上可能存在不一致性,导致检测效果不佳。

    • ASFF通过一个自适应融合模块,动态地结合来自不同尺度的特征图,使得每个像素点能够获得来自各个尺度的最优特征表示。

  2. 自适应权重学习

    • ASFF在训练过程中通过一个轻量级的网络结构(如1x1卷积层)学习自适应权重,这些权重用于加权组合来自不同尺度的特征。

    • 这个学习过程是自适应的,即权重会根据输入图像的特征和目标物体的位置进行调整,从而确保融合后的特征在空间和语义上都是最优的。

  3. 特征一致性

    • 通过自适应权重,ASFF能有效地调节各尺度特征的贡献,解决了特征金字塔中不同层次特征之间的空间位置不一致性问题。

    • 这种融合方式不仅增强了特征的一致性,还提高了检测器对各种尺度目标的响应能力。

具体步骤

  1. 特征提取

    输入图像通过基础卷积神经网络(如ResNet)提取特征,并通过特征金字塔网络(FPN)生成不同尺度的特征图。
  2. 权重生成

    对每个尺度的特征图,ASFF使用一个小型网络(如1x1卷积层)生成对应的自适应权重图。
  3. 特征融合

    将不同尺度的特征图与其对应的权重图逐像素相乘,然后进行加权求和,生成最终的融合特征图。
  4. 检测输出

    最终的融合特征图输入到检测头中,生成检测结果(如边界框和类别预测)。

优势

  • 性能提升:通过自适应融合不同尺度的特征,ASFF显著提升了检测精度,特别是在复杂场景和多尺度目标检测任务中。

  • 高效性:ASFF在提高性能的同时,保持了较低的计算开销,仅增加了极少的推理时间,适合实时应用。

ASFF的方法通过动态调整特征贡献,确保每个像素点在不同尺度特征上的最优组合,从而提高了单次检测器的整体检测性能。

2. 将ASFF_DETECT代码实现

2.1 ASFF_DETECT添加到YOLOv5中

 关键步骤一:将下面代码粘贴到/yolov5-6.1/models/yolo.py文件中

class ASFF_Detect(nn.Module):   #add ASFFV5 layer and Rfb 
    stride = None  # strides computed during build
    onnx_dynamic = False  # ONNX export parameter
    export = False  # export mode

    def __init__(self, nc=80, anchors=(), ch=(), multiplier=0.5,rfb=False,inplace=True):  # detection layer
        super().__init__()
        self.nc = nc  # number of classes
        self.no = nc + 5  # number of outputs per anchor
        self.nl = len(anchors)  # number of detection layers
        self.na = len(anchors[0]) // 2  # number of anchors
        self.grid = [torch.zeros(1)] * self.nl  # init grid
        self.l0_fusion = ASFFV5(level=0, multiplier=multiplier,rfb=rfb)
        self.l1_fusion = ASFFV5(level=1, multiplier=multiplier,rfb=rfb)
        self.l2_fusion = ASFFV5(level=2, multiplier=multiplier,rfb=rfb)
        self.anchor_grid = [torch.zeros(1)] * self.nl  # init anchor grid
        self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))  # shape(nl,na,2)
        self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)  # output conv
        self.inplace = inplace  # use in-place ops (e.g. slice assignment)

    def forward(self, x):
        z = []  # inference output
        result=[]
       
        result.append(self.l2_fusion(x))
        result.append(self.l1_fusion(x))
        result.append(self.l0_fusion(x))
        x=result      
        for i in range(self.nl):
            x[i] = self.m[i](x[i])  # conv
            bs, _, ny, nx = x[i].shape  # x(bs,255,20,20) to x(bs,3,20,20,85)
            x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()

            if not self.training:  # inference
                if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
                    self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)

                y = x[i].sigmoid() # https://github.com/iscyy/yoloair
                if self.inplace:
                    y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i]  # xy
                    y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i]  # wh
                else:  # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
                    xy, wh, conf = y.split((2, 2, self.nc + 1), 4)  # y.tensor_split((2, 4, 5), 4)  # torch 1.8.0
                    xy = (xy * 2 + self.grid[i]) * self.stride[i]  # xy
                    wh = (wh * 2) ** 2 * self.anchor_grid[i]  # wh
                    y = torch.cat((xy, wh, conf), 4)
                z.append(y.view(bs, -1, self.no))

        return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
    
    def _make_grid(self, nx=20, ny=20, i=0):
        d = self.anchors[i].device
        t = self.anchors[i].dtype
        shape = 1, self.na, ny, nx, 2  # grid shape
        y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
        if check_version(torch.__version__, '1.10.0'):  # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
            yv, xv = torch.meshgrid(y, x, indexing='ij')
        else:
            yv, xv = torch.meshgrid(y, x)
        grid = torch.stack((xv, yv), 2).expand(shape) - 0.5  # add grid offset, i.e. y = 2.0 * x - 0.5
        anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
        #print(anchor_grid)
        return grid, anchor_grid

2.2 新增yaml文件

关键步骤二在下/yolov5-6.1/models下新建文件 yolov5_ASFF.yaml并将下面代码复制进去

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)

   [[17, 20, 23], 1, ASFF_Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

2.3 注册模块

关键步骤三:在yolo.py中注册,

首先在model的类下面添加下面内容,位置如图所示

if isinstance(m, ASFF_Detect):
            s = 256  # 2x min stride
            m.inplace = self.inplace
            m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))])  # forward
            m.anchors /= m.stride.view(-1, 1, 1)
            check_anchor_order(m)
            self.stride = m.stride
            try:
                self._initialize_biases()  # only run once    
                LOGGER.info('initialize_biases done')
            except:
                LOGGER.info('decoupled no biase ')

 然后修改_profile_one_layer函数下的代码为

c = isinstance(m, Detect) or isinstance(m, ASFF_Detect) # is final layer, copy input as inplace fix

 修改后如下图所示

修改_apply的内容

if isinstance(m, Detect) or isinstance(m, ASFF_Detect):

修改后如下

 在parse_model函数中注册模块

内容如下位置如下

elif m is ASFF_Detect:
            args.append([ch[x] for x in f])
            if isinstance(args[1], int):  # number of anchors
                args[1] = [list(range(args[1] * 2))] * len(f)

2.4 执行程序

在train.py中,将cfg的参数路径设置为yolov5_ASFF.yaml的路径

建议大家写绝对路径,确保一定能找到

  🚀运行程序,如果出现下面的内容则说明添加成功🚀

3. 完整代码分享

https://pan.baidu.com/s/1C98TemcSlia0n4ngAb9guQ?pwd=z6n4

提取码: z6n4 

4. GFLOPs

关于GFLOPs的计算方式可以查看:百面算法工程师 | 卷积基础知识——Convolution

未改进的GFLOPs

改进后的GFLOPs

5. 进阶

现在的代码只能适配yolov5s版本,你能将他们扩展为更大的模型吗?

6. 总结

ASFF检测头的核心在于自适应地融合来自不同尺度的特征,以提高单次检测器的精度和鲁棒性。ASFF检测头首先通过基础卷积神经网络提取输入图像的基本特征,并通过特征金字塔网络(FPN)生成多个尺度的特征图。然后,ASFF模块在每个尺度上使用一个轻量级的网络(例如1x1卷积层)生成自适应权重图,这些权重图用来表示各个尺度特征对最终融合特征的贡献。接下来,不同尺度的特征图与对应的权重图逐像素相乘,再进行加权求和,生成一个融合后的特征图,该特征图在空间和语义上都更加一致。最后,这个融合特征图输入到检测头中,用于生成检测结果,包括物体的边界框和类别预测。通过这种自适应的特征融合方法,ASFF检测头有效地解决了不同尺度特征之间的不一致性问题,显著提高了检测精度,同时保持了较低的计算开销,使其适用于实时应用场景。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1820532.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NestJS学习笔记

一、安装NestJS CLI工具 环境检查 //查看node版本 node -v//查看npm版本 npm -v 安装nest/cli 使用npm全局安装nestjs/cli npm i -g nestjs/cli 查看nest版本 nest -v 结果如图: 创建nest项目 //命令行创建nest项目 nest new 【项目名】 VScode扩展下载 1、…

体验版小程序访问不到后端接口请求失败问题解决方案

文章目录 解决方案一:配置合法域名解决方案二:开发调试模式第一步:进入开发调试模式第二步:启用开发调试 注意事项结语 🎉欢迎来到Java面试技巧专栏~探索Java中的静态变量与实例变量 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒&…

银行业信息技术外包(ITO)深度解析:现状、挑战、业务分类与协同策略

一、引言 最近有朋友在咨询关于银行业信息技术外包(ITO)这块业务,同时也在网上看到了关于银行业信息技术外包(ITO)的相关信息,今天正好有时间,通过采集的相关信息结合自己的相关工作接触到的相关…

推挽与开漏输出

一般来说,微控制器的引脚都会有一个驱动电路,可以配置不同类型的数字和模拟电路接口。输出模式一般会有推挽与开漏输出。 推挽输出 推挽输出(Push-Pull Output),故名思意能输出两种电平,一种是推&#xf…

STM32开发过程中碰到的问题总结 - 1

文章目录 前言1. 怎么生成keil下可以使用的文件和gcc下编译使用的makefile2. STM32的时钟树3.怎么查看keil5下的编译工具链用的是哪个4. Arm编译工具链和GCC编译工具链有什么区别吗?5. 怎么查看Linux虚拟机是x86的还是aarch646. 怎么下载gcc-arm的编译工具链7.怎么修…

JavaScript的数组排序

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…

css设置滚动条样式;滚动条设置透明

滚动条透明代码 .resizable-div {resize: both;/* 允许水平和垂直调整大小 */overflow: auto;/* 确保内容超出边界时出现滚动条 */ } /* 滚动条整体样式 */ .resizable-div::-webkit-scrollbar {width: 4px; /* 竖直滚动条宽度 */height: 4px; /* 水平滚动条高度 */ }/* 滚动条…

使用uniapp设置tabbar的角标和移除tabbar的角标

使用场景描述 在一进入到小程序的时候就要将用户在购物车中添加的商品总数&#xff0c;要以角标的形式显示在tababr中。 代码实现 //index.vue<script setup> import { onLoad } from dcloudio/uni-apponLoad(()>{uni.setTabBarBadge({index: 1,text: 5 //为了实现…

VBA实现关闭Excel自动计算,关闭屏幕刷新

Excel代码提速神器 涉及到提取表格大量数据操作&#xff0c;复制粘贴多个单元格时&#xff0c;尽量避免一个个单元格提取&#xff0c;或者一行行一列列提取数值&#xff0c;设计大量IO流操作非常浪费时间。尽量找出数据之间的规律&#xff0c;批量选中复制粘贴&#xff0c;找到…

【大数据】Spark使用大全:下载安装、RDD操作、JAVA编程、SQL

目录 前言 1.下载安装 2.RDD操作 3.JAVA编程示例 4.Spark SQL 前言 本文是作者大数据系列中的一文&#xff0c;专栏地址&#xff1a; https://blog.csdn.net/joker_zjn/category_12631789.html?spm1001.2014.3001.5482 该系列会成体系的聊一聊整个大数据的技术栈&…

P3388 【模板】割点(割顶)

题目背景 割点 题目描述 给出一个 n 个点&#xff0c;m 条边的无向图&#xff0c;求图的割点。 输入格式 第一行输入两个正整数 n,m。 下面 m 行每行输入两个正整数 x,y 表示 x 到 y 有一条边。 输出格式 第一行输出割点个数。 第二行按照节点编号从小到大输出节点&am…

实验五 网络中的树

文章目录 5.1 网络中的树第一关 认识树相关知识编程要求代码文件 第2关 根节点的二阶邻居求解方法相关知识编程要求代码文件 第3关 根节点的n阶邻居求解方法相关知识 5.2 权值矩阵与环&#xff08;无向网络&#xff09;第1关 无向网络的权值矩阵相关知识编程要求代码文件 第2关…

CrossOver 2024软件下载-CrossOver 2024详细安装教程

Crossover软件是一款可以在Mac、Linux和Chromebook上运行Windows程序的软件。 它是一款商业软件&#xff0c;由CodeWeavers公司开发&#xff0c;Crossover不是一个虚拟机或模拟器&#xff0c;它使用Wine技术来将Windows程序直接转换成可以在其他操作系统上运行的程序&#xff0…

基于SpringBoot3+Vue3宠物小程序宠物医院小程序的设计与实现

大家好&#xff0c;我是程序员小孟。 最近开发了一个宠物的小程序&#xff0c;含有详细的文档、源码、项目非常的不错&#xff01; 一&#xff0c;系统的技术栈 二&#xff0c;项目的部署教程 前端部署包&#xff1a;npm i 启动程序&#xff1a;npm run dev 注意事项&…

HSP_07章 排序和查找

P96_ 冒泡排序 排序的基本介绍 冒泡排序介绍 冒泡排序思路分析 代码 # 说明,如果只是完成排序功能,我们可以直接使用list的方法sort # 排序的列表 num_list[24,69,80,57,13,0,900,-1] print("排序前".center(32,"-")) print(f"num_list: {num_list}…

Ceph入门到精通-Bucket 生命周期的作用,新手该如何设置?

存储桶(Bucket)生命周期策略的作用主要是帮助存储管理员高效地管理对象的存储周期,包括对象的转换、存档和删除。以下是关于桶生命周期的作用和配置的概述: 一、桶生命周期的作用: 存储优化:通过将对象转换到更经济的存储类别,降低存储成本。 数据管理:自动删除不再需…

交换机简介

一、 集线器的替代品—交换机 使用集线器的缺点&#xff0c;因此就设计出了交换机来代替集线器&#xff0c;交换机常见端口数量一般有4、8、16、24、32等数量。 华为交换机&#xff1a;S5720-HI系列 仅从实物图上来看&#xff0c;交换机和集线器非常的像&#xff0c;但是它们的…

Python第二语言(十一、Python面向对象(下))

目录 1. 封装 1.1 私有成员&#xff1a;__成员、__成员方法 2. 继承&#xff1a;单继承、多继承 2.1 继承的基础语法 2.2 复写 & 子类使用父类成员 3. 变量的类型注解&#xff1a;给变量标识变量类型 3.1 为什么需要类型注解 3.2 类型注解 3.3 类型注解的语法 3.…

visio添加表格

插入Excel表格&#xff1a; 打开Microsoft Visio&#xff0c;新建一个空白画布。点击菜单栏中的“插入”。在插入中点击“图表”。在弹出的插入对象设置页面中选择“Microsoft Excel工作表”。点击确定按钮&#xff0c;然后在表格中输入内容。将鼠标点击到画布的空白处&#x…

大数据在商业中的应用——Kompas.ai如何助力企业决策

引言 在现代商业中&#xff0c;大数据逐渐成为企业决策的重要工具。通过对海量数据的分析和处理&#xff0c;企业可以获得重要的市场信息和决策支持。本文将探讨大数据在商业中的应用&#xff0c;并介绍Kompas.ai如何通过AI技术助力企业决策。 大数据的发展及其重要性 大数据…