YOLOv5、v8改进:CrissCrossAttention注意力机制

news2025/1/11 14:43:05

目录

1.简介

2. yolov5添加方法:

2.1common.py构建CrissCrossAttention模块

2.2yolo.py中注册 CrissCrossAttention模块

2.3修改yaml文件。


1.简介

这是ICCV2019的用于语义分割的论文,可以说和CVPR2019的DANet遥相呼应。

和DANet一样,CCNet也是想建模像素之间的long range dependencies,来做更加丰富的contextual information,来补充特征图,以此来提升语义分割的性能。但是和DANet不一样,CCNet仅考虑空间分辨上的建模,不考虑建模通道之间的联系。作者提出的模块,criss-cross attention module,针对空间维度上的建模,对于空间位置的一个点u,仅考虑建模和u在同一行或者同一列的其他位置的像素之间的联系。相比DANet,能减少很多计算量,但是不足的是,对一个点的特征向量,尽管有同一行或者同一列的其他像素信息作为补充,对于语义分割任务,contextual information仍然是稀疏的(sparse),因为语义分割更在意一个像素和它周围的一些像素的关系。针对这个问题,作者提出了recurrent criss-cross attention module,来建模一个像素和全局所有像素的关系。方式是通过重复criss-cross attention module来实现的。这些module也是参数shared的。

同样是建模空间维度的pixel-wise contextual information,CCNet的计算量相较于self attention,可小太多了。一个CC module,要处理的是一个像素点和同一行、同一列一共(H+W-1)这么多的像素,那么应用在所有像素上,计算量就是O(HW(H+W-1))。回顾DANet的空间注意力分支(position attention module),每一个像素就要和(HW)个像素建模之间的联系,应用在所有相素,计算量就是O(HW*(H*W))。
通过递归的方式用CC module,可以对一个像素捕捉到全局的contextual information,提到了语义分割任务的效果。
个人看法,简单且有效的,就是极其优秀的方法,CCNet就属于这一类方法。
 

在这里插入图片描述

1.首先一个原图送进backbone,这个backbone是修改过的,把最后两个stage的stride改为1,同时应用空洞卷积来增大感受野。得到的特征图是原图的1/8.

2.然后经过1*1的卷积降维。得到H

3.H经过一个criss-cross attention module 得到H ′ 这个时候,H’中的每个位置都捕捉到了和u在同一行或者同一列的context information

4.H’经过一个相同结构、相同参数的cc module,得到了H’’。在H‘’中的每个位置,捕捉的是全局性的contextual information
5..最后经过一个分割层输出最后的预测结果。
在这里插入图片描述

 

之前改进增加了很多注意力机制的方法,包括比较常规的SE、CBAM等,本文加入CrissCrossAttention注意力机制,该注意力机制为应用在语义分割中的模块,用于可以让网络更加关注待检测目标,提高检测效果

基本原理:

       语义分割的Criss-Cross网络(CCNet)的细节。我们首先介绍了CCNet的总体框架。然后,将介绍在水平和垂直方向捕获上下文信息的2D交叉注意力模块。为了获取密集的全局上下文信息,我们建议对交叉注意力模块采用循环操作。为了进一步改进RCCA,我们引入了判别损失函数来驱动RCCA学习类别一致性特征。最后,我们提出了同时利用时间和空间上下文信息的三维交叉注意模块。

2. yolov5添加方法:

2.1common.py构建CrissCrossAttention模块

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Softmax


def INF(B,H,W):
     return -torch.diag(torch.tensor(float("inf")).repeat(H),0).unsqueeze(0).repeat(B*W,1,1)


