Mxnet转Onnx 踩坑记录

news2024/9/29 19:22:29

0. 前言

使用将MXNET模型转换为ONNX的过程中有很多算子不兼容,在此对那些不兼容的算子替换。在此之前需要安装mxnet分支v1.x版本作为mx2onnx的工具,git地址如下:

mxnet/python/mxnet/onnx at v1.x · apache/mxnet · GitHub

同时还参考了如下的两个知乎链接:

https://zhuanlan.zhihu.com/p/166267806

https://zhuanlan.zhihu.com/p/165294876

1. UpSample

1.1 采用Resize实现

mxnet/contrib/onnx/mx2onnx/_op_translations.py
def create_helper_tensor_node(input_vals, output_name, kwargs):

    """create extra tensor node from numpy values"""

    data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[input_vals.dtype]



    tensor_node = onnx.helper.make_tensor_value_info(

        name=output_name,

        elem_type=data_type,

        shape=input_vals.shape

    )

    kwargs["initializer"].append(

        onnx.helper.make_tensor(

            name=output_name,

            data_type=data_type,

            dims=input_vals.shape,

            vals=input_vals.flatten().tolist(),

            raw=False,

        )

    )



    return tensor_node



@mx_op.register("UpSampling")

def convert_upsample(node, **kwargs):

    """Map MXNet's UpSampling operator attributes to onnx's Upsample operator

    and return the created node.

    """

   

    name, input_nodes, attrs = get_inputs(node, kwargs)

   

    sample_type = attrs.get('sample_type', 'nearest')

    sample_type = 'linear' if sample_type == 'bilinear' else sample_type

    scale = convert_string_to_list(attrs.get('scale'))

    scaleh = scalew = float(scale[0])

    if len(scale) > 1:

        scaleh = float(scale[0])

        scalew = float(scale[1])

    scale = np.array([1.0, 1.0, scaleh, scalew], dtype=np.float32)

    roi = np.array([], dtype=np.float32)

    node_roi=create_helper_tensor_node(roi, name+'roi', kwargs)

    node_sca=create_helper_tensor_node(scale, name+'scale', kwargs)

   

    node = onnx.helper.make_node(

        'Resize',

        inputs=[input_nodes[0], name+'roi', name+'scale'],

        outputs=[name],

        coordinate_transformation_mode='asymmetric',

        mode=sample_type,

        nearest_mode='floor',

        name=name

    )

    return [node_roi, node_sca, node]

1.2 采用ConvTranspose实现

@mx_op.register("UpSampling")

def convert_upsample(node, **kwargs):

    """Map MXNet's UpSampling operator attributes to onnx's Upsample operator

    and return the created node.

    """

    import math

    name, inputs, attrs = get_inputs(node, kwargs)

   

    #==NearestNeighbor ==

    channels=64  #此处需要手动修改!!!

    scale=int(attrs.get('scale'))

    pad=math.floor((scale - 1)/2.0)

    weight = np.ones((channels,1,scale,scale), dtype=np.float32)

    weight_node=create_helper_tensor_node(weight, name+'_weight', kwargs)

    pad_dims = [pad, pad]

    pad_dims = pad_dims + pad_dims

    #print(pad_dims)

    deconv_node = onnx.helper.make_node(

        "ConvTranspose",

        inputs=[inputs[0], name+'_weight'],

        outputs=[name],

        auto_pad="VALID",

        strides=[scale, scale],

        kernel_shape=[scale,scale],

        pads=pad_dims,

        group=channels,

        name=name)

   

return [deconv_node]

2. 使用Slice代替Crop.

mxnet/contrib/onnx/mx2onnx/_op_translations.py

def create_helper_shape_node(input_node, node_name):

    """create extra transpose node for dot operator"""

    trans_node = onnx.helper.make_node(

        'Shape',

        inputs=[input_node],

        outputs=[node_name],

        name=node_name

    )

   

    return trans_node



@mx_op.register("Crop")

def convert_crop(node, **kwargs):

    """Map MXNet's crop operator attributes to onnx's Crop operator

    and return the created node.

    """

   

    name, inputs, attrs = get_inputs(node, kwargs)

   

    start=np.array([0, 0, 0, 0], dtype=np.int) #index是int类型

    start_node=create_helper_tensor_node(start, name+'__starts', kwargs)

    shape_node = create_helper_shape_node(inputs[1], inputs[1]+'__shape')

   

    crop_node = onnx.helper.make_node(

        "Slice",

        inputs=[inputs[0], name+'__starts', inputs[1]+'__shape'], #data、start、end

        outputs=[name],

        name=name

    )

   

    logging.warning(

        "Using an experimental ONNX operator: Crop. " \

        "Its definition can change.")



    return [start_node, shape_node, crop_node]

3. ONNX softmax维度转换问题

onnx的实现的softmax在处理多维输入(NCHW)存在问题。

@mx_op.register("softmax")

