Yolov8有效涨点,添加多种注意力机制,修改损失函数提高目标检测准确率

news2024/11/26 22:50:49

目录

简介

CBAM注意力机制原理及代码实现

原理

 代码实现

 GAM注意力机制

原理

代码实现

修改损失函数

YAML文件

完整代码


🚀🚀🚀订阅专栏,更新及时查看不迷路🚀🚀🚀

http://t.csdnimg.cn/sVHxvicon-default.png?t=N7T8http://t.csdnimg.cn/sVHxv

简介

Ultralytics 推出了最新版本的 YOLO 模型。注意力机制是提高模型性能最热门的方法之一。

本次将介绍几种常见的注意力机制,这些注意力机制在大多数的数据集上均能有效的提升目标检测的精度/召回率/准确率。

CBAM注意力机制原理及代码实现
原理
CBAM注意力机制结构图

CBAM(Convolutional Block Attention Module)是一种用于卷积神经网络(CNN)的注意力机制,它能够增强网络对输入特征的关注度,提高网络性能。CBAM 主要包含两个子模块:通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)。

以下是CBAM注意力机制的基本原理:

1. 通道注意力模块(Channel Attention Module):
输入:经过卷积层的特征图。
处理步骤:
对每个通道进行全局平均池化,得到通道的全局平均值。
通过两个全连接层,将全局平均值映射为两个权重向量(一个用于缩放,一个用于偏置)。
将这两个权重向量与原始特征图相乘,以加权调整每个通道的重要性。

2. 空间注意力模块(Spatial Attention Module):**
输入:通道注意力模块的输出。
处理步骤:
     对每个通道的特征图进行分别的最大池化和平均池化,得到两个空间特征图。
     将这两个空间特征图相加,通过一个卷积层产生一个权重图。
     将原始特征图与权重图相乘,以加权调整每个空间位置的重要性。

3. 整合:
   将通道注意力模块和空间注意力模块的输出相乘,得到最终的注意力增强特征图。
   将这个注意力增强的特征图传递给网络的下一层进行进一步处理。

CBAM的关键优势在于它能够同时考虑通道和空间信息,有助于网络更好地理解和利用输入特征。这种注意力机制有助于提高网络在视觉任务上的性能,使其能够更有针对性地关注重要的特征。

 代码实现

路径:"./ultralytics/nn/modules/conv.py"

class ChannelAttention(nn.Module):
    """Channel-attention module https://github.com/open-mmlab/mmdetection/tree/v3.0.0rc1/configs/rtmdet."""

    def __init__(self, channels: int) -> None:
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(channels, channels, 1, 1, 0, bias=True)
        self.act = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * self.act(self.fc(self.pool(x)))


class SpatialAttention(nn.Module):
    """Spatial-attention module."""

    def __init__(self, kernel_size=7):
        """Initialize Spatial-attention module with kernel size argument."""
        super().__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        self.cv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.act = nn.Sigmoid()

    def forward(self, x):
        """Apply channel and spatial attention on input for feature recalibration."""
        return x * self.act(self.cv1(torch.cat([torch.mean(x, 1, keepdim=True), torch.max(x, 1, keepdim=True)[0]], 1)))


class CBAM(nn.Module):
    """Convolutional Block Attention Module."""

    def __init__(self, c1, kernel_size=7):  # ch_in, kernels
        super().__init__()
        self.channel_attention = ChannelAttention(c1)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        """Applies the forward pass through C1 module."""
        return self.spatial_attention(self.channel_attention(x))

添加完代码以后需要在"./ultralytics/nn/tasks.py"进行注册

 GAM注意力机制
原理

目标的设计是一种减少信息缩减并放大全局维度交互特征的机制。我们采用 CBAM 的顺序通道空间注意力机制并重新设计子模块。整个过程如图所示。

GAM结构图


通道注意力机制
通道注意力子模块使用 3D 排列来保留三个维度的信息。然后,它使用两层 MLP(多层感知器)放大跨维度通道空间依赖性。 (MLP是一种编码器-解码器结构,其缩减比为r,与BAM相同。)通道注意子模块如图所示。 

