昇思学习打卡-10-ShuffleNet图像分类

news2024/10/6 2:47:10

文章目录

  • 网络介绍
  • 网络结构
    • 部分实现
    • 对应网络结构
  • 模型训练
  • shuffleNet的优缺点总结
    • 优点
    • 不足

网络介绍

ShuffleNet主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution和Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。ShuffleNet在保持不低的准确率的前提下,将参数量几乎降低到了最小,因此其运算速度较快,单位参数量对模型准确率的贡献非常高。

网络结构

部分实现

class ShuffleV1Block(nn.Cell):
    def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
        super(ShuffleV1Block, self).__init__()
        self.stride = stride
        pad = ksize // 2
        self.group = group
        if stride == 2:
            outputs = oup - inp
        else:
            outputs = oup
        self.relu = nn.ReLU()
        branch_main_1 = [
            GroupConv(in_channels=inp, out_channels=mid_channels,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=1 if first_group else group),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        ]
        branch_main_2 = [
            nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
                      pad_mode='pad', padding=pad, group=mid_channels,
                      weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(mid_channels),
            GroupConv(in_channels=mid_channels, out_channels=outputs,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=group),
            nn.BatchNorm2d(outputs),
        ]
        self.branch_main_1 = nn.SequentialCell(branch_main_1)
        self.branch_main_2 = nn.SequentialCell(branch_main_2)
        if stride == 2:
            self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')

    def construct(self, old_x):
        left = old_x
        right = old_x
        out = old_x
        right = self.branch_main_1(right)
        if self.group > 1:
            right = self.channel_shuffle(right)
        right = self.branch_main_2(right)
        if self.stride == 1:
            out = self.relu(left + right)
        elif self.stride == 2:
            left = self.branch_proj(left)
            out = ops.cat((left, right), 1)
            out = self.relu(out)
        return out

    def channel_shuffle(self, x):
        batchsize, num_channels, height, width = ops.shape(x)
        group_channels = num_channels // self.group
        x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
        x = ops.transpose(x, (0, 2, 1, 3, 4))
        x = ops.reshape(x, (batchsize, num_channels, height, width))
        return x

对应网络结构

在这里插入图片描述

模型训练

import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracy

def train():
    mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    min_lr = 0.0005
    base_lr = 0.05
    lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                                base_lr,
                                                batches_per_epoch*250,
                                                batches_per_epoch,
                                                decay_epoch=250)
    lr = Tensor(lr_scheduler[-1])
    optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
    loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
    model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
    callback = [TimeMonitor(), LossMonitor()]
    save_ckpt_path = "./"
    config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
    ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
    callback += [ckpt_callback]

    print("============== Starting Training ==============")
    start_time = time.time()
    # 由于时间原因,epoch = 5,可根据需求进行调整
    model.train(5, dataset, callbacks=callback)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    print("total time:" + hour + "h " + minute + "m " + second + "s")
    print("============== Train Success ==============")

if __name__ == '__main__':
    train()

在这里插入图片描述

shuffleNet的优缺点总结

优点

  • 轻量化设计:ShuffleNet通过减少参数和计算量,使得模型更适合在资源受限的设备上运行。
  • 高效的通道重排:ShuffleNet引入了一种高效的通道重排机制,称为"Shuffle",以提高模型的表达能力,同时保持参数数量较低。
  • 性能与速度的平衡:ShuffleNet在保持较低延迟的同时,提供了与较大模型相当的准确性,这使得它在实时应用中非常有用。
  • 适用于移动设备:由于其轻量化的特性,ShuffleNet非常适合在移动设备和嵌入式系统中使用。
  • 参数共享:ShuffleNet通过在网络中使用参数共享技术,进一步减少了模型的大小和计算成本。
  • 灵活性:ShuffleNet提供了不同版本的模型,允许研究者和开发者根据特定应用的需求选择合适的模型大小。

