YOLOv8改进 | 主干网络 | ⭐重写星辰Rewrite the Stars⭐【CVPR2024】

news2024/11/18 10:38:21

 秋招面试专栏推荐深度学习算法工程师面试问题总结【百面算法工程师】——点击即可跳转


💡💡💡本专栏所有程序均经过测试,可成功执行💡💡💡


专栏目录 :《YOLOv8改进有效涨点》专栏介绍 & 专栏目录 | 目前已有70+篇内容,内含各种Head检测头、损失函数Loss、Backbone、Neck、NMS等创新点改进——点击即可跳转


最近的研究关注到了在网络设计中尚未充分挖掘的“星运算”(逐元素乘法)的潜力。尽管直观的解释有很多,但其应用的基础原理仍然在很大程度上未被探索。我们的研究试图揭示星运算的能力,它能够在不扩大网络的情况下,将输入映射到高维、非线性的特征空间中,类似于核技巧。我们进一步引入了StarNet,这是一个简单但强大的原型,它在紧凑的网络结构和高效的预算下展现了令人印象深刻的性能和低延迟就像天空中的星星一样,星运算看似不起眼,却蕴含着广阔的潜力宇宙。文章在介绍主要的原理后,将手把手教学如何进行模块的代码添加和修改,并将修改后的完整代码放在文章的最后,方便大家一键运行,小白也可轻松上手实践。以帮助您更好地学习深度学习目标检测YOLO系列的挑战。

专栏地址YOLOv8改进——更新各种有效涨点方法——点击即可跳转

目录

1.原理 

2. 将starNet添加到yolov8网络中

2.1 starNet代码实现

2.2 更改init.py文件

2.3 添加yaml文件

2.4 注册模块

2.5 替换函数

2.6 执行程序

3. 完整代码分享

4. GFLOPs

5. 进阶

6. 总结


1.原理 

论文地址:Rewrite the Stars⭐ ——点击即可跳转

官方代码:官方代码仓库——点击即可跳转

StarNet的主要原理可以总结如下:

StarNet的核心思想

  1. 星操作(Star Operation):

    • 星操作是指元素级乘法(element-wise multiplication),用于融合不同子空间的特征。

    • 这种操作能够将输入特征映射到高维非线性特征空间,类似于核技巧(kernel tricks),而不需要增加网络的宽度(通道数)。

  2. 高维非线性特征映射:

    • 星操作通过元素级乘法生成一个新的特征空间,这个特征空间具有大约((d\sqrt{2})^2)个线性独立的维度。

    • 这种操作与传统神经网络通过增加网络宽度来获得高维特征的方式不同,更像是多项式核函数(polynomial kernel functions)的操作。

  3. 高效的紧凑网络结构:

    • 通过堆叠多个星操作层,每一层都显著增加隐含的维度复杂度。

    • 即使在紧凑的特征空间中操作,星操作仍然能够利用隐含的高维特征。

  4. 性能和效率:

    • 星操作在性能和效率上表现出色,尤其是在网络宽度较小的情况下。

    • StarNet使用星操作设计了一个简单但高效的网络结构,证明了其在紧凑网络中的有效性。

StarNet的实现

  1. 网络结构:

    • StarNet通过堆叠多层星操作层构建

         2. 星操作层的基本形式为(W_1^TX + B_1) \ast (W_2^TX + B_2),即通过元素级乘法融合两个线性变换后的特征。

  1. 理论分析和实验验证:

    • 通过理论分析证明了星操作能够在一个层内将输入特征映射到一个高维非线性特征空间,并且在多层堆叠下能够递归地显著增加隐含的特征维度。

    • 实验结果表明,StarNet在ImageNet-1K验证集上的表现优于许多精心设计的高效模型,并且在实际应用中具有较低的延迟和较高的运行效率。

  2. 比较与优势:

    • 与现有的高效网络设计相比,StarNet没有复杂的设计和超参数调优,仅依赖于星操作的高效性。

    • StarNet的设计理念与传统方法(如卷积、线性层和非线性激活的结合)有明显的不同,强调利用隐含的高维特征来提升网络效率。

综上所述,StarNet通过星操作实现了在紧凑网络中的高效性和高性能,展示了元素级乘法在特征融合中的巨大潜力和应用前景。

2. 将starNet添加到yolov8网络中

2.1 starNet代码实现

