YOLOv8改进 - 注意力篇 - 引入SK网络注意力机制

news2024/9/27 6:42:04

一、本文介绍

作为入门性篇章,这里介绍了SK网络注意力在YOLOv8中的使用。包含SK原理分析,SK的代码、SK的使用方法、以及添加以后的yaml文件及运行记录。

二、SK原理分析

SK官方论文地址:SK注意力文章

SK注意力机制:SK网络中的神经元可以捕获具有不同比例的目标对象,实验验证了神经元根据输入自适应地调整其感受野大小的能力。其SK模块的原理结构如下图所示。该方法由三个部分组成:Split,Fuse,Select。

Split就是一个multi-branch的操作,用不同的卷积核进行卷积得到不同的特征;

Fuse部分就是用SE的结构获取通道注意力的矩阵(N个卷积核就可以得到N个注意力矩阵,这步操作对所有的特征参数共享),这样就可以得到不同kernel经过SE之后的特征;

Select操作就是将这几个特征进行相加。

相关代码:

SK注意力的代码,如下。

from torch.nn import init
from collections import OrderedDict

class SKAttention(nn.Module):

    def __init__(self, channel=512, kernels=[1, 3, 5, 7], reduction=16, group=1, L=32):
        super().__init__()
        self.d = max(L, channel // reduction)
        self.convs = nn.ModuleList([])
        for k in kernels:
            self.convs.append(
                nn.Sequential(OrderedDict([
                    ('conv', nn.Conv2d(channel, channel, kernel_size=k, padding=k // 2, groups=group)),
                    ('bn', nn.BatchNorm2d(channel)),
                    ('relu', nn.ReLU())
                ]))
            )
        self.fc = nn.Linear(channel, self.d)
        self.fcs = nn.ModuleList([])
        for i in range(len(kernels)):
            self.fcs.append(nn.Linear(self.d, channel))
        self.softmax = nn.Softmax(dim=0)

    def forward(self, x):
        bs, c, _, _ = x.size()
        conv_outs = []
        ### split
        for conv in self.convs:
            conv_outs.append(conv(x))
        feats = torch.stack(conv_outs, 0)  # k,bs,channel,h,w

        ### fuse
        U = sum(conv_outs)  # bs,c,h,w

        ### reduction channel
        S = U.mean(-1).mean(-1)  # bs,c
        Z = self.fc(S)  # bs,d

        ### calculate attention weight
        weights = []
        for fc in self.fcs:
            weight = fc(Z)
            weights.append(weight.view(bs, c, 1, 1))  # bs,channel
        attention_weughts = torch.stack(weights, 0)  # k,bs,channel,1,1
        attention_weughts = self.softmax(attention_weughts)  # k,bs,channel,1,1

        ### fuse
        V = (attention_weughts * feats).sum(0)
        return V

四、YOLOv8中SK使用方法

1.YOLOv8中添加SK模块,首先在ultralytics/nn/modules/conv.py最后添加SK模块的代码。

2.在conv.py的开头__all__ = 内添加SK模块的类别名(SK的类别名在本文中为SKAttention)

3.在同级文件夹下的__init__.py内添加SKAttention的相关内容:(分别是from .conv import SKAttention ;以及在__all__内添加SKAttention)

4.在ultralytics/nn/tasks.py进行SK注意力机制的注册,以及在YOLOv8的yaml配置文件中添加SK即可。

首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面添加以下注册代码:(本文续接上篇文章,加在了CBAM、ECA的位置)

        elif m in {CBAM,ECA,SKAttention}:#添加注意力模块,没有CBAM、ECA的,将CBAM、ECA删除即可
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, *args[1:]]

然后,就是新建一个名为YOLOv8_SK.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_SK.yaml)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call CPAM-yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SKAttention, [1024]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

其中参数中nc,由自己的数据集决定。本文测试,采用的coco8数据集,有80个类别。

在根目录新建一个train.py文件,内容如下

from ultralytics import YOLO


# 加载一个模型
model = YOLO('ultralytics/cfg/models/v8/YOLOv8_SK.yaml')  # 从YAML建立一个新模型
# 训练模型
results = model.train(data='ultralytics/cfg/datasets/coco8.yaml', epochs=1,imgsz=640,optimizer="SGD")

训练输出:

五、总结

以上就是SK的原理及使用方式,但具体SK注意力机制的具体位置放哪里,效果更好。需要根据不同的数据集做相应的实验验证。希望本文能够帮助你入门YOLO中注意力机制的使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2169263.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

音视频通话 SDK

腾讯云视立方音视频通话 SDK 是音视频终端 SDK(腾讯云视立方)的子产品 SDK 之一,基于音视频通话场景,提供专属含 UI 快速接入方案,仅需三步即可快速集成上线,轻松实现1对1视频聊天、多人视频通话和聊天应用…

内网基础知识

内网基础知识 寄了,最后net time /domain命令还是运行不了 内网也指局域网(Local Area Network,LAN),是指在某一区域内由多台计算机互连而成的计算机组,组网范围通常在数千米以内。 工作组 work group 一种资源管理模式&#…

[SAP ABAP] PARAMETERS

PARAMETERS定义用户可以输入值的输入字段(单值Input) 基本语法 PARAMETERS PNAME. PNAME命名长度不能超过8位 PARAMETERS创建一个单一的输入域且最多只能输入一行,定义后的PNAME可作为变量在程序中运用 示例1 p_1的数据类型为CHAR1 输出结果: 补…

6.使用 VSCode 过程中的英语积累 - Run 菜单(每一次重点积累 5 个单词)

前言 学习可以不局限于传统的书籍和课堂,各种生活的元素也都可以做为我们的学习对象,本文将利用 VSCode 页面上的各种英文元素来做英语的积累,如此做有 3 大利 这些软件在我们工作中是时时刻刻接触的,借此做英语积累再合适不过&a…

.NET 6.0 使用log4net配置日志记录方法

1.包管理器引入相关包 2.添加Log4net文件夹和log4net.config配置文件(配置文件属性设为始终复制)。 3.替换 log4net.config的内容(3.1与3.2选择一个就好,只是创建日志文件有所区别) 3.1: <?xml version"1.0" encoding"utf-8"?> <configuration…

JavaWeb美食推荐管理系统

目录 1 项目介绍2 项目截图3 核心代码3.1 Controller3.2 Service3.3 Dao3.4 spring-mybatis.xml3.5 spring-mvc.xml3.5 login.jsp 4 数据库表设计5 文档参考6 计算机毕设选题推荐7 源码获取 1 项目介绍 博主个人介绍&#xff1a;CSDN认证博客专家&#xff0c;CSDN平台Java领域优…

OceanBase技术解析:自适应分布式下压技术

在《OceanBase 数据库源码解析》这本书中&#xff0c;关于SQL执行器的深入剖析相对较少&#xff0c;因此&#xff0c;希望增添一些实用且详尽的补充内容。 上一篇博客《 OceanBase技术解析&#xff1a; 执行器中的自适应技术》中&#xff0c;已初步介绍了执行器中几项典型的自适…

HarmonyOS异常处理实践

一、HarmonyOS应用异常处理框架 全面检测、精准记录异常传播路径、日志精简 二、FaultLog FaultLog是应用异常日志查询接口&#xff0c;提供QuerySelfFaultLog接口以查询自身故障。 JS_CRASH&#xff1a;ArkTS程序故障类型 CPP_CRASH&#xff1a;C程序故障类型 APP_FREEZE&…

csv导入导出

一、csv 1、介绍 CSV&#xff08;Comma-Separated Values&#xff0c;逗号分隔的值&#xff09;是一种简单、实用的文件格式&#xff0c;用于存储和表示包括文本、数值等各种类型的数据。CSV 文件通常以 .csv 作为文件扩展名。这种文件格式的一个显著特点是&#xff1a;文件内…

JavaSE——Arrays类、System类

目录 一、Arrays类 1.Arrays.toString() 2.Arrays.sort() 3.Arrays实现冒泡排序的定制排序 4.Arrays.binarySearch()——二叉查找 5.Arrays.copyOf()——数组元素的复制 6.Arrays.fill()——数组的填充 7.Arrays.equals(arr1,arr2)——比较2个数组元素内容是否完全一致…

java中的ArrayList和LinkedList的底层剖析

引入: 数据结构的分类&#xff0c;数据结构可以分成&#xff1a;线性表&#xff0c;树形结构&#xff0c;图形结构。 线性结构(线性表)包括:数组、链表、栈队列 树形结构:二叉树、AVL树、红黑树、B树、堆、Trie、哈夫曼树、并查集 图形结构:邻接矩阵、邻接表 线性表是具有存…

通信工程学习:什么是TDD时分双工

TDD:时分双工 TDD(时分双工,Time Division Duplexing)是一种在移动通信系统中广泛使用的全双工通信技术。以下是TDD的详细解释: 一、定义与原理 TDD是一种通过时间划分来实现双向通信的技术。在TDD模式中,接收和传送在同一频率信道(即载波)的不同时隙…

新品上市!智能无线接入型路由器ZX7981EP,WIFI6技术双频频段

在这个快节奏的时代 每一次点击都渴望即刻响应&#xff0c;每一份数据都期待安全传输 我们希望大家都能享有顶尖的网络体验&#xff0c;由此 启明智显ZX7891EP智能无线接入型路由器新品上市&#xff01; 2.4G/5G双频段&#xff0c;WAN口/LAN口皆齐全 最新802.1ax WiFi6技术…

【Linux】Linux工具——CMake入门

目录 1.什么是CMake 2.CMakeflie的安装和版本的查看 3.几个简单示例 3.1.编译一个.cc文件 3.2.编译一个.hpp文件和一个.cc文件 3.3.编译一个.hpp文件和两个.cc文件 3.4.编译两个.hpp文件和一个.cc文件 4.CMakeLists.txt 4.1.CMakeLists.txt常用的几条指令 4.2.变量和…

软件测试之单元测试/系统测试/集成测试详解

&#x1f345; 点击文末小卡片&#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 一、单元测试的概念 单元测试是对软件基本组成单元进行的测试&#xff0c;如函数或一个类的方法。当然这里的基本单元不仅仅指的是一个函数或者方法&#xff0…

基于densenet模型在RML201610a数据集上的调制识别【代码+数据集+python环境+GUI系统】

基于densenet模型在RML201610a数据集上的调制识别【代码数据集python环境GUI系统】 Loss曲线 背景意义 随着社会的快速发展&#xff0c;人们在通信方面的需求逐渐增加&#xff0c;特别是在无线通信领域。通信环境的复杂化催生了多种通信形式和相关应用&#xff0c;这使得调制…

最新版无忧二级域名分发源码,支持包月续费

目前版本支持&#xff0c;开通会员&#xff0c;会员组可以解析哪些域名 比如 用普通域名引流&#xff0c;免费使用&#xff0c;就可以注册就是普通用户组 会员组可以设置价格比如10块钱买永久会员&#xff0c;没有别的特权&#xff0c;只是会员才可以租备案域名&#xff0c; 设…

有源蜂鸣器(5V STM32)

目录 一、介绍 二、模块原理 1.有/无源蜂鸣器介绍 2.原理图 3.引脚描述 三、程序设计 main.c文件 beep.h文件 beep.c文件 四、实验效果 五、资料获取 项目分享 一、介绍 蜂鸣器是一种能将音频信号转化声音信号的发音器件&#xff0c;在家电器上&#xff0c;在银行…

直播 SDK

直播 SDK 是音视频终端 SDK&#xff08;腾讯云视立方&#xff09;针对移动直播场景专属打造的一体化产品&#xff0c;支持直播推拉流、主播观众互动连麦、主播跨房 PK 等能力&#xff0c;为用户提供高质量直播服务&#xff0c;快速满足手机直播的需求。更多关于直播 SDK 的文档…

Ubuntu 22.04无法连接网络(网络图标丢失)解决方案

对于Ubuntu 22.04而言&#xff1a; sudo service NetworkManager stop sudo rm /var/lib/NetworkManager/NetworkManager.state sudo service NetworkManager start