Yolov8改进CoTAttention注意力机制,效果秒杀CBAM、SE

news2024/11/27 7:38:11

1.CoTAttention

论文地址:2107.12292.pdf (arxiv.org)

CoTAttention网络是一种用于多模态场景下的视觉问答(Visual Question Answering,VQA)任务的神经网络模型。它是在经典的注意力机制(Attention Mechanism)上进行了改进,能够自适应地对不同的视觉和语言输入进行注意力分配,从而更好地完成VQA任务。

CoTAttention网络中的“CoT”代表“Cross-modal Transformer”,即跨模态Transformer。在该网络中,视觉和语言输入分别被编码为一组特征向量,然后通过一个跨模态的Transformer模块进行交互和整合。在这个跨模态的Transformer模块中,Co-Attention机制被用来计算视觉和语言特征之间的交互注意力,从而实现更好的信息交换和整合。在计算机视觉和自然语言处理紧密结合的VQA任务中,CoTAttention网络取得了很好的效果。

新加坡国立大学的Qibin Hou等人提出了一种为轻量级网络设计的新的注意力机制,该机制将位置信息嵌入到了通道注意力中,称为coordinate attention

京东AI Research提出新的主干网络CoTNet,在CVPR上2021获得开放域图像识别竞赛冠军

京东AI Research提出一种Contextual Transformer Networks结构,就是将Transformer捕捉全局信息的能力与CNN捕捉临近局部信息能力相结合,从而提高网络模型的特征表达能力。值得注意的是,该方法可以实现模块的“即插即用”,将ResNet网络中的3x3模块替换成CoTNet的核心模块即可使用,Res2Net网络也是基于这种思想实现的。

在相同深度(50层或101层)下,top-1和top-5结果都表明本文的方法比卷积网络和Attention-based网络性能更好。

2. 基于Yolov8的CoordAttention实现

2.1 CoTAttention 加入 modules.py

核心代码:

######################  CoTAttention   ####     start   by  AI&CV  ###############################
import torch
from torch import flatten, nn
from torch.nn import functional as F


