YOLOv8添加AIFI(Attention-based Intrascale Feature Interaction模块替换SPPF模块)

news2024/12/28 19:38:35

1. 引言

1.1 相关介绍

模块名称:Attention-based Intrascale Feature Interaction
论文名称:RT-DETR: DETRs Beat Yolos on Real-time Object Detection
这是论文中的图,此处将其中的AIFI模块拿过来改进YOLOv8。
在这里插入图片描述

1.2 其他可改进SPPF模块

  1. 如何修改:YOLOv8修改特征金字塔(替换SPPF模块)
  2. 或者看此贴:yolov8改进——SFFP特征金字塔池化修改(详细版)
  3. 常见特征金字塔模块代码实现:常见特征金字塔模块代码实现

2.改进

2.1 AIFI代码

在YOLOv8新版中,已经集成了这个模块,因此,这里不展示如何放置到yolov8中。


如果使用的是老版的YOLOV8代码,nn模块下新建一个AIFI.py即可。
代码如下:

class TransformerEncoderLayer(nn.Module):
    """Defines a single layer of the transformer encoder."""

    def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False):
        """Initialize the TransformerEncoderLayer with specified parameters."""
        super().__init__()
        self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
        # Implementation of Feedforward model
        self.fc1 = nn.Linear(c1, cm)
        self.fc2 = nn.Linear(cm, c1)

        self.norm1 = nn.LayerNorm(c1)
        self.norm2 = nn.LayerNorm(c1)
        self.dropout = nn.Dropout(dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.act = act
        self.normalize_before = normalize_before

    @staticmethod
    def with_pos_embed(tensor, pos=None):
        """Add position embeddings to the tensor if provided."""
        return tensor if pos is None else tensor + pos

    def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
        """Performs forward pass with post-normalization."""
        q = k = self.with_pos_embed(src, pos)
        src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
        src = src + self.dropout2(src2)
        return self.norm2(src)

    def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
        """Performs forward pass with pre-normalization."""
        src2 = self.norm1(src)
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
        return src + self.dropout2(src2)

    def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
        """Forward propagates the input through the encoder module."""
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
        return self.forward_post(src, src_mask, src_key_padding_mask, pos)


class AIFI(TransformerEncoderLayer):
    """Defines the AIFI transformer layer."""

    def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False):
        """Initialize the AIFI instance with specified parameters."""
        super().__init__(c1, cm, num_heads, dropout, act, normalize_before)

    def forward(self, x):
        """Forward pass for the AIFI transformer layer."""
        c, h, w = x.shape[1:]
        pos_embed = self.build_2d_sincos_position_embedding(w, h, c)
        # Flatten [B, C, H, W] to [B, HxW, C]
        x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))
        return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous()

    @staticmethod
    def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
        """Builds 2D sine-cosine position embedding."""
        grid_w = torch.arange(int(w), dtype=torch.float32)
        grid_h = torch.arange(int(h), dtype=torch.float32)
        grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij')
        assert embed_dim % 4 == 0, \
            'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
        pos_dim = embed_dim // 4
        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
        omega = 1. / (temperature ** omega)

        out_w = grid_w.flatten()[..., None] @ omega[None]
        out_h = grid_h.flatten()[..., None] @ omega[None]

        return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None]

2.2 task.py

这里新版YOLOv8也帮我们写好了,因此,不需要改动。


如果是老版的代码,在parse_model方法下,找到一堆elif的地方添加以下代码。

        elif m is AIFI:
            args = [ch[f], *args]

老版如下。并没有AIFI的代码。
在这里插入图片描述

2.3 模型改进

将yolov8.yaml复制一份,新建yolov8-AIFI.yaml,将SPPF模块替换为AIFI即可,如下。
SPPF那一行修改如下: - [-1, 1, AIFI, [1024, 8]] # 9

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, AIFI, [1024, 8]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

3. 运行图

运行效果如下,没有报错。
在这里插入图片描述
提醒:这个对torch版本要求比较高!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1163291.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

6-7 二叉树的非递归遍历 分数 10

