昇思25天学习打卡营第18天|xiaoyushao

news2025/1/16 3:44:18

        今天分享基于MobileNetv2的垃圾分类。读取本地图像数据作为输入,对图像中的垃圾物体进行检测,并且将检测结果图片保存到文件中。

       

目录

        一、 MobileNetv2模型原理介绍

        二、 数据处理

        1. 数据准备

        2. 数据加载

        3. 数据预处理

        三、 MobileNetv2模型搭建

        四、 MobileNetv2模型的训练与测试

        1. 训练策略

        2. 模型训练与测试

        五、 模型推理

        六、 导出AIR/GEIR/ONNX模型文件


一、 MobileNetv2模型原理介绍

        MobileNet网络是由Google团队于2017年提出的专注于移动端、嵌入式或IoT设备的轻量级CNN网络,相比于传统的卷积神经网络,MobileNet网络使用深度可分离卷积(Depthwise Separable Convolution)的思想在准确率小幅度降低的前提下,大大减小了模型参数与运算量。并引入宽度系数 α和分辨率系数 β使模型满足不同应用场景的需求。

        由于MobileNet网络中Relu激活函数处理低维特征信息时会存在大量的丢失,所以MobileNetV2网络提出使用倒残差结构(Inverted residual block)和Linear Bottlenecks来设计网络,以提高模型的准确率,且优化后的模型更小。

        图中Inverted residual block结构是先使用1x1卷积进行升维,然后使用3x3的DepthWise卷积,最后使用1x1的卷积进行降维,与Residual block结构相反。Residual block是先使用1x1的卷积进行降维,然后使用3x3的卷积,最后使用1x1的卷积进行升维。详细内容可参见MobileNetV2论文。

        二、 数据处理

        1. 数据准备

        MobileNetV2的代码默认使用ImageFolder格式管理数据集,每一类图片整理成单独的一个文件夹, 数据集结构如下:

from download import download

# 下载data_en数据集
url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/MindStudio-pc/data_en.zip" 
path = download(url, "./", kind="zip", replace=True)

# 下载预训练权重文件
url = "https://ascend-professional-construction-dataset.obs.cn-north-4.myhuaweicloud.com:443/ComputerVision/mobilenetV2-200_1067.zip" 
path = download(url, "./", kind="zip", replace=True)

        2. 数据加载

         导入MindSpore模块和辅助模块。

import math
import numpy as np
import os
import random

from matplotlib import pyplot as plt
from easydict import EasyDict
from PIL import Image
import numpy as np
import mindspore.nn as nn
from mindspore import ops as P
from mindspore.ops import add
from mindspore import Tensor
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.vision as C
import mindspore.dataset.transforms as C2
import mindspore as ms
from mindspore import set_context, nn, Tensor, load_checkpoint, save_checkpoint, export
from mindspore.train import Model
from mindspore.train import Callback, LossMonitor, ModelCheckpoint, CheckpointConfig

os.environ['GLOG_v'] = '3' # Log level includes 3(ERROR), 2(WARNING), 1(INFO), 0(DEBUG).
os.environ['GLOG_logtostderr'] = '0' # 0:输出到文件,1:输出到屏幕
os.environ['GLOG_log_dir'] = '../../log' # 日志目录
os.environ['GLOG_stderrthreshold'] = '2' # 输出到目录也输出到屏幕:3(ERROR), 2(WARNING), 1(INFO), 0(DEBUG).
set_context(mode=ms.GRAPH_MODE, device_target="CPU", device_id=0) # 设置采用图模式执行,设备为Ascend#

        配置后续训练、验证、推理用到的参数:

# 垃圾分类数据集标签,以及用于标签映射的字典。
garbage_classes = {
    '干垃圾': ['贝壳', '打火机', '旧镜子', '扫把', '陶瓷碗', '牙刷', '一次性筷子', '脏污衣服'],
    '可回收物': ['报纸', '玻璃制品', '篮球', '塑料瓶', '硬纸板', '玻璃瓶', '金属制品', '帽子', '易拉罐', '纸张'],
    '湿垃圾': ['菜叶', '橙皮', '蛋壳', '香蕉皮'],
    '有害垃圾': ['电池', '药片胶囊', '荧光灯', '油漆桶']
}

