昇思MindSpore学习总结八——模型保存与加载

news2024/12/22 16:56:27

        在训练网络模型的过程中,实际上我们希望保存中间和最后的结果,用于微调(fine-tune)和后续的模型推理与部署,接下来将介绍如何保存与加载模型。

1.构建模型

import numpy as np
import mindspore
from mindspore import nn
from mindspore import Tensor

def network():
    model = nn.SequentialCell(
                nn.Flatten(),
                nn.Dense(28*28, 512),
                nn.ReLU(),
                nn.Dense(512, 512),
                nn.ReLU(),
                nn.Dense(512, 10))
    return model

这里是没有经过训练的,可以直接用上一节训练的模型model。

2、保存和加载权重

2.1 保存

保存模型使用save_checkpoint接口,传入网络和指定的保存路径:

mindspore.save_checkpoint(save_objckpt_file_nameintegrated_save=Trueasync_save=Falseappend_dict=Noneenc_key=Noneenc_mode='AES-GCM'choice_func=None**kwargs)

【参数】

  • save_obj (Union[Cell, list, dict]) - 待保存的对象。数据类型可为 mindspore.nn.Cell 、list或dict。若为list,可以是 Cell.trainable_params() 的返回值,或元素为dict的列表(如[{“name”: param_name, “data”: param_data},…],param_name 的类型必须是str,param_data 的类型必须是Parameter或者Tensor);若为dict,可以是 mindspore.load_checkpoint() 的返回值。

  • ckpt_file_name (str) - checkpoint文件名称。如果文件已存在,将会覆盖原有文件。

  • integrated_save (bool) - 在并行场景下是否合并保存拆分的Tensor。默认值: True 。

  • async_save (bool) - 是否异步执行保存checkpoint文件。默认值: False 。

  • append_dict (dict) - 需要保存的其他信息。dict的键必须为str类型,dict的值类型必须是int、float、bool、string、Parameter或Tensor类型。默认值: None 。

  • enc_key (Union[None, bytes]) - 用于加密的字节类型密钥。如果值为 None ,那么不需要加密。默认值: None 。

  • enc_mode (str) - 该参数在 enc_key 不为 None 时有效,指定加密模式,目前仅支持 "AES-GCM" , "AES-CBC" 和 "SM4-CBC" 。默认值: "AES-GCM" 。

  • choice_func (function) - 一个用于自定义控制保存参数的函数。函数的输入值为字符串类型的Parameter名称,并且返回值是一个布尔值。如果返回 True ,则匹配自定义条件的Parameter将被保存。 如果返回 False ,则未匹配自定义条件的Parameter不会被保存。默认值: None 。

  • kwargs (dict) - 配置选项字典。

model = network()
mindspore.save_checkpoint(model, "model.ckpt")

运行之后会在同路径下找到一个文件

2.2 加载

        要加载模型权重,需要先创建相同模型的实例,然后使用load_checkpointload_param_into_net方法加载参数。

2.2.1 load_checkpoint

mindspore.load_checkpoint(ckpt_file_namenet=Nonestrict_load=Falsefilter_prefix=Nonedec_key=Nonedec_mode='AES-GCM'specify_prefix=Nonechoice_func=None)

【参数】

  • ckpt_file_name (str) - checkpoint的文件名称。

  • net (Cell) - 加载checkpoint参数的网络。默认值: None 。

  • strict_load (bool) - 是否将严格加载参数到网络中。如果是 False ,它将根据相同的后缀名将参数字典中的参数加载到网络中,并会在精度不匹配时,进行强制精度转换,比如将 float32 转换为 float16 。默认值: False 。

  • filter_prefix (Union[str, list[str], tuple[str]]) - 废弃(请参考参数 choice_func)。以 filter_prefix 开头的参数将不会被加载。默认值: None 。

  • dec_key (Union[None, bytes]) - 用于解密的字节类型密钥,如果值为 None ,则不需要解密。默认值: None 。

  • dec_mode (str) - 该参数仅当 dec_key 不为 None 时有效。指定解密模式,目前支持 "AES-GCM" , "AES-CBC" 和 "SM4-CBC" 。默认值: "AES-GCM" 。

  • specify_prefix (Union[str, list[str], tuple[str]]) - 废弃(请参考参数 choice_func)。以 specify_prefix 开头的参数将会被加载。默认值: None 。

  • choice_func (Union[None, function]) - 函数的输入值为字符串类型的Parameter名称,并且返回值是一个布尔值。如果返回 True ,则匹配自定义条件的Parameter将被加载。 如果返回 False ,则匹配自定义条件的Parameter将被删除。默认值: None 。

