昇思25天学习打卡营第09天 | 保存与加载

news2024/10/6 0:58:11

昇思25天学习打卡营第09天 | 保存与加载

在训练网络模型的过程中,通常希望保存中间状态和最后的结果,用于后续的模型微调、推理和部署。

文章目录

  • 昇思25天学习打卡营第09天 | 保存与加载
    • 定义网络
    • 保存模型
    • 加载模型
    • 保存MindIR
    • 加载MindIR
    • 总结
    • 打卡

定义网络

def network():
   model = nn.SequentialCell(
               nn.Flatten(),
               nn.Dense(28*28, 512),
               nn.ReLU(),
               nn.Dense(512, 512),
               nn.ReLU(),
               nn.Dense(512, 10))
   return model

model = network()

保存模型

通过save_checkpoint()接口,传入网络和路径来保存模型:

mindspore.save_checkpoint(model, "model.ckpt")

加载模型

加载模型时首先需要创建相同的模型实例,然后通过load_checkpoint()load_param_into_net()方法加载参数:

model = network()
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)

保存MindIR

MindSpore提供云端(训练)和端侧(推理)统一的中间表示(Intermediate Representaton,IR)。
使用export接口将模型保存为MindIR。

  • MindIR同时保存Checkpoint和模型结构,因此需要定义输入Tensor来获取输入shape。
model = network()
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")

加载MindIR

MindIR可以通过load接口加载,传入nn.GraphCell即可进行推理。

nn.GraphCell仅支持图模式。

mindspore.set_context(mode=mindspore.GRAPH_MODE)

graph = mindspore.load("model.mindir")
model = nn.GraphCell(graph)
outputs = model(inputs)
print(outputs.shape)

总结

这一节的内容对模型的保存和加载接口进行了介绍。通过save_checkpoint将模型保存为.ckpt格式,通过load_checkpointload_param_into_net.ckpt文件重新加载到网络中进行后续操作。此外,还可以通过export接口将模型保存为MindIR格式,同时保存Checkpoint和模型结果,然后通过load加载到nn.GraphCell中即可进行推理。

打卡

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1903615.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RUST 编程语言 绘制随机颜色图片 画圆形 画矩形 画直线

什么是Rust Rust是一种系统编程语言,旨在提供高性能和安全性。它是由Mozilla和其开发社区创建的开源语言,设计目标是在C的应用场景中提供一种现代、可靠和高效的选择。Rust的目标是成为一种通用编程语言,能够处理各种计算任务,包…

#数据结构 顺序表

线性表 顺序表 每种结构都有它存在意义 线性表的顺序存储实现指的是用一组连续的存储单元存储线性表的数据元素。 概念 顺序表是用一段物理地址连续的存储单元依次存储数据元素的线性表,一般情况下采用数组存储。在数组上完成数据的增查改删。 逻辑结构&#…

数值分析笔记(五)线性方程组解法

三角分解法 A的杜利特分解公式如下: u 1 j a 1 j ( j 1 , 2 , ⋯ , n ) , l i 1 a i 1 / u 11 ( i 2 , 3 , ⋯ , n ) , u k j a k j − ∑ m 1 k − 1 l b m u m j ⇒ a k j ( j k , k 1 , ⋯ , n ) , l i k ( a i k − ∑ m 1 k − 1 l i n u m k ) /…

阶段三:项目开发---搭建项目前后端系统基础架构:QA:可能遇到的问题及解决方案

任务实现 常见问题1:文件监视程序的系统限制。 1、错误提示:如果在Vue项目中,使用【 npm run serve】运行kongguan_web项目时报以下错误: 2、产生原因:文件监视程序的系统产生了限制,达到了默认的上限&am…

数据结构1:C++实现变长数组

数组作为线性表的一种,具有内存连续这一特点,可以通过下标访问元素,并且下标访问的时间复杂的是O(1),在数组的末尾插入和删除元素的时间复杂度同样是O(1),我们使用C实现一个简单的边长数组。 数据结构定义 class Arr…

【Docker系列】Docker 命令行输出格式化指南

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

【Python】搭建属于自己 AI 机器人

目录 前言 1 准备工作 1.1 环境搭建 1.2 获取 API KEY 2 写代码 2.1 引用库 2.2 创建用户 2.3 创建对话 2.4 输出内容 2.5 调试 2.6 全部代码 2.7 简短的总结 3 优化代码 3.1 规范代码 3.1.1 引用库 3.1.2 创建提示词 3.1.3 创建模型 3.1.4 规范输出&#xf…

