【MindSpore学习打卡】应用实践-计算机视觉-ShuffleNet图像分类:从理论到实践

news2024/12/23 3:07:53

在当今的深度学习领域,卷积神经网络(CNN)已经成为图像分类任务的主流方法。然而,随着网络深度和复杂度的增加,计算资源的消耗也显著增加,特别是在移动设备和嵌入式系统中,这种资源限制尤为突出。ShuffleNet作为一种高效的卷积神经网络,通过引入Pointwise Group Convolution和Channel Shuffle两种操作,大大降低了计算量,同时保持了较高的分类精度。在本篇博客中,我们将详细探讨ShuffleNet的设计原理,并通过MindSpore框架实现ShuffleNet在CIFAR-10数据集上的训练与评估,帮助读者更好地理解和应用这一高效的网络结构。

ShuffleNet网络介绍

ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,主要应用在移动端。其设计核心在于引入了两种操作:Pointwise Group Convolution和Channel Shuffle。这些操作在保持精度的同时,大大降低了模型的计算量。

Pointwise Group Convolution

Pointwise Group Convolution:我们在代码中定义了一个GroupConv类,用于实现逐点分组卷积。这种卷积操作通过将输入特征图分成多个组,每组单独进行卷积操作,从而显著减少了参数量和计算量。具体来说,逐点分组卷积的卷积核大小为 1 × 1 1 \times 1 1×1,这使得每个卷积核只作用于一个通道,进一步降低了计算复杂度。

Group Convolution

