ShuffleNetV2 结构(附源码)

news2024/11/24 23:01:36

本文不细看paper,只看网络结构和源码实现。

看下ShuffleNetV2的结构吧。
在这里插入图片描述
image是3通道进去,经过conv1和maxpool,
然后stage2~4则是主题,里面stride = 2和 stride = 1的shuffleBlock分别重复几次。

shuffleBlock如下,左边是stride = 1的,右边是stride = 2的。
举个栗子,stage2的in_channel为24, out_channel为116,
每个block是有2个branch的,这个channel要分配一下,比如左边右边各58,经过最后的Concat, 就是116.
stride = 2时channel会加倍,stride = 1时channel不变。
后面代码里会看到。
在这里插入图片描述

class ShuffleNetV2(nn.Module):
    def __init__(
        self,
        model_size="1.5x",
        out_stages=(2, 3, 4),
        with_last_conv=False,
        kernal_size=3,
        activation="ReLU",
        pretrain=True,
    ):
        super(ShuffleNetV2, self).__init__()
        # out_stages can only be a subset of (2, 3, 4)
        assert set(out_stages).issubset((2, 3, 4))

        print("model size is ", model_size)  #1.0x

        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        self.out_stages = out_stages
        self.with_last_conv = with_last_conv #False
        self.kernal_size = kernal_size
        self.activation = activation #LeakyReLU
        if model_size == "0.5x":
            self._stage_out_channels = [24, 48, 96, 192, 1024]
        elif model_size == "1.0x":
            self._stage_out_channels = [24, 116, 232, 464, 1024] 
        elif model_size == "1.5x":
            self._stage_out_channels = [24, 176, 352, 704, 1024]
        elif model_size == "2.0x":
            self._stage_out_channels = [24, 244, 488, 976, 2048]
        else:
            raise NotImplementedError

        # building first layer
        input_channels = 3
        output_channels = self._stage_out_channels[0]  #24
        #conv3x3,s=2.3->24,p=1,BN,LeakyReLU
        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
            nn.BatchNorm2d(output_channels),
            act_layers(activation),
        )
        input_channels = output_channels  #24

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        stage_names = ["stage{}".format(i) for i in [2, 3, 4]] #paper中的stage2~4
        #zip是把几个数组的元素打包,以最短的数组为基准
        for name, repeats, output_channels in zip(
            stage_names, self.stage_repeats, self._stage_out_channels[1:]
        ):
        #看paper中的表格,stride=2的repeat一次,stride=1的repeat多少次
            seq = [
                ShuffleV2Block(
                    input_channels, output_channels, 2, activation=activation
                )
            ]
            for i in range(repeats - 1):
                seq.append(
                    ShuffleV2Block(
                        output_channels, output_channels, 1, activation=activation
                    )
                )
            setattr(self, name, nn.Sequential(*seq))
            input_channels = output_channels
        output_channels = self._stage_out_channels[-1]
        if self.with_last_conv:
            conv5 = nn.Sequential(
                nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
                nn.BatchNorm2d(output_channels),
                act_layers(activation),
            )
            self.stage4.add_module("conv5", conv5)
        self._initialize_weights(pretrain)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)
        output = []
        for i in range(2, 5):
            stage = getattr(self, "stage{}".format(i))
            x = stage(x)
            if i in self.out_stages:
                output.append(x)
        return tuple(output)

ShuffleNetV2 block

class ShuffleV2Block(nn.Module):
    def __init__(self, inp, oup, stride, activation="ReLU"):
        super(ShuffleV2Block, self).__init__()

        if not (1 <= stride <= 3):
            raise ValueError("illegal stride value")
        self.stride = stride

        branch_features = oup // 2  #每个branch分配一半的channel
        assert (self.stride != 1) or (inp == branch_features << 1)

        if self.stride > 1:
            self.branch1 = nn.Sequential(
                self.depthwise_conv(
                    inp, inp, kernel_size=3, stride=self.stride, padding=1
                ),
                nn.BatchNorm2d(inp),
                nn.Conv2d(
                    inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False
                ),
                nn.BatchNorm2d(branch_features),
                act_layers(activation),
            )
        else:
            self.branch1 = nn.Sequential()

        self.branch2 = nn.Sequential(
            nn.Conv2d(
                inp if (self.stride > 1) else branch_features,
                branch_features,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.BatchNorm2d(branch_features),
            act_layers(activation),
            self.depthwise_conv(
                branch_features,
                branch_features,
                kernel_size=3,
                stride=self.stride,
                padding=1,
            ),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(
                branch_features,
                branch_features,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.BatchNorm2d(branch_features),
            act_layers(activation),
        )

    @staticmethod
    def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
        return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)

    def forward(self, x):
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1) #在dim=1(channel)上分成2块
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

        out = channel_shuffle(out, 2)

        return out