关键步骤一在/ultralytics/ultralytics/nn/modules下新建backbone/starNet.py,并将下面代码粘贴到文件内

"""
Implementation of Prof-of-Concept Network: StarNet.

We make StarNet as simple as possible [to show the key contribution of element-wise multiplication]:
    - like NO layer-scale in network design,
    - and NO EMA during training,
    - which would improve the performance further.

Created by: Xu Ma (Email: ma.xu1@northeastern.edu)
Modified Date: Mar/29/2024
"""

import torch
import torch.nn as nn
from timm.models.layers import DropPath, trunc_normal_

__all__ = [
    "starnet_s050",
    "starnet_s100",
    "starnet_s150",
    "starnet_s1",
    "starnet_s2",
    "starnet_s3",
    "starnet_s4",
]

model_urls = {
    "starnet_s1": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s1.pth.tar",
    "starnet_s2": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s2.pth.tar",
    "starnet_s3": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s3.pth.tar",
    "starnet_s4": "https://github.com/ma-xu/Rewrite-the-Stars/releases/download/checkpoints_v1/starnet_s4.pth.tar",
}


class ConvBN(torch.nn.Sequential):
    def __init__(
        self,
        in_planes,
        out_planes,
        kernel_size=1,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        with_bn=True,
    ):
        super().__init__()
        self.add_module(
            "conv",
            torch.nn.Conv2d(
                in_planes, out_planes, kernel_size, stride, padding, dilation, groups
            ),
        )
        if with_bn:
            self.add_module("bn", torch.nn.BatchNorm2d(out_planes))
            torch.nn.init.constant_(self.bn.weight, 1)
            torch.nn.init.constant_(self.bn.bias, 0)


class Block(nn.Module):
    def __init__(self, dim, mlp_ratio=3, drop_path=0.0):
        super().__init__()
        self.dwconv = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=True)
        self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
        self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
        self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True)
        self.dwconv2 = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=False)
        self.act = nn.ReLU6()
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x1, x2 = self.f1(x), self.f2(x)
        x = self.act(x1) * x2
        x = self.dwconv2(self.g(x))
        x = input + self.drop_path(x)
        return x


class StarNet(nn.Module):
    def __init__(
        self,
        base_dim=32,
        depths=[3, 3, 12, 5],
        mlp_ratio=4,
        drop_path_rate=0.0,
        num_classes=1000,
        **kwargs
    ):
        super().__init__()
        self.num_classes = num_classes
        self.in_channel = 32
        # stem layer
        self.stem = nn.Sequential(
            ConvBN(3, self.in_channel, kernel_size=3, stride=2, padding=1), nn.ReLU6()
        )
        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
        ]  # stochastic depth
        # build stages
        self.stages = nn.ModuleList()
        cur = 0
        for i_layer in range(len(depths)):
            embed_dim = base_dim * 2**i_layer
            down_sampler = ConvBN(self.in_channel, embed_dim, 3, 2, 1)
            self.in_channel = embed_dim
            blocks = [
                Block(self.in_channel, mlp_ratio, dpr[cur + i])
                for i in range(depths[i_layer])
            ]
            cur += depths[i_layer]
            self.stages.append(nn.Sequential(down_sampler, *blocks))

        self.channel = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear or nn.Conv2d):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm or nn.BatchNorm2d):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x):
        features = []
        x = self.stem(x)
        features.append(x)
        for stage in self.stages:
            x = stage(x)
            features.append(x)
        return features


def starnet_s1(pretrained=False, **kwargs):
    model = StarNet(24, [2, 2, 8, 3], **kwargs)
    if pretrained:
        url = model_urls["starnet_s1"]
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"], strict=False)
    return model


def starnet_s2(pretrained=False, **kwargs):
    model = StarNet(32, [1, 2, 6, 2], **kwargs)
    if pretrained:
        url = model_urls["starnet_s2"]
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"], strict=False)
    return model


def starnet_s3(pretrained=False, **kwargs):
    model = StarNet(32, [2, 2, 8, 4], **kwargs)
    if pretrained:
        url = model_urls["starnet_s3"]
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"], strict=False)
    return model


def starnet_s4(pretrained=False, **kwargs):
    model = StarNet(32, [3, 3, 12, 5], **kwargs)
    if pretrained:
        url = model_urls["starnet_s4"]
        checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
        model.load_state_dict(checkpoint["state_dict"], strict=False)
    return model


