C3与C2f模块介绍与代码

news2024/9/21 16:47:22

C3与C2f模块介绍与代码

微信公众号:幼儿园的学霸

目录

文章目录

  • C3与C2f模块介绍与代码
  • 目录
  • 简介
  • CSP/C3模块概述
  • C2f模块概述
  • C3与C2f结构对比
  • 参考资料

简介

顺序:CSPNet->C3->C2f
C2 module refers to the CSP (Cross Stage Partial) Bottleneck with 2 convolutions
C2f module is a faster implementation of the C2 module. It improves the execution speed of the model while maintaining similar performance. This optimization is achieved by making certain modifications to the original C2 module

CSP/C3模块概述

CSP(Cross Stage Partial-connections,跨阶段部分连接) 模块是一种跨阶段部分连接的模块,它能够有效地整合不同阶段的特征表示,并使模型在训练过程中更加关注重要的部分.特点:降低计算量的同时保证精度

看着这跨阶段三个字,肯定又是Skip操作,目的为了解决梯度消失问题,同时丰富多尺度特征,提高检测等任务的效果。
CSP结构通过将输入特征分为两部分,然后在这两个部分之间进行交叉连接的方法来提高神经网络的性能。CSP结构能有效的提高模型的特征表示能力,从而提高模型的准确性和泛化能力。

CSP 模块主要由两个部分组成:cross connection 和 partial connection。在 cross connection 部分,输入的 feature maps 会被分为两部分,分别进行不同的预测和处理,然后再将其合并起来;在 partial connection 部分,输入的 feature maps 会经过一个卷积层和一个残差连接,来提取更高层次的特征信息。

在CSPNet中,Partial Connection通常与Cross Connection相结合来实现。具体而言,基础层的特征图被分割成两部分,其中一部分直接绕过某些层(实现Cross Connection),而另一部分则进入这些层进行处理(实现Partial Connection)。通过这种方式,网络既能够保持较强的表征能力,又能够降低计算复杂度和内存使用。

通过 cross connection 和 partial connection 的结合,CSP 模块可以在保证模型深度的同时,提高了网络的计算效率和特征表示能力,从而在目标检测、图像分类等任务中达到更好的表现。具体来说,CSP 模块可以帮助网络有效地利用不同尺度的特征信息,增强模型对于输入图像的感知能力,同时减少了由于深度网络引入的梯度消失和过拟合等问题。

具体来说,BottleneckCSP 模块会首先使用一个 1x1 的卷积层来减少 feature maps 的通道数,然后再通过一个 bottleneck 层进一步压缩 feature maps 的深度。接下来,feature maps 会被分为两个部分并进行 shuffle 操作,然后再进行 concat、BN、ReLU 等操作得到新的 feature maps。这样处理后,BottleneckCSP 模块可以有效地利用空间信息和通道信息,加深网络的层数,从而提高检测精度。同时,由于使用了 bottleneck 和 shuffle 操作,BottleneckCSP 模块的计算复杂度较低,网络也较为轻量级。

C3模块:YOLOv5网络结构的核心就是CSPBlock模块,用YOLOv5的的语言来说,就是"C3"模块,相关代码如下所示

#!/usr/bin/env python3
# coding=utf-8

# ============================#
# Program:test.py
#       C3模块展示及网络结构查看
# Date:24-7-21
# Author:liheng
# Version:V1.0
# ============================#

import torch
import torch.nn as nn
# from torchsummary import summary

# 定义 Conv 类和 Bottleneck 类
def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""

    default_act = nn.SiLU()  # default activation
    # default_act = nn.ReLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))

class Bottleneck(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))