class_cn = ['贝壳', '打火机', '旧镜子', '扫把', '陶瓷碗', '牙刷', '一次性筷子', '脏污衣服',
            '报纸', '玻璃制品', '篮球', '塑料瓶', '硬纸板', '玻璃瓶', '金属制品', '帽子', '易拉罐', '纸张',
            '菜叶', '橙皮', '蛋壳', '香蕉皮',
            '电池', '药片胶囊', '荧光灯', '油漆桶']
class_en = ['Seashell', 'Lighter','Old Mirror', 'Broom','Ceramic Bowl', 'Toothbrush','Disposable Chopsticks','Dirty Cloth',
            'Newspaper', 'Glassware', 'Basketball', 'Plastic Bottle', 'Cardboard','Glass Bottle', 'Metalware', 'Hats', 'Cans', 'Paper',
            'Vegetable Leaf','Orange Peel', 'Eggshell','Banana Peel',
            'Battery', 'Tablet capsules','Fluorescent lamp', 'Paint bucket']

index_en = {'Seashell': 0, 'Lighter': 1, 'Old Mirror': 2, 'Broom': 3, 'Ceramic Bowl': 4, 'Toothbrush': 5, 'Disposable Chopsticks': 6, 'Dirty Cloth': 7,
            'Newspaper': 8, 'Glassware': 9, 'Basketball': 10, 'Plastic Bottle': 11, 'Cardboard': 12, 'Glass Bottle': 13, 'Metalware': 14, 'Hats': 15, 'Cans': 16, 'Paper': 17,
            'Vegetable Leaf': 18, 'Orange Peel': 19, 'Eggshell': 20, 'Banana Peel': 21,
            'Battery': 22, 'Tablet capsules': 23, 'Fluorescent lamp': 24, 'Paint bucket': 25}

# 训练超参
config = EasyDict({
    "num_classes": 26,
    "image_height": 224,
    "image_width": 224,
    #"data_split": [0.9, 0.1],
    "backbone_out_channels":1280,
    "batch_size": 16,
    "eval_batch_size": 8,
    "epochs": 10,
    "lr_max": 0.05,
    "momentum": 0.9,
    "weight_decay": 1e-4,
    "save_ckpt_epochs": 1,
    "dataset_path": "./data_en",
    "class_index": index_en,
    "pretrained_ckpt": "./mobilenetV2-200_1067.ckpt" # mobilenetV2-200_1067.ckpt 
})

        3. 数据预处理

        利用ImageFolderDataset方法读取垃圾分类数据集,并整体对数据集进行处理。读取数据集时指定训练集和测试集,首先对整个数据集进行归一化,修改图像频道等预处理操作。然后对训练集的数据依次进行RandomCropDecodeResize、RandomHorizontalFlip、RandomColorAdjust、shuffle操作,以增加训练数据的丰富度;对测试集进行Decode、Resize、CenterCrop等预处理操作;最后返回处理后的数据集。

def create_dataset(dataset_path, config, training=True, buffer_size=1000):
    """
    create a train or eval dataset

    Args:
        dataset_path(string): the path of dataset.
        config(struct): the config of train and eval in diffirent platform.

    Returns:
        train_dataset, val_dataset
    """
    data_path = os.path.join(dataset_path, 'train' if training else 'test')
    ds = de.ImageFolderDataset(data_path, num_parallel_workers=4, class_indexing=config.class_index)
    resize_height = config.image_height
    resize_width = config.image_width
    
    normalize_op = C.Normalize(mean=[0.485*255, 0.456*255, 0.406*255], std=[0.229*255, 0.224*255, 0.225*255])
    change_swap_op = C.HWC2CHW()
    type_cast_op = C2.TypeCast(mstype.int32)

    if training:
        crop_decode_resize = C.RandomCropDecodeResize(resize_height, scale=(0.08, 1.0), ratio=(0.75, 1.333))
        horizontal_flip_op = C.RandomHorizontalFlip(prob=0.5)
        color_adjust = C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
    
        train_trans = [crop_decode_resize, horizontal_flip_op, color_adjust, normalize_op, change_swap_op]
        train_ds = ds.map(input_columns="image", operations=train_trans, num_parallel_workers=4)
        train_ds = train_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=4)
        
        train_ds = train_ds.shuffle(buffer_size=buffer_size)
        ds = train_ds.batch(config.batch_size, drop_remainder=True)
    else:
        decode_op = C.Decode()
        resize_op = C.Resize((int(resize_width/0.875), int(resize_width/0.875)))
        center_crop = C.CenterCrop(resize_width)
        
        eval_trans = [decode_op, resize_op, center_crop, normalize_op, change_swap_op]
        eval_ds = ds.map(input_columns="image", operations=eval_trans, num_parallel_workers=4)
        eval_ds = eval_ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=4)
        ds = eval_ds.batch(config.eval_batch_size, drop_remainder=True)

    return ds

        展示部分处理后的数据:

ds = create_dataset(dataset_path=config.dataset_path, config=config, training=False)
print(ds.get_dataset_size())
data = ds.create_dict_iterator(output_numpy=True)._get_next()
images = data['image']
labels = data['label']

for i in range(1, 5):
    plt.subplot(2, 2, i)
    plt.imshow(np.transpose(images[i], (1,2,0)))
    plt.title('label: %s' % class_en[labels[i]])
    plt.xticks([])
plt.show()

        运行结果:

        三、 MobileNetv2模型搭建

        使用MindSpore定义MobileNetV2网络的各模块时需要继承mindspore.nn.Cell。Cell是所有神经网络(Conv2d等)的基类。

        神经网络的各层需要预先在__init__方法中定义,然后通过定义construct方法来完成神经网络的前向构造。原始模型激活函数为ReLU6,池化模块采用是全局平均池化层。

__all__ = ['MobileNetV2', 'MobileNetV2Backbone', 'MobileNetV2Head', 'mobilenet_v2']

def _make_divisible(v, divisor, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v

class GlobalAvgPooling(nn.Cell):
    """
        定义GlobalAvgPooling.
    """

    def __init__(self):
        super(GlobalAvgPooling, self).__init__()

    def construct(self, x):
        x = P.mean(x, (2, 3))
        return x

class ConvBNReLU(nn.Cell):
    """
        定义卷积/深度与Batchnorm和ReLU块
        Examples:
            >>> ConvBNReLU(16, 256, kernel_size=1, stride=1, groups=1)
    """

    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        super(ConvBNReLU, self).__init__()
        padding = (kernel_size - 1) // 2
        in_channels = in_planes
        out_channels = out_planes
        if groups == 1:
            conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad', padding=padding)
        else:
            out_channels = in_planes
            conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, pad_mode='pad',
                             padding=padding, group=in_channels)

        layers = [conv, nn.BatchNorm2d(out_planes), nn.ReLU6()]
        self.features = nn.SequentialCell(layers)

    def construct(self, x):
        output = self.features(x)
        return output

class InvertedResidual(nn.Cell):
    """
        定义Mobilenetv2剩余块.
        Examples:
            >>> ResidualBlock(3, 256, 1, 1)
    """

    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        assert stride in [1, 2]

        hidden_dim = int(round(inp * expand_ratio))
        self.use_res_connect = stride == 1 and inp == oup

        layers = []
        if expand_ratio != 1:
            layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
        layers.extend([
            ConvBNReLU(hidden_dim, hidden_dim,
                       stride=stride, groups=hidden_dim),
            nn.Conv2d(hidden_dim, oup, kernel_size=1,
                      stride=1, has_bias=False),
            nn.BatchNorm2d(oup),
        ])
        self.conv = nn.SequentialCell(layers)
        self.cast = P.Cast()

    def construct(self, x):
        identity = x
        x = self.conv(x)
        if self.use_res_connect:
            return P.add(identity, x)
        return x

