昇思25天学习打卡营第21天|CV-Shufflenet图像分类

news2025/1/17 14:11:00

打卡

目录

打卡

ShuffleNet 网络介绍

ShuffleNet 模型架构

Pointwise Group Convolution

Channel Shuffle

ShuffleNet模块

ShuffleNet 模块代码

构建ShuffleNet网络

模块代码

模型训练和评估

模型训练

模型评估

模型预测


ShuffleNet 网络介绍

ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到最好的模型精度。ShuffleNetV1的设计核心是引入了两种操作:Pointwise Group Convolution Channel Shuffle,这在保持精度的同时大大降低了模型的计算量。因此,ShuffleNetV1和MobileNet类似,都是通过设计更高效的网络结构来实现模型的压缩和加速。

ShuffleNet在保持不低的准确率的前提下,将参数量几乎降低到了最小,因此其运算速度较快,单位参数量对模型准确率的贡献非常高。了解ShuffleNet更多详细内容,详见论文ShuffleNet。

ShuffleNet 模型架构

ShuffleNet最显著的特点在于对不同通道进行重排来解决Group Convolution带来的弊端。通过改进ResNet 的 Bottleneck单元,在较小的计算量的情况下达到了较高的准确率。

Pointwise Group Convolution

  • Group Convolution(分组卷积)原理如下图。分组卷积的每一组的卷积核大小为in_channels/g*k*k,一共有g组,所有组共有 (in_channels/g*k*k)*out_channels 个参数,是正常卷积参数的1/g。分组卷积中,每个卷积核只处理输入特征图的一部分通道,其优点在于参数量会有所降低,但输出通道数仍等于卷积核的数量
  • Depthwise Convolution(深度可分离卷积)将组数g分为和输入通道相等的in_channels,然后对每一个in_channels做卷积操作,每个卷积核只处理一个通道,记卷积核大小为1*k*k,则卷积核参数量为:in_channels*k*k,得到的feature maps通道数与输入通道数相等
  • Pointwise Group Convolution(逐点分组卷积)在分组卷积的基础上,令每一组的卷积核大小为 1×1 ,卷积核参数量为(in_channels/g*1*1)*out_channels。

Channel Shuffle

Group Convolution的弊端在于不同组别的通道无法进行信息交流,堆积 GConv 层后一个问题是不同组之间的特征图是不通信的,这就好像分成了g个互不相干的道路,每一个人各走各的,这可能会降低网络的特征提取能力。这也是Xception,MobileNet等网络采用密集的1x1卷积(Dense Pointwise Convolution)的原因。

为了解决不同组别通道“近亲繁殖”的问题,ShuffleNet优化了大量密集的1x1卷积(在使用的情况下计算量占用率达到了惊人的93.4%),引入Channel Shuffle机制(通道重排)。这项操作直观上表现为将不同分组通道均匀分散重组,使网络在下一层能处理不同组别通道的信息。

如下图所示,对于g组,每组有n个通道的特征图,首先reshape成g行n列的矩阵,再将矩阵转置成n行g列,最后进行flatten操作,得到新的排列。这些操作都是可微分可导的且计算简单,在解决了信息交互的同时符合了ShuffleNet轻量级网络设计的轻量特征。

ShuffleNet模块

如下图,ShuffleNet对ResNet中的Bottleneck结构进行由(a)到(b), (c)的更改:

  1. 将开始和最后的1×1 卷积模块(降维、升维)改成 Point Wise Group Convolution;

  2. 为了进行不同通道的信息交流,再降维之后进行Channel Shuffle;

  3. 降采样模块中,3×3 Depth Wise Convolution的步长设置为2,长宽降为原来的一般,因此shortcut中采用步长为2的3×3 平均池化,并把相加改成拼接。