通道注意力子模块


空间注意力机制
在空间注意力子模块中,为了关注空间信息,我们使用两个卷积层进行空间信息融合。我们还使用与 BAM 相同的通道注意子模块的缩减率 r。同时,最大池化会减少信息并产生负面影响。我们删除池化以进一步保留特征图。因此,空间注意力模块有时会显着增加参数的数量。为了防止参数显着增加,我们在 ResNet50 中采用带有通道洗牌的组卷积。没有组卷积的空间注意力子模块如图所示。 

空间注意力子模块
代码实现

代码添加在 ./ultralytics/nn/modules/conv.py 中,同样需要在task.py中注册

class GAM_Attention(nn.Module):
    def __init__(self, c1, c2, group=True, rate=4):
        super(GAM_Attention, self).__init__()
 
        self.channel_attention = nn.Sequential(
            nn.Linear(c1, int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(c1 / rate), c1)
        )
 
        self.spatial_attention = nn.Sequential(
 
            nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(c1, int(c1 / rate),
                                                                                                     kernel_size=7,
                                                                                                     padding=3),
            nn.BatchNorm2d(int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate) if group else nn.Conv2d(int(c1 / rate), c2,
                                                                                                     kernel_size=7,
                                                                                                     padding=3),
            nn.BatchNorm2d(c2)
        )
 
    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)
        # x_channel_att=channel_shuffle(x_channel_att,4) #last shuffle
        x = x * x_channel_att
 
        x_spatial_att = self.spatial_attention(x).sigmoid()
        x_spatial_att = channel_shuffle(x_spatial_att, 4)  # last shuffle
        out = x * x_spatial_att
        # out=channel_shuffle(out,4) #last shuffle
        return out
修改损失函数

WIoU是一种新型的损失函数,代码实现

def WIoU(cls, pred, target, self=None):
        self = self if self else cls(pred, target)
        dist = torch.exp(self.l2_center / self.l2_box.detach())
        return self._scaled_loss(dist * self.iou)

 这个其实就是修改了loss.py中的BboxLoss,在本段代码的第十二行,将type改成了WIoU

class BboxLoss(nn.Module):
 
    def __init__(self, reg_max, use_dfl=False):
        """Initialize the BboxLoss module with regularization maximum and DFL settings."""
        super().__init__()
        self.reg_max = reg_max
        self.use_dfl = use_dfl
 
    def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
        """IoU loss."""
        weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)
        loss,iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False,type_='WIoU')
        loss_iou=loss.sum()/target_scores_sum
 
        # DFL loss
        if self.use_dfl:
            target_ltrb = bbox2dist(anchor_points, target_bboxes, self.reg_max)
            loss_dfl = self._df_loss(pred_dist[fg_mask].view(-1, self.reg_max + 1), target_ltrb[fg_mask]) * weight
            loss_dfl = loss_dfl.sum() / target_scores_sum
        else:
            loss_dfl = torch.tensor(0.0).to(pred_dist.device)
 
        return loss_iou, loss_dfl
YAML文件
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 9  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
 
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
 
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12
  - [-1, 1, GAM_Attention, [512,512]]
 
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 16 (P3/8-small)
  - [-1, 1, GAM_Attention, [256,256]]
 
  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 20 (P4/16-medium)
  - [-1, 1, GAM_Attention, [512,512]]
 
  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 24 (P5/32-large)
  - [-1, 1, GAM_Attention, [1024,1024]]
 
  - [[17, 21, 25], 1, Detect, [nc]]  # Detect(P3, P4, P5)

在head部分,可以将GAM_attention改成不同的注意力机制,来改变网络结构,从而提升目标检测 的精度

完整代码

链接: https://pan.baidu.com/s/1IDnEZxpcaEgBowlTxX2iNA?pwd=vdrs 提取码: vdrs 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1492859.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Claude 3 Sonnet 模型现已在亚马逊云科技的 Amazon Bedrock 正式可用!

