YOLOv8改进 添加可变形注意力机制DAttention

news2025/1/16 19:59:02

一、Deformable Attention Transformer论文

论文地址:arxiv.org/pdf/2201.00520.pdf

二、Deformable Attention Transformer注意力结构

Deformable Attention Transformer包含可变形注意力机制,允许模型根据输入的内容动态调整注意力权重。在传统的Transformer中,注意力是通过对查询和键向量之间的点积来确定的,然后将输入嵌入的加权和进行计算。然而,这种方法假设了一个刚性的注意力模式,其中每个查询都会参与固定的一组键。

在可变形注意力转换器中,注意力权重使用学习机制进行计算,该机制可以根据输入学习调整注意力模式。这使得模型能够捕捉输入的不同部分之间更复杂的关系,从而在需要建模长距离依赖或捕捉细粒度细节的任务上提供更好的性能。可变形注意力机制DAttention引入了额外的可学习参数,用于计算注意力权重。这些参数通过反向传播在训练过程中进行学习,使得模型能够调整注意力模式以更好地适应输入数据。

三、代码实现

1、在官方的yolov8包中ultralytics\ultralytics\nn\modules\__init__.py文件中的from .conv import和__all__中加入注意力机制DAttention。

2、在ultralytics\ultralytics\nn\modules\conv.py文件中上边引用包:

import einops
from timm.models.layers import trunc_normal_
import torch.nn.functional as F

__all__中同样添加DAttention:

并在该conv.py文件中输入DAttention的代码:

 #########            添加DAttention        #####################
class LayerNormProxy(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = einops.rearrange(x, 'b c h w -> b h w c')
        x = self.norm(x)
        return einops.rearrange(x, 'b h w c -> b c h w')
class DAttention(nn.Module):

    def __init__(
            self, q_size=(224, 224), kv_size=(224, 224), n_heads=8, n_head_channels=32, n_groups=1,
            attn_drop=0.0, proj_drop=0.0, stride=1,
            offset_range_factor=-1, use_pe=True, dwc_pe=True,
            no_off=False, fixed_pe=False, ksize=9, log_cpb=False
    ):

        super().__init__()
        n_head_channels = int(q_size / 8)
        q_size = (q_size, q_size)

        self.dwc_pe = dwc_pe
        self.n_head_channels = n_head_channels
        self.scale = self.n_head_channels ** -0.5
        self.n_heads = n_heads
        self.q_h, self.q_w = q_size
        # self.kv_h, self.kv_w = kv_size
        self.kv_h, self.kv_w = self.q_h // stride, self.q_w // stride
        self.nc = n_head_channels * n_heads
        self.n_groups = n_groups
        self.n_group_channels = self.nc // self.n_groups
        self.n_group_heads = self.n_heads // self.n_groups
        self.use_pe = use_pe
        self.fixed_pe = fixed_pe
        self.no_off = no_off
        self.offset_range_factor = offset_range_factor
        self.ksize = ksize
        self.log_cpb = log_cpb
        self.stride = stride
        kk = self.ksize
        pad_size = kk // 2 if kk != stride else 0

        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
            LayerNormProxy(self.n_group_channels),
            nn.GELU(),
            nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
        )

        if self.no_off:
            for m in self.conv_offset.parameters():
                m.requires_grad_(False)

        self.proj_q = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_k = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0)

        self.proj_v = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )
        self.proj_out = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_drop = nn.Dropout(proj_drop, inplace=True)
        self.attn_drop = nn.Dropout(attn_drop, inplace=True)

        if self.use_pe and not self.no_off:
            if self.dwc_pe:
                self.rpe_table = nn.Conv2d(
                    self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc)
            elif self.fixed_pe:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)
                )
                trunc_normal_(self.rpe_table, std=0.01)
            elif self.log_cpb:
                # Borrowed from Swin-V2
                self.rpe_table = nn.Sequential(
                    nn.Linear(2, 32, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Linear(32, self.n_group_heads, bias=False)
                )
            else:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
                )
                trunc_normal_(self.rpe_table, std=0.01)
        else:
            self.rpe_table = None

    @torch.no_grad()
    def _get_ref_points(self, H_key, W_key, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
            torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # B * g H W 2

        return ref

    @torch.no_grad()
    def _get_q_grid(self, H, W, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.arange(0, H, dtype=dtype, device=device),
            torch.arange(0, W, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # B * g H W 2

        return ref

    def forward(self, x):
        x = x
        B, C, H, W = x.size()
        dtype, device = x.dtype, x.device

        q = self.proj_q(x)
        q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)
        offset = self.conv_offset(q_off).contiguous()  # B * g 2 Hg Wg
        Hk, Wk = offset.size(2), offset.size(3)
        n_sample = Hk * Wk

        if self.offset_range_factor >= 0 and not self.no_off:
            offset_range = torch.tensor([1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(1, 2, 1, 1)
            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset = einops.rearrange(offset, 'b p h w -> b h w p')
        reference = self._get_ref_points(Hk, Wk, B, dtype, device)

        if self.no_off:
            offset = offset.fill_(0.0)

        if self.offset_range_factor >= 0:
            pos = offset + reference
        else:
            pos = (offset + reference).clamp(-1., +1.)

        if self.no_off:
            x_sampled = F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride)
            assert x_sampled.size(2) == Hk and x_sampled.size(3) == Wk, f"Size is {x_sampled.size()}"
        else:
            x_sampled = F.grid_sample(
                input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
                grid=pos[..., (1, 0)],  # y, x -> x, y
                mode='bilinear', align_corners=True)  # B * g, Cg, Hg, Wg

        x_sampled = x_sampled.reshape(B, C, 1, n_sample)
        # self.proj_k.weight = torch.nn.Parameter(self.proj_k.weight.float())
        # self.proj_k.bias = torch.nn.Parameter(self.proj_k.bias.float())
        # self.proj_v.weight = torch.nn.Parameter(self.proj_v.weight.float())
        # self.proj_v.bias = torch.nn.Parameter(self.proj_v.bias.float())
        # 检查权重的数据类型
        q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)

        k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
        v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)

        attn = torch.einsum('b c m, b c n -> b m n', q, k)  # B * h, HW, Ns
        attn = attn.mul(self.scale)

        if self.use_pe and (not self.no_off):

            if self.dwc_pe:
                residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels,
                                                                              H * W)
            elif self.fixed_pe:
                rpe_table = self.rpe_table
                attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                attn = attn + attn_bias.reshape(B * self.n_heads, H * W, n_sample)
            elif self.log_cpb:
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (
                        q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups,
                                                                                               n_sample,
                                                                                               2).unsqueeze(1)).mul(
                    4.0)  # d_y, d_x [-8, +8]
                displacement = torch.sign(displacement) * torch.log2(torch.abs(displacement) + 1.0) / np.log2(8.0)
                attn_bias = self.rpe_table(displacement)  # B * g, H * W, n_sample, h_g
                attn = attn + einops.rearrange(attn_bias, 'b m n h -> (b h) m n', h=self.n_group_heads)
            else:
                rpe_table = self.rpe_table
                rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (
                        q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups,
                                                                                               n_sample,
                                                                                               2).unsqueeze(1)).mul(
                    0.5)
                attn_bias = F.grid_sample(
                    input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads,
                                           g=self.n_groups),
                    grid=displacement[..., (1, 0)],
                    mode='bilinear', align_corners=True)  # B * g, h_g, HW, Ns

                attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
                attn = attn + attn_bias

        attn = F.softmax(attn, dim=2)
        attn = self.attn_drop(attn)

        out = torch.einsum('b m n, b c n -> b c m', attn, v)

        if self.use_pe and self.dwc_pe:
            out = out + residual_lepe
        out = out.reshape(B, C, H, W)

        y = self.proj_drop(self.proj_out(out))
        h, w = pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)

        return y