# very small networks #


def starnet_s050(pretrained=False, **kwargs):
    return StarNet(16, [1, 1, 3, 1], 3, **kwargs)


def starnet_s100(pretrained=False, **kwargs):
    return StarNet(20, [1, 2, 4, 1], 4, **kwargs)


def starnet_s150(pretrained=False, **kwargs):
    return StarNet(24, [1, 2, 4, 2], 3, **kwargs)

StarNet 处理图像的主要工作流程可以分为几个关键步骤,重点关注独特的星型运算及其在分层网络结构中的集成。以下是主要流程的概述:

1. 分层架构

  • StarNet 采用 4 阶段分层架构。每个阶段使用卷积层来降低分辨率,同时将通道数加倍。这种分层方法在现代神经网络设计中很常见,有助于逐步抽象多个级别的图像特征。

2. 星型运算

  • StarNet 的核心创新是星型运算,它在低维空间中计算但产生高维特征。此操作包括使用两组权重转换输入、执行元素乘法以及有选择地应用激活函数。

  • 具体来说,星型操作通常实现为act(W_1^T X) \cdot (W_2^T X),其中W_1W_2是权重矩阵,act  是应用于转换分支之一的激活函数。

3. 块设计和激活放置

  • StarNet 中的每个块都使用深度卷积、批量归一化和简化的激活函数 (ReLU6) 来提高效率。这些块设计简约,避免使用复杂的结构,以强调星型操作的有效性。

  • 激活的位置经过战略性放置,以平衡计算效率和模型性能。文档中的研究表明,与其他位置相比,仅激活一个分支可获得更好的准确性。

4. 通道扩展和网络宽度

  • 在整个网络中,通道扩展始终设置为 4 倍,网络宽度在每个阶段加倍。这种策略有助于管理计算负载,同时提高网络学习复杂特征的能力。

5. 训练和实施细节

  • StarNet 采用类似于 DeiT(数据高效图像变换器)的标准训练方案,以确保公平比较和高效训练。

  • 在推理过程中,批量规范化层融合以简化网络并减少延迟。

通过利用这些技术,StarNet 设法在简单性、计算效率和高性能之间取得平衡,使其成为图像处理任务的引人注目的方法。

2.2 更改init.py文件

关键步骤二:修改modules文件夹下的__init__.py文件,先导入函数  

然后在下面的__all__中声明函数

2.3 添加yaml文件

关键步骤三:在/ultralytics/ultralytics/cfg/models/v8下面新建文件yolov8_starNet.yaml文件,粘贴下面的内容

  • OD【目标检测】
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# 0-P1/2
# 1-P2/4
# 2-P3/8
# 3-P4/16
# 4-P5/32

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, starnet_s050, []]  # 4
  - [-1, 1, SPPF, [1024, 5]]  # 5

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 6
  - [[-1, 3], 1, Concat, [1]]  # 7 cat backbone P4
  - [-1, 3, C2f, [512]]  # 8

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 9
  - [[-1, 2], 1, Concat, [1]]  # 10 cat backbone P3
  - [-1, 3, C2f, [256]]  # 11 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]] # 12
  - [[-1, 8], 1, Concat, [1]]  # 13 cat head P4
  - [-1, 3, C2f, [512]]  # 14 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]] # 15
  - [[-1, 5], 1, Concat, [1]]  # 16 cat head P5
  - [-1, 3, C2f, [1024]]  # 17 (P5/32-large)

  - [[11, 14, 17], 1, Detect, [nc]]  # Detect(P3, P4, P5)
  • Seg【语义分割】
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# 0-P1/2
# 1-P2/4
# 2-P3/8
# 3-P4/16
# 4-P5/32

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, starnet_s050, []]  # 4
  - [-1, 1, SPPF, [1024, 5]]  # 5

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 6
  - [[-1, 3], 1, Concat, [1]]  # 7 cat backbone P4
  - [-1, 3, C2f, [512]]  # 8

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 9
  - [[-1, 2], 1, Concat, [1]]  # 10 cat backbone P3
  - [-1, 3, C2f, [256]]  # 11 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]] # 12
  - [[-1, 8], 1, Concat, [1]]  # 13 cat head P4
  - [-1, 3, C2f, [512]]  # 14 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]] # 15
  - [[-1, 5], 1, Concat, [1]]  # 16 cat head P5
  - [-1, 3, C2f, [1024]]  # 17 (P5/32-large)

  - [[11, 14, 17], 1,  Segment, [nc, 32, 256]]  # Detect(P3, P4, P5)
  • OBB【旋转检测】
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# 0-P1/2
# 1-P2/4
# 2-P3/8
# 3-P4/16
# 4-P5/32

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [-1, 1, starnet_s050, []]  # 4
  - [-1, 1, SPPF, [1024, 5]]  # 5

# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 6
  - [[-1, 3], 1, Concat, [1]]  # 7 cat backbone P4
  - [-1, 3, C2f, [512]]  # 8

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 9
  - [[-1, 2], 1, Concat, [1]]  # 10 cat backbone P3
  - [-1, 3, C2f, [256]]  # 11 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]] # 12
  - [[-1, 8], 1, Concat, [1]]  # 13 cat head P4
  - [-1, 3, C2f, [512]]  # 14 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]] # 15
  - [[-1, 5], 1, Concat, [1]]  # 16 cat head P5
  - [-1, 3, C2f, [1024]]  # 17 (P5/32-large)

  - [[11, 14, 17], 1, OBB, [nc, 1]]  # Detect(P3, P4, P5)

温馨提示:因为本文只是对yolov8基础上添加模块,如果要对yolov8n/l/m/x进行添加则只需要指定对应的depth_multiple 和 width_multiple。 


# YOLOv8n
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.25  # layer channel multiple
max_channels: 1024 # max_channels
 
# YOLOv8s
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
max_channels: 1024 # max_channels
 
# YOLOv8l 
depth_multiple: 1.0  # model depth multiple
width_multiple: 1.0  # layer channel multiple
max_channels: 512 # max_channels
 
# YOLOv8m
depth_multiple: 0.67  # model depth multiple
width_multiple: 0.75  # layer channel multiple
max_channels: 768 # max_channels
 
# YOLOv8x
depth_multiple: 1.33  # model depth multiple
width_multiple: 1.25  # layer channel multiple
max_channels: 512 # max_channels

2.4 注册模块

关键步骤四:在task.py的parse_model函数替换为下面的内容