class MobileNetV2Backbone(nn.Cell):
    """
        MobileNetV2架构.
        Args:
            class_num (int): number of classes.
            width_mult (int): 通道乘法器为8/16和其他。默认值为1
            has_dropout (bool): 是否使用dropout。默认为false
            inverted_residual_setting (list): 反向残留设置。默认为无
            round_nearest (list): 频道转向。默认为8
        Examples:
            >>> MobileNetV2(num_classes=1000)
    """

    def __init__(self, width_mult=1., inverted_residual_setting=None, round_nearest=8,
                 input_channel=32, last_channel=1280):
        super(MobileNetV2Backbone, self).__init__()
        block = InvertedResidual
        # setting of inverted residual blocks
        self.cfgs = inverted_residual_setting
        if inverted_residual_setting is None:
            self.cfgs = [
                # t, c, n, s
                [1, 16, 1, 1],
                [6, 24, 2, 2],
                [6, 32, 3, 2],
                [6, 64, 4, 2],
                [6, 96, 3, 1],
                [6, 160, 3, 2],
                [6, 320, 1, 1],
            ]

        # building first layer
        input_channel = _make_divisible(input_channel * width_mult, round_nearest)
        self.out_channels = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
        features = [ConvBNReLU(3, input_channel, stride=2)]
        # building inverted residual blocks
        for t, c, n, s in self.cfgs:
            output_channel = _make_divisible(c * width_mult, round_nearest)
            for i in range(n):
                stride = s if i == 0 else 1
                features.append(block(input_channel, output_channel, stride, expand_ratio=t))
                input_channel = output_channel
        features.append(ConvBNReLU(input_channel, self.out_channels, kernel_size=1))
        self.features = nn.SequentialCell(features)
        self._initialize_weights()

    def construct(self, x):
        x = self.features(x)
        return x

    def _initialize_weights(self):
        """
            初始化权重.   
            Examples:
                >>> _initialize_weights()
        """
        self.init_parameters_data()
        for _, m in self.cells_and_names():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.set_data(Tensor(np.random.normal(0, np.sqrt(2. / n),
                                                          m.weight.data.shape).astype("float32")))
                if m.bias is not None:
                    m.bias.set_data(
                        Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
            elif isinstance(m, nn.BatchNorm2d):
                m.gamma.set_data(
                    Tensor(np.ones(m.gamma.data.shape, dtype="float32")))
                m.beta.set_data(
                    Tensor(np.zeros(m.beta.data.shape, dtype="float32")))

    @property
    def get_features(self):
        return self.features

class MobileNetV2Head(nn.Cell):
    """
        MobileNetV2 architecture.
        Args:
            class_num (int): Number of classes. Default is 1000.
            has_dropout (bool): 是否使用dropout。默认为false
        Examples:
            >>> MobileNetV2(num_classes=1000)
    """

    def __init__(self, input_channel=1280, num_classes=1000, has_dropout=False, activation="None"):
        super(MobileNetV2Head, self).__init__()
        # mobilenet head
        head = ([GlobalAvgPooling(), nn.Dense(input_channel, num_classes, has_bias=True)] if not has_dropout else
                [GlobalAvgPooling(), nn.Dropout(0.2), nn.Dense(input_channel, num_classes, has_bias=True)])
        self.head = nn.SequentialCell(head)
        self.need_activation = True
        if activation == "Sigmoid":
            self.activation = nn.Sigmoid()
        elif activation == "Softmax":
            self.activation = nn.Softmax()
        else:
            self.need_activation = False
        self._initialize_weights()

    def construct(self, x):
        x = self.head(x)
        if self.need_activation:
            x = self.activation(x)
        return x

    def _initialize_weights(self):
        """
            初始化权重.
            Examples:
                >>> _initialize_weights()
        """
        self.init_parameters_data()
        for _, m in self.cells_and_names():
            if isinstance(m, nn.Dense):
                m.weight.set_data(Tensor(np.random.normal(
                    0, 0.01, m.weight.data.shape).astype("float32")))
                if m.bias is not None:
                    m.bias.set_data(
                        Tensor(np.zeros(m.bias.data.shape, dtype="float32")))
    @property
    def get_head(self):
        return self.head

class MobileNetV2(nn.Cell):
    """
        MobileNetV2 architecture.
        Examples:
            >>> MobileNetV2(backbone, head)
    """

    def __init__(self, num_classes=1000, width_mult=1., has_dropout=False, inverted_residual_setting=None, \
        round_nearest=8, input_channel=32, last_channel=1280):
        super(MobileNetV2, self).__init__()
        self.backbone = MobileNetV2Backbone(width_mult=width_mult, \
            inverted_residual_setting=inverted_residual_setting, \
            round_nearest=round_nearest, input_channel=input_channel, last_channel=last_channel).get_features
        self.head = MobileNetV2Head(input_channel=self.backbone.out_channel, num_classes=num_classes, \
            has_dropout=has_dropout).get_head

    def construct(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x

class MobileNetV2Combine(nn.Cell):
    """
        MobileNetV2Combine architecture.
        Args:
            backbone (Cell): 特征提取层
            head (Cell):  全链接层.
        Examples:
            >>> MobileNetV2(num_classes=1000)
    """

    def __init__(self, backbone, head):
        super(MobileNetV2Combine, self).__init__(auto_prefix=False)
        self.backbone = backbone
        self.head = head

    def construct(self, x):
        x = self.backbone(x)
        x = self.head(x)
        return x

def mobilenet_v2(backbone, head):
    return MobileNetV2Combine(backbone, head)

        四、 MobileNetv2模型的训练与测试

        1. 训练策略

        一般情况下,模型训练时采用静态学习率,如0.01。随着训练步数的增加,模型逐渐趋于收敛,对权重参数的更新幅度应该逐渐降低,以减小模型训练后期的抖动。所以,模型训练时可以采用动态下降的学习率,常见的学习率下降策略有:

        · polynomial decay/square decay

        · cosine decay

        · exponential decay

        · stage decay

        这里使用cosine decay下降策略:

def cosine_decay(total_steps, lr_init=0.0, lr_end=0.0, lr_max=0.1, warmup_steps=0):
    """
        应用余弦衰减生成学习率数组.
        Args:
           total_steps(int): 训练的所有步骤.
           lr_init(float): 初始学习率.
           lr_end(float): 最终学习率
           lr_max(float): 最大学习率.
           warmup_steps(int): warmup阶段的所有步骤.
        Returns:
           list, learning rate array.
    """
    lr_init, lr_end, lr_max = float(lr_init), float(lr_end), float(lr_max)
    decay_steps = total_steps - warmup_steps
    lr_all_steps = []
    inc_per_step = (lr_max - lr_init) / warmup_steps if warmup_steps else 0
    for i in range(total_steps):
        if i < warmup_steps:
            lr = lr_init + inc_per_step * (i + 1)
        else:
            cosine_decay = 0.5 * (1 + math.cos(math.pi * (i - warmup_steps) / decay_steps))
            lr = (lr_max - lr_end) * cosine_decay + lr_end
        lr_all_steps.append(lr)

    return lr_all_steps

        在模型训练过程中,可以添加检查点(Checkpoint)用于保存模型的参数,以便进行推理及中断后再训练使用。使用场景如下:

        · 训练后推理场景

        1. 模型训练完毕后保存模型的参数,用于推理或预测操作。

        2. 训练过程中,通过实时验证精度,把精度最高的模型参数保存下来,用于预测操作。

        · 再训练场景

        1. 进行长时间训练任务时,保存训练过程中的Checkpoint文件,防止任务异常退出后从初始状态开始训练。

        2. Fine-tuning(微调)场景,即训练一个模型并保存参数,基于该模型,面向第二个类似任务进行模型训练。

        这里加载ImageNet数据上预训练的MobileNetv2进行Fine-tuning,只训练最后修改的FC层,并在训练过程中保存Checkpoint。

def switch_precision(net, data_type):
    if ms.get_context('device_target') == "Ascend":
        net.to_float(data_type)
        for _, cell in net.cells_and_names():
            if isinstance(cell, nn.Dense):
                cell.to_float(ms.float32)

        2. 模型训练与测试

        在进行正式的训练之前,定义训练函数,读取数据并对模型进行实例化,定义优化器和损失函数。

        首先简单介绍损失函数及优化器的概念:

        · 损失函数:又叫目标函数,用于衡量预测值与实际值差异的程度。深度学习通过不停地迭代来缩小损失函数的值。定义一个好的损失函数,可以有效提高模型的性能。

        · 优化器:用于最小化损失函数,从而在训练过程中改进模型。

        定义了损失函数后,可以得到损失函数关于权重的梯度。梯度用于指示优化器优化权重的方向,以提高模型性能。

        在训练MobileNetV2之前对MobileNetV2Backbone层的参数进行了固定,使其在训练过程中对该模块的权重参数不进行更新;只对MobileNetV2Head模块的参数进行更新。

        MindSpore支持的损失函数有SoftmaxCrossEntropyWithLogits、L1Loss、MSELoss等。这里使用SoftmaxCrossEntropyWithLogits损失函数。

        训练测试过程中会打印loss值,loss值会波动,但总体来说loss值会逐步减小,精度逐步提高。每个人运行的loss值有一定随机性,不一定完全相同。每打印一个epoch后模型都会在测试集上的计算测试精度,从打印的精度值分析MobileNetV2模型的预测能力在不断提升。

from mindspore.amp import FixedLossScaleManager
import time
LOSS_SCALE = 1024

train_dataset = create_dataset(dataset_path=config.dataset_path, config=config)
eval_dataset = create_dataset(dataset_path=config.dataset_path, config=config)
step_size = train_dataset.get_dataset_size()
    
backbone = MobileNetV2Backbone() #last_channel=config.backbone_out_channels
# Freeze parameters of backbone. You can comment these two lines.
for param in backbone.get_parameters():
    param.requires_grad = False
# load parameters from pretrained model
load_checkpoint(config.pretrained_ckpt, backbone)

head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)
network = mobilenet_v2(backbone, head)

# define loss, optimizer, and model
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(LOSS_SCALE, drop_overflow_update=False)
lrs = cosine_decay(config.epochs * step_size, lr_max=config.lr_max)
opt = nn.Momentum(network.trainable_params(), lrs, config.momentum, config.weight_decay, loss_scale=LOSS_SCALE)

# 定义用于训练的train_loop函数。
def train_loop(model, dataset, loss_fn, optimizer):
    # 定义正向计算函数
    def forward_fn(data, label):
        logits = model(data)
        loss = loss_fn(logits, label)
        return loss

    # 定义微分函数,使用mindspore.value_and_grad获得微分函数grad_fn,输出loss和梯度。
    # 由于是对模型参数求导,grad_position 配置为None,传入可训练参数。
    grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters)

    # 定义 one-step training函数
    def train_step(data, label):
        loss, grads = grad_fn(data, label)
        optimizer(grads)
        return loss

    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)

        if batch % 10 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

