ShuffleNet系列 网络结构

news2025/1/18 13:53:33

文章目录

  • ShuffleNet V1
    • Channel Shuffle:通道打散
    • SuffleNet Unit
    • Model Architecture
    • 实验结果
  • ShuffleNet V2
    • Guideline 1
    • Guideline 2
    • Guideline 3
    • Guideline 4
    • 模型结构
    • 代码

论文:ShuffleNet:
ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices
论文:ShuffleNet V2:
ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design

ShuffleNet V1

Channel Shuffle:通道打散

在这里插入图片描述

SuffleNet Unit

在这里插入图片描述

Model Architecture

在这里插入图片描述

实验结果

在这里插入图片描述

ShuffleNet V2

  • 计算速度,5 个轻量级网络要领

Guideline 1

  • 卷积输入和输出是相同的通道数时,将最小化内存访问成本,每秒处理图片最多
    在这里插入图片描述

Guideline 2

  • 过多的分组卷积会加大内存访问成本,导致训练速度下降,时间上升,但是参数量更低,准确率更高。
    在这里插入图片描述

Guideline 3

  • 碎片操作减少网络的并行度,碎片操作:将大的卷积分成小卷积操作进行。
    在这里插入图片描述

在这里插入图片描述

Guideline 4

  • 不要忽略元素级操作,元素级指的是 Relu,TensorAdd,BiasAdd等等的矩阵元素操作,可以推测这些操作基本没有被算到 FLOPs 中,但是对内存访问成本(MAC)的影响比较大。
    在这里插入图片描述
    在这里插入图片描述

模型结构

在这里插入图片描述
在这里插入图片描述

代码

from typing import List, Callable

import torch
from torch import Tensor
import torch.nn as nn
 
def channel_shuffle(x: Tensor, groups: int):

    batch_size, num_channels, height, width = x.size()
    channels_per_group = num_channels // groups

    # reshape
    # [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width]
    x = x.view(batch_size, groups, channels_per_group, height, width)

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batch_size, -1, height, width)

    return x 

class InvertedResidual(nn.Module):
    def __init__(self, input_c: int, output_c: int, stride: int):
        super(InvertedResidual, self).__init__()

        if stride not in [1, 2]:
            raise ValueError("illegal stride value.")
        self.stride = stride

        assert output_c % 2 == 0
        branch_features = output_c // 2
        # 当stride为1时,input_channel应该是branch_features的两倍
        # python中 '<<' 是位运算,可理解为计算×2的快速方法
        assert (self.stride != 1) or (input_c == branch_features << 1)

        if self.stride == 2:
            self.branch1 = nn.Sequential(
                self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(input_c),
                nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True)
            )
        else:
            self.branch1 = nn.Sequential()

        self.branch2 = nn.Sequential(
            nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,
                      stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
            self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True)
        )

    @staticmethod
    def depthwise_conv(input_c: int,output_c: int,kernel_s: int,stride: int = 1, padding: int = 0,bias: bool = False):
        return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s,
                         stride=stride, padding=padding, bias=bias, groups=input_c)

    def forward(self, x: Tensor):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
        out = channel_shuffle(out, 2)

        return out


class ShuffleNetV2(nn.Module):
    def __init__(self, stages_repeats: List[int],stages_out_channels: List[int],num_classes: int = 1000,
                 inverted_residual: Callable[..., nn.Module] = InvertedResidual):
        super(ShuffleNetV2, self).__init__()

        if len(stages_repeats) != 3:
            raise ValueError("expected stages_repeats as list of 3 positive ints")
        if len(stages_out_channels) != 5:
            raise ValueError("expected stages_out_channels as list of 5 positive ints")
        self._stage_out_channels = stages_out_channels

        # input RGB image
        input_channels = 3
        output_channels = self._stage_out_channels[0]

        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True)
        )
        input_channels = output_channels

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Static annotations 
        self.stage2: nn.Sequential
        self.stage3: nn.Sequential
        self.stage4: nn.Sequential

        stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
        for name, repeats, output_channels in zip(stage_names, stages_repeats, self._stage_out_channels[1:]):
            seq = [inverted_residual(input_channels, output_channels, 2)]
            for i in range(repeats - 1):
                seq.append(inverted_residual(output_channels, output_channels, 1))
            setattr(self, name, nn.Sequential(*seq))
            input_channels = output_channels

        output_channels = self._stage_out_channels[-1]
        self.conv5 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True)
        )
        self.fc = nn.Linear(output_channels, num_classes)

    def _forward_impl(self, x: Tensor):
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.conv5(x)
        x = x.mean([2, 3])  # global pool
        x = self.fc(x)
        return x

    def forward(self, x: Tensor):
        return self._forward_impl(x)


