YOLOv8改进教程|加入可改变核卷积AKConv模块,效果远超DSConv!

news2024/9/20 22:47:54


⭐⭐ YOLOv8改进专栏|包含主干、模块、注意力机制、检测头等前沿创新 ​ ⭐⭐


 一、 论文介绍

        论文链接:https://arxiv.org/abs/2311.11587

        代码链接:GitHub - CV-ZhangXin/AKConv

论文速览::AKConv是2023年11月发表的一种可变卷积核,赋予卷积核任意数量的参数和任意采样形状,以解决具有固定样本形状和正方形的卷积核不能很好地适应不断变化的目标的问题点可以为网络开销和性能之间的权衡提供更丰富的选择。 AKConv的核心思想在于它为卷积核提供了任意数量的参数和任意采样形状,能够使用任意数量的参数(如1,2,3,4,5,6,7等)来提取姝征,这在标准卷积和可变形卷积中并未实现。AKConv能够根据硬件环境,使卷积参数的数星呈线性增减((非常适用于轻量化模型)。

总结:AKConv是一种具有任意数量的参数和任意采样形状的可变卷积核,对不规则特征有更好的提取效果。


二、 加入到RT-DETR中

2.1 复制代码

        复制代码粘到ultralytics->nn->modules->conv.py文件中,在顶部导入torch.nn.functional包,(torch.nn.functional as F),将代码粘贴于下方,并在__all__中声明,如下图所示:

# Ultralytics YOLO 🚀, AGPL-3.0 license
"""Convolution modules."""

import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


__all__ = (
    "Conv",
    "Conv2",
    "LightConv",
    "DWConv",
    "DWConvTranspose2d",
    "ConvTranspose",
    "Focus",
    "GhostConv",
    "ChannelAttention",
    "SpatialAttention",
    "CBAM",
    "Concat",
    "RepConv",
    "AKConv",
)


class AKConv(nn.Module):
    def __init__(self, inc, outc, num_param, stride=1, bias=None):
        super(AKConv, self).__init__()
        self.num_param = num_param
        self.stride = stride
        self.conv = nn.Sequential(nn.Conv2d(inc, outc, kernel_size=(num_param, 1), stride=(num_param, 1), bias=bias),
                                  nn.BatchNorm2d(outc),
                                  nn.SiLU())  # the conv adds the BN and SiLU to compare original Conv in YOLOv5.
        self.p_conv = nn.Conv2d(inc, 2 * num_param, kernel_size=3, padding=1, stride=stride)
        nn.init.constant_(self.p_conv.weight, 0)
        self.p_conv.register_full_backward_hook(self._set_lr)

    @staticmethod
    def _set_lr(module, grad_input, grad_output):
        grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
        grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))

    def forward(self, x):
        # N is num_param.
        offset = self.p_conv(x)
        dtype = offset.data.type()
        N = offset.size(1) // 2
        # (b, 2N, h, w)
        p = self._get_p(offset, dtype)

        # (b, h, w, 2N)
        p = p.contiguous().permute(0, 2, 3, 1)
        q_lt = p.detach().floor()
        q_rb = q_lt + 1

        q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)],
                         dim=-1).long()
        q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)],
                         dim=-1).long()
        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)

        # clip p
        p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1)

        # bilinear kernel (b, h, w, N)
        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))

        # resampling the features based on the modified coordinates.
        x_q_lt = self._get_x_q(x, q_lt, N)
        x_q_rb = self._get_x_q(x, q_rb, N)
        x_q_lb = self._get_x_q(x, q_lb, N)
        x_q_rt = self._get_x_q(x, q_rt, N)

        # bilinear
        x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
                   g_rb.unsqueeze(dim=1) * x_q_rb + \
                   g_lb.unsqueeze(dim=1) * x_q_lb + \
                   g_rt.unsqueeze(dim=1) * x_q_rt

        x_offset = self._reshape_x_offset(x_offset, self.num_param)
        out = self.conv(x_offset)

        return out

    # generating the inital sampled shapes for the AKConv with different sizes.
    def _get_p_n(self, N, dtype):
        base_int = round(math.sqrt(self.num_param))
        row_number = self.num_param // base_int
        mod_number = self.num_param % base_int
        p_n_x, p_n_y = torch.meshgrid(
            torch.arange(0, row_number),
            torch.arange(0, base_int))
        p_n_x = torch.flatten(p_n_x)
        p_n_y = torch.flatten(p_n_y)
        if mod_number > 0:
            mod_p_n_x, mod_p_n_y = torch.meshgrid(
                torch.arange(row_number, row_number + 1),
                torch.arange(0, mod_number))

            mod_p_n_x = torch.flatten(mod_p_n_x)
            mod_p_n_y = torch.flatten(mod_p_n_y)
            p_n_x, p_n_y = torch.cat((p_n_x, mod_p_n_x)), torch.cat((p_n_y, mod_p_n_y))
        p_n = torch.cat([p_n_x, p_n_y], 0)
        p_n = p_n.view(1, 2 * N, 1, 1).type(dtype)
        return p_n

    # no zero-padding
    def _get_p_0(self, h, w, N, dtype):
        p_0_x, p_0_y = torch.meshgrid(
            torch.arange(0, h * self.stride, self.stride),
            torch.arange(0, w * self.stride, self.stride))

        p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)

        return p_0

    def _get_p(self, offset, dtype):
        N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3)

        # (1, 2N, 1, 1)
        p_n = self._get_p_n(N, dtype)
        # (1, 2N, h, w)
        p_0 = self._get_p_0(h, w, N, dtype)
        p = p_0 + p_n + offset
        return p

    def _get_x_q(self, x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        c = x.size(1)
        # (b, c, h*w)
        x = x.contiguous().view(b, c, -1)

        # (b, h, w, N)
        index = q[..., :N] * padded_w + q[..., N:]  # offset_x*w + offset_y
        # (b, c, h*w*N)
        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)

        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)

        return x_offset

    #  Stacking resampled features in the row direction.
    @staticmethod
    def _reshape_x_offset(x_offset, num_param):
        b, c, h, w, n = x_offset.size()
        # using Conv3d
        # x_offset = x_offset.permute(0,1,4,2,3), then Conv3d(c,c_out, kernel_size =(num_param,1,1),stride=(num_param,1,1),bias= False)
        # using 1 × 1 Conv
        # x_offset = x_offset.permute(0,1,4,2,3), then, x_offset.view(b,c×num_param,h,w)  finally, Conv2d(c×num_param,c_out, kernel_size =1,stride=1,bias= False)
        # using the column conv as follow, then, Conv2d(inc, outc, kernel_size=(num_param, 1), stride=(num_param, 1), bias=bias)

        x_offset = rearrange(x_offset, 'b c h w n -> b c (h n) w')
        return x_offset

