Pytorch注意力机制应用到具体网络方法(闭眼都会版)

news2024/12/26 3:49:17

文章目录

  • 以YoloV4-tiny为例
    • 要加入的注意力机制代码
    • 模型中插入注意力机制

以YoloV4-tiny为例

在这里插入图片描述
解释一下各个部分:

  • 最左边这部分为主干提取网络,功能为特征提取
  • 中间这边部分为FPN,功能是加强特征提取
  • 最后一部分为yolo head,功能为获得我们具体的一个预测结果

需要明白几个点:

  • 注意力机制模块是一个即插即用的模块,理论上是可以添加到任何一个特征图后面
  • 但是,不建议添加到主干部分(即最左边的那部分),主干部分所用的特征是我们后面处理所用的基础,故不建议添加到主干部分
  • 如果添加到主干部分,由于注意力机制模块 它的权值模块是随机初始化的,那主干部分的权值就被破坏了,最开始提取出来的特征就不好用了。
  • 故建议把注意力机制模块添加到主干以外的部分

本节把注意力机制添加到加强网络里面,即上图的中间部分。
添加注意力机制可以添加到上图标注的部分。

要加入的注意力机制代码

这一部分为要加入的注意力机制模块,文件名为attention.py

import torch
from torch import nn
# 通道注意力机制
class channel_attention(nn.Module):
    def __init__(self,channel,ration=16):   #因为要进行全连接,故需要传入通道数量,及缩放比例
        super(channel_attention,self).__init__()  #初始化
        #定义最大池化层
        self.max_pool = nn.AdaptiveMaxPool2d(1) #输出层的高和宽是1
        #定义平均池化
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.fc = nn.Sequential(
            #定义第一次全连接
            nn.Linear(channel,channel // ration ,False),
            nn.ReLU(),
            # 定义第二次全连接
            nn.Linear(channel//ration,channel,False)
        )
        #由于图中的通道注意力机制是连个全连接层相加之后再取sigmoid
        self.sigmoid=nn.Sigmoid()

    #前传部分
    def forward(self,x):
        b,c,h,w=x.size()
        #首先对输入进来的x先进行一个全局最大池化 在进行一个全局平均池化
        max_pool_out=self.max_pool(x).view([b,c])
        avg_pool_out=self.avg_pool(x).view([b,c])
        #然后对两次池化后的结果用共享的全连接层fc进行处理
        max_fc_out=self.fc(max_pool_out)
        avg_fc_out=self.fc(avg_pool_out)
        #最后将上面的两个结果进行相加
        out=max_fc_out + avg_fc_out
        out=self.sigmoid(out).view([b,c,1,1])
        #print(out)
        return out * x
# 空间注意力机制
class spacial_attention(nn.Module):
    def __init__(self,kernel_size=7):   #空间注意力没有通道数,故不用传入channel和ration
        #但是空间注意力会进行一次卷积,故我们需要关注卷积核大小,一般为3或7
        super(spacial_attention,self).__init__()  #初始化
        padding=7//2  #卷积核大小整除输入通道数
        self.conv=nn.Conv2d(2,1,kernel_size,1,padding,bias=False)
        #由图可知输入通道数是2,输出通道数为1,卷积核大小默认设置为7,步长为1,因为不需要压缩特征层阿高和宽

        #由于图中的通道注意力机制是连个全连接层相加之后再取sigmoid
        self.sigmoid=nn.Sigmoid()
    #空间注意力机制前传部分
    def forward(self,x):
        b,c,h,w=x.size()
        max_pool_out,_= torch.max(x,dim=1,keepdim=True)#需要把通道这一维度保留下来,故设置keepdim为True
        #对于pytorch来讲,它的通道是在第一维度,也就是batchsize后面的那个维度故定义dim为1
        mean_pool_out = torch.mean(x,dim = 1,keepdim=True)
        #对最大值和平均值进行一个堆叠
        pool_out = torch.cat([max_pool_out, mean_pool_out],dim=1)
        #对堆叠后的结果取一个卷积
        out=self.conv(pool_out)
        out=self.sigmoid(out)
        print(out)
        return out * x

#把空间注意力机制和通道注意力机制进行一个融合
class Cbam(nn.Module):
    def __init__(self,channel,ratio=16,kernel_size=7):
        super(Cbam,self).__init__()
        #调用已经定义好的2个注意力机制
        self.channel_attention=channel_attention(channel,ratio)
        self.spacial_attention = spacial_attention(kernel_size)
    #融合后机制的前传部分
    def forward(self,x):
        x=self.channel_attention(x)
        x=self.spacial_attention(x)
        return x

在模型文件(yolo.py)中,首行添加如下部分

from .attention import se_block,cbam_block,eca_block
attention_blocks=[se_block,cbam_block,eca_block]
为何要设置成上面的形式?
为了方便调用,到时候可以直接编写下面的代码调用具体的注意力机制模块
attention_blocks[0]

之后,需要找到yolo.py里面的模型主体部分,大概形式如下代码

class YoloBody(nn.Module):
	def __init__(self,anchors_mask,num_classes,phi=0)
	#在原来的代码上只是添加了phi,代表我们选用的注意力机制模块,默认情况下为0
		super(YoloBody, self).__init__()
	        self.backbone       = darknet53_tiny(None)
	
	        self.conv_for_P5    = BasicConv(512,256,1)
	        self.yolo_headP5    = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)
	
	        self.upsample       = Upsample(256,128)
	        self.yolo_headP4    = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],384)
	        #下面这部分为自己填写
	     	self.phi    = phi  #这个是自己添加的
	        if 1 <= self.phi and self.phi <= 3:
            self.feat1_att      = attention_block[self.phi - 1](256)  #通道数为256
            self.feat2_att      = attention_block[self.phi - 1](512)#通道数为512
            self.upsample_att   = attention_block[self.phi - 1](128)#通道数为128
            #通道数到底是多少看这个模型的前传部分的通道数为多少
    def forward(self, x):
		#---------------------------------------------------#
		#   生成CSPdarknet53_tiny的主干模型
		#   feat1的shape为26,26,256
		#   feat2的shape为13,13,512
		#---------------------------------------------------#
		feat1, feat2 = self.backbone(x)
		#下面代码为自己填写
		if 1 <= self.phi and self.phi <= 3:#如果满足条件就添加具体的注意力机制
		    feat1 = self.feat1_att(feat1)
		    feat2 = self.feat2_att(feat2)
		#下面代码模型自带
		# 13,13,512 -> 13,13,256
		P5 = self.conv_for_P5(feat2)
		# 13,13,256 -> 13,13,512 -> 13,13,255
		out0 = self.yolo_headP5(P5) 
		
		# 13,13,256 -> 13,13,128 -> 26,26,128
		P5_Upsample = self.upsample(P5)
		# 26,26,256 + 26,26,128 -> 26,26,384
		#上面代码模型自带,下面代码自己编写
		if 1 <= self.phi and self.phi <= 3:
		    P5_Upsample = self.upsample_att(P5_Upsample)
		 #下面代码模型自带
		P4 = torch.cat([P5_Upsample,feat1],axis=1)
		
		# 26,26,384 -> 26,26,256 -> 26,26,255
		out1 = self.yolo_headP4(P4)
		
		return out0, out1


