YOLOv8添加注意力模块并测试和训练

news2024/9/21 0:49:58

YOLOv8添加注意力模块并测试和训练

参考bilibili视频

yolov8代码库中写好了注意力模块,但是yolov8的yaml文件中并没用使用它,如下图的通道注意力和空间注意力以及两者的结合CBAM,打开conv.py文件可以看到,其中包含了各种卷积块的定义,因此yolov8是把通道注意力和空间注意力以及两者的结合CBAM当作卷积块来处理:
在这里插入图片描述

在这里插入图片描述

2 逐层写入自定义的注意力模块

(1)ultralytics/nn/modules/conv.py中写入自定义的注意力模块:
在这里插入图片描述

(2)ultralytics/nn/modules/init.py中添加自定义的注意力模块名:
在这里插入图片描述
在这里插入图片描述
只有逐层添加模块名,才能封装成ultralytics.nn.modules的内部模块
(3)ultralytics/nn/tasks.py中添加自定义的注意力模块名,以便任务执行时调用自定义的注意力模块。
在这里插入图片描述
接着在ultralytics/nn/tasks.py–>parse_model函数中解析yaml文件时,判断是否有自定义的注意力模块:
在这里插入图片描述

由于CBAM可以看成只是给卷积块Conv加权重,并不会改变输入、输出通道数,因此可以仿照Conv块的处理,在下面判断的语句中它只会执行以下几句:

c1,c2为输入输出通道数,if 后面的语句是的作用是除了最后一层类别输出通道数,其它层的通道数都要是8的整数倍。args存放了c1,c2和args[1]之后的所有参数组成新的args,需要注意,args至少要两个元素,如果只有一个元素,agrs[1:]时会报错超出范围,因此模型的yaml文件中args位置,必须至少2个元素,如:
在这里插入图片描述

- [-1, 3, CBAM, [1024, 7]]	# 输入1024个通道数,kenel size=7

3 修改模型的yaml文件

在ultralytics/cfg/models/v8中复制一个yolov8-seg.yaml文件新建yaml文件命名为yolov8CBAM-seg.yaml:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8-seg instance segmentation model. For Usage examples see https://docs.ultralytics.com/tasks/segment

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n-seg.yaml' will call yolov8-seg.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]
  s: [0.33, 0.50, 1024]
  m: [0.67, 0.75, 768]
  l: [1.00, 1.00, 512]
  x: [1.00, 1.25, 512]

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 3, C2f, [128, True]]     #-->2
  - [-1, 1, CBAM, [128, 7]] #CBAM 3
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8-->4
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, CBAM, [256, 7]]   #CBAM 6
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16-->7
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, CBAM, [512, 7]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32-->10
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, CBAM, [1024, 7]]
  - [-1, 1, SPPF, [1024, 5]] # 9-->13

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 8], 1, Concat, [1]] #[[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 12    -->16

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 5], 1, Concat, [1]] #[[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 15 (P3/8-small)--->19

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 16], 1, Concat, [1]]  #[[-1, 12], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 18 (P4/16-medium)-->22

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] #[[-1, 9], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2f, [1024]] # 21 (P5/32-large)--->25

#  - [[15, 18, 21], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5)
  - [[19, 22, 25], 1, Segment, [nc, 32, 256]] # Segment(P3, P4, P5)

这里在主干backbone中的c2f块后面添加了重复一次的CBAM共添加了四个。由于head层需要Concat backbone的相应层,因此,原来的层序号需要逐一修改,注释中 " -->x "表示新的序号,将原来的序号替换成新的即可。

4 测试是否修改成功

复制一份tests/test_python.py文件中的测试代码,新建文件命名为test_yolov8_CBAM_model.py,只保留下方代码:

# Ultralytics YOLO 🚀, AGPL-3.0 license

import contextlib
import urllib
from copy import copy
from pathlib import Path

import cv2
import numpy as np
import pytest
import torch
import yaml
from PIL import Image