2.2.2 load_param_into_net

mindspore.load_param_into_net(netparameter_dictstrict_load=False)将参数加载到网络中,返回网络中没有被加载的参数列表。

【参数】

  • net (Cell) - 将要加载参数的网络。

  • parameter_dict (dict) - 加载checkpoint文件得到的字典。

  • strict_load (bool) - 是否将参数严格加载到网络中。如果是 False , 它将以相同的后缀名将参数字典中的参数加载到网络中,并会在精度不匹配时,进行精度转换,比如将 float32 转换为 float16 。默认值: False 。

model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

【运行结果】

param_not_load是未被加载的参数列表,为空时代表所有参数均加载成功。 

3、保存和加载MindIR

        除Checkpoint外,MindSpore提供了云侧(训练)和端侧(推理)统一的中间表示(Intermediate Representation,IR)。可使用export接口直接将模型保存为MindIR。

3.1 保存

mindspore.export(net*inputsfile_namefile_format**kwargs)将MindSpore网络模型导出为指定格式的文件。

【参数】

  • net (Union[Cell, function]) - MindSpore网络结构。

  • inputs (Union[Tensor, Dataset, List, Tuple, Number, Bool]) - 网络的输入,如果网络有多个输入,需要一同传入。当传入的类型为 Dataset 时,将会把数据预处理行为同步保存起来。需要手动调整batch的大小,当前仅支持获取 Dataset 的 image 列。

  • file_name (str) - 导出模型的文件名称。

  • file_format (str) - MindSpore目前支持导出”AIR”,”ONNX”和”MINDIR”格式的模型。

    • AIR - Ascend Intermediate Representation。一种Ascend模型的中间表示格式。推荐的输出文件后缀是”.air”。

    • ONNX - Open Neural Network eXchange。一种针对机器学习所设计的开放式的文件格式。推荐的输出文件后缀是”.onnx”。

    • MINDIR - MindSpore Native Intermediate Representation for Anf。一种MindSpore模型的中间表示格式。推荐的输出文件后缀是”.mindir”。

  • kwargs (dict) - 配置选项字典。

    • enc_key (byte) - 用于加密的字节类型密钥,有效长度为16、24或者32。

    • enc_mode (Union[str, function]) - 指定加密模式,当设置 enc_key 时启用。

      • 对于 ‘AIR’和 ‘ONNX’格式的模型,当前仅支持自定义加密导出。

      • 对于 ‘MINDIR’格式的模型,支持的加密选项有: ‘AES-GCM’, ‘AES-CBC’, ‘SM4-CBC’和用户自定义加密算法。默认值: "AES-GCM"

      • 关于使用自定义加密导出的详情,请查看 教程。

    • dataset (Dataset) - 指定数据集的预处理方法,用于将数据集的预处理导入MindIR。

    • obf_config (dict) - 模型混淆配置选项字典。

      • type (str) - 混淆类型,目前支持动态混淆,即 ‘dynamic’ 。

      • obf_ratio (Union[str, float]) - 全模型算子的混淆比例,可取浮点数(0, 1]或者字符串 "small" 、 "medium" 、 "large" 。"small" 、"medium" 、"large" 分别对应于 0.1、0.3、0.6。

      • customized_func (function) - 在自定义函数模式下需要设置的Python函数,用来控制混淆结构中的选择分支走向。它的返回值需要是bool类型,且是恒定的,用户可以参考不透明谓词进行设置(请查看 动态混淆教程 中的 my_func())。如果设置了 customized_func ,那么在使用 load 接口导入模型的时候,需要把这个函数也传入。

      • obf_random_seed (int) - 混淆随机种子,是一个取值范围为(0, 9223372036854775807]的整数,不同的随机种子会使模型混淆后的结构不同。如果用户设置了 obf_random_seed ,那么在部署混淆模型的时候,需要在调用 mindspore.nn.GraphCell 接口中传入 obf_random_seed 。需要注意的是,如果用户同时设置了 customized_func 和 obf_random_seed ,那么后一种模式将会被采用。

    • custom_func (function) - 用户自定义的导出策略的函数。该函数会在网络导出时,对模型使用该函数进行自定义处理。需要注意,当前仅支持对 format 为 MindIR 的文件使用 custom_func ,且自定义函数仅接受一个代表 MindIR 文件 Proto 对象的入参。当使用 custom_func 对模型进行修改时,需要保证修改后模型的正确性,否则可能导致模型加载失败或功能错误。默认值: None 。

