【保姆级教程|YOLOv8添加注意力机制】【2】在C2f结构中添加ShuffleAttention注意力机制并训练

news2024/10/6 10:31:37

《博主简介》

小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。
更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~
👍感谢小伙伴们点赞、关注!

《------往期经典推荐------》

一、AI应用软件开发实战专栏【链接】

项目名称项目名称
1.【人脸识别与管理系统开发】2.【车牌识别与自动收费管理系统开发】
3.【手势识别系统开发】4.【人脸面部活体检测系统开发】
5.【图片风格快速迁移软件开发】6.【人脸表表情识别系统】
7.【YOLOv8多目标识别与自动标注软件开发】8.【基于YOLOv8深度学习的行人跌倒检测系统】
9.【基于YOLOv8深度学习的PCB板缺陷检测系统】10.【基于YOLOv8深度学习的生活垃圾分类目标检测系统】
11.【基于YOLOv8深度学习的安全帽目标检测系统】12.【基于YOLOv8深度学习的120种犬类检测与识别系统】
13.【基于YOLOv8深度学习的路面坑洞检测系统】14.【基于YOLOv8深度学习的火焰烟雾检测系统】
15.【基于YOLOv8深度学习的钢材表面缺陷检测系统】16.【基于YOLOv8深度学习的舰船目标分类检测系统】
17.【基于YOLOv8深度学习的西红柿成熟度检测系统】18.【基于YOLOv8深度学习的血细胞检测与计数系统】
19.【基于YOLOv8深度学习的吸烟/抽烟行为检测系统】20.【基于YOLOv8深度学习的水稻害虫检测与识别系统】
21.【基于YOLOv8深度学习的高精度车辆行人检测与计数系统】22.【基于YOLOv8深度学习的路面标志线检测与识别系统】
22.【基于YOLOv8深度学习的智能小麦害虫检测识别系统】23.【基于YOLOv8深度学习的智能玉米害虫检测识别系统】
24.【基于YOLOv8深度学习的200种鸟类智能检测与识别系统】25.【基于YOLOv8深度学习的45种交通标志智能检测与识别系统】
26.【基于YOLOv8深度学习的人脸面部表情识别系统】27.【基于YOLOv8深度学习的苹果叶片病害智能诊断系统】

二、机器学习实战专栏【链接】,已更新31期,欢迎关注,持续更新中~~
三、深度学习【Pytorch】专栏【链接】
四、【Stable Diffusion绘画系列】专栏【链接】

《------正文------》

## 搜索C2f源码位置并新建C2f类

在项目目录中全局搜索class c2f即可找到c2f的源码位置。然后打开源码位置,进行相应修改。源码路径为:ultralytics/nn/modules/block.py

在原文件中直接copy一份c2f类的源码,然后命名为c2f_Attention,如下所示:

在这里插入图片描述

在不同文件导入新建的C2f类

ultralytics/nn/modules/block.py顶部,all中添加刚才创建的类的名称:c2f_Attention,如下图所示:

在这里插入图片描述

同样需要在ultralytics/nn/modules/__init__.py文件,相应位置导入刚出创建的c2f_Attention类。如下图:

在这里插入图片描述

还需要在ultralytics/nn/tasks.py中导入创建的c2f_Attention类,,如下图:

在这里插入图片描述

parse_model解析函数中添加C2f类

ultralytics/nn/tasks.pyparse_model解析网络结构的函数中,加入c2f_Attention类,如下图:
在这里插入图片描述

创建新的配置文件c2f_att_yolov8.yaml

ultralytics/cfg/models/v8目录下新建c2f_att_yolov8.yaml配置文件,内容如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f_Attention, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f_Attention, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f_Attention, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)

  - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

新的c2f_att_yolov8.yaml配置文件与原yolov8.yaml文件的对比如下:

在这里插入图片描述

在C2f中添加注意力:ShuffleAttention

注意:对于有通道数参数的注意力机制,其输入通道数为其上层的输出通道数。这个注意力添加的位置有关。

在路径ultralytics/nn下新建注意力模块,ShuffleAttention.py文件。内容如下:

import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter


class ShuffleAttention(nn.Module):

    def __init__(self, channel=512, reduction=16, G=8):
        super().__init__()
        self.G = G
        self.channel = channel
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
        self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sigmoid = nn.Sigmoid()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # flatten
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.size()
        # group into subfeatures
        x = x.view(b * self.G, -1, h, w)  # bs*G,c//G,h,w

        # channel_split
        x_0, x_1 = x.chunk(2, dim=1)  # bs*G,c//(2*G),h,w

        # channel attention
        x_channel = self.avg_pool(x_0)  # bs*G,c//(2*G),1,1
        x_channel = self.cweight * x_channel + self.cbias  # bs*G,c//(2*G),1,1
        x_channel = x_0 * self.sigmoid(x_channel)

        # spatial attention
        x_spatial = self.gn(x_1)  # bs*G,c//(2*G),h,w
        x_spatial = self.sweight * x_spatial + self.sbias  # bs*G,c//(2*G),h,w
        x_spatial = x_1 * self.sigmoid(x_spatial)  # bs*G,c//(2*G),h,w

        # concatenate along channel axis
        out = torch.cat([x_channel, x_spatial], dim=1)  # bs*G,c//G,h,w
        out = out.contiguous().view(b, -1, h, w)

        # channel shuffle
        out = self.channel_shuffle(out, 2)
        return out