from tests import CFG, IS_TMP_WRITEABLE, MODEL, SOURCE, TMP
from ultralytics import RTDETR, YOLO
from ultralytics.cfg import MODELS, TASK2DATA, TASKS
from ultralytics.data.build import load_inference_source
from ultralytics.utils import (
    ASSETS,
    DEFAULT_CFG,
    DEFAULT_CFG_PATH,
    LOGGER,
    ONLINE,
    ROOT,
    WEIGHTS_DIR,
    WINDOWS,
    checks,
)
from ultralytics.utils.downloads import download
from ultralytics.utils.torch_utils import TORCH_1_9

CFG = 'ultralytics/cfg/models/v8/yolov8l-CBAMseg.yaml'	#使用l模型加一个l字母
SOURCE = ASSETS / "bus.jpg"
def test_model_forward():
    """Test the forward pass of the YOLO model."""
    model = YOLO(CFG)
    model(source=SOURCE, imgsz=[512,512], augment=True)  # also test no source and augment

先在ultralytics/nn/tasks.py的parse_model函数中增加一行代码用于查看模型结构:

print(f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f}  {t:<45}{str(args):<30}")

在这里插入图片描述

运行test_yolov8_CBAM_model.py的结果如下:

============================= test session starts ==============================
collected 1 item                                                               

test_yolov8_CBAM_model.py::test_model_forward PASSED                     [100%]  0                  -1  1      1856  ultralytics.nn.modules.conv.Conv             [3, 64, 3, 2]                 
  1                  -1  1     73984  ultralytics.nn.modules.conv.Conv             [64, 128, 3, 2]               
  2                  -1  3    279808  ultralytics.nn.modules.block.C2f             [128, 128, 3, True]           
  3                  -1  1     16610  ultralytics.nn.modules.conv.CBAM             [128, 7]                      
  4                  -1  1    295424  ultralytics.nn.modules.conv.Conv             [128, 256, 3, 2]              
  5                  -1  6   2101248  ultralytics.nn.modules.block.C2f             [256, 256, 6, True]           
  6                  -1  1     65890  ultralytics.nn.modules.conv.CBAM             [256, 7]                      
  7                  -1  1   1180672  ultralytics.nn.modules.conv.Conv             [256, 512, 3, 2]              
  8                  -1  6   8396800  ultralytics.nn.modules.block.C2f             [512, 512, 6, True]           
  9                  -1  1    262754  ultralytics.nn.modules.conv.CBAM             [512, 7]                      
 10                  -1  1   2360320  ultralytics.nn.modules.conv.Conv             [512, 512, 3, 2]              
 11                  -1  3   4461568  ultralytics.nn.modules.block.C2f             [512, 512, 3, True]           
 12                  -1  1    262754  ultralytics.nn.modules.conv.CBAM             [512, 7]                      
 13                  -1  1    656896  ultralytics.nn.modules.block.SPPF            [512, 512, 5]                 
 14                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 15             [-1, 8]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 16                  -1  3   4723712  ultralytics.nn.modules.block.C2f             [1024, 512, 3]                
 17                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
 18             [-1, 5]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 19                  -1  3   1247744  ultralytics.nn.modules.block.C2f             [768, 256, 3]                 
 20                  -1  1    590336  ultralytics.nn.modules.conv.Conv             [256, 256, 3, 2]              
 21            [-1, 16]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 22                  -1  3   4592640  ultralytics.nn.modules.block.C2f             [768, 512, 3]                 
 23                  -1  1   2360320  ultralytics.nn.modules.conv.Conv             [512, 512, 3, 2]              
 24            [-1, 13]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 25                  -1  3   4723712  ultralytics.nn.modules.block.C2f             [1024, 512, 3]                
 26        [19, 22, 25]  1   7950688  ultralytics.nn.modules.head.Segment          [80, 32, 256, [256, 512, 512]]

image 1/1 /XXXXXXXXXXXXXXXXX/ultralyticsv8_2-main/ultralytics/assets/bus.jpg: 640x480 (no detections), 116.5ms
Speed: 2.7ms preprocess, 116.5ms inference, 0.7ms postprocess per image at shape (1, 3, 640, 480)


