一、定义
- torch.jit.trace 相关代码解读
- onnx 内部实现
3 查看是否为aten 算子 - aten 算子实现
- torch.autograd.Functions 算子实现
- 自定义算子实现
- 查找未实现的节点
- 一次性发现所有的未实现 aten 算子
二、实现
- torch.jit.trace 相关代码解读
1. torch.jit.script() : 将其转换为可运行的脚本。转换后的脚本可以像普通的 Python 函数一样调用,也可以保存到磁盘并在没有 PyTorch 依赖的环境中执行。
2. torch.jit.trace : 跟踪了给定输入张量的执行路径,因此在使用转换后的模块对象进行推理时,输入张量的维度和数据类型必须与跟踪时使用的相同。
3 查看是否为aten 算子
import torch
print(
torch.jit.trace(
torch.nn.ELU(), # module
torch.ones(1) # example input
).graph
)
算子追踪,
3. aten 算子实现
1.查看torch 接口定义 torch/nn/functional.pyi
2.查看onnx 算子命名 https://github.com/onnx/onnx/blob/main/docs/Operators.md
3. 查看注册函数书写 symbolic_opset9.py
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
# Define a custom symbolic function for aten::relu.
# The custom symbolic function is incorrect, which will result in mismatches.
#def relu(input: Tensor) -> Tensor: ... 查看接口定义,
def correct_relu_symbolic_function(g, input):
return g.op("Relu", input) #查看onnx 实现
torch.onnx.register_custom_op_symbolic( #注册
"aten::relu",
correct_relu_symbolic_function,
opset_version=opset_version,
)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(3, 4),
torch.nn.ReLU(),
torch.nn.Linear(4, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 6),
)
def forward(self, x):
return self.layers(x)
graph_info = torch.onnx.verification.find_mismatch(
Model(),
(torch.randn(2, 3),),
opset_version=opset_version,
)
- torch.autograd.Functions 算子实现
如果算子是torch.autograd.Functions 的子模块,可以使用该方法实现。
import torch
class MyRelu(torch.autograd.Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def symbolic(g: torch.Graph, input: torch.Value) -> torch.Value:
return g.op("Clip", input, g.op("Constant", value_t=torch.tensor(0, dtype=torch.float)))
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
myrelu = MyRelu.apply #核心
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(3, 4),
torch.nn.Linear(4, 5),
torch.nn.Linear(5, 6),
)
def forward(self, x):
return myrelu(self.layers(x))
graph_info = torch.onnx.verification.find_mismatch(
Model(),
(torch.randn(2, 3),),
opset_version=opset_version,
)
-
自定义算子实现
1. onnx 算子实现- 自定义c++ 算子 +Extending TorchScript with Custom C++ Operators 实现
-
查找未实现的节点
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
# Define a custom symbolic function for aten::relu.
# The custom symbolic function is incorrect, which will result in mismatches. 注册函数错误,导致find_mismatch 算子
def incorrect_relu_symbolic_function(g, self):
return self
torch.onnx.register_custom_op_symbolic(
"aten::relu",
incorrect_relu_symbolic_function,
opset_version=opset_version,
)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(3, 4),
torch.nn.ReLU(),
torch.nn.Linear(4, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 6),
)
def forward(self, x):
return self.layers(x)
graph_info = torch.onnx.verification.find_mismatch(
Model(),
(torch.randn(2, 3),),
opset_version=opset_version,
)
#===================== Mismatch info for graph partition : ======================
================================ Mismatch error ================================
Tensor-likes are not close!
Mismatched elements: 12 / 12 (100.0%)
Greatest absolute difference: 0.2328854203224182 at index (1, 2) (up to 1e-07 allowed)
Greatest relative difference: 0.699536174352349 at index (1, 3) (up to 0.001 allowed)
==================================== Tree: =====================================
5 X __2 X __1 \u2713
id: | id: 0 | id: 00
| |
| |__1 X (aten::relu)
| id: 01
|
|__3 X __1 \u2713
id: 1 | id: 10
|
|__2 X __1 X (aten::relu)
id: 11 | id: 110
|
|__1 \u2713
id: 111
=========================== Mismatch leaf subgraphs: ===========================
['01', '110']
============================= Mismatch node kinds: =============================
{'aten::relu': 2}
修改后:
aten 算子实现
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
# Define a custom symbolic function for aten::relu.
# The custom symbolic function is incorrect, which will result in mismatches.
#def relu(input: Tensor) -> Tensor: ... 查看接口定义,
def correct_relu_symbolic_function(g, input):
return g.op("Relu", input) #查看onnx 实现
torch.onnx.register_custom_op_symbolic( #注册
"aten::relu",
correct_relu_symbolic_function,
opset_version=opset_version,
)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(3, 4),
torch.nn.ReLU(),
torch.nn.Linear(4, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 6),
)
def forward(self, x):
return self.layers(x)
graph_info = torch.onnx.verification.find_mismatch(
Model(),
(torch.randn(2, 3),),
opset_version=opset_version,
)
方式二、
c++ 自定义算子
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
from torch.onnx import register_custom_op_symbolic # 为 TorchScript 算子补充注册符号函数
from torch.onnx.symbolic_helper import parse_args
# '''
# 装饰器 @parse_args 了。简单来说,TorchScript 算子的符号函数要求标注出每一个输入参数的类型。比如"v"表示 Torch 库里的 value 类型,
# 一般用于标注张量,而"i"表示 int 类型,"f"表示 float 类型,"none"表示该参数为空。具体的类型含义可以在 torch.onnx.symbolic_helper.py
# '''
@parse_args("v", "v")
def correct_relu_symbolic_function(g,input):
return g.op("Relu", input)
torch.onnx.register_custom_op_symbolic( #注册
"aten::relu",
correct_relu_symbolic_function,
opset_version=opset_version,
)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(3, 4),
torch.nn.ReLU(),
torch.nn.Linear(4, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 6),
)
def forward(self, x):
return self.layers(x)
graph_info = torch.onnx.verification.find_mismatch(
Model(),
(torch.randn(2, 3),),
opset_version=opset_version,
)
- 一次性发现所有的未实现 aten 算子
import torch
import torch.onnx.verification
torch.manual_seed(0)
opset_version = 15
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(3, 4),
torch.nn.ReLU(),
torch.nn.Linear(4, 5),
torch.nn.ReLU(),
torch.nn.Linear(5, 6),
)
def forward(self, x):
return self.layers(x)
torch_script_graph, unconvertible_ops = torch.onnx.utils.unconvertible_ops(
Model(), (torch.randn(2, 3),), opset_version=opset_version
)
print(set(unconvertible_ops))