YOLOv5改进 | 注意力篇 | 利用YOLO-Face提出的SEAM注意力机制优化物体遮挡检测(附代码 + 修改教程)

news2024/11/25 22:38:43

一、本文介绍

本文给大家带来的改进机制是由YOLO-Face提出能够改善物体遮挡检测的注意力机制SEAM,SEAM(Spatially Enhanced Attention Module)注意力网络模块旨在补偿被遮挡面部的响应损失,通过增强未遮挡面部的响应来实现这一目标,其希望通过学习遮挡面和未遮挡面之间的关系来改善遮挡情况下的损失从而达到改善物体遮挡检测的效果,本文将通过介绍其主要原理后,提供该机制的代码和修改教程,并附上运行的yaml文件和运行代码,小白也可轻松上手。。

欢迎大家订阅我的专栏一起学习YOLO! 

  专栏回顾:YOLOv5改进专栏——持续复现各种顶会内容——内含100+创新

目录

一、本文介绍

二、原理介绍

2.1 遮挡改进

2.2 SEAM模块

2.3 排斥损失 

三、核心代码

四、添加教程

 4.1 修改一

4.2 修改二 

4.3 修改三 

4.4 修改四 

五、SEAM的yaml文件和运行记录

5.1 SEAM的yaml文件

5.2 MultiSEAM的yaml文件

5.3 训练过程截图 

五、本文总结


二、原理介绍

2.1 遮挡改进

本文重点介绍遮挡改进,其主要体现在两个方面:注意力网络模块(SEAM)排斥损失(Repulsion Loss)

1. SEAM模块:SEAM(Spatially Enhanced Attention Module)注意力网络模块旨在补偿被遮挡面部的响应损失,通过增强未遮挡面部的响应来实现这一目标。SEAM模块通过深度可分离卷积和残差连接的组合来实现,其中深度可分离卷积按通道进行操作,虽然可以学习不同通道的重要性并减少参数量,但忽略了通道间的信息关系。为了弥补这一损失,不同深度卷积的输出通过点对点(1x1)卷积组合。然后使用两层全连接网络融合每个通道的信息,以增强所有通道之间的联系。这种模型希望通过学习遮挡面和未遮挡面之间的关系,来弥补遮挡情况下的损失。

2. 排斥损失(Repulsion Loss):一种设计来处理面部遮挡问题的损失函数。具体来说,排斥损失被分为两部分:RepGT和RepBox。RepGT的功能是使当前的边界框尽可能远离周围的真实边界框,而RepBox的目的是使预测框尽可能远离周围的预测框,从而减少它们之间的IOU,以避免某个预测框被NMS抑制,从而属于两个面部。


2.2 SEAM模块

下图展示了SEAM(Separated and Enhancement Attention Module)的架构以及CSMM(Channel and Spatial Mixing Module)的结构

左侧是SEAM的整体架构,包括三个不同尺寸(patch-6、patch-7、patch-8)的CSMM模块。这些模块的输出进行平均池化,然后通过通道扩展(Channel exp)操作,最后相乘以提供增强的特征表示。右侧是CSMM模块的详细结构,它通过不同尺寸的patch来利用多尺度特征,并使用深度可分离卷积来学习空间维度和通道之间的相关性。模块包括了以下元素:

(a)Patch Embedding:对输入的patch进行嵌入。
(b)GELU:Gaussian Error Linear Unit,一种激活函数。
(c)BatchNorm:批量归一化,用于加速训练过程并提高性能。
(d)Depthwise Convolution:深度可分离卷积,对每个输入通道分别进行卷积操作。
(f)Pointwise Convolution:逐点卷积,其使用1x1的卷积核来融合深度可分离卷积的特征。

这种模块设计旨在通过对空间维度和通道的细致处理,从而增强网络对遮挡面部特征的注意力和捕捉能力。通过综合利用多尺度特征和深度可分离卷积,CSMM在保持计算效率的同时,提高了特征提取的精确度。这对于面部检测尤其重要,因为面部特征的大小、形状和遮挡程度可以在不同情况下大相径庭。通过SEAM和CSMM,YOLO-FaceV2提高了模型对复杂场景中各种面部特征的识别能力。


2.3 排斥损失 

