mmdetection3增加12种注意力机制

news2025/1/13 10:15:41

在mmdetection/mmdet/models/layers/目录下增加attention_layers.py

import torch.nn as nn
from mmdet.registry import MODELS
#自定义注意力机制算法
from .attention.CBAM import CBAMBlock as _CBAMBlock
from .attention.BAM import BAMBlock as _BAMBlock
from .attention.SEAttention import SEAttention as _SEAttention
from .attention.ECAAttention import ECAAttention as _ECAAttention
from .attention.ShuffleAttention import ShuffleAttention as _ShuffleAttention
from .attention.SGE import SpatialGroupEnhance as _SpatialGroupEnhance
from .attention.A2Atttention import DoubleAttention as _DoubleAttention
from .attention.PolarizedSelfAttention import SequentialPolarizedSelfAttention as _SequentialPolarizedSelfAttention
from .attention.CoTAttention import CoTAttention as _CoTAttention
from .attention.TripletAttention import TripletAttention as _TripletAttention
from .attention.CoordAttention import CoordAtt as _CoordAtt
from .attention.ParNetAttention import ParNetAttention as _ParNetAttention


@MODELS.register_module()
class CBAMBlock(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(CBAMBlock, self).__init__()
        print("======激活注意力机制模块【CBAMBlock】======")
        self.module = _CBAMBlock(channel = in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)
    
    
@MODELS.register_module()
class BAMBlock(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(BAMBlock, self).__init__()
        print("======激活注意力机制模块【BAMBlock】======")
        self.module = _BAMBlock(channel = in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)


@MODELS.register_module()
class SEAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(SEAttention, self).__init__()
        print("======激活注意力机制模块【SEAttention】======")
        self.module = _SEAttention(channel = in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)   
 

@MODELS.register_module()
class ECAAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(ECAAttention, self).__init__()
        print("======激活注意力机制模块【ECAAttention】======")
        self.module = _ECAAttention(**kwargs)

    def forward(self, x):
        return self.module(x)  


@MODELS.register_module()
class ShuffleAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(ShuffleAttention, self).__init__()
        print("======激活注意力机制模块【ShuffleAttention】======")
        self.module = _ShuffleAttention(channel = in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)


@MODELS.register_module()
class SpatialGroupEnhance(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(SpatialGroupEnhance, self).__init__()
        print("======激活注意力机制模块【SpatialGroupEnhance】======")
        self.module = _SpatialGroupEnhance(**kwargs)

    def forward(self, x):
        return self.module(x)   
    

@MODELS.register_module()
class DoubleAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(DoubleAttention, self).__init__()
        print("======激活注意力机制模块【DoubleAttention】======")
        self.module = _DoubleAttention(in_channels, 128, 128,True)

    def forward(self, x):
        return self.module(x)  


@MODELS.register_module()
class SequentialPolarizedSelfAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(SequentialPolarizedSelfAttention, self).__init__()
        print("======激活注意力机制模块【Polarized Self-Attention】======")
        self.module = _SequentialPolarizedSelfAttention(channel=in_channels)

    def forward(self, x):
        return self.module(x)   
    
    
@MODELS.register_module()
class CoTAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(CoTAttention, self).__init__()
        print("======激活注意力机制模块【CoTAttention】======")
        self.module = _CoTAttention(dim=in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)  

    
@MODELS.register_module()
class TripletAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(TripletAttention, self).__init__()
        print("======激活注意力机制模块【TripletAttention】======")
        self.module = _TripletAttention()

    def forward(self, x):
        return self.module(x)      


@MODELS.register_module()
class CoordAtt(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(CoordAtt, self).__init__()
        print("======激活注意力机制模块【CoordAtt】======")
        self.module = _CoordAtt(in_channels, in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)    


@MODELS.register_module()
class ParNetAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(ParNetAttention, self).__init__()
        print("======激活注意力机制模块【ParNetAttention】======")
        self.module = _ParNetAttention(channel=in_channels)

    def forward(self, x):
        return self.module(x)  

与attention_layers.py同级目录下创建attention文件夹,在attention文件中放12种注意力机制算法文件。

下载地址:mmdetection3的12种注意力机制资源-CSDN文库icon-default.png?t=N7T8https://download.csdn.net/download/lanyan90/89513979

使用方法:

以faster-rcnn_r50为例,创建faster-rcnn_r50_fpn_1x_coco_attention.py

_base_ = 'configs/detection/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py'

custom_imports = dict(imports=['mmdet.models.layers.attention_layers'], allow_failed_imports=False)

model = dict(
    backbone=dict(
        plugins = [
            dict(
                position='after_conv3',
                #cfg = dict(type='CBAMBlock', reduction=16, kernel_size=7)
                #cfg = dict(type='BAMBlock', reduction=16, dia_val=1)
                #cfg = dict(type='SEAttention', reduction=8)
                #cfg = dict(type='ECAAttention', kernel_size=3)
                #cfg = dict(type='ShuffleAttention', G=8)
                #cfg = dict(type='SpatialGroupEnhance', groups=8)
                #cfg = dict(type='DoubleAttention')
                #cfg = dict(type='SequentialPolarizedSelfAttention')
                #cfg = dict(type='CoTAttention', kernel_size=3)
                #cfg = dict(type='TripletAttention')
                #cfg = dict(type='CoordAtt', reduction=32)
                #cfg = dict(type='ParNetAttention')
            )
        ]
    )
)

想使用哪种注意力机制,放开plugins中的注释即可。

以mask-rcnn_r50为例,创建mask-rcnn_r50_fpn_1x_coco_attention.py

_base_ = 'configs/segmentation/mask_rcnn/mask-rcnn_r50_fpn_1x_coco.py'
custom_imports = dict(imports=['mmdet.models.layers.attention_layers'], allow_failed_imports=False)

model = dict(
    backbone=dict(
        plugins = [
            dict(
                position='after_conv3',
                #cfg = dict(type='CBAMBlock', reduction=16, kernel_size=7)
                #cfg = dict(type='BAMBlock', reduction=16, dia_val=1)
                #cfg = dict(type='SEAttention', reduction=8)
                #cfg = dict(type='ECAAttention', kernel_size=3)
                #cfg = dict(type='ShuffleAttention', G=8)
                #cfg = dict(type='SpatialGroupEnhance', groups=8)
                #cfg = dict(type='DoubleAttention')
                #cfg = dict(type='SequentialPolarizedSelfAttention')
                #cfg = dict(type='CoTAttention', kernel_size=3)
                #cfg = dict(type='TripletAttention')
                #cfg = dict(type='CoordAtt', reduction=32)
                #cfg = dict(type='ParNetAttention')
            )
        ]
    )
)

用法一样!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1895803.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

语音声控灯:置入NRK3301离线语音识别ic 掌控的灯具新风尚

一、语音声控灯芯片开发背景 我们不难发现,传统的灯具控制方式已难以满足现代人对便捷性和智能化的追求。传统的开关控制方式需要人们手动操作,不仅繁琐且不便,特别是在夜晚或光线昏暗的环境下,更容易造成不便甚至安全隐患。而语音…

Spring学习03-[Spring容器核心技术IOC学习进阶]

IOC学习进阶 Order使用Order改变注入顺序实现Ordered接口,重写getOrder方法来改变自动注入顺序 DependsOn使用 Lazy全局设置-设置所有bean启动时候懒加载 Scopebean是单例的,会不会有线程安全问题 Order 可以改变自动注入的顺序 比如有个animal的接口&a…

海外仓一件代发功能自动化:海外仓WMS系统配置方法

根据数据显示,2014-2019年短短几年之间,跨境电商销售总额增长了160%以上。这为跨境电商商家和海外仓,国际物流等服务端企业都提供了巨大的发展机遇。 然而,作为海外仓,要想服务好跨境电商,仓库作业的每一个…

JAVA进阶学习10

文章目录 一、创建不可变集合二、Stream流2.1 Stream流的获取2.1 Stream流的中间方法2.2 Stream流的终结方法 一、创建不可变集合 意义:如果一个集合中的数据在复制或使用过程中不能修改,或者被其他对象调用时不能改变内部数据,即增加数据的安…

【C++ 】解决 C++ 语言报错:Null Pointer Dereferenc

文章目录 引言 在 C 编程中,空指针解引用(Null Pointer Dereference)是一种常见且危险的错误。当程序试图通过空指针访问内存时,会导致程序崩溃或产生不可预期的行为。本文将详细探讨空指针解引用的成因、检测方法及其预防和解决…

Unity之VS脚本自动添加头部注释Package包开发

内容将会持续更新,有错误的地方欢迎指正,谢谢! Unity之VS脚本自动添加头部注释Package包开发 TechX 坚持将创新的科技带给世界! 拥有更好的学习体验 —— 不断努力,不断进步,不断探索 TechX —— 心探索、心进取&…

【靶机实战】Apache Log4j2命令执行漏洞复现

# 在线靶场 可以通过访问极核官方靶场开启靶机实验:极核靶场 -> 漏洞复现靶场 -> Log4j2-RCE 原文:【靶机实战】Apache Log4j2命令执行漏洞复现 - 极核GetShell (get-shell.com) # 简介 Apache Log4j2 是一个广泛使用的 Java 日志记录库&#…

秋招突击——设计模式补充——简单工厂模式和策略模式

文章目录 引言正文简单工厂模式策略模式策略模式和工厂模式的结合策略模式解析 总结 引言 一个一个来吧,面试腾讯的时候,问了我单例模式相关的东西,自己这方面的东西,还没有看过。这里需要需要补充一下。但是设计模式有很多&…

比赛获奖的武林秘籍:01 如何看待当代大学生竞赛中“卷”“祖传老项目”“找关系”的现象?

比赛获奖的武林秘籍:01 如何看待当代大学生竞赛中“卷”“祖传老项目”“找关系”的现象? 摘要 本文主要分析了大学生电子计算机类比赛中“卷”“祖传老项目”“找关系”的现象,结合自身实践经验,给出了相应的解决方案。 正文 …

1-4 NLP发展历史与我的工作感悟

1-4 NLP发展历史与我的工作感悟 主目录点这里 第一个重要节点:word2vec词嵌入 能够将无限的词句表示为有限的词向量空间,而且运算比较快,使得文本与文本间的运算有了可能。 第二个重要节点:Transformer和bert 为预训练语言模型发…

百日筑基第十一天-看看SpringBoot

百日筑基第十一天-看看SpringBoot 创建项目 Spring 官方提供了 Spring Initializr 的方式来创建 Spring Boot 项目。网址如下: https://start.spring.io/ 打开后的界面如下: 可以将 Spring Initializr 看作是 Spring Boot 项目的初始化向导&#xff…

【Unity navigation面板】

【Unity navigation面板】 Unity的Navigation面板是一个集成在Unity编辑器中的界面,它允许开发者对导航网格(NavMesh)进行配置和管理。 Unity Navigation面板的一些关键特性和功能: 导航网格代理(NavMesh Agent&…

手动访问mongo和ES插入和查询

1、手动访问mongo 1.1、mongo连接数据库 1.2、mongo插入和查询 db.hmf_test.insert( { "aoeId": "1", "aoeAes": "吴秀梅", "aoeSm4": "北京xx网络技术有限公司.", "aoeSm4_a": "…

针对某客户报表系统数据库跑批慢进行性能分析及优化

某客户报表系统数据库跑批时间过长,超出源主库较多,故对其进行了分析调优,目前状态如下: 1、业务连接的rac的scanip,因为负载均衡将跑批的连接连接到了多个计算节点导致节点间通讯成本较高,故速率缓慢&…

Websocket通信实战项目(图片互传应用)+PyQt界面+python异步编程(async) (上)服务器端python实现

Rqtz : 个人主页 ​​ 共享IT之美,共创机器未来 ​ Sharing the Beauty of IT and Creating the Future of Machines Together 目录 项目背景 ​编辑​专有名词介绍 服务器GUI展示 功能(位置见上图序号) 客户端GUI展示(h5cssjs&#xf…

allure如何记录操作步骤,操作步骤不写在测试用例中,同样可以体现在allure报告,如何实现

嗨,我是兰若,今天写完用例,在运行用例并且生成报告的时候,发现报告里面没有具体的操作步骤,这可不行,如果没有具体的操作步骤的话,用例运行失败了,要怎么知道问题是出现在哪一个步骤…

Android studio开发入门教程详解(复习)

引言 本文为个人总结Android基础知识复习笔记。如有不妥之处,敬请指正。后续将持续更新更多知识点。 文章目录 引言UITextView文本基本用法实际应用常用属性和方法 Button按钮处理点击事件 EditText输入框基本属性高级特性 ImageView图片ImageView的缩放模式 Prog…

adobe pdf设置默认打开是滚动而不是单页视图

上班公司用adobe pdf,自己还不能安装其它软件。 每次打开pdf,总是默认单页视图,修改滚动后,下次打开又 一样,有时候比较烦。 后面打开编辑->首选项, 如下修改,下次打开就是默认滚动了

数据结构 —— 图的遍历

数据结构 —— 图的遍历 BFS(广度遍历)一道美团题DFS(深度遍历) 我们今天来看图的遍历,其实都是之前在二叉树中提过的方法,深度和广度遍历。 在这之前,我们先用一个邻接矩阵来表示一个图&#…