YoloV8改进---注意力机制:引入瓶颈注意力模块BAM,对标CBAM

news2024/11/27 4:15:23

目录

​编辑

1.BAM介绍

 2.BAM引入到yolov8

2.1 加入modules.py中:

 2.2 加入tasks.py中:

2.3 yolov8_BAM.yaml


1.BAM介绍

 论文:https://arxiv.org/pdf/1807.06514.pdf

        摘要:提出了一种简单有效的注意力模块,称为瓶颈注意力模块(BAM),可以与任何前馈卷积神经网络集成。我们的模块沿着两条独立的路径,通道和空间,推断出一张注意力图。我们将我们的模块放置在模型的每个瓶颈处,在那里会发生特征图的下采样。我们的模块用许多参数在瓶颈处构建了分层注意力,并且它可以以端到端的方式与任何前馈模型联合训练。我们通过在CIFAR-100、ImageNet-1K、VOC 2007和MS COCO基准上进行大量实验来验证我们的BAM。我们的实验表明,各种模型在分类和检测性能上都有持续的改进,证明了BAM的广泛适用性。

        作者将BAM放在了Resnet网络中每个stage之间。有趣的是,通过可视化我们可以看到多层BAMs形成了一个分层的注意力机制,这有点像人类的感知机制。BAM在每个stage之间消除了像背景语义特征这样的低层次特征,然后逐渐聚焦于高级的语义–明确的目标。 

 

 作者提出了新的Attention模型——瓶颈注意模块,通过分离的两个路径channel和spatial得到attention map,减少计算开销和参数开销。

实验 

 BAM可以在大规模数据集中的各种模型上有很好的泛化能力,同时参数和计算的开销可以忽略不计,这表明提出的模块BAM可以有效地提高网络容量。另一个值得注意的是,改进的性能来自于只在网络中放置三个模块。

 BAM提高了所有具有两个骨干网络的强大基线的准确性.BAM的准确率提高是以可忽略不计的参数开销实现的,这表明提高不是由于天真的容量增加,而是由于我们有效的特征细化。

 2.BAM引入到yolov8

2.1 加入modules.py中:

###################### BAM  attention  ####     START   by  AI&CV  ###############################

import torch
from torch import nn
import torch.nn.functional as F


class ChannelGate(nn.Module):
    def __init__(self, channel, reduction=16):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.mlp = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel)
        )
        self.bn = nn.BatchNorm1d(channel)

    def forward(self, x):
        b, c, h, w = x.shape
        y = self.avgpool(x).view(b, c)
        y = self.mlp(y)
        y = self.bn(y).view(b, c, 1, 1)
        return y.expand_as(x)