# 定义用于测试的test_loop函数。
def test_loop(model, dataset, loss_fn):
    num_batches = dataset.get_dataset_size()
    model.set_train(False)
    total, test_loss, correct = 0, 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred = model(data)
        total += len(data)
        test_loss += loss_fn(pred, label).asnumpy()
        correct += (pred.argmax(1) == label).asnumpy().sum()
    test_loss /= num_batches
    correct /= total
    print(f"Test: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

print("============== Starting Training ==============")
# 由于时间问题,训练过程只进行了2个epoch ,可以根据需求调整。
epoch_begin_time = time.time()
epochs = 2
for t in range(epochs):
    begin_time = time.time()
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(network, train_dataset, loss, opt)
    ms.save_checkpoint(network, "save_mobilenetV2_model.ckpt")
    end_time = time.time()
    times = end_time - begin_time
    print(f"per epoch time: {times}s")
    test_loop(network, eval_dataset, loss)
epoch_end_time = time.time()
times = epoch_end_time - epoch_begin_time
print(f"total time:  {times}s")
print("============== Training Success ==============")

        运行结果:

        五、 模型推理

        加载模型Checkpoint进行推理,使用load_checkpoint接口加载数据时,需要把数据传入给原始网络,而不能传递给带有优化器和损失函数的训练网络。

CKPT="save_mobilenetV2_model.ckpt"
def image_process(image):
    """Precess one image per time.
    
    Args:
        image: shape (H, W, C)
    """
    mean=[0.485*255, 0.456*255, 0.406*255]
    std=[0.229*255, 0.224*255, 0.225*255]
    image = (np.array(image) - mean) / std
    image = image.transpose((2,0,1))
    img_tensor = Tensor(np.array([image], np.float32))
    return img_tensor

def infer_one(network, image_path):
    image = Image.open(image_path).resize((config.image_height, config.image_width))
    logits = network(image_process(image))
    pred = np.argmax(logits.asnumpy(), axis=1)[0]
    print(image_path, class_en[pred])

def infer():
    backbone = MobileNetV2Backbone(last_channel=config.backbone_out_channels)
    head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)
    network = mobilenet_v2(backbone, head)
    load_checkpoint(CKPT, network)
    for i in range(91, 100):
        infer_one(network, f'data_en/test/Cardboard/000{i}.jpg')