def parse_model(
    d, ch, verbose=True, warehouse_manager=None
):  # model_dict, input_channels(3)
    """Parse a YOLO model.yaml dictionary into a PyTorch model."""
    import ast

    # Args
    max_channels = float("inf")
    nc, act, scales = (d.get(x) for x in ("nc", "activation", "scales"))
    depth, width, kpt_shape = (
        d.get(x, 1.0) for x in ("depth_multiple", "width_multiple", "kpt_shape")
    )
    if scales:
        scale = d.get("scale")
        if not scale:
            scale = tuple(scales.keys())[0]
            LOGGER.warning(
                f"WARNING ⚠️ no model scale passed. Assuming scale='{scale}'."
            )
        depth, width, max_channels = scales[scale]

    if act:
        Conv.default_act = eval(
            act
        )  # redefine default activation, i.e. Conv.default_act = nn.SiLU()
        if verbose:
            LOGGER.info(f"{colorstr('activation:')} {act}")  # print

    if verbose:
        LOGGER.info(
            f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}"
        )
    ch = [ch]
    layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch out
    is_backbone = False
    for i, (f, n, m, args) in enumerate(
        d["backbone"] + d["head"]
    ):  # from, number, module, args
        try:
            if m == "node_mode":
                m = d[m]
                if len(args) > 0:
                    if args[0] == "head_channel":
                        args[0] = int(d[args[0]])
            t = m
            m = getattr(torch.nn, m[3:]) if "nn." in m else globals()[m]  # get module
        except:
            pass
        for j, a in enumerate(args):
            if isinstance(a, str):
                with contextlib.suppress(ValueError):
                    try:
                        args[j] = locals()[a] if a in locals() else ast.literal_eval(a)
                    except:
                        args[j] = a

        n = n_ = max(round(n * depth), 1) if n > 1 else n  # depth gain
        if m in (
            Classify,
            Conv,
            ConvTranspose,
            GhostConv,
            Bottleneck,
            GhostBottleneck,
            SPP,
            SPPF,
            DWConv,
            Focus,
            BottleneckCSP,
            C1,
            C2,
            C2f,
            C3,
            C3TR,
            C3Ghost,
            nn.Conv2d,
            nn.ConvTranspose2d,
            DWConvTranspose2d,
            C3x,
            RepC3,
        ):
            if args[0] == "head_channel":
                args[0] = d[args[0]]
            c1, c2 = ch[f], args[0]
            if (
                c2 != nc
            ):  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)

            args = [c1, c2, *args[1:]]
            if m in (
                RepNCSPELAN4,
            ):
                args[2] = make_divisible(min(args[2], max_channels) * width, 8)
                args[3] = make_divisible(min(args[3], max_channels) * width, 8)

            if m in (
                BottleneckCSP,
                C1,
                C2,
                C2f,
                C3,
                C3TR,
                C3Ghost,
                C3x,
                RepC3,
            ):
                args.insert(2, n)  # number of repeats
                n = 1
        elif m is AIFI:
            args = [ch[f], *args]
        elif m in (HGStem, HGBlock):
            c1, cm, c2 = ch[f], args[0], args[1]
            if (
                c2 != nc
            ):  # if c2 not equal to number of classes (i.e. for Classify() output)
                c2 = make_divisible(min(c2, max_channels) * width, 8)
                cm = make_divisible(min(cm, max_channels) * width, 8)
            args = [c1, cm, c2, *args[2:]]
            if m in (HGBlock):
                args.insert(4, n)  # number of repeats
                n = 1
        elif m is ResNetLayer:
            c2 = args[1] if args[3] else args[1] * 4
        elif m is nn.BatchNorm2d:
            args = [ch[f]]
        elif m is Concat:
            c2 = sum(ch[x] for x in f)
        elif m in (
            Detect,
            Segment,
            Pose,
            OBB,
        ):
            args.append([ch[x] for x in f])
            if m in (
                Segment,
            ):
                args[2] = make_divisible(min(args[2], max_channels) * width, 8)

        elif m is RTDETRDecoder:  # special case, channels arg must be passed in index 1
            args.insert(1, [ch[x] for x in f])
        elif m is CBLinear:
            c2 = make_divisible(min(args[0][-1], max_channels) * width, 8)
            c1 = ch[f]
            args = [
                c1,
                [make_divisible(min(c2_, max_channels) * width, 8) for c2_ in args[0]],
                *args[1:],
            ]
        elif m is CBFuse:
            c2 = ch[f[-1]]
        elif isinstance(m, str):
            t = m
            if len(args) == 2:
                m = timm.create_model(
                    m,
                    pretrained=args[0],
                    pretrained_cfg_overlay={"file": args[1]},
                    features_only=True,
                )
            elif len(args) == 1:
                m = timm.create_model(m, pretrained=args[0], features_only=True)
            c2 = m.feature_info.channels()
        elif m in {
            starnet_s050,
            starnet_s100,
            starnet_s150,
            starnet_s1,
            starnet_s2,
            starnet_s3,
            starnet_s4,
        }:
            m = m(*args)
            c2 = m.channel


        else:
            c2 = ch[f]

        if isinstance(c2, list):
            is_backbone = True
            m_ = m
            m_.backbone = True
        else:
            m_ = (
                nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)
            )  # module
            t = str(m)[8:-2].replace("__main__.", "")  # module type
        m.np = sum(x.numel() for x in m_.parameters())  # number params
        m_.i, m_.f, m_.type = (
            i + 4 if is_backbone else i,
            f,
            t,
        )  # attach index, 'from' index, type
        if verbose:
            LOGGER.info(
                f"{i:>3}{str(f):>20}{n_:>3}{m.np:10.0f}  {t:<45}{str(args):<30}"
            )  # print
        save.extend(
            x % (i + 4 if is_backbone else i)
            for x in ([f] if isinstance(f, int) else f)
            if x != -1
        )  # append to savelist
        layers.append(m_)
        if i == 0:
            ch = []
        if isinstance(c2, list):
            ch.extend(c2)
            for _ in range(5 - len(ch)):
                ch.insert(0, 0)
        else:
            ch.append(c2)
    return nn.Sequential(*layers), sorted(save)

2.5 替换函数

