主干网络篇 | YOLOv8更换主干网络之ShuffleNetV2(包括完整代码+添加步骤+网络结构图)

news2025/1/23 6:14:59

前言:Hello大家好,我是小哥谈。ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具有较好的应用潜力!~🌈  

     目录

🚀1. 基础概念

🚀2.网络结构

🚀3.添加步骤

🚀4.改进方法

🍀🍀步骤1:block.py文件修改

🍀🍀步骤2:__init__.py文件修改

🍀🍀步骤3:tasks.py文件修改

🍀🍀步骤4:创建自定义yaml文件

🍀🍀步骤5:新建train.py文件

🍀🍀步骤6:模型训练测试

🚀1. 基础概念

ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。

ShuffleNetV2的主要特点包括:

  1. 分组卷积:通过将输入通道分成多个组,并在组内进行卷积操作,减少了计算量和参数数量。
  2. 逐点卷积:使用1x1的卷积核进行逐点卷积,用于调整通道数和特征图的维度。
  3. 通道重排:通过将输入特征图按通道进行重排,实现信息的混洗和交互,增强了特征的表达能力。
  4. 瓶颈结构:采用瓶颈结构,即先降维再升维,减少了计算量和参数数量。
  5. 网络设计:ShuffleNet V2通过堆叠多个ShuffleNet单元来构建整个网络,可以根据任务的需求进行不同层数和宽度的配置。

ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具有较好的应用潜力。

shuffleNetV2这篇论文比较硬核,提出了不少新的思想,推荐大家可以看看论文原文。主要思想包括:

  • 模型的计算复杂度不能只看FLOPs,还需要参考一些其他的指标
  • 作者提出了4条如何设计高效网络的准则
  • 基于该准则提出了新的block设置

FLOPS网上有两种:FLOPS和 FLOPs

FLOPS:全大写,指每秒浮点运算次数,可以理解为计算的速度,是衡量硬件性能的一个指标 (硬件)
FLOPs:s小写,指浮点运算数,理解为计算量,可以用来衡量算法/模型的复杂度,(模型)在论文中常用GFLOPs(1 GFLOPs = 10^9FLOPs)

 ShuffleNetV2网络结构:

 原理图:

其中,a、b为ShuffleNetV1原理图,c、d为ShuffleNetV2原理图。

论文题目:《ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design》

论文地址:  https://arxiv.org/pdf/1807.11164.pdf

代码实现:  GitHub - megvii-model/ShuffleNet-Series 


🚀2.网络结构

本文的改进是基于YOLOv8,关于其网络结构具体如下图所示:

YOLOv8官方仓库地址:

GitHub - ultralytics/ultralytics: NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite

针对本文的改进,作者将所使用的含有预训练权重文件的YOLOv8完整源码进行了上传,大家可在我的“资源”中自行下载。  


🚀3.添加步骤

针对本文的改进,具体步骤如下所示:👇

步骤1:block.py文件修改

步骤2:__init__.py文件修改

步骤3:tasks.py文件修改

步骤4:创建自定义yaml文件

步骤5:新建train.py文件

步骤6:模型训练测试


🚀4.改进方法

🍀🍀步骤1:block.py文件修改

在源码中找到block.py文件,具体位置是ultralytics/nn/modules/block.py,然后将ShuffleNetV2模块代码添加到block.py文件末尾位置。

ShuffleNetV2模块代码:

# ShuffleNetv2核心代码
# By CSDN 小哥谈
import torch
import torch.nn as nn

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups
    x = x.view(batchsize, groups, channels_per_group, height, width)
    x = torch.transpose(x, 1, 2).contiguous()
    x = x.view(batchsize, -1, height, width)
    return x