排斥损失(Repulsion Loss)是一种用于处理面部检测中遮挡问题的损失函数。在面部检测中,类内遮挡可能会导致一个面部包含另一个面部的特征,从而增加错误检测率。排斥损失能够有效地通过排斥效应来缓解这一问题。排斥损失被分为两个部分:RepGTRepBox

(a)RepGT损失:其功能是使当前边界框尽可能远离周围的真实边界框。这里的“周围真实边界框”指的是与除了要预测的边界框外的面部标签具有最大IoU的那个边界框。RepGT损失的计算方法如下:

L_{\text{RepGT}} = \sum_{P \in P^+} \text{SmoothLn}(\text{IoG}(P, G_{\text{Rep}}))

其中,P代表面部预测框,G_{\text{Rep}}是周围具有最大IoU的真实边界框。这里的IoG(Intersection over Ground truth)定义为\frac{\text{area}(P \cap G)}{\text{area}(G)},且其值范围在0到1之间。SmoothLn是一个连续可导的对数函数,\sigma是一个在[0,1)范围内的平滑参数,用于调整排斥损失对异常值的敏感度。

(b)RepBox损失:其目的是使预测框尽可能远离周围的预测框,从而减少它们之间的IOU,以避免一个预测框因NMS(非最大抑制)而被压制,并归属于两个面部。预测框被分成多个组,不同组之间的预测框对应不同的面部标签。对于不同组之间的预测框p_ip_j,希望它们之间的重叠面积尽可能小。RepBox也使用SmoothLn作为优化函数。

L_{\text{RepBox}} = \sum_{i \neq j} \text{SmoothLn}(\text{IoU}(B_{p_i}, B_{p_j}))

排斥损失通过使边界框之间保持距离,减少预测框之间的重叠,从而提高面部检测在遮挡情况下的准确性。


三、核心代码

代码的使用方式看章节四!

import torch
import torch.nn as nn

__all__ = ['SEAM', 'MultiSEAM']

class Residual(nn.Module):
    def __init__(self, fn):
        super(Residual, self).__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

class SEAM(nn.Module):
    def __init__(self, c1, c2, n, reduction=16):
        super(SEAM, self).__init__()
        if c1 != c2:
            c2 = c1
        self.DCovN = nn.Sequential(
            # nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1, groups=c1),
            # nn.GELU(),
            # nn.BatchNorm2d(c2),
            *[nn.Sequential(
                Residual(nn.Sequential(
                    nn.Conv2d(in_channels=c2, out_channels=c2, kernel_size=3, stride=1, padding=1, groups=c2),
                    nn.GELU(),
                    nn.BatchNorm2d(c2)
                )),
                nn.Conv2d(in_channels=c2, out_channels=c2, kernel_size=1, stride=1, padding=0, groups=1),
                nn.GELU(),
                nn.BatchNorm2d(c2)
            ) for i in range(n)]
        )
        self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(c2, c2 // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c2 // reduction, c2, bias=False),
            nn.Sigmoid()
        )

        self._initialize_weights()
        # self.initialize_layer(self.avg_pool)
        self.initialize_layer(self.fc)


    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.DCovN(x)
        y = self.avg_pool(y).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        y = torch.exp(y)
        return x * y.expand_as(x)

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_uniform_(m.weight, gain=1)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def initialize_layer(self, layer):
        if isinstance(layer, (nn.Conv2d, nn.Linear)):
            torch.nn.init.normal_(layer.weight, mean=0., std=0.001)
            if layer.bias is not None:
                torch.nn.init.constant_(layer.bias, 0)


def DcovN(c1, c2, depth, kernel_size=3, patch_size=3):
    dcovn = nn.Sequential(
        nn.Conv2d(c1, c2, kernel_size=patch_size, stride=patch_size),
        nn.SiLU(),
        nn.BatchNorm2d(c2),
        *[nn.Sequential(
            Residual(nn.Sequential(
                nn.Conv2d(in_channels=c2, out_channels=c2, kernel_size=kernel_size, stride=1, padding=1, groups=c2),
                nn.SiLU(),
                nn.BatchNorm2d(c2)
            )),
            nn.Conv2d(in_channels=c2, out_channels=c2, kernel_size=1, stride=1, padding=0, groups=1),
            nn.SiLU(),
            nn.BatchNorm2d(c2)
        ) for i in range(depth)]
    )
    return dcovn

