在yolov5源码中添加注意力机制

news2024/11/24 8:45:09

yolov5源码中添加注意力机制

  • 1 项目环境配置
    • 1.1 yolov5 源码下载
    • 1.2 创建虚拟环境
    • 1.3 安装依赖
  • 2 常用的注意力机制
    • 2.1 SE 注意力机制
    • 2.2 CBAM 注意力机制
    • 2.3 ECA 注意力机制
    • 2.4 CA 注意力机制
  • 3 添加方式
    • 3.1 修改 common.py 文件
    • 3.2 修改 yolo.py 文件
    • 3.3 修改 yolov5s.yaml 文件
    • 3.4 修改 train.py 文件

1 项目环境配置

1.1 yolov5 源码下载

点击下载

1.2 创建虚拟环境

win+r打开Windows终端界面输入(其中yolov5为我命名的虚拟环境名称):

mkvirtualenv yolov5

进入虚拟环境

workon yolov5

没有此模块无法创建虚拟环境的请移步:Python 的虚拟环境

1.3 安装依赖

  1. 依赖前提:有python环境以及pytorch

本人环境:python3.9,cuda11.7
安装 pytorch 移步官网

在这里插入图片描述

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117

避免不必要的错误,建议使用 pip 安装

  1. 安装项目依赖

进入项目文件夹,终端键入:

pip install -r requirements.txt

环境搭建完成!

2 常用的注意力机制

2.1 SE 注意力机制