模型中插入注意力机制

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2265605.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

交通控制系统中的 Prompt工程:引导LLMs实现高效交叉口管理 !

本研究提出了一种新型的交通控制系统方法&#xff0c;通过使用大型语言模型&#xff08;LLMs&#xff09;作为交通控制器。该研究利用它们的逻辑推理、场景理解和决策能力&#xff0c;实时优化通行能力并提供基于交通状况的反馈。LLMs将传统的分散式交通控制过程集中化&#xf…

产品升级!Science子刊同款ARGs-HOST分析,get!

凌恩生物明星chanpin 抗性宏基因-宿主分析 Science子刊同款分析 数据挖掘更进一步&#xff01; 抗生素的大量使用与滥用使微生物体内编码抗生素抗性的基因在环境中选择性富集&#xff0c;致病菌通过基因突变或者水平基因转移获得抗生素抗性基因后&#xff0c;导致抗生素治疗…

Python8-写一些小作业

记录python学习&#xff0c;直到学会基本的爬虫&#xff0c;使用python搭建接口自动化测试就算学会了&#xff0c;在进阶webui自动化&#xff0c;app自动化 python基础8-灵活运用顺序、选择、循环结构 写一些小练习题目1、给一个半径&#xff0c;求圆的面积和周长&#xff0c;…