class C3(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
        # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))

# 创建一个示例模型
c1 = 64  # 输入通道数
c2 = 256  # 输出通道数
model = C3(c1, c2,3)

# 创建一个示例输入张量(假设输入大小为 [batch_size, c1, height, width])
batch_size = 3
height, width = 224, 224
input_tensor = torch.randn(batch_size, c1, height, width)

# 进行前向传播
output = model(input_tensor)

# 打印输出张量的形状,以确认模型运行正常
print("Output shape:", output.shape)
# print(model)
# 使用 torchsummary 打印模型摘要
# summary(model, input_size=(c1, height, width))

# # 查看模型结构
# import netron
# import torch.onnx
# modelData='./c3.onnx'
# # 将 pytorch 模型以 onnx 格式导出并保存
# torch.onnx.export(model, input_tensor, modelData)
# # 输出网络结构
# netron.start(modelData)

C2f模块概述

C2 module refers to the CSP (Cross Stage Partial) Bottleneck with 2 convolutions

C2f块:首先由一个卷积块(Conv)组成,该卷积块接收输入特征图并生成中间特征图
特征图拆分:生成的中间特征图被拆分成两部分,一部分直接传递到最终的Concat块,另一部分传递到多个Botleneck块进行进一步处理。
Bottleneck块:输入到这些Botleneck块的特征图通过一系列的卷积、归一化和激活操作进行处理,最后生成的特征图会与直接传递的那部分特征图在Concat块进行拼接(Concat)。
模型深度控制:在C2f模块中,Botleneck模块的数量由模型的depth muliple参数定义,这意味着可以根据需求灵活调整模块的深度和计算复杂度。
最终卷积块:拼接后的特征图会输入到一个最终的卷积块进行进一步处理,生成最终的输出特征图。

yolov8使用的C2f结构同样分为两种,一种在bottleneck中有残差结构,一种没有残差结构.

C2f模块默认不使用shortcut连接,C3模块默认使用shortcut连接,C2f相比于C3模块梯度流更丰富.

新的"C2f"模块在一定程度上是受到了YOLOv7的ELAN模块的启发,加入更多的分支,丰富梯度回传时的支流。

C2f模块代码如下:

#!/usr/bin/env python3
# coding=utf-8

# ============================#
# Program:C2f.py
#       C2f模块展示及网络结构查看
# Date:24-7-21
# Author:liheng
# Version:V1.0
# ============================#

import torch
import torch.nn as nn
# from torchsummary import summary

# 定义 Conv 类和 Bottleneck 类
def autopad(k, p=None, d=1):  # kernel, padding, dilation
    """Pad to 'same' shape outputs."""
    if d > 1:
        k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]  # actual kernel-size
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p

