STM32 X-CUBE-AI:Pytorch模型部署全流程

news2024/11/16 12:21:49

文章目录

    • 概要
    • 版本:
    • 参考资料
    • STM32CUBEAI安装
    • CUBEAI模型支持
    • LSTM模型转换注意事项
    • 模型转换
    • 模型应用
      • 1 错误类型及代码
      • 2 模型创建和初始化
      • 3 获取输入输出数据变量
      • 4 获取模型前馈输出
      • 模型应用小结
    • 小结

概要

STM32 CUBE MX扩展包:X-CUBE-AI部署流程:模型转换、CUBEAI模型验证、CUBEAI模型应用
深度学习架构使用Pytorch模型,模型包括多个LSTM和全连接层(包含Dropout和激活函数层)。

版本:

STM32CUBEMX:6.8.1
X-CUBE-AI:8.1.0 (推荐该版本,对LSTM支持得到更新)
ONNX:1.14.0

参考资料

遇到ERROR和BUG可到ST社区提问:ST社区
CUBEAI入门指南下载地址:X-CUBE-AI入门指南手册
官方应用示例:部署示例

STM32CUBEAI安装

CUBEAI扩展包的安装目前已有许多教程,这里不再赘述。CUBEAI安装
需要注意的是,在STM32CUBEMX上安装CUBEAI时,可能并不能安装到最新版本的CUBEAI,因此可以前往ST官网下载最新版本(最新版本会对模型实现进行更新)。https://www.st.com/zh/embedded-software/x-cube-ai.html
在这里插入图片描述

CUBEAI模型支持

目前,CUBEAI支持三种类型的模型:

  1. Keras:.h5
  2. TensorFlow:.tflite
  3. 一切可以转换成ONNX格式的模型:.onnx

Pytorch部署CUBEAI需要将Pytorch生成模型.pth转换成.onnx。

LSTM模型转换注意事项

  1. 由于CUBEAI扩展包和ONNX对LSTM的转换限制,在Pytorch模型搭建时,需要设置LSTM的batch_first=False(并不影响模型训练和应用),设置后,需要注意输入输出数据的格式。
  2. LSTM模型内部全连接层的输入前,将数据切片,取最后时间步,可避免一些问题。
  3. 对于多LSTM,可以采取forward函数多输入x的形式,每一个x是一个LSTM的输入。

模型转换

  1. Pytorch->Onnx
    使用torch.onnx.export()函数,其中可动态设置变量,函数使用已有很多教程,暂不赘述。Pytorch转ONNX及验证
    由于部署后属于模型应用阶段,输入数据batch=1,seq_lengthinput_num可自行设置,也可设置动态参数(设置seq_length为动态参数后,CUBEMX验证的示例数据中的seq_length=1)。

  2. Onnx->STM32
    大致流程可参考: STM32 模型验证
    STM32CUBEMX中选中相应模型即可,可修改模型名称方便后续网络部署(不要使用默认network名字)。
    在这里插入图片描述
    验证阶段可能会发生多种问题,问题错误种类参见:
    在这里插入图片描述

模型应用

到目前为止,Pytorch模型已经成功转换成ONNX并在CUBEMX进行了验证且得到通过,下面则是STM32的模型应用部分。

通过keil打开项目后,在下面的目录下存放模型相关文件,主要用到的函数为:modelName.c、modelName.h,其中modelName为在CUBEMX中定义的模型名称。
在这里插入图片描述
主要使用的函数如下:

  1. ai_modelName_create_and_init:用于模型创建和初始化
  2. ai_modelName_inputs_get:用于获取模型输入数据
  3. ai_modelName_outputs_get:用于获取模型输出数据
  4. ai_pytorch_ftc_lstm_run:用于前馈运行模型得到输出
  5. ai_mnetwork_get_error:用于获取模型错误代码,调试用

各函数相关参数及用法如下。
相似代码可参见:X-CUBE-AI入门指南手册

1 错误类型及代码

/*!
 * @enum ai_error_type
 * @ingroup ai_platform
 *
 * Generic enum to list network error types.
 */