四相机设计实现全向视觉感知的开源空中机器人无人机

开源空中机器人 基于深度学习的OmniNxt全向视觉算法OAK-4p-New 全景硬件同步相机 机器人的纯视觉避障定位建图一直是个难题&#xff1a; 系统实现复杂 纯视觉稳定性不高 很难选到实用的视觉传感器 为此多数厂家还是采用激光雷达的定位方案。 OAK-4p-New 为了弥合这一差距…

Diagramming AI: 使用自然语言来生成各种工具图

前言 在画一些工具图时&#xff08;流程图、UML图、架构图&#xff09;&#xff0c;你还在往画布上一个个的拖拽组件来进行绘制么&#xff1f;今天介绍一款更有效率的画图工具&#xff0c;它能够通过简单的自然语言描述来完成一个个复杂的图。 首页 进入官网之后&#xff0c;我…

springboot启动不了 因一个spring-boot-starter-web底下的tomcat-embed-core依赖丢失

这个包丢失了 启动不了 起因是pom中加入了 <tomcat.version></tomcat.version>版本指定&#xff0c;然后idea自动编译后&#xff0c;包丢了&#xff0c;删除这个配置后再也找不回来&#xff0c; 这个包正常在 <dependency><groupId>org.springframe…

“笃威尔数字技术”受邀出席2024 H-Tech Data创新情报论坛!

​ 2024年12月20日&#xff0c;以“创新情报 向新而行”为主题的2024 H-Tech Data创新情报论坛暨创新情报专业委员会成立仪式在深圳成功举办。本次大会由中国科学技术情报学会主办&#xff0c;由深圳国家高新技术产业创新中心牵头承办&#xff0c;旨在围绕技术赋能、场景应用、…

Android Studio 的革命性更新:Project Quartz 和 Gemini,开启 AI 开发新时代!

&#x1f31f; Android Studio 的革命性更新&#xff1a;Project Quartz 和 Gemini&#xff0c;开启 AI 开发新时代&#xff01; 在这个技术飞速发展的时代&#xff0c;Android 开发者们迎来了两项重大更新&#xff1a;Project Quartz 和 Gemini。这不仅仅是更新&#xff0c;而…

kkfileview代理配置,Vue对接kkfileview实现图片word、excel、pdf预览

kkfileview部署 官网&#xff1a;https://kkfileview.keking.cn/zh-cn/docs/production.html 这个是官网部署网址&#xff0c;这里推荐大家使用docker镜像部署&#xff0c;因为我是直接找运维部署的&#xff0c;所以这里我就不多说明了&#xff0c;主要说下nginx代理配置&am…

RT-DETR学习笔记(2)

七、IOU-aware query selection 下图是原始DETR。content query 是初始化为0的label embedding, position query 是通过nn.Embedding初始化的一个嵌入矩阵&#xff0c;这两部分没有任何的先验信息&#xff0c;导致DETR的收敛慢。 RT-DETR则提出要给这两部分&#xff08;conten…

iOS 苹果开发者账号: 查看和添加设备UUID 及设备数量

