Yolov5-v7.0使用CBAM注意力机制记录

news2024/11/13 15:29:53

Yolov5-v7.0使用CBAM注意力机制记录

一、CBAM实现代码

在model/common.py文件中加入如下代码:

#############CBAM注意力机制##############
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        # (特征图的大小-算子的size+2*padding)/步长+1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # 1*h*w
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        # 2*h*w
        x = self.conv(x)
        # 1*h*w
        return self.sigmoid(x)


class CBAM(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, ratio=16, kernel_size=7):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(c1, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        out = self.channel_attention(x) * x
        # c*h*w
        # c*h*w * 1*h*w
        out = self.spatial_attention(out) * out
        return out

加在代码最后就行。

二、注册CBAM机制

在model/yolo.py文件中修改。
第一处:

from models.common import

引入CBAM。

在这里插入图片描述
第二处:

def parse_model(d, ch):

继续增加CBAM
在这里插入图片描述

三、更改yolov5s.yaml文件

yolov5s.yaml文件是模型结构文件,增加CBAM机制

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license

# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors:
  - [10,13, 16,30, 33,23]  # P3/8
  - [30,61, 62,45, 59,119]  # P4/16
  - [116,90, 156,198, 373,326]  # P5/32

# YOLOv5 v6.0 backbone
backbone:
  # [from, number, module, args]
  [
   [-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2
   [-1, 1, Conv, [128, 3, 2]],  # 1-P2/4
   [-1, 3, C3, [128]],
   [-1, 1, Conv, [256, 3, 2]],  # 3-P3/8
   [-1, 6, C3, [256]],
   [-1, 1, Conv, [512, 3, 2]],  # 5-P4/16
   [-1, 9, C3, [512]],
   [-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32
   [-1, 3, C3, [1024]],
   [-1, 1, CBAM, [1024]],
   [-1, 1, SPPF, [1024, 5]],  # 10
  ]

# YOLOv5 v6.0 head
head:
  [[-1, 1, Conv, [512, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 6], 1, Concat, [1]],  # cat backbone P4
   [-1, 3, C3, [512, False]],  # 14

   [-1, 1, Conv, [256, 1, 1]],
   [-1, 1, nn.Upsample, [None, 2, 'nearest']],
   [[-1, 4], 1, Concat, [1]],  # cat backbone P3
   [-1, 3, C3, [256, False]],  # 18 (P3/8-small)

   [-1, 1, Conv, [256, 3, 2]],
   [[-1, 15], 1, Concat, [1]],  # cat head P4
   [-1, 3, C3, [512, False]],  # 21 (P4/16-medium)

   [-1, 1, Conv, [512, 3, 2]],
   [[-1, 11], 1, Concat, [1]],  # cat head P5
   [-1, 3, C3, [1024, False]],  # 24 (P5/32-large)

   [[18, 21, 24], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)
  ]

四、修改train.py文件

增加注意力机制后,运行train.py文件会报错:

RuntimeError: adaptive_max_pool2d_backward_cuda does not have a deterministic implementation, but you set 'torch.use_deterministic_algorithms(True)'. You can turn off determinism just for this operation, or you can use the 'warn_only=True' option, if that's acceptable for your application. You can also file an issue at https://github.com/pytorch/pytorch/issues to help us prioritize adding deterministic support for this operation.

需要修改train.py文件。在代码:

scaler.scale(loss).backward()

前面加上:

torch.use_deterministic_algorithms(False)

修改后如下:
在这里插入图片描述

五、训练

训练时留意一下CBAM是否生效
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1953174.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Windows】激活补丁被误删,怎么办?如何关闭Windows11安全中心中的“病毒和威胁保护”!

按下“win(徽标键)i”快捷键,选择隐私与安全性-Windows安全中心。 选择防火墙和网络保护-域保护。 将开关闭,专业网络和公用网络防火墙也同样关闭,如下图所示: 关闭防火墙后,左边菜单…

改进向量搜索-使用PostgresML和LlamaIndex重新排名

改进向量搜索-使用PostgresML和LlamaIndex重新排名 搜索和重新排名:提高结果相关性 搜索系统通常采用两种主要方法:关键字和语义。关键字搜索将精确的查询词与索引数据库内容匹配,而语义搜索使用 NLP 和机器学习来理解查询上下文和意图。许多…

【踩坑系列-Docker】基于Alibaba Cloud Linux3基础镜像安装Nginx

Author:赵志乾 Date:2024-07-26 Declaration:All Right Reserved!!! 1. 问题描述 使用Alibaba Cloud Linux3作为基础镜像,在其上安装Nginx,对应的Dockerfile内容如下: …

使用 From File 模块加载数据

目录 检查模型 创建时间和信号数据 加载 timeseries 数据 加载数组数据 加载总线数据 此示例说明如何使用 From File 模块从 MAT 文件加载仿真输入数据,包括如何创建和格式化输入数据。可以通过编程方式创建您加载的数据,加载从另一个仿真中记录的数据,或加载从…

栈和队列<数据结构 C版>

目录 栈(Stack) 栈的结构体 初始化 销毁 入栈 判空 出栈 取栈顶元素 获取栈个数 测试: 队列(Queue) 队列的结构体 单个结点 队列 初始化 销毁 入队列,队尾 判空 出队列,队头 …

【YashanDB知识库】开源调度框架Quartz写入Boolean值到YashanDB报错

问题现象 Quartz 是一个广泛应用于企业级应用中的开源作业调度框架,它主要用于在Java环境中管理和执行任务。 为了任务调度,Quartz的数据模型中使用了大量的布尔值记录任务、流程的各种状态,如: Quartz使用JDBC写入任务状态代码…

【资料分享】2024第三届钉钉杯大学生大数据挑战赛B题思路解析+双语言代码

2024钉钉杯大学生大数据挑战赛,B题解题思路和双语言代码分享,资料预览:

制作excel模板,用于管理后台批量导入船舶数据

文章目录 引言I 数据有效性:基于WPS在Excel中设置下拉框选择序列内容II 数据处理:基于easyexcel工具实现导入数据的持久化2.1 自定义枚举转换器2.2 ExcelDataConvertExceptionIII 序列格式化: 基于Sublime Text 文本编辑器进行批量字符操作引言 需求: excel数据导入模板制…

【MySQL进阶之路 | 高级篇】表级锁之S锁,X锁,意向锁

1. 从数据操作的粒度划分:表级锁,页级锁,行锁 为了尽可能提高数据库的并发度,每次锁定的数据范围越小越好,理论上每次只锁定当前操作的数据的方案会得到最大的并发度,但是管理锁是很耗资源的事情&#xff…

前端开发:HTML与CSS

文章目录 前言1.1、CS架构和BS架构1.2、网页构成 HTML1.web开发1.1、最简单的web应用程序1.2、HTTP协议1.2.1 、简介1.2.2、 http协议特性1.3.3、http请求协议与响应协议 2.HTML概述3.HTML标准结构4.标签的语法5.基本标签6.超链接标签6.1、超链接基本使用6.2、锚点 7.img标签8.…

【网络安全的神秘世界】文件包含漏洞

🌝博客主页:泥菩萨 💖专栏:Linux探索之旅 | 网络安全的神秘世界 | 专接本 | 每天学会一个渗透测试工具 一、概述 文件包含:重复使用的函数写在文件里,需要使用某个函数时直接调用此文件,而无需再…

【学习日记】函数调用 和 全局变量 如何实现 位置无关码

问题来源 在 I.MX6ull 的启动流程中,u-boot会将自身从内存一开始的位置拷贝到其他位置,以便给linux留出内存空间,防止 u-boot被覆盖如果代码中包含直接引用其链接时地址的指令,那么当代码被移动到新的地址时,这些引用…

聊聊RNNLSTM

RNN 用于解决输入数据为,序列到序列(时间序列)数据,不能在传统的前馈神经网络(FNN)很好应用的问题。时间序列数据是指在不同时间点上收集到的数据,这类数据反映了某一事物、现象等随时间的变化状态或程度,即输入内容的上下文关联…

工业现场实测,焦化厂导烟车与装煤车风机实现无人作业

一、项目背景 作为我国重要的能源行业之一,焦化行业在国民经济中扮演着重要角色,焦化工艺是高温、高压、有毒物质等因素共同作用下进行的,因此存在着安全隐患,并伴有环境污染,改善焦化工艺的安全和环保问题是当前亟待…

优选算法之前缀和(下)

目录 一、和为 k 的子数组 1.题目链接:560. 和为 K 的子数组 2.题目描述: 3.解法(前缀和 哈希表) 🌻算法思路: 🌻算法代码: 二、和可被 k 整除的子数组 1.题目链接&#xff…

MySQL中多表查询之外连接

首先先来介绍一下我做的两个表,然后再用他们两个举例说明。 -- 创建教师表 create table teachers( id_t int primary key auto_increment, -- 老师编号 name_t varchar(5) -- 姓名 ); -- 创建学生表 create table students( id_s int primary key auto_increment,…

Android APK混淆处理方案分析

这里写目录标题 一、前言1.1 相关工具二、Apk 分析2.1 apk 解压文件2.2 apk 签名信息2.3 apk AndroidManifest.xml2.4 apk code三、Apk 处理3.1 添加垃圾文件3.2 AndroidManifest.xml 处理3.3 dex 混淆处理3.4 zipalign对齐3.5 apk 重新签名3.6 apk 安装测试四、总结一、前言 提…

使用Astro+Vercel+Cloudflare一天时间开发部署上线一个知识博客网站,简直简简单单

大家好,这里是程序猿代码之路。在当今数字化时代,拥有一个个人博客网站对于分享知识、展示个人品牌变得越来越重要。然而,许多非技术背景的用户对于搭建和维护一个网站可能会感到望而却步。幸运的是,随着低代码和无代码平台的兴起…

Spring高手之路21——深入剖析Spring AOP代理对象的创建

文章目录 创建代理对象核心动作的三个步骤1. 判断 Bean 是否需要增强(源码分析时序图说明)2. 匹配增强器 Advisors(源码分析时序图说明)3. 创建代理对象(源码分析时序图说明) 创建代理对象核心动作的三个步…

C++模版基础知识与STL基本介绍

目录 一. 泛型编程 二. 函数模板 1. 概念 2. 函数模版格式 3. 函数模版的原理 4. 模版函数的实例化 (1). 隐式实例化 (2.) 显式实例化 5. 模版参数的匹配原则 三. 类模板 1. 类模板的定义格式 2. 类模板的实例化 四. STL的介绍 1. 什么是STL? 2. STL的版…