typedef enum {
  AI_ERROR_NONE                         = 0x00,     /*!< No error */
  AI_ERROR_TOOL_PLATFORM_API_MISMATCH   = 0x01,
  AI_ERROR_TYPES_MISMATCH               = 0x02,
  AI_ERROR_INVALID_HANDLE               = 0x10,
  AI_ERROR_INVALID_STATE                = 0x11,
  AI_ERROR_INVALID_INPUT                = 0x12,
  AI_ERROR_INVALID_OUTPUT               = 0x13,
  AI_ERROR_INVALID_PARAM                = 0x14,
  AI_ERROR_INVALID_SIGNATURE            = 0x15,
  AI_ERROR_INVALID_SIZE                 = 0x16,
  AI_ERROR_INVALID_VALUE                = 0x17,
  AI_ERROR_INIT_FAILED                  = 0x30,
  AI_ERROR_ALLOCATION_FAILED            = 0x31,
  AI_ERROR_DEALLOCATION_FAILED          = 0x32,
  AI_ERROR_CREATE_FAILED                = 0x33,
} ai_error_type;

/*!
 * @enum ai_error_code
 * @ingroup ai_platform
 *
 * Generic enum to list network error codes.
 */
typedef enum {
  AI_ERROR_CODE_NONE                = 0x0000,    /*!< No error */
  AI_ERROR_CODE_NETWORK             = 0x0010,
  AI_ERROR_CODE_NETWORK_PARAMS      = 0x0011,
  AI_ERROR_CODE_NETWORK_WEIGHTS     = 0x0012,
  AI_ERROR_CODE_NETWORK_ACTIVATIONS = 0x0013,
  AI_ERROR_CODE_LAYER               = 0x0014,
  AI_ERROR_CODE_TENSOR              = 0x0015,
  AI_ERROR_CODE_ARRAY               = 0x0016,
  AI_ERROR_CODE_INVALID_PTR         = 0x0017,
  AI_ERROR_CODE_INVALID_SIZE        = 0x0018,
  AI_ERROR_CODE_INVALID_FORMAT      = 0x0019,
  AI_ERROR_CODE_OUT_OF_RANGE        = 0x0020,
  AI_ERROR_CODE_INVALID_BATCH       = 0x0021,
  AI_ERROR_CODE_MISSED_INIT         = 0x0030,
  AI_ERROR_CODE_IN_USE              = 0x0040,
  AI_ERROR_CODE_LOCK                = 0x0041,
} ai_error_code;

2 模型创建和初始化

/*!
 * @brief Create and initialize a neural network (helper function)
 * @ingroup pytorch_ftc_lstm
 * @details Helper function to instantiate and to initialize a network. It returns an object to handle it;
 * @param network an opaque handle to the network context
 * @param activations array of addresses of the activations buffers
 * @param weights array of addresses of the weights buffers
 * @return an error code reporting the status of the API on exit
 */
AI_API_ENTRY
ai_error ai_modelName_create_and_init(
  ai_handle* network, const ai_handle activations[], const ai_handle weights[]);

重点关注输入参数networkactivations:数据类型均为ai_handle(即void*)。初始化方式如下:

	ai_error err;
	ai_handle network = AI_HANDLE_NULL;
	const ai_handle act_addr[] = { activations };
		
	// 实例化神经网络
	err = ai_modelName_create_and_init(&network, act_addr, NULL);
	if (err.type != AI_ERROR_NONE)
	{
		printf("E: AI error - type=%d code=%d\r\n", err.type, err.code);
	}

3 获取输入输出数据变量

/*!
 * @brief Get network inputs array pointer as a ai_buffer array pointer.
 * @ingroup pytorch_ftc_lstm
 * @param network an opaque handle to the network context
 * @param n_buffer optional parameter to return the number of outputs
 * @return a ai_buffer pointer to the inputs arrays
 */
AI_API_ENTRY
ai_buffer* ai_modelName_inputs_get(
  ai_handle network, ai_u16 *n_buffer);

