昇思25天学习打卡营第09天 | 保存与加载
在训练网络模型的过程中,通常希望保存中间状态和最后的结果,用于后续的模型微调、推理和部署。
文章目录
- 昇思25天学习打卡营第09天 | 保存与加载
- 定义网络
- 保存模型
- 加载模型
- 保存MindIR
- 加载MindIR
- 总结
- 打卡
定义网络
def network():
model = nn.SequentialCell(
nn.Flatten(),
nn.Dense(28*28, 512),
nn.ReLU(),
nn.Dense(512, 512),
nn.ReLU(),
nn.Dense(512, 10))
return model
model = network()
保存模型
通过save_checkpoint()
接口,传入网络和路径来保存模型:
mindspore.save_checkpoint(model, "model.ckpt")
加载模型
加载模型时首先需要创建相同的模型实例,然后通过load_checkpoint()
和load_param_into_net()
方法加载参数:
model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
保存MindIR
MindSpore提供云端(训练)和端侧(推理)统一的中间表示(Intermediate Representaton,IR)。
使用export
接口将模型保存为MindIR。
MindIR
同时保存Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。
model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")
加载MindIR
MindIR可以通过load
接口加载,传入nn.GraphCell
即可进行推理。
nn.GraphCell
仅支持图模式。
mindspore.set_context(mode=mindspore.GRAPH_MODE)
graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)
总结
这一节的内容对模型的保存和加载接口进行了介绍。通过save_checkpoint
将模型保存为.ckpt
格式,通过load_checkpoint
和load_param_into_net
将.ckpt
文件重新加载到网络中进行后续操作。此外,还可以通过export
接口将模型保存为MindIR
格式,同时保存Checkpoint
和模型结果,然后通过load
加载到nn.GraphCell
中即可进行推理。