ShuffleNet 模块代码
class ShuffleV1Block(nn.Cell):
    def __init__(self, inp, oup, group, first_group, mid_channels, ksize, stride):
        super(ShuffleV1Block, self).__init__()
        self.stride = stride
        pad = ksize // 2
        self.group = group
        if stride == 2:
            outputs = oup - inp
        else:
            outputs = oup
        self.relu = nn.ReLU()
        branch_main_1 = [
            GroupConv(in_channels=inp, out_channels=mid_channels,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=1 if first_group else group),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
        ]
        branch_main_2 = [
            nn.Conv2d(mid_channels, mid_channels, kernel_size=ksize, stride=stride,
                      pad_mode='pad', padding=pad, group=mid_channels,
                      weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(mid_channels),
            GroupConv(in_channels=mid_channels, out_channels=outputs,
                      kernel_size=1, stride=1, pad_mode="pad", pad=0,
                      groups=group),
            nn.BatchNorm2d(outputs),
        ]
        self.branch_main_1 = nn.SequentialCell(branch_main_1)
        self.branch_main_2 = nn.SequentialCell(branch_main_2)
        if stride == 2:
            self.branch_proj = nn.AvgPool2d(kernel_size=3, stride=2, pad_mode='same')

    def construct(self, old_x):
        left = old_x
        right = old_x
        out = old_x
        right = self.branch_main_1(right)
        if self.group > 1:
            right = self.channel_shuffle(right)
        right = self.branch_main_2(right)
        if self.stride == 1:
            out = self.relu(left + right)
        elif self.stride == 2:
            left = self.branch_proj(left)
            out = ops.cat((left, right), 1)
            out = self.relu(out)
        return out

    def channel_shuffle(self, x):
        batchsize, num_channels, height, width = ops.shape(x)
        group_channels = num_channels // self.group
        x = ops.reshape(x, (batchsize, group_channels, self.group, height, width))
        x = ops.transpose(x, (0, 2, 1, 3, 4))
        x = ops.reshape(x, (batchsize, num_channels, height, width))
        return x

构建ShuffleNet网络

ShuffleNet网络结构如下图。以输入图像 224×224,组数3(g = 3)为例,首先通过数量24,卷积核大小为3×3 ,stride为2的卷积层,输出特征图大小为112×112 ,channel为24;然后通过stride为2的最大池化层,输出特征图大小为56×56 ,channel数不变;再堆叠3个ShuffleNet模块(Stage2, Stage3, Stage4),三个模块分别重复4次、8次、4次,其中每个模块开始先经过一次下采样模块(上图(c)),使特征图长宽减半,channel翻倍(Stage2的下采样模块除外,将channel数从24变为240);随后经过全局平均池化,输出大小为1×1×960 ,再经过全连接层和softmax,得到分类概率。

模块代码

class ShuffleNetV1(nn.Cell):
    def __init__(self, n_class=1000, model_size='2.0x', group=3):
        super(ShuffleNetV1, self).__init__()
        print('model size is ', model_size)
        self.stage_repeats = [4, 8, 4]
        self.model_size = model_size
        if group == 3:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 12, 120, 240, 480]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 240, 480, 960]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 360, 720, 1440]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 480, 960, 1920]
            else:
                raise NotImplementedError
        elif group == 8:
            if model_size == '0.5x':
                self.stage_out_channels = [-1, 16, 192, 384, 768]
            elif model_size == '1.0x':
                self.stage_out_channels = [-1, 24, 384, 768, 1536]
            elif model_size == '1.5x':
                self.stage_out_channels = [-1, 24, 576, 1152, 2304]
            elif model_size == '2.0x':
                self.stage_out_channels = [-1, 48, 768, 1536, 3072]
            else:
                raise NotImplementedError
        input_channel = self.stage_out_channels[1]
        self.first_conv = nn.SequentialCell(
            nn.Conv2d(3, input_channel, 3, 2, 'pad', 1, weight_init='xavier_uniform', has_bias=False),
            nn.BatchNorm2d(input_channel),
            nn.ReLU(),
        )
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
        features = []
        for idxstage in range(len(self.stage_repeats)):
            numrepeat = self.stage_repeats[idxstage]
            output_channel = self.stage_out_channels[idxstage + 2]
            for i in range(numrepeat):
                stride = 2 if i == 0 else 1
                first_group = idxstage == 0 and i == 0
                features.append(ShuffleV1Block(input_channel, output_channel,
                                               group=group, first_group=first_group,
                                               mid_channels=output_channel // 4, ksize=3, stride=stride))
                input_channel = output_channel
        self.features = nn.SequentialCell(features)
        self.globalpool = nn.AvgPool2d(7)
        self.classifier = nn.Dense(self.stage_out_channels[-1], n_class)

    def construct(self, x):
        x = self.first_conv(x)
        x = self.maxpool(x)
        x = self.features(x)
        x = self.globalpool(x)
        x = ops.reshape(x, (-1, self.stage_out_channels[-1]))
        x = self.classifier(x)
        return x

模型训练和评估

采用CIFAR-10数据集对ShuffleNet进行预训练。

  • CIFAR-10共有60000张32*32的彩色图像,均匀地分为10个类别,其中50000张图片作为训练集,10000图片作为测试集。
  • 使用 mindspore.dataset.Cifar10Dataset 接口下载并加载CIFAR-10的训练集。目前仅支持二进制版本(CIFAR-10 binary version)。
from download import download
import mindspore as ms
from mindspore.dataset import Cifar10Dataset
from mindspore.dataset import vision, transforms

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz"

download(url, "./dataset", kind="tar.gz", replace=True)


def get_dataset(train_dataset_path, batch_size, usage):
    image_trans = []
    if usage == "train":
        image_trans = [
            vision.RandomCrop((32, 32), (4, 4, 4, 4)),
            vision.RandomHorizontalFlip(prob=0.5),
            vision.Resize((224, 224)),
            vision.Rescale(1.0 / 255.0, 0.0),
            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            vision.HWC2CHW()
        ]
    elif usage == "test":
        image_trans = [
            vision.Resize((224, 224)),
            vision.Rescale(1.0 / 255.0, 0.0),
            vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
            vision.HWC2CHW()
        ]
    label_trans = transforms.TypeCast(ms.int32)
    dataset = Cifar10Dataset(train_dataset_path, usage=usage, shuffle=True)
    dataset = dataset.map(image_trans, 'image')
    dataset = dataset.map(label_trans, 'label')
    dataset = dataset.batch(batch_size, drop_remainder=True)
    return dataset

dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "train")
batches_per_epoch = dataset.get_dataset_size()  ## 390

模型训练

用随机初始化的参数做预训练。

1)首先调用ShuffleNetV1定义网络,参数量选择"2.0x",并定义损失函数为交叉熵损失,学习率经过4轮的warmup后采用余弦退火,优化器采用Momentum

