昇思25天学习打卡营第15天|应用实践之ShuffleNet图像分类

news2025/1/25 9:12:02

基本介绍

         今天的应用实践的领域是计算机视觉领域,更确切的说是图像分类任务,不过,与昨日不同的是,今天所使用的模型是ShuffleNet模型。ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。今天会简单介绍一些ShuffleNet模型,并使用CIFAR-10数据集进行训练与评估,最后进行模型预测

ShuffleNet模型简介

        ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution和Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。因此,ShuffleNetV1和MobileNet类似,都是通过设计更高效的网络结构来实现模型的压缩和加速

  • Pointwise Group Convolution

Group Convolution(分组卷积)原理如下图所示,相比于普通的卷积操作,分组卷积的情况下,每一组的卷积核大小为in_channels/g*k*k,一共有g组,所有组共有(in_channels/g*k*k)*out_channels个参数,是正常卷积参数的1/g。分组卷积中,每个卷积核只处理输入特征图的一部分通道,其优点在于参数量会有所降低,但输出通道数仍等于卷积核的数量

  • Channel Shuffle

        Group Convolution的弊端在于不同组别的通道无法进行信息交流,堆积GConv层后一个问题是不同组之间的特征图是不通信的,这就好像分成了g个互不相干的道路,每一个人各走各的,这可能会降低网络的特征提取能力。这也是Xception,MobileNet等网络采用密集的1x1卷积(Dense Pointwise Convolution)的原因。为了解决不同组别通道“近亲繁殖”的问题,ShuffleNet优化了大量密集的1x1卷积(在使用的情况下计算量占用率达到了惊人的93.4%),引入Channel Shuffle机制(通道重排)。这项操作直观上表现为将不同分组通道均匀分散重组,使网络在下一层能处理不同组别通道的信息。

以上两个结构就是ShuffleNet的主要结构,ShuffleNet的模型代码(MindSpore版)如下:

class ShuffleNetV1(nn.Cell):
    def __init__(self, n_class=1000, model_size='2.0x', group=3):
        super(ShuffleNetV1, self).__init__()
        print('model size is ', model_size)
        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        if group == 3:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 12, 120, 240, 480]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 240, 480, 960]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 360, 720, 1440]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 480, 960, 1920]
            else:
                raise NotImplementedError
        elif group == 8:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 16, 192, 384, 768]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 384, 768, 1536]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 576, 1152, 2304]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 768, 1536, 3072]
            else:
                raise NotImplementedError
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.SequentialCell(
            nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(input_channel),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]
            for i in range(numrepeat):
                stride = 2 if i == 0 else 1
                first_group = idxstage == 0 and i == 0
                features.append(ShuffleV1Block(input_channel, output_channel,
                                               group=group, first_group=first_group,
                                               mid_channels=output_channel // 4, ksize=3, stride=stride))
                input_channel = output_channel
        self.features = nn.SequentialCell(features)
        self.globalpool = nn.AvgPool2d(7)
        self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)

    def construct(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.globalpool(x)
        x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
        x = self.classifier(x)
        return x

数据集准备

        采用CIFAR-10数据集对ShuffleNet进行预训练。CIFAR-10共有60000张32*32的彩色图像,均匀地分为10个类别,其中50000张图片作为训练集,10000图片作为测试集。可直接使用mindspore.dataset.Cifar10Dataset接口下载并加载CIFAR-10的训练集。这部分的操作和昨天几乎一样,就不进行展示

模型训练与评估

        采用随机初始化的参数做预训练。首先调用ShuffleNetV1定义网络,参数量选择"2.0x",并定义损失函数为交叉熵损失,学习率经过4轮的warmup后采用余弦退火,优化器采用Momentum,总共训练5轮。最后用train.model中的Model接口将模型、损失函数、优化器封装在model中,并用model.train()对网络进行训练。将ModelCheckpointCheckpointConfigTimeMonitorLossMonitor传入回调函数中,将会打印训练的轮数、损失和时间,并将ckpt文件保存在当前目录下。具体训练代码如下:

def train():
    mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    min_lr = 0.0005
    base_lr = 0.05
    lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                                base_lr,
                                                batches_per_epoch*250,
                                                batches_per_epoch,
                                                decay_epoch=250)
    lr = Tensor(lr_scheduler[-1])
    optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
    loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
    model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
    callback = [TimeMonitor(), LossMonitor()]
    save_ckpt_path = "./"
    config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
    ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
    callback += [ckpt_callback]

    print("============== Starting Training ==============")
    start_time = time.time()
    # 由于时间原因,epoch = 5,可根据需求进行调整
    model.train(5, dataset, callbacks=callback)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    print("total time:" + hour + "h " + minute + "m " + second + "s")
    print("============== Train Success ==============")

评估的时候直接使用model.eval()进行评估,具体代码如下:

def test():
    mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="Ascend")
    dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
    load_param_into_net(net, param_dict)
    net.set_train(False)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),
                    'Top_5_Acc': Top5CategoricalAccuracy()}
    model = Model(net, loss_fn=loss, metrics=eval_metrics)
    start_time = time.time()
    res = model.eval(dataset, dataset_sink_mode=False)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-5_390.ckpt" \
        + "', time: " + hour + "h " + minute + "m " + second + "s"
    print(log)
    filename = './eval_log.txt'
    with open(filename, 'a') as file_object:
        file_object.write(log + '\n')

