YOLOv8如何添加注意力模块?

news2024/9/27 23:22:37

分为两种:有参注意力和无参注意力。
eg:
有参:

import torch
from torch import nn

class EMA(nn.Module):
    def __init__(self, channels, factor=8):
        super(EMA, self).__init__()
        self.groups = factor
        assert channels // self.groups > 0
        self.softmax = nn.Softmax(-1)
        self.agp = nn.AdaptiveAvgPool2d((1, 1))
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
        self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
        self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        b, c, h, w = x.size()
        group_x = x.reshape(b * self.groups, -1, h, w)  # b*g,c//g,h,w
        x_h = self.pool_h(group_x)
        x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
        hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
        x_h, x_w = torch.split(hw, [h, w], dim=2)
        x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
        x2 = self.conv3x3(group_x)
        x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x12 = x2.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
        x22 = x1.reshape(b * self.groups, c // self.groups, -1)  # b*g, c//g, hw
        weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
        return (group_x * weights.sigmoid()).reshape(b, c, h, w)

无参:

import torch
import torch.nn as nn


class SimAM(torch.nn.Module):
    def __init__(self, e_lambda=1e-4):
        super(SimAM, self).__init__()

        self.activaton = nn.Sigmoid()
        self.e_lambda = e_lambda

    def __repr__(self):
        s = self.__class__.__name__ + '('
        s += ('lambda=%f)' % self.e_lambda)
        return s

    @staticmethod
    def get_module_name():
        return "simam"

    def forward(self, x):
        b, c, h, w = x.size()

        n = w * h - 1

        x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
        y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5

        return x * self.activaton(y)

1、在nn文件夹下新建attention.py文件,把上面俩代码放进去
在这里插入图片描述

2、在tasks.py文件里面导入俩函数
在这里插入图片描述
3、在解析函数里面添加解析代码
在这里插入图片描述
c1:上一层的输出通道数,也是这一层的输入通道数
C2:该层的输出通道数,即将成为下一层的输入通道数
args[]:每个带参数的模块,都要指定这个东西,这个包括[c1,c2,剩下的参数],然后传给该层的模块,有些模块不需要额外参数,就只传一个输出通道数给这一层就行
切记!!!C2是这一层的输出通道数,而args[]里的输入输出通道数是给模块的
4、新建模型配置文件
在这里插入图片描述
4、快速验证配置文件,新建main.py文件,然后运行

from ultralytics import YOLO

if __name__=='__main__':
    print('11111111111')
    model=YOLO('/home/xxxxxxxx/v8/yolov8-main/ultralytics/models/v8/yolov8-att.yaml')

在这里插入图片描述
5、如果想修改这个参数,传进来
在这里插入图片描述
6、配置文件改也行,传进去
在这里插入图片描述
7、总结:放进attention.py,接着在tasks.py里注册,接着解析函数添加(有通道无通道),模型配置文件替换

8、第二种:在4、6、9后面加
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1151269.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

CondaError_ Downloaded bytes did not match Content-Length

问题 使用anaconda下载包文件时,出现了CondaError: Downloaded bytes did not match Content-Length的错误 CondaError: Downloaded bytes did not match Content-Lengthurl: https://conda.anaconda.org/pytorch/win-64/pytorch-2.1.0-py3.11_cuda11.8_cudnn8_0.…

二维码智慧门牌管理系统升级,打造高效事件处理流程

文章目录 前言一、二维码智慧门牌管理系统的升级目标二、事件处理流程优化三、升级带来的好处 前言 随着城市化的不断推进,城市管理面临越来越多的挑战。为了更好地解决这些问题,许多城市已经开始采用二维码智慧门牌管理系统。这个系统不仅可以提高城市…

操作系统第一章-第三章大题_期末考试_详细易考

1.ABC三道作业如下表所示: 作业输入CPU输出A1505050B10060100C806050 (1) 计算在单道环境下运行时CPU的利用率;(2分) (2) 假设计算机系统中具有一个CPU、三个通道,画出ABC三道作业并发执行的情况图,并计算CPU利用率。(12分) 问题分析: c p u 利用率 c p u 有效…

Python对象(Object)与类型(Type)的关系

Object与Type 1、Object与Type概述2、Object与Type的关系 1、Object与Type概述 对象(Object)和类型(Type)是Python中两个最最基本的概念,它们是构筑Python语言大厦的基石 所有的数据类型,值,变…

[BUUCTF NewStarCTF 2023 公开赛道] week4 crypto/pwn

再补完这个就基本上完了. crypto RSA Variation II Schmidt-Samoa密码系统看上去很像RSA,其中Npqq, 给的eN给了d from secret import flag from Crypto.Util.number import *p getPrime(1024) q getPrime(1024)N p*p*qd inverse(N, (p-1)*(q-1)//GCD(p-1, q-1))m bytes…

cause: java.lang.numberformatexception: for input string

一个十分粗心的错误 我本来想要写的是name不为空,并且不为空字符串,结果不小心写成了空格! 解决方案:将空格改为空字符串即可

JMeter的使用——傻瓜式学习【中】

目录 前言 1、JMeter参数化 1.1、什么是参数化 1.2、用户定义的变量 1.2.1、什么时候使用用户定义的变量 1.2.2、使用“用户定义的变量”进行参数化的步骤: 1.2.3、案例 1.3、用户参数 1.3.1、什么时候使用用户参数? 1.3.2、使用“用户参数”进…

交叉编译工具链(以STM32MP1为例)

1.什么是交叉编译工具链? 在一个系统上进行编译,在另一个系统上进行执行 2.STM32MP1交叉编译工具链 3.交叉编译器内容 4.两种工具链模式 5.两种链接模式 6.工具使用 注意:OpenSTLinux已经提供了编译框架,不需要命令行手工编译 …

Spring Cloud 实战 | 解密Feign底层原理,包含实战源码

专栏集锦,大佬们可以收藏以备不时之需 Spring Cloud实战专栏:https://blog.csdn.net/superdangbo/category_9270827.html Python 实战专栏:https://blog.csdn.net/superdangbo/category_9271194.html Logback 详解专栏:https:/…

在VM虚拟机上安装centos并了解Linux常用命令

一. centos安装 新建一个虚拟机,使用ISO映像文件(在浏览器上直接搜索阿里云镜像站,下载合适的镜像文件) 安装后设置密码然后重启 重启后输入账号和密码 查看IP 输入命令: vi ifcfg-ens33,进入编辑界面&a…

物联网AI MicroPython传感器学习 之 PAJ7620手势识别传感器

学物联网,来万物简单IoT物联网!! 一、产品简介 手势识别传感器PAJ7620u2是一款集成3D手势识别和运动跟踪为一体的交互式传感器,传感器可以在有效范围内识别手指的顺时针/逆时针转动方向和手指的运动方向等。它可以识别13种手势&a…

STM32的RTC模块的难点推导

在 S T M 32 STM32 STM32的 R e a l t i m e c l o c k , R T C Real\quad time\quad clock,RTC Realtimeclock,RTC模块中有一些功能点不太好理解,下面我根据我自己对这些功能难点的理解来做一些推导并记录如下。 首先来看一下平滑数字校准。假设我们目前的 R …

万字解析设计模式之原型模式与建造者模式

一、原型模式 1.1概述 原型模式是一种创建型设计模式,其目的是使用已有对象作为原型来创建新的对象。原型模式的核心是克隆,即通过复制已有对象来创建新对象,而不是通过创建新对象的过程中独立地分配和初始化所有需要的资源。这种方式可以节…

CMT2310A一款低功耗高性能Sub-1GHz射频收发器芯片

CMT2310A是一款超低功耗,高性能,适用于各种113至960 MHz无线应用的00K,(G)FSK 和4(G)FSK 射频收发器。它是CMOSTEK NextGenRFTM 射频产品线的一部分,这条产品线包含完整的发射器,接收器和收发器。CMT2310A的高集成度,简…

npm package.json属性详解

npm package.json属性详解 概述 package.json必须是一个严格的json文件,而不仅仅是js里边的一个对象。其中很多属性可以通过npm-config来生成 name package.json中最重要的属性是name和version两个属性,这两个属性是必须要有的,否则模块就…

【机器学习(二) 线性代数基础I(Linear Algebra Foundations)】

机器学习(二) 线性代数基础I(Linear Algebra Foundations) 这一节主要介绍一些线性代数的基础。 目录 机器学习(二) 线性代数基础I(Linear Algebra Foundations)1. 向量 Vectors2. 复杂度 Complexity3.线…

基于3D点云的语义分割模型调研(最新更新2023.10.30)

文章目录 3D点云分割数据集点云模型的评价指标3D点云语义分割方法发展PointSIFT模型的效果 https://blog.csdn.net/toCVer/article/details/126265782 基于深度学习的三维点云分割综述 3D点云分割数据集 传统的点云分割方法包括基于边缘检测的方法、基于区域增长的算法、基于特…

【Linux】:Linux开发工具之Linux编辑器vim的使用

🔫1.Linux编辑器-vim使用 📤 vi/vim的区别简单点来说,它们都是多模式编辑器,不同的是vim是vi的升级版本,它不仅兼容vi的所有指令,而且还有一些新的特性在里面。例如语法加亮,可视化操作不仅可以…

deepsort算法 卡尔曼滤波 匈牙利算法

目标追踪最核心的两个算法就是卡尔曼滤波和匈牙利算法算法。 卡尔曼滤波:根据当前帧中的轨迹预测下一帧的轨迹。匈牙利算法:将预测的目标位置与检测到的目标位置进行匹配,实现对目标的准确跟踪。Sort算法 Sort算法分为以下几个步骤: 1.卡尔曼…

AS/400-物理文件-02

物理文件 - Physical file Physical file物理文件中的条目级别相关命令 Physical file 简介物理文件 这是一个文件。包含预定义的结构化格式的数据。它是PF类型。通过使用CRTPF命令创建PF。PF中包含的字段的最大数量为8000。最多包含120个关键字段。 PF 的结构如下 TYPE SPECIF…