# SE
class SE(nn.Module):
    def __init__(self, c1, c2, ratio=16):
        super(SE, self).__init__()
        #c*1*1
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.l1 = nn.Linear(c1, c1 // ratio, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.l2 = nn.Linear(c1 // ratio, c1, bias=False)
        self.sig = nn.Sigmoid()
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avgpool(x).view(b, c)
        y = self.l1(y)
        y = self.relu(y)
        y = self.l2(y)
        y = self.sig(y)
        y = y.view(b, c, 1, 1)
        return x * y.expand_as(x)

2.2 CBAM 注意力机制

# CBAM
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu = nn.ReLU()
        self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
        max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return out
    
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1
        # (特征图的大小-算子的size+2*padding)/步长+1
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        # 1*h*w
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        #2*h*w
        x = self.conv(x)
        #1*h*w
        return self.sigmoid(x)
    
class CBAM(nn.Module):
    def __init__(self, c1, c2, ratio=16, kernel_size=7):  # ch_in, ch_out, number, shortcut, groups, expansion
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(c1, ratio)
        self.spatial_attention = SpatialAttention(kernel_size)
    def forward(self, x):
        out = self.channel_attention(x) * x
        # c*h*w
        # c*h*w * 1*h*w
        out = self.spatial_attention(out) * out
        return out

2.3 ECA 注意力机制

class ECA(nn.Module):
    """Constructs a ECA module.
    Args:
        channel: Number of channels of the input feature map
        k_size: Adaptive selection of kernel size
    """

    def __init__(self, c1,c2, k_size=3):
        super(ECA, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # feature descriptor on the global spatial information
        y = self.avg_pool(x)
        y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
        # Multi-scale information fusion
        y = self.sigmoid(y)

        return x * y.expand_as(x)

2.4 CA 注意力机制

# CA
class h_sigmoid(nn.Module):
    def __init__(self, inplace=True):
        super(h_sigmoid, self).__init__()
        self.relu = nn.ReLU6(inplace=inplace)
    def forward(self, x):
        return self.relu(x + 3) / 6
class h_swish(nn.Module):
    def __init__(self, inplace=True):
        super(h_swish, self).__init__()
        self.sigmoid = h_sigmoid(inplace=inplace)
    def forward(self, x):
        return x * self.sigmoid(x)

class CoordAtt(nn.Module):
    def __init__(self, inp, oup, reduction=32):
        super(CoordAtt, self).__init__()
        self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
        self.pool_w = nn.AdaptiveAvgPool2d((1, None))
        mip = max(8, inp // reduction)
        self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(mip)
        self.act = h_swish()
        self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
        self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
    def forward(self, x):
        identity = x
        n, c, h, w = x.size()
        #c*1*W
        x_h = self.pool_h(x)
        #c*H*1
        #C*1*h
        x_w = self.pool_w(x).permute(0, 1, 3, 2)
        y = torch.cat([x_h, x_w], dim=2)
        #C*1*(h+w)
        y = self.conv1(y)
        y = self.bn1(y)
        y = self.act(y)
        x_h, x_w = torch.split(y, [h, w], dim=2)
        x_w = x_w.permute(0, 1, 3, 2)
        a_h = self.conv_h(x_h).sigmoid()
        a_w = self.conv_w(x_w).sigmoid()
        out = identity * a_w * a_h
        return out

3 添加方式

3.1 修改 common.py 文件

修改 yolov5-master/models/common.py文件,将上述提供的注意力机制代码块直接加到 common.py 文件夹的末尾,此处以SE注意力机制为例

在这里插入图片描述

3.2 修改 yolo.py 文件

修改 yolov5-master/models/yolo.py文件,将注意力机制类名SE添加到 yolo.py 文件的 parse_model方法中如下集合里

在这里插入图片描述

3.3 修改 yolov5s.yaml 文件

修改 yolov5-master/models/yolov5s.yaml文件,将SE注意力机制模块添加到你想添加的位置,常见的有C3模块的后面,以及在主干网络 backboneSPPF 的前一层,这里我将SE注意力机制模块添加在主干网络 backboneSPPF 的前一层

修改前:
在这里插入图片描述

修改后:

在这里插入图片描述

另外,由于我将SE注意力机制模块添加在了第 9 层(层索引为 9,起始层索引为 0),那么,原来的第 9 层,以及第 9 层之后的层数都要加 1

加1前:

在这里插入图片描述

加1后:

在这里插入图片描述

3.4 修改 train.py 文件

修改 yolov5-master/train.py 文件,在默认参数 --cfg后面的 default中添加我们前面修改过的 yolov5s.yaml文件

修改前:
在这里插入图片描述
修改后:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/721578.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

易查分如何导入数据?这个最关键的要点别忽略

我们在使用易查分制作查询系统时,偶尔会遇到Excel文件没有办法正常上传的情况。这个问题困扰着许多老师,他们不知道该如何解决。今天我想和大家讨论一下,易查分导入数据时最常出现错误的原因,其中这个要点最关键,但很多…

谷歌Bard入门指南

文章目录 谷歌Bard入门指南一、简介二、使用指南三、中文化3.1 中文提问3.2 中文回答 四、Hello Game五、亮点 谷歌Bard入门指南 一、简介 Bard 是一个大型语言模型,也称为对话式 AI 或聊天机器人,经过训练,内容丰富且全面。Bard 接受过大量…

Mysql——》哈希索引

推荐链接: 总结——》【Java】 总结——》【Mysql】 总结——》【Redis】 总结——》【Kafka】 总结——》【Spring】 总结——》【SpringBoot】 总结——》【MyBatis、MyBatis-Plus】 总结——》【Linux】 总结——》【MongoD…

接口测试-postman,JMeter与LoadRunner比较

目录 JMeter与LoadRunner比较 JMeter缺点 一.创建测试用例集、子集 二.创建测试用例 三.设置变量 四.添加响应处理 五.批量执行测试用例 总结: postman是一个谷歌出的轻量级的专门测试接口的小工具~(PS:postman包括两种:C…

PostgreSQL如何根据执行计划进行性能调优?

EXPLAIN命令 PG中EXPLAIN命令语法格式如下: EXPLAIN [(option[,...])] statement EXPLAIN [ANALYZE] [VERBOSE] statement该命令的options如下: ANALYZE [boolean]VERBOSE [boolean]COSTS [boolean]BUFFERS [boolean]FORMAT {TEXT | XML | JSON | YAM…

【UnityDOTS 十】DynamicBufferComponent介绍

DynamicBufferComponent 前言 DynamicBufferComponent 作为一种特殊的组件存在,可以作为一种非托管内存下可动态调整带下的数组容器组件。 一、DynamicBufferComponent是什么? DynamicBufferComponent也是组件的一种。 需要关注的是内部指针&#xf…

spring使用01

① 导入 Spring 开发的基本包坐标 ② 编写 Dao 接口和实现类 ③ 创建 Spring 核心配置文件 ④ 在 Spring 配置文件中配置 UserDaoImpl ⑤ 使用 Spring 的 API 获得 Bean 实例 第一步&#xff1a;创建maven的web骨架 然后&#xff0c;导入 Spring 开发的基本包坐标 <depe…

工资10K,副业20K,这届程序员搞副业真野

最近刚完成了一个远程外包项目工作&#xff0c;钱刚到账&#xff0c;小金库又添了一笔&#xff1a; 从一开始的15K死工资&#xff0c;到现在的主业副业一共25K收入&#xff0c;最近的经济压力小了很多&#xff0c;终于也有闲钱和老婆去旅旅游&#xff0c;升级一下外设&#xff…

平板电脑的触控笔有必要买吗?平价电容笔排行榜

伴随着ipad的流行&#xff0c;部分学习党开始从传统的纸笔教学向无纸化教学转变。于是&#xff0c;原本属于苹果专利的电容笔&#xff0c;一下子就火了起来&#xff0c;不少人都对这个价格接近一千块钱的电容笔产生了浓厚的兴趣。我想&#xff0c;苹果电容笔特有的的“重力压感…

MySQL数据库基础(二):DDL,DML,DQL

六、DDL数据库操作 1、MySQL的组成结构 注&#xff1a;我们平常说的MySQL&#xff0c;其实主要指的是MySQL数据库管理软件。 一个MySQL DBMS可以同时存放多个数据库&#xff0c;理论上一个项目就对应一个数据库。如博客项目blzhujianog数据库、商城项目shop数据库、微信项目wec…

OpenCV创建一张类型为CV_8UC3的3通道彩色图像

#include <iostream> #include <opencv2/imgcodecs.hpp> #include <opencv2/opencv.hpp> #include <opencv2/highgui.hpp>int

[MySQL]MySQL库的操作

[MySQL]MySQL库的操作 文章目录 [MySQL]MySQL库的操作1. 创建数据库2. 字符集和校验规则2.1. 基本概念2.2. 查看系统默认字符集以及校验规则2.3. 查看数据库支持的字符集2.4 查看数据库支持的校验规则2.5 指明字符集和校验规则创建数据库2.6 校验规则对数据库的影响 3. 删除数据…

经典轻量级神经网络(1)MobileNet V1及其在Fashion-MNIST数据集上的应用

经典轻量级神经网络(1)MobileNet V1及其在Fashion-MNIST数据集上的应用 1 MobileNet V1的简述 自从2017年由谷歌公司提出&#xff0c;MobileNet可谓是轻量级网络中的Inception&#xff0c;经历了一代又一代的更新。 MobileNet 应用了Depthwise 深度可分离卷积来代替常规卷积…

【hadoop】Google的基本思想

Google的基本思想 三架马车GFS分布式文件系统的核心架构和原理机架感知 MapReduce计算模型PageRank问题MapReduce BigTable 三架马车 Google的基本思想主要有三个&#xff0c;称之为三架马车&#xff0c;分别是GFS&#xff08;Google File System&#xff09;、MapReduce计算模…

gitlab/gerrit

gitlab/gerrit 1. gitlab2. gerrit2.1 环境准备2.2 下载软件2.3 创建启动账户2.4 安装gerrit2.5 创建登录账户2.6 启动服务2.7 修改配置文件2.8 配置反向代理(nginx)2.9 gerrit主页 3. gitlabgerrit3.1 配置gerrit replication功能&#xff08;用于复制具体项目&#xff09;3.2…

深入浅出讲解Stable Diffusion原理,新手也能看明白

说明 最近一段时间对多模态很感兴趣&#xff0c;尤其是Stable Diffusion&#xff0c;安装了环境&#xff0c;圆了自己艺术家的梦想。看了这方面的一些论文&#xff0c;也给人讲过一些这方面的原理&#xff0c;写了一些文章&#xff0c;具体可以参考我的文章&#xff1a; 北方…

51单片机驱动 mg996r金属舵机 STC89C52单片机直接驱动金属大舵机

/*无论是大舵机&#xff0c;还是小舵机&#xff0c;控制方法都一样会区别在 大舵机只能接P0口&#xff08;此口外接上拉&#xff0c;驱动电流最大&#xff09;小舵机任意口 */ //#include<reg51.h> //#define uint unsigned int //#define uchar unsigned char //sbit S…

10、架构:组件通信设计

通信是一个应用中不可或缺的一个功能&#xff0c;现如今前端视图类框架大多数都是由数据驱动&#xff0c;通过数据来进行视图层的展示渲染。举个简单的例子如下&#xff0c;这是一个常见的 React 列表渲染&#xff1a; // each const numbers [1, 2, 3, 4, 5]; const listIte…

应用级监控方案Spring Boot Admin

1.简介 Spring Boot Admin为项目常用的监控方式&#xff0c;可以动态的监控服务是否运行和运行的参数&#xff0c;如类的调用情况、流量等。其中分为server与client&#xff1a; server&#xff1a; 提供展示UI与监控服务。client&#xff1a;加入server&#xff0c;被监控的…

C语言王国探险记之函数的简单概念

王国探险记系列 文章目录&#xff08;5&#xff09; 目录 王国探险记系列 文章目录&#xff08;5&#xff09; 前言 一&#xff0c;函数的基本概念 二&#xff0c;调用外部函数和main()函数区别 2.1如果我们将函数的定义放到后面&#xff0c;可不可以呢&#xff1f; 总结…