不足

  • 准确率折损:由于模型的轻量化设计,ShuffleNet在某些复杂任务上可能无法达到最高精度。
  • 特定任务的局限性:在一些需要非常高精度或对模型容量要求较高的任务中,ShuffleNet可能不是最佳选择。
  • 对输入尺寸敏感:ShuffleNet的性能可能对输入图像的尺寸较为敏感,这可能需要在设计网络时进行额外的考虑。

此章节学习到此结束,感谢昇思平台。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1904011.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

20、matlab信号波形生成:狄利克雷函数、高斯脉冲和高斯脉冲序列

1、名词说明 狄利克雷函数(Dirac Delta Function) 狄利克雷函数,也称为单位冲激函数或δ函数,是一个在数学和信号处理中常用的特殊函数。狄利克雷函数通常用符号δ(t)表示,其定义为: δ(t) { ∞, t 0{…

美股交易相关知识点 持续完善中

美股交易时间 美东时间:除了凌晨 03:50 ~ 04:00 这10分钟时间不可交易以外,其他时间都是可以交易的。 如果是在香港或者北京时间下交易要区分两种: 美东夏令时:除了下午 15:50 ~ 16:00 这10分钟时间不可交易以外,其他时间都是可…

springboot公寓租赁系统-计算机毕业设计源码03822

摘要 1 绪论 1.1 研究背景与意义 1.2选题背景 1.3论文结构与章节安排 2 公寓租赁系统系统分析 2.1 可行性分析 2.1.1 技术可行性分析 2.1.2 经济可行性分析 2.1.3 法律可行性分析 2.2 系统功能分析 2.2.1 功能性分析 2.2.2 非功能性分析 2.3 系统用例分析 2.4 系…

GRPC使用之ProtoBuf

1. 入门指导 1. 基本定义 Protocol Buffers提供一种跨语言的结构化数据的序列化能力,类似于JSON,不过更小、更快,除此以外它还能用用接口定义(IDL interface define language),通protoc编译Protocol Buffer定义文件,…

拆分Transformer注意力,韩国团队让大模型解码提速20倍|大模型AI应用开始小规模稳步爆发|周伯文:大模型也有幻觉,全球AI创新指数公布

拆分Transformer注意力,韩国团队让大模型解码提速20倍AI正在颠覆AI上市不到两年,蜗牛游戏可能要退市了?世界人工智能大会结束了,百花齐放,但也群魔乱舞“串联OLED”被苹果带火了,比OLED强在哪里&#xff1f…

赚钱小思路,送给没有背景的辛辛苦苦努力的我们!

我是一个没有背景的普通人,主要靠勤奋和一股钻劲,这十几年来我的日常作息铁打不变,除了睡觉,不是在搞钱,就是在琢磨怎么搞钱。 ​ 可以说打拼了十几年,各种小生意都做过,以前一直是很乐观的&…

SSM养老院管理系统-计算机毕业设计源码02221

摘要 本篇论文旨在设计和实现一个基于SSM的养老院管理系统,旨在提供高效、便捷的养老院管理服务。该系统将包括老人档案信息管理、护工人员管理、房间信息管理、费用管理等功能模块,以满足养老院管理者和居民的不同需求。 通过引入SSM框架&#x…

动手学深度学习(Pytorch版)代码实践 -循环神经网络-54循环神经网络概述

54循环神经网络概述 1.潜变量自回归模型 使用潜变量h_t总结过去信息 2.循环神经网络概述 ​ 循环神经网络(recurrent neural network,简称RNN)源自于1982年由Saratha Sathasivam 提出的霍普菲尔德网络。循环神经网络,是指在全…

批量爬取B站网络视频信息

使用XPath爬取B站视频链接等相关信息 分析B站html框架获取内容完整代码 对于B站,目前网上的爬虫大多都是使用通过解析服务器的响应来爬取想要的内容,下面我们通过使用XPath来爬取B站上一些想要的信息 此次任务我们需要对B站搜索到的关键字,并…

Linux系统安装软件包的方法rpm和yum详解

起因: 本篇文章是记录学习Centos7的历程 关于rpm 常见命令 1)查看已经安装的软件包 rpm -q 软件包名 2)查看文件的相关信息 rpm -qi 软件包名 3)查看软件包的依赖关系 就是说要想安装这个软件包,就必须把一些前…