class Conv(nn.Module):
    """Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""

    default_act = nn.SiLU()  # default activation
    # default_act = nn.ReLU()  # default activation

    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
        """Initialize Conv layer with given arguments including activation."""
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

    def forward(self, x):
        """Apply convolution, batch normalization and activation to input tensor."""
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        """Perform transposed convolution of 2D data."""
        return self.act(self.conv(x))


class Bottleneck(nn.Module):
    # Standard bottleneck
    # 残差连接瓶颈层, Residual block
    def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5):
        '''
        :param c1: 输入通道
        :param c2: 输出通道
        :param shortcut: 为True时采用残差连接
        :param g: groups 在输出通道上分组, c2 // g 分组后不同组之间的卷积核参数不同
        :param e: 中间层的通道数
        '''

    # ch_in, ch_out, shortcut, groups, kernels, expand
        super().__init__()
        c_ = int(c2 * e)  # hidden channels 中间层的通道
        self.cv1 = Conv(c1, c_, k[0], 1)
        self.cv2 = Conv(c_, c2, k[1], 1, g=g)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))


class C2f(nn.Module):
    # CSP Bottleneck with 2 convolutions
    def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        self.c = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, 2 * self.c, 1, 1)
        self.cv2 = Conv((2 + n) * self.c, c2, 1)  # optional act=FReLU(c2)
        self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))

    def forward(self, x):
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in self.m)
        return self.cv2(torch.cat(y, 1))

# 创建一个示例模型
c1 = 64  # 输入通道数
c2 = 256  # 输出通道数
shortcut = False
model = C2f(c1, c2,3,shortcut=shortcut)

# 创建一个示例输入张量(假设输入大小为 [batch_size, c1, height, width])
batch_size = 3
height, width = 224, 224
input_tensor = torch.randn(batch_size, c1, height, width)

# 进行前向传播
output = model(input_tensor)

# 打印输出张量的形状,以确认模型运行正常
print("Output shape:", output.shape)
# print(model)
# 使用 torchsummary 打印模型摘要
# summary(model, input_size=(c1, height, width))

# 查看模型结构
import netron
import torch.onnx
modelData='./c2f.onnx'
# 将 pytorch 模型以 onnx 格式导出并保存
torch.onnx.export(model, input_tensor, modelData)
# 输出网络结构
netron.start(modelData)

C3与C2f结构对比

基于以上代码,可以绘制两者的模型结构如下:
C3模块模型结构如下:
C3模块

C2f模块模型结构如下:
C2f模块

C2f模块默认不使用shortcut连接,C3模块默认使用shortcut连接,C2f相比于C3模块梯度流更丰富.

参考资料

1.Why the C2f module in yaml file called ‘C2f’? What does it means
2.《目标检测大杂烩》-第13章-浅析YOLOv8
3.yolov5进阶
4.YOLOv8中的C2f模块


图注:幼儿园的学霸

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2106284.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LabVIEW学习篇 - 18】:人机界面交互设计02

文章目录 错误处理函数简单错误处理器通用错误处理器清楚错误合并错误错误代码至错误簇转换查找第一个错误 鼠标指针 错误处理函数 在LabVIEW中,是通过错误输入簇和错误输出簇来传递错误信息,可以将底层错误信息传递到上层VI。设计人员需要对不同程度的…

Air780E低功耗4G模组硬件设计手册01

本文主要介绍了Air780E的硬件设计中的的应用接口部分。 一、主要性能 Air780E模块功能框图: 模块型号列表: 模块主要性能: *注: 模组工作在-40C~-35C或75C~85C温度范围时,模组可以正常工作,但…

基于51单片机的车距离警示灯proteus仿真

地址: https://pan.baidu.com/s/1tBIqTY4cCK38Z_xRKrq83g 提取码:1234 仿真图: 芯片/模块的特点: AT89C52/AT89C51简介: AT89C52/AT89C51是一款经典的8位单片机,是意法半导体(STMicroelectro…

pyqt fromlayout 布局中间空隙问

问题:当采用 form layout 布局时候,在qt designer 设计界面,如果把中间移除会在布局中间占用位置(图1、图2),需要把后续空间向前移动后保存(图3),在将界面文件打开即显示…

基于STM32景区环境监测系统的设计与实现(论文+源码)

1系统方案设计 根据系统功能的设计要求,展开基于STM32景区环境监测系统设计。如图2.1所示为系统总体设计框图。系统以STM32单片机作为系统主控模块,通过DHT11传感器、MQ传感器、声音传感器实时监测景区环境中的温湿度、空气质量以及噪音数据。系统监测环…

中国各省会、地级市到杭州球面距离的数据

环境规制是一系列政策措施,旨在解决环境问题、保护生态环境,并推动低碳可持续发展。这些措施包括法律法规、行政命令和经济激励等,目的是减少企业和个人对环境的负面影响。环境规制强度是衡量这些政策措施严格程度的指标,通常通过…

GIS十大经典问题之9.地形分析问题

本系列《GIS十大经典问题》包括: 缓冲区分析问题叠加分析问题最短路径分析问题空间插值问题泰森多边形(Voronoi 图)生成问题空间聚类问题空间数据压缩问题空间查询问题地形分析问题网络分析中的连通性问题 一、地形分析介绍 地形分析在地理…

HTTP 二、进阶

四、安全 1、TLS是什么 (1)为什么要有HTTPS ​ 简单的回答是“因为 HTTP 不安全”。由于 HTTP 天生“明文”的特点,整个传输过程完全透明,任何人都能够在链路中截获、修改或者伪造请求 / 响应报文,数据不具有可…

数字化营销:品牌知名度提升的新利器

​嘿,朋友们!在如今这个数字化高速发展的时代,企业的营销格局发生了翻天覆地的变化。使用蚓链数字化营销系统,数字化营销正成为提升品牌知名度的关键力量。 先来了解一下蚓链数字化营销的内涵与特点。它是利用数字技术和互联网平台…

重磅!微信放开公众号注册限制!只要手机号,不用实名!

重磅!微信放开公众号注册限制!只要手机号,不用实名! 随着移动互联网的发展,微信公众号已经成为了许多个人与企业传递信息、分享内容的首选平台。就在近日,微信官方再次放出大招:公众号注册无需…

Python画笔案例-033 绘制爆炸图

1、绘制蝌蚪 通过 python 的turtle 库绘制爆炸图,如下图: 2、实现代码 绘制爆炸图,以下为实现代码: 爆炸图,非函数版本 : """爆炸图.py """ import turtle import randomc…

企业如何避免六西格玛黑带培训陷入形式主义?

在开展六西格玛黑带培训的过程中,不少企业陷入了形式主义的泥潭,导致培训效果大打折扣。本文,深圳天行健企业管理咨询公司旨在分享如何避免六西格玛黑带培训陷入形式主义,确保培训成果真正转化为企业的生产力。 一、明确培训目标与…

电商企业借助精益六西格玛培训提升资产周转率——从资本困局到效率跃升

随着市场日益饱和,电商企业的增长模式被迫从粗放式扩展向精细化运营转型。这个过程许多电商企业遭遇了资产周转率低下的瓶颈,资金流动性不足直接影响企业的扩展能力与市场竞争力。面对这一困境,越来越多的电商企业开始借助精益六西格玛这一强…

认知杂谈46

今天分享 有人说的一段争议性的话 I I 强者思维的人际关系观 拥有强者思维的人在和人交往的时候,可不会粗心大意。 I I 他们在人际交往这个大舞台上,会充分考虑他人的感受,绝不会像那些在网上肆无忌惮乱喷的人。 I I 他们深知人心的复杂多变…

链表.......

从右到左 更新尾部 typedef typedef struct ListNode { int value; struct ListNode *next;(这里不能用listnode*应为还没有定义) } ListNode; #include <stdio.h> #include <stdlib.h> // 定义链表节点结构体 struct ListNode { int value; s…

红黑树总结(RbTree)——C++版

目录 红黑树的五大规则 这些规则的作用 插入和删除中的规则修正(简单了解一下) 代码实现 单纯的变色 左旋变色 右旋变色 双旋变色 其他细节 简单的数据测试 set/map进行封装 红黑树是一种自平衡的二叉搜索树&#xff0c;它通过一组规则来确保树在插入或删除操作后保…

华为手机找不到wifi调试?不急,没有wifi调试一样可以进行局域网模式调试

最近小黄在使用uniapp启动无线调试的时候突然发现华为的手机突然找不到wifi调试了&#xff0c;那么我们怎么进行无线调试呢&#xff1f; 其实他只是找不到开关而已&#xff0c;正常使用就行。 1.使用数据线连接手机。 打开cmd命令行执行&#xff1a;adb tcpip 5555 2.再执行ad…

IOS 22 自定义标题栏(Toolbar)

标题栏实现效果 实现逻辑 自定义标题栏&#xff0c;我们可以基于系统NavigationBar定制&#xff0c;也可以使用控件完全自定义。本文使用控件完全自定义来实现自定义标题栏效果。 SuperToolbarView 创建一个自定义控件SuperToolbarView&#xff0c;可以把SuperToolbarView分…

如何查找自己文件的复制记录 - 用这个方法简单

如何查看自己文件的复制记录&#xff1f;在电脑操作的过程中经常会复制文件&#xff0c;那么这些记录在哪里可以看&#xff0c;怎么查找&#xff0c;我们可以使用专门的软件工具进行查看文件的复制、剪切历史记录&#xff0c;下面推荐一款比较实用的文件复制记录查看软件。 文…

Chrome 浏览器插件获取网页 window 对象(方案三)

前言 最近有个需求&#xff0c;是在浏览器插件中获取 window 对象下的某个数据&#xff0c;当时觉得很简单&#xff0c;和 document 一样&#xff0c;直接通过嵌入 content_scripts 直接获取&#xff0c;然后使用 sendMessage 发送数据到插件就行了&#xff0c;结果发现不是这…