YOLOv8改进 - 注意力篇 - 引入(A2-Nets)Double Attention Networks注意力机制

news2024/11/16 15:36:46

一、本文介绍

作为入门性篇章,这里介绍了A2-Nets网络注意力在YOLOv8中的使用。包含A2-Nets原理分析,A2-Nets的代码、A2-Nets的使用方法、以及添加以后的yaml文件及运行记录。

二、A2-Nets原理分析

A2-Nets官方论文地址:A2-Nets文章

A2-Nets注意力机制(双重注意力机制):它从输入图像/视频的整个时空空间中聚集和传播信息全局特征,使后续卷积层能够有效地从整个空间中访问特征。采用双注意机制(包括Spatial Attention和Channel Attention。Spatial Attention用于捕获图像中不同空间位置的重要性,而Channel Attention用于捕获图像中不同通道的重要性),分两步进行设计,第一步通过二阶注意池将整个空间的特征聚集成一个紧凑的集合,第二步通过另一个注意自适应地选择特征并将其分配到每个位置。​

相关代码:

A2-Nets注意力的代码,如下。

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F


class DoubleAttention(nn.Module):

    def __init__(self, in_channels, c_m=128, c_n=128, reconstruct=True):
        super().__init__()
        self.in_channels = in_channels
        self.reconstruct = reconstruct

        self.c_m = c_m
        self.c_n = c_n
        self.convA = nn.Conv2d(in_channels, c_m, 1)
        self.convB = nn.Conv2d(in_channels, c_n, 1)
        self.convV = nn.Conv2d(in_channels, c_n, 1)
        if self.reconstruct:
            self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size=1)
        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        b, c, h, w = x.shape
        assert c == self.in_channels
        A = self.convA(x)  # b,c_m,h,w
        B = self.convB(x)  # b,c_n,h,w
        V = self.convV(x)  # b,c_n,h,w
        tmpA = A.view(b, self.c_m, -1)
        attention_maps = F.softmax(B.view(b, self.c_n, -1))
        attention_vectors = F.softmax(V.view(b, self.c_n, -1), dim=-1)
        # step 1: feature gating
        global_descriptors = torch.bmm(tmpA, attention_maps.permute(0, 2, 1))  # b.c_m,c_n
        # step 2: feature distribution
        tmpZ = global_descriptors.matmul(attention_vectors)  # b,c_m,h*w
        tmpZ = tmpZ.view(b, self.c_m, h, w)  # b,c_m,h,w
        if self.reconstruct:
            tmpZ = self.conv_reconstruct(tmpZ)

        return tmpZ

四、YOLOv8中SK使用方法

1.YOLOv8中添加SK模块:

首先在ultralytics/nn/modules/conv.py最后添加DoubleAttention模块的代码。

2.在conv.py的开头__all__ = 内添加DoubleAttention模块的类别名(A2-Nets的类别名在本文中为DoubleAttention)

3.在同级文件夹下的__init__.py内添加A2-Nets的相关内容:(分别是from .conv import DoubleAttention ;以及在__all__内添加DoubleAttention)

4.在ultralytics/nn/tasks.py进行SK注意力机制的注册,以及在YOLOv8的yaml配置文件中添加DoubleAttention即可。

首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面添加以下注册代码:(本文续接上篇文章,加在了CBAM、ECA、SKAttention的位置)

        elif m in {CBAM,ECA,SKAttention,DoubleAttention}:#添加注意力模块,没有CBAM、ECA、SKAttention的,将CBAM、ECA、SKAttention删除即可
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, *args[1:]]

然后,就是新建一个名为YOLOv8_DoubleAttention.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_DoubleAttention.yaml)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call CPAM-yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, DoubleAttention, [1024]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

其中参数中nc,由自己的数据集决定。本文测试,采用的coco8数据集,有80个类别。

在根目录新建一个train.py文件,内容如下

from ultralytics import YOLO

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
# 加载一个模型
    model = YOLO('ultralytics/cfg/models/v8/YOLOv8_DoubleAttention.yaml')  # 从YAML建立一个新模型
# 训练模型
    results = model.train(data='ultralytics/cfg/datasets/coco8.yaml', epochs=1,imgsz=640,optimizer="SGD")

训练输出:​

五、总结

以上就是DoubleAttention的原理及使用方式,但具体DoubleAttention注意力机制的具体位置放哪里,效果更好。需要根据不同的数据集做相应的实验验证。希望本文能够帮助你入门YOLO中注意力机制的使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2172271.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

美妆电商与AI知识库:构建智能化购物体验

在当今这个数字化时代,美妆电商行业正经历着前所未有的变革。随着人工智能(AI)技术的飞速发展,AI知识库在美妆电商领域的应用日益广泛,不仅重塑了传统的购物模式,还为消费者带来了前所未有的智能化购物体验…

手把手教你找到海外网红合作:海外红人营销渠道

在全球范围内,许多企业寻求与知名网红建立合作关系,以推广产品、共同创作内容或探索其他合作形式。以下是一些有效的方法来实现这一目标: 利用社交媒体平台:社交媒体是寻找海外网红的首选途径。平台如Instagram、YouTube和TikTok拥…

windows10使用bat脚本安装前后端环境之node环境设置

首先需要搞清楚node在本地是怎么安装配置、然后在根据如下步骤编写bat脚本: 思路 1.下载需要安装node版本zip格式包 2.配置环境变量 3.安装插件 可以根据自己需要来定义与配置(如下添加redis与node配置) bat脚本: echo off…

Node的安装和配置