class CrissCrossAttention(nn.Module):
    """ Criss-Cross Attention Module"""
    def __init__(self, in_dim):
        super(CrissCrossAttention,self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = Softmax(dim=3)
        self.INF = INF
        self.gamma = nn.Parameter(torch.zeros(1))


    def forward(self, x):
        m_batchsize, _, height, width = x.size()
        proj_query = self.query_conv(x)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
        proj_key = self.key_conv(x)
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        proj_value = self.value_conv(x)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
        energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
        concate = self.softmax(torch.cat([energy_H, energy_W], 3))

        att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
        #print(concate)
        #print(att_H) 
        att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
        out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
        out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
        #print(out_H.size(),out_W.size())
        return self.gamma*(out_H + out_W) + x

2.2yolo.py中注册 CrissCrossAttention模块

elif m is CrissCrossAttention:
    c1, c2 = ch[f], args[0]
    if c2 != no:
        c2 = make_divisible(c2 * gw, 8)
    args = [c1, *args[1:]]

2.3修改yaml文件。

# YOLOAir 🚀, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOAir v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 9
  ]

# YOLOAir v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 13

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 17 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 14], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 20 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 10], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 23 (P5/32-large)
   [-1, 1, CrissCrossAttention, [1024]], #修改

   [[17, 20, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

YOLOv8和v5的改法是一致的

有什么问题可以评论区私聊

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/936541.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DBeaver的安装和使用:windows版

DBeaver官网下载地址:https://dbeaver.io/download/ 下载完成后, 进入傻瓜式安装: 这里会进入重复界面,一样点击下一步即可 选择安装目录,尽量不要选C盘, 我的电脑只有c盘, 没办法 等待安装完成…

linux操作系统的权限的深入学习(未完)

1.Linux权限的概念 Linux下有两种用户:超级用户(root)、普通用户。 超级用户:可以再linux系统下做任何事情,不受限制 普通用户:在linux下做有限的事情。 超级用户的命令提示符是“#”,普通用户…

MVSNet 和 PatchMatchNet 的DTU数据集 几个不同之处 一定要注意

文章目录 1 测试集 数据加载不同2 训练集 数量 分辨率不同 1 测试集 数据加载不同 1.MVSNet 的DTU测试数据集和PatchmatchNet测试数据集不一样; 区别在于数据加载,前者 cams文件最后是最小深度和间隔,后者是最小深度和最大深度。 2 训练集 …

layui框架学习(41:表单模块)

之前的文章《layui框架学习》14-16中介绍了通过预设类及部分layui属性设置表单的外观样式,layui中还提供有表单模块以对表单元素进行各类动态化渲染和相关操作,本文学习并记录表单模块form的常用属性、函数及事件的用法(如果内容已在之前文章…

时序预测 | MATLAB实现SSA-XGBoost(麻雀算法优化极限梯度提升树)时间序列预测

时序预测 | MATLAB实现SSA-XGBoost(麻雀算法优化极限梯度提升树)时间序列预测 目录 时序预测 | MATLAB实现SSA-XGBoost(麻雀算法优化极限梯度提升树)时间序列预测预测效果基本介绍模型描述程序设计参考资料 预测效果 基本介绍 Matlab实现SSA-XGBoost时间序列预测,麻…

VS警告 C4819 该文件包含不能在当前代码页(936)中表示的字符。请将该文件保存为 Unicode 格式以防止数据丢失

1、问题 Microsoft visual studio 2019跑仿真的时候出现该警告,在高级保存选项设置编码为Unicode(UTF-8 无签名)还是会有该警告。 2、解决方法 右键项目,打开属性设置,选中:工程 -> 右键选择"属性" -> C/C ->…

JSON文件读写教程【jsoncpp源码编译】【结尾附三方库下载链接】

目录 1 数据下载(jsoncpp源码)2 文件编译3 测试用例4 下载链接:内容: JSON文件的读取与保存可以使用jsoncpp库来实现,这里介绍该库的下载及编译方法。 1 数据下载(jsoncpp源码) 数据下载:Github地址 图1 github源码示意图 2 文件编译 2.1 点击Download ZIP,下载源码。 …

大数据分析与AI在农业领域的应用

文章目录 数据采集与监测数据分析与预测个性化管理与优化决策支持系统结合物联网技术优势与前景 🎈个人主页:程序员 小侯 🎐CSDN新晋作者 🎉欢迎 👍点赞✍评论⭐收藏 ✨收录专栏:Java知识介绍 ✨文章内容&a…

使用ChatGPT给Python代码写单元测试

先写一个简单的python函数,找chatgpt写单元测试: 有一个python函数,请帮忙写单元测试,函数长这样: def test2(a: list, b: list) -> float:"""计算两个坐标的距离:param a list 格式如&#xff1a…

记一次蓝屏日志

记一次Win 蓝屏日志: 📲引: 虽然说,我是一个在职两年半的程序员,但是对于这个问题其实也和大部分人一样,一脸懵逼🤖 那是一个风和日丽的早上,w开开心心去上班摸鱼🐟&a…

LiteOS qemu realview-pbx-a9 环境搭建与运行

前言 最近打算移植搭建 一些常见的 RTOS 的 qemu 开发学习环境,当前 RT-Thread、FreeRTOS 已经成功运行 qemu,LiteOS 初步验证可以正常 运行 qemu realview-pbx-a9,这里做个记录 首先学习或者研究 RTOS,只是看内核源码&#xff0…

SNN论文总结

Is SNN a great work ? Is SNN a convolutional work ? ANN的量化在SNN中是怎么体现的,和threshold有关系吗,threshold可训练和这个有关吗(应该无关) 解决过发放不发放的问题。 Intuation SNN编码方式 Image to spike patter…

性能优化维度

CPU 首先检查 cpu,cpu 使用率要提升而不是降低。其次CPU 空闲并不一定是没事做,也有可能是锁或者外部资源瓶颈。常用top、vmstat命令查看信息。 vmstat 命令: top: 命令 IO iostat 命令: Memory free 命令: 温馨提示&#xff1a…

Qt应用开发(基础篇)——日历 QCalendarWidget

一、前言 QCalendarWidget类继承于QWidget,是Qt设计用来让用户更直观的选择日期的窗口部件。 时间微调输入框 QCalendarWidget根据年份和月份初始化,程序员也通过提供公共函数去改变他们,默认日期为当前的系统时间,用户通过鼠标和…

《C语言编程环境搭建》工欲善其事 必先利其器

C语言编译器 GCC 系列 GNU编译器套装(英语:GNU Compiler Collection,缩写为GCC),指一套编程语言编译器,常被认为是跨平台编译器的事实标准。原名是:GNU C语言编译器(GNU C Compiler)。 MinGW 又称mingw32 &#xff0c…

C语言练习5(巩固提升)

C语言练习5 选择题 选择题 1&#xff0c;下面代码的结果是&#xff1a;( ) #include <stdio.h> #include <string.h> int main() {char arr[] { b, i, t };printf("%d\n", strlen(arr));return 0; }A.3 B.4 C.随机值 D.5 &#x1f4af;答案解析&#…

Unittest 笔记:unittest拓展生成HTM报告

HTMLTestRunner 是一个unitest拓展可以生成HTML 报告 下载地址&#xff1a;GitHub: https://github.com/defnnig/HTMLTestRunner HTMLTestRunner是一个独立的py文件&#xff0c;可以放在Lib 作为第三方模块使用或者作为项目的一部分。 方式1&#xff1a; 验证是否安装成功&…

Navicat下载安装使用

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

基于oracle数据库存储过程的创建及调用

基于oracle数据库存储过程的创建及调用 教学大纲&#xff1a; PLSQL编程&#xff1a;Hello World、程序结构、变量、流程控制、游标.存储过程&#xff1a;概念、无参存储、有参存储&#xff08;输入、输出&#xff09;.JAVA调用存储存储过程. 1. PLSQL编程 1.1. 概念和目的…

力扣真题:无重复字符的最长子串(三种方法)

这道题我一开始使用了Set加类似滑动窗口的方法&#xff0c;最后解得出来&#xff0c;但效率不尽人意&#xff0c;最后经过几次修改&#xff0c;最终用到是滑动窗口指针数组的方式讲效果达到最优&#xff0c;超过近99%的代码。 1、第一版 class Solution {public int lengthOf…