class MultiSEAM(nn.Module):
    def __init__(self, c1, c2, depth, kernel_size=3, patch_size=[3, 5, 7], reduction=16):
        super(MultiSEAM, self).__init__()
        if c1 != c2:
            c2 = c1
        self.DCovN0 = DcovN(c1, c2, depth, kernel_size=kernel_size, patch_size=patch_size[0])
        self.DCovN1 = DcovN(c1, c2, depth, kernel_size=kernel_size, patch_size=patch_size[1])
        self.DCovN2 = DcovN(c1, c2, depth, kernel_size=kernel_size, patch_size=patch_size[2])
        self.avg_pool = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(c2, c2 // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(c2 // reduction, c2, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y0 = self.DCovN0(x)
        y1 = self.DCovN1(x)
        y2 = self.DCovN2(x)
        y0 = self.avg_pool(y0).view(b, c)
        y1 = self.avg_pool(y1).view(b, c)
        y2 = self.avg_pool(y2).view(b, c)
        y4 = self.avg_pool(x).view(b, c)
        y = (y0 + y1 + y2 + y4) / 4
        y = self.fc(y).view(b, c, 1, 1)
        y = torch.exp(y)
        return x * y.expand_as(x)


四、添加教程

 4.1 修改一

第一还是建立文件,我们找到如下yolov5-master/models文件夹下建立一个目录名字呢就是'modules'文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。


4.2 修改二 

第二步我们在该目录下创建一个新的py文件名字为'__init__.py'(用群内的文件的话已经有了无需新建),然后在其内部导入我们的检测头如下图所示。


4.3 修改三 

第三步我门中到如下文件'yolov5-master/models/yolo.py'进行导入和注册我们的模块(用群内的文件的话已经有了无需重新导入直接开始第四步即可)

从今天开始以后的教程就都统一成这个样子了,因为我默认大家用了我群内的文件来进行修改!!


4.4 修改四 

按照我的添加在parse_model里添加即可。

到此就修改完成了,大家可以复制下面的yaml文件运行。


五、SEAM的yaml文件和运行记录

5.1 SEAM的yaml文件

# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32


# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [
    [-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
    [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
    [-1, 3, C3, [128]],
    [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
    [-1, 6, C3, [256]],
    [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
    [-1, 9, C3, [512]],
    [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
    [-1, 3, C3, [1024]],
    [-1, 1, SPPF, [1024, 5]], # 9
  ]

# YOLOv5 v6.0 head
head: [
    [-1, 1, Conv, [512, 1, 1]],
    [-1, 1, nn.Upsample, [None, 2, "nearest"]],
    [[-1, 6], 1, Concat, [1]], # cat backbone P4
    [-1, 3, C3, [512, False]], # 13

    [-1, 1, Conv, [256, 1, 1]],
    [-1, 1, nn.Upsample, [None, 2, "nearest"]],
    [[-1, 4], 1, Concat, [1]], # cat backbone P3
    [-1, 3, C3, [256, False]], # 17 (P3/8-small)

    [-1, 1, Conv, [256, 3, 2]],
    [[-1, 14], 1, Concat, [1]], # cat head P4
    [-1, 3, C3, [512, False]], # 20 (P4/16-medium)

    [-1, 1, Conv, [512, 3, 2]],
    [[-1, 10], 1, Concat, [1]], # cat head P5
    [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
    [-1, 1, SEAM, [1024, 1]],

    [[17, 20, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  ]

5.2 MultiSEAM的yaml文件

# YOLOv5 🚀 by Ultralytics, AGPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32


# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [
    [-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
    [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
    [-1, 3, C3, [128]],
    [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
    [-1, 6, C3, [256]],
    [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
    [-1, 9, C3, [512]],
    [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
    [-1, 3, C3, [1024]],
    [-1, 1, SPPF, [1024, 5]], # 9
  ]

# YOLOv5 v6.0 head
head: [
    [-1, 1, Conv, [512, 1, 1]],
    [-1, 1, nn.Upsample, [None, 2, "nearest"]],
    [[-1, 6], 1, Concat, [1]], # cat backbone P4
    [-1, 3, C3, [512, False]], # 13

    [-1, 1, Conv, [256, 1, 1]],
    [-1, 1, nn.Upsample, [None, 2, "nearest"]],
    [[-1, 4], 1, Concat, [1]], # cat backbone P3
    [-1, 3, C3, [256, False]], # 17 (P3/8-small)
    [-1, 1, MultiSEAM, [256, 1]],

    [-1, 1, Conv, [256, 3, 2]],
    [[-1, 14], 1, Concat, [1]], # cat head P4
    [-1, 3, C3, [512, False]], # 21 (P4/16-medium)
    [-1, 1, MultiSEAM, [512, 1]], # 22


    [-1, 1, Conv, [512, 3, 2]],
    [[-1, 10], 1, Concat, [1]], # cat head P5
    [-1, 3, C3, [1024, False]], # 25 (P5/32-large)
    [-1, 1, MultiSEAM, [1024, 1]], # 26

    [[18, 22, 26], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  ]


5.3 训练过程截图 


五、本文总结

到此本文的正式分享内容就结束了,在这里给大家推荐我的YOLOv8改进有效涨点专栏,本专栏目前为新开的平均质量分98分,后期我会根据各种最新的前沿顶会进行论文复现,也会对一些老的改进机制进行补充,如果大家觉得本文帮助到你了,订阅本专栏,关注后续更多的更新~

  专栏回顾:YOLOv5改进专栏——持续复现各种顶会内容——内含100+创新

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1511875.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

el-Switch 开关二次确认

前言 最近在做毕设,有个需求是点击按钮控制用户的状态是否禁用,就看到element有个switch组件可以改造一下,就上网看了一下,结果为了这个效果忙活了很久。。。所以说记录一下,让大家少踩坑。 前置条件 先看完我的需求再…

Eclipse安装springboot

Eclipse免费,套件丰富,代码开源,功能强大…推荐! 1 下载eclipse: https://www.eclipse.org/downloads/download.php?file/technology/epp/downloads/release/2023-12/R/eclipse-jee-2023-12-R-win32-x86_64.zip 2 安装Spring框…

图论(蓝桥杯 C++ 题目 代码 注解)

目录 迪杰斯特拉模板(用来求一个点出发到其它点的最短距离): 克鲁斯卡尔模板(用来求最小生成树): 题目一(蓝桥王国): 题目二(随机数据下的最短路径&#…

Hadoop学习2:完全分布集群搭建

文章目录 Fully-Distributed Operation(完全分布模式) 重点机器环境同步集群规划配置文件修改以及同步步骤0:下面其他步骤涉及修改配置以这里为准(要不然部署使用过程会有很多问题)通用配置(三台节点机器&a…

社交革命的引领者:探索Facebook如何改变我们的生活方式

1.数字社交的兴起 随着互联网的普及,社交媒体成为我们日常生活的重要组成部分。Facebook作为其中的先驱,从最初的社交网络演变成了一个拥有数十亿用户的全球化平台。它不仅改变了我们与世界互动的方式,还深刻影响了我们的社交习惯、人际关系以…

华为机考:HJ102 字符统计

华为机考&#xff1a;HJ102 字符统计 描述 方法1 先将所有字符计算数量&#xff0c;在对比其中字符的assic码 #include<iostream> #include<vector> #include<algorithm> #include<string> using namespace std; bool cmp(pair<char, int> a,…

C++ std::list的merge()使用与分析

看到《C标准库第2版》对list::merge()的相关介绍&#xff0c;令我有点迷糊&#xff0c;特意敲代码验了一下不同情况的调用结果。 《C标准库第2版》对list::merge()的相关介绍 list::merge()定义 merge()的作用就是将两个list合并在一起&#xff0c;函数有2个版本&#xff1a;…

数字图像处理-空间滤波

空间滤波 空域滤波基础 – 离散卷积的边缘效应 平滑空间滤波器 # -*- coding: utf-8 -*- # Author: Huazhong Yang # Email: cjdxyhz163.com # Time : 2024/3/7 20:26import cv2 import numpy as np# 读取图像 image cv2.imread(a1.png)# 应用高斯滤波 # 第二个参数是高斯…

微信小程序开发系列(三十)·小程序本地存储API·同步和异步的区别

目录 1. 同步API 1.1 getStorageSync存储API 1.2 removeStorageSync获取数据API 1.3 removeStorageSync删除 1.4 clearStorageSync清空 2. 异步API 2.1 setStorage存储API 2.2 getStorage获取数据API 2.3 removeStorage删除API 2.4 clearStorage清空 3. …

qt vs 编程 字符编码 程序从源码到编译到显示过程中存在的字符编码

理解字符编码&#xff0c;请参考&#xff1a;unicode ucs2 utf16 utf8 ansi GBK GB2312 CSDN博客 汉字&#xff08;或者说多字节字符&#xff09;的存放需求&#xff0c;是计算机中各种编码问题的最直接原因。如果程序不直接使用汉字&#xff0c;或间接在所有操作步骤中统一使…

Hilt

1.使用Hilt实现快速依赖注入 1.1 导入依赖 //hilt依赖//Hiltimplementation("com.google.dagger:hilt-android:2.44")annotationProcessor("com.google.dagger:hilt-android-compiler:2.44")1.2 在build.gradle(app)中加入插件 plugins {id("com.an…

大规模自动化重构框架--OpenRewrite浅析

目录 1. OpenRewrite是什么&#xff1f;定位&#xff1f; 2. OpenWrite具体如何做&#xff1f; 3. 核心概念释义 3.1 Lossless Semantic Trees (LST) 无损语义树 3.2 访问器&#xff08;Visitors&#xff09; 3.3 配方&#xff08;Recipes&#xff09; 4. 参考链接 Open…

SpringBlade error/list SQL 注入漏洞复现

0x01 产品简介 SpringBlade 是一个由商业级项目升级优化而来的 SpringCloud 分布式微服务架构、SpringBoot 单体式微服务架构并存的综合型项目。 0x02 漏洞概述 SpringBlade 框架后台 /api/blade-log/error/list路径存在SQL注入漏洞,攻击者除了可以利用 SQL 注入漏洞获取数…

Redis应用缓存

目录 前言 关于“二八定律” 使用Redis作为缓存 为什么关系型数据库性能不高 为什么并发量高了就容易宕机 Redis就是一个用来作为数据库缓存的常见方案 缓存更新策略 定期生成 搜索引擎为例 实时生成 淘汰策略 FIFO(First In First Out) 先进先出 lRU(Least …

106. Dockerfile通过多阶段构建减小Golang镜像的大小

我们如何通过引入具有多阶段构建过程的Dockerfiles来减小Golang镜像的大小&#xff1f; 让我们从一个通用的Dockerfile开始&#xff0c;它负责处理基本的事务&#xff0c;如依赖项、构建二进制文件、声明暴露的端口等&#xff0c;以便为Go中的一个非常基础的REST API提供服务。…

YoloV8实战:YoloV8-World应用实战案例

摘要 YOLO-World模型确实是一个突破性的创新&#xff0c;它结合了YOLOv8框架的实时性能与开放式词汇检测的能力&#xff0c;为众多视觉应用提供了前所未有的解决方案。以下是对YOLO-World模型的进一步解读&#xff1a; 模型架构与功能 YOLO-World模型充分利用了YOLOv8框架的…

剑指offer面试题34:在二叉树中和为某一值的路径

面试题34&#xff1a;在二叉树中和为某一值的路径 题目&#xff1a; LCR 153. 二叉树中和为目标值的路径 - 力扣&#xff08;LeetCode&#xff09; 给你二叉树的根节点 root 和一个整数目标和 targetSum &#xff0c;找出所有 从根节点到叶子节点 路径总和等于给定目标和的路…

C语言 - 各种自定义数据类型

1.结构体 把不同类型的数据组合成一个整体 所占内存长度是各成员所占内存的总和 typedef struct XXX { int a; char b; }txxx; txxx data; typedef struct XXX { int a:1; int b:1; …

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:RichText)

富文本组件&#xff0c;解析并显示HTML格式文本。 说明&#xff1a; 该组件从API Version 8开始支持。后续版本如有新增内容&#xff0c;则采用上角标单独标记该内容的起始版本。该组件无法根据内容自适应设置宽高属性&#xff0c;需要开发者设置显示布局。 子组件 不包含子组…

封装的echarts子组件使用watch监听option失效的问题

项目场景&#xff1a; 我在项目里面封装了一个echarts组件&#xff0c;组件接收一个来自外部的option,然后我用了一个watch函数去监听这个option的变化&#xff0c;option变化之后&#xff0c;销毁&#xff0c;然后再新建一个charts表 碎碎念 问题如标题所示&#xff0c;这篇…