YOLOv8改进 - 注意力篇 - 引入SCAM注意力机制

news2024/12/25 14:00:31

一、本文介绍

作为入门性篇章,这里介绍了SCAM注意力在YOLOv8中的使用。包含SCAM原理分析,SCAM的代码、SCAM的使用方法、以及添加以后的yaml文件及运行记录。

二、SCAM原理分析

SCAM官方论文地址:SCAM文章

SCAM官方代码地址:SCAM代码

SCAM注意力机制(空间上下文感知模块):

空间上下文感知模块(SCAM)在FEM和FFM之后,特征映射已经考虑了局部上下文信息,并且能够很好地表示小对象特征。在此阶段对小目标和背景之间的全局关系进行建模比在主干阶段更有效。利用全局上下文信息来表示像素之间的跨空间关系,可以抑制无用背景,增强目标和背景的区分能力。受GCNet和SCP的启发,SCAM由三个分支组成。第一个部分使用GAP和GMP整合全球信息。第二个分支使用1 × 1卷积生成特征映射的线性变换结果,该特征映射在图4中称为value。第三个分支使用1 × 1卷积来简化查询和键的倍数。这个卷积在图4中称为QK。随后,将第一分支和第三分支分别与第二分支矩阵相乘。得到的两个分支分别表示跨通道和空间的上下文信息。最后,利用广播Hadamard积在这两个分支上得到了SCAM的输出。

相关代码:

SCAM注意力的代码,如下。

class SCAM(nn.Module):
    def __init__(self, in_channels, reduction=1):
        super(SCAM, self).__init__()
        self.in_channels = in_channels
        self.inter_channels = in_channels

        self.k = Conv(in_channels, 1, 1, 1)
        self.v = Conv(in_channels, self.inter_channels, 1, 1)
        self.m = Conv_withoutBN(self.inter_channels, in_channels, 1, 1)
        self.m2 = Conv(2, 1, 1, 1)

        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # GAP
        self.max_pool = nn.AdaptiveMaxPool2d(1)  # GMP

    def forward(self, x):
        n, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3)

        # avg max: [N, C, 1, 1]
        avg = self.avg_pool(x).softmax(1).view(n, 1, 1, c)
        max = self.max_pool(x).softmax(1).view(n, 1, 1, c)

        # k: [N, 1, HW, 1]
        k = self.k(x).view(n, 1, -1, 1).softmax(2)

        # v: [N, 1, C, HW]
        v = self.v(x).view(n, 1, c, -1)

        # y: [N, C, 1, 1]
        y = torch.matmul(v, k).view(n, c, 1, 1)

        # y2:[N, 1, H, W]
        y_avg = torch.matmul(avg, v).view(n, 1, h, w)
        y_max = torch.matmul(max, v).view(n, 1, h, w)

        # y_cat:[N, 2, H, W]
        y_cat = torch.cat((y_avg, y_max), 1)

        y = self.m(y) * self.m2(y_cat).sigmoid()

        return x + y

四、YOLOv8中SCAM使用方法

1.YOLOv8中添加SCAM模块:

首先在ultralytics/nn/modules/conv.py最后添加SCAM模块的代码。

2.在conv.py的开头__all__ = 内添加SCAM模块的类别名:

3.在同级文件夹下的__init__.py内添加SCAM的相关内容:(分别是from .conv import SCAM ;以及在__all__内添加SCAM)

4.在ultralytics/nn/tasks.py进行LSKA注意力机制的注册,以及在YOLOv8的yaml配置文件中添加SCAM即可。

首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面添加以下注册代码:

        elif m is SCAM:
            c2 = ch[f]
            args = [c2]

然后,就是新建一个名为YOLOv8_SCAM.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_SCAM.yaml)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call CPAM-yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SCAM, [1024]]#11代表卷积核大小,可以填写7、11、23、35、41、53
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

其中参数中nc,由自己的数据集决定。本文测试,采用的coco8数据集,有80个类别。

在根目录新建一个train.py文件,内容如下

from ultralytics import YOLO

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
# 加载一个模型
    model = YOLO('ultralytics/cfg/models/v8/YOLOv8_SCAM.yaml')  # 从YAML建立一个新模型
