昇思25天学习打卡营第8天|模型权重与 MindIR 的保存加载

news2024/11/16 6:29:35

目录

导入Python 库和模块

创建神经网络模型

保存和加载模型权重

保存和加载MindIR


导入Python 库和模块


        上一章节着重阐述了怎样对超参数予以调整,以及如何开展网络模型的训练工作。在网络模型训练的整个进程当中,事实上我们满怀期望能够留存中间阶段以及最终的成果,以便用于细微的调整(fine-tune)以及后续的模型推理和部署操作。在本章节,我们将会为您介绍怎样去保存以及加载模型。

        首先,我们进行了一系列的 Python 库和模块的导入操作:我们导入了 NumPy 库,并将其简称为 np 。要知道,NumPy 通常被广泛应用于数值计算领域以及数组相关的操作之中。此外,我们还导入了 MindSpore 库,MindSpore 乃是一个极为出色的深度学习框架。不仅如此,我们从 MindSpore 库中导入了 nn 模块,这里面或许涵盖了与神经网络相关联的各类类和函数。最后,我们还从 MindSpore 库中导入了 Tensor 类,其主要作用在于创建张量这种数据结构。

        代码如下:

import numpy as np  
import mindspore  
from mindspore import nn  
from mindspore import Tensor  

创建神经网络模型


        定义了一个被称作“network”的函数,此函数旨在创建一个神经网络模型。在该函数的内部,通过运用“nn.SequentialCell”构建了一个按照顺序相互连接的神经网络。最终,这个函数会返回构建完成的模型。

        代码如下:

def network():  
    model = nn.SequentialCell(  
                #用于将输入数据展平为一维向量  
                nn.Flatten(),  
                #全连接层,输入维度为 28*28,输出维度为 512。  
                nn.Dense(28*28, 512),  
                #激活函数 ReLU 层。  
                nn.ReLU(),  
                #全连接层,输入维度为 512,输出维度为 512。  
                nn.Dense(512, 512),  
                #激活函数 ReLU 层。  
                nn.ReLU(),  
                #全连接层,输入维度为 512,输出维度为 10。  
                nn.Dense(512, 10))  
    return model  

保存和加载模型权重


        当对模型进行保存操作时,将采用 save_checkpoint 这一接口,并将网络和特定指定的保存路径传入其中。

        代码如下:

model = network()  
mindspore.save_checkpoint(model, "model.ckpt")  

        分析:在 MindSpore 框架中,“model = network()”这行代码一般而言是创建了一个被命名为“model”的对象。此对象是通过对名为“network”的函数或者类的调用而得以生成。而“mindspore.save_checkpoint(model, "model.ckpt")”这行代码,其发挥的作用是借助 MindSpore 框架所提供的“save_checkpoint”函数,把创建好的“model”对象的当前状态保存至一个叫做“model.ckpt”的文件之中。之所以要进行这样的操作,通常是出于如下目的:在后续的一系列操作里,能够重新加载这个模型的状态,从而便于继续开展训练工作、执行预测任务,或者实现模型的迁移以及部署等相关操作。

        要实现模型权重的加载,第一步是创建相同的模型实例,接下来则要通过 load_checkpoint 和 load_param_into_net 方法对参数予以加载。

        代码如下:

model = network()  
param_dict = mindspore.load_checkpoint("model.ckpt")  
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)  
print(param_not_load)  

         分析:首先,通过 model = network() 创建了一个名为 model 的网络对象。

        然后,使用 mindspore.load_checkpoint("model.ckpt") 从名为 "model.ckpt" 的文件中加载模型的检查点数据,并将其存储在 param_dict 中。

        接着,通过 mindspore.load_param_into_net(model, param_dict) 尝试将加载的参数数据 param_dict 加载到模型 model 中。同时,返回未成功加载的参数以及一个相关的标识,未成功加载的参数存储在 param_not_load 中。

        最后,使用 print(param_not_load) 输出未成功加载的参数。

        运行结果:

        []

保存和加载MindIR


        除了 Checkpoint 之外,MindSpore 为云侧(训练)和端侧(推理)提供了统一的中间表示(Intermediate Representation,IR)。用户能够通过 export 接口,直接将模型保存为 MindIR 格式。这种统一的中间表示和便捷的模型保存方式,为模型的训练和推理提供了高效且便捷的支持,极大地提升了开发和应用的效率。

        代码如下:

model = network()  
inputs = Tensor(np.ones([1, 1, 28, 28]).astype(np.float32))  
mindspore.export(model, inputs, file_name="model", file_format="MINDIR")  
mindspore.set_context(mode=mindspore.GRAPH_MODE)  
graph = mindspore.load("model.mindir")  
model = nn.GraphCell(graph)  
outputs = model(inputs)  
print(outputs.shape) 

        分析:首先定义了一个叫做 model 的网络模型,接着准备了一个输入数据 inputs ,这个数据的值全是 1 ,并且是张量的形式。然后通过 mindspore.export 把模型和这个输入保存成 MINDIR 格式的文件,文件名就叫 model 。接下来设置 MindSpore 的运行环境为图模式。再去加载之前保存的 model.mindir 文件,并把它转变为 GraphCell 类型的模型。之后使用之前准备好的输入数据 inputs 来对模型进行推理运算,从而得到输出 outputs 。最后把输出的形状给打印出来。

        运行结果:

        (1, 10)

        运行截图:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1892404.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

python基础语法 004-3流程控制- while

1 while while 主要用的场景没有 for 循环多。 while循环&#xff1a;主要运行场景 我不知道什么时候结束。。。不知道运行多少次 1.1 基本用法 # while 4 > 3: #一直执行 # print("hell0")while 4 < 3: #不会打印&#xff0c;什么都没有print("…

go开源webssh终端源码main.go分析

1.地址: https://github.com/Jrohy/webssh.git 2.添加中文注释地址: https://github.com/tonyimax/webssh_cn.git main.go分析 主包名&#xff1a;main package main //主包名 依赖包加载 //导入依赖包 import ("embed" //可执行文件…

【Qwen2部署实战】探索Qwen2-7B:通过FastApi框架实现API的部署与调用

系列篇章&#x1f4a5; No.文章1【Qwen部署实战】探索Qwen-7B-Chat&#xff1a;阿里云大型语言模型的对话实践2【Qwen2部署实战】Qwen2初体验&#xff1a;用Transformers打造智能聊天机器人3【Qwen2部署实战】探索Qwen2-7B&#xff1a;通过FastApi框架实现API的部署与调用4【Q…

springboot双学位招生管理系统-计算机毕业设计源码93054

摘 要 科技进步的飞速发展引起人们日常生活的巨大变化&#xff0c;电子信息技术的飞速发展使得电子信息技术的各个领域的应用水平得到普及和应用。信息时代的到来已成为不可阻挡的时尚潮流&#xff0c;人类发展的历史正进入一个新时代。在现实运用中&#xff0c;应用软件的工作…

松下Panasonic机器人维修故障原因

松下机器人伺服电机是许多工业自动化设备的关键组成部分。了解如何进行Panasonic工业机械臂电机维修&#xff0c;对于确保设备正常运行至关重要。 【松下焊接机器人维修案例】【松下机器人维修故障排查】 一、常见松下工业机械手伺服电机故障及原因 1. 过热&#xff1a;过热可…

Webpack: 并行构建

概述 受限于 Node.js 的单线程架构&#xff0c;原生 Webpack 对所有资源文件做的所有解析、转译、合并操作本质上都是在同一个线程内串行执行&#xff0c;CPU 利用率极低&#xff0c;因此&#xff0c;理所当然地&#xff0c;社区出现了一些以多进程方式运行 Webpack&#xff0…

铜排载流量计算

母线载流量的理论计算 有些设计规范给出了根据电流密度确定母线大小的标准&#xff0c;一般铜母线的要求是每平方毫米载流量1.55A&#xff0c;但只可以作为设计“自由空气中的单导体母线”的参考&#xff0c;不可以作为实际设备中选择母线截面积的方法。也有些设计手册里给出了…

无线领夹麦克风选什么价位,揭秘无线领夹麦克风哪个品牌音质最好

在自媒体的快速发展之下&#xff0c;越来越多人加入到短视频拍摄行业。当我们踏出户外&#xff0c;想要用声音记录美好生活时&#xff0c;一个优质的麦克风便成了不可或缺的装备。户外环境的喧嚣与手机麦克风的局限性常常让我们的声音淹没在背景噪音之中&#xff0c;使得同期录…

小白也能看懂的Python基础教程(8)

Python面向对象 目录 Python面向对象 一、面向对象的概念 1、常见的编程思想 2、面向过程是什么&#xff1f; 3、什么是面向对象&#xff1f; 4、封装 5、继承 6、多态 二、面向对象的概念 1、两个重要概念 2、类 3、对象 4、self关键字 三、对象属性 1、什么…