1、安装Node 下载nodejs 链接:下载 | Node.js 中文网 官网下载最新版本:https://nodejs.org/en/download/ 一路点击Next,最后Finish。nodejs一般会下载在C盘里。 下载完成后,可以在cmd中查看安装的nodejs和npm版本,…

python单例和工厂模式

设计模式 设计模式是一种编程套路,可以极大的方便程序的开发 最常见、最经典的设计模式,就是学习的面向对象 除了面向对象之外,在编程中也有很多既定的套路可以方便开发,我们称之为设计模式: 单例、工厂模式建造者…

2022年上真题(案例分析)

一、数据流图 1. E1:商户 E2:外卖平台 E3:用户 E4:支付系统 2. D1:商户用户信息表 D2:订单表 D3:餐品信息表 D4:评价表 3. 数据流名称 …

Linux设备驱动中的异步通知与异步I/O学习s

1、异步通知的概念和作用 异步通知的意思是:一旦设备就绪,则主动通知应用程序,这样应用程序根本就不需要查询设备状态,这一点非常类似于硬件上”中断“的概念,比较准确的称谓是”信号驱动的异步I/O“。信号是在软件层次…

盘点2024年4款高效率的语音转文字工具。

语音转换文字软件真的是一种提高效率的神器,我在工作中常常因为手动记录太慢而选择录音。事后在形成记录,但效率比较低。自从知道有直接转换的工具之后,我有再多的录音都不怕了。如果大家也有跟我一样的工作时,可以试试使用这些语…

C++之STL—常用拷贝和替换算法

copy(iterator beg, iterator end, iterator dest); // 按值查找元素,找到返回指定位置迭代器,找不到返回结束迭代器位置 // beg 开始迭代器 // end 结束迭代器 // dest 目标起始迭代器 replace(iterator beg, iterator end, oldvalue, newvalue); …

儿童手抄报模板-200个(家有神兽必备)

在这个充满色彩与想象的世界里,每一位小朋友都是一位小小艺术家和梦想家。作为家长或老师,我们总是希望能为他们的学习生活增添一抹亮色,激发他们的创造力与探索欲。今天,就为大家带来一份超级实用的资源——儿童手抄报模板-200个…

Spring:强制登陆与拦截器

1.只使用session验证 (1)第一步:用户登陆时存储session ApiOperation("用户登陆") PostMapping("/login") public AppResult login(HttpServletRequest request,RequestParam("username") ApiParam("用…

将上一篇的feign接口抽取到公共api模块(包含feign接口示例)

文章目录 一、准备二、主要工作三、建立dto类四、添加多个feign接口五、测试六、目录结构6.1 父工程service-demo6.2 order-service模块6.3 product-service模块6.4 sd-api模块 一、准备 将上一篇的目录结构改造一下: 修改包名使根路径分别为com.hdl.order和com.h…

微信支付:chooseWXPay:fail, the permission value is offline verifying

在开发公众号微信支付的时候,在微信开发者工具中使用 WeixinJSBridge 唤起 微信支付,页面上看到微信支付的loading一闪而过,但是没有出现微信支付的页面。控制台log显示错误信息:“chooseWXPay:fail, the permission value is off…

实验1 Python语言基础一

目录 实验1 Python语言基础一1、下载安装Python,贴出验证安装成功截图2、建立test.py文件,运行后贴出截图,思考if __name”__main__”的意思和作用3、分别运行下面两种代码,分析运行结果产生的原因。记牢python中重要语法“tab”的作用。6、编…

企业内训|大模型/智算行业发展机会深度剖析-某数据中心厂商

北京中嘉和信通信技术有限公司于8月29日举办了一场主题为“大模型/智算行业发展机会深度剖析”的企业内训。此次培训由TsingtaoAI公司负责人汶生主讲,针对当前大模型技术的发展现状、应用场景及未来趋势进行了全面分析和解读。 汶生老师在培训中深入剖析了大模型的…

京东商品详情API(item_get)性能优化策略:提升数据抓取效率

在电商领域,快速、准确地获取商品信息对于提升用户体验、优化库存管理和市场决策至关重要。京东商品详情API(item_get)作为京东开放平台提供的一项重要服务,允许开发者获取京东平台上商品的详细信息。然而,如何高效利用…

xxl-job 适配达梦数据库

前言 在数字化转型的浪潮中,任务调度成为了后端服务不可或缺的一部分。XXL-JOB 是一个轻量级、分布式的任务调度框架,广泛应用于各种业务场景。达梦数据库(DM),作为一款国内领先的数据库产品,已经被越来越…

详解调用钉钉AI助理消息API发送钉钉消息卡片给指定单聊用户

文章目录 前言准备工作1、在钉钉开发者后台创建一个钉钉企业内部应用;2、创建并保存好应用的appKey和appSecret,后面用于获取调用API的请求token;3、了解AI助理主动发送消息API:4、应用中配置好所需权限:4.1、权限点4.…

光控资本:国庆节股市能不能继续交易?A股放量大涨

早年这个时分,商场谈论的最多论题是,持股过节仍是持币过节。 而本年却大不一样,“国庆节股市能不能继续生意”成为这两天股民之间的梗。 今天上午,A股继续暴升,创了三个纪录。一是上午成交额为9466亿元,跨…

2024 maya的散布工具sppaint3d使用指南

目前工具其实可以分为三个版本 1 最老的原版 时间还是2011年的,只支持python2版的maya 2 作者python3更新版 后来作者看maya直到2022上还是没有类似好用方便的工具,于是更新到了2022版本 这个是原作者更新的2022版本,改成了python3&#…