mindspore框架保存及加载模型
-
详细流程:昇思-保存及加载模型
-
关键步骤
-
关键代码
from mindspore import export, load_checkpoint, load_param_into_net
from mindspore import Tensor
import numpy as np
from MobileNet2GarbageCls.MobileNetv2 import *
# 有了CheckPoint文件后,可导出模型file_format='MINDIR','AIR','ONNX'
backbone = MobileNetV2Backbone() # last_channel=config.backbone_out_channels
head = MobileNetV2Head(input_channel=backbone.out_channels, num_classes=7) # num_classes=7
network = mobilenet_v2(backbone, head) # 我的模型
pretrained_ckpt = './sleepClassify/Models/save_mobilenetV2_sleepCls0.824.ckpt' # 睡岗分类
load_checkpoint(pretrained_ckpt, network) # 将参数加载到网络中
# 1. 导出MindIR格式
input = np.random.uniform(0.0, 1.0, size=[32, 3, 224, 224]).astype(np.float32)
# export(network, Tensor(input), file_name='MobileNet2_sleep_0824', file_format='MINDIR')
# 2. 导出AIR格式:ValueError: Only support export file in 'AIR' format with Ascend backend.
# export(network, Tensor(input), file_name='MobileNet2_sleep_082', file_format='AIR')
# 3. 导出ONNX格式;mindspore框架目前ONNX格式导出仅支持ResNet系列、BERT网络。
export(network, Tensor(input), file_name='MobileNet2_sleep_0824', file_format='ONNX')