/*!
 * @brief Get network outputs array pointer as a ai_buffer array pointer.
 * @ingroup pytorch_ftc_lstm
 * @param network an opaque handle to the network context
 * @param n_buffer optional parameter to return the number of outputs
 * @return a ai_buffer pointer to the outputs arrays
 */
AI_API_ENTRY
ai_buffer* ai_modelName_outputs_get(
  ai_handle network, ai_u16 *n_buffer);

需要先创建输入输出数据:

// 输入输出结构体
ai_buffer* ai_input;
ai_buffer* ai_output;

// 结构体内容如下
/*!
 * @struct ai_buffer
 * @ingroup ai_platform
 * @brief Memory buffer storing data (optional) with a shape, size and type.
 * This datastruct is used also for network querying, where the data field may
 * may be NULL.
 */
typedef struct ai_buffer_ {
  ai_buffer_format        format;     /*!< buffer format */
  ai_handle               data;       /*!< pointer to buffer data */
  ai_buffer_meta_info*    meta_info;  /*!< pointer to buffer metadata info */
  /* New 7.1 fields */
  ai_flags                flags;      /*!< shape optional flags */
  ai_size                 size;       /*!< number of elements of the buffer (including optional padding) */
  ai_buffer_shape         shape;      /*!< n-dimensional shape info */
} ai_buffer;

之后调用函数进行结构体赋值:

	ai_input = ai_modelName_inputs_get(network, NULL);
	ai_output = ai_modelName_outputs_get(network, NULL);

接下来需要对结构体中的data进行赋值,ai_input和ai_output均为输入输出地址,对于多输入形式的模型,可以数组索引多个输入:

// 单输入
ai_float *pIn;
ai_output[0].data = AI_HANDLE_PTR(pIn);

// 多输入
ai_float *pIn[]
for(int i=0; i<AI_MODELNAME_IN_NUM; i++)
	{
		ai_input[i].data = AI_HANDLE_PTR(pIn[i]);
	}
// 输出
ai_float *pOut;
ai_output[0].data = AI_HANDLE_PTR(pOut);

pIn为指针数组,数组内存储多个输入数据指针;AI_MODELNAME_IN_NUM为宏定义,表示输入数据数量。
AI_HANDLE_PTRai_handle类型宏定义,传入ai_float *指针,将数据转换成ai_handle类型。

#define AI_HANDLE_PTR(ptr_)           ((ai_handle)(ptr_))

4 获取模型前馈输出

/*!
 * @brief Run the network and return the output
 * @ingroup pytorch_ftc_lstm
 *
 * @details Runs the network on the inputs and returns the corresponding output.
 * The size of the input and output buffers is stored in this
 * header generated by the code generation tool. See AI_PYTORCH_FTC_LSTM_*
 * defines into file @ref pytorch_ftc_lstm.h for all network sizes defines
 *
 * @param network an opaque handle to the network context
 * @param[in] input buffer with the input data
 * @param[out] output buffer with the output data
 * @return the number of input batches processed (default 1) or <= 0 if it fails
 * in case of error the error type could be queried by 
 * using @ref ai_pytorch_ftc_lstm_get_error
 */
AI_API_ENTRY
ai_i32 ai_modelName_run(
  ai_handle network, const ai_buffer* input, ai_buffer* output);

函数传入网络句柄,输入输出buffer指针,返回处理的批次数量(应用阶段应该为1),可通过判断返回值是否为1,说明模型运行是否成功。

	printf("---------Running Network-------- \r\n");
	batch = ai_modelName_run(network, ai_input, ai_output);
	printf("---------Running End-------- \r\n");
	if (batch != BATCH) {
		err = ai_mnetwork_get_error(network);
		printf("E: AI error - type=%d code=%d\r\n", err.type, err.code);
		Error_Handler();
	}

运行后,可通过查看pOut数组数据得到模型输出。