cs231n作业2 双层神经网络

双层神经网络 我们选用ReLU函数和softmax函数: 步骤: 1、LOSS损失函数(前向传播)与梯度(后向传播)计算 Forward: 计算score,再根据score计算loss Backward:分别对W2、b2、W1、b1求…

品质至上!中国星坤连接器的发展之道!

在电子连接技术领域,中国星坤以其卓越的创新能力和对品质的不懈追求,赢得了业界的广泛认可。凭借在高精度连接器设计和制造上的领先地位,星坤不仅获得了多项实用新型专利,更通过一系列国际质量管理体系认证,彰显了其产…

知识社区在线提问小程序模板源码

蓝色的知识问答,问答交流,知识社区,在线提问手机app小程序网页模板。包含:社区主页、提问、我的、绑定手机,实名认证等。 知识社区在线提问小程序模板源码

P5. 微服务: Bot代码的执行

P5. 微服务: Bot代码的执行 0 概述1 Bot代码执行框架2 Bot代码传递给BotRunningSystem3 微服务: Bot代码执行的实现逻辑3.1 整体微服务逻辑概述3.2 生产者消费者模型实现3.3 consume() 执行代码函数的实现3.4 执行结果返回给 nextStep 4 扩展4.1 Bot代码的语言 0 概述 本章介绍…

Keysight 是德 DSA91304A 高性能示波器

Keysight 是德 DSA91304A 高性能示波器 DSA91304A Infiniium 高性能示波器:13 GHz 13 GHz4个模拟通道高达 1 Gpts 存储器和 40 GSa/s 采样率可以提供更完整的信号迹线捕获50 mV/格时低至 1.73 mVrms 的本底噪声和深入的抖动分析功能可以确保卓越的测量精度硬件加速…

C语言_数据的存储

数据类型介绍 1. 整形家族 //字符存储的时候,存储的是ASCII值,是整型 //char 默认是unsigned char还是signed char标准没有规定,其他类型都默认是signed char,unsigned char,signed char short,unsigned s…

windows机器免密登录linux主机

1. 正常连接需要输入密码 ssh root1.1.1.1 2. 在Windows上生成SSH密钥对(如果你还没有的话): ssh-keygen 3. scp将id_rsa.pub传输到对应的主机 4.对应机器上查看 5.从windows上免密登录

rsyslog日志转发

前言 Rsyslog可用于接受来自各种来源(本地和网络)的输入,转换它们,并将结果输出到不同(通过模板和filter过滤)的目的地(目录文件中) rsyslog是一个开源工具,被广泛用于Linux系统以通过TCP/UDP…

cs231n 作业3

使用普通RNN进行图像标注 单个RNN神经元行为 前向传播: 反向传播: def rnn_step_backward(dnext_h, cache):dx, dprev_h, dWx, dWh, db None, None, None, None, Nonex, Wx, Wh, prev_h, next_h cachedtanh 1 - next_h**2dx (dnext_h*dtanh).dot(…

第T4周:使用TensorFlow实现猴痘病识别

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 文章目录 一、前期工作1.设置GPU(如果使用的是CPU可以忽略这步)2. 导入数据3. 查看数据 二、数据预处理1、加载数据2、数据可视化3、再…

人脸识别课堂签到系统【PyQt5实现】

人脸识别签到系统 1、运用场景 课堂签到,上班打卡,进出门身份验证。 2、功能类别 人脸录入,打卡签到,声音提醒,打卡信息导出,打包成exe可执行文件 3、技术栈 python3.8,sqlite3,opencv,face_recognition,PyQt5,csv 4、流程图 1、导入库 2、编写UI界面 3、打…

Linux服务器使用总结-不定时更新

# 查看升级日志 cat /var/log/dpkg.log |grep nvidia|grep libnvidia-common

C++ 多态篇

文章目录 1. 多态的概念和实现1.1 概念1.2 实现1.2.1 协变1.2.2 析构函数1.2.3 子类虚函数不加virtual 2. C11 final和override3.1 final3.2 override 3. 函数重载、重写与隐藏4. 多态的原理5. 抽象类6.单继承和多继承的虚表6.1 单继承6.2 多继承 7. 菱形继承的虚表(了解)7.1 菱…