def convert_softmax(node, **kwargs):

    """Map MXNet's softmax operator attributes to onnx's Softmax operator

    and return the created node.

    """

    name, input_nodes, attrs = get_inputs(node, kwargs)

    axis = int(attrs.get("axis", -1))

   

    c_softmax_node = []

    axis=-1

   

    transpose_node1 = onnx.helper.make_node(

        "Transpose",

        inputs=input_nodes,

        perm=(0,2,3,1), #NCHW--NHWC--(NHW,C)

        name=name+'_tr1',

        outputs=[name+'_tr1']

    )

   

    softmax_node = onnx.helper.make_node(

        "Softmax",

        inputs=[name+'_tr1'],

        axis=axis,

        name=name+'',

        outputs=[name+'']

    )

   

    transpose_node2 = onnx.helper.make_node(

        "Transpose",

        inputs=[name+''],

        perm=(0,3,1,2), #NHWC--NCHW

        name=name+'_tr2',

        outputs=[name+'_tr2']

    )

   

    c_softmax_node.append(transpose_node1)

    c_softmax_node.append(softmax_node)

    c_softmax_node.append(transpose_node2)

   

    return c_softmax_node

https://pic4.zhimg.com/80/v2-749aceeb26827b81603c3862f010cafb_720w.jpg

4. MaxPool 一致性对应不上问题,ceil设置

5. AvgPool  count_include_pad问题

AvgPool一致性对不上的时候,查看参数设置是否正确

6. AdaptiveAvgPooling2D 不支持

将AdaptiveAvgPooling2D 固定尺寸,转换为AvgPooling2D

7. onnx check错误,说明有些算子check不通过(mxnet自带的bug),安装mx2onnx v1.x重新再转一次

8. SoftmaxActivation

在mxnet中,SoftmaxActivation表明This operator has been deprecated.

https://pic1.zhimg.com/80/v2-4139577e2b687590c488ed114023eb20_720w.jpg

解决办法:手动修改SoftmaxActivation的op为softmax,axis=1对应channel。

9. FullConnect 全连接层转换不兼容

       

使用最新版本的mxnet 转,不要用mx2onnx v1.x里的注册函数

10. L2Normalization不支持instance模式

解决方法:无解,只有将L2Normal放到外面做

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1913655.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

李良济“小儿推拿妈妈班”圆满结课,以中医智慧守护儿童健康成长!

孩子生场病,妈妈半条命!作为妈妈最害怕的就是孩子生病,自己又无能为力! 为了帮助妈妈们,正确应对孩子健康问题,日常生活中科学帮助孩子提升体质少生病! 参加此次课程的,不仅有妈妈&a…

8.7结构体const使用场景

代码 #include <iostream> using namespace std; #include <string>//const使用场景//定义学生结构体 struct student {string name;int age;int score; };//将函数中的形参改为指针&#xff0c;可以减少内存空间&#xff0c;而且不会复制出新的副本 void printSt…

Spring Cloud LoadBalancer 入门与实战

一、什么是 LoadBalancer? LoadBalancer(负载均衡器) 是一种网络设备或软件机制&#xff0c;用于分发传入的网络流量负载&#xff08;请求&#xff09;到多个后端目标服务器上&#xff0c;从而实现系统资源的均衡利用和提高系统的可用性和新能。 1.1 负载均衡分类 负载均衡…

微信小程序中wx.navigateBack()页面栈返回上一页时执行上一页的方法或修改上一页的data属性值

let pages getCurrentPages();let prevPage pages[pages.length - 2]; // 获取上一个页面实例对象console.log(prevPage) //打印信息// 在 wx.navigateBack 的 success 回调中执行需要的方法wx.navigateBack({delta: 1, // 返回上一页success: function() {//修改上一页的属性…

Oracle基础以及一些‘方言’(二)

1、Oracle的查询语法结构 Oracle 的单表查询的语法结构&#xff1a; SELECT 1 FROM 2 WHERE 3 GROUP BY 4 HAVING 5 ORDER BY 6 其每个关键词的功能与MySQL中的功能已知&#xff0c;不过分页查询的关键词 limit 并不在Oracle的语法结构中。伪列&#xff1a; 在 Oracle 的表的使…

三品PLM管理系统软件:制造企业工程变更管理的革新者

在当今快速变化的市场环境中&#xff0c;制造企业面临着前所未有的挑战。客户需求的不断变化、供应链的波动、设计过程中的不确定性以及产品生命周期的缩短&#xff0c;都要求企业能够迅速响应并适应这些变化。工程变更管理作为企业响应市场变化、提升产品竞争力的关键环节&…

Loadlibrary failed with error 87:参数错误

问题描述&#xff1a; win10 系统在安装 Photoshop 2022 版后&#xff0c;点击桌面图标提示&#xff1a;Loadlibrary failed with error 87&#xff1a;参数错误&#xff0c;反复出现&#xff0c;反复确定&#xff0c;直至软件关闭。 解决方法&#xff1a; 1. 找到 C:\Window…

Kafka安装使用指南