关键步骤五:在task.py的BaseModel类下的_predict_once函数替换为下面的内容

    def _predict_once(self, x, profile=False, visualize=False, embed=None):
        """
        Perform a forward pass through the network.

        Args:
            x (torch.Tensor): The input tensor to the model.
            profile (bool):  Print the computation time of each layer if True, defaults to False.
            visualize (bool): Save the feature maps of the model if True, defaults to False.
            embed (list, optional): A list of feature vectors/embeddings to return.

        Returns:
            (torch.Tensor): The last output of the model.
        """
        y, dt, embeddings = [], [], []  # outputs
        for m in self.model:
            if m.f != -1:  # if not from previous layer
                x = (
                    y[m.f]
                    if isinstance(m.f, int)
                    else [x if j == -1 else y[j] for j in m.f]
                )  # from earlier layers
            if profile:
                self._profile_one_layer(m, x, dt)
            if hasattr(m, "backbone"):
                x = m(x)
                for _ in range(5 - len(x)):
                    x.insert(0, None)
                for i_idx, i in enumerate(x):
                    if i_idx in self.save:
                        y.append(i)
                    else:
                        y.append(None)
                # for i in x:
                #     if i is not None:
                #         print(i.size())
                x = x[-1]
            else:
                x = m(x)  # run
                y.append(x if m.i in self.save else None)  # save output
            if visualize:
                feature_visualization(x, m.type, m.i, save_dir=visualize)
            if embed and m.i in embed:
                embeddings.append(
                    nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)
                )  # flatten
                if m.i == max(embed):
                    return torch.unbind(torch.cat(embeddings, 1), dim=0)
        return x

2.6 执行程序

在train.py中,将model的参数路径设置为yolov8_starNet.yaml的路径

建议大家写绝对路径,确保一定能找到

from ultralytics import YOLO
 
# Load a model
# model = YOLO('yolov8n.yaml')  # build a new model from YAML
# model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)
 
model = YOLO(r'/projects/ultralytics/ultralytics/cfg/models/v8/yolov8_starNet.yaml')  # build from YAML and transfer weights
 
# Train the model
model.train(batch=16)

🚀运行程序,如果出现下面的内容则说明添加成功🚀 

                   from  n    params  module                                       arguments                     
  0                  -1  1    413472  starnet_s050                                 []                            
  1                  -1  1     74368  ultralytics.nn.modules.block.SPPF            [128, 256, 5]                 
  2                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
  3             [-1, 3]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
  4                  -1  1    140032  ultralytics.nn.modules.block.C2f             [320, 128, 1]                 
  5                  -1  1         0  torch.nn.modules.upsampling.Upsample         [None, 2, 'nearest']          
  6             [-1, 2]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
  7                  -1  1     35200  ultralytics.nn.modules.block.C2f             [160, 64, 1]                  
  8                  -1  1     36992  ultralytics.nn.modules.conv.Conv             [64, 64, 3, 2]                
  9             [-1, 8]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 10                  -1  1    123648  ultralytics.nn.modules.block.C2f             [192, 128, 1]                 
 11                  -1  1    147712  ultralytics.nn.modules.conv.Conv             [128, 128, 3, 2]              
 12             [-1, 5]  1         0  ultralytics.nn.modules.conv.Concat           [1]                           
 13                  -1  1    493056  ultralytics.nn.modules.block.C2f             [384, 256, 1]                 
 14        [11, 14, 17]  1    897664  ultralytics.nn.modules.head.Detect           [80, [64, 128, 256]]          
YOLOv8_starnet summary: 249 layers, 2362144 parameters, 2362128 gradients 7.2 GFLOPs

3. 完整代码分享

https://pan.baidu.com/s/1Xn-L1oK3c5jFgyCNwNMTVg?pwd=6m7k

 提取码: 6m7k 

4. GFLOPs

关于GFLOPs的计算方式可以查看百面算法工程师 | 卷积基础知识——Convolution

未改进的YOLOv8nGFLOPs

img

改进后的GFLOPs

现在手上没有卡了,等过段时候有卡了把这补上,需要的同学自己测一下

5. 进阶

可以与其他的注意力机制或者损失函数等结合,进一步提升检测效果

6. 总结