文章目录 1.非递归前序遍历1.1C写法及解析1.2本题ac答案 2.非递归中序遍历2.非递归后序遍历2.1栈模拟实现非递归C写法本题ac答案本题flag标记法 2.2逆序思想2.3整体代码 1.非递归前序遍历 1.1C写法及解析 vector<int> preorderTraversal(TreeNode* root) {vector<in…

数据结构(超详细讲解!!)第十九节 块链串及串的应用

1.定义 由于串也是一种线性表&#xff0c;因此也可以采用链式存储。由于串的特殊性&#xff08;每个元素只有一个字符&#xff09;&#xff0c;在具体实现时&#xff0c;每个结点既可以存放一个字符&#xff0c;也可以存放多个字符。每个结点称为块&#xff0c;整个链表称为块链…

linux杀毒软件ClamAV下载、安装(在线安装、离线安装)

流程图 下载 ClamAVNet 离线安装脚本 #扫描文件路径&#xff0c;程序安装路径&#xff0c;当然也可以全盘扫描&#xff0c;全盘扫描的时候路径设置为"/"即可 scanfile"/home" #分钟 小时 日 月 年, 例:0 0 * * * 表示每天0时0秒 scantime"0 0 * * *…

2023/11/2 JAVA学习

接口里面只有这两个东西,无构造器,代码块之类的 私有方法可以在接口里的其他默认方法,或私有方法中访问 静态方法,类持有,可直接调用 接口多继承,可以一个接口继承其他几个接口把几个接口合并成一个接口 先创建外部类,再创建成员内部类 在外部类中无法直接访问内部类的方法变量…

04 训练 windows环境下调用GPU资源做模型训练加速示例

笔者有一台windows电脑,要想在训练yolo模型的时候提升速度,可以按照笔者本文的示例进行。 1、检查可用GPU资源 可以在设备管理器中检查电脑中是否含有GPU设备,如下图所示,可以在设备管理器中检查显卡信息,证明我们有GPU资源可以在训练模型的时候调用。 2、核对显卡算力 …

TIME_WAIT相关知识

四次挥手 这是TCP四次握手的过程图。 TCP 连接终止时&#xff0c;主机 1 先发送 FIN 报文&#xff0c;主机 2 进入 CLOSE_WAIT 状态&#xff0c;并发送一个 ACK 应答&#xff0c;同时&#xff0c;主机 2 通过 read 调用获得 EOF&#xff0c;并将此结果通知应用程序进行主动关闭…

uniapp app端选取(上传)多种类型文件

这里仅记录本人一些遇到办法&#xff0c;后台需要file对象&#xff0c;而App端运行在jsCore内&#xff0c;并非浏览器环境&#xff0c;并没有File类&#xff0c;基本返回的都是blob路径&#xff0c;uni-file-picker得app端只支持图片和视频&#xff0c;我这边需求是音视频都要支…

浏览器请求http地址,自动跳转成https

谷歌浏览器&#xff1a; 点击url地址左侧的锁&#xff0c;选择【网站设置】 点击【隐私和安全】&#xff0c;将【不安全内容】改为允许&#xff0c;然后刷新即可

协力共创智能未来:乐鑫 ESP RainMaker 云方案线下研讨会圆满落幕

近日&#xff0c;乐鑫 ESP RainMaker 云方案线下研讨会&#xff08;深圳&#xff09;在亚马逊云科技与合作伙伴嘉宾的支持下成功举办&#xff0c;吸引了众多来自智能家电、照明电工、能源和宠物等行业的品牌客户、方案商和制造商。研讨会围绕如何基于乐鑫 ESP RainMaker 硬件连…

关于客户旅程地图:你需要知道的一切

客户旅程地图&#xff08;Customer Journey Map&#xff09;是一种工具&#xff0c;用于可视化和理解客户与品牌或产品互动的全过程。它以客户的角度展示了他们的互动&#xff0c;从最初的意识阶段到购买决策&#xff0c;再到购买和售后支持。客户旅程地图有助于企业深入了解客…

什么是物流RPA?物流RPA解决什么问题?物流RPA实施难点在哪里?

RPA指的是机器人流程自动化&#xff0c;它是一套模拟人类在计算机、平板电脑、移动设备等界面执行任务的软件。通过RPA&#xff0c;可以自动完成重复性、繁琐的工作&#xff0c;提高工作效率和质量&#xff0c;降低人力成本。RPA适用于各种行业和场景&#xff0c;例如财务、人力…

Redis基础教程

Redis基础教程 Redis介绍 官方网站&#xff1a;https://redis.io/ Redis是一种键值型的NoSql数据库&#xff0c;这里有两个关键字&#xff1a; 键值型&#xff1a;Redis中存储的数据都是以key、value对的形式存储NoSql&#xff1a;相对于传统关系型数据库而言&#xff0c;有…

关键词搜索亚马逊商品数据接口(标题|主图|SKU|价格|优惠价|掌柜昵称|店铺链接|店铺所在地)

亚马逊提供了API接口来获取商品数据。其中&#xff0c;关键词搜索亚马逊商品接口&#xff08;item_search-按关键字搜索亚马逊商品接口&#xff09;可以用于获取按关键字搜索到的商品数据。 通过该接口&#xff0c;您可以使用API Key和API Secret来认证身份&#xff0c;并使用…

Seata 四种事务模式

Seata 是一款开源的分布式事务解决方案&#xff0c;致力于提供高性能和简单易用的分布式事务服务。Seata 将为用户提供了 AT、TCC、SAGA 和 XA 事务模式&#xff0c;为用户打造一站式的分布式解决方案。 全文参考文献&#xff1a;中文文档 TC (Transaction Coordinator) - 事务…

Redis入门指南学习笔记(2):常用数据类型解析

一.前言 本文主要介绍Redis中包含几种主要数据类型&#xff1a;字符串类型、哈希类型、列表类型、集合类型和有序集合类型。 二.字符串类型 字符串类型是Redis中最基本的数据类型&#xff0c;它是其他4种数据类型的基础&#xff0c;其他数据类型与字符串类型的差别从某种角度…

HALCON的综合应用案例【01】: 3D 算法处理在 Visual Studio 2019 C# 环境中的集成实例

前言: HALCON 为一款比较流行的商业视觉处理软件,他提供了多种开发的模式,可以在HALCON中开发,也可以将HALCON的设计通过导出库的形式集成到其他开发环境里面,以方便系统集成。本文为笔者自己的一个3D 视觉检测项目,利用HALCON的3D 库开发算法,然后,将算法集成到 MS-V…

指挥通信车360度3d虚拟互动展示系统的优势及特点

通信车是装有通信装备&#xff0c;用于保障通信联络的专用车辆&#xff0c;用于偏僻/特殊环境下的机动通信。并且机动通信局装备通常分为应急综合通信车、网络管理车、程控电话车、自适应跳频电台车、数字扩频接力车、散射通信车、卫星通信车、光缆引接车、线缆收放车和通信电源…

医疗数据可视化大屏:重构医疗决策的未来

医疗行业一直是信息密集型领域之一&#xff0c;它的复杂性不仅在于患者病历和医疗数据的海量积累&#xff0c;还包括了病情诊断、医疗资源分配、病患治疗等多层次的挑战。随着信息技术的不断发展&#xff0c;医疗数据可视化大屏成为了一种创新性的工具&#xff0c;它为医疗管理…

Linux学习笔记之一(计算机网络基础)

Linux learning note 1、计算机网络1.1、IP地址和MAC地址1.2、NAT、端口1.3、动态IP、静态IP、DHCP1.4、子网掩码、网关地址、DNS服务器1.5、TCP、UDP、ftp、http 2、虚拟机的网络管理2.1、桥接模式2.2、NAT模式2.3、仅主机模式2.4、总结 1、计算机网络 1.1、IP地址和MAC地址 …

python:将多个9波段影像tif文件转成numpy格式保存

作者:CSDN @ _养乐多_ 最近有粉丝问,如何将多个9波段的Aster影像tif文件转成numpy格式保存,然后输入网络进去训练。本文提供了两种思路和代码。 结果如下图所示, 文章目录 一、简单方法(分两步)二、端到端方法(一步到位)一、简单方法(分两步) 先将所有的多波段影像…