ultralytics/nn/tasks.py中导入,并修改在parse_model解析网络结构的函数中,添加解析代码:

在这里插入图片描述

在这里插入图片描述

注意力不同位置添加方法

ultralytics/nn/modules/block.py中的c2f_Attention类中代码相应位置添加注意力机制:

1 . 方式一:在self.cv1后面添加注意力机制

在这里插入图片描述

2.方式二:在self.cv2后面添加注意力机制

在这里插入图片描述

3.方式三:在c2fbottleneck中添加注意力机制,将Bottleneck类,复制一份,并命名为Bottleneck_Attention,然后,在Bottleneck_Attention的cv2后面添加注意力机制,同时修改C2f_Attention类别中的BottleneckBottleneck_Attention。如下图所示:

在这里插入图片描述

加载配置文件并训练

加载c2f_att_yolov8.yaml配置文件,并运行train.py训练代码:

#coding:utf-8
from ultralytics import YOLO

if __name__ == '__main__':
    model = YOLO('ultralytics/cfg/models/v8/c2f_att_yolov8.yaml')
    model.load('yolov8n.pt') # loading pretrain weights
    model.train(data='datasets/TomatoData/data.yaml', epochs=150, batch=2)

注意观察,打印出的网络结构是否正常修改,如下图所示:
在这里插入图片描述

【源码免费获取】

为了小伙伴们能够,更好的学习实践,本文已将所有代码、示例数据集、论文等相关内容打包上传,供小伙伴们学习。获取方式如下:

关注下方名片G-Z-H:【阿旭算法与机器学习】,发送【yolov8改进】即可免费获取

在这里插入图片描述


结束语

关于本篇文章大家有任何建议或意见,欢迎在评论区留言交流!

觉得不错的小伙伴,感谢点赞、关注加收藏哦!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1390970.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SpringBoot框架篇】35.kafka环境搭建和收发消息

kafka环境搭建 kafka依赖java环境,如果没有则需要安装jdk yum install java-1.8.0-openjdk* -y1.下载安装kafka kafka3.0版本后默认自带了zookeeper,3.0之前的版本需要单独再安装zookeeper,我使用的最新的3.6.1版本。 cd /usr/local wget https://dlcdn.apache.…

SpringBoot Redis入门(四)——Redis单机、哨兵、集群模式

单机模式:单台缓存服务器,开发、测试环境下使用;哨兵模式:主-从模式,提高缓存服务器的高可用和安全性。所有缓存的数据在每个节点上都一致。每个节点添加监听器,不断监听节点可用状态,一旦主节点…

x-cmd pkg | public-ip-cli - 公共 IP 地址查询工具

简介 public-ip-cli 是一个用 Javascript 编写的命令行工具,用于获取当前计算机或网络所使用的公共 IP 地址。 它可以让用户在命令行界面上查询 OpenDNS、Google DNS 和 HTTPS 服务的 DNS 记录以获取与互联网通信时所分配的公共 IP 地址。 首次用户 使用 x env us…

成功解决VScode进入到内置函数中调试

主要有两个关键步骤, 第一步 将launch.json中的"justMyCode"设为false 可通过使用ctrlshiftP搜索lauch.json找到次文件 如果找不到的话,可点击debug按钮,然后找到点击create a launch.json file创建 创建得到的launch.json如下&am…

四大软件架构:掌握单体、分布式、微服务、Serverless 的精髓

四大软件架构:掌握单体、分布式、微服务、Serverless 的精髓 简介: 如果一个软件开发人员,不了解软件架构的演进,会制约技术的选型和开发人员的生存、晋升空间。这里我列举了目前主要的四种软件架构以及他们的优缺点,…

isis小实验

要求: 1.合理规划level1-2 2.r1访问r5走r6且走上面 3.全网可达 个人理解:以重发布的视角:is-level level1即L1可以看做rip,L2可以看做OSPF,L1-2可以看作是既要rip又要OSPF,优点:isis只用在每个路由器上宣告一次 缺点:isis需要每个接口上输isis enable 1(序号)特点:L1-2会自动下…

Java、C#、Python间的Battle

一、编译原理和开发效率 编译速度: C# (约大于等于) JAVA > Python python的编译原理 前提:python 3.6 python不会直接编译源码 而是把源码直接扔给解释器,这种方式 使得python非常灵活,让它的开发效…

Docker Consul详解与部署示例