今天,我们宣布一个激动人心的里程碑:Anthropic 的 Claude 3 Sonnet 模型现已在亚马逊云科技的 Amazon Bedrock 正式可用。 下一代 Claude (Claude 3) 的三个模型 Claude 3 Opus、Claude 3 Sonnet 和 Claude 3 Haiku 将陆续登陆 Amazon Bedrock。Amazon …

AIGC——Layer Diffusion使用潜在透明度的透明图像层扩散

前言 ControlNet的作者Lvmin Zhang大佬在新的一年又发布了新的工作LayerDiffusion,这个工作再次让人眼前一亮,和ControlNet一样,LayerDiffusion也是解决文生图中比较实际的问题,那就是生成透明的4通道RGBA图像,而且效…

在Mac上安装nginx+rtmp 本地服务器

需要使用终端命令,如果没有Homebrew,要安装Homebrew,执行: ruby -e "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install)" 一、安装Nginx 1、先clone Nginx项目到本地: brew tap de…

目标检测5:采用yolov8, RK3568上推理实时视频流

上一个效果图,海康球机对着电脑屏幕拍,清晰度不好。 RK3568接取RTSP视频流,通过解码,推理,编码,最终并把结果推出RTSP视频流。 数据集采用coco的80个种类集,通过从yovo8.pt,转换成R…

在 Flutter 中使用 flutter_gen 简化图像资产管理

你是否厌倦了在 Flutter 项目中手动管理图像资产的繁琐任务? 告别手工输入资源路径的痛苦,欢迎使用“Flutter Gen”高效资源管理的时代。在本文中,我将带您从手动处理图像资源的挫折到动态生成它们的便利。 选择1:痛苦手动添加–…

网络编程(3/4)

广播 ​ #include<myhead.h>int main(int argc, const char *argv[]) {//1、创建套接字int sfd socket(AF_INET, SOCK_DGRAM, 0);if(sfd -1){perror("socket error");return -1;}//2、将套接字设置成允许广播int broadcast 1;if(setsockopt(sfd, SOL_SOC…

基于FastAPI构造一个AI模型部署应用

前言 fastapi是目前一个比较流行的python web框架&#xff0c;在大模型日益流行的今天&#xff0c;其云端部署和应用大多数都是基于fastapi框架。所以掌握和理解fastapi框架基本代码和用法尤显重要。 需要注意的是&#xff0c;fastapi主要是通过app对象提供了web服务端的实现代…

SINAMICS V90 PN 指导手册 第7章 回参考点功能

如果伺服是增量式编码器&#xff0c;共有三种回参考点模式&#xff0c;分别是 通过数字量输入信号REF设置回参考点通过外部参考挡块和编码器零脉冲回参考点仅通过编码器零脉冲回参考点 如果伺服是绝对值编码器&#xff0c;除了这三种以外&#xff0c;还可以通过“ABS”调整绝…

利用 Redis 和 Lua 实现高效的限流功能

简介 在现代系统中&#xff0c;限流是一种重要的机制&#xff0c;用于控制服务端的流量并保护系统免受恶意攻击或请求泛滥的影响。本文将介绍如何利用 Redis 和 Lua 结合实现高效的限流功能。 一、什么是限流 限流指的是对系统中的请求进行控制和调节&#xff0c;确保系统在…

Nginx 可视化管理软件 Nginx Proxy Manager

一、简介 Nginx Proxy Manager 是一款开源的 Nginx 可视化管理界面&#xff0c;基于 Nginx 具有漂亮干净的 Web UI 界面。他允许用户通过浏览器界面轻松地管理和监控 Nginx 服务器&#xff0c;可以获得受信任的 SSL 证书&#xff0c;并通过单独的配置、自定义和入侵保护来管理…

ARM单片机中程序在ROM空间和RAM空间的分布(分散加载文件,Scatter-Loading Description File)

