主干网络篇 | YOLOv8更换主干网络之ShuffleNetV2

news2024/12/24 21:43:11

前言:Hello大家好,我是小哥谈。ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具有较好的应用潜力!~🌈  

     目录

🚀1. 基础概念

🚀2.网络结构

🚀3.添加步骤

🚀4.改进方法

🍀🍀步骤1:block.py文件修改

🍀🍀步骤2:__init__.py文件修改

🍀🍀步骤3:tasks.py文件修改

🍀🍀步骤4:创建自定义yaml文件

🍀🍀步骤5:新建train.py文件

🍀🍀步骤6:模型训练测试

🚀1. 基础概念

ShuffleNetV2是一种轻量级的神经网络架构,用于图像分类和目标检测任务。它是ShuffleNet的改进版本,旨在提高模型的性能和效率。

ShuffleNetV2的主要特点包括:

  1. 分组卷积:通过将输入通道分成多个组,并在组内进行卷积操作,减少了计算量和参数数量。
  2. 逐点卷积:使用1x1的卷积核进行逐点卷积,用于调整通道数和特征图的维度。
  3. 通道重排:通过将输入特征图按通道进行重排,实现信息的混洗和交互,增强了特征的表达能力。
  4. 瓶颈结构:采用瓶颈结构,即先降维再升维,减少了计算量和参数数量。
  5. 网络设计:ShuffleNet V2通过堆叠多个ShuffleNet单元来构建整个网络,可以根据任务的需求进行不同层数和宽度的配置。

ShuffleNetV2相比于之前的版本,在保持模型轻量化的同时,提高了模型的准确性和性能。它在计算资源有限的设备上具有较好的应用潜力。

shuffleNetV2这篇论文比较硬核,提出了不少新的思想,推荐大家可以看看论文原文。主要思想包括:

  • 模型的计算复杂度不能只看FLOPs,还需要参考一些其他的指标
  • 作者提出了4条如何设计高效网络的准则
  • 基于该准则提出了新的block设置

FLOPS网上有两种:FLOPS和 FLOPs

FLOPS:全大写,指每秒浮点运算次数,可以理解为计算的速度,是衡量硬件性能的一个指标 (硬件)
FLOPs:s小写,指浮点运算数,理解为计算量,可以用来衡量算法/模型的复杂度,(模型)在论文中常用GFLOPs(1 GFLOPs = 10^9FLOPs)

 ShuffleNetV2网络结构:

 原理图:

其中,a、b为ShuffleNetV1原理图,c、d为ShuffleNetV2原理图。

论文题目:《ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design》

论文地址:  https://arxiv.org/pdf/1807.11164.pdf

代码实现:  GitHub - megvii-model/ShuffleNet-Series 


🚀2.网络结构

本文的改进是基于YOLOv8,关于其网络结构具体如下图所示:

YOLOv8官方仓库地址:

GitHub - ultralytics/ultralytics: NEW - YOLOv8 🚀 in PyTorch > ONNX > OpenVINO > CoreML > TFLite

针对本文的改进,作者将所使用的含有预训练权重文件的YOLOv8完整源码进行了上传,大家可在我的“资源”中自行下载。  


🚀3.添加步骤

针对本文的改进,具体步骤如下所示:👇

步骤1:block.py文件修改

步骤2:__init__.py文件修改

步骤3:tasks.py文件修改

步骤4:创建自定义yaml文件

步骤5:新建train.py文件

步骤6:模型训练测试


🚀4.改进方法

🍀🍀步骤1:block.py文件修改

在源码中找到block.py文件,具体位置是ultralytics/nn/modules/block.py,然后将ShuffleNetV2模块代码添加到block.py文件末尾位置。

ShuffleNetV2模块代码:

# ShuffleNetv2核心代码
# By CSDN 小哥谈
import torch
import torch.nn as nn

def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups
    x = x.view(batchsize, groups, channels_per_group, height, width)
    x = torch.transpose(x, 1, 2).contiguous()
    x = x.view(batchsize, -1, height, width)
    return x