2.2 更改modules.__init__.py文件 

       打开ultralytics->nn->modules->__init__.py,在第64行与81行加入AKConv进行声明。

​2.3 更改task.py文件 

        打开ultralytics->nn路径下的tasks.py文件,首先在第51行加入AKConv导入模块,然后在第928行(或其他合适的位置)加入下方代码:

        elif m is AKConv:
            c2 = args[0]
            c1 = ch[f]
            args = [c1, c2, *args[1:]]

 2.4 更改yaml文件 

        创建yaml文件,使用AKConv替换yaml文件中原有的Conv模块。

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024] # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024] # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768] # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512] # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512] # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, Conv, [512, 3, 2]] # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, Conv, [1024, 3, 2]] # 7-P5/32
  - [-1, 1, AKConv, [256, 3]]
  - [-1, 1, SPPF, [1024, 5]] # 9

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 12

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 15 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 18 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2f, [1024]] # 21 (P5/32-large)

  - [[15, 18, 21], 1, Detect, [nc]] # Detect(P3, P4, P5)

 2.5 修改train.py文件

        在train.py脚本中填入创建好的yaml路径,运行即可训练。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1679254.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Linux——Centos7安装RabbitMQ】 RabbitMQ无法连接

到这一步是基本已经装好了,现在是在开放端口,我这个报错是因为我的防火墙是处于关闭状态,所以在开放端口时会报防火墙为运行,把防火墙打开,在开放端口,就可以访问到了 重启防火墙: systemctl …

白酒:酒精度数对白酒风味的影响与品鉴技巧

云仓酒庄豪迈白酒作为品质的白酒品牌,其酒精度数对白酒风味的影响与品鉴技巧是品鉴爱好者关注的重点。酒精度数作为衡量白酒质量的一项重要指标,不仅决定了白酒的口感和风格,更在一定程度上体现了白酒的品质和价值。本文将探讨酒精度数对云仓…

正点原子[第二期]Linux之ARM(MX6U)裸机篇学习笔记-15.6讲 GPIO中断实验-GPIO驱动添加中断处理函数

前言: 本文是根据哔哩哔哩网站上“正点原子[第二期]Linux之ARM(MX6U)裸机篇”视频的学习笔记,在这里会记录下正点原子 I.MX6ULL 开发板的配套视频教程所作的实验和学习笔记内容。本文大量引用了正点原子教学视频和链接中的内容。…

2024年第十届中西部外语翻译大赛

2024年第十届中西部外语翻译大赛 竞赛信息 “由中西部翻译协会共同体指导发起,各省市译协共建学术指导委员会,2024年第十届中西部外语翻译大赛由中西部翻译协会共同体秘书处(武汉公仪网络科技有限公司)承办。” - 获奖证书样图 -…

MQTT_服务器的安装_1.3

