第三章:PyTorch 转 ONNX 详解 — mmdeploy 0.12.0 文档
torch.onnx — PyTorch 2.0 documentation
torch.onnx.export
细解
计算图导出方法
TorchScript是一种序列化和优化PyTorch模型的格式,将torch.nn.Module模型转换为TorchScript的torch.jit.ScriptModule模型,也是一种中间表示。
torch.onnx.export中使用的模型实际上是torch.jit.ScriptModule。
将torch.nn.Module转化为TorchScript模型(导出计算图)有两种模式:跟踪(trace)和脚本化(script)。
torch.onnx.export输入一个torch.nn.Module,默认会使用跟踪(trace)的方法导出。
import torch
class Model(torch.nn.Module):
def __init__(self, n):
super().__init__()
self.n = n
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
for i in range(self.n):#控制输入张量被卷积的次数
x = self.conv(x)
return x
models = [Model(2), Model(3)]# n=2和n=3的模型
model_names = ['model_2', 'model_3']
for model, model_name in zip(models, model_names):
dummy_input = torch.rand(1, 3, 10, 10)
dummy_output = model(dummy_input)
model_trace = torch.jit.trace(model, dummy_input)
model_script = torch.jit.script(model)
#torch.onnx.export默认使用trace,所有不需要先trace
# 跟踪法与直接 torch.onnx.export(model, ...)等价
# torch.onnx.export(model_trace, dummy_input, f'{model_name}_trace.onnx', example_outputs=dummy_output, opset_version = 11)
torch.onnx.export(model, dummy_input,f'{model_name}_trace.onnx', example_outputs=dummy_output, opset_version = 11)
# 脚本化必须先调用 torch.jit.sciprt
torch.onnx.export(model_script, dummy_input, f'{model_name}_script.onnx', example_outputs=dummy_output)
# 如果是先运行了torch.jit.script,将模型转化成TorchScript,则export函数不需要再运行一遍
# 如果输入不是TorchScript,则export需要运行一遍模型
# dummy_input和dummy_output表示输入输出张量的数据类型和形状
跟踪法trace中,不同的n得到的ONNX模型结构是不一样的。
脚本法script中,Loop节点表示循环,不同的n可以有相同的结构。
推理引擎对静态图支持更好,不需要显式的将PyTorch模型转换为TorchScript,直接使用torch.onnx.export跟踪法导出即可。
虽然在代码中没有直接将trace的脚本作为export输入,但是可以通过trace来定位export问题是否出现在trace中。
参数讲解
def export(model, args, f, export_params=True, verbose=False, training=TrainingMode.EVAL,
input_names=None, output_names=None, aten=False, export_raw_ir=False,
operator_export_type=None, opset_version=None, _retain_param_name=True,
do_constant_folding=True, example_outputs=None, strip_doc_string=True,
dynamic_axes=None, keep_initializers_as_inputs=None, custom_opsets=None,
enable_onnx_checker=True, use_external_data_format=False):
- 模型(model):必选
- 输入(args):必选
- 导出的 onnx 文件名(f):必选
- 模型中是否保存权重(export_params):一般模型结构和模型权重放在一个文件里存储,所以默认是true,如果是在不同的框架间传递模型,而不是用于部署,则设置为false。
- 输入/输出张量名称(input_names, output_names):推理引擎一般都需要通过“名称-张量值”的数据对来输入数据,并根据输出张量的名称来获取输出数据,保证ONNX和推理引擎中使用同一套名称。
- opset_version:ONNX算子集版本。
- dynamic_axes:指定输入输出张量的哪些维度是动态的,为了追求效率,ONNX默认所有参与运算的张量都是静态的(张量的形状不发生改变)。可以显式的指明输入输出张量的哪几个维度的大小是可变的
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()#继承父类构造函数中
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
x = self.conv(x)
return x
model = Model()
dummy_input = torch.rand(1, 3, 10, 10)
model_names = ['model_static.onnx',
'model_dynamic_0.onnx',
'model_dynamic_23.onnx']
# dynamic_axes_0 = {#第0维动态
# 'in' : [0],
# 'out' : [0]
# }
dynamic_axes_0 = {
'in' : {0: 'batch'},
'out' : {0: 'batch'}
}
dynamic_axes_23 = {#第2、3维动态
'in' : [2, 3],
'out' : [2, 3]
}
torch.onnx.export(model, dummy_input, model_names[0], input_names=['in'], output_names=['out'])#没有动态维度
torch.onnx.export(model, dummy_input, model_names[1], input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_0)#第0维动态
torch.onnx.export(model, dummy_input, model_names[2], input_names=['in'], output_names=['out'], dynamic_axes=dynamic_axes_23)#第2、3维动态
# ONNX 要求每个动态维度都有一个名字,直接这样写会引出一条UserWarning,警告我们通过列表方式设置动态维度的话,系统会自动为它们分配名字
import onnxruntime
import numpy as np
origin_tensor = np.random.rand(1, 3, 10, 10).astype(np.float32)
mult_batch_tensor = np.random.rand(2, 3, 10, 10).astype(np.float32)
big_tensor = np.random.rand(1, 3, 20, 20).astype(np.float32)
inputs = [origin_tensor, mult_batch_tensor, big_tensor]
exceptions = dict()
model_names = ['model_static.onnx',#批量或者维度增加就会出错
'model_dynamic_0.onnx',#维度增加就会出错
'model_dynamic_23.onnx']#批量增加就会出错
for model_name in model_names:
for i, input in enumerate(inputs):
try:
ort_session = onnxruntime.InferenceSession(model_name)
ort_inputs = {'in': input}
ort_session.run(['out'], ort_inputs)#只有在设置了对应的动态维度后才不会出错
except Exception as e:
exceptions[(i, model_name)] = e
print(f'Input[{i}] on model {model_name} error.')
print(exceptions[(1, 'model_static.onnx')])
else:
print(f'Input[{i}] on model {model_name} succeed.')
使用技巧
torch.onnx.is_in_onnx_export():PyTorch推理时不运行,但是在执行torch.onnx.export()时为真。
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
x = self.conv(x)
if torch.onnx.is_in_onnx_export():# 仅在模型导出时把输出张量的数值限制在[0,1]之间
#可以在代码中添加和模型部署相关的逻辑
x = torch.clip(x, 0, 1)
return x
利用中断张量跟踪的操作
import torch
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
#item、for、list等方法都会导致ONNX模型不太正确
x = x * x[0].item()#跟踪法会把某些取决于输入的中间结果变成常量
# .item()把torch中的张量转换成普通的Python遍历
return x, torch.Tensor([i for i in x])#遍历torch张量,并用一个列表新建一个torch张量。
model = Model()
dummy_input = torch.rand(10)
torch.onnx.export(model, dummy_input, 'a.onnx')
涉及到张量与普通变量转换的逻辑都会导致最终ONNX模型不太正确。
利用这个性质,在保证正确性的前提下令模型中间结果变成常量。
这个技巧尝尝用于模型的静态化上,即令模型中所有张量形状都变成常量。
使用张量为输入(PyTorch版本 < 1.9.0)
PyTorch 对 ONNX 的算子支持
如果torch.onnx.export()正常执行后,另一个容易出现的问题就是算子不兼容。
在转换普通torch.nn.Module模型时:
- Pytorch利用跟踪法执行前向推理,把遇到的算子整合成计算图;
- Pytorch把遇到的算子翻译成ONNX定义的算子。
算子翻译的过程可能遇到的情况:
- 算子可以一对一翻译成ONNX算子。
- 算子没有一对一的ONNX算子,被翻译成一个或多个ONNX算子。
- 算子没有翻译成ONNX的规则。
ONNX 算子文档
onnx/Operators.md at main · onnx/onnx · GitHub
算子变更表格(算子名,算子变更版本号opset_version),第一次变更的版本号,表示算子第一次被支持,且第一个改动记录可以知道当前算子集中该算子的定义规则。
表格中的链接可以说明该算子的输入输出参数规定使用示例。
PyTorch 对 ONNX 算子的映射
pytorch/torch/onnx at master · pytorch/pytorch · GitHub
symbloic_opset{n}.py表示pytorch对应的ONNX算子集版本。
在vscode中限定在torch/onnx文件夹搜索对应算子
按照调用逻辑直接跳转到
@_onnx_symbolic(
"aten::upsample_bicubic2d",
decorate=[_apply_params("upsample_bicubic2d", 4, "cubic")],
)
->
@_beartype.beartype
def _interpolate(name: str, dim: int, interpolate_mode: str):
return symbolic_helper._interpolate_helper(name, dim, interpolate_mode)
->
@_beartype.beartype
def _interpolate_helper(name, dim, interpolate_mode):
@quantized_args(True, False, False)
def symbolic_fn(g, input, output_size, *args):
...
return symbolic_fn
symbolic_fn中插值算子被映射成多个ONNX算子,一个g.op对应ONNX
return g.op(
"Resize",
input,
empty_roi,
empty_scales,
output_size,
coordinate_transformation_mode_s=coordinate_transformation_mode,
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
mode_s=interpolate_mode, # nearest, linear, or cubic
nearest_mode_s="floor",
) # only valid when mode="nearest"
查找对应的ONNXonnx/Operators.md at main · onnx/onnx · GitHub resize算子定义,可以知道对应参数含义。
查询PyTorch到ONNX的映射关系,然后在torch.onnx.export()的opset_version设定一个版本号,然后去PyTorch符号表文件里去查。如果没有对应算子,就需要考虑用其他算子替代,或者自定义算子。
总结
- 跟踪法和脚本化在导出待控制语句的计算图时有什么区别。
- torch.onnx.export()中如何设置input_names, output_names, dynamic_axes。
- 使用torch.onnx.is_in_onnx_export()来使得模型在转换到ONNX时有不同的行为。
- 查询ONNX 算子文档。
- 查询ONNX算子对PyTorch算子支持情况。
- 查询ONNX算子对PyTorch算子使用方式。