class CBRM(nn.Module):  # Conv BN ReLU Maxpool2d
    def __init__(self, c1, c2):
        super(CBRM, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

    def forward(self, x):
        return self.maxpool(self.conv(x))


class Shuffle_Block(nn.Module):
    def __init__(self, ch_in, ch_out, stride):
        super(Shuffle_Block, self).__init__()

        if not (1 <= stride <= 2):
            raise ValueError('illegal stride value')
        self.stride = stride

        branch_features = ch_out // 2
        assert (self.stride != 1) or (ch_in == branch_features << 1)

        if self.stride > 1:
            self.branch1 = nn.Sequential(
                self.depthwise_conv(ch_in, ch_in, kernel_size=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(ch_in),

                nn.Conv2d(ch_in, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True),
            )

        self.branch2 = nn.Sequential(
            nn.Conv2d(ch_in if (self.stride > 1) else branch_features,
                      branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),

            self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),

            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
        )

    @staticmethod
    def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

    def forward(self, x):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

        out = channel_shuffle(out, 2)

        return out

再然后,在block.py文件最上方下图所示位置加入CBRMShuffle_Block

🍀🍀步骤2:__init__.py文件修改

在源码中找到__init__.py文件,具体位置是ultralytics/nn/modules/__init__.py

修改1:加入CBRMShuffle_Block,具体如下图所示:

修改2:加入CBRMShuffle_Block,具体如下图所示:

🍀🍀步骤3:tasks.py文件修改

在源码中找到tasks.py文件,具体位置是ultralytics/nn/tasks.py

修改1:在下图所示位置导入类名CBRMShuffle_Block

修改2:找到parse_model函数(736行左右),在下图中所示位置添加如下代码。

 # -------ShuffleNetv2------------
        elif m in [CBRM, Shuffle_Block]:
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, c2, *args[1:]]
        # --------------------------------

具体添加位置如下图所示:

🍀🍀步骤4:创建自定义yaml文件

在源码ultralytics/cfg/models/v8目录下创建yaml文件,并命名为:yolov8_ShuffleNetV2.yaml。具体如下图所示:

yolov8_ShuffleNetV2.yaml文件完整代码如下所示:

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [ -1, 1, CBRM, [ 32 ] ] # 0-P2/4
  - [ -1, 1, Shuffle_Block, [ 128, 2 ] ]  # 1-P3/8
  - [ -1, 3, Shuffle_Block, [ 128, 1 ] ]  # 2
  - [ -1, 1, Shuffle_Block, [ 256, 2 ] ]  # 3-P4/16
  - [ -1, 7, Shuffle_Block, [ 256, 1 ] ]  # 4
  - [ -1, 1, Shuffle_Block, [ 512, 2 ] ]  # 5-P5/32
  - [ -1, 3, Shuffle_Block, [ 512, 1 ] ]  # 6


# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 3], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 9

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 2], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 12 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 15 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 6], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 18 (P5/32-large)

  - [[12, 15, 18], 1, Detect, [nc]]  # Detect(P3, P4, P5)
🍀🍀步骤5:新建train.py文件

在源码根目录下新建train.py文件,文件完整代码如下所示:

from ultralytics import YOLO