2)最后用train.model中的Model接口将模型、损失函数、优化器封装在model中,并用model.train()对网络进行训练。将ModelCheckpointCheckpointConfigTimeMonitorLossMonitor传入回调函数中,将会打印训练的轮数、损失和时间,并将ckpt文件保存在当前目录下。

import time
import mindspore
import numpy as np
from mindspore import Tensor, nn
from mindspore.train import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor, Model, Top1CategoricalAccuracy, Top5CategoricalAccuracy

def train():
    mindspore.set_context(mode=mindspore.PYNATIVE_MODE, device_target="Ascend")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    min_lr = 0.0005
    base_lr = 0.05
    lr_scheduler = mindspore.nn.cosine_decay_lr(min_lr,
                                                base_lr,
                                                batches_per_epoch*250,
                                                batches_per_epoch,
                                                decay_epoch=250)
    lr = Tensor(lr_scheduler[-1])
    optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.00004, loss_scale=1024)
    loss_scale_manager = ms.amp.FixedLossScaleManager(1024, drop_overflow_update=False)
    model = Model(net, loss_fn=loss, optimizer=optimizer, amp_level="O3", loss_scale_manager=loss_scale_manager)
    callback = [TimeMonitor(), LossMonitor()]
    save_ckpt_path = "./"
    config_ckpt = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=5)
    ckpt_callback = ModelCheckpoint("shufflenetv1", directory=save_ckpt_path, config=config_ckpt)
    callback += [ckpt_callback]

    print("============== Starting Training ==============")
    start_time = time.time()
    # 由于时间原因,epoch = 5,可根据需求进行调整
    model.train(5, dataset, callbacks=callback)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    print("total time:" + hour + "h " + minute + "m " + second + "s")
    print("============== Train Success ==============")

if __name__ == '__main__':
    train()

模型评估

设置好评估模型的路径后加载数据集,并设置Top 1, Top 5的评估标准,最后用model.eval()接口对模型进行评估。

from mindspore import load_checkpoint, load_param_into_net