StarNet通过引入一种称为星操作(Star Operation)的核心创新,实现了在紧凑网络中高效处理图像的目标。星操作通过两个线性变换后的特征进行元素级乘法,生成一个高维非线性特征空间,类似于多项式核函数的效果。StarNet采用一个四阶段的分层架构,每个阶段通过卷积层下采样分辨率并增加通道数,从而逐步抽象图像特征。每个网络块中使用深度卷积、批量归一化和ReLU6激活函数,简化设计以提高效率。通道扩展固定为4倍,网络宽度在每个阶段翻倍,以平衡计算负载和学习能力。通过标准的训练方法和推理时融合批量归一化层,StarNet在ImageNet-1K等基准测试中展示了其简单设计下的高效性能,表现出色。综上所述,StarNet通过巧妙的特征融合策略和高效的网络结构,在保持低计算复杂度的同时,实现了卓越的图像处理性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1955453.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vulntarget-a

实际部署之后的win7 ip: 192.168.127.128 具体攻击过程如下 win7 扫描服务 使用fscan扫描win 7中的服务以及漏洞 ./fscan -h 192.168.127.128 扫出来一个ms17-010以及通达oa的漏洞&#xff0c;既然有永恒之蓝的&#xff0c;直接上MSF就行了 msf6 > search ms17-010 msf6…

架子鼓鼓谱制谱什么软件好 Guitar Pro怎么编鼓谱

对那些想学架子鼓又缺乏预算的朋友来说&#xff0c;通过架子鼓谱子软件来模拟学习也是个不错的选择。不过&#xff0c;由于市面上类似的软件过多&#xff0c;许多人都挑花了眼&#xff01;那么&#xff0c;架子鼓谱子软件哪个比较好用呢&#xff1f;下面小编就来推荐几款好用的…

HTML--JavaScript操作DOM对象

目录 本章目标 一.DOM对象概念 ​编辑 二.节点访问方法 常用方法&#xff1a; 层次关系访问节点 三.节点信息 四.节点的操作方法 操作节点的属性 创建节点 删除替换节点 五.节点操作样式 style属性 class-name属性 六.获取元素位置 总结 本章目标 了解DOM的分类和节…

【论文精读】 | 基于图表示的视频抑郁症识别的两阶段时间建模框架

文章目录 0、Description1、Introduction2、Related work2.1 Relationship between depression and facial behaviours2.2 Video-based automatic depression analysis2.3 Facial graph representation 3、The proposed two-stage approach3.1 Short-term depressive behaviour…

18746 逆序数

这个问题可以使用归并排序的思想来解决。在归并排序的过程中&#xff0c;我们可以统计逆序数的数量。当我们合并两个已经排序的数组时&#xff0c;如果左边的数组中的元素&#xfffd;&#xfffd;于右边的数组中的元素&#xff0c;那么就存在逆序&#xff0c;逆序数的数量就是…

力扣SQL50 按分类统计薪水 条件统计

Problem: 1907. 按分类统计薪水 文章目录 思路Code 思路 &#x1f468;‍&#x1f3eb; 参考题解 Code SELECT Low Salary AS category,SUM(CASE WHEN income < 20000 THEN 1 ELSE 0 END) AS accounts_count FROM AccountsUNION SELECT Average Salary category,SUM(C…

DBus快速入门

DBus快速入门 参考链接&#xff1a; 中文博客&#xff1a; https://www.e-learn.cn/topic/1808992 https://blog.csdn.net/u011942101/article/details/123383195 https://blog.csdn.net/weixin_44498318/article/details/115803936 https://www.e-learn.cn/topic/1808992 htt…

Vue3 Pinia的创建与使用代替Vuex 全局数据共享 同步异步

介绍 提供跨组件和页面的共享状态能力&#xff0c;作为Vuex的替代品&#xff0c;专为Vue3设计的状态管理库。 Vuex&#xff1a;在Vuex中&#xff0c;更改状态必须通过Mutation或Action完成&#xff0c;手动触发更新。Pinia&#xff1a;Pinia的状态是响应式的&#xff0c;当状…

JAVA中的异常:异常的分类+异常的处理

文章目录 1. 异常的分类1.1 Error1.2 Exception1.2.1 运行时异常1.2.2 非运行时异常 2.异常的处理2.1 try catch用法2.2 throws和throw的区别 1. 异常的分类 Throwable类&#xff08;表示可抛&#xff09;是所有异常和错误的超类&#xff0c;两个直接子类为Error和Exception分…

C++11新特性——智能指针——参考bibi《 原子之音》的视频以及ChatGpt