infer()

        运行结果:

        六、 导出AIR/GEIR/ONNX模型文件

        导出AIR模型文件,用于后续Atlas 200 DK上的模型转换与推理。当前仅支持MindSpore+Ascend环境。

backbone = MobileNetV2Backbone(last_channel=config.backbone_out_channels)
head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=config.num_classes)
network = mobilenet_v2(backbone, head)
load_checkpoint(CKPT, network)

input = np.random.uniform(0.0, 1.0, size=[1, 3, 224, 224]).astype(np.float32)
# export(network, Tensor(input), file_name='mobilenetv2.air', file_format='AIR')
# export(network, Tensor(input), file_name='mobilenetv2.pb', file_format='GEIR')
export(network, Tensor(input), file_name='mobilenetv2.onnx', file_format='ONNX')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1958953.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis 7.x 系列【34】Spring Boot 集成

有道无术&#xff0c;术尚可求&#xff0c;有术无道&#xff0c;止于术。 本系列Redis 版本 7.2.5 源码地址&#xff1a;https://gitee.com/pearl-organization/study-redis-demo 文章目录 1. 前言2. Spring Data Redis3. Spring Boot Data Redis Starter3.1 起步依赖3.2 自动…

精品PPT | 微信云原生大数据平台构建及落地实践.pptx