Kafka是一种高吞吐量的分布式发布订阅消息系统。 Kafka启动方式有Zookeeper和Kraft&#xff0c;两种方式只能选择其中一种启动&#xff0c;不能同时使用。 【Kafka安装】 Kafka下载 https://downloads.apache.org/kafka/3.7.1/kafka_2.13-3.7.1.tgz Kafka解压 tar -xzf kafka_…

服务器数据恢复—raid5阵列热备盘没有激活导致阵列崩溃的数据恢复案例

服务器存储数据恢复环境&#xff1a; 一台EMC存储中有一组raid5磁盘阵列&#xff0c;划分1个lun供小型机使用&#xff0c;上层采用ZFS文件系统。 服务器存储故障&#xff1a; 一台有一组raid5磁盘阵列的存储在运行过程中突然崩溃。管理员检查发现存储中的raid5阵列有两块硬盘离…

大模型时代:人工智能与大数据平台的深度融合

在当今的大数据时代&#xff0c;数据已经成为驱动业务增长和创新的关键因素。与此同时&#xff0c;随着人工智能技术的不断进步&#xff0c;AI在大规模数据处理和分析方面的能力日益强大。因此&#xff0c;将人工智能与大数据平台相结合&#xff0c;可以为企业带来巨大的商业价…

✈️一文带你入门【NestJS】

✈️引言 在现代Web开发领域&#xff0c;框架和技术的迭代速度令人咋舌。其中&#xff0c;NestJS作为一款基于Node.js的后端框架&#xff0c;以其卓越的设计理念和强大的功能集&#xff0c;迅速吸引了众多开发者的眼球。本文将带你深入了解NestJS的起源、发展&#xff0c;以及…

LangChain教程 – 如何构建自定义知识聊天机器人

您可能已经了解到过去几个月发布的大量 AI 应用程序。您甚至可能已经开始使用其中的一些。 ChatPDF和CustomGPT AI等 AI 工具已经对人们变得非常有用——这是有充分理由的。您需要滚动浏览 50 页文档才能找到简单答案的日子已经一去不复返了。相反&#xff0c;您可以依靠 AI 来…

mysql 9 新特性

mysql9新特性 新特性Audit Log NotesC API NotesCharacter Set SupportCompilation NotesComponent NotesConfiguration NotesData Dictionary NotesData Type NotesDeprecation and Removal NotesEvent Scheduler NotesJavaScript ProgramsOptimizer NotesPerformance Schema …

Linux初始化新的git仓库

1.在git服务器上找到项目常部署的git地址可以根据其他项目的git地址确认 例如ssh://git192.168.10.100/opt/git/repository.git 用户名&#xff1a;git&#xff08;前面的是用户&#xff09; 服务器地址&#xff1a;192.168.10.100 git仓库路径&#xff1a;/opt/git/ 2.在服务器…

js 图片放大镜

写购物项目的时候&#xff0c;需要放大图片&#xff0c;这里用js写了一个方法&#xff0c;鼠标悬浮的时候放大当前图片 这个是class写法 <!--* Descripttion: * Author: 苍狼一啸八荒惊* LastEditTime: 2024-07-10 09:41:34* LastEditors: 夜空苍狼啸 --><!DOCTYPE …

IP 地址与 CDN 性能优化

内容分发网络&#xff08;CDN&#xff09;就是通过内容分配到离用户最优的服务器来提高访问速度。而IP地址如何分配与管理就是CND技术的基础。本文将来探讨介绍CDN中的IP地址分配与管理&#xff0c;以及如何通过CDN优化网络性能。 首先我们来了解CDN的基本原理 CDN是一种分布式…

数据库之DML

1&#xff0c;创建表 mysql> create table student(-> id int primary key,-> name varchar(20) not null,-> grade float-> );插入记录 mysql> insert into student values(1,monkey,98.5); Query OK, 1 row affected (0.01 sec)一次性插入多条记录 mysql…

百问网全志D1h开发板MIPI屏适配

MIPI屏适配 100ASK-D1-H_DualDisplay-DevKit V11 1. 显示适配 1.1 修改设备树 1.1.1 修改内核设备树 进入目录&#xff1a; cd /home/ubuntu/tina-d1-h/device/config/chips/d1-h/configs/nezha/linux-5.4修改board.dts: &lcd0 {lcd_used <1>;lcd…

MP | 基于kmer的泛基因组分析方法及应用

2024年5月24日&#xff0c;中国农业大学分子设计育种前沿科学中心作物杂种优势与利用教育部重点实验室郭伟龙与姚颖垠团队在《Molecular Plant》发表了题为《A k-mer-based pangenome approach for cataloging seed-storage-protein genes in wheat to facilitate genotype-to-…

成都云飞浩容文化传媒有限公司怎么样?

在电商行业风起云涌的今天&#xff0c;成都云飞浩容文化传媒有限公司以其独特的视角和专业的服务&#xff0c;成为了这一领域的佼佼者。今天&#xff0c;就让我们一起走进云飞浩容&#xff0c;探索其背后的故事和成功的秘诀。 一、专注电商&#xff0c;用心服务 成都云飞浩容文…