class SpatialGate(nn.Module):
    def __init__(self, channel, reduction=16, kernel_size=3, dilation_val=4):
        super().__init__()
        self.conv1 = nn.Conv2d(channel, channel // reduction, kernel_size=1)
        self.conv2 = nn.Sequential(
            nn.Conv2d(channel // reduction, channel // reduction, kernel_size, padding=dilation_val,
                      dilation=dilation_val),
            nn.BatchNorm2d(channel // reduction),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel // reduction, kernel_size, padding=dilation_val,
                      dilation=dilation_val),
            nn.BatchNorm2d(channel // reduction),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Conv2d(channel // reduction, 1, kernel_size=1)
        self.bn = nn.BatchNorm2d(1)

    def forward(self, x):
        b, c, h, w = x.shape
        y = self.conv1(x)
        y = self.conv2(y)
        y = self.conv3(y)
        y = self.bn(y)
        return y.expand_as(x)


class BAM(nn.Module):
    def __init__(self, channel):
        super(BAM, self).__init__()
        self.channel_attn = ChannelGate(channel)
        self.spatial_attn = SpatialGate(channel)

    def forward(self, x):
        attn = F.sigmoid(self.channel_attn(x) + self.spatial_attn(x))
        return x + x * attn

###################### BAM  attention  ####     END   by  AI&CV  ###############################

 2.2 加入tasks.py中:

def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)

添加以下内容 

        elif m is BAM:
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, *args[1:]]

2.3 yolov8_BAM.yaml

仅供参考,加入网络位置不同在不同数据集表现不一致是正常现场

# Ultralytics YOLO 🚀, GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 2  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)
  - [-1, 1, BAM, [1024]]  # 24 

  - [[15, 18, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

3.YOLOv8魔术师专栏介绍

 💡💡💡YOLOv8魔术师,独家首发创新(原创),持续更新,最终完结篇数≥100+,适用于Yolov5、Yolov7、Yolov8等各个Yolo系列,专栏文章提供每一步步骤和源码,轻松带你上手魔改网络

💡💡💡重点:通过本专栏的阅读,后续你也可以自己魔改网络,在网络不同位置(Backbone、head、detect、loss等)进行魔改,实现创新!!!

 专栏介绍:

✨✨✨原创魔改网络、复现前沿论文,组合优化创新

🚀🚀🚀小目标、遮挡物、难样本性能提升

🍉🍉🍉持续更新中,定期更新不同数据集涨点情况

本专栏提供每一步改进步骤和源码,开箱即用,在你的数据集下轻松涨点

通过注意力机制、小目标检测、Backbone&Head优化、 IOU&Loss优化、优化器改进、卷积变体改进、轻量级网络结合yolo等方面进行展开点,

专栏链接如下:

https://blog.csdn.net/m0_63774211/category_12289773.html?spm=1001.2014.3001.5482

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/722249.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

06、Nginx反向代理与负载均衡

反向代理: 这种代理方式叫做,隧道代理。有性能瓶颈,因为所有的数据都经过Nginx,所以Nginx服务器的性能至关重要 负载均衡: 把请求,按照一定算法规则,分配给多台业务服务器(即使其中…

【论文笔记】Skill-based Meta Reinforcement Learning

【论文笔记】Skill-based Meta Reinforcement Learning 文章目录 【论文笔记】Skill-based Meta Reinforcement LearningAbstract1 INTRODUCTION2 RELATED WORKMeta-Reinforcement LearningOffline datasetsOffline Meta-RLSkill-based Learning 3 PROBLEM FORMULATION AND PRE…

CPU大小端和网络序的理解

引子 Big/Little Endian是Host CPU如何去理解在内存中的数据,内存中的数据是没有Big/Little Endian之分的(内存仅仅作为存储介质),而Host CPU才有Big/Little Endian之分。 不同Endian的CPU,从内存读取数据的时候&#…

Linux进度条

Linux进度条 一.基本概念1.回车和换行2.缓冲区2.实现倒计时 二.进度条1.前置工作2.代码实现 一.基本概念 1.回车和换行 回车:指光标移到该行的起始位置(\r)。 换行:换到下一行(\n)。 在c语音里\n将回车和换…

spring.boot 随笔0 springFactoriesInstance入门

0. 其实也没有那么入门 明天还要上班,速度write,直接放一张多样性比较好的 spring.factories 文件(取自 spring-boot-2.3.4.RELEASE.jar) # PropertySource Loaders org.springframework.boot.env.PropertySourceLoader\ org.springframework.boot.env…

知识图谱实战应用19-基于Py2neo的英语单词关联记忆知识图谱项目

大家好,我是微学AI,今天给大家介绍一下知识图谱实战应用19-基于Py2neo的英语单词关联记忆知识图谱项目。基于Py2neo的英语单词关联记忆知识图谱项目可以帮助用户更系统地学习和记忆英语单词,通过图谱的可视化展示和智能推荐功能,提供了一种全新的、更高效的记忆方法,并促进…

AtomicInteger使用详解

AtomicInteger使用详解 1、get():获取当前AtomicInteger对象的值。2、set(int newValue):将AtomicInteger对象的值设置为指定的newValue。3、getAndSet(int newValue):先获取当前AtomicInteger对象的值,然后将对象的值设置为指定的…

ROS1/2机器人课程的价值和规模

价值用价格,规模用销量。 免费的ROS1/2课程也很多。 2023版,15元,24人。 2022版,1.99元,21人。 价格不贵,人数很少,店家也很少。 当然,有朋友说,有免费冲击&#xff0…

宏晶微 音频处理芯片 MS7124

MS7124是一款高性能24bit数字立体声音频DAC,该DAC采用Sigma-Delta结构,支持标准的I2S数字信号输入,输出支持立体声和单声道。

比对Excel数据

以a个为准绳比对b表数据,添加比对结果列输出。 (本笔记适合初通 Python 的 coder 翻阅) 【学习的细节是欢悦的历程】 Python 官网:https://www.python.org/ Free:大咖免费“圣经”教程《 python 完全自学教程》,不仅仅是基础那么…

大数据Doris(五十七):RECOVER数据删除恢复

文章目录 RECOVER数据删除恢复 一、Recover语法 二、数据恢复案例 RECOVER数据删除恢复 Doris为了避免误操作造成的灾难,支持对误删除的数据库/表/分区进行数据恢复,在drop table或者 drop database之后,Doris不会立刻对数据进行物理删除…

C语言进阶--文件操作

目录 一.何为文件 二.文件名 三.文件的打开和关闭 3.1.流 3.2.文件指针 3.3.文件的打开与关闭 打开文件: 模式: 关闭文件: 四.文件的顺序读写 4.1.常见的顺序读写函数 4.2.字符的输入输出fgetc/fputc 输出函数: 输入…

计算机存储层次及常用存储简介

计算机存储层次(Memory hierarchy) 存储层次是在计算机体系结构下存储系统层次结构的排列顺序。 每一层于下一层相比 都拥有 较高的速度 和 较低延迟性 ,以及 较小的容量 (也有少量例外,如AMD早期的Duron CPU&#xf…

pod 控制器 3

简单回顾 之前我们学习过的的 docker ,例如我们运行 docker run busybox echo "hello wrold" 他的实际内在逻辑是这个样子的 程序将指令推送给 dockerdocker 会检查本地是否有 busybox 镜像,若没有则去 docker hub 上面拉取镜像,…

windows下mysql配置文件my.ini在这个位置

windows下mysql配置文件my.ini在这个位置 选中服务邮件 右键----属性,弹出下图 一般默认路径: "D:\Program Files\MySQL\MySQL Server 8.0\bin\mysqld.exe" --defaults-file"C:\ProgramData\MySQL\MySQL Server 8.0\my.ini" MySQ…

信息安全管理与评估赛题第4套

全国职业院校技能大赛 高等职业教育组 信息安全管理与评估 赛题四 模块一 网络

simulink 常用模块

add sub pro div加减乘除 Relational Operator 数值比较模块 < < > > ! Compare To Constant 直接和某数字比较&#xff0c;是上面的封装 Logical Operator 逻辑运算 & | ~ switch 相当于c语言的if 中间是条件&#xff0c;满足走上&#xff0c;否则走下…

上门按摩小程序开发|同城预约上门小程序定制

上门按摩小程序对实体按摩商家来说是非常适合的。下面是对上门按摩小程序适合实体按摩商家开发的简单介绍&#xff1a;   扩展服务范围&#xff1a;上门按摩小程序可以让实体按摩商家将服务范围扩展到用户的家中或办公场所。用户可以通过小程序预约上门按摩服务&#xff0c;无…

【openGauss数据库】--运维指南03--数据导出

【openGauss数据库】--运维指南03--数据导出 &#x1f53b; 一、openGauss导出数据&#x1f530; 1.1 概述&#x1f530; 1.2 导出单个数据库&#x1f537; 1.2.1 导出数据库&#x1f537; 1.2.2 导出模式&#x1f537; 1.2.3 导出表 &#x1f530; 1.3 导出所有数据库&#x1…

计算两个二维数组arr1和arr2中对应位置元素的商

代码实现 &#xff1a;一个嵌套循环&#xff0c;用于计算两个二维数组arr1和arr2中对应位置元素的商&#xff0c;并将结果存储在result数组中。首先&#xff0c;定义了一个空数组result用于存储结果。然后&#xff0c;通过两个for循环遍历arr1数组的每一行和每一列。在内层循环…