问题1:repeat_interleave 无法转译
具体报错为:
TypeError: 'torch._C.Value' object is not iterable
(Occurred when translating repeat_interleave).
原因是我的模型代码中有:
batch_indices = torch.repeat_interleave(torch.arange(cand_nums.shape[0]).to(device), cand_nums)
percep_feats_expanded = percep_feats[batch_indices] # shape [ΣN_i, D_f, H', W']
这里cand_nums
尺寸不固定,而在较低版本的 pytorch 中,onnx 静态图与 pytorch 动态张量操作不兼容,导致 repeat_interleave
操作无法转译。我这里cand_nums
实际只在训练时可能有多个元素,推理时必定只有一个元素,加个if
就解决了:
batch_size = cand_nums.shape[0]
if batch_size > 1:
batch_indices = torch.repeat_interleave(torch.arange(cand_nums.shape[0]).to(device), cand_nums)
percep_feats_expanded = percep_feats[batch_indices] # shape [ΣN_i, D_f, H', W']
else:
percep_feats_expanded = percep_feats.repeat(cand_nums[0], 1, 1, 1) # shape [N, D_f, H', W']
后来发现高版本的 pytorch 已经兼容这种动态张量操作,但上面这样的if
操作或许可以提升onnx的计算效率,毕竟推理时其实不需要动态大小
问题2:expected scalar type Long but found Float
报错的堆栈大致为:
[W shape_type_inference.cpp:419] Warning: Constant folding in symbolic shape inference fails: expected scalar type Long but found Float
Exception raised from data_ptr<long int> at /pytorch/build/aten/src/ATen/core/TensorMethods.cpp:5759 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7efd098d3a22 in /home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x5b (0x7efd098d03db in /home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: long* at::Tensor::data_ptr<long>() const + 0xde (0x7efd0ba1683e in /home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #3: torch::jit::onnx_constant_fold::runTorchSlice_opset10(torch::jit::Node const*, std::vector<at::Tensor, std::allocator<at::Tensor> >&) + 0x47e (0x7efdae76068e in /home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #4: torch::jit::onnx_constant_fold::runTorchBackendForOnnx(torch::jit::Node const*, std::vector<at::Tensor, std::allocator<at::Tensor> >&, int) + 0x1c5 (0x7efdae761985 in /home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #5: <unknown function> + 0xaed38e (0x7efdae7a038e in /home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #6: torch::jit::ONNXShapeTypeInference(torch::jit::Node*, std::map<std::string, c10::IValue, std::less<std::string>, std::allocator<std::pair<std::string const, c10::IValue> > > const&, int) + 0x906 (0x7efdae7a5146 in /home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #7: <unknown function> + 0xaf4df4 (0x7efdae7a7df4 in /home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #8: <unknown function> + 0xa71010 (0x7efdae724010 in /home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0x5015fe (0x7efdae1b45fe in /home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #10: PyCFunction_Call + 0x52 (0x4f5572 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #11: _PyObject_MakeTpCall + 0x3bb (0x4e0e1b in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #12: _PyEval_EvalFrameDefault + 0x4dfc (0x4dcf0c in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #13: _PyEval_EvalCodeWithName + 0x2f1 (0x4d70d1 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #14: _PyFunction_Vectorcall + 0x19c (0x4e823c in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #15: _PyEval_EvalFrameDefault + 0x1153 (0x4d9263 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #16: _PyEval_EvalCodeWithName + 0x2f1 (0x4d70d1 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #17: _PyFunction_Vectorcall + 0x19c (0x4e823c in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #18: _PyEval_EvalFrameDefault + 0x1153 (0x4d9263 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #19: _PyEval_EvalCodeWithName + 0x2f1 (0x4d70d1 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #20: _PyFunction_Vectorcall + 0x19c (0x4e823c in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #21: _PyEval_EvalFrameDefault + 0x1153 (0x4d9263 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #22: _PyEval_EvalCodeWithName + 0x2f1 (0x4d70d1 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #23: _PyFunction_Vectorcall + 0x19c (0x4e823c in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #24: _PyEval_EvalFrameDefault + 0x49b1 (0x4dcac1 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #25: _PyEval_EvalCodeWithName + 0x2f1 (0x4d70d1 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #26: _PyFunction_Vectorcall + 0x19c (0x4e823c in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #27: _PyEval_EvalFrameDefault + 0x1153 (0x4d9263 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #28: _PyFunction_Vectorcall + 0x106 (0x4e81a6 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #29: _PyEval_EvalFrameDefault + 0x399 (0x4d84a9 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #30: _PyEval_EvalCodeWithName + 0x2f1 (0x4d70d1 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #31: PyEval_EvalCodeEx + 0x39 (0x585e29 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #32: PyEval_EvalCode + 0x1b (0x585deb in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #33: /home/abc/anaconda3/envs/pmapnet/bin/python() [0x5a5bd1]
frame #34: /home/abc/anaconda3/envs/pmapnet/bin/python() [0x5a4bdf]
frame #35: /home/abc/anaconda3/envs/pmapnet/bin/python() [0x45c538]
frame #36: PyRun_SimpleFileExFlags + 0x340 (0x45c0d9 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #37: /home/abc/anaconda3/envs/pmapnet/bin/python() [0x44fe8f]
frame #38: Py_BytesMain + 0x39 (0x579e89 in /home/abc/anaconda3/envs/pmapnet/bin/python)
frame #39: __libc_start_main + 0xf3 (0x7efdd0895083 in /lib/x86_64-linux-gnu/libc.so.6)
frame #40: /home/abc/anaconda3/envs/pmapnet/bin/python() [0x579d3d]
(function ComputeConstantFolding)
...
Traceback (most recent call last):
File "/home/abc/Sources/MyModel/utils/export.py", line 26, in <module>
export_onnx(model, 'path_matcher.onnx')
File "/home/abc/Sources/MyModel/utils/export.py", line 17, in export_onnx
torch.onnx.export(model, (sdmap_labels, cand_nums, percep_labels, percep_weights), output_path, verbose=True,
File "/home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/onnx/__init__.py", line 275, in export
return utils.export(model, args, f, export_params, verbose, training,
File "/home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/onnx/utils.py", line 88, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names,
File "/home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/onnx/utils.py", line 689, in _export
_model_to_graph(model, args, verbose, input_names,
File "/home/abc/anaconda3/envs/pmapnet/lib/python3.8/site-packages/torch/onnx/utils.py", line 501, in _model_to_graph
params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
RuntimeError: expected scalar type Long but found Float
错误信息看起来是类型不一致,但没有指出具体代码位置。
参考 #66623,疑似是 MultiheadAttention
在 pytorch v1.9 中导出的问题,升级到 v1.10 可解,其它帖子也有印证这一说法。我把 pytorch 升级到 2.6.0 这个问题确实没了
问题3:C++工程加载onnx模型时报Concat类型不匹配
C++代码中加载模型时抛出异常:
terminate called after throwing an instance of 'Ort::Exception'
what(): Load model from my_model.onnx failed:Type Error: Type parameter (T) of Optype (Concat) bound to different types (tensor(int32) and tensor(int64) in node (/Concat_1).
错误信息给出了出错的位置,就比较好排查了,用 netron 打开 onnx 文件,搜索定位到 Concat_1 这个操作的位置:
显然对应于代码中的这一行:
percep_feats_expanded = percep_feats.repeat(cand_nums[0], 1, 1, 1)
这里repeat
操作的后三个参数就是图中 Concat1 的后面三个 Constant_output,它们都被当成了 int64 类型,而cand_nums[0]
是 int32 类型的 tensor 元素,所以导致了 Concat 类型不匹配的问题。python 中因为隐式类型转换而不报错,onnx 转换后在 C++ 中对类型有严格要求,就会报错。解决办法是把cand_nums
改成 int64 类型