class CoTAttention(nn.Module):

    def __init__(self, dim=512, kernel_size=3):
        super().__init__()
        self.dim = dim
        self.kernel_size = kernel_size

        self.key_embed = nn.Sequential(
            nn.Conv2d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=4, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU()
        )
        self.value_embed = nn.Sequential(
            nn.Conv2d(dim, dim, 1, bias=False),
            nn.BatchNorm2d(dim)
        )

        factor = 4
        self.attention_embed = nn.Sequential(
            nn.Conv2d(2 * dim, 2 * dim // factor, 1, bias=False),
            nn.BatchNorm2d(2 * dim // factor),
            nn.ReLU(),
            nn.Conv2d(2 * dim // factor, kernel_size * kernel_size * dim, 1)
        )

    def forward(self, x):
        bs, c, h, w = x.shape
        k1 = self.key_embed(x)  # bs,c,h,w
        v = self.value_embed(x).view(bs, c, -1)  # bs,c,h,w

        y = torch.cat([k1, x], dim=1)  # bs,2c,h,w
        att = self.attention_embed(y)  # bs,c*k*k,h,w
        att = att.reshape(bs, c, self.kernel_size * self.kernel_size, h, w)
        att = att.mean(2, keepdim=False).view(bs, c, -1)  # bs,c,h*w
        k2 = F.softmax(att, dim=-1) * v
        k2 = k2.view(bs, c, h, w)

        return k1 + k2

######################  CoTAttention   ####     end   by  AI&CV  ###############################

​2.2 yolov8_CoTAttention.yaml

# Ultralytics YOLO  , GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 4  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
  - [-1, 3, CoTAttention, [256]]   # 16

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)
  - [-1, 3, CoTAttention, [512]]  

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 23 (P5/32-large)
  - [-1, 3, CoTAttention, [1024]]  

  - [[16, 20, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1170389.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++ 算法:区间和的个数

涉及知识点 归并排序 题目 给你一个整数数组 nums 以及两个整数 lower 和 upper 。求数组中,值位于范围 [lower, upper] (包含 lower 和 upper)之内的 区间和的个数 。 区间和 S(i, j) 表示在 nums 中,位置从 i 到 j 的元素之和…

多技术融合提升环境、生态、水文、土地、农业、大气等领域科研技术水平

专题一、空间数据获取与制图 1.1 软件安装与应用讲解 1.2 空间数据介绍 1.3海量空间数据下载 1.4 ArcGIS软件快速入门 1.5 Geodatabase地理数据库 点击查看原文链接https://mp.weixin.qq.com/s?__bizMzg2NDYxNjMyNA&mid2247546998&idx6&sn39342c376b158eff1…

基于SSM的购物商城网站的设计与实现

末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:Vue 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目:是 目录…

多测师肖sir_高级金牌讲师__adb命令

adb指令整理: ADB常用的指令: 查看当前连接设备 : adb devices 进入到shell : adb shell 查看日志 : adb logcat 安装apk文件 : adb install xxx.apk 卸载APP : adb uninstall 包名 查看包名 &…

Cadence Virtuoso如何保存spectre仿真在cell view里

Launch ADE L 可以选择 Load State 加载上次仿真状态,但是我想保存在cell view和schematic在一起可以直接打开,可以选择Save State旁的Cellview 可以在Library Manager中看到保存成功了

YOLOv5项目实战(2)— 手把手教你租借云服务器去训练模型

前言:Hello大家好,我是小哥谈。近期由于出差在外(在新疆吐鲁番出差呢~),一直远程使用公司服务器进行算法模型训练,但是由于这几天公司VPN故障,导致无法远程训练模型,所以就想着租借服务器来进行训练。近期我研发的算法模型是工业场景烟雾明火检测,本节课就以此为例教大…

AI 引擎系列 4 - 首次运行 AI 引擎编译器和 x86simulator(2022.1 更新)

AI 引擎系列 4 - 首次运行 AI 引擎编译器和 x86simulator(2022.1 更新) 简介 在 AI 引擎系列的前 3 篇博文中,我们探讨了 AI 引擎应用所需的不同文件。在本篇中,我们将为 X86 目标运行 AI 引擎编译器,观察它生成的不…

YOLO目标检测——昏暗车辆检测数据集【含对应voc、coco和yolo三种格式标签】

实际项目应用:智能交通监控系统、驾驶辅助系统、城市安全监控、自动驾驶系统以及路况分析与规划等数据集说明:昏暗车辆检测数据集,真实场景的高质量图片数据,数据场景丰富,含有图片汽车、卡车、公共汽车标签说明&#…

C++23:多维视图(std::mdspan)

C23:多维视图(std::mdspan) 介绍 在 C23 中,std::mdspan 是一个非拥有的多维视图,用于表示连续对象序列。这个连续对象序列可以是一个简单的 C 数组、带有大小的指针、std::array、std::vector 或 std::string。 这…

YOLO目标检测——车辆分类检测数据集【含对应voc、coco和yolo三种格式标签】

实际项目应用:安全监控、智能驾驶、人机交互、智能城市数据集说明:车辆分类检测数据集,真实场景的高质量图片数据,数据场景丰富,含有图片汽车、公共汽车、摩托车、救护车和卡车等图片标签说明:使用lableimg…

SPSS两相关样本检验

前言: 本专栏参考教材为《SPSS22.0从入门到精通》,由于软件版本原因,部分内容有所改变,为适应软件版本的变化,特此创作此专栏便于大家学习。本专栏使用软件为:SPSS25.0 本专栏所有的数据文件请点击此链接下…

LED点阵显示原理(取字模软件+Keil+Proteus)

前言 写这个的时候我还是有点生气的,因为发现完全按照书上面的步骤来,结果发现不理想,后面还是自己调试才解决了。-_-说多了都是泪,直接进入正文。 软件的操作还是参考我之前的博客。 LED数码管的静态显示与动态显示&#xff0…

nvm 下载 nodejs 速度慢问题解决

1、找到 nvm 的下载目录,在目录下找到 settings.txt 文件 2、打开 settings.txt 文件 ,添加以下代码: node_mirror: https://npm.taobao.org/mirrors/node/ npm_mirror: https://npm.taobao.org/mirrors/npm/添加完成后再去下载即可。

sed 原地替换文件时遇到的趣事

哈喽大家好,我是咸鱼 在文章《三剑客之 sed》中咸鱼向大家介绍了文本三剑客中的 sed sed 全名叫 stream editor,流编辑器,用程序的方式来编辑文本 那么今天咸鱼打算讲一下我在用 sed 原地替换文件时遇到的趣事 sed 让文件属性变了&#xff…

故障诊断模型 | Maltab实现BP神经网络的故障诊断

文章目录 效果一览文章概述模型描述源码设计参考资料效果一览 文章概述 故障诊断模型 | Maltab实现BP神经网络的故障诊断 模型描述 BP(Back Propagation) 算法是神经网络深度学习中最重要的算法之一,了解BP算法可以让我们更理解神经网络深度学习模型训练的本质,属于内功修行的…

什么叫储能能量管理单元EMU?储能能量管理单元EMU功能?储能EMU是什么?储能能量管理系统如何实现一次调频AGC-AVC功能?

一:储能EMU是什么意思?什么叫储能能量管理单元EMU? EMU是能量管理单元的英文缩写 (Energy Management Unit, EMU) EmuPower3300能量管理单元EMU是由广州智昊电气研发配套EsccPower3300储能协调管理器组成对光伏电站的管理,控制,…

Java随机获取某个范围内的随机整数

随机获取某个范围内的随机整数 一、代码 /*** 随机获取某个范围内的随机整数的值* param min 最小值* param max 最大值* return*/public static int randomNum(int min,int max) {// 创建一个Random对象Random random new Random();// 生成指定范围内的随机整数int randomI…

【Linux】 passwd命令使用

passwd命令用来更改使用者的密码。 语法 passwd [选项] [用户名] passwd命令 -Linux手册页 著者 克里斯蒂安加夫顿<gaftonredhat.com> 命令选项及作用 执行令 passwd --help 执行命令结果 参数 -k, --keep-tokens 保持身份验证令牌不过期-d, --delete …

超详细docker学习笔记

关于docker 一、基本概念什么是docker?docker组件&#xff1a;我们能用docker做什么Docker与配置管理&#xff1a;Docker的技术组件Docker资源Docker与虚拟机对比 二、安装docker三、镜像命令启动命令帮助命令列出本地主机上的镜像在远程仓库中搜索镜像查看占据的空间删除镜像…

Lustre文件系统介绍

一、什么是Lustre文件系统 Lustre架构是用于集群的存储架构。Lustre架构的核心组件是Lustre文件系统&#xff0c;它在Linux操作系统上得到支持&#xff0c;并提供了一个符合POSIX *标准的UNIX文件系统接口。 Lustre存储架构用于许多不同类型的集群。它以支持世界上许多最大的…