channel shuffle在这篇文章里面说过。
channel数转为矩阵,矩阵转置再压平。

在这里插入图片描述

def channel_shuffle(x, groups):
    # type: (torch.Tensor, int) -> torch.Tensor
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups

    # reshape
    #channel变为groups x channels_per_group
    x = x.view(batchsize, groups, channels_per_group, height, width)

	#转置为channels_per_group x groups
    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batchsize, -1, height, width)

    return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/86789.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

搭建Kubord管理k8s/EKS以及Harbor私有仓库教程

eks首先要去aws后台进行创建&#xff0c;这里不再讲解详细的过程&#xff0c;下面讲解如果通过命令行以及kuboard调度esk服务。 安装docker以及docker-compose yum install docker service docker start curl https://get.daocloud.io/docker/compose/releases/download/1.24…

零食商城小程序开发,建立商家良好品牌形象

相信很多人都无法拒绝来自零食的诱惑&#xff0c;尤其是在闲暇刷剧时&#xff0c;一边看剧一边享受着味蕾的满足&#xff0c;简直不要太幸福。现在人们对于零食的要求越来越高&#xff0c;不仅注重口感&#xff0c;更讲究包装&#xff0c;这就让零食行业逐渐走向精细化。而零食…

ssm+Vue计算机毕业设计校园统一网络授课平台(程序+LW文档)

ssmVue计算机毕业设计校园统一网络授课平台&#xff08;程序LW文档&#xff09; 项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项…

SpringMVC-狂神

SpringMVC优点&#xff1a; 轻量级&#xff0c;简单易学 高效&#xff0c;基于请求响应的MVC框架 与Spring无缝结合 功能强大&#xff1a;RESTful风格&#xff0c;数据验证&#xff0c;格式化&#xff0c;本地化&#xff0c;主题等 简单灵活 SpringMVC全部围绕DispatchSer…

AI(人工智能),时代的风口

你知道AI并非一个新词吗&#xff1f; 你知道 AI 正在影响着包括数学、物理学、生命科学等诸多领域的前沿科学研究吗&#xff1f; “AI是一个具有魅力的词&#xff0c;也是一个很古老的词”。 我们通常所说的AI &#xff08;Artificial intelligence&#xff09; 翻译为“人工…

安卓玩机搞机技巧综合资源-----不亮屏幕导资料 有屏幕锁保数据刷机等 多种方式【十五】

接上篇 安卓玩机搞机技巧综合资源------如何提取手机分区 小米机型代码分享等等 【一】 安卓玩机搞机技巧综合资源------开机英文提示解决dm-verity corruption your device is corrupt. 设备内部报错 AB分区等等【二】 安卓玩机搞机技巧综合资源------EROFS分区格式 小米红…

C#打开摄像头后获取图片,调用face_recognition进行人脸识别

运行效果如截图&#xff1a;左边和保存的图片做对比&#xff0c;打印相似度&#xff0c;部分打印内容为python中的打印输出&#xff0c;可以用来做结果判断。右边打开摄像头后&#xff0c;可以单张图片进行人脸识别&#xff0c;或者一直截图镜头中的图片进行比对。期中python是…

ReSharper添加对最新C#11特性的支持

ReSharper添加对最新C#11特性的支持 C#11 UTF-8文字-增加了对UTF-8文字的基本支持。代码分析现在建议对文字使用u8后缀&#xff0c;而不是System.Text.Encoding.UTF8.GetBytes()方法或具有适当UTF8符号的字节数组。还有一组UTF-8文本的编译器警告和错误。 文件本地类型-添加了对…

服务器公网带宽1M能同时接受多少人访问?

文章目录1、什么是服务器的带宽?2、服务器带宽多少?3、服务器带宽1M能同时接受多少人访问?1、什么是服务器的带宽? 在服务器托管中&#xff0c;服务器带宽指在特定时间段从或向网站/服务器传输的数据量&#xff0c;例如&#xff0c;单月内的累积消耗“带宽”&#xff0c;实…

【开源掌机】百问网DShanMCU-Mio开源掌机(爻-澪)项目,完美支持运行10多个模拟器!

