前言
构建onnx方式通常有两种:
1、通过代码转换成onnx结构,比如pytorch —> onnx
2、通过onnx 自定义结点,图,生成onnx结构
本文主要是简单学习和使用两种不同onnx结构,
下面以 Suqeeze
结点进行分析
方式
方法一:pytorch --> onnx
暂缓,主要研究方式二
方法二: onnx
import onnx
from onnx import helper
from onnx import TensorProto
# 创建一个squeeze节点
def create_squeeze_node(input_name, output_name, axes):
node = helper.make_node(
'Squeeze',
inputs=[input_name],
outputs=[output_name],
axes=axes
)
return node
# 创建一个ONNX图
def create_onnx_graph():
input_name = 'input'
output_name = 'output'
axes = [1, 2] # 压缩的轴
# 创建输入张量
input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [1, 1, 3, 3])
# 创建输出张量
output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [3, 3])
# 创建squeeze节点
squeeze_node = create_squeeze_node(input_name, output_name, axes)
# 创建ONNX图
graph = helper.make_graph(
[squeeze_node],
'squeeze_graph',
[input_tensor],
[output_tensor]
)
# 创建ONNX模型
model = helper.make_model(graph)
return model
# 保存ONNX模型
def save_onnx_model(model, file_path):
onnx.save(model, file_path)
# 创建并保存ONNX模型
model = create_onnx_graph()
save_onnx_model(model, 'squeeze_model.onnx')