model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

MindIR同时保存了Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。

3.2 加载

已有的MindIR模型可以方便地通过load接口加载,传入nn.GraphCell即可进行推理。nn.GraphCell仅支持图模式。

mindspore.nn.GraphCell(graphparams_init=Noneobf_random_seed=None)

运行从MindIR加载的计算图。

此功能仍在开发中。目前 GraphCell 不支持修改图结构,在导出MindIR时只能使用shape和类型与输入相同的数据。

【参数】

  • graph (FuncGraph) - 从MindIR加载的编译图。

  • params_init (dict) - 需要在图中初始化的参数。key为参数名称,类型为字符串,value为 Tensor 或 Parameter。如果参数名在图中已经存在,则更新其值;如果不存在,则忽略。默认值: None 。

  • obf_random_seed (Union[int, None]) - 用于动态混淆保护的混淆随机种子。动态混淆是一种模型保护方法,可以参考 mindspore.obfuscate_model() 。如果导入的 graph 是一个经过混淆的模型,那么须提供 obf_random_seed 。 obf_random_seed 的取值范围是(0, 9223372036854775807]。默认值: None 。

mindspore.set_context(mode=mindspore.GRAPH_MODE)

graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)

 这里时间改了一下,之前差8个小时。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1885282.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

别再被大模型骗了,一个小技巧,让LLaMa3诚信度提升65%

人工智能正以惊人的速度发展,大语言模型(LLM)作为其中的"明星",展现了令人赞叹的语言理解和生成能力。然而,在享受大语言模型带来便利的同时,我们也必须正视其在诚实性和安全性方面所面临的挑战。 近期,华中…

CSF视频文件格式转换WMV格式(2024年可用)

如果大家看过一些高校教学讲解视频的话,很可能见过这样一个难得的格式,".csf ",非常漂亮 。 用暴风影音都可以打开观看,会自动下载解码。 但是一旦我们想要利用或者上传视频的时候就麻烦了,一般网站不认这…

3个企业级最佳实践,教你ByteHouse云数仓这么用

随着各业务场景各行业数字化转型加快,数据量呈爆炸式增长。在拥有庞大数据的同时,业务也在分析、查询与响应层面,对数据库系统性能提出了更高要求。云原生技术推动了分布式数据库系统的迭代升级,对云数仓技术而言,“写…

MacBook关闭谷歌浏览器双指左右移动(扫动)前进后退功能

这个功能真的很反人类,正常上下滑动页面的时候很容易误操作,尤其是当你在一个页面上做了很多的编辑工作后误触发了此手势,那真叫一个崩溃! 其实这应该是 Macbook 触控板提供的一个快捷操作,跟浏览器本身估计没关系&am…

mysql-sql-第十三周

学习目标: sql 学习内容: 37.查询各科成绩最高分、最低分和平均分: 以如下形式显示:课程 ID,课程 name,最高分,最低分,平均分,及格率,中等率,优良率,优秀率 及格为>60,中等为:70-80,优良为:80-90,优秀…

使用Comsol进行边坡稳定性分析的例子——详细步骤(第二部分)

使用Comsol进行边坡稳定性分析的例子——详细步骤 研究1方法结果书接上回 在FOS参数的帮助下,对材料强度进行参数化。在第二个研究步骤中添加 FOS 的辅助扫描。对于某些 FOS 值,解不会收敛,并且设置为最后一个 FOS 值的默认图将给出错误。禁用此研究的默认绘图以避免出现错误…

65、基于卷积神经网络的调制分类(matlab)

1、基于卷积神经网络的调制分类的原理及流程 基于卷积神经网络(CNN)的调制分类是一种常见的信号处理任务,用于识别或分类不同调制方式的信号。下面是基于CNN的调制分类的原理和流程: 原理: CNN是一种深度学习模型&a…

root密码忘了怎么办(从系统引导过程解决)

目录 1.Linux系统密码忘记 2.系统引导过程 2.1 systemd 2.2 GRUB和GRUB2 2.3 运行级别 3.修复MBR扇区故障和GRUB引导故障 3.1 MBR扇区故障 3.2 GRUB引导故障 1.Linux系统密码忘记 我们在生活中经常遇到这类困扰,就是某个账号还是账户密码忘了,这…

Llama也能做图像生成?文生图模型已开源