void printData_(ai_float *pOut, ai_i8 num)
{
	printf("(Total Num: %d): ", num);
	for (int i=0; i < num; i++)
	{
		if (i == num-1)
		{
			printf("%.4f. \r\n", pOut[i]);
		}
		else
		{
			printf("%.4f, ", pOut[i]);
		}
	}
}

模型应用小结

可以根据官方部署示例中的方法对AI_InitAI_Run进行封装。

小结

遇到的BUG持续更新。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1197012.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ROS 多级tf坐标转换

题目 现有一移动机器人&#xff0c;该机器人的基坐标系为“base_link”&#xff0c;机器人包含3个子坐标系分别为“joint1”&#xff0c;“joint2”&#xff0c;“joint3”。 要求&#xff1a;利用多坐标转换&#xff0c;实现joint1下的坐标向joint2下的坐标转换&#xff0c;…

AMD64内存属性详解

本文参考文档为AMD64 Architecture Programmer’s Manual Volume 2: System Programming&#xff0c;版本号3.41&#xff0c;这不是对原英文文档的翻译&#xff0c;但是所有内容的排版都是根据原手册的排版来的&#xff0c;如有与官方文档冲突的内容&#xff0c;以官方文档为准…

[LeetCode]-622. 设计循环队列

目录 662. 设计循环队列 题目 思路 代码 662. 设计循环队列 622. 设计循环队列 - 力扣&#xff08;LeetCode&#xff09;https://leetcode.cn/problems/design-circular-queue/ 题目 设计你的循环队列实现。 循环队列是一种线性数据结构&#xff0c;其操作表现基于 FIFO&…

推荐几个宝藏app

立冬后&#xff0c;真尼玛冷&#xff0c;哎&#xff01;记得多穿点衣服呀&#xff0c;老铁们&#xff01;&#xff01; GKD 去广告神器 下载网址&#xff1a;https://github.com/gkd-kit/gkd 特性&#xff1a; 它不仅支持跳过开屏广告&#xff0c;还支持跳过弹窗广告等&#xf…

Shopee活动取消规则是什么?shopee官方促销活动怎么取消?

作为一家知名的电商平台&#xff0c;shopee官方对于消费者取消促销活动的请求给予了相应的规定和处理流程。 shopee活动取消规则是什么&#xff1f; 首先&#xff0c;消费者应该明确了解虾皮的促销活动取消规则。根据虾皮的官方规定&#xff0c;消费者在参与促销活动之前&…

公司注册股东选择几个人合适?

创业初期很多创业者都会选择有注册有限责任公司&#xff0c;有限责任由五十个以下的股东出资设立&#xff0c;每个股东以其所认缴的出资额为限对公司承担有限责任。那么问题来了股东人数选择几个最合适呢&#xff0c;下面上海注册公司网&#xff08;www.91kaiye.cn&#xff09;…

部署ruoyi-vue-plus和ruoyi-app