3、在 ultralytics\ultralytics\nn\tasks.py文件中开头引入DAttention。

并在该文件 def parse_model模块中加入DAttention注意力机制代码:

        elif m in {DAttention}:
            c2 = ch[f]
            args = [c2, *args]

4、创建yolov8+DAttention的yaml文件:

(可根据自己的需求选择DAttention注意力机制插入的位置,本文以插入yolov8结构中池化层SPPF后边为例)

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 2  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
  - [-1, 1, DAttention, [[20, 20]]]

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 13

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 19 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 22 (P5/32-large)

  - [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

四、模型验证

可以看出模型中已经包含DAttention注意力机制。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1345459.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

模型 安索夫矩阵

本系列文章 主要是 分享模型,涉及各个领域,重在提升认知。产品市场战略。 1 安索夫矩阵的应用 1.1 江小白的多样化经营策略 使用安索夫矩阵来分析江小白市场战略。具体如下: 根据安索夫矩阵,江小白的现有产品是其白酒产品&…

西北工业大学计算机组成原理实验报告——verilog前两次

说明 为了有较好的可读性,报告仅仅粘贴关键代码。该PDF带有大纲功能,点击大纲中的对应标题,可以快速跳转。 实验目标 掌握单周期CPU执行指令的流程和原理;学习使用verilog HDL语言实现单周期CPU, 并通过功能仿真;提…

leetcode 315. 计算右侧小于当前元素的个数(hard)【小林优质解法】

链接:力扣(LeetCode)官网 - 全球极客挚爱的技术成长平台 代码: class Solution {int[]counts; //用来存储结果int[]index; //用来绑定数据和原下标int[]helpNums; //用于辅助排序 nums 数组int[]helpIndex; //用于辅助排序 i…

Javaweb之Mybatis入门的详细解析

Mybatis入门 前言 在前面我们学习MySQL数据库时,都是利用图形化客户端工具(如:idea、datagrip),来操作数据库的。 在客户端工具中,编写增删改查的SQL语句,发给MySQL数据库管理系统,由数据库管理系统执行S…

分享好用稳定快递查询api接口(对接简单)

提供实时查询和自动识别单号信息。稳定高效,调用简单方便,性价比高,一条链接即可对接成功。 使用数据平台该API接口需要先注册后申请此API接口。申请成功后赠送免费次数,可直接在线请求接口数据。 接口地址:https://…

迭代归并:归并排序非递归实现解析

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《数据结构&算法》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 📋 前言 归并排序的思想上我们已经全部介绍完了,但是同时也面临和快速排序一样的问题那就是递…

『亚马逊云科技产品测评』活动征文|云服务器如何快速搭建个人博客(图文详解)

授权声明:本篇文章授权活动官方亚马逊云科技文章转发、改写权,包括不限于在 Developer Centre, 知乎,自媒体平台,第三方开发者媒体等亚马逊云科技官方渠道 文章目录 引言一、前期准备步骤1.1 准备一个亚马逊 EC2 服务器1.2 进入控…

Stimulsoft BI Designer 2024.1.2 Crack

Stimulsoft BI Designer Do you want to create reports and dashboards, but you do not need to create your application? We want to offer you a great solution – Stimulsoft Designer! All you need is to connect your data, drag it onto the template page, config…

uniapp打包Android、Ios、微信小程序

首先我们需要在我们的代码中,把我们所要用到的配置信息配置好,在检查一下我们测试的内容是否有打开(取消注释),在检查一下我们的版本信息是否正确,查看一下接口ip是否是正式线 这里的配置信息一定要配置好…

二进制、位运算和掩码运算,小白鼠测试示例

1. 二进制 二进制是一种基于两个数字0和1的数制系统。它可以表示两种状态,即开和关。所有输入电脑的任何信息最终都要转化为二进制。目前通用的是ASCII码。最基本的单位为bit。 在计算机科学中,二进制是最常用的数制系统,因为计算机内部的所…

【MySQL】主从异步复制配置

您好,我是码农飞哥(wei158556),感谢您阅读本文,欢迎一键三连哦。 💪🏻 1. Python基础专栏,基础知识一网打尽,9.9元买不了吃亏,买不了上当。 Python从入门到精…

解锁大数据世界的钥匙——Hadoop HDFS安装与使用指南

目录 1、前言 2、Hadoop HDFS简介 3、Hadoop HDFS安装与配置 4、Hadoop HDFS使用 5、结语 1、前言 大数据存储与处理是当今数据科学领域中最重要的任务之一。随着互联网的迅速发展和数据量的爆炸性增长,传统的数据存储和处理方式已经无法满足日益增长的需求。…

LabVIEW在大型风电机组状态监测系统开发中的应用

LabVIEW在大型风电机组状态监测系统开发中的应用 风电作为一种清洁能源,近年来在全球范围内得到了广泛研究和开发。特别是大型风力发电机组,由于其常常位于边远地区如近海、戈壁、草原等,面临着恶劣自然环境和复杂设备运维挑战。为了提高风电…

C语言函数篇——scanf()函数介绍

好的,让我们以输入/输出函数中的scanf()为例,来详细介绍并展示其应用案例。 scanf()函数介绍: scanf()函数是C语言中用于从标准输入(通常是键盘)读取数据的函数。它可以从用户处获取各种类型的数据,并将其…

11|代理(上):ReAct框架,推理与行动的协同

11|代理(上):ReAct框架,推理与行动的协同 在之前介绍的思维链(CoT)中,我向你展示了 LLMs 执行推理轨迹的能力。在给出答案之前,大模型通过中间推理步骤(尤其…

QT上位机开发(倒计时软件)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 倒计时软件是生活中经常遇到的一种场景。比如运动跑步,比如学校考试,比如论文答辩等等,只要有时间限制规定的地…

电影《海王2》观后感

上周看了电影《海王2》,整体特效和打斗还是非常不错的,自己在写文章的时候,看完电影已经一周了,相当于是叙事自我在描述这段经历。 (1)体验自我VS叙事自我 首先简单说明下“体验自我”和“叙事自我”&…

Redis经典五大类型源码及底层实现(二)

👏作者简介:大家好,我是爱吃芝士的土豆倪,24届校招生Java选手,很高兴认识大家📕系列专栏:Spring源码、JUC源码、Kafka原理、分布式技术原理、数据库技术🔥如果感觉博主的文章还不错的…

Cuk、Zeta和Sepic开关电源拓扑结构

Cuk、Zeta和Sepic变换器,三种拓扑结构大致类似。不同点在于电感和二极管,MOS管的位置关系的变化。 Cuk电源是一种非隔离的直流电源转换器,其基本结构包括输入滤波电容、开关管、输入电感、输出电感和输出电容等元件。Cuk电路可以看作是Boost和Buck电路的…

【CFP-专栏2】计算机类SCI优质期刊汇总(含IEEE/Top)

一、计算机区块链类SCI-IEEE 【期刊概况】IF:4.0-5.0, JCR2区,中科院2区; 【大类学科】计算机科学; 【检索情况】SCI在检; 【录用周期】3-5个月左右录用; 【截稿时间】12.31截稿; 【接收领域】区块链…