昇思25天学习打卡营第13天 | mindspore 实现 ShuffleNet 图像分类

news2025/1/22 21:12:13

1. 背景:

使用 mindspore 学习神经网络,打卡第 13 天;主要内容也依据 mindspore 的学习记录。

2. 迁移学习介绍:

mindspore 实现 ShuffleNet 图像分类;

  • ShuffleNet 基本介绍:
    ShuffleNetV1 是旷视科技提出的一种计算高效的 CNN 模型,设计目标是利用有限资源达到最好的模型精度;An Extremely Efficient Convolutional Neural Network for MobileDevices 文章链接 一文中提出的一种网络框架。

  • 解决的问题:
    降低模型的计算量,同时达到最好的模型精度,可以应用到移动端;

  • 创新点:
    a. 逐点分组卷积 (Pointwise Group Convolution):
    将输入的特征分组卷积;这样每个卷积核只处理输入特征图的一部分通道;这样,降低了参数量,同时,输出通道数等于卷积核数量;
    Group Convolution
    Pointwise Group Convolution:在分组卷积基础上,令每一组卷积核都为 1*1;
    b. 通道重排 (Channel Shuffle):
    不同通道均匀分散重组,使网络在下一层处理不同通道信息;
    Channel Shuffle 通道重排

Channel Shuffle 的逻辑:
Channel Shuffle 的逻辑

3. 具体实现:

3.1 数据下载:

使用 CIFAR-10 数据集,共有60000张32*32的彩色图像,分为10个类别,每类有6000张图,数据集一共有50000张训练图片和10000张评估图片;

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"

download(url, "./datasets-cifar10-bin", kind="tar.gz", replace=True)

3.2 数据前处理:

对 cifar10 数据集做处理

import mindspore as ms
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
import mindspore.dataset.transforms as transforms
from mindspore import dtype as mstype

data_dir = "./datasets-cifar10-bin/cifar-10-batches-bin"  # 数据集根目录
batch_size = 256  # 批量大小
image_size = 32  # 训练图像空间大小
workers = 4  # 并行线程个数
num_classes = 10  # 分类数量

def create_dataset_cifar10(dataset_dir, usage, resize, batch_size, workers):

    data_set = ds.Cifar10Dataset(dataset_dir=dataset_dir,
                                 usage=usage,
                                 num_parallel_workers=workers,
                                 shuffle=True)

    trans = []
    if usage == "train":
        trans += [
            vision.RandomCrop((32, 32), (4, 4, 4, 4)),
            vision.RandomHorizontalFlip(prob=0.5)
        ]

    trans += [
        vision.Resize(resize),
        vision.Rescale(1.0 / 255.0, 0.0),
        vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
        vision.HWC2CHW()
    ]

    target_trans = transforms.TypeCast(mstype.int32)

    # 数据映射操作
    data_set = data_set.map(operations=trans,
                            input_columns='image',
                            num_parallel_workers=workers)

    data_set = data_set.map(operations=target_trans,
                            input_columns='label',
                            num_parallel_workers=workers)

    # 批量操作
    data_set = data_set.batch(batch_size)

    return data_set


# 获取处理后的训练与测试数据集
dataset_train = create_dataset_cifar10(dataset_dir=data_dir,
                                       usage="train",
                                       resize=image_size,
                                       batch_size=batch_size,
                                       workers=workers)
step_size_train = dataset_train.get_dataset_size()

dataset_val = create_dataset_cifar10(dataset_dir=data_dir,
                                     usage="test",
                                     resize=image_size,
                                     batch_size=batch_size,
                                     workers=workers)
step_size_val = dataset_val.get_dataset_size()

3.3 构建ShuffleNet 模块单元:

对于 ShuffleNet 模块单元,主要是 ShuffleNet 模块单元;
如论文中图所示:
在这里插入图片描述
相对于 ResNet 中的 Bottleneck 结构,有如下修改:
a. 将开始和最后的1 * 1 卷积模块(降维、升维)改成Point Wise Group Convolution;
b. 为了进行不同通道的信息交流,再降维之后进行Channel Shuffle;
c. 降采样模块中,3 * 3 的 Depth Wise Convolution的步长设置为2,长宽降为原来的一般,因此shortcut中采用步长为 2 的 3 * 3 平均池化,并把相加改成拼接。

  • ShuffleV1 Block 代码如下:
class ShuffleV1Block(nn.Cell):
    def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
        super(ShuffleV1Block, self).__init__()
        self.stride = stride
        pad = ksize // 2
        self.group = group
        if stride == 2:
            outputs = oup - inp
        else:
            outputs = oup
        self.relu = nn.ReLU()
        branch_main_1 = [
            GroupConv(in_channels=inp, out_channels=mid_channels,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=1 if first_group else group),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        ]
        branch_main_2 = [
            nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
                      pad_mode='pad', padding=pad, group=mid_channels,
                      weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(mid_channels),
            GroupConv(in_channels=mid_channels, out_channels=outputs,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=group),
            nn.BatchNorm2d(outputs),
        ]
        self.branch_main_1 = nn.SequentialCell(branch_main_1)
        self.branch_main_2 = nn.SequentialCell(branch_main_2)
        if stride == 2:
            self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')

    def construct(self, old_x):
        left = old_x
        right = old_x
        out = old_x
        right = self.branch_main_1(right)
        if self.group > 1:
            right = self.channel_shuffle(right)
        right = self.branch_main_2(right)
        if self.stride == 1:
            out = self.relu(left + right)
        elif self.stride == 2:
            left = self.branch_proj(left)
            out = ops.cat((left, right), 1)
            out = self.relu(out)
        return out

    def channel_shuffle(self, x):
        batchsize, num_channels, height, width = ops.shape(x)
        group_channels = num_channels // self.group
        x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
        x = ops.transpose(x, (0, 2, 1, 3, 4))
        x = ops.reshape(x, (batchsize, num_channels, height, width))
        return x

3.4 构建 ShuffleNet V1 网络结构:

如 Table 1 所示:在这里插入图片描述
代码如下:

class ShuffleNetV1(nn.Cell):
    def __init__(self, n_class=1000, model_size='2.0x', group=3):
        super(ShuffleNetV1, self).__init__()
        print('model size is ', model_size)
        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        if group == 3:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 12, 120, 240, 480]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 240, 480, 960]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 360, 720, 1440]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 480, 960, 1920]
            else:
                raise NotImplementedError
        elif group == 8:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 16, 192, 384, 768]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 384, 768, 1536]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 576, 1152, 2304]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 768, 1536, 3072]
            else:
                raise NotImplementedError
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.SequentialCell(
            nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(input_channel),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]
            for i in range(numrepeat):
                stride = 2 if i == 0 else 1
                first_group = idxstage == 0 and i == 0
                features.append(ShuffleV1Block(input_channel, output_channel,
                                               group=group, first_group=first_group,
                                               mid_channels=output_channel // 4, ksize=3, stride=stride))
                input_channel = output_channel
        self.features = nn.SequentialCell(features)
        self.globalpool = nn.AvgPool2d(7)
        self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)

    def construct(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.globalpool(x)
        x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
        x = self.classifier(x)
        return x

3.5 模型训练与评估:

  • 模型训练:
    本节用随机初始化的参数做预训练。首先调用ShuffleNetV1定义网络,参数量选择"2.0x",并定义损失函数为交叉熵损失,学习率经过4轮的warmup后采用余弦退火,优化器采用Momentum。最后用train.model中的Model接口将模型、损失函数、优化器封装在model中,并用model.train()对网络进行训练。将ModelCheckpointCheckpointConfigTimeMonitorLossMonitor传入回调函数中,将会打印训练的轮数、损失和时间,并将ckpt文件保存在当前目录下。
import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracy

def train():
    mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    min_lr = 0.0005
    base_lr = 0.05
    lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                                base_lr,
                                                batches_per_epoch*250,
                                                batches_per_epoch,
                                                decay_epoch=250)
    lr = Tensor(lr_scheduler[-1])
    optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
    loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
    model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
    callback = [TimeMonitor(), LossMonitor()]
    save_ckpt_path = "./"
    config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
    ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
    callback += [ckpt_callback]

    print("============== Starting Training ==============")
    start_time = time.time()
    # 由于时间原因,epoch = 5,可根据需求进行调整
    model.train(5, dataset, callbacks=callback)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    print("total time:" + hour + "h " + minute + "m " + second + "s")
    print("============== Train Success ==============")

if __name__ == '__main__':
    train()
  • 模型评估:
from mindspore import load_checkpoint, load_param_into_net

def test():
    mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="Ascend")
    dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
    load_param_into_net(net, param_dict)
    net.set_train(False)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),
                    'Top_5_Acc': Top5CategoricalAccuracy()}
    model = Model(net, loss_fn=loss, metrics=eval_metrics)
    start_time = time.time()
    res = model.eval(dataset, dataset_sink_mode=False)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-5_390.ckpt" \
        + "', time: " + hour + "h " + minute + "m " + second + "s"
    print(log)
    filename = './eval_log.txt'
    with open(filename, 'a') as file_object:
        file_object.write(log + '\n')

if __name__ == '__main__':
    test()
  • 开始循环运行:
# 开始循环训练
print("Start Training Loop ...")

for epoch in range(num_epochs):
    curr_loss = train(data_loader_train, epoch)
    curr_acc = evaluate(data_loader_val)

    print("-" * 50)
    print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % (
        epoch+1, num_epochs, curr_loss, curr_acc
    ))
    print("-" * 50)

    # 保存当前预测准确率最高的模型
    if curr_acc > best_acc:
        best_acc = curr_acc
        ms.save_checkpoint(network, best_ckpt_path)

print("=" * 80)
print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, "
      f"save the best ckpt file in {best_ckpt_path}", flush=True)

3.6 可视化模型预测:

import mindspore
import matplotlib.pyplot as plt
import mindspore.dataset as ds

net = ShuffleNetV1(model_size="2.0x", n_class=10)
show_lst = []
param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
load_param_into_net(net, param_dict)
model = Model(net)
dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = dataset_show.batch(16)
show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()
image_trans = [
    vision.RandomCrop((32, 32), (4, 4, 4, 4)),
    vision.RandomHorizontalFlip(prob=0.5),
    vision.Resize((224, 224)),
    vision.Rescale(1.0 / 255.0, 0.0),
    vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
    vision.HWC2CHW()
        ]
dataset_predict = dataset_predict.map(image_trans, 'image')
dataset_predict = dataset_predict.batch(16)
class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}