======================== 1 passed, 4 warnings in 7.04s =========================

进程已结束,退出代码0

至此,注意力模块添加完成。

5 训练

在这里插入图片描述
如上图,这里使用x超大模型,只需yolov8-CBAMseg.yaml中加一个x变成yolov8x-CBAMseg.yaml,优化器为上一篇博客yolov8更改的Lion优化器。可以看到arguments参数按照x模型发生了调整,模型开始训练。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1974516.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ClinicalAgent:结合大模型的临床试验多智能体系统

ClinicalAgent&#xff1a;结合大模型的临床试验多智能体系统 提出背景ClinicalAgent 框架规划智能体功效智能体安全智能体 解法解法 子解法1&#xff08;因为需要处理复杂的数据和多变量&#xff09; 子解法2&#xff08;因为需要及时反馈临床试验中的变化&#xff09; 子解…

海信聚好看的DBDocter软件使用心得

在墨天轮大会看到这个软件,好称是内核级别的诊断工具, 工作空闲下载免费看看 结果要1.7GB还TAR. DBdoctor是一款内核级数据库性能诊断软件。可以对数据库做细粒度的扫描&#xff0c;帮助您一分钟内找到数据库性能问题&#xff0c;实现性能诊断百倍提效。针对数据库性能诊断门…

ICML 2024:从历史数据中挖掘最优策略,高效完成50+任务,“离线策略提升的在线演员-评论家”研究工作

长期以来&#xff0c;如何提升数据利用效率被认为是强化学习落地应用的一大桎梏。过去非策略&#xff08;off-policy&#xff09;的强化学习虽然能反复利用收集到的数据来进行策略优化&#xff0c;然而这些方法未能最大限度地利用重放缓冲区&#xff08;Replay buffer&#xff…

新手小白学习PCB设计,立创EDA专业版

本教程有b站某UP主的视频观后感 视频链接&#xff1a;http://【【教程】零基础入门PCB设计-国一学长带你学立创EDA专业版 全程保姆级教学 中文字幕&#xff08;持续更新中&#xff09;】https://www.bilibili.com/video/BV1At421h7Ui?vd_sourcefedb10d2d09f5750366f83c1e0d4a…

JAVA进阶学习13

文章目录 2.2.3 综合输入和输出方法进行文件拷贝2.2.4 字节流读取时乱码的问题 2.3 字符流的方法概述2.3.1 FileReader方法2.3.2 FileWriter方法2.3.3 小结 三、高级IO流3.1 缓冲流3.1.1 字节缓冲流3.1.2 字符缓冲流 3.2 转换流3.3 序列化流3.3.1 序列化流3.3.2 反序列化流 3.4…

亚马逊自养号测评一直被砍单封号怎么解决

亚马逊是一个大数据公司&#xff0c;可以检测出你的购买行为是否正常&#xff0c;如每次都是直接用链接购买产品而从来不用搜索栏&#xff0c;每次购买产品单一而且时间快速&#xff0c;买家留评比例过高或者评论内容太假&#xff0c;产品还没签收就上评论&#xff0c;某个list…

vxtable行转列