# Load a model
model = YOLO(r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\v8\yolov8_ShuffleNetV2.yaml')  # build a new model from YAML
model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)
model = YOLO(r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\v8\yolov8_ShuffleNetV2.yaml').load('yolov8n.pt')  # build from YAML and transfer weights

# Train the model
model.train(data=r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\datasets\helmet.yaml', epochs=100, imgsz=640)

注意:一定要用绝对路径,以防发生报错。

🍀🍀步骤6:模型训练测试

train.py文件,点击“运行”,在作者自制的安全帽佩戴检测数据集上,模型可以正常训练。

模型训练过程: 

模型训练结果: 

 关于本次改进所使用的安全帽佩戴检测数据集,已上传至我的“资源”中,大家可免费下载。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1522197.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

spc x-bar 正态分布 echarts demo

使用echarts,elementUi,vue编写的spc分析的demo示例. 含x-bar和正态分布图,同一数据可以互转 chart.vue <template><div class"app-container"><el-row><el-col :span"4" class"button-container"><el-button clic…

Linux学习——线程池的创建

一&#xff0c;线程池的作用和优点 线程池使用的是一种池化技术&#xff0c;当我们要使用线程时采用线程池创建就一次创建多个线程&#xff0c;在调用当前线程时就让其它的线程进行等待。这样做的优点有如下几点&#xff1a; 1&#xff0c;提高响应速度。线程池提前把线程创建好…

使用Laravel框架创建项目

1.使用Composer创建项目 composer create-project --prefer-dist laravel/laravel blog "5.5.*" 如下图所以&#xff0c;Laravel框架就安装完成了 安装完成后&#xff0c;进入到项目文件夹根目录&#xff0c;打开终端&#xff0c;输入php artisan serve运行项目 p…

Linux操作系统裸机开发-环境搭建

一、配置SSH服务 1、下载安装ssh服务输入以下命令 sudo apt-get install nfs-kernel-server portmap2、建立一个供SSH服务使用的文件夹如以下命令 mkdir linux 3、完成前两步之后需要将其文件路径放到/etc/exports文件里输入以下命令&#xff1a; sudo vi /etc/esports 4.打…

天天说微服务,天天开发RESTful API,那你知道RESTful API是什么东东吗?

RESTful API&#xff08;Representational State Transfer&#xff09;是一种基于网络的架构风格&#xff0c;用于设计和构建Web服务。它是一种轻量级的架构&#xff0c;可以通过HTTP协议进行通信&#xff0c;并支持各种数据格式&#xff0c;例如JSON和XML。 在现代的Web应用程…

三极管工作原理及典型电路

一、三极管的工作原理 三极管&#xff0c;也被称为双极型晶体管或晶体三极管&#xff0c;是一种电流控制元件。主要功能是将微弱的电信号放大成幅度值较大的电信号&#xff0c;工作在饱和区和截止区时同时也被用作无触点开关。 根据结构和工作原理的不同&#xff0c;三极管可以…

Jmeter---分布式

分布式&#xff1a;多台机协作&#xff0c;以集群的方式完成测试任务&#xff0c;可以提高测试效率。 分布式架构&#xff1a;控制机&#xff08;分发任务&#xff09;与多台执行机&#xff08;执行任务&#xff09; 环境搭建&#xff1a; 不同的测试机上安装 Jmeter 配置基…

代码随想录|Day22|回溯02|216.组合总和III、17.电话号码的字母组合

216.组合总和III 本题思路和 77. 组合 类似&#xff0c;在此基础上多了一个和为 n 的判断。 class Solution:def combinationSum3(self, k: int, n: int) -> List[List[int]]:def backtrack(start, path, currentSum):# 递归终止条件&#xff1a;到达叶子节点# 如果和满足条…

HTTPS证书很贵吗?

首先&#xff0c;我们需要明确一点&#xff0c;HTTPS证书的价格并不是一成不变的&#xff0c;它受到多种因素的影响。其中最主要的因素包括证书的类型、颁发机构以及所需的验证级别。 从类型上来看&#xff0c;HTTPS证书主要分为单域名证书、多域名证书和通配符证书。单域名证书…

mmz批量多页抓取数据-AES.CBC算法-爬虫

目标&#xff1a;mmz多页下载 方法&#xff1a;加一个for循环实现多页的下载 问题&#xff1a;浏览器传输服务器时对页码参数做了加密处理 解决方法&#xff1a; 1、判断加密算法模式&#xff08;mmz是AES-CBC算法&#xff09; 2、找到加密的key和iv 代码&#xff1a; i…

基于springboot+vue实现疫情防控物资调配系统项目【项目源码】计算机毕业设计

基于springbootvue实现疫情防控物资调配系统演示 B/S结构的介绍 在确定了项目的主题和研究背景之后&#xff0c;就要确定本系统的架构了。主流的架构有两种&#xff0c;一种是B/S架构&#xff0c;一种是C/S架构。C/S的全称是Client/Server&#xff0c;Client是客户端的意思&am…

HarmonyOS NEXT应用开发—Grid和List内拖拽交换子组件位置

介绍 本示例分别通过onItemDrop()和onDrop()回调&#xff0c;实现子组件在Grid和List中的子组件位置交换。 效果图预览 使用说明&#xff1a; 拖拽Grid中子组件&#xff0c;到目标Grid子组件位置&#xff0c;进行两者位置互换。拖拽List中子组件&#xff0c;到目标List子组件…

插入排序:一种简单而有效的排序算法

插入排序&#xff1a;一种简单而有效的排序算法 一、什么是插入排序&#xff1f;二、插入排序的步骤三、插入排序的C语言实现四、插入排序的性能分析五、插入排序的优化六、总结 在我们日常生活和工作中&#xff0c;排序是一种非常常见的操作。比如&#xff0c;我们可能需要对一…

MasterPDF 强大的多功能软件

哈喽呀&#xff0c;我是苏音今天给大家带来一期免费PDF的工具&#xff0c;可以实现你的大部分需求。 最近有PDF文档相关的的需求&#xff0c;但是之前一直在用WPS&#xff0c;就看能不能实现下面两个功能 1.导出指定页的PDF 2.在某一页PDF中加入指定图片 虽然WPS可以实现将…

免费接口调用 招标信息自动抽取|招标信息|招标数据解析接口

一、开源项目介绍 一款多模态AI能力引擎&#xff0c;专注于提供自然语言处理&#xff08;NLP&#xff09;、情感分析、实体识别、图像识别与分类、OCR识别和语音识别等接口服务。该平台功能强大&#xff0c;支持本地化部署&#xff0c;并鼓励用户体验和开发者共同完善&#xf…

SpringBoot整合Seata注册到Nacos服务

项目引入pom文件 <!-- SpringCloud Seata 组件--> <dependency><groupId>com.alibaba.cloud</groupId><artifactId>spring-cloud-alibaba-seata</artifactId><version>${alibaba.seata}</version><exclusions><exc…

Postman接口测试之断言,全网最细教程没有之一!

一、断言 在 postman 中我们是在Tests标签中编写断言&#xff0c;同时右侧封装了常用的断言&#xff0c;当然 Tests 除了可以作为断言&#xff0c;还可以当做后置处理器来编写一些后置处理代码&#xff0c;经常应用于&#xff1a; 【1】获取当前接口的响应&#xff0c;传递给…

智慧城市革命,物联网技术如何改变城市治理与生活方式

随着科技的不断进步&#xff0c;智慧城市已经成为现代城市发展的重要方向之一。物联网技术作为智慧城市的重要支撑&#xff0c;正深刻改变着城市的治理模式和居民的生活方式。本文将探讨智慧城市革命&#xff0c;以及物联网技术如何改变城市治理与生活方式&#xff0c;同时介绍…

c++入门学习⑨——STL(万字总结,超级超级详细版)看完这一篇就够了!!!

目录 &#x1f384;前言 &#x1f384;概念 引入 定义 优点 &#x1f384;六大组件 容器 算法 迭代器 仿函数 适配器 空间配置器 &#x1f384;三大组件 迭代器&#xff08;iterator&#xff09; 定义 分类&#xff1a; 正向迭代器&#xff1a; 常量正向迭代…

c语言:操作符详解(上)

目录 一、操作符的分类二、二进制和进制转换1.2进制转10进制2.10进制转2进制3.2进制转8进制4.2进制转16进制 三、原码、反码、补码四、算术操作符、-、*、/、%1.**和-**2.*3./4.% 五、移位操作符1.左移操作符2.右移操作符 六、位操作符&#xff1a;&、|、^、~七、赋值操作符…