# 推理效果展示(上方为预测的结果,下方为推理效果图片)
plt.figure(figsize=(16, 5))
predict_data = next(dataset_predict.create_dict_iterator())
output = model.predict(ms.Tensor(predict_data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
index = 0
for image in show_images_lst:
    plt.subplot(2, 8, index+1)
    plt.title('{}'.format(class_dict[pred[index]]))
    index += 1
    plt.imshow(image)
    plt.axis("off")
plt.show()

4. 相关链接:

  • ShuffleNetV1 论文
  • https://xihe.mindspore.cn/events/mindspore-training-camp
  • https://gitee.com/mindspore/docs/blob/r2.3/tutorials/application/source_zh_cn/cv/shufflenet.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1936204.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RabbitMQ学习实践一:MQ的安装

文章是本人在学习springboot实现消息队列功能时所经历的过程的记录,仅供参考,如有侵权请随时指出。 参考文章地址: RabbitMQ安装与入门_rabbitmq win11配置-CSDN博客 RabbitMQ入门到实战一篇文章就够了-CSDN博客 RabbitMQ系列&#xff08…

进程通信(5):POSIX消息队列

随进程持续:IPC对象一直存在直到最后一个进程关闭该对象为止(管道和FIFO)。 随内核持续:IPC对象存在一直到内核自举(系统重新启动)或者显示删除该对象。 如System V消息队列,System V信号量,S…

固态继电器的实际使用和有效应用

固态继电器(SSR)已成为现代电气和电子系统中不可或缺的组件,与传统的机电继电器相比具有众多优势。在本文中,我们将深入探讨SSR的实际方面、其应用以及有效部署的关键考虑因素。 什么是固态继电器? 固态继电器是使用半导体器件(如…

ORB_SLAM2 ORBSLAM2 Ubuntu20.04 ROS Noetic虚拟机镜像下载

下图是build.sh 和 build_ros.sh 编译完成截图: slam测试视频: orbslam2 ubuntu20.04 test 下载地址(付费使用,不能接受请勿下载): 链接:https://pan.baidu.com/s/16R7Pb6LjgR5SeoeBSZfgaQ?pwdu05r 提取…

前端小知识点——按钮之间出现很小的空隙如何规避

前端小知识点——按钮之间出现很小的空隙如何规避 文章介绍问题再现总结 文章介绍 本文主要介绍页面中两个按钮相邻时会出现一点空隙,导致在后续自定义填充的时候出现换行或其它问题,特此记录。 问题再现 这个图片能看到我们给外面的div设置的是300的宽…

stack模拟实现【C++】

文章目录 全部的实现代码放在了文章末尾什么是适配器模式?stack准备工作包含头文件定义命名空间类的成员变量 默认成员函数emptysizetoppushpop全部代码 全部的实现代码放在了文章末尾 stack的模拟实现我采用了C适配器模式 stack的适配器一般是deque,也…

SpringBoot系列—2.SpringBoot拦截器篇

SpringBoot系列—1.IDEA搭建SpringBoot框架 SpringBoot系列—2.SpringBoot拦截器篇 SpringBoot系列—3.SpringBoot Redis篇 SpringBoot系列—4.SpringBoot 整合Mybatis、MP(MyBatis-Plus) SpringBoot系列—5.SpringBoot 整合Mybatis-Plus分页 1.新建拦截…

C# 之工控机数据类型 高低位(大小端)、BitConverter、IsLittleEndian、字节数组转换(高低位)

八种基本数据类型:byte、short、int、long、float、double、boolean、char byte 8位、有符号的以二进制补码表示的整数 min : -128(-2^7) max: 127(2^7-1) default: 0 对应包装类:Byte short 16位、有符号的以二进制补码表示…

将微信聊天记录导出成html等格式

通过github上的开源项目WechatMsg,可以将微信中的聊天记录(包括文字、图片、语音、表情包甚至拍一拍)导出,方便我们随时分享和查看,此外还有聊天记录分析等有趣的功能,感兴趣的小伙伴可以研究一下。我个人认…

静态网站怎么更新数据

今天看到个问题 我不是行业从业者,但目前遇到一个问题 我公司网站为纯静态,除了直接从html里修改文字外能不能这样 建立一个xml或者txt文档,其中有很多信息,例如网站名称,电话,备案号等,一行一行…

AI赋能项目集成:我的实战经验与洞见

背景 在传统的教学模式中,教师往往难以兼顾每位学生的个性化需求,学习信息的收集与分析也受限于时间和精力的限制,难以做到全面而深入。然而,每位学生都是独一无二的个体,他们拥有不同的学习风格、兴趣偏好以及理解能…

【IC前端虚拟项目】sanity_case的编写与通包测试

【IC前端虚拟项目】数据搬运指令处理模块前端实现虚拟项目说明-CSDN博客 在花了大力气完成reference model之后,整个验证环境的搭建就完成了,再多看一下这个结构然后就可以进行sanity_case和通包测试: 关于sanity_case和通包测试我在很多篇文章中说过好多次了在这里就不赘述…

如何安装Visual Studio Code

Visual Studio Code(简称 VS Code) Visual Studio Code 是一款由微软开发的免费、开源的现代化轻量级代码编辑器。 主要特点包括: 跨平台:支持 Windows、Mac 和 Linux 等主流操作系统,方便开发者在不同平台上保持一…

ETL电商项目总结

ETL电商项目总结 ETL电商业务简介及各数据表关系 业务背景 ​ 本案例围绕某个互联网小型电商的订单业务来开发。某电商公司,每天都有一些的用户会在线上采购商品,该电商公司想通过数据分析,查看每一天的电商经营情况。例如:电商…

科普文:字节码class文件和字节码增强技术

1. 引言 1.1. 什么是字节码 Java字节码是指Java语言编译后生成的一种二进制文件格式,它包含了Java程序的所有信息,包括类信息、方法信息、变量信息等。字节码是Java程序执行的基础,它被用于实现Java虚拟机(JVM)的加载…

随手记:vsCode修改主题色为自定义颜色

因为工作需要长时间面对vscode,视力不好,想要把工具改成护眼色,于是就把vscode改成了自定义的护眼色 效果图: 操作步骤: 快捷键打开设置页面: 按住ctrlshiftp 选择Open setting 按回车键 打开setting页面编…

【STM32CubeMX】一 TIME定时器Mode and Configuration的详解

使用STM32CubeMX软件学习配置定时器,对Mode and Configuration进行分析各部分选项的功能。本次以TIM2为例进行分析。 一、 Slave Mode 可以配置的选项有: Disable External Clock Mode 1 外部时钟源模式1 Reset Mode 复位模式 Gated Mode 门控模式 Tri…

采用T网络反馈电路的运算放大器(运放)反相放大器

运算放大器(运放)反相放大器电路 设计目标 输入电压ViMin输入电压ViMax输出电压VoMin输出电压VoMaxBW fp电源电压Vcc电源电压Vee-2.5mV2.5mV–2.5V2.5V5kHz5V–5V 设计说明1 该设计将输入信号 Vin 反相并应用 1000V/V 或 60dB 的信号增益。具有 T 反馈网络的反相放大器可用…

scanf` 和 `printf` 通常比 `cin` 和 `cout` 在处理数据时的分析

#include<bits/stdc.h> using namespace std; int x[5000005],k; int main() {int n;scanf("%d%d",&n,&k);for(int i0;i<n;i)scanf("%d",&x[i]);sort(x,xn);//快排printf("%d",x[k]); }和#include<vector> #include&…

redis数据库(下)

集合键值对 集合的每一个元素也是字符串格式数据,是无序集合,并且元素不可重复(自动去重) 1.集合的创建和添加命令 sadd命令:无责创建有责添加 sadd 键名 元素1 元素2......... 注意:再次添加元素时,如果触发了集合的唯一性,那么命令执行结果就为0,表示执行失败…