一、大数据上云概述 1.为什么大数据要上云 2.微信大数据平台架构演进 二、大数据上云基础建设 1.统一编排 2.Pod 设计及大数据配套能力 3.计算组件云环境适配 三、稳定性及效率提升 1.K8S 集群稳定性与弹性配额 2.可观测性与智能运维

Java学习笔记(六)面向对象编程(基础部分)

Hi i,m JinXiang ⭐ 前言 ⭐ 本篇文章主要介绍Java面向对象编程&#xff08;基础部分&#xff09;类与对象、方法重载、作用域、构造器细节、this关键字、可变参数使用以及部分理论知识 &#x1f349;欢迎点赞 &#x1f44d; 收藏 ⭐留言评论 &#x1f4dd;私信必回哟&#x1…

C# 12 新增功能实操!

前言 今天咱们一起来探索并实践 C# 12 引入的全新功能&#xff01; C#/.NET该如何自学入门&#xff1f; 注意&#xff1a;使用这些功能需要使用最新的 Visual Studio 2022 版本或安装 .NET 8 SDK 。 主构造函数 主构造函数允许你直接在类定义中声明构造函数参数&#xff0c;…

停车共享小程序的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;停车场管理&#xff0c;停车预约管理&#xff0c;停车缴费管理&#xff0c;系统管理 微信端账号功能包括&#xff1a;系统首页&#xff0c;停车场&#xff0c;我的 开发系统&…

主图趋势交易九稳量化系统 期货指标公式大全 最准的期货指标源码 看期货涨跌最简单的方法文华财经指标公式源码

交易的动机必须来自于内心&#xff0c;一种解决问题的执着。在整个交易生涯的漫长岁月里&#xff0c;无法始终保持这种热忱。除非亲身体验&#xff0c;否则很难理解这种疯狂的热忱。这是一种高度的专注&#xff0c;其他一切好像都不存在&#xff0c;视野之内没有其他的东西。这…

【STL】之 vector 使用方法及模拟实现

前言&#xff1a; 本文主要讲在C STL库中vector容器的使用方法和底层的模拟实现~ 成员变量的定义&#xff1a; 对于vector容器&#xff0c;我们首先采用三个成员变量去进行定义&#xff0c;分别是&#xff1a; private:iterator _start; // 指向数据块的开始iterator _finish…

论文解读(10)-图神经网络

加油&#xff0c;继续看论文。 这次学图神经网络&#xff0c;这个概念经常在其他论文里出现&#xff0c;所以我想先学习一下这方面的知识。 参考&#xff1a; 【图神经网络综述】一文道尽GNN原理、框架和应用-CSDN博客 【图神经网络】10分钟掌握图神经网络及其经典模型_图神经…

网络爬虫必备工具:代理IP科普指南

文章目录 1. 网络爬虫简介1.1 什么是网络爬虫&#xff1f;1.2 网络爬虫的应用领域1.3 网络爬虫面临的主要挑战 2. 代理IP&#xff1a;爬虫的得力助手2.1 代理IP的定义和工作原理2.2 为什么爬虫需要代理IP&#xff1f;2.3 代理IP如何解决爬虫的常见问题&#xff1f; 3. 代理IP的…

shapeit填充

使用shapeit软件进行填充 一&#xff0c;安装 下载地址&#xff0c; 官网里面写得很详细。 https://mathgen.stats.ox.ac.uk/genetics_software/shapeit/shapeit.html 二&#xff0c;步骤 官网里面每一个参数都很详细 1.拆分染色体 for chr in {1..24}; do plink --vcf /…