目录 Consul构成 Docker Consul 概述 Raft算法 服务注册与发现 健康检查 Key/Value存储 多数据中心 部署模式 consul-template守护进程 registrator容器 consul服务部署(192.168.41.31) 环境准备 搭建Consul服务 查看集群信息 registrato…

uniCloud ---- uni-captch实现图形验证码

目录 用途说明 组成部分 目录结构 原理时序 云端一体组件介绍 验证码配置(可选): 普通验证码组件 公共模块 云函数公用模块 项目实战 创建云函数 创建注册页 创建云函数 关联公用模块 uni-captcha 刷新验证码 自定义实现 验…

Go新项目-为何选Gin框架?(0)

先说结论:我们选型Gin框架 早在大概在2019年下旬,由于内部一个多线程上传的需求,考虑到Go协程的优势; 内部采用Gin框架编写了内部的数据上传平台BAP,采用GinVue开发,但前期没考虑到工程化思维,导…

用LED数码显示器伪静态显示数字1234

#include<reg51.h> // 包含51单片机寄存器定义的头文件 void delay(void) //延时函数&#xff0c;延时约0.6毫秒 { unsigned char i; for(i0;i<200;i) ; } void main(void) { while(1) //无限循环 { P20xfe; …

.Net 8.0 Web API Controllers 添加到 windows 服务

示例源码下载&#xff1a;https://download.csdn.net/download/hefeng_aspnet/88747022 创建 Windows 服务的方法之一是从工作线程服务模板开始。 但是&#xff0c;如果您希望能够让它托管 API 控制器&#xff08;也许是为了查看它正在运行的进程的状态&#xff09;&#xff0…

IC验证——perl脚本ccode_standard——c代码寄存器配置标准化

目录 1 脚本名称 2 脚本路径 3 脚本参数说明 4 脚本操作说明 5 脚本代码 1 脚本名称 ccode_standard 2 脚本路径 /scripts/bin/ccode_standard 3 脚本参数说明 次序 参数名 说明 1 address (./rfdig&#xff1b;.&#xff1b;..&#xff1b;./boot) 指定脚本执行路…

如何避免知识付费小程序平台的陷阱?搭建平台的最佳实践

随着知识经济的兴起&#xff0c;知识付费已经成为一种趋势。越来越多的人开始将自己的知识和技能进行变现&#xff0c;而知识付费小程序平台则成为了一个重要的渠道。然而&#xff0c;市面上的知识付费小程序平台琳琅满目&#xff0c;其中不乏一些不良平台&#xff0c;让老实人…

【零基础入门Python数据分析】Anaconda3 JupyterNotebookseaborn版

目录 一、安装环境 python介绍 anaconda介绍 jupyter notebook介绍 anaconda3 环境安装 解决JuPyter500&#xff1a;Internal Server Error问题-CSDN博客 Jupyter notebook快捷键操作大全 二、Python基础入门 数据类型与变量 数据类型 变量及赋值 布尔类型与逻辑运算…

flutter报错Cannot hit test a render box that has never been laid out

出现这个问题的原因可能是因为你把一个ListView或者GridView放到了一个没有设置大小的容器里面导致的&#xff0c;所以意思是不能渲染那一个没有布局过的容器。我这里遇到的错误是因为我把GridView放到了一个Container里面&#xff0c;并且我没有设置Container宽高。 就导致了那…

linux如何排查cpu持续飙高原因

一、检查CPU使用率 首先在Linux系统中检查CPU使用率。可以通过在命令行中输入top或htop命令来查看当前系统中各个进程的CPU使用率。如果CPU使用率大于80%&#xff0c;则可以考虑进行排查。 $ top二、检查系统负载 另外可以使用uptime命令来查看系统的平均负载情况。 $ upti…

elasticsearch6.6.0设置访问密码

elasticsearch6.6.0设置访问密码 第一步 x-pack-core-6.6.0.jar第二步 elasticsearch.yml第三步 设置密码 第一步 x-pack-core-6.6.0.jar 首先破解 x-pack-core-6.6.0.jar 破解的方式大家可以参考 https://codeantenna.com/a/YDks83ZHjd 中<5.破解x-pack> 这部分 , 也可…

Zookeeper安装教程

文章目录 前言一、选择安装包二、使用wget下载并安装zookeeper 前言 Linux下Zookeeper安装步骤 一、选择安装包 Zookeeper下载地址&#xff1a;https://zookeeper.apache.org/releases.html 选择一个稳定版本即可&#xff0c;我这里选择的是3.7.2 点击“Apache ZooKeeper 3.…

考研C语言刷题篇之分支循环结构一

目录 第一题 第二题 方法一&#xff1a;要循环两次&#xff0c;一次求阶乘&#xff0c;一次求和。 注意&#xff1a;在求和时&#xff0c;如果不将sum每次求和的初始值置为1&#xff0c;那么求和就会重复。 方法二&#xff1a; 第三题 方法一&#xff1a;用数组遍历的思想…