class CBRM(nn.Module):  # Conv BN ReLU Maxpool2d
    def __init__(self, c1, c2):
        super(CBRM, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(c1, c2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(c2),
            nn.ReLU(inplace=True),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)

    def forward(self, x):
        return self.maxpool(self.conv(x))


class Shuffle_Block(nn.Module):
    def __init__(self, ch_in, ch_out, stride):
        super(Shuffle_Block, self).__init__()

        if not (1 <= stride <= 2):
            raise ValueError('illegal stride value')
        self.stride = stride

        branch_features = ch_out // 2
        assert (self.stride != 1) or (ch_in == branch_features << 1)

        if self.stride > 1:
            self.branch1 = nn.Sequential(
                self.depthwise_conv(ch_in, ch_in, kernel_size=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(ch_in),

                nn.Conv2d(ch_in, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True),
            )

        self.branch2 = nn.Sequential(
            nn.Conv2d(ch_in if (self.stride > 1) else branch_features,
                      branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),

            self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),

            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
        )

    @staticmethod
    def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

    def forward(self, x):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

        out = channel_shuffle(out, 2)

        return out

再然后,在block.py文件最上方下图所示位置加入CBRMShuffle_Block

🍀🍀步骤2:__init__.py文件修改

在源码中找到__init__.py文件,具体位置是ultralytics/nn/modules/__init__.py

修改1:加入CBRMShuffle_Block,具体如下图所示:

修改2:加入CBRMShuffle_Block,具体如下图所示:

🍀🍀步骤3:tasks.py文件修改

在源码中找到tasks.py文件,具体位置是ultralytics/nn/tasks.py

修改1:在下图所示位置导入类名CBRMShuffle_Block

修改2:找到parse_model函数(736行左右),在下图中所示位置添加如下代码。

 # -------ShuffleNetv2------------
        elif m in [CBRM, Shuffle_Block]:
            c1, c2 = ch[f], args[0]
            if c2 != nc:
                c2 = make_divisible(min(c2, max_channels) * width, 8)
            args = [c1, c2, *args[1:]]
        # --------------------------------

具体添加位置如下图所示:

🍀🍀步骤4:创建自定义yaml文件

在源码ultralytics/cfg/models/v8目录下创建yaml文件,并命名为:yolov8_ShuffleNetV2.yaml。具体如下图所示:

yolov8_ShuffleNetV2.yaml文件完整代码如下所示:

# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'
  # [depth, width, max_channels]
  n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPs
  s: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPs
  m: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPs
  l: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPs
  x: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs

# YOLOv8.0n backbone
backbone:
  # [from, repeats, module, args]
  - [ -1, 1, CBRM, [ 32 ] ] # 0-P2/4
  - [ -1, 1, Shuffle_Block, [ 128, 2 ] ]  # 1-P3/8
  - [ -1, 3, Shuffle_Block, [ 128, 1 ] ]  # 2
  - [ -1, 1, Shuffle_Block, [ 256, 2 ] ]  # 3-P4/16
  - [ -1, 7, Shuffle_Block, [ 256, 1 ] ]  # 4
  - [ -1, 1, Shuffle_Block, [ 512, 2 ] ]  # 5-P5/32
  - [ -1, 3, Shuffle_Block, [ 512, 1 ] ]  # 6


# YOLOv8.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 3], 1, Concat, [1]]  # cat backbone P4
  - [-1, 3, C2f, [512]]  # 9

  - [-1, 1, nn.Upsample, [None, 2, 'nearest']]
  - [[-1, 2], 1, Concat, [1]]  # cat backbone P3
  - [-1, 3, C2f, [256]]  # 12 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 9], 1, Concat, [1]]  # cat head P4
  - [-1, 3, C2f, [512]]  # 15 (P4/16-medium)

  - [-1, 1, Conv, [512, 3, 2]]
  - [[-1, 6], 1, Concat, [1]]  # cat head P5
  - [-1, 3, C2f, [1024]]  # 18 (P5/32-large)

  - [[12, 15, 18], 1, Detect, [nc]]  # Detect(P3, P4, P5)
🍀🍀步骤5:新建train.py文件

在源码根目录下新建train.py文件,文件完整代码如下所示:

from ultralytics import YOLO

# Load a model
model = YOLO(r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\v8\yolov8_ShuffleNetV2.yaml')  # build a new model from YAML
model = YOLO('yolov8n.pt')  # load a pretrained model (recommended for training)
model = YOLO(r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\models\v8\yolov8_ShuffleNetV2.yaml').load('yolov8n.pt')  # build from YAML and transfer weights

# Train the model
model.train(data=r'C:\Users\Lenovo\PycharmProjects\ultralytics-main\ultralytics\cfg\datasets\helmet.yaml', epochs=100, imgsz=640)

注意:一定要用绝对路径,以防发生报错。

🍀🍀步骤6:模型训练测试

train.py文件,点击“运行”,在作者自制的安全帽佩戴检测数据集上,模型可以正常训练。

模型训练过程: 

模型训练结果: 

 关于本次改进所使用的安全帽佩戴检测数据集,已上传至我的“资源”中,大家可免费下载。


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1523529.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode刷题小记 八、【回溯算法】

1.回溯算法 文章目录 1.回溯算法写在前面1.1回溯算法基本知识1.2组合问题1.3组合问题的剪枝操作1.4组合总和III1.5电话号码的字母组合1.6组合总和1.7组合总和II1.8分割回文串1.9复原IP地址1.10子集问题1.11子集II1.12非递减子序列1.13全排列1.14全排列II1.15N皇后1.16解数独 写…

WanAndroid(鸿蒙版)开发的第三篇

前言 DevEco Studio版本&#xff1a;4.0.0.600 WanAndroid的API链接&#xff1a;玩Android 开放API-玩Android - wanandroid.com 其他篇文章参考&#xff1a; 1、WanAndroid(鸿蒙版)开发的第一篇 2、WanAndroid(鸿蒙版)开发的第二篇 3、WanAndroid(鸿蒙版)开发的第三篇 …

三级等保技术建议书

1信息系统详细设计方案 1.1安全建设需求分析 1.1.1网络结构安全 1.1.2边界安全风险与需求分析 1.1.3运维风险需求分析 1.1.4关键服务器管理风险分析 1.1.5关键服务器用户操作管理风险分析 1.1.6数据库敏感数据运维风险分析 1.1.7“人机”运维操作行为风险综合分析 1.2…

EPICS和Arduino Uno之间基于串行文本协议的控制开发

Arduino Uno的串口服务程序设置如文本的串口通信协议设计以及在Arduino上的应用-CSDN博客中所示。通过在串口上发送约定的文本协议&#xff0c;它实现的功能如下&#xff1a; 实现功能&#xff1a; 读取三路0.0V~5.0V模拟量输入&#xff0c;读取端口A0~A2设置三路0.0V~5.0V的模…

Github 2024-03-16 开源项目日报Top10

根据Github Trendings的统计,今日(2024-03-16统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Python项目5非开发语言项目2TypeScript项目1C++项目1Lua项目1Swift项目1《Hello 算法》:动画图解、一键运行的数据结构与算法教程 创建周期:4…

spring boot3登录开发-微信小程序用户登录设计与实现

⛰️个人主页: 蒾酒 &#x1f525;系列专栏&#xff1a;《spring boot实战》 &#x1f30a;山高路远&#xff0c;行路漫漫&#xff0c;终有归途 目录 写在前面 登录流程 流程解析 具体实现 相关代码 说明 服务端 小程序端 写在最后 写在前面 本文介绍了springb…

BigDL-LLM 安装指南——在iGPU集成显卡下使用BigDL-LLM大模型库加速LLM

文章目录 iGPU是什么&#xff1f;一、环境准备1.1 Visual Studio 2022 Community 安装1.2 安装或更新最新版本的GPU驱动程序1.3 安装英特尔oneAPI工具包2024.0版本1.4 安装Anaconda 二、BigDL -LLM 安装2.1 创建虚拟环境2.2 激活虚拟环境2.3 安装bigdl-llm[xpu] 三、运行环境配…

Etcd 介绍与使用(入门篇)

etcd 介绍 etcd 简介 etc &#xff08;基于 Go 语言实现&#xff0c;&#xff09;在 Linux 系统中是配置文件目录名&#xff1b;etcd 就是配置服务&#xff1b; etcd 诞生于 CoreOS 公司&#xff0c;最初用于解决集群管理系统中 os 升级时的分布式并发控制、配置文件的存储与分…

哔哩哔哩后端Java一面

前言 作者&#xff1a;晓宜 个人简介&#xff1a;互联网大厂Java准入职&#xff0c;阿里云专家博主&#xff0c;csdn后端优质创作者&#xff0c;算法爱好者 最近各大公司的春招和实习招聘都开始了&#xff0c;这里分享下去年面试B站的的一些问题&#xff0c;希望对大家有所帮助…

PLC_博图系列☞基本指令“RESET_BF”复位位域

PLC_博图系列☞基本指令“RESET_BF”复位位域 文章目录 PLC_博图系列☞基本指令“RESET_BF”复位位域背景介绍RESET_BF&#xff1a;复位位域说明类型为 PLC 数据类型、STRUCT 或 ARRAY 的位域参数示例 关键字&#xff1a; PLC、 西门子、 博图、 Siemens 、 RESET_BF 背景…

java 开发工具

新建项目 打开idea选择 New Project 新建一个项目 左边选择 Java项目&#xff0c;右边选择Java版本 接着next 修改项目名称和保存路径&#xff0c;然后点击下面的 Finish 最终页面; 在 src 目录右键&#xff0c;新建一个包 在 src 目录右键&#xff0c;新建java 文件 有时候会需…

Git全套教程一套精通git.跟学黑马笔记

Git全套教程一套精通git.跟学黑马笔记 文章目录 Git全套教程一套精通git.跟学黑马笔记1.版本管理工具概念2. 版本管理工具介绍2.1版本管理发展简史(维基百科)2.1.1 SVN(SubVersion)2.1.2 Git 3. Git 发展简史4. Git 的安装4.1 git 的下载4.2 安装4.3 基本配置4.4 为常用指令配置…

智能工具柜-RFID智能工具柜管理系统

RFID工具柜管理系统是一种便捷化的工具管理系统&#xff0c;它采用RFID技术实现信息化&#xff0c;可以大大提高工具管理的效率和准确性。 日常的工具管理也确实存在一定的管理问题&#xff0c;如工具管理效率低、管理不准确等。因此&#xff0c;采用RFID技术实现信息化已经成…

【深度学习】深度估计,Depth Anything Unleashing the Power of Large-Scale Unlabeled Data

论文标题&#xff1a;Depth Anything Unleashing the Power of Large-Scale Unlabeled Data 论文地址&#xff1a;https://arxiv.org/pdf/2401.10891.pdf 项目主页&#xff1a;https://depth-anything.github.io/ 演示地址&#xff1a;https://huggingface.co/spaces/LiheYoung…

【Elasticsearch】windows安装elasticsearch教程及遇到的坑

一、安装参考 1、安装参考&#xff1a;ES的安装使用(windows版) elasticsearch的下载地址&#xff1a;https://www.elastic.co/cn/downloads/elasticsearch ik分词器的下载地址&#xff1a;https://github.com/medcl/elasticsearch-analysis-ik/releases kibana可视化工具下载…

火车订票管理系统|基于springboot框架+ Mysql+Java+B/S结构的火车订票管理系统设计与实现(可运行源码+数据库+设计文档)

推荐阅读100套最新项目 最新ssmjava项目文档视频演示可运行源码分享 最新jspjava项目文档视频演示可运行源码分享 最新Spring Boot项目文档视频演示可运行源码分享 目录 前台功能效果图 管理员功能登录前台功能效果图 用户功能模块 系统功能设计 数据库E-R图设计 lunwen…

【深度学习目标检测】二十三、基于深度学习的行人检测计数系统-含数据集、GUI和源码(python,yolov8)

行人检测计数系统是一种重要的智能交通监控系统&#xff0c;它能够通过图像处理技术对行人进行实时检测、跟踪和计数&#xff0c;为城市交通规划、人流控制和安全管理提供重要数据支持。本系统基于先进的YOLOv8目标检测算法和PyQt5图形界面框架开发&#xff0c;具有高效、准确、…

[WUSTCTF2020]颜值成绩查询 --不会编程的崽

这题也是一个很简单的盲注题目&#xff0c;这几天sql与模板注入做麻了&#xff0c;也是轻松拿捏。 它已经提示&#xff0c;enter number&#xff0c;所有猜测这里后台代码并没有使用 " 闭合。没有明显的waf提示&#xff0c; 但是or&#xff0c;and都没反应。再去fuzz一…

C++17之std::variant

1. std::variant操作 如下列出了为std:: variable <>提供的所有操作。

Spring Boot整合STOMP实现实时通信

目录 引言 代码实现 配置类WebSocketMessageBrokerConfig DTO 工具类 Controller common.html stomp-broadcast.html 运行效果 完整代码地址 引言 STOMP&#xff08;Simple Text Oriented Messaging Protocol&#xff09;作为一种简单文本导向的消息传递协议&#xf…