导读 基于next-token prediction的图像生成方法首次在ImageNet benchmark超越了LDM, DiT等扩散模型,证明了最原始的自回归模型架构同样可以实现极具竞争力的图像生成性能。 Llama也能做图像生成?文生图模型已开源 香港大学、字节跳动提出了基于自回归模…

【AI大模型】大型模型飞跃升级—文档图像识别领域迎来技术巨变_图像识别大模型

写在前面 2023年12月31日,第十九届中国图象图形学学会青年科学家会议在广州举行,由中国图象图形学学会主办。 该会议的目标是促进青年科学家之间的交流与合作,以提升我国在图像图形领域的科研水平和创新能力。 由中国图象图形学学会和上海合合…

如何将音频文件发送至摄像头

目前再很多互联互通的场景下,如AI盒子再从摄像头上取视频分析,分析出发生某个事件,需要反向通过摄像头的喇叭播放语音,发出告警提示,使用场景如下 盒子上对于此类场景的需求往往不能满足,或者为这个需求需要…

Day8: 232.用栈实现队列 225. 用队列实现栈 20. 有效的括号 1047. 删除字符串中的所有相邻重复项

题目232. 用栈实现队列 - 力扣(LeetCode) class MyQueue { public:MyQueue() {}void push(int x) { // 出栈input.push(x);}int pop() {// 如果出栈为空,把入栈元素全都转移到出栈if (output.empty()) {while (!input.empty()) {int itop i…

【WEB前端2024】3D智体编程:乔布斯3D纪念馆-第52课-语音控制机器人

【WEB前端2024】3D智体编程:乔布斯3D纪念馆-第52课-语音控制机器人 使用dtns.network德塔世界(开源的智体世界引擎),策划和设计《乔布斯超大型的开源3D纪念馆》的系列教程。dtns.network是一款主要由JavaScript编写的智体世界引擎…

彭涛 | 2024年6月小结

6月是忙碌的一个月,换办公室,买家具,群发售,新小伙伴入职等等 1、出海小报童 这个月时间主要做小报童,从刚开始设计内容大纲,到写作,后续拉新花费了大量时间。 比如我们要去调研同行&#xff0c…

新能源行业必会基础知识-----电力市场概论笔记-----中长期合约电力市场

新能源行业知识体系-------主目录-----持续更新(进不去说明我没写完):https://blog.csdn.net/grd_java/article/details/139946830 目录 1. 合约市场2. 双边交易3. 集中交易4. 挂牌交易及互联网中长期电力交易平台5. 中长期交易的优势 1. 合约市场 什么是合约市场 …

从选题到定稿:软考高级系统架构设计师论文写作全攻略

一、论文考试概述 软考系统架构设计师考试的最后一门是论文写作,安排在下午进行,时长两小时,要求撰写约3000字的论文,以45分为及格线。时间紧迫,不容过多犹豫与思考,因此需迅速选定并着手撰写。论文题目通…

【数据结构】C语言实现二叉树

C语言实现二叉树 导读一、二叉树的数据类型二、二叉树的初始化2.1 补充知识点——传址传参2.2 补充知识点——指针传参 三、二叉树的创建3.1 通过添加结点创建BST3.2 通过结点序列创建二叉树3.2.1 由遍历序列手算构建二叉树3.2.1.1 构建步骤3.2.1.2 习题演练3.2.1.3 小结 3.2.2…

在C#/Net中使用Mqtt

net中MQTT的应用场景 c#常用来开发上位机程序,或者其他一些跟设备打交道比较多的系统,所以会经常作为拥有数据的终端,可以用来采集上传数据,而MQTT也是物联网常用的协议,所以下面介绍在C#开发中使用MQTT。 安装MQTTn…

yolov5实例分割跑通以及C#读取yolov5_Seg实例分割转换onnx进行检测部署

一、首先需要训练yolov5_seg的模型,可以去网上学习,或者你直接用我的, 训练环境和yolov5—7.0的环境一样,你可以直接拷过来用。 yolov5_seg算法 链接:https://pan.baidu.com/s/1m-3lFWRHwg5t8MmIOKm4FA 提取码&…

第十四届蓝桥杯省赛C++B组D题【飞机降落】题解(AC)

解题思路 这道题目要求我们判断给定的飞机是否都能在它们的油料耗尽之前降落。为了寻找是否存在合法的降落序列,我们可以使用深度优先搜索(DFS)的方法,尝试所有可能的降落顺序。 首先,我们需要理解题目中的条件。每架…