def test():
    mindspore.set_context(mode=mindspore.GRAPH_MODE, device_target="Ascend")
    dataset = get_dataset("./dataset/cifar-10-batches-bin", 128, "test")
    net = ShuffleNetV1(model_size="2.0x", n_class=10)
    param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
    load_param_into_net(net, param_dict)
    net.set_train(False)
    loss = nn.CrossEntropyLoss(weight=None, reduction='mean', label_smoothing=0.1)
    eval_metrics = {'Loss': nn.Loss(), 'Top_1_Acc': Top1CategoricalAccuracy(),
                    'Top_5_Acc': Top5CategoricalAccuracy()}
    model = Model(net, loss_fn=loss, metrics=eval_metrics)
    start_time = time.time()
    res = model.eval(dataset, dataset_sink_mode=False)
    use_time = time.time() - start_time
    hour = str(int(use_time // 60 // 60))
    minute = str(int(use_time // 60 % 60))
    second = str(int(use_time % 60))
    log = "result:" + str(res) + ", ckpt:'" + "./shufflenetv1-5_390.ckpt" \
        + "', time: " + hour + "h " + minute + "m " + second + "s"
    print(log)
    filename = './eval_log.txt'
    with open(filename, 'a') as file_object:
        file_object.write(log + '\n')

if __name__ == '__main__':
    test()

模型预测

import mindspore
import matplotlib.pyplot as plt
import mindspore.dataset as ds

net = ShuffleNetV1(model_size="2.0x", n_class=10)
show_lst = []
param_dict = load_checkpoint("shufflenetv1-5_390.ckpt")
load_param_into_net(net, param_dict)
model = Model(net)
dataset_predict = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = ds.Cifar10Dataset(dataset_dir="./dataset/cifar-10-batches-bin", shuffle=False, usage="train")
dataset_show = dataset_show.batch(16)
show_images_lst = next(dataset_show.create_dict_iterator())["image"].asnumpy()
image_trans = [
    vision.RandomCrop((32, 32), (4, 4, 4, 4)),
    vision.RandomHorizontalFlip(prob=0.5),
    vision.Resize((224, 224)),
    vision.Rescale(1.0 / 255.0, 0.0),
    vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),
    vision.HWC2CHW()
        ]
dataset_predict = dataset_predict.map(image_trans, 'image')
dataset_predict = dataset_predict.batch(16)
class_dict = {0:"airplane", 1:"automobile", 2:"bird", 3:"cat", 4:"deer", 5:"dog", 6:"frog", 7:"horse", 8:"ship", 9:"truck"}
# 推理效果展示(上方为预测的结果,下方为推理效果图片)
plt.figure(figsize=(16, 5))
predict_data = next(dataset_predict.create_dict_iterator())
output = model.predict(ms.Tensor(predict_data['image']))
pred = np.argmax(output.asnumpy(), axis=1)
index = 0
for image in show_images_lst:
    plt.subplot(2, 8, index+1)
    plt.title('{}'.format(class_dict[pred[index]]))
    index += 1
    plt.imshow(image)
    plt.axis("off")
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1949626.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

vue3+vite纯前端实现自动触发浏览器刷新更新版本内容,并在打包时生成版本号文件

前言 在前端项目中,有时候为了实现自动触发浏览器刷新并更新版本内容,可以采取一系列巧妙的措施。我的项目中是需要在打包时候生成一个version.js文件,用当前打包时间作为版本的唯一标识,然后打包发版 ,从实现对版本更…

【Golang 面试基础题】每日 5 题(八)

✍个人博客:Pandaconda-CSDN博客 📣专栏地址:http://t.csdnimg.cn/UWz06 📚专栏简介:在这个专栏中,我将会分享 Golang 面试中常见的面试题给大家~ ❤️如果有收获的话,欢迎点赞👍收藏…

【YashanDB知识库】绑定参数,同一个sql多个执行计划的问题

问题现象 同一个sql有两个执行计划,是否合理? 它的EXECUTIONS,ELAPSED_TIME等统计信息怎么看,是独立分开的还是统一计算的? 如下图: 问题影响版本 tpcc测试:23.2.1.100 问题的风险及影响 …

无人机公司销售需要什么资质

国家民航局于2024年1月1日实施了《无人驾驶航空器飞行管理暂行条例》,根据这个管理条例里面的 第十一条 使用除微型以外的民用无人驾驶航空器从事飞行活动的单位应当具备下列条件,并向国务院民用航空主管部门或者地区民用航空管理机构申请取得民用无人驾…

若依+AI项目开发(二)

后端代码分析 二次开发 开始执行 生成成功 创建子模块

电子签章-开放签应用

开放签电子签章系统开源工具版旨在将电子签章、电子合同系统开发中的前后端核心技术开源开放,适合有技术能力的个人 / 团队学习或自建电子签章 \ 电子合同功能或应用,避免研发同仁在工作过程中重复造轮子,降低电子签章技术研发要求&#xff0…

如何解决ChromeDriver 126找不到chromedriver.exe问题

引言 在使用Selenium和ChromeDriver进行网页自动化时,ChromeDriver与Chrome浏览器版本不匹配的问题时有发生。最近,许多开发者在使用ChromeDriver 126时遇到了无法找到chromedriver.exe文件的错误。本文将介绍该问题的原因,并提供详细的解决…

mysql-bin 恢复数据库

能看到这里的同学估计肯定摊上大事了吧!不要慌,一定要冷静,记录一下作者的大事件吧,黑客通过SQL注入的方式执行了一段SQL : DROP DATABASE ****** 后果就是导致整个数据库被删了,当时心是拔凉拔凉的&#x…

3.2、数据结构-数组、矩阵和广义表

数组结构 数组是定长线性表在维度上的扩展,即线性表中的元素又是一个线性表。N维数组是一种“同构”的数据结构,其每个数据元素类型相同、结构一致。 一个m行n列的数组表示如下: 其可以表示为行向量形式(一行一行的数据)或者列向量形式(一…

收银系统源码视频介绍

千呼新零售2.0系统是零售行业连锁店一体化收银系统,包括线下收银线上商城连锁店管理ERP管理商品管理供应商管理会员营销等功能为一体,线上线下数据全部打通。 适用于商超、便利店、水果、生鲜、母婴、服装、零食、百货、宠物等连锁店使用。 详细介绍请…

Haproxy 可观测性最佳实践

HAProxy 是一款广泛使用的高性能负载均衡器,支持 TCP 和 HTTP 协议,提供高可用性、负载均衡和代理服务。 HAProxy 2.0 以上版本提供了完善的指标暴露体系,方便观测云收集对应的指标信息。 版本要求 HAProxy 2.0 HAProxy Enterprise 2.0r1 HAP…

自定义协议(应用层协议)——网络版计算机基于TCP传输协议

应用层:自定义网络协议:序列化和反序列化,如果是TCP传输的:还要关心区分报文边界(在序列化设计的时候设计好)——粘包问题 1、首先想要使用TCP协议传输的网络,服务器和客户端都应该要创建自己…

AI发展下的伦理挑战:构建未来科技的道德框架

一、引言 随着人工智能(AI)技术的飞速发展,我们正处在一个前所未有的科技变革时代。AI不仅在医疗、教育、金融、交通等领域展现出巨大的应用潜力,也在日常生活中扮演着越来越重要的角色。然而,这一技术的迅猛进步也带来…

RuoYi基于SpringBoot+Vue前后端分离的Java快速开发框架学习_2_登录

文章目录 一、登录1.生成验证码2.验证码作用1.大体流程2.代码层面(我们都是从前端开始看起) 一、登录 1.生成验证码 基本思路: 后端生成一个表达式,例如34?7,显而易见后面是答案截取出来题干和答案把题干11?变成图片,变成流&a…

下属不把你当回事?就做好这3步,他们会对你唯命是从!

下属不把你当回事?就做好这3步,他们会对你唯命是从! 一:规范制度,做事有理可依 企业管理好比是满汉全席,制度才是压轴大菜,人性化说教不过是菜盘边上的点缀罢了。 千万不可舍本逐末。 事要有人干…

React间的组件通信

一、父传子&#xff08;props&#xff09; 步骤 父组件传递数据&#xff0c;子组件标签身上绑定属性子组件接收数据&#xff0c;props的参数 // 子组件 function Son(props) {return (<div>this is Son, {props.name}</div>) }// 父组件 function App() {const n…

如何使用 DSPy 构建多步骤推理的 RAG 系统

一、前言 检索增强生成 (RAG) 系统已经成为构建基于大语言模型 (LLM) 应用的强大方法。RAG 系统的工作原理是&#xff1a;首先使用检索模型从外部知识源检索相关信息&#xff0c;然后使用这些信息来提示 LLM 生成最终的响应。 然而&#xff0c;基本的 RAG 系统&#xff08;也…

谷粒商城实战笔记-47-商品服务-API-三级分类-网关统一配置跨域

文章目录 一&#xff0c;跨域问题1&#xff0c;跨域问题产生的原因2&#xff0c;预检请求3&#xff0c;跨域解决方案3.1 CORS (Cross-Origin Resource Sharing)后端配置示例&#xff08;Spring Boot&#xff09; 3.2 JSONP (JSON with Padding)3.3 代理服务器Nginx代理配置示例…

python自动化中正则表达式提取(适用于提取文本结果)

对于结果是json格式的我们经常使用jsonpath&#xff0c;但是很多时候我们需要从一些文本中提取数据&#xff0c;这个时候正则表达式的提取就很重要&#xff0c;这边主要分享一些正则表达式的提取方法和应用场景的实践&#xff0c;主要介绍两种用法re.search()跟re.findall() 1…

基于springboot+vue+uniapp的居民健康监测小程序

开发语言&#xff1a;Java框架&#xff1a;springbootuniappJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#…