昇思25天学习打卡营第8天|MindSpore保存与加载(保存和加载MindIR)

在MindIR中&#xff0c;一个函数图&#xff08;FuncGraph&#xff09;表示一个普通函数的定义&#xff0c;函数图一般由ParameterNode、ValueNode和CNode组成有向无环图&#xff0c;可以清晰地表达出从参数到返回值的计算过程。在上图中可以看出&#xff0c;python代码中两个函…

Unity Scrollview的Scrollbar控制方法

备忘&#xff1a;碰到用scrollview自带的scrollbar去控制滑动&#xff0c;结果发现用代码控制scrollbar.value无效&#xff0c;搜了一下都是说用scrollRect.verticalNormalizedPosition和scrollRect.horizontalNormalizedPosition来控制的。我寻思着有关联的scrollbar为什么用不…

【TB作品】智能台灯控制器,ATMEGA128单片机,Proteus仿真

题目 8 &#xff1a;智能台灯控制器 基于单片机设计智能台灯控制器&#xff0c;要求可以调节 LED 灯的亮度&#xff0c;实现定时开启与关闭&#xff0c; 根据光照自动开启与关闭功能。 具体要求如下&#xff1a; &#xff08;1&#xff09;通过 PWM 功能调节 LED 灯亮度&#x…

Jenkins 使用 Publish over SSH进行远程访问

Publish over SSH 是 Jenkins 的一个插件,可以让你通过 SSH 将构建产物分发到远程服务器。以下是如何开启 Publish over SSH 的步骤: 一、安装 Publish over SSH 插件 在 Jenkins 中,进入 "Manage Jenkins" > "Manage Plugins"。选择 "Availab…

储能锂电池出货量持续增长 国家政策推动行业发展速度加快

储能锂电池出货量持续增长 国家政策推动行业发展速度加快 储能锂电池又称锂离子储能电池&#xff0c;指专为储存电能而设计的锂离子电池。储能锂电池具有转换效率高、能量密度高、维护成本低、环境适应性强、响应速度快等优势&#xff0c;在数据中心、通信基站以及电力系统等领…

香橙派AIpro如何赋能AI+边缘流媒体设备

文章目录 &#xff08;一&#xff09;前言&#xff08;二&#xff09;AI边缘流媒体设备展示&#xff08;三&#xff09;赋能AI边缘流媒体设备1、准备开发环境2、在板子中下载编译安装SRS3、基本推拉流测试4、多路推流性能测试 &#xff08;四&#xff09;一些注意事项1、开发板…

webSocket网页通信---使用js模拟多页面实时通信

webSocket是什么 WebSocket是一种先进的网络技术&#xff0c;它提供了一种在单个TCP连接上进行全双工通信的能力。传统的基于HTTP的通信是单向的&#xff0c;即客户端发起请求&#xff0c;服务器响应请求&#xff0c;然后连接关闭。但是&#xff0c;WebSocket允许服务器和客户端…

Nginx系列(二)---Mac上的快速使用

一、安装 前置软件&#xff1a;Homebrew 安装方法&#xff1a;终端输入/bin/bash -c "$(curl -fsSL <https://cdn.jsdelivr.net/gh/ineo6/homebrew-install/install.sh>)"更新&#xff1a; brew update 设置中科大镜像源&#xff1a;git -C "$(brew --r…

【串口通信】之TTL电平

1. 什么是串口 串口,全称为串行通信端口,是一种计算机硬件接口,用于实现数据的串行传输。与并行通信不同,串口通信一次只传输一个比特,数据通过串行线按顺序传输。串口通信在嵌入式系统、工业控制、计算机与外围设备通信等领域非常常见 2. 什么是串口通信 串口通信是指通过…

和闺蜜的泰国之旅

每当我回想起那次和闺蜜丽丽&#xff08;全名罗莉&#xff09;的泰国之旅&#xff0c;心中总是涌起复杂的情绪。那段经历仿佛一场噩梦&#xff0c;至今仍无法从脑海中挥去。 我们满怀期待地抵达曼谷&#xff0c;热带的阳光、繁忙的街道、美味的街头小吃&#xff0c;都让我们兴…

Redis 管道(Pipeline)是什么?有什么用?

目录 1. redis 客户端-服务端模型的不足之处 2. redis 管道是什么&#xff1f;有什么好处&#xff1f; 3. 管道的使用场景 4. 管道使用的注意事项 1. redis 客户端-服务端模型的不足之处 众所周知&#xff0c;redis 是一个客户端-服务端的模型设计&#xff0c;客户端向服务…