此例子是以Windows系统安装开源版本的EMQX 下载 EMQX 下载并解压 解压如图 进入bin 文件夹在文件目录中输入cmd回车 启动服务器 然后在cmd中输入下面的代码(会弹出一个访问网络的选项,确认可以访问网络) emqx start 结果如图(…

半小时搞懂STM32知识点——UART

1.UART 1.1为什么要使用UART这种协议?介绍一下UART及其特点 成本低,硬件简单,数据格式灵活; 低速全双工异步串行通信 1.2 UART数据帧格式? 起始位(1)+数据位(5-8) 校验位…

百面算法工程师 | YOLOv6面试考点原理全解析

本文给大家带来的百面算法工程师是深度学习目标检测YOLOv6面试总结,文章内总结了常见的提问问题,旨在为广大学子模拟出更贴合实际的面试问答场景。在这篇文章中,我们还将介绍一些常见的深度学习目标检测面试问题,并提供参考的回答…

项目管理—需求管理规程(软件研发过程标准,管理标准,标书技术编写,资质评审,安全管理体系,项目交付,实施运维,各类建设方案)

软件资料清单列表部分文档清单:工作安排任务书,可行性分析报告,立项申请审批表,产品需求规格说明书,需求调研计划,用户需求调查单,用户需求说明书,概要设计说明书,技术解…

GPT-4o 引领人机交互新风向,向量数据库赛道沸腾了

OpenAI 发布 ChatGPT-4o,意味着人机交互进入新的时代。Chat-GPT4o 是一个跨文本、视觉和音频端到端训练的新模型,所有输入和输出都由同一个神经网络处理。这也在告诉所有人,GenAI 连接非结构化数据,非结构化数据之间跨模态的交互正…

Geoserver

Geoserver GIS工具 文章目录 Geoserver前言一、Geoserver是什么?二、概念1.Geoserver结构图2.相关概念3.Geoserver相关站点4.Geoserver安装5.PostgreSQL安装1.拉取镜像2.创建挂载卷3.安装 6 Docker 环境安装postgrespostgis扩展 总结其他参考资料 前言 GeoServer&…

从开发板导出根文件系统并修改(Ubuntu)

前面提到过基于ubuntu-base去构建根文件系统基于Ubuntu-base构建根文件系统-CSDN博客,但是有时候我们并不需要重头开始,可以基于现有的根文件系统做调整。又或者我们直接在出厂的系统上去搭建好自己的运行环境并且编译出自己想要的程序,现在要…

Web浏览器的兼容性测试需要考虑哪些测试点?

测试web网站兼容性时,可以使用各种测试用例来确保网站在不同浏览器中的良好兼容性。以下是一些常见的兼容性测试用例示例: 1. 页面加载测试: - 确保网站在不同浏览器中正常加载,没有加载错误。 - 检查页面加载时间,…

Kivy UI界面

一、版本介绍 Ubuntu:18.04.6 LTS Conda:4.5.12 Python:3.6.15 Kivy:2.0.0 二、安装Kivy # 更新系统包列表 sudo apt-get update# 安装Kivy的依赖项 sudo apt-get install -y python-pip libsdl2-dev libsdl2-image-dev li…

【机器学习】:基于决策树与随机森林对数据分类

机器学习实验报告:决策树与随机森林数据分类 实验背景与目的 在机器学习领域,决策树和随机森林是两种常用的分类算法。决策树以其直观的树形结构和易于理解的特点被广泛应用于分类问题。随机森林则是一种集成学习算法,通过构建多个决策树并…

Galxe已投资Pencils Protocol,投资者阵营正不断扩大

近日,Scroll 生态项目 Penpad 将品牌进一步升级为 Pencils Protocol,全新升级后其不仅对 LaunchPad 平台进行了功能上的升级,同时其也进一步引入了 Staking、Vault 以及 Shop 等玩法,这也让 Pencils Protocol 的叙事方向不再仅限于…

“图生视频”技术创新:剪贴画秒变动画生成的实验验证与分析

在最近的研究进展中,AniClipart系统的问世标志着文本到视频生成技术的一个重要里程碑。这一系统由香港城市大学和莫纳什大学的研究者们共同开发,旨在解决将静态剪贴画图像根据文本提示自动转换成动画序列的挑战。传统的动画制作流程繁琐且耗时&#xff0…

Python 小抄

Python 备忘单 目录 1.语法和空格 2.注释 3.数字和运算 4.字符串处理 5.列表、元组和字典 6.JSON 7.循环 8.文件处理 9.函数 10.处理日期时间 11.NumPy 12.Pandas 要运行单元格,请按 ShiftEnter 或单击页面顶部的 Run(运行)。 1.语法和空格…

「每日跟读」英语常用句型公式 第15篇

「每日跟读」英语常用句型公式 第15篇 1. It’s only logical that __ 合理的做法/结论是__ It’s only logical that we should take a break (合理的做法是我们应该休息一下) It’s only logical that we work hard to make money(合理…

如何设计知识竞赛活动中的观众互动环节

知识竞赛活动过程中有多种方式进行观众互动,达到台上台下互动的效果,让台下观众参与到竞赛活动中,增加现场气氛。下面介绍几种常用观众互动环节设计方法。 一、台上选手对抗台下观众 此方案为台下观众和台上选手一起答题,如果台…

又双叒叕新增2本SCI期刊“On Hold“,慎投,有剔除风险!

本周投稿推荐 SSCI • 2区社科经管类,3.0-4.0(录用友好) EI • 计算机工程生物医学等(领域广,录用极快) CNKI • 3天内初审录用,随即出版(急录友好) SCI&EI …