众筹说明 定金翻倍&#xff0c;即定金19.9元&#xff0c;在付尾款时可抵40元(成品售价不会超过120元)&#xff01;达标当天就开搞&#xff0c;满100人加速搞尽量在年前发货&#xff0c;让大家先玩起来&#xff01;如果不达标则原路退款&#xff0c;项目取消。 众筹时间&#…

利用Matlab进行图像分割和边缘检测

本文章包含以下内容&#xff1a; 1、灰度阀值分割 (1)单阈值分割图像 先将一幅彩色图像转换为灰度图像&#xff0c;显示其直方图&#xff0c;参考直方图中灰度的分布&#xff0c;尝试确定阈值&#xff1b;应反复调节阈值的大小&#xff0c;直至二值化的效果最为满意…

LDR6035PD快充快放带数据还要啥莲花清翁

随着Type-C的普及和推广&#xff0c;目前市面上的移动电源正在慢慢淘汰micro-USB接口&#xff0c;逐渐都更新成了Type-C接口&#xff0c;micro-USB接口从2007年上市&#xff0c;已经陪伴我们走过十多个年头&#xff0c;自从2015年Type-C登场&#xff0c;micro-USB也开始渐渐淡出…

写给前端开发者的「Promise备忘手册」

前言 大家好&#xff0c;我是HoMeTown&#xff0c;Promise想必大家都知道&#xff0c;在平时的开发工程中也经常会有用到&#xff0c;但是Promise作为ES6的重要特性&#xff0c;其实还拥有很多丰富的知识&#xff0c;本文面向比较初级一些的同学&#xff0c;可以帮你搞懂Promi…

金庸群侠传3DUnity重置入门-Mods开发

金庸3DUnity重置入门系列文章 金庸3dUnity重置入门 - lua 语法 金庸3dUnity重置入门 - UniTask插件 金庸3dUnity重置入门 - Mods开发 金庸3dUnity重置入门 - Cinemachine 动画 金庸3dUnity重置入门 - 大世界实现方案 金庸3dUnity重置入门 - 素材极限压缩 (部分可能放到付…

[附源码]Nodejs计算机毕业设计基于web的社团管理系统Express(程序+LW)

该项目含有源码、文档、程序、数据库、配套开发软件、软件安装教程。欢迎交流 项目运行 环境配置&#xff1a; Node.js Vscode Mysql5.7 HBuilderXNavicat11VueExpress。 项目技术&#xff1a; Express框架 Node.js Vue 等等组成&#xff0c;B/S模式 Vscode管理前后端分…

机器学习——01基础知识

机器学习——01基础知识 github地址&#xff1a;https://github.com/yijunquan-afk/machine-learning 参考资料 [1] 庞善民.西安交通大学机器学习导论2022春PPT [2] 周志华. 机器学习.北京:清华大学出版社,2016 [3] AIlearning 一、机器学习算法的应用 目前&#xff0c;机…

【Redis】集合Set和底层实现

文章目录Redis 集合(Set)Set简介常用命令应用场景共同关注实例整数集合整数集合介绍整数集合的升级哈希表哈希表的原理和实现Redis中的哈希表rehash渐进式rehashRedis 集合(Set) Set简介 Redis set对外提供的功能与list类似是一个列表的功能&#xff0c;特殊之处在于set是可以…

多维时序 | MATLAB实现GRU多变量时间序列预测

✅作者简介&#xff1a;热爱科研的Matlab仿真开发者&#xff0c;修心和技术同步精进&#xff0c;matlab项目合作可私信。 &#x1f34e;个人主页&#xff1a;Matlab科研工作室 &#x1f34a;个人信条&#xff1a;格物致知。 更多Matlab仿真内容点击&#x1f447; 智能优化算法 …

c语言中fread,fgets等取文件字符的缓存空间小出现问题

一种奇怪现象 #include <stdio.h> #include <stdlib.h> #include<windows.h>int main(void){int i;printf("hello\n");fflush(stdout); //当没有这部刷新&#xff0c;hello会和end等到时间一起输出Sleep(2000); //windowsa.h中的Sleep&#…

某研究生不写论文竟研究起了算命?

起因 大约一个月前&#xff0c;在学校大病一场&#xff08;不知道是不是&#x1f411;了&#xff0c;反正在学校每天核酸没检测出来&#xff09;在宿舍休息了整整一周。当时因为发烧全身疼所以基本一直躺着刷刷视频。看了一周倪海厦老师讲的天纪&#xff0c;人纪感悟颇多&…