YOLOv5修改注意力机制CBAM

news2024/12/24 8:57:27

直接上干货

CBAM注意力机制是由通道注意力机制(channel)和空间注意力机制(spatial)组成。

传统基于卷积神经网络的注意力机制更多的是关注对通道域的分析,局限于考虑特征图通道之间的作用关系。CBAM从 channel 和 spatial 两个作用域出发,引入空间注意力和通道注意力两个分析维度,实现从通道到空间的顺序注意力结构。空间注意力可使神经网络更加关注图像中对分类起决定作用的像素区域而忽略无关紧要的区域,通道注意力则用于处理特征图通道的分配关系,同时对两个维度进行注意力分配增强了注意力机制对模型性能的提升效果。
 

CBAM中的通道注意力机制模块流程图如下。先将输入特征图分别进行全局最大池化和全局平均池化,对特征映射基于两个维度压缩,获得两张不同维度的特征描述。池化后的特征图共用一个多层感知器网络,先通过一个全连接层下降通道数,再通过另一个全连接恢复通道数。将两张特征图在通道维度堆叠,经过 sigmoid 激活函数将特征图的每个通道的权重归一化到0-1之间。将归一化后的权重和输入特征图相乘。

yaml 配置文件如下

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 6  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, CBAM, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 10
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 14

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 18 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 15], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 21 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 11], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 24 (P5/32-large)

   [[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

common加入以下代码

# CBAM
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        # (特征图的大小-算子的size+2*padding)/步长+1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 1*h*w
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        # 2*h*w
        x = self.conv(x)
        # 1*h*w
        return self.sigmoid(x)


class CBAM(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, ratio=16, kernel_size=7):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(c1, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        out = self.channel_attention(x) * x
        # c*h*w
        # c*h*w * 1*h*w
        out = self.spatial_attention(out) * out
        return out

YOLO 的

parse_model 注册

到此完成

后续会给大家讲解YOLOv8怎么修改

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/861606.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Opencv将数据保存到xml、yaml / 从xml、yaml读取数据

Opencv将数据保存到xml、yaml / 从xml、yaml读取数据 Opencv提供了读写xml、yaml的类实现: 本文重点参考:https://blog.csdn.net/cd_yourheart/article/details/122705776?spm1001.2014.3001.5506,并将给出文件读写的具体使用实例。 1. 官…

Android多渠道打包+自动签名工具 [原创]

多渠道打包自动签名工具 [原创] github源码:github.com/G452/apk-packer 如果觉得有帮助可以点个小星星支持一下,万分感谢! 使用步骤: 1、在apk-packer.exe目录内放入打包需要的配置: 1)签名文件.jks2&am…

undefined reference to `dlopen‘ ‘SSL_library_init‘ `X509_certificate_type‘

使用Crow的时候需要注意crow依赖asio依赖OpenSSL,asio要求1.22以上版本,我使用的是1.26.0; 这个版本的asio要求OpenSSL是1.0.2,其他版本我得机器上编不过,ubuntu上默认带的OpenSSL是1.1.1; 所以我下载了OPENSSL1.2.0重…

【Linux】TCP协议的相关实验——深入理解

TCP相关实验 理解CLOSE_WAIT状态 当客户端和服务器在进行TCP通信时,如果客户端调用close函数关闭对应的文件描述符,此时客户端底层操作系统就会向服务器发起FIN请求,服务器收到该请求后会对其进行ACK响应。 但如果当服务器收到客户端的FIN…

【LeetCode每日一题】——205.同构字符串

文章目录 一【题目类别】二【题目难度】三【题目编号】四【题目描述】五【题目示例】六【题目提示】七【解题思路】八【时间频度】九【代码实现】十【提交结果】 一【题目类别】 哈希表 二【题目难度】 简单 三【题目编号】 205.同构字符串 四【题目描述】 给定两个字符…

Appium - 移动端自动测试框架,如何入门?

Appium是一个开源跨平台移动应用自动化测试框架。 既然只是想学习下Appium如何入门,那么我们就直奔主题。文章结构如下: 1、为什么要使用Appium? 2、如何搭建Appium工具环境?(超详细) 3、通过demo演示Appium的使用 4、Appium如何…

【C++常见八股1】内存布局 | 参数压栈 | 构造析构调用 | 空类大小

内存布局 .text 代码段:存放二进制代码、字符串常量.data 段:存放已初始化全局变量、静态变量、常量.bss 段:未初始化全局变量,未初始化静态变量heap 堆区:new/malloc 手动分配的内存,需要手动释放stack 栈…

竞赛项目 深度学习图像风格迁移

文章目录 0 前言1 VGG网络2 风格迁移3 内容损失4 风格损失5 主代码实现6 迁移模型实现7 效果展示8 最后 0 前言 🔥 优质竞赛项目系列,今天要分享的是 🚩 深度学习图像风格迁移 - opencv python 该项目较为新颖,适合作为竞赛课题…

Redis_概述

1.redis概述 1.1 简介 截止到2021年12月 数据库排名https://db-engines.com/en/ranking redis(Remote Dictionary Server) 一个开源的key-value存储系统它支持存储的Value类型:包括String(字符串),list(链表),set(集合),zset(sorted set 有序集合),hash(哈希类型…

【C++11】列表初始化 | decltype操作符 | nullptr | STL的更新

文章目录 一.列表初始化1. 花括号初始化2. initializer_list 二.decltype三.nullptr四.STL的更新1.STL新增容器2.字符串转换函数3.容器中的一些新方法 一.列表初始化 1. 花括号初始化 { }的初始化 C98中,标准允许使用大括号{}对数组或者结构体元素进行统一的列表初…

Unity游戏源码分享-俄罗斯方块unity2017

Unity游戏源码分享-俄罗斯方块unity2017 工程地址: https://download.csdn.net/download/Highning0007/88204011

STM32 F103C8T6学习笔记3:串口配置—串口收发—自定义Printf函数

今日学习使用STM32 C8T6的串口,我们在经过学习笔记2的总结归纳可知,STM32 C8T6最小系统板上有三路串口,如下图: 今日我们就着手学习如何配置开通这些串口进行收发,这里不讲串口通信概念与基础,可以自行网上…

突破笔试:力扣全排列(medium)

1. 题目链接:46. 全排列 2. 题目描述:给定一个不含重复数字的数组 nums ,返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1: 输入:nums [1,2,3] 输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[…

Redis_安装、启动以及基本命令

2.Redis安装 2.1前置处理环境 VMware安装安装centOS的linux操作系统xshellxftp 2.2 配置虚拟机网络 按ctrlaltf2 切换到命令行 cd (/)目录 修改/etc/sysconfig/network-scripts/ifcfg-ens3 vi 命令 按insert表示插入 按ctrlesc退出修改状态 :wq 写入并退出 此文件必须保持一…

linux鲁班猫代码初尝试[编译镜像][修改根文件系统重编译]

编译镜像 官方百度云盘资料:https://doc.embedfire.com/linux/rk356x/quick_start/zh/latest/quick_start/baidu_cloud/baidu_cloud.html 解压虚拟机压缩包:"鲁班猫\8-SDK源码压缩包\开发环境虚拟机镜像\ubuntu20.04.7z"后既可以用VMware打开,打开后可以看到已经有…

探索数据之美:初步学习 Python 柱状图绘制

文章目录 一 基础柱状图1.1 创建简单柱状图1.2 反转x和y轴1.3 数值标签在右侧1.4 演示结果 二 基础时间线柱状图2.1 创建时间线2.2 时间线主题设置取值表2.3 演示结果 三 GDP动态柱状图绘制3.1 需求分析3.2 数据文件内容3.3 列表排序方法3.4 参考代码3.5 运行结果 一 基础柱状图…

libheif—— 1、Vs2017搭建libheif开发环境

HEIF(高效图像文件格式) 一种图片有损压缩格式,它的后缀名通常为".heic"或".heif"。 HEIF 是由运动图像专家组 (MPEG) 标准化的视觉媒体容器格式,用于存储和共享图像和图像序列。它基于…

第二章:CSS基础进阶-part1:CSS高级选择器

文章目录 一、 组合选择器二、属性选择器三、伪类选择器1、动态伪类选择器2、状态伪类选择器3、结构性伪类选择器4、否定伪类选择器 一、 组合选择器 后代选择器:E F子元素选择器: E>F相邻兄弟选择器:EF群组选择器:多个选择器…

模板的进阶

目录 1.非类型模板参数 2.模板特化 2.1概念 2.2函数模板特化 2.3类模板特化 2.3.1全特化 2.3.2偏特化 3.模板分离编译 3.1什么是分离编译 3.2 模板的分离编译 3.3解决方法 4. 模板总结 1.非类型模板参数 模板参数分类类型形参与非类型形参。 类型形参即&#xff1a…

smtplib.SMTPHeloError: (500, b‘Error: bad syntax‘)

如果你编写邮件收发工具的时候,有可能会遇到这个问题。这里直接给出解决办法。 目录 1、检查系统版本 2、点击右侧的更改适配器选项