<script setup lang"ts"> import dayjs from "dayjs"; import {Search} from "element-plus/icons-vue"; import {ElMessage} from "element-plus"; class SearchModel{startTime?: Date | stringendTime?: Date | stringcons…

react-native从入门到实战系列教程一ScrollView组件吸顶效果

在ScrollView组件里面把第一元素固定在视图顶部的效果&#xff0c;ScrollView在手机上自带了bounce回弹的效果&#xff0c;不管内容是不是超出组件高度还是宽度 实现效果 代码实现 import {View,Text,StyleSheet,Dimensions,TextInput,Button,Alert,ScrollView,StatusBar,Saf…

[windows10]win10永久禁用系统自动更新操作方法

WinR打开运行 输入regedit打开注册表 点击确定打开注册表 按照如下路径找到UX 计算机\HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\WindowsUpdate\UX\Settings 在空白处点击鼠标右键&#xff0c;新建选择DWORD&#xff0c;然后重命名为FlightSettingsMaxPauseDays 双击FlightSet…

图论:1203. 项目管理(以小组为单位进行拓扑排序)

文章目录 1.问题分析2.思路整理3.官解思路 LeetCode&#xff1a;1203. 项目管理 建议直接看思路整理 1.问题分析 仔细读题可以发现&#xff0c;如果不考虑小组项目彼此相邻&#xff0c;则项目之间的依赖关系就是一个拓扑排序。 但是如果要考虑小组项目彼此相邻&#xff0c;问…

【机器人学】6-3.六自由度机器人运动学参数辨识- 机器人辨识参数耦合性分析

前言 上一章我们用两步优化方法求解了辨识参数&#xff0c; 【机器人学】6-2.六自由度机器人运动学参数辨识-优化方法求解辨识参数 我们给机器人的几何参数进行了数学建模&#xff0c;其中使用高斯牛顿法求解出了激光仪相对于机器人基座的坐标变换和机器人末端执行器相对于靶球…

【RTT-Studio】详细使用教程七:SGM5352外部DAC使用

文章目录 一、简介二、RTT时钟配置三、初始化配置四、完整代码五、测试验证 一、简介 本文主要介绍使用RTT-ThreadStudio来驱动SGM5352芯片的使用&#xff0c;该芯片主要是一个低功率&#xff0c;4通道&#xff0c;16位&#xff0c;电压输出DAC。它从2.7V到5.5V&#xff0c;设…

短视频矩阵系统设计:抖音短视频平台的最佳选择

随着移动互联网的快速发展&#xff0c;短视频行业异军突起&#xff0c;抖音短视频平台凭借其丰富的内容、便捷的创作工具和智能推荐算法&#xff0c;吸引了大量用户。在这个背景下&#xff0c;短视频矩阵系统应运而生&#xff0c;成为抖音短视频平台的最佳选择。本文将详细介绍…

左手坐标系、右手坐标系、坐标轴方向

一、右手坐标系 1、y轴朝上&#xff1a;webgl、Threejs、Unity、Unreal、Maya、3D Builder x&#xff1a;向右y&#xff1a;向上z&#xff1a;向前&#xff08;朝向观察者、指向屏幕外&#xff09; 2、z轴朝上&#xff1a;cesium、blender x&#xff1a;向右y&#xff1a;向前…

C# 方法的重载(Overload)

在C#中&#xff0c;方法的重载&#xff08;Overloading&#xff09;是指在一个类中可以有多个同名的方法&#xff0c;只要这些方法具有不同的方法签名&#xff08;即参数的数量、类型或顺序不同&#xff09;。这使得你可以使用相同的方法名称来执行相似但参数不同的操作&#x…

GEE必须会教程——基于Landsat影像构建NDVI时间序列

很久很久以前&#xff0c;小编写了一篇基于MODIS影像构建归一化植被指数的文章&#xff0c;不知道大家还有没有印象&#xff0c;有一段时间没有更新时间序列分析相关的文章了。 今天&#xff0c;我们来看看基于Lansat影像&#xff0c;我们来看看在GEE上如何构建NDVI的时间序列。…

AI跟踪报道第50期-新加坡内哥谈技术-本周AI新闻: 听听没有Scarlett Johansson的GPT-4o更加震撼

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

shellcode加密免杀

通过加密shellcode方式过安全软件拦截 先说结论&#xff0c;笔者没成功 shellcode&#xff1a; Shellcode 是一段用于在目标系统上执行特定操作的机器码。它通常被用于利用软件漏洞&#xff0c;以获取对目标系统的控制权或执行特定的恶意行为。 Shellcode 可以执行诸如创建进程…

MySQL 预处理、如何在 [Node.js] 中使用 MySQL?

前面文章我们已经总结了mysql下载安装配置启动以及如何用 Navicat 连接&#xff0c;还有MySQL的入门基础知识 、Node.js的基本知识、Express框架基于Node.js基础知识、下面我们总结如何在Node.js中使用MySQL数据库以及MySQL预处理基本知识。 目录 一、MySQL预处理 二、如何在…