1. ONNX 背景
ONNX 全称为 Open Neural Network Exchange,是微软提出并推广的一种机器学习模型的开放格式表示。ONNX定义了一组通用的算子集、一系列用于构建深度学习模型的模块以及一种通用的文件格式,使得人工智能开发人员能够将模型与各种框架、工具、运行时和编译器一起使用。ONNX可以理解为是 AI 算法框架与硬件平台之间的桥梁,AI算法研究人员可以使用任意的深度学习框架来设计并训练模型,训练完成后将模型转换成 ONNX 格式来进行存储,模型部署工程师可以针对 ONNX 这一中间格式来针对不同的硬件平台进行运行时设计和优化,从而实现AI模型设计和模型部署的解耦。
2. ONNX 结构分析
想要部署ONNX模型,我们首先需要了解 ONNX 模型的结构。神经网络模型是由 计算图 + 权重 组成的,计算图是一个有向无环的计算流程图,权重则是网络训练好的参数集合。关于 ONNX 更详尽的介绍,可以参考 ONNX 官方文档 onnx/docs at main · onnx/onnx · GitHub。
ONNX 模型是利用 ProtoBuf 这一数据结构存储协议来将模型序列化到硬盘上的。一个存储到本地的 .onnx 模型可以以结构化的方式解析成下图所示的各个部分,其中比较重要的部分加粗表示。加粗部分上方是当前结构的名称,下方是当前结构的类型,各个Proto类型的定义可以参考onnx/onnx/onnx.proto at main · onnx/onnx · GitHub。
2.1 查看onnx
- 可以使用 Netron 来可视化查看 ONNX 模型
- 可以使用 protoc 工具来解析 .onnx 模型文件。命令中 onnx.proto 是 ONNX 官方 repo 中的 Proto 定义。这条命令的含义是将super-resolution-10.onnx作为输入,按照 onnx.proto 定义从中提取 onnx.ModelProto 对象,并将结果重定向到 model.txt。
$ protoc --decode=onnx.ModelProto -I D:\Python\workspace\onnx_learn\onnx\onnx onnx.proto < D:\Python\workspace\onnx_learn\super-resolution-10.onnx > model.txt
2.2 使用 Python 来获取到 ONNX 模型结构
1) onnx model 结构
onnx model 是 ModelProto 类型,是 ONNX 模型最顶层的结构。其所包含的各个成员如下:
属性名 | 示例值 | 描述 |
ir_version | int64 | 模型的onnx IR版本。 |
opset_import | OperatorSetId | 可用于模型的算子集标识符集合。一个onnx实现中必须包含这个集合中的所有算子,否则将拒接模型。 |
producer_name | string | 生成这个onnx的生产者工具名称。 |
producer_version | string | 这个生产者工具的版本。 |
domain | string | onnx模型的命名空间,用反向域名命名,和java一样。 |
model_version | int64 | 模型本身的版本。 |
doc_string | string | 文档注释,可以是Markdown。 |
graph | Graph | 模型计算图。 |
metadata_props | map<string,string> | 元数据的键值对属性。 |
training_info | TrainingInfoProto[] | 包含训练信息的可选扩展。 |
functions | FunctionProto[] | 模型本地函数的可选列表。 |
使用Python API获取onnx model各成员
import onnx
model = onnx.load('super-resolution-10.onnx')
print(f"model.ir_version ---> {model.ir_version}")
print(f"model.opset_import ---> {model.opset_import}")
print(f"model.producer_name ---> {model.producer_name}")
print(f"model.producer_version ---> {model.producer_version}")
print(f"model.domain ---> {model.domain}")
print(f"model.model_version ---> {model.model_version}")
print(f"model.doc_string ---> {model.doc_string}")
print(f"model.metadata_props ---> {model.metadata_props}")
print(f"model.training_info ---> {model.training_info}")
print(f"model.functions ---> {model.functions}")
print(f"model.graph ---> {model.graph}")
2) model.graph 结构
model 中最重要的是 graph,类型是 GraphProto。 其所包含的各个成员如下:
属性名 | 示例值 | 描述 |
name | string | 模型计算图的名字 |
node | Node[] | 计算图中的算子集合(有向无环图的节点集),按照拓扑排序排列 |
initializer | Tensor[] | 计算图中的initializer,是一个tensor列表,通常存放模型的权重,可以理解为一个常量池。 |
doc_string | string | 文档注释 |
input | ValueInfo[] | 模型计算图的输入tensor列表。 |
output | ValueInfo[] | 模型计算图的输出tensor列表。 |
value_info | ValueInfo[] | 模型计算图除输入输出外中间tensor列表,当使用shape_inference时,推理出来的shape存储到这里,即 Netron 中看到的中间tensor的维度。 |
metadata_props | map<string,string> | 模型计算图的元数据(IR version >= 10) |
使用Python API获取 model.graph 各成员
print("-------------------- model.graph.name --------------------")
print(model.graph.name)
print("-------------------- model.graph.node --------------------")
print(model.graph.node)
print("-------------------- model.graph.initializer --------------------")
print(model.graph.initializer)
print("-------------------- model.graph.doc_string --------------------")
print(model.graph.doc_string)
print("-------------------- model.graph.input --------------------")
print(model.graph.input)
print("-------------------- model.graph.output --------------------")
print(model.graph.output)
print("-------------------- model.graph.value_info --------------------")
print(model.graph.value_info)
print("-------------------- model.graph.metadata_props (IR version >= 10) --------------------")
print(model.graph.metadata_props)
3) model.graph.node 结构
model.graph中的 node 是一个节点集列表,其中的每个元素节点均为 NodeProto 类型,所包含的成员如下:
属性名 | 示例值 | 描述 |
name | string | 节点的名字 |
input | string[] | 节点的输入列表,相当于计算图的输入边集。 |
output | string[] | 节点的输出列表,相当于计算图的输出边集。 |
op_type | string | 节点的类型,表明该节点的计算逻辑。 |
domain | string | ONNX中定义的节点集的域。由于 ONNX 是支持第三方拓展内置的算子集的,这个域唯一的指明节点的op_type,类似Java的包管理一样,用域名倒置表示。 |
attribute | Attribute[] | 节点的属性列表。例如Conv节点的kernel shape和padding等。 |
doc_string | string | 文档注释。 |
overload | string | 函数的唯一ID。(added in IR version 10) |
metadata_props | map<string,string> | 节点的元数据。(IR version >= 10) |
使用Python API获取 model.graph 各成员
# 打印 node 各个属性值
print("----------------------- model.graph.node[0].name -----------------------")
print(model.graph.node[0].name)
print("----------------------- model.graph.node[0].input -----------------------")
print(model.graph.node[0].input)
print("----------------------- model.graph.node[0].output -----------------------")
print(model.graph.node[0].output)
print("----------------------- model.graph.node[0].op_type -----------------------")
print(model.graph.node[0].op_type)
print("----------------------- model.graph.node[0].domain -----------------------")
print(model.graph.node[0].domain)
print("----------------------- model.graph.node[0].attribute -----------------------")
print(model.graph.node[0].attribute)
print("----------------------- model.graph.node[0].doc_string -----------------------")
print(model.graph.node[0].doc_string)
print("----------------------- model.graph.node[0].overload -----------------------")
print(model.graph.node[0].overload)
print("----------------------- model.graph.node[0].metadata_props -----------------------")
print(model.graph.node[0].metadata_props)
4) model.graph.initializer 结构
model.graph.initializer 是一个tensor列表,其中的元素类型为TensorProto。Initializer通常保存模型的权重参数,一些输入默认值也可以保存在这里,可以将其理解为一个tensor常量池。initializer每个元素的成员如下:
属性名 | 示例值 | 描述 |
name | string | 该tensor的名字 |
dims | int[] | 该tensor的维度 |
data_type | int | 该tensor的数据类型,不同的数值代表不同个的数据类型 |
raw_data | bytes | 该tensor保存的具体数据,二进制形式 |
doc_string | string | 文档注释 |
使用Python API获取 model.graph 各成员
print("----------------------- model.graph.initializer[0].name -----------------------")
print(model.graph.initializer[0].name)
print("----------------------- model.graph.initializer[0].dims -----------------------")
print(model.graph.initializer[0].dims)
print("----------------------- model.graph.initializer[0].data_type -----------------------")
print(model.graph.initializer[0].data_type)
print("----------------------- model.graph.initializer[0].raw_data -----------------------")
# 二进制表示,打印出来可能会很长
print(model.graph.initializer[0].raw_data)
5)model.graph.input & output & value_info 结构
graph 中的 input、output 和 value_info 均为一个列表,可以使用index进行索引。Input为计算图的所有输入,output是计算图的所有输出,value_info则为计算图中所有中间计算结果tensor的信息。当使用
onnx.shape_inference.infer_shapes()推理所有中间tensor的维度时,这些信息均会保存在value_info中。input、output 和 value_info的每个元素类型为ValueInfoProto,其包含的成员如下
属性名 | 示例值 | 描述 |
name | string | 当前值的名字 |
type | TypeProto | 当前值的类型,这其中包含当前值的数据类型和维度 |
使用Python API获取 model.graph 各成员
# 以第一个input为例
print("----------------------- model.graph.input[0].name -----------------------")
print(model.graph.input[0].name)
print("----------------------- model.graph.input[0].type -----------------------")
print(model.graph.input[0].type)
3. 总结
本文重点解析了 ONNX 模型结构,并演示了如何使用Python定位到ONNX模型各个层面的元素。在得到不同元素之后,我们可以对ONNX模型进行适当的修改,使其更加适配我们的后端运行时,进一步提高推理性能。我们在之后的文章会介绍如何使用 ONNX 官方的 API 来修改ONNX模型。
作者:高通工程师,阮慧源(Huiyuan Ruan)