# 训练模型
    results = model.train(data='ultralytics/cfg/datasets/coco8.yaml', epochs=1,imgsz=640,optimizer="SGD")

训练输出:​

​​

五、总结

以上就是SCAM的原理及使用方式,但具体SCAM注意力机制的具体位置放哪里,效果更好。需要根据不同的数据集做相应的实验验证。希望本文能够帮助你入门YOLO中注意力机制的使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2181197.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

解决 Macos下 Orbstack docker网络问题

两种解决方法,第一种开代理 参考 —— 但是我这一种没成功,第二种方法是换镜像源 { "registry-mirrors": ["http://hub-mirror.c.163.com","https://docker.mirrors.ustc.edu.cn","https://mirrors.tencent.com&q…

安防监控/视频系统EasyCVR视频汇聚平台如何过滤134段的告警通道?

视频汇聚/集中存储EasyCVR安防监控视频系统采用先进的网络传输技术,支持高清视频的接入和传输,能够满足大规模、高并发的远程监控需求。平台支持国标GB/T 28181协议、部标JT808、GA/T 1400协议、RTMP、RTSP/Onvif协议、海康Ehome、海康SDK、大华SDK、华为…

大麦演唱会门票

切勿再令您所爱的人耗费高昂的价格去购置黄牛票 ⚠️核心内容参考: 据悉,于购票环节,大麦凭借恶意流量清洗技术,于网络层实时甄别并阻拦凭借自动化手段发起下单请求的流量,强化对刷票脚本、刷票软件以及虚拟设备的识别能力&#…

开源 AI 智能名片 2+1 链动模式 S2B2C 商城小程序的数据运营策略与价值创造

一、引言 1.1 研究背景 在当今数字化时代,数据运营已成为企业发展的核心驱动力。开源 AI 智能名片 21 链动模式 S2B2C 商城小程序作为一种创新的营销工具,与数据运营紧密相连。该小程序通过集成人工智能、大数据分析等先进技术,能够实时收集…

【问题解决】Xshell终端双击或者选中文字自动发送Ctrl+C

问题 在xshell终端,当鼠标双击或者选中一行文字时,xshell会自动发送一个 CtrlC 的命令。如下图: 原因 已知可能会导致这个问题的软件,关掉就没问题了: 有道词典金山词典词霸秒译bing翻译钉钉AI助理360极速搜索… …

Python保留数据删除Excel单元格的函数和公式

在分析处理Excel表格时,我们可能需要使用各种公式或函数对表格数据进行计算,从而分析出更多的信息。但在展示、分享或再利用分析结果时,我们可能需要将含有公式的单元格转换为静态数值,从而简化数据、保护计算结果不被更改&#x…

(c++)内存四区:1.代码区2.全局区(静态区)3.栈区4.堆区

//内存四区:1.代码区 2.全局区 3.栈区 4.堆区 1.放在代码区的有:1.写的代码:只读的、共享的、存放的二进制机器指令、由操作系统直接管理 2.放在全局区的有:1.全局的(变量或常量) 2.静态的&#xff0…

【毕业/转行】想从事GIS开发工程师?如何规划?

既然是GIS开发,那就离不开学习编程 那如何学习才能掌握呢?如何才能达到企业的用人标准? 给大家梳理了学习的路线,想从事gis开发的小伙伴可以直接按这个路线学习! 共分为6大阶段,让你从纯小白到成熟的三维GIS开发工程师! 大纲&#xff1a…

Python:import语句的使用(详细解析)(一)

相关阅读 Pythonhttps://blog.csdn.net/weixin_45791458/category_12403403.html?spm1001.2014.3001.5482 import语句是Python中一个很重要的机制,允许在一个文件中访问另一个文件的函数、类、变量等,本文就将进行详细介绍。 在具体谈论import语句前&a…

linux驱动编程——等待队列

一、等待队列 可实现调用read函数时阻塞等。 1、流程 (1)初始化等待队列头(带参宏) init_waitqueue_head(q) 等待队列头wq数据类型: wait_queue_head_t,等待条件condition:int型变量。 &…

Actor 并发控制模型

目录 一、模型概述 二、模型特点 三、模型组成 四、模型优势 五、应用实例 一般来说,我们有两种策略来在并发线程中实现通信:共享内存和消息传递。大多数传统语言,并发线程之间的通信使用的都是共享内存,共享内存最大的问题就…

分糖果C++

题目&#xff1a; 样例解释&#xff1a; 样例1解释 拿 k20 块糖放入篮子里。 篮子里现在糖果数 20≥n7&#xff0c;因此所有小朋友获得一块糖&#xff1b; 篮子里现在糖果数变成 13≥n7&#xff0c;因此所有小朋友获得一块糖&#xff1b; 篮子里现在糖果数变成 6<n7&#xf…

为本地生活赛道从业者赋能,易播易赚开启“抖音直播分享会”

9月22日&#xff0c;由杭州易播易赚科技有限公司主办的“抖音直播分享会”在杭州市富阳区召开&#xff0c;此次会议吸引了来自全国各地的抖音直播从业者、有志于加入抖音直播事业的创业者以及行业内知名专家齐聚一堂&#xff0c;共同探讨行业发展趋势、分享实战经验&#xff0c…

tomcat版本升级导致的umask问题

文章目录 1、问题背景2、问题分析3、深入研究4、umask4.1、umask的工作原理4.2、umask的计算方式4.3、示例4.4、如何设置umask4.5、注意事项 1、问题背景 我们的java服务是打成war包放在tomcat容器里运行的&#xff0c;有一天我像往常一样去查看服务的日志文件&#xff0c;却提…

Mysql高级篇(中)——多版本并发控制 MVCC

多版本并发控制 MVCC 一、概述二、基本原理三、实现原理四、示例解释五、MVCC 优点六、现实中的实现七、MVCC 三剑客1. ReadView2. Undo Log3. Purge4. 三者之间的关系&#xff1a;5. 示例6. 总结 八、MVCC 整体操作流程⭐、readview1. 作用2. 工作机制3. 数据版本的可见性判断…

[云服务器15] 全网最全!手把手搭建discourse论坛,100%完成

首先&#xff0c;由我隆重地介绍Discourse&#xff1a; 这是一个优秀的论坛部署平台&#xff0c;相较于flarum Discuz!&#xff0c;有着更加简洁的画面、完全开源等优点&#xff0c;同时资源占用也不高&#xff01; 并且&#xff0c;这和我们亲爱的雨云论坛是有几分相似的哦&…

国庆偷偷卷!小众降维!POD-Transformer多变量回归预测(Matlab)

目录 效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现POD-Transformer多变量回归预测&#xff0c;本征正交分解数据降维融合Transformer多变量回归预测&#xff0c;使用SVD进行POD分解&#xff08;本征正交分解&#xff09;&#xff1b; 2.运行环境Matlab20…

Windows——解除Windows系统中文件名和目录路径的最大长度限制

第一步&#xff1a;打开本地组策略编辑器 按下Win R键打开运行窗口&#xff0c;输入 gpedit.msc 并回车&#xff0c;打开本地组策略编辑器。 第二步&#xff1a;开启 长路径设置 第三步&#xff1a;重启计算机

Windows环境Apache httpd 2.4 web服务器加载PHP8:Hello,world!

Windows环境Apache httpd 2.4 web服务器加载PHP8&#xff1a;Hello&#xff0c;world&#xff01; &#xff08;1&#xff09;首先需要安装apache httpd 2.4 web服务器&#xff1a; Windows安装启动apache httpd 2.4 web服务器-CSDN博客文章浏览阅读222次&#xff0c;点赞5次&…

Spark“数字人体”AI挑战赛_脊柱疾病智能诊断大赛_GPU赛道亚军比赛攻略_triple-Z团队

关联比赛: Spark“数字人体”AI挑战赛——脊柱疾病智能诊断大赛 triple-Z团队答题攻略 1 赛题分析 1.1 赛题回顾 本次比赛的任务是采用模型对核磁共振的脊柱图像进行智能检测。首先需要对5个椎体和6个椎间盘进行定位&#xff0c;这部分实际上就是11个关键点的检测任务&…