参考链接&#xff1a;苹果开发者账号下添加新设备UUID - 简书 如果要添加新设备到 Profiles 证书里&#xff1a; 1.登录开发者中心 Sign In - Apple 2.找到证书设置&#xff1a; Certificate&#xff0c;Identifiers&Profiles > Profiles > 选择对应证书 edit &g…

汽车IVI中控开发入门及进阶(47):CarPlay开发

概述: 车载信息娱乐(IVI)系统已经从仅仅播放音乐的设备发展成为现代车辆的核心部件。除了播放音乐,IVI系统还为驾驶员提供导航、通信、空调、电源配置、油耗性能、剩余行驶里程、节能建议和许多其他功能。 ​ 驾驶座逐渐变成了你家和工作场所之外的额外生活空间。2014年,…

Oracle、ACCSEE与TDMS的区别

Oracle、ACCSEE和TDMS都是不同类型的数据管理和存储工具&#xff0c;它们各自有独特的用途、结构和复杂性。Oracle是一个功能强大的关系型数据库管理系统&#xff0c;适用于大规模企业级应用&#xff0c;支持复杂查询和事务管理。ACCSEE主要应用于实时数据采集和过程监控&#…

商场消防电气控制系统设计(论文+源码)

1系统的功能及方案设计 如图2.1所示为本次设计的整体框图&#xff0c;其中单片机部分采用ST89C52来负责协调各个模块&#xff1b;液晶选择LCD1602液晶屏来显示信息;温度传感器选择PT1000进行温度的检测&#xff1b;烟雾传检测选择MQ2烟雾传感器&#xff1b;CO2检测选择CCS811模…

7. petalinux 根文件系统配置(package group)

根文件系统配置&#xff08;Petalinux package group&#xff09; 当使能某个软件包组的时候&#xff0c;依赖的包也会相应被使能&#xff0c;解决依赖问题&#xff0c;在配置页面的help选项可以查看需要安装的包 每个软件包组的功能: packagegroup-petalinux-audio包含与音…

2024年12月一区SCI-加权平均优化算法Weighted average algorithm-附Matlab免费代码

引言 本期介绍了一种基于加权平均位置概念的元启发式优化算法&#xff0c;称为加权平均优化算法Weighted average algorithm&#xff0c;WAA。该成果于2024年12月最新发表在中JCR1区、 中科院1区 SCI期刊 Knowledge-Based Systems。 在WAA算法中&#xff0c;加权平均位置代表当…

操作系统(23)外存的存储空间的管理

一、外存的基本概念与特点 定义&#xff1a;外存&#xff0c;也称为辅助存储器&#xff0c;是计算机系统中用于长期存储数据的设备&#xff0c;如硬盘、光盘、U盘等。与内存相比&#xff0c;外存的存储容量大、成本低&#xff0c;但访问速度相对较慢。特点&#xff1a;外存能够…

【202】仓库管理系统

-- 基于springboot仓库管理系统设计与实现 开发技术栈: 开发语言 : Java 开发软件 : Eclipse/MyEclipse/IDEA JDK版本 : JDK8 后端技术 : SpringBoot 前端技术 : Vue、Element、HTML、JS、CsS、JQuery 服务器 : Tomcat8/9 管理包 : Maven 数据库 : MySQL5.x/8 数据库工具 : …

iDP3复现代码数据预处理全流程(二)——vis_dataset.py

vis_dataset.py 主要作用在于点云数据的可视化&#xff0c;并可以做一些简单的预处理 关键参数基本都在 vis_dataset.sh 中定义了&#xff0c;需要改动的仅以下两点&#xff1a; 1. 点云图像保存位置&#xff0c;因为 dataset_path 被设置为了绝对路径&#xff0c;因此需要相…

重温设计模式--1、组合模式

文章目录 1 、组合模式&#xff08;Composite Pattern&#xff09;概述2. 组合模式的结构3. C 代码示例4. C示例代码25 .应用场景 1 、组合模式&#xff08;Composite Pattern&#xff09;概述 定义&#xff1a;组合模式是一种结构型设计模式&#xff0c;它允许你将对象组合成…