智能指针 一、内存泄露1.1 内存泄露常见原因1.2 如何避免内存泄露 二、实例Demo2.1 文件结构2.2 Dog.h2.3 Dog.cpp2.3 mian.cpp 三、独占式智能指针:unique _ptr3.1 创建方式3.1.1 ⭐从原始(裸)指针转换&#xff1a;3.1.2 ⭐⭐使用 new 关键字直接创建&#xff1a;3.1.3 ⭐⭐⭐…

敏感信息泄露wp

1.右键查看网页源代码 2.前台JS绕过&#xff0c;ctrlU绕过JS查看源码 3.开发者工具&#xff0c;网络&#xff0c;查看协议 4.后台地址在robots,拼接目录/robots.txt 5.用dirsearch扫描&#xff0c;看到index.phps,phps中有源码&#xff0c;拼接目录&#xff0c;下载index.phps …

【Go - context 速览,场景与用法】

作用 context字面意思上下文&#xff0c;用于关联管理上下文&#xff0c;具体有如下几个作用 取消信号传递&#xff1a;可以用来传递取消信号&#xff0c;让一个正在执行的函数知道它应该提前终止。超时控制&#xff1a;可以设定一个超时时间&#xff0c;自动取消超过执行时间…

Oat++ 后端实现跨域

这里记录在官方的例子中&#xff0c;加入跨域。Oat Example-CRUD 在官方的例子中&#xff0c;加入跨域。 Oat Example-CRUD 修改AppComponent.hpp文件中的代码&#xff0c;如下&#xff1a; #include "AppComponent.hpp"#include "controller/UserController…

8、从0搭建企业门户网站——网站部署

目录 正文 1、域名解析 2、云服务器端口授权 3、Mysql数据库初始化 4、上传网站软件包 5、Tomcat配置 6、运行Tomcat 7、停止Tomcat 8、部署后发现验证码无法使用 完毕! 正文 当云服务器租用、域名购买和软件开发都完成后,我们就可以开始网站部署上线,ICP备案会长…

表达式的转换

题目&#xff1a; 表达式的转换 - 洛谷 P1175 - Virtual Judge 思路&#xff1a; 这道题可以拆成两问&#xff1a; 第一问&#xff0c;将中缀表达式转成后缀表达式放入一个数组 第二问&#xff0c;将后缀表达式的数组计算&#xff0c;并输出过程 第一问思路&#xff1a; …

Python高维度大型气象矩阵存储策略分享

零、前情提要 最近需要分析全球范围多变量的数值预报数据&#xff0c;将grb格式的数据下载下来经过一通处理后需要将预处理数据先保存一遍&#xff0c;方便后续操作&#xff0c;处理完发现此时的数据维度很多&#xff0c;数据量巨大&#xff0c;使用不同的保存策略的解析难度和…

信息安全技术解析

在信息爆炸的今天&#xff0c;信息技术安全已成为社会发展的重要基石。随着网络攻击的日益复杂和隐蔽&#xff0c;保障数据安全、提升防御能力成为信息技术安全领域的核心任务。本文将从加密解密技术、安全行为分析技术和网络安全态势感知技术三个方面进行深入探讨&#xff0c;…

C++笔试强训9

文章目录 一、选择题1-5题6-10题 二、编程题题目一题目二 一、选择题 1-5题 函数形参是个int类型的引用&#xff0c;传参的时候直接传一个int类型的变量就行 故选A 1.malloc/calloc/realloc—>free 2. new / delete 3. new[] / delete[] 一定要匹配起来使用&#xff0c;否则…

猫头虎分享:Numpy知识点一文带你详细学习np.random.randn()

&#x1f42f; 猫头虎分享&#xff1a;Numpy知识点一文带你详细学习np.random.randn() 摘要 Numpy 是数据科学和机器学习领域中不可或缺的工具。在本篇文章中&#xff0c;我们将深入探讨 np.random.randn()&#xff0c;一个用于生成标准正态分布的强大函数。通过详细的代码示…

ADS 使用教程(二十九)Understanding Bounding Area Layer for FEM

上一篇&#xff1a;ADS 使用教程&#xff08;二十八&#xff09;Working with FEM Mesh & Field Data in ADS 这一节&#xff0c;我们来一起了解一下有限元法&#xff08;FEM&#xff09;中的边界区域层&#xff08;Bounding Area Layer&#xff09;&#xff0c;这是定义仿…