爆改YOLOv8|利用可改变核卷积AKConv改进yolov8-轻量涨点

news2024/9/22 21:23:26

1,本文介绍

AKConv(可改变核卷积)是一种改进的卷积操作方法,其核心在于动态调整卷积核的形状和大小。与传统卷积层固定核大小不同,AKConv 通过引入可学习的机制,使卷积核在训练过程中能够自适应地调整,从而更好地适应不同的数据特征和任务需求。

核心特点:

  1. 可变核尺寸:AKConv 允许卷积核在不同的层和位置上具有不同的尺寸,这有助于捕捉更多的局部特征。

  2. 动态调整:卷积核的形状和大小可以在训练过程中进行调整,使得模型能够根据输入数据的特性自动优化卷积操作。

  3. 提高表达能力:通过自适应地调整核的参数,AKConv 可以提高网络的表达能力和性能,特别是在处理复杂或变化多端的输入数据时。

应用场景:

  • 计算机视觉:在图像分类、目标检测等任务中,AKConv 能够有效提升模型对各种尺度和形状特征的敏感度。
  • 特征提取:适用于需要捕捉多种尺度特征的应用,例如医学影像分析和高分辨率图像处理。

AKConv 提供了一种灵活且强大的卷积操作方式,能够在多个任务中提高模型的适应性和性能。

关于AKConv的详细介绍可以看论文:https://arxiv.org/pdf/2311.11587.pdf

本文将讲解如何将AKConv融合进yolov8

话不多说,上代码!

2, 将AKConv融合进yolov8

2.1 步骤一

找到如下的目录'ultralytics/nn/modules',然后在这个目录下创建一个AKConv.py文件,文件名字可以根据你自己的习惯起,然后将AKConv的核心代码复制进去