对于 K e i l u V i s i o n I D E Keil\quad uVision\quad IDE KeiluVisionIDE&#xff0c;程序编译好之后&#xff0c;代码的下载位置&#xff08; R O M ROM ROM空间&#xff09;以及代码运行的时候使用的 R A M RAM RAM空间&#xff08; R A M RAM RAM空间&#xff09;默认…

第二证券|重大转变!全球资金正重回中国股市!

外资巨头最新发声。 摩根士丹利在最新发布的陈述中称&#xff0c;全球资金正在重返我国股市。跟着部分基金对我国商场的看跌心情有所平缓&#xff0c;全球长时间出资者撤出我国股票商场&#xff08;A股和港股&#xff09;的举动现已按下暂停键。 按下暂停键或许是一个前期痕迹…

leetcode刷题(javaScript)——二叉树、平衡二叉树相关场景题总结

二叉树的知识点很多&#xff0c;在算法刷题中需要有想象力的数据结构了。主要是用链表存储&#xff0c;没有数组更容易理解。在刷二叉树相关算法时&#xff0c;需要注意以下几点&#xff1a; 掌握二叉树的基本概念&#xff1a;了解二叉树的基本概念&#xff0c;包括二叉树的定义…

pytorch什么是梯度

目录 1.导数、偏微分、梯度1.1 导数1.2 偏微分1.3 梯度 2. 通过梯度求极小值3. learning rate 1.导数、偏微分、梯度 1.1 导数 对于yx 2 2 2 的导数&#xff0c;描述了y随x值变化的一个变化趋势&#xff0c;导数是个标量反应的是变化的程度&#xff0c;标量的长度反应变化率的…

next/future/image图片根据外部容器100%填充

官网文档地址&#xff1a; https://nextjs.org/docs/pages/api-reference/components/image 主要需要使用属性fill。外部元素需要设置好position:relative <Imagexsrc"/images/1.jpg"fillsizes100vw />

浅谈块存储、文件存储、对象存储

**块存储、文件存储和对象存储各自有其独特的特点和适用场景**。具体来说&#xff1a; 1. **块存储**&#xff1a; - 描述&#xff1a;块存储将存储空间分割成固定大小的块&#xff0c;这些块可以直接映射到主机操作系统。它提供的是原始的存储空间&#xff0c;不带文件系统…

hive实战项目:旅游集市数仓建设

旅游集市数仓建设 文章目录 旅游集市数仓建设为什么要设计数据分层&#xff1f;分层设计ODS&#xff08;Operational Data Store&#xff09;&#xff1a;数据运营层DW&#xff08;Data Warehouse&#xff09;&#xff1a;数据仓库层DWD&#xff08;Data Warehouse Detail&…

软考63-上午题-【面向对象技术】-面向对象的基本概念2

一、动态绑定、静态绑定 1-1、动态绑定、静态绑定的定义 编译时进行的绑定 静态绑定 运行时进行的绑定 动态绑定 动态绑定&#xff0c;和类的继承和多态想关联的。 在运行过程中&#xff0c;当一个对象发送消息请求服务时&#xff0c;要根据接受对象的具体情况将请求的操作…

gitlab的安装

1、下载rpm 安装包 (1)直接命令下载 wget https://mirrors.tuna.tsinghua.edu.cn/gitlab-ce/yum/el7/gitlab-ce-11.6.10-ce.0.el7.x86_64.rpm&#xff08;2&#xff09;直接去服务器上下载包 Index of /gitlab-ce/yum/el7/ | 清华大学开源软件镜像站 | Tsinghua Open Source…

html标签之列表标签,含爱奇艺,小米,腾讯,阿里

什么是css块元素&#xff1f; 块级元素是独占一行显示的。它的兄弟元素必定不会与其在同一行中&#xff08;除非脱离了文档流&#xff09;。通俗点来说&#xff0c;就是块元素(block element)一般是其他元素的容器元素&#xff0c;能容纳其他块元素或内联元素。 css块元素的三…