"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`.
weight: https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth

:param num_classes:
:return:
"""
def shufflenet_v2_x1_0(num_classes=1000):

    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 116, 232, 464, 1024],
                         num_classes=num_classes)

    return model

"""
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`.
weight: https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth

:param num_classes:
:return:
"""
def shufflenet_v2_x0_5(num_classes=1000):

    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 48, 96, 192, 1024],
                         num_classes=num_classes)

    return model

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1187348.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2023年的低代码:数字化、人工智能、趋势及未来展望

本文由葡萄城技术团队发布。转载请注明出处&#xff1a;葡萄城官网&#xff0c;葡萄城为开发者提供专业的开发工具、解决方案和服务&#xff0c;赋能开发者。 前言 正如许多专家预测的那样&#xff0c;低代码平台在2023年将展现更加强劲的势头。越来越多的企业正在纷纷转向低代…

ArcGIS 气象风场等示例 数据制作、服务发布及前端加载

1. 原始数据为多维数据 以nc数据为例。 首先在pro中需要以多维数据的方式去添加多维数据&#xff0c;这里的数据包含uv方向&#xff1a; 加载进pro的效果&#xff1a; 这里注意 数据属性需要为矢量uv&#xff1a; 如果要发布为服务&#xff0c;需要导出存储为tif格式&…

spring 中 @Validated/@Valid

超级好的链接 添加链接描述

Vue实现面经基础版案例(路由+组件缓存)

一、面经基础版-案例效果分析 1.面经效果演示 2.功能分析 通过演示效果发现&#xff0c;主要的功能页面有两个&#xff0c;一个是列表页&#xff0c;一个是详情页&#xff0c;并且在列表页点击时可以跳转到详情页底部导航可以来回切换&#xff0c;并且切换时&#xff0c;只有…

AI:69-基于深度学习的音乐推荐

🚀 本文选自专栏:AI领域专栏 从基础到实践,深入了解算法、案例和最新趋势。无论你是初学者还是经验丰富的数据科学家,通过案例和项目实践,掌握核心概念和实用技能。每篇案例都包含代码实例,详细讲解供大家学习。 📌📌📌在这个漫长的过程,中途遇到了不少问题,但是…

JavaScript脚本操作CSS

脚本化CSS就是使用JavaScript脚本操作CSS&#xff0c;配合HTML5、Ajax、jQuery等技术&#xff0c;可以设计出细腻、逼真的页面特效和交互行为&#xff0c;提升用户体验&#xff0c;如网页对象的显示/隐藏、定位、变形、运动等动态样式。 1、CSS脚本化基础 CSS样式有两种形式&…

OpenCV 在ImShow窗体上选择感兴趣的区域

窗体上选择感兴趣ROI区域 在计算机视觉处理中, 通常是针对图像中的一个特定区域进行处理, 有时候这个特定区域需要人来选择, OpenCV 也提供了窗口选择ROI机制. 窗体支持两种选择ROI区域的方法, 一个是单选, 一个是多选, 操作方法如下: 单选: 通过鼠标在屏幕上选择区域, 然后通过…

【Linux系统编程十六】:(基础IO3)--用户级缓冲区

【Linux系统编程十六】&#xff1a;基础IO3--用户级缓冲区 一.用户级缓冲区二.缓冲区刷新策略1.验证&#xff1a; 三.缓冲区意义 一.用户级缓冲区 我们首先理解上面的代码&#xff0c;分别使用printf和fprintf&#xff0c;fwrite往1号文件描述符里输出&#xff0c;也就是往显示…

论文阅读——InternImage(cvpr2023)

arxiv&#xff1a;https://arxiv.org/abs/2211.05778 github&#xff1a;https://github.com/OpenGVLab/InternImage 一、介绍 大部分大模型都是基于transformer的&#xff0c;本文是一个基于CNN的视觉基础模型。使用可变性卷积deformable convolution作为核心操作&…

docker复制镜像文件

一、复制镜像 #1. 查找本机已有的镜像docker images |grep xxxx#2. 将镜像复制出来指向到xxxx.tar的文件中 docker save 343cca04e31d > xxxx.tareg: 二、加载镜像 直接将拷贝好的镜像包直接加载即可 docker load < myimage.tar

【C++】一文简练总结【多态】及其底层原理&具体应用(21)

前言 大家好吖&#xff0c;欢迎来到 YY 滴C系列 &#xff0c;热烈欢迎&#xff01; 本章主要内容面向接触过C的老铁 主要内容含&#xff1a; 欢迎订阅 YY滴C专栏&#xff01;更多干货持续更新&#xff01;以下是传送门&#xff01; 目录 一.多态的概念二.多态的实现1&#xff…

Codeforces Round 908 (Div. 2)视频详解

Educational Codeforces Round 157 &#xff08;A--D&#xff09;视频详解 视频链接A题代码B题代码C题代码D题代码 视频链接 Codeforces Round 908 (Div. 2)视频详解 A题代码 #include<bits/stdc.h> #define endl \n #define deb(x) cout << #x << "…

mac M2 anaconda 解决装不了python3.7

今天发现一个很奇怪的问题 但是我一换成 conda create -n DCA python3.8.12就是成功的 这个就很奇怪, 解决如下 https://towardsdatascience.com/how-to-manage-conda-environments-on-an-apple-silicon-m1-mac-1e29cb3bad12 998 conda search pythonconda search python …

C++之函数中实现类与调用总结(二百五十四)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 优质专栏&#xff1a;Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 人生格言&#xff1a; 人生…

SpringMVC使用AOP监听方法推送数据

导入aop的maven依赖 <dependency><groupId>org.aspectj</groupId><artifactId>aspectjweaver</artifactId><version>1.6.12</version> </dependency>创建一个spring的XML文件编写aop配置 <?xml version"1.0" …

基于springboot的二次元商品销售网站的设计与开发

大家好我是玥沐春风&#xff0c;今天分享一个基于springboot的二次元商品销售网站的设计与开发&#xff0c;项目源码以及部署相关请联系我&#xff0c;文末附上联系信息 。 开发工具及技术 2.3.1 Spring Boot框架 SpringBoot是一个全新的开源的轻量级框架。简化了Spring应用的…

Linux中的权限

目录 一、shell命令以及运行原理 二、Linux权限的概念 三、权限的八进制表示 四、修改文件的拥有者和所属组 五、权限常见的问题 1、目录的权限 2、umask 3、粘滞位 一、shell命令以及运行原理 首先&#xff0c;我们先来看看这个问题&#xff1a;我们使用命令是直接操…

electron安装报错:Electron failed to install correctly...解决方案

问题描述&#xff1a; 按照官方文档在yarn dev时报错&#xff1a; 一般遇到Electron failed to install correctly&#xff0c;please delete node_moules/electron and try installing again这种错误时&#xff0c;就是electron本体没有下载成功 解决方案&#xff1a; 1、…

GaN HEMT 电容的分析建模,包括寄生元件

标题&#xff1a;Analytical Modeling of Capacitances for GaN HEMTs, Including Parasitic Components 来源&#xff1a;IEEE TRANSACTIONS ON ELECTRON DEVICES&#xff08;14年&#xff09; 摘要&#xff1a;本文提出了一种基于表面势的终端电荷和电容模型&#xff0c;包…