import torch.nn as nn
import torch
from einops import rearrange
import math
 
 
class AKConv(nn.Module):
    def __init__(self, inc, outc, num_param, stride=1, bias=None):
        super(AKConv, self).__init__()
        self.num_param = num_param
        self.stride = stride
        self.conv = nn.Sequential(nn.Conv2d(inc, outc, kernel_size=(num_param, 1), stride=(num_param, 1), bias=bias),
                                  nn.BatchNorm2d(outc),
                                  nn.SiLU())  # the conv adds the BN and SiLU to compare original Conv in YOLOv5.
        self.p_conv = nn.Conv2d(inc, 2 * num_param, kernel_size=3, padding=1, stride=stride)
        nn.init.constant_(self.p_conv.weight, 0)
        self.p_conv.register_full_backward_hook(self._set_lr)
 
    @staticmethod
    def _set_lr(module, grad_input, grad_output):
        grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
        grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
 
    def forward(self, x):
        # N is num_param.
        offset = self.p_conv(x)
        dtype = offset.data.type()
        N = offset.size(1) // 2
        # (b, 2N, h, w)
        p = self._get_p(offset, dtype)
 
        # (b, h, w, 2N)
        p = p.contiguous().permute(0, 2, 3, 1)
        q_lt = p.detach().floor()
        q_rb = q_lt + 1
 
        q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)],
                         dim=-1).long()
        q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)],
                         dim=-1).long()
        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
 
        # clip p
        p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1)
 
        # bilinear kernel (b, h, w, N)
        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
 
        # resampling the features based on the modified coordinates.
        x_q_lt = self._get_x_q(x, q_lt, N)
        x_q_rb = self._get_x_q(x, q_rb, N)
        x_q_lb = self._get_x_q(x, q_lb, N)
        x_q_rt = self._get_x_q(x, q_rt, N)
 
        # bilinear
        x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
                   g_rb.unsqueeze(dim=1) * x_q_rb + \
                   g_lb.unsqueeze(dim=1) * x_q_lb + \
                   g_rt.unsqueeze(dim=1) * x_q_rt
 
        x_offset = self._reshape_x_offset(x_offset, self.num_param)
        out = self.conv(x_offset)
 
        return out
 
    # generating the inital sampled shapes for the AKConv with different sizes.
    def _get_p_n(self, N, dtype):
        base_int = round(math.sqrt(self.num_param))
        row_number = self.num_param // base_int
        mod_number = self.num_param % base_int
        p_n_x, p_n_y = torch.meshgrid(
            torch.arange(0, row_number),
            torch.arange(0, base_int), indexing='xy')
        p_n_x = torch.flatten(p_n_x)
        p_n_y = torch.flatten(p_n_y)
        if mod_number > 0:
            mod_p_n_x, mod_p_n_y = torch.meshgrid(
                torch.arange(row_number, row_number + 1),
                torch.arange(0, mod_number),indexing='xy')
 
            mod_p_n_x = torch.flatten(mod_p_n_x)
            mod_p_n_y = torch.flatten(mod_p_n_y)
            p_n_x, p_n_y = torch.cat((p_n_x, mod_p_n_x)), torch.cat((p_n_y, mod_p_n_y))
        p_n = torch.cat([p_n_x, p_n_y], 0)
        p_n = p_n.view(1, 2 * N, 1, 1).type(dtype)
        return p_n
 
    # no zero-padding
    def _get_p_0(self, h, w, N, dtype):
        p_0_x, p_0_y = torch.meshgrid(
            torch.arange(0, h * self.stride, self.stride),
            torch.arange(0, w * self.stride, self.stride),indexing='xy')
 
        p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
 
        return p_0
 
    def _get_p(self, offset, dtype):
        N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3)
 
        # (1, 2N, 1, 1)
        p_n = self._get_p_n(N, dtype)
        # (1, 2N, h, w)
        p_0 = self._get_p_0(h, w, N, dtype)
        p = p_0 + p_n + offset
        return p
 
    def _get_x_q(self, x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        c = x.size(1)
        # (b, c, h*w)
        x = x.contiguous().view(b, c, -1)
 
        # (b, h, w, N)
        index = q[..., :N] * padded_w + q[..., N:]  # offset_x*w + offset_y
        # (b, c, h*w*N)
 
        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
 
        # 根据实际情况调整
        index = index.clamp(min=0, max=x.shape[-1] - 1)
 
        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
 
        return x_offset
 
    #  Stacking resampled features in the row direction.
    @staticmethod
    def _reshape_x_offset(x_offset, num_param):
        b, c, h, w, n = x_offset.size()
        # using Conv3d
        # x_offset = x_offset.permute(0,1,4,2,3), then Conv3d(c,c_out, kernel_size =(num_param,1,1),stride=(num_param,1,1),bias= False)
        # using 1 × 1 Conv
        # x_offset = x_offset.permute(0,1,4,2,3), then, x_offset.view(b,c×num_param,h,w)  finally, Conv2d(c×num_param,c_out, kernel_size =1,stride=1,bias= False)
        # using the column conv as follow, then, Conv2d(inc, outc, kernel_size=(num_param, 1), stride=(num_param, 1), bias=bias)
 
        x_offset = rearrange(x_offset, 'b c h w n -> b c (h n) w')
        return x_offset

2.2 步骤二

在task.py导入我们的模块

2.3 步骤三

在task.py的parse_model方法里面注册我们的模块

到此注册成功,复制后面的yaml文件直接运行即可

yaml文件


# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect
 
# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs
 
# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2
  - [-1, 1, AKConv, [128, 3, 2]]  # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, AKConv, [256, 3, 2]]  # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, AKConv, [512, 3, 2]]  # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, AKConv, [1024, 3, 2]]  # 7-P5/32
  - [-1, 3, C2f, [1024, True]]
  - [-1, 1, SPPF, [1024, 5]]  # 9
 
 
# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 6], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 12
 
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 4], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 15 (P3/8-small)
 
  - [-1, 1, AKConv, [256, 3, 2]]
  - [[-1, 12], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 18 (P4/16-medium)
 
  - [-1, 1, AKConv, [512, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 21 (P5/32-large)
 
  - [[15, 18, 21], 1, Detect, [nc]]  # Detect(P3, P4, P5)

不知不觉已经看完了哦,动动小手留个点赞收藏吧--_--

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2096065.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

优雅谈大模型:白话ZeRO 下

机器学习模型的复杂性和规模不断增长,分布式训练变得比以往任何时候都更加重要。训练具有数千亿参数的大型语言模型( LLMs )将是机器学习基础设施面临的挑战。与传统的分布式计算框架不同的地方在于GPU的分布式训练需要将数据传递给GPU芯片等…

JAVAEE初阶第二节——多线程基础(下)

系列文章目录 JAVAEE初阶第二节——多线程基础(下) 多线程基础(下) 单例模式阻塞式队列定时器线程池 文章目录 系列文章目录JAVAEE初阶第二节——多线程基础(下) 多线程基础(下) 一.多线程案例 1.单例模式1.1 饿汉模式 1.2 懒汉模式 1.2.1 懒汉模式-单线程版1.2.3 懒汉模式…

[Tools: LoRA] Diffusers中Stable Diffusion的实现

实现底层原理 Diffusers中的Attention操作实现在AttnProcessor类(diffusers.models.attention_processor.py),里面定义了单次Attention操作。添加LoRA,本质上是用LoRAAttnProcessor类替换AttnProcessor类。LoRAAttnProcessor中新…

github和gitlab的区别是什么

区别:github如果使用私有仓库,是需要付费的;而gitlab可以在上面搭建私人的免费仓库。gitlab让开发团队对他们的代码仓库拥有更多的控制,相对于github,它有不少的特色:允许免费设置仓库权限;可以…

自然语言处理-词向量转换

文章目录 一、简介1.含义2.基本原理3.常见转换方法1). 独热编码(One-Hot Encoding)2). 词袋模型(Bag of Words, BoW)3). TF-IDF(Term Frequency-Inverse Document Frequency&#xf…

网络工程师学习笔记——局域网和城域网

传统局域网(LAN) 局域网的主要特征:由网络拓扑结构所采用的协议类型以及介质访问的控制方法 分组广播式网络,所有的工作站都连接到共享的传输介质上,共享信道的分配技术是局域网的核心技术 局域网常见的设备&#x…

Centos Stream9网卡驱动重置无法找到网卡解决办法

1.问题原因 使用Centos Stream9系统时,我们正常在/etc/NetworkManager/system-connections目录下修改网络配置文件保存后,重置网卡会发现提示无法连接或没有找到该网卡,此问题有以下几点原因: linux系统重管理网络连接的有netwo…

巧妙的数(逐倍数判断)

cin>>s; 若s串=1236 lens=s.size(),pd=1,ys=0,p[10]={} 0<=i< l 开始运算: P[1]=p[2]=p[3]=p[6]=true; //下标做标记 若 p[6]=ture,则p[2]=p[3]=ture,p[6]=false pd=1 9>=k>1 若pd%k!=0&&p[k]=ture时,则pd*=k;

开学季老师如何发布分班?

开学啦&#xff0c;老师们又要开始忙碌了。但是&#xff0c;别担心&#xff0c;现在有个超方便的工具&#xff0c;让分班这件事变得简单又快速。以前分班可是个大工程&#xff0c;得一个个手动处理&#xff0c;现在不一样了&#xff0c;有了易查分这个小程序&#xff0c;一切都…

不可思议!分享6款AI论文大纲提纲自动生成器,导师直夸好

在当今学术研究和写作领域&#xff0c;人工智能&#xff08;AI&#xff09;技术的迅速发展为论文写作带来了革命性的变化。AI论文大纲生成器作为其中的重要工具&#xff0c;能够显著提高论文撰写效率和质量。本文将介绍六款AI论文大纲生成器&#xff0c;这些工具不仅能够帮助学…

如何使用Docker部署MySQL

一、查询镜像 使用如下命令“docker search mysql”即可查看docker仓库中所有的mysql的镜像。 使用了 docker search mysql 命令来搜索 MySQL 相关的 Docker 镜像。结果中列出了许多与 MySQL 相关的镜像&#xff0c;每个镜像都有名称、描述、星级评分&#xff08;表示受欢迎程…

yolo8 目标检测、鉴黄

省流 看前必读 别浪费时间 &#xff1a;本文只是一个记录&#xff0c;防止自己下次被改需求时浪费时间&#xff0c;在这里就随意的写了一下文章记录整个步骤&#xff0c;但是文章想必肯定没有对应的教程讲的详细&#xff0c;该文章只适合想要快速按照步骤完成一个简单的 demo 的…

存储系统总结

内存物理组成 SAM&#xff1a;顺序存取存储器&#xff0c;按照某种顺序存取&#xff0c;存取时间和在存储体上的物理位置有关系 DAM&#xff1a;直接存取存储器&#xff0c;先寻找一块小区域&#xff0c;接着顺序查找 RAM&#xff1a;随机存取存储器&#xff0c;存取时间与物理…

第8讲 ,ISP 串口程序下载

1 硬件的连接 需要使用 串口下载软件。 flymcu 这是 正点原子的 自启动电路。 2 stm32 的串口下载的原理 stm32 下载 只能是 串口一 &#xff0c; 也就是 PA9&#xff0c; PA10 3 然后是 stm32 的启动顺序 这里使用的是 第二种的 启动模式&#xff0c; 也就是 通过 串口进行烧…

Java 入门指南:Java 并发编程 —— Condition 灵活管理线程间的同步

Condition Condition 是 Java 并发编程中的一种高级同步工具&#xff0c;它可以协助线程之间进行等待和通信。提供了一种比传统的 wait() 和 notify() 更加灵活的方式来管理线程间的同步。Condition 接口通常与 Lock 接口一起使用&#xff0c;允许更细粒度的控制线程的等待和唤…

idea插件开发的第一天-写一个小Demo

介绍 Demo说明 本文基于maven项目开发,idea版本为2022.3以上,jdk为1.8本文在Tools插件之上进行开发 Tools插件说明 Tools插件是一个Idea插件,此插件提供统一Spi规范,极大的降低了idea插件的开发难度,并提供开发者模块,可以极大的为开发者开发此插件提供便利Tools插件安装需…

Python爬虫案例五:将获取到的文本生成词云图

基础知识&#xff1a; # 词云图 wordcloud # 1、导包 jieba wordcloud import jieba from wordcloud import WordCloud data 全年经济社会发展主要目标任务圆满完成 data_list list(jieba.cut(data)) # print(data_list) # generator数据类型# 2、构造词云图样式 》虚拟的…

LabVIEW与Python联合图像处理

LabVIEW可以将图片作为参数传递给Python进行处理。可以通过LabVIEW调用Python脚本&#xff0c;并传递图片数据。以下是如何实现这个功能的基本思路&#xff1a; 1. 在LabVIEW中读取图像 首先&#xff0c;使用LabVIEW中的图像处理函数&#xff08;如NI Vision Development Modu…

多态【C++】

文章目录 概念概念虚函数 定义及实现构成条件虚函数的重写override和final重载/重定义&#xff08;隐藏&#xff09;/重写&#xff08;覆盖&#xff09;的区别 抽象类概念接口继承和实现继承 多态的原理虚函数表 多继承关系的虚函数表 概念 概念 通俗来说&#xff0c;就是多种…

用TCC来解决多个第三方系统数据一致性问题

对于做集成的公司来说&#xff0c;会集成各种第三方系统&#xff0c;要么是通过第三方系统的api&#xff0c;要么直接集成第三方系统的设备。如果是通过api集成&#xff0c;单次请求只调用一个三方系统没问题&#xff0c;同步调用就行&#xff0c;但如果同时要调用多个三方系统…