记录一次ffmpeg手动编译出现的问题

前言部分 使用环境: ubuntu 22.04 最近手动编译了一次的ffmpeg(参考博客ffmpeg学习:ubuntu下编译ffmpeg(全网最懒的编译脚本)),但是过程出现了一些问题,因此在此记录一下,若有疑问,欢迎讨论~。 …

15集终于编译成功了-了个球!编译TFLite Micro语音识别工程-《MCU嵌入式AI开发笔记》

15集终于编译成功了-个球!编译TFLite Micro语音识别工程-《MCU嵌入式AI开发笔记》 还是参考这个官方文档: https://codelabs.developers.google.cn/codelabs/sparkfun-tensorflow#2 全是干货! 这里面提到的这个Micro工程已经移开了&#xff1…

Overleaf :LaTeX协作神器!【送源码】

Overleaf 是一个广受欢迎的在线 LaTeX 编辑器,专为学术写作和文档排版设计。它以其协作功能和用户友好的界面而闻名,使得 LaTeX 编辑变得更加容易和直观。 软件介绍 Overleaf 提供了一个基于云的 LaTeX 编辑环境,支持实时协作,使得…

哲讯SAP知识分享:SAP资产模块常用事务代码清单

在当今日益复杂的商业环境中,企业对于资产管理的需求日益增强。SAP作为全球领先的企业管理软件提供商,其资产模块(AM)以其高效、灵活的特性,为企业提供了全面的资产管理解决方案。本文将对SAP资产事务类型进行详细介绍…

算法的空间复杂度(C语言)

1.空间复杂度的定义 算法在临时占用储存空间大小的量度(就是完成这个算法所额外开辟的空间),空间复杂度也使用大O渐进表示法来表示 注: 函数在运行时所需要的栈空间(储存参数,局部变量,一些寄存器信息等)…

MySQL第三天作业

一、在数据库中创建一个表student,用于存储学生信息 CREATE TABLE student( id INT PRIMARY KEY, name VARCHAR(20) NOT NULL, grade FLOAT ); 1、向student表中添加一条新记录 记录中id字段的值为1,name字段的值为"monkey"…

STM32第十六课:WiFi模块的配置及应用

文章目录 需求一、WiFi模块概要二、配置流程1.配置通信串口,引脚和中断2.AT指令3.发送逻辑编写 三、需求实现代码总结 需求 完成WiFi模块的配置,使其最终能和服务器相互发送消息。 一、WiFi模块概要 本次使用的WiFi模块为ESP-12F模块(安信可&#xf…

字符串——string类的常用接口

一、string类对象的常见构造 二、string类对象的容量操作 三、string类对象的访问及遍历操作 四、string类对象的修改操作 一、string类对象的常见构造 1.string() ——构造空的string类对象,也就是空字符串 2.string(const char* s) ——用字符串来初始化stri…

Win10如何设置远程桌面?

远程桌面介绍 远程桌面是一款Windows提供的远程工具,旨在连接同一局域网内的两台计算机。如果您掌握被控端电脑的IP地址,便可直接连接到这台已启用远程桌面的计算机,通过远程桌面进行文件传输或提供远程技术支持。 在同一家公司内&#xff0…

JVM专题之垃圾收集器

JVM参数 3.1.1 标准参数 -version -help -server -cp 3.1.2 -X参数 非标准参数,也就是在JDK各个版本中可能会变动 ``` -Xint 解释执行 -Xcomp 第一次使用就编译成本地代码 -Xmixed 混合模式,JVM自己来决定 3.1.3 -XX参数 > 使用得最多的参数类型 > > 非…