基于dcm4chee搭建的PACS系统讲解(三)服务端使用Rest API获取study等数据

文章目录 DICOMWeb Support模块主要数据结构ER查询信息基本信息metadata信息统计信息 实践查询API及参数解析API返回的json数组定义VRObjectNodeObjectMapper解析显示指定tag并解析 后记 前期预研的PACS系统&#xff0c;近期要在项目中上线了。因为PACS系统采用无权限认证&…

Embeddings 赋能 - AI智能匹配,呈现精准内容

&#x1f680;前言 在当今的 AI 时代,传统的相关内容推荐和搜索功能已经显得相对简单和低效。借助 AI 技术,我们可以实现更加智能化和个性化的内容发现体验。 本文将为大家介绍如何利用 OpenAI 的 Embedding 技术,打造出智能、高效的相关内容推荐和搜索功能。 &#x1f680;…

UCOS-III 互斥锁接口详解

在实时操作系统uC/OS-III中&#xff0c;互斥锁&#xff08;Mutex&#xff09;是一种用于管理对共享资源的访问的同步机制。互斥锁通过保证在任何时刻只有一个任务可以持有锁&#xff0c;从而防止资源竞争问题。同时&#xff0c;uC/OS-III还实现了递归锁定和优先级继承机制&…

七款公司常用的加密软件推荐|2024年公司办公加密软件推荐

在现代企业中&#xff0c;加密软件是保护敏感信息、防止数据泄露和确保通信安全的关键工具。加密软件能够对数据进行加密&#xff0c;使其在未经授权的情况下无法被读取或篡改&#xff0c;本文分享七款加密软件&#xff0c;它们各具特色&#xff0c;能够满足不同的安全需求。 1…

狂赚又吸金 身心灵赛道AI玩法全解析

想必很多初入AI的小白们&#xff0c;小白的不能在小白了&#xff0c;因为在他们眼中&#xff0c;确实对AI一无所知。 基于他们平时刷抖音、刷视频号的习惯&#xff0c;有的时候会发一些传统剪辑的作品&#xff0c;问AI怎么做&#xff1f;很多人认为AI所见的视频&#xff0c;AI…

GeoServer发布MongoDB中的shp数据全流程梳理

目录 前言1.shp转geojson2.shp导入MongoDB3.创建空间索引4.GeoServer安装MongoDB插件5.发布6.注意事项6.1 geojson要去掉头尾6.2 MongoDB4.4以上的mongoimport工具需要额外安装6.3 空间索引是必须项 7.总结 前言 网上搜到的GeoServer发布MongoDB中的矢量数据或shp数据的文章比较…

http协议与nginx

动态页面与静态页面的差别&#xff1a; &#xff08;1&#xff09;URL不同 静态⻚⾯链接⾥没有“?” 动态⻚⾯链接⾥包含“&#xff1f;” &#xff08;2&#xff09;后缀不同 (开发语⾔不同) 静态⻚⾯⼀般以 .html .htm .xml 为后缀 动态⻚⾯⼀般以 .php .jsp .py等为后…

我国工业大模型发展中的四个反差现象

以大模型为代表的新一代人工智能技术正加速推进新型工业化的变革进程。2024年1月&#xff0c;国务院常务会议研究部署推动人工智能赋能新型工业化有关工作&#xff0c;强调以人工智能和制造业深度融合为主线&#xff0c;以智能制造为主攻方向&#xff0c;以场景应用为牵引&…

【Git从入门到精通】——知识概述及Git安装

&#x1f3bc;个人主页&#xff1a;【Y小夜】 &#x1f60e;作者简介&#xff1a;一位双非学校的大二学生&#xff0c;编程爱好者&#xff0c; 专注于基础和实战分享&#xff0c;欢迎私信咨询&#xff01; &#x1f386;入门专栏&#xff1a;&#x1f387;【MySQL&#xff0…

Google Test的使用

Google Test支持的操作系统包含下面这些&#xff1a; 1、Linux 2、Mac OS X 3、Windows 4、Cygwin 5、MinGW 6、Windows Mobile 7、Symbian一、google test的基本使用步骤 1、包含gtest/gtest.h头文件 2、使用TEST()宏定义测试case 3、在测试体中使用gooletest断言进行值检查…