模型预测

        训练完毕则可进行模型预测,并将预测结果可视化,结果如下:

可以看出,shuffleNet效果还是不错的,在轻量化的前提下也保证了一定的精度。

Jupyter运行情况

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1909234.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

柳永,市井生活的吟游者

柳永,原名柳三变,字景庄,后改名为柳永,字耆卿,约生于宋太宗雍熙元年(公元984年),卒于宋仁宗皇祐五年(公元1053年),享年69岁。他是北宋著名词人&am…

最近换工作的一些启示,清华学姐篇

最近更新频率慢下来了,一部分原因是沉迷运动不能自拔,还有一部分原因是业余分出来很大的精力来拓展个人的边界,希望在工作之外取得一些成绩,写作上耽误了不少,很难做到日更。 所以整体上今年更新频率较低,但…

重载、覆盖(重写)、重定义(同名隐藏)的区别 (超详解)

📚 重载(Overloading)、覆盖(Overriding)、重定义(Hiding)是面向对象编程中常见的概念,它们分别用于描述不同情况下函数或方法的行为。 目录 重载(Overloading&#xff…

Zynq系列FPGA实现SDI视频编解码+图像缩放+多路视频拼接,基于GTX高速接口,提供8套工程源码和技术支持

目录 1、前言工程概述免责声明 2、相关方案推荐本博已有的 SDI 编解码方案本博已有的FPGA图像缩放方案本方案的无缩放应用本方案在Xilinx--Kintex系列FPGA上的应用 3、详细设计方案设计原理框图SDI 输入设备Gv8601a 均衡器GTX 解串与串化SMPTE SD/HD/3G SDI IP核BT1120转RGB自研…

一个简单的 Vue 组件例子

https://andi.cn/page/621509.html

17.分频器设计拓展练习-任意分频通用模块

(1)Verilog代码: module divider_n(clk,reset_n,clk_out);input clk;input reset_n;output clk_out;wire clk_out1;wire clk_out2;wire [9:0]n;wire m;assign n 9;assign m n % 2;divider_even divider_even_inst(.clk(clk),.reset_n(reset_n),.n(n),.en(!m),.cl…

多租户hive数仓

1、概念 多租户对应的是单租户,本篇文章重点讲解多租户,单租户为了解内容。 1.1 多租户 多租户技术或称多重租赁技术,简称SaaS,是一种软件架构技术,是实现如何在多用户环境下(此处的多用户一般是面向企业…

解锁Playwright新技能:输入框处理技巧全解析

感谢您抽出 来阅读此文 声明:文章中引用的视频为微信群里面的山豆根大佬原创所录制哟~免费视频录制剪辑不易,请大家多多支持。 并且,大佬为这一系列的视频创作还专门购买了服务器搭建了一个实战项目和练习元素定位的网站。网站的具体信息可…

仅需10行代码,Python带你玩转编程世界!

更多Python学习内容:ipengtao.com Python作为一种简单易学且功能强大的编程语言,其简洁的语法和丰富的库可以在很少的代码行数内实现许多有趣且实用的功能。本文将展示一些通过10行以内的Python代码实现的有趣项目,从简单的数学计算到数据可视…

江协科技51单片机学习- p25 无源蜂鸣器

🚀write in front🚀 🔎大家好,我是黄桃罐头,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流 🎁欢迎各位→点赞👍 收藏⭐️ 留言📝​…

SpringAOP的坑

AOP中几种常见的通知类型及其基本作用: Before:前置通知,在目标方法执行之前执行。After:后置通知,无论方法执行结果如何(包括异常),都会在目标方法执行之后执行。AfterReturning&a…

暑期旅游怎么玩?开发旅游小程序让出行变简单

暑假正值旅游旺季,旅游小程序的出现为旅行带来了许多便利。随着移动互联网的发展,旅游行业也在不断寻求创新与变革。旅游小程序为游客提供了更加便捷的旅行体验,通过旅游小程序,用户可以了解旅游信息、旅游服务、在线咨询等&#…

谷粒商城学习笔记-19-快速开发-逆向生成所有微服务基本CRUD代码

文章目录 一,使用逆向工程步骤梳理1,修改逆向工程的application.yml配置2,修改逆向工程的generator.properties配置3,以Debug模式启动逆向工程4,使用逆向工程生成代码5,整合生成的代码到对应的模块中 二&am…

前端面试题25(css常用的预处理器)

在前端开发领域,CSS预处理器在面试中经常被提及,其中最流行的三种预处理器是Sass、LESS和Stylus。下面分别介绍它们的特点和优势: 1. Sass(Syntactically Awesome Style Sheets) 优势: 变量:允…

[工具类]Java 合并、拆分PPT幻灯片

本文将介绍在Java程序中如何来合并及拆分PPT文档的方法。示例大纲: 1. 合并 1.1 将指定幻灯片合并到文档 1.2 合并多个幻灯片文档为一个文档 2. 拆分 2.1 按幻灯片每一页单独拆分为一个文档 2.2 按指定幻灯片页数范围来拆分为多个文档 使用工具:F…

vite+vue3拍照上传到nodejs服务器

一:效果展示: 拍照效果 二:Nodejs后端接口代码: 三:前端完整代码:

风险评估:Tomcat的安全配置,Tomcat安全基线检查加固

「作者简介」:冬奥会网络安全中国代表队,CSDN Top100,就职奇安信多年,以实战工作为基础著作 《网络安全自学教程》,适合基础薄弱的同学系统化的学习网络安全,用最短的时间掌握最核心的技术。 这一章节我们需…

VUE+ELEMENTUI表格的表尾合计

<el-table :data"XXXX" :summary-method"getSummaries" show-summary "true" > getSummaries(param) { const { columns, data } param; const sums []; columns.forEach((column, index) > { if (index 0) { sums[index] 合计; }…

高考后的IT专业启航:暑期预习指南与学习路线图

文章目录 每日一句正能量前言&#xff1a;启航IT世界&#xff0c;高考后的暑期学习之旅基础课程预习指南基础课程预习指南&#xff1a;构建你的IT知识大厦引言一、计算机科学导论二、编程语言入门三、操作系统基础四、数据结构与算法五、网络基础六、数据库原理结语 技术学习路…

ollama教程——如何在Ollama中导入和管理GGUF与Safetensors模型

ollama教程——如何在Ollama中导入和管理GGUF与Safetensors模型 引言Ollama模型导入概述Ollama支持的模型格式Ollama的版本要求和安装安装OllamaGGUF模型导入什么是GGUF模型通过Modelfile导入GGUF模型代码示例常见问题和解决方案1. 模型文件路径错误2. 模型文件格式不正确3. Ol…