nginx.conf worker_processes 1;error_log /var/log/nginx/error.log warn; pid /var/run/nginx.pid;events {worker_connections 1024; }http {include mime.types;default_type application/octet-stream;sendfile on;keepalive_timeout 65;# 限制…

人工智能基础——Python:运行效率与时间复杂度

人工智能的学习之路非常漫长&#xff0c;不少人因为学习路线不对或者学习内容不够专业而举步难行。不过别担心&#xff0c;我为大家整理了一份600多G的学习资源&#xff0c;基本上涵盖了人工智能学习的所有内容。点击下方链接,0元进群领取学习资源,让你的学习之路更加顺畅!记得…

15 # 手写 throttle 节流方法

什么是节流 节流是限制事件触发的频率&#xff0c;当持续触发事件时&#xff0c;在一定时间内只执行一次事件&#xff0c;这个效果跟英雄联盟里的闪现技能释放差不多。 函数防抖关注一定时间连续触发的事件只在最后执行一次&#xff0c;而函数节流侧重于一段时间内只执行一次…

快速排序实现方法(剑指offer思路)

快速排序思想 从参与排序的数组中&#xff0c;选择一个数&#xff0c;把小于这个数的放在左边&#xff0c;大于这个数的放在右边&#xff0c;然后递归操作。 实现算法思路 选择最后一个当作参考值&#xff0c;使用small索引当作比这个数小的下标值遍历数组&#xff0c;如果小…

MySQL查询时间处理相关函数与方法实践笔记

1. 实践案例 在查询mysql数据库获取数据时&#xff0c;有这样一个需求&#xff1a;按每30分钟分组获取电量数据&#xff0c;形成1天48个数据点。 方法一&#xff1a; select hour(a.CreateTime) 时点,case when MINUTE(a.CreateTime)<30 then 1 else 2 end 半小时,sum(a…

思维模型 斯金纳箱原理

本系列文章 主要是 分享 思维模型&#xff0c;涉及各个领域&#xff0c;重在提升认知。通过合理奖惩&#xff0c;塑造行为&#xff0c;此名为“学习”。 1 斯金纳箱原理的应用 1.1 斯金纳箱在游戏设计中的应用-《糖果传奇》 《糖果传奇》是一款由 King 开发的三消游戏&#x…

基于SSM的培训机构运营系统

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;Vue 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 目录…

find和grep命令的简单使用

find和grep命令的简单使用 一、find例子--不同条件查找 二、grep正则表达式的简单说明例子--简单文本查找例子--结合管道进行查找 一、find find 命令在指定的目录下查找对应的文件。 find [path] [expression]● path 是要查找的目录路径&#xff0c;可以是一个目录或文件名…

链表OJ题(2)

目录 1.移除链表元素❓√ 2.反转链表 3.相交链表 4.链表的中间节点 5.链表中倒数第k个节点❓ 6.合并链表❓√ 7.分割链表❓ 今天链表面试OJ题目 移除链表元素反转链表相交链表链表的中间节点链表中倒数第k个节点合并链表分割链表 &#x1f642;起始条件 中间节点 结束条…

YOLOv8-Seg改进:分割注意力系列篇 | 高效多尺度注意力 EMA | ICASSP2023

🚀🚀🚀本文改进:EMA跨空间学习高效多尺度注意力引入到YOLOv8中进行二次创新,改进方法1)head层输出层结合;2)加入backbone; 🚀🚀🚀EMAAttention 亲测在多个数据集能够实现涨点,同样适用于小目标分割 🚀🚀🚀YOLOv8-seg创新专栏:http://t.csdnimg.cn/…

Postgres的级数生成函数generate_series应用

Postgres的级数生成函数generate_series应用 引用&#xff1a;http://postgres.cn/docs/12/functions-srf.html 函数文档 函数 参数类型 返回类型 描述 generate_series(start, stop) int、bigint或者numeric setof int、setof bigint或者setof numeric&#xff08;与参数类型相…

【推荐】一款AI写作大师、问答、绘画工具-「智元兔 AI」

在当今技术飞速发展的时代&#xff0c;越来越多的领域开始应用人工智能&#xff08;Artificial Intelligence&#xff0c;简称AI&#xff09;。其中&#xff0c;AI写作工具备受瞩目&#xff0c;备受推崇。在众多的选择中&#xff0c;智元兔AI是一款在笔者使用过程中非常有帮助的…

RetroMAE论文阅读

1. Introduction 在NLP常用的预训练模型通常是由token级别的任务进行训练的&#xff0c;如MLM和Seq2Seq&#xff0c;但是密集检索任务更倾向于句子级别的表示&#xff0c;需要捕捉句子的信息和之间的关系&#xff0c;一般主流的策略是自对比学习&#xff08;self-contrastive …

【图像分类】【深度学习】【Pytorch版本】 GoogLeNet(InceptionV2)模型算法详解

【图像分类】【深度学习】【Pytorch版本】 GoogLeNet(InceptionV2)模型算法详解 文章目录 【图像分类】【深度学习】【Pytorch版本】 GoogLeNet(InceptionV2)模型算法详解前言GoogLeNet(InceptionV2)讲解Batch Normalization公式InceptionV2结构InceptionV2特殊结构GoogLeNet(I…