class GroupConv(nn.Cell):
    def __init__(self, in_channels, out_channels, kernel_size, stride, pad_mode="pad", pad=0, groups=1, has_bias=False):
        super(GroupConv, self).__init__()
        self.groups = groups
        self.convs = nn.CellList()
        for _ in range(groups):
            self.convs.append(nn.Conv2d(in_channels // groups, out_channels // groups,
                                        kernel_size=kernel_size, stride=stride, has_bias=has_bias,
                                        padding=pad, pad_mode=pad_mode, group=1, weight_init='xavier_uniform'))

    def construct(self, x):
        features = ops.split(x, split_size_or_sections=int(len(x[0]) // self.groups), axis=1)
        outputs = ()
        for i in range(self.groups):
            outputs = outputs + (self.convs[i](features[i].astype("float32")),)
        out = ops.cat(outputs, axis=1)
        return out

Channel Shuffle

Channel Shuffle:为了克服分组卷积带来的不同组别通道无法进行信息交流的问题,ShuffleNet引入了Channel Shuffle机制。我们在代码中实现了一个channel_shuffle方法,通过对通道进行重排,使得不同组别的通道能够进行信息交互。这一步骤在保持网络高效性的同时,增强了特征的表达能力。
Channel Shuffle

def channel_shuffle(self, x):
    batchsize, num_channels, height, width = ops.shape(x)
    group_channels = num_channels // self.group
    x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
    x = ops.transpose(x, (0, 2, 1, 3, 4))
    x = ops.reshape(x, (batchsize, num_channels, height, width))
    return x

ShuffleNet模块

ShuffleNet模块:在ShuffleNet模块中,我们结合了Pointwise Group Convolution和Channel Shuffle,并在降采样模块中引入了步长为2的Depth Wise Convolution。这种设计不仅提高了网络的计算效率,还保证了特征提取的有效性。在代码实现中,我们通过ShuffleV1Block类定义了ShuffleNet的基本模块,并在其中实现了上述操作。

ShuffleNet模块

class ShuffleV1Block(nn.Cell):
    def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
        super(ShuffleV1Block, self).__init__()
        self.stride = stride
        pad = ksize // 2
        self.group = group
        if stride == 2:
            outputs = oup - inp
        else:
            outputs = oup
        self.relu = nn.ReLU()
        branch_main_1 = [
            GroupConv(in_channels=inp, out_channels=mid_channels,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=1 if first_group else group),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        ]
        branch_main_2 = [
            nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
                      pad_mode='pad', padding=pad, group=mid_channels,
                      weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(mid_channels),
            GroupConv(in_channels=mid_channels, out_channels=outputs,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=group),
            nn.BatchNorm2d(outputs),
        ]
        self.branch_main_1 = nn.SequentialCell(branch_main_1)
        self.branch_main_2 = nn.SequentialCell(branch_main_2)
        if stride == 2:
            self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')

    def construct(self, old_x):
        left = old_x
        right = old_x
        out = old_x
        right = self.branch_main_1(right)
        if self.group > 1:
            right = self.channel_shuffle(right)
        right = self.branch_main_2(right)
        if self.stride == 1:
            out = self.relu(left + right)
        elif self.stride == 2:
            left = self.branch_proj(left)
            out = ops.cat((left, right), 1)
            out = self.relu(out)
        return out

构建ShuffleNet网络

ShuffleNet网络结构如下图所示。以输入图像 224 × 224 224 \times 224 224×224,组数3(g = 3)为例,经过多个ShuffleNet模块和全局平均池化,最终得到分类结果。

ShuffleNet网络结构

class ShuffleNetV1(nn.Cell):
    def __init__(self, n_class=1000, model_size='2.0x', group=3):
        super(ShuffleNetV1, self).__init__()
        print('model size is ', model_size)
        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        if group == 3:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 12, 120, 240, 480]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 240, 480, 960]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 360, 720, 1440]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 480, 960, 1920]
            else:
                raise NotImplementedError
        elif group == 8:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 16, 192, 384, 768]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 384, 768, 1536]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 576, 1152, 2304]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 768, 1536, 3072]
            else:
                raise NotImplementedError
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.SequentialCell(
            nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(input_channel),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]
            for i in range(numrepeat):
                stride = 2 if i == 0 else 1
                first_group = idxstage == 0 and i == 0
                features.append(ShuffleV1Block(input_channel, output_channel,
                                               group=group, first_group=first_group,
                                               mid_channels=output_channel // 4, ksize=3, stride=stride))
                input_channel = output_channel
        self.features = nn.SequentialCell(features)
        self.globalpool = nn.AvgPool2d(7)
        self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)

    def construct(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.globalpool(x)
        x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
        x = self.classifier(x)
        return x

模型训练和评估

模型训练和评估:在训练部分,我们使用了CIFAR-10数据集,并通过数据增强技术(如随机裁剪和水平翻转)来提高模型的泛化能力。我们定义了ShuffleNet网络,并使用交叉熵损失函数和Momentum优化器进行训练。在评估部分,我们加载训练好的模型,并在测试集上进行评估,计算模型的Top-1和Top-5准确率,以全面衡量模型的性能。

训练集准备与加载

首先下载并加载CIFAR-10数据集。CIFAR-10共有60000张32x32的彩色图像,分为10个类别,其中50000张图片作为训练集,10000张图片作为测试集。

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"
download(url, "./dataset", kind="tar.gz", replace=True)

import mindspore as ms
from mindspore.dataset import Cifar10Dataset
from mindspore.dataset import vision, transforms

def get_dataset(train_dataset_path, batch_size, usage):
    image_trans = []
    if usage == "train":
        image_trans = [
            vision.RandomCrop((32, 32), (4, 4, 4, 4)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Resize((224, 224)),
            vision.Rescale(1.0 / 255.0, 0.0),
            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            vision.HWC2CHW()
        ]
    elif usage == "test":
        image_trans = [
            vision.Resize((224, 224)),
            vision.Rescale(1.0 / 255.0, 0.0),
            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            vision.HWC2CHW()
        ]
    label_trans = transforms.TypeCast(ms.int32)
    dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True)
    dataset = dataset.map(image_trans, 'image')
    dataset = dataset.map(label_trans, 'label')
    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset

dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "train")
batches_per_epoch = dataset.get_dataset_size()

模型训练

定义ShuffleNet网络,并使用交叉熵损失函数和Momentum优化器进行训练。

import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracy

def train():
    mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    min_lr = 0.0005
    base_lr = 0.05
    lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                                base_lr,
                                                batches_per_epoch*250,
                                                batches_per_epoch,
                                                decay_epoch=250)
    lr = Tensor(lr_scheduler[-1])
    optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
    loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
    model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
    callback = [TimeMonitor(), LossMonitor()]
    save_ckpt_path = "./"
    config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
    ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
    callback += [ckpt_callback]

    print("============== Starting Training ==============")
    start_time = time.time()
    model.train(5, dataset, callbacks=callback)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    print("total time:" + hour + "h " + minute + "m " + second + "s")
    print("============== Train Success ==============")

if __name__ == '__main__':
    train()

在这里插入图片描述

模型评估

在CIFAR-10测试集上对训练好的模型进行评估。

from mindspore import load_checkpoint, load_param_into_net

def test():
    mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="Ascend")
    dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
    load_param_into_net(net, param_dict)
    net.set_train(False)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),
                    'Top_5_Acc': Top5CategoricalAccuracy()}
    model = Model(net, loss_fn=loss, metrics=eval_metrics)
    start_time = time.time()
    res = model.eval(dataset, dataset_sink_mode=False)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-5_390.ckpt" \
        + "', time: " + hour + "h " + minute + "m " + second + "s"
    print(log)
    filename = './eval_log.txt'
    with open(filename, 'a') as file_object:
        file_object.write(log + '\n')

if __name__ == '__main__':
    test()

在这里插入图片描述

模型预测

在CIFAR-10测试集上对模型进行预测,并将预测结果可视化。

import mindspore
import matplotlib.pyplot as plt
import mindspore.dataset as ds

net = ShuffleNetV1(model_size="2.0x", n_class=10)
show_lst = []
param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
load_param_into_net(net, param_dict)
model = Model(net)
dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = dataset_show.batch(16)
show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()
image_trans = [
    vision.RandomCrop((32, 32), (4, 4, 4, 4)),
    vision.RandomHorizontalFlip(prob=0.5),
    vision.Resize((224, 224)),
    vision.Rescale(1.0 / 255.0, 0.0),
    vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
    vision.HWC2CHW()
        ]
dataset_predict = dataset_predict.map(image_trans, 'image')
dataset_predict = dataset_predict.batch(16)
class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
# 推理效果展示(上方为预测的结果,下方为推理效果图片)
plt.figure(figsize=(16, 5))
predict_data = next(dataset_predict.create_dict_iterator())
output = model.predict(ms.Tensor(predict_data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
index = 0
for image in show_images_lst:
    plt.subplot(2, 8, index+1)
    plt.title('{}'.format(class_dict[pred[index]]))
    index += 1
    plt.imshow(image)
    plt.axis("off")
plt.show()

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1888513.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构(JAVA)—代码题

01-数据结构—判断题 02-数据结构—选择题 03 数据结构—多选填空程序填空 ​ 01-顺序表的建立及遍历 import java.util.Iterator; import java.util.LinkedList; import java.util.ListIterator; import java.util.Scanner;public class Main {public static void main(St…

有哪些好的 Stable Diffusion 提示词(Prompt)可以参考?

Docker 作图咒语生成器 docker-prompt-generator 是一个开源项目,可以利用模型反推出提示词,让你偷偷懒,无需琢磨怎么写prompt,只需要找一个差不多的模型反推一下,直接用就好了,支持支持 MidJourney、Stab…

《Linux开发笔记》C语言编译

C语言编译过程 编译过程主要分为四步:预处理、编译、汇编、链接 预处理:主要用于查找头文件、展开宏 编译:把.i文件编译成.s文件 汇编:把.s文件汇编为.o文件 链接:把多个.o文件链接成一个app 以上四个步骤主要由3个命…

颅内感染性疾病患者就诊指南

颅内感染性疾病,即病原体侵入中枢神经系统,导致脑部或脑膜发生炎症的疾病。这些病原体可能是细菌、病毒、真菌或寄生虫等。颅内感染不仅会对脑组织造成损害,还可能引发一系列严重的并发症,如癫痫发作、意识障碍等 颅内感染性疾病的…

电子邮件OTP验证身份认证接口API服务商比较

电子邮件OTP验证身份认证接口API服务商如何正确选择? 电子邮件OTP验证是一种广泛应用且安全的身份认证方式。AokSend将比较几家主要的电子邮件OTP验证身份认证接口API服务商,帮助企业选择合适的解决方案。 电子邮件OTP:验证优势 可以为用户…

UE4_材质_材质节点_DepthFade

一、DepthFade参数 DepthFade(深度消退)表达式用来隐藏半透明对象与不透明对象相交时出现的不美观接缝。 项目说明属性消退距离(Fade Distance)这是应该发生消退的全局空间距离。未连接 FadeDistance(FadeDistance&a…

数字时代的巨变!互联网思维与按效果付费的超强联合!

数字时代企业转型的重大变革 家人们,今天来聊聊数字时代企业数字化转型那些事儿! 在这个时代,企业数字化转型已经是必然趋势啦。蚓链数字化生态系统在帮助企业在转型过程中采用互联网思维,再加上按效果付费的商业交付模式&#xf…

小程序开发平台版源码系统——万能门店小程序功能 前后端分离 带完整的安装代码包以及搭建教程

系统概述 在移动互联网的浪潮中,小程序以其轻量、便捷、无需下载即可使用的特点,迅速成为连接用户与商家的新桥梁。为了满足广大商家快速搭建个性化、高效运营的小程序需求,我们精心打造了“小程序开发平台版源码系统——万能门店小程序功能…

【经典算法题】两数之和

暴力解法 两层for循环&#xff0c;O(n*n) 优化解法 哈希&#xff0c;O(n) class Solution {public int[] twoSum(int[] nums, int target) {Map<Integer, Integer> hashtable new HashMap<Integer, Integer>();for (int i 0; i < nums.length; i) {if (ha…

【Hec-Ras】第一期:软件安装

Hec-Ras软件安装 1 HEC-RAS软件介绍2 HEC-RAS软件下载3 HEC-RAS软件安装4 HEC-RAS软件界面介绍参考 1 HEC-RAS软件介绍 HEC-RAS 是美国陆军工程兵团工程水文中心&#xff08; Hydrologic Engineering Centers, HEC&#xff09;开发的河道水力计算程序&#xff08;River Analys…

前后端分离django-restframework——解决跨域请求

为什么会出现跨域问题&#xff1f; 出于浏览器的同源策略限制。同源策略&#xff08;Sameoriginpolicy&#xff09;是一种约定&#xff0c;它是浏览器最核心也最基本的安全功能&#xff0c;如果缺少了同源策略&#xff0c;则浏览器的正常功能可能都会受到影响。可以说Web是构建…

Redis 7.x 系列【16】持久化机制之 AOF

有道无术&#xff0c;术尚可求&#xff0c;有术无道&#xff0c;止于术。 本系列Redis 版本 7.2.5 源码地址&#xff1a;https://gitee.com/pearl-organization/study-redis-demo 1. 概述 默认情况下&#xff0c;Redis 使用RDB持久化&#xff0c;但如果进程出现问题或电源中断…

企业如何在当前形势下实现精益生产?

近年来&#xff0c;企业面临诸多挑战&#xff0c;从成本控制到效率提升&#xff0c;每一项决策都关乎着企业的生死存亡。尤其是在全球经济一体化和市场竞争日益激烈的大背景下&#xff0c;企业如何实现精益生产&#xff0c;成为了摆在众多企业家面前的一道难题。今天&#xff0…

STM32 看门狗 HAL

由时钟图可以看出看门狗采用的是内部低速时钟&#xff0c;频率为40KHz 打开看门狗&#xff0c;采用32分频&#xff0c;计数1250。 结合设置的分频系数和重载计数值&#xff0c;我们可以计算出看门狗的定时时间&#xff1a; 32*1250/40kHz 1s 主函数中喂狗就行 HAL_IWDG_Ref…

ISP IC/FPGA设计-第一部分-SC130GS摄像头分析(0)

1.介绍 SC130GS是一款国产的Global shutter CMOS图像传感器&#xff0c;最高支持1280Hx1024V240fps的传输速率&#xff1b;SC130GS有黑白和彩色款&#xff0c;作为ISP开发选择彩色的&#xff0c;有效像素窗口为1288Hx1032V&#xff0c;支持复杂的片上操作&#xff0c;选择他理…

2023-2024华为ICT大赛中国区 实践赛网络赛道 全国总决赛 理论部分真题

Part1 数通模块(10题)&#xff1a; 1、如图所示&#xff0c;某园区部署了IPv6进行业务测试&#xff0c;该网络中有4台路由器&#xff0c;运行OSPFv3实现网络的互联互通&#xff0c;以下关于该OSPFv3网络产生的LSA的描述&#xff0c;错误的是哪一项?(单选题) A.R1的LSDB中将存在…

若依 ruoyi vue上传控件 el-upload上传文件 判断是否有文件 判断文件大小

console.info(this.$refs.upload.uploadFiles.length)//this.$refs.upload.uploadFiles.length 获取当前上传控件中已选择的文件大小//判断是否存在已上传文件 if(this.$refs.upload.uploadFiles.length 0){this.$modal.msgWarning("请上传文件");return; }

高质量免费数字人(整合包),是的,我又出手了~

在数字时代&#xff0c;我们不断追求更加逼真和高效的技术应用&#xff0c;特别是在数字人领域。腾讯推出了一款名为MuseTalk的革命性数字人产品&#xff0c;支持实时音频驱动的唇部同步数字人&#xff0c;迅速成为行业的新宠。&#xff08;先来一张美图镇楼&#xff09; Muse…

【SPIE独立出版】第四届智能交通系统与智慧城市国际学术会议(ITSSC 2024)

第四届智能交通系统与智慧城市国际学术会议&#xff08;ITSSC 2024&#xff09;将于2024年8月23-25日在中国西安举行。本次会议主要围绕智能交通、交通新能源、无人驾驶、智慧城市、智能家居、智能生活等研究领域展开讨论&#xff0c; 旨在为该研究领域的专家学者们提供一个分享…

鼠标点击器免费版?详细介绍鼠标连点器的如何使用

随着科技的发展&#xff0c;鼠标连点器逐渐成为了我们生活和工作中不可或缺的工具。它不仅能够帮助我们完成频繁且重复的点击任务&#xff0c;还能在很大程度上减少我们的手部疲劳&#xff0c;提高工作效率。本文将详细介绍鼠标连点器的使用方法&#xff0c;并推荐三款好用的免…