PyTorch 内 LibTorch/TorchScript 的使用

news2024/9/22 17:37:48

PyTorch 内 LibTorch/TorchScript 的使用

  • 1. .pt .pth .bin .onnx 格式
    • 1.1 模型的保存与加载到底在做什么?
    • 1.2 为什么要约定格式?
    • 1.3 格式汇总
      • 1.3.1 .pt .pth 格式
      • 1.3.2 .bin 格式
      • 1.3.3 直接保存完整模型
      • 1.3.4 .onnx 格式
      • 1.3.5 jit.trace
      • 1.3.6 jit.script
    • 1.4 总结
  • 2. TorchScript 的转换
    • 2.1 jit trace 注意事项
    • 2.2 jit trace 验证技巧
    • 2.3 混合使用 trace 和 script
    • 2.4 trace 和 script 的性能
    • 2.5 总结
  • 3. LibTorch 的使用
    • 3.1 LibTorch 的链接
    • 3.2 接口和实现

Reference:

  1. [Pytorch].pth转.pt文件
  2. Pytorch格式 .pt .pth .bin .onnx 详解
  3. pytorch 基于tracing/script方式转ONNX

在这里插入图片描述

1. .pt .pth .bin .onnx 格式

1.1 模型的保存与加载到底在做什么?

我们在使用pytorch构建模型并且训练完成后,下一步要做的就是把这个模型放到实际场景中应用,或者是分享给其他人学习、研究、使用。因此,我们开始思考一个问题,提供哪些模型信息,能够让对方能够完全复现我们的模型?

  • 模型代码
    1. 包含了我们如何定义模型的结构,包括模型有多少层/每层有多少神经元等等信息;
    2. 包含了我们如何定义的训练过程,包括epoch batch_size等参数;
    3. 包含了我们如何加载数据和使用;
    4. 包含了我们如何测试评估模型。
  • 模型参数:提供了模型代码之后,对方确实能够复现模型,但是运行的参数需要重新训练才能得到,而没有办法在我们的模型参数基础上继续训练,因此对方还希望我们能够把模型的参数也保存下来给对方。
    1. 包含model.state_dict(),这是模型每一层可学习的节点的参数,比如weight/bias;
    2. 包含optimizer.state_dict(),这是模型的优化器中的参数;
    3. 包含我们其他参数信息,如epoch/batch_size/loss等。
  • 数据集
    1. 包含了我们训练模型使用的所有数据;
    2. 可以提示对方如何去准备同样格式的数据来训练模型。
  • 使用文档
    1. 根据使用文档的步骤,每个人都可以重现模型;
    2. 包含了模型的使用细节和我们相关参数的设置依据等信息。

可以看到,根据我们提供的模型代码/模型参数/数据集/使用文档,我们就可以有理由相信对方是有手就会了,那么目的就达到了。

现在我们反转一下思路,我们希望别人给我们提供模型的时候也能够提供这些信息,那么我们就可以拿捏住别人的模型了。

1.2 为什么要约定格式?

根据上一段的思路,我们知道模型重现的关键是模型结构/模型参数/数据集,那么我们提供或者希望别人提供这些信息,需要一个交流的规范,这样才不会1000个人给出1000种格式,而 .pt .pth .bin 以及 .onnx 就是约定的格式。

torch.save: Saves a serialized object to disk. This function uses Python’s pickle utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.

不同的后缀只是用于提示我们文件可能包含的内容,但是具体的内容需要看模型提供者编写的 README.md 才知道。而在使用 torch.load() 方法加载模型信息的时候,并不是根据文件的后缀进行的读取,而是根据文件的实际内容自动识别的,因此对于 torch.load() 方法而言,不管你把后缀改成是什么,只要文件是对的都可以读取

torch.load: Uses pickle’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into.

1.3 格式汇总

格式解释适用场景可对应的后缀
.pt 或 .pthPyTorch 的默认模型文件格式,用于保存和加载完整的 PyTorch 模型,包含模型的结构和参数等信息需要保存和加载完整的 PyTorch 模型的场景,例如在训练中保存最佳的模型或在部署中加载训练好的模型.pt 或 .pth
.bin一种通用的二进制格式,可以用于保存和加载各种类型的模型和数据需要将 PyTorch 模型转换为通用的二进制格式的场景.bin
ONNX一种通用的模型交换格式,可以用于将模型从一个深度学习框架转换到另一个深度学习框架或硬件平台。在 PyTorch 中,可以使用 torch.onnx.export 函数将 PyTorch 模型转换为 ONNX 格式需要将 PyTorch 模型转换为其他深度学习框架或硬件平台可用的格式的场景.onnx
TorchScriptPyTorch 提供的一种序列化和优化模型的方法,可以将 PyTorch 模型转换为一个序列化的程序,并使用 JIT 编译器对模型进行优化。在 PyTorch 中,可以使用 torch.git.trace 或 torch.git.script 函数将 PyTorch 模型转换为 TorchScript 格式需要将 PyTorch 模型序列化和优化,并在没有 Python 环境的情况下运行模型的场景.pt 或 .pth

1.3.1 .pt .pth 格式

一个完整的 PyTorch 模型文件,包含了如下参数:

  • model_state_dict:模型参数
  • optimizer_state_dict:优化器的状态
  • epoch:当前的训练轮数
  • loss:当前的损失值

下面是一个 .pt 文件的保存和加载示例(注意,后缀也可以是 .pth):

  • .state_dict():包含所有的参数和持久化缓存的字典,model 和 optimizer 都有这个方法
  • torch.save():将所有的组件保存到文件中

模型保存

import torch
import torch.nn as nn

# 定义一个简单的模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = Net()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)# 初始化优化器

loss = nn.MSELoss()# 初始化损失函数

PATH = "model.pth" # 保存路径

# 保存模型
torch.save({
            'epoch': 10,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            }, PATH)

netron 可得:
在这里插入图片描述

模型加载

import torch
import torch.nn as nn

# 定义同样的模型结构
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# 加载模型
model = Net()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
PATH = "model.pth"
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
model.eval()

1.3.2 .bin 格式

.bin 文件是一个二进制文件,可以保存 PyTorch 模型的参数和持久化缓存。.bin 文件的大小较小,加载速度较快,因此在生产环境中使用较多。

下面是一个.bin文件的保存和加载示例(注意:也可以使用 .pt .pth 后缀—后缀无意义):
保存模型

import torch
import torch.nn as nn

# 定义一个简单的模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = Net()
# 保存参数到.bin文件
torch.save(model.state_dict(), PATH)

加载模型

import torch
import torch.nn as nn

# 定义相同的模型结构
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# 加载.bin文件
model = Net()
model.load_state_dict(torch.load(PATH))
model.eval()

1.3.3 直接保存完整模型

可以看出来,我们在之前的保存方式中,都是保存了 .state_dict(),但是没有保存模型的结构,在其他地方使用的时候,必须先重新定义相同结构的模型(或兼容模型),才能够加载模型参数进行使用,如果我们想直接把整个模型都保存下来,避免重新定义模型,可以按如下操作:
保存模型

PATH = "entire_model.pt"
# PATH = "entire_model.pth"
# PATH = "entire_model.bin"
torch.save(model, PATH)

netron 可得:
在这里插入图片描述

可以看到与上面仅保存参数的方式相比,多了很多信息。

加载模型

model = torch.load("entire_model.pt")
model.eval()

1.3.4 .onnx 格式

上述保存的文件可以通过 PyTorch 提供的 torch.onnx.export 函数转化为ONNX格式,这样可以在其他深度学习框架中使用 PyTorch 训练的模型。转化方法如下:

import torch
import torch.onnx

# 将模型保存为.bin文件
model = torch.nn.Linear(3, 1)
torch.save(model.state_dict(), "model.bin")
# torch.save(model.state_dict(), "model.pt")
# torch.save(model.state_dict(), "model.pth")

# 将.bin文件转化为ONNX格式
model = torch.nn.Linear(3, 1)
model.load_state_dict(torch.load("model.bin"))
# model.load_state_dict(torch.load("model.pt"))
# model.load_state_dict(torch.load("model.pth"))
example_input = torch.randn(1, 3)
torch.onnx.export(model, example_input, "model.onnx", input_names=["input"], output_names=["output"])

加载 ONNX 格式的代码可以参考以下示例代码(注意 ONNX 只能推理不能训练,不包含反向信息的):

import onnx
import onnxruntime

# 加载ONNX文件
onnx_model = onnx.load("model.onnx")

# 将ONNX文件转化为ORT格式
ort_session = onnxruntime.InferenceSession("model.onnx")

# 输入数据
input_data = np.random.random(size=(1, 3)).astype(np.float32)

# 运行模型
outputs = ort_session.run(None, {"input": input_data})

# 输出结果
print(outputs)

注意,需要安装 onnxonnxruntime 两个 Python 包。此外,还需要使用 numpy 等其他常用的科学计算库。

1.3.5 jit.trace

保存模型

import torch
import torch.nn as nn

# 定义一个简单的模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = Net()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # 初始化优化器
loss = nn.MSELoss() # 初始化损失函数
model.eval()

PATH = "model_trace.pth"

# 保存模型
example = torch.rand(1, 10)
traced_module = torch.jit.trace(model, example)
traced_module.save(PATH)

在这里插入图片描述

1.3.6 jit.script

保存模型

import torch
import torch.nn as nn

# 定义一个简单的模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

model = Net()

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # 初始化优化器
loss = nn.MSELoss() # 初始化损失函数
model.eval()

PATH = "model_script.pth" # 保存路径

# 保存模型
scripted_module = torch.jit.script(model)
scripted_module.save(PATH)

netron 可得:
在这里插入图片描述

1.4 总结

综上,PyTorch 可以导出的模型的几种后缀格式,但是模型导出的关键并不是后缀,而是到处时候提供的信息到底是什么,只要知道了模型的 model.state_dict()optimizer.state_dict(),以及相应的epoch batch_size loss等信息,我们就能够重建出模型,至于要导出哪些信息,就取决于你了,务必在 readme.md 中写清楚,导出了哪些信息。

保存场景保存方法文件后缀
整个模型(保存模型结构)model = Net()
torch.save(model, PATH)
.pt .pth .bin
仅模型参数(不保存模型结构)model = Net()
torch.save(model.state_dict(), PATH)
.pt .pth .bin
checkpoints使用model = Net()
torch.save({‘epoch’:10,‘model_state_dict’:model.state_dict(),‘optimizer_state_dict’: optimizer.state_dict(),‘loss’: loss,}, PATH)
.pt .pth .bin
ONNX通用保存model = Net()
model.load_state_dict(torch.load(“model.bin”))
example_input = torch.randn(1, 3)
torch.onnx.export(model, example_input, “model.onnx”, input_names=[“input”], output_names=[“output”])
.onnx
TorchScript 无 Python 环境使用model = Net()
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save(‘model_scripted.pt’)
model = torch.jit.load(‘model_scripted.pt’)
model.eval()
.pt .pth

2. TorchScript 的转换

上文内提到 .pthpt 等价,而且后缀主要用于提示。不过相对来说,PyTorch 的模型文件一般保存为 .pth 文件的更多一点,而 C++ 接口一般读取的是 .pt 文件,因此,C++ 在调用 PyTorch 训练好的模型文件的时候,就需要转换为以 .pt 为代表的 TorchScript 文件,才能够读取。

Script mode 通过 torch.jit.trace 或者 torch.jit.script 来调用。这两个函数都是将 Python 代码转换为 TorchScript 的两种不同的方法。

  • torch.jit.trace:将一个特定的输入(通常是一个张量,需要我们提供一个input)传递给一个 PyTorch 模型,torch.jit.trace 会跟踪此 input 在 model 中的计算过程,然后将其转换为 Torch 脚本。这个方法适用于那些在静态图中可以完全定义的模型,例如具有固定输入大小的神经网络。通常用于转换预训练模型。

  • torch.jit.script 直接将 Python 函数(或者一个 Python 模块)通过 Python 语法规则和编译转换为 Torch 脚本。torch.jit.script 更适用于动态图模型,这些模型的结构和输入可以在运行时发生变化。例如,对于 RNN 或者一些具有可变序列长度的模型,使用 torch.jit.script 会更为方便。

在通常情况下,更应该倾向于使用 torch.jit.trace 而不是 torch.jit.script

在模型部署方面,ONNX 被大量使用。而导出 ONNX 的过程,也是 model 进行 torch.jit.trace 的过程,因此这里我们把 torch 的 trace 做稍微详细一点的介绍。

2.1 jit trace 注意事项

为了能够把模型编写的更能够被 jit trace,需要对代码做一些妥协,例如:

  1. 如果 model 中有 DataParallel 的子模块,或者 model 中有将 tensors 转换为 numpy arrays,或者调用了 OpenCV 的函数等,这种情况下,model 不是一个正确的在单个设备上、正确连接的 graph,这种情况下,不管是使用 torch.jit.script 还是 torch.jit.trace 都不能 trace 出正确的 TorchScript 来。

  2. model 的输入输出应该是 Union[Tensor, Tuple[Tensor], Dict[str, Tensor]] 的类型,而且在 dict 中的值,应该是同样的类型。但是对于 model 中间子模块的输入输出,可以是任意类型,例如 dicts of Any, classes, kwargs 以及 Python 支持的都可以。对于 model 输入输出类型的限制是比较容易满足的,在Detectron2中,有类似的例子:

    outputs = model(inputs)   # inputs和outputs是python的类型, 例如dictsor classes
    # torch.jit.trace(model, inputs)  # 失败!trace只支持Union[Tensor,Tuple[Tensor], Dict[str, Tensor]]类型
    adapter = TracingAdapter(model, inputs)  # 使用Adapter,将model inputs包装为trace支持的类型
    traced = torch.jit.trace(adapter, adapter.flattened_inputs)  # 现在以trace成功
    
    # Traced model的输出只能是tuple tensors类型:
    flattened_outputs = traced(*adapter.flattened_inputs)
    # 再通过adapter转换为想要的输出类型
    new_outputs = adapter.outputs_schema(flattened_outputs)
    
  3. 一些数值类型的问题。比如下面的代码片段:

    import torch
    a=torch.tensor([1,2])
    print(type(a.size(0)))
    print(type(a.size()[0]))
    print(type(a.shape[0]))
    

    在eager mode下,这几个返回值的类型都是int型。上面代码的输出为:

    <class 'int'>
    <class 'int'>
    <class 'int'>
    

    但是在 trace mode 下,这几个表达式的返回值类型都是 Tensor 类型。因此,有些表达式使用不当,如果在 trace 过程中,一些 shape 表达式的返回值类型是 int 型,那么可能造成这块代码没有被 trace。在代码中,可以通过使用 torch.jit.is_tracing 来检查这块代码在 trace mode 下有没有被执行。

  4. 由于动态的 control flow,造成模型没有被完整的 trace。看下面的例子:

    import torch
    
    def f(x):
        return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
    
    m = torch.jit.trace(f, torch.tensor(3))
    print(m.code)
    

    输出为:

    def f(x: Tensor) -> Tensor:
      return torch.sqrt(x)
    

    可以看到 trace 后的 model 只保留了一条分支。因此由于输入造成的 dynamic 的 control flow,trace 后容易出现错误。

    这种情况下,我们可以使用 torch.jit.script 来进行 TorchScript 的转换。

    import torch
    
    def f(x):
        return torch.sqrt(x) if x.sum() > 0 else torch.square(x)
    
    m = torch.jit.script(f)
    print(m.code)
    

    输出为:

    def f(x: Tensor) -> Tensor:
      if bool(torch.gt(torch.sum(x), 0)):
        _0 = torch.sqrt(x)
      else:
        _0 = torch.square(x)
      return _0
    

    在大多数情况下,我们应该使用 torch.jit.trace,但是像上面的这种 dynamic control flow 的情况,我们可以混合使用 torch.jit.tracetorch.jit.script,在后面会进行阐述
    另外在一些 Blog 中,对于 dynamic control flow 的定义是有错误的,例如 if x[0] == 4: x += 1 是 dynamic control flow,但是:

    model: nn.Sequential = ...
    for m in model:
      x = m(x)
    

    以及:

    class A(nn.Module):
      backbone: nn.Module
      head: Optiona[nn.Module]
      def forward(self, x):
        x = self.backbone(x)
        if self.head is not None:
            x = self.head(x)
        return x
    

    都不是 dynamic control flowdynamic control flow 是由于对输入条件的判断造成的不同分支的执行

  5. trace 过程中,将变量 trace 成了常量。看下面一个例子:

    import torch
    a, b = torch.rand(1), torch.rand(2)
    
    def f1(x): return torch.arange(x.shape[0])
    def f2(x): return torch.arange(len(x))
    
    print(torch.jit.trace(f1, a)(b))
    # 输出: tensor([0, 1])
    # 可以看到trace后的model是没问题的,这里使用变量a作为torch.jit.trace的example input,然后将转换后的TorchScript用变量b作为输入,正常情况下,b的shape是2维的,因此返回值是tensor([0,1])是正确的
    
    print(torch.jit.trace(f2, a)(b))
    # 输出:
    # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
    # tensor([0])
    # 可以看到这个输出结果是错误的,b的维度是2维,输出应该是tensor([0,1]),这里torch.jit.trace也提示了,使用len可能会造成不正确的trace。
    
    # 我们打印一下两者的区别
    print(torch.jit.trace(f1, a).code, '\n',torch.jit.trace(f2, a).code)
    # 输出
    # def f1(x: Tensor) -> Tensor:
    #   _0 = ops.prim.NumToTensor(torch.size(x, 0))
    #   _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
    #   return _1
    
    #  def f2(x: Tensor) -> Tensor:
    #   _0 = torch.arange(1, dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
    #   return _0
    
    # TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
    
    # 从trace的code中可以看出,使用x.shape这种方式,在trace后的code里面,是有shape的一个变量值存在的,但是直接使用len这种方式,trace后的code里面,就直接是1
    

    我们导出 ONNX 的过程,也是进行 torch.jit.trace 的过程,在导出 ONNX 的时候,有时候也会遇到

    TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.

    这样的提示信息,这时候要检查一下代码中是不是有可能 trace 过程中,变量会被当做常量的情况,有可能会导致导出的 ONNX 精度异常。

    • 关于 ONNX
      ONNX 默认基于 trace 的方式,运行一次模型,记录下和 tensor 的相关操作。trace 将不会捕获根据输入数据而改变的行为。比如 if 语句,只会记录执行的那一条分支,同样的,for 循环的次数,导出与跟踪运行完全相同的静态图。如果要使用动态控制流导出模型,则需要使用 torch.jit.script
      torch.jit.script:真正的去编译,在 PYTHON 的 AST 语法树做语法分析句法分析。因此可以使用if等动态控制流。返回 ScriptModule。
      torch.onnx.export 在运行时,先判断是否是 SriptModule,如果不是,则进行 torch.jit.trace,因此 export 需要一个随机生成的输入参数。
      import torch.nn as nn
      import torch
      import torch.nn.functional as F
      import cv2
      import numpy as np
      import onnx
      import onnxruntime as ort
      
      #from torch.onnx import register_custom_op_symbolic # 私有层支持
      
      class test_net(nn.Module):
          def __init__(self,):
              super(test_net, self).__init__()
              #self.model = nn.MaxPool3d(kernel_size=(1,3,3), stride=(2,1,2))
              #self.model = nn.AvgPool3d(kernel_size=(1,3,3), stride=(2,1,2)) #-> AveragePool
              self.model = nn.Conv3d(3,64,kernel_size=(1,3,3), stride=(2,1,2))
              self.relu = nn.ReLU()
              self.relu6 = nn.ReLU6()
              self.relu66 = nn.ReLU6()
      
          def forward(self, x):
              out1 = self.model(x)
              f_mean = torch.mean(out1) # -> ReduceMean
              #f_mean = torch.mean(out1).item() # item()会将f_mean转换为常数 会丢失 mean操作
              # script模式转onnx会报错 torch._C._jit_pass_erase_number_types(graph) RuntimeError: Unknown number type: Scalar
              out2 = torch.div(out1, f_mean)
              #outlist = list()
              #for i in range(3):
              #    if i in [0]:
              #        #outlist.append(nn.ReLU()(out2))  # script模式下报错 类对象要提前构建
              #        outlist.append(self.relu(out2))   # scrip_to_onnx 报错 找不到25 BUG
              #    else:
              #        #outlist.append(nn.ReLU6()(out2))
              #        outlist.append(self.relu6(out2))
              #out = torch.cat(outlist)
              # 上述 for循环构图在tracing模式下会展开
              # script模式下难转换,报错
              # 手动平铺
              o1 = self.relu(out2)
              o2 = self.relu6(out2)
              #o3 = self.relu6(out2)   # script模式下被优化掉了 BUG
              o3 = self.relu66(out2)   # script模式下被优化掉了
              out = torch.cat([o1,o2,o3])
      
              return out
      
      # 模型构建和运行
      imgh, imgw = 24, 94
      net = test_net().eval() # 若存在batchnorm、dropout层则一定要eval() 使得BN层参数不更新
      dummy_input = torch.randn(1,3,3,imgh, imgw)# n c d h w
      torch_out = net.forward(dummy_input)# net(dummy_input)
      
      
      # export onnx
      dynamic_axes = {'input': {3: 'height', 4: 'width'}, 'output': {3: 'height', 4: 'width'}} # 配置动态分辨率
      onnx_pth = "test-conv-relu.onnx"
      
      # 传入原model,采用默认trace方式捕获模型,需要运行模型
      torch.onnx.export(net, dummy_input, onnx_pth, input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes)
      # 也可传入 scriptModule
      #net_script= torch.jit.script(test_net())
      # 需要外加配置 example_outputs,用来获取输出的shape和dtype,无需运行模型
      #torch.onnx.export(net_script, dummy_input, onnx_pth, input_names=['input'], output_names=['output'], dynamic_axes=dynamic_axes, example_outputs=[torch_out])
      
      # ort run
      oxx_m = ort.InferenceSession(onnx_pth)
      onnx_blob = dummy_input.data.numpy()
      onnx_out = oxx_m.run(None, {'input':onnx_blob})[0]
      
      dummy_input2 = torch.randn(1,3,3,imgh*2, imgw*2)
      onnx_blob2 = dummy_input2.data.numpy()
      onnx_out2 = oxx_m.run(None, {'input':onnx_blob2})[0]
      
      # opencv run
      #cv_m = cv2.dnn.readNet(onnx_pth)
      
      print('mean diff = ', np.mean(onnx_out - torch_out.data.numpy()))
      

    除了 len 会导致 trace 错误,其他几个也会导致 trace 出现问题:

    • .item() 会在 trace 过程中将 tensors 转为 int/float

    • 任何将 torch 类型转为 numpy/python 类型的代码

    • 一些有问题的算子,例如 advanced indexing

    • torch.jit.trace 不会对传入的 device 生效

      import torch
      def f(x):
          return torch.arange(x.shape[0], device=x.device)
      m = torch.jit.trace(f, torch.tensor([3]))
      print(m.code)
      # 输出
      # def f(x: Tensor) -> Tensor:
      #   _0 = ops.prim.NumToTensor(torch.size(x, 0))
      #   _1 = torch.arange(annotate(number, _0), dtype=None, layout=None, device=torch.device("cpu"), pin_memory=False)
      #   return _1
      print(m(torch.tensor([3]).cuda()).device)
      # 输出:device(type='cpu')
      

      trace 不会对传入的 cuda device 生效。

2.2 jit trace 验证技巧

为了保证trace的正确,我们可以通过一下的一些方法来尽量保证 trace 后的模型不会出错:
1.注意 warnings 信息。类似这样的:

TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.

TraceWarnings信息,它会造成模型的结果有可能不正确,但是它只是个 warning 等级。
2. 做单元测试。需要验证一下 eager mode 的模型输出与 trace 后的模型输出是否一致。

assert allclose(torch.jit.trace(model, input1)(input2), model(input2))
  1. 避免一些特殊的情况。例如下面的代码:
if x.numel() > 0:
  output = self.layers(x)
else:
  output = torch.zeros((0, C, H, W))  # 会创建一个空的输出

避免一些特殊情况比如空的输入输出之类的。

  1. 注意shape的使用。前面提到,tensor.size()在trace过程中会返回Tensor类型的数据,Tensor类型会在计算过程中被添加到计算图中,应该避免将Tensor类型的shape转为了常量。主要注意以下两点:
  • 使用 torch.size(0) 来代替 len(tensor),因为 torch.size(0) 返回的是 Tensor,len(tensor) 返回的是 int。对于自定义类,实现一个 .size 方法或者使用 .__len__() 方法来代替 len() ,例如这个例子
  • 不要使用 int() 或者 torch.as_tensor 来转换 size 的类型,因为这些操作也会被视为常量。
  1. 混合 tracing 和 scripting 方法。可以使用 torch.jit.script 来转换一些 torch.jit.trace 不能搞定的小的代码片段,混合使用 tracing 和 scripting,基本可以解决所有的问题。

2.3 混合使用 trace 和 script

trace 和 script 都有他们的问题,混合使用可以解决大部分问题。但是为了尽可能减小对于代码质量的负面影响,大部分情况下,都应该使用 torch.jit.trace,必要时才使用 torch.jit.script

  1. 在使用 torch.jit.trace 时,使用 @script_if_tracing 装饰器可以让被装饰的函数使用 script 方式进行编译

    def forward(self, ...):
      # ... some forward logic
      @torch.jit.script_if_tracing
      def _inner_impl(x, y, z, flag: bool):
          # use control flow, etc.
          return ...
      output = _inner_impl(x, y, z, flag)
      # ... other forward logic
    

    但是使用 @script_if_tracing 时,需要保证函数中没有 PyTorch 的 modules,如果有的话,需要做一些修改,例如下面的:

    # 因为代码中有self.layers(),是一个pytorch的module,因此不能使用@script_if_tracing
    if x.numel() > 0:
      x = preprocess(x)
      output = self.layers(x)
    else:
      # Create empty outputs
      output = torch.zeros(...)
    

    这里需要做如下修改:

    # 需要将self.layers移出if判断,这时候可以用@script_if_tracing
    if x.numel() > 0:
      x = preprocess(x)
    else:
      # Create empty inputs
      x = torch.zeros(...)
    # 需要将self.layers()修改为支持empty的输入,或者将原先的条件判断加入到self.layers中
    output = self.layers(x)
    
  2. 合并多次 trace 的结果
    使用 torch.jit.script 生成的模型相比使用 torch.jit.trace 有两个好处:

    • 可以使用条件控制流,例如模型中使用一个 bool 值来控制 forward 的 flow,在 traced modules 里面是不支持的
    • 使用 traced module,只能有一个 forward() 函数,但是使用 scripted module,可以有多个前向计算的函数
    class Detector(nn.Module):
      do_keypoint: bool
    
      def forward(self, img):
          box = self.predict_boxes(img)
          if self.do_keypoint:
              kpts = self.predict_keypoint(img, box)
    
      @torch.jit.export
      def predict_boxes(self, img): pass
    
      @torch.jit.export
      def predict_keypoint(self, img, box): pass
    

    对于这种有 bool 值的控制流,除了使用 script,还可以多次进行 trace,然后将结果合并。

    det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
    det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
    

    然后将他们的 weight 复制一遍,并合并两次 trace 的结果:

    det2.submodule.weight = det1.submodule.weight
    class Wrapper(nn.ModuleList):
      def forward(self, img, do_keypoint: bool):
        if do_keypoint:
            return self[0](img)
        else:
            return self[1](img)
    exported = torch.jit.script(Wrapper([det1, det2]))
    

    对于这种有 bool 值的控制流,除了使用 script,还可以多次进行 trace,然后将结果合并。

    det1 = torch.jit.trace(Detector(do_keypoint=True), inputs)
    det2 = torch.jit.trace(Detector(do_keypoint=False), inputs)
    

    然后将他们的 weight 复制一遍,并合并两次 trace 的结果:

    det2.submodule.weight = det1.submodule.weight
    class Wrapper(nn.ModuleList):
      def forward(self, img, do_keypoint: bool):
        if do_keypoint:
            return self[0](img)
        else:
            return self[1](img)
    exported = torch.jit.script(Wrapper([det1, det2]))
    

2.4 trace 和 script 的性能

trace 总是会比 script 生成一样或者更简单的计算图,因此性能会更好一些。因为 script 会完整的表达 Python 代码的逻辑,甚至一些不必要的代码也会如实表达。例如下面的例子:

class A(nn.Module):
  def forward(self, x1, x2, x3):
    z = [0, 1, 2]
    xs = [x1, x2, x3]
    for k in z: x1 += xs[k]
    return x1
model = A()
print(torch.jit.script(model).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
#   z = [0, 1, 2]
#   xs = [x1, x2, x3]
#   x10 = x1
#   for _0 in range(torch.len(z)):
#     k = z[_0]
#     x10 = torch.add_(x10, xs[k])
#   return x10
print(torch.jit.trace(model, [torch.tensor(1)] * 3).code)
# def forward(self, x1: Tensor, x2: Tensor, x3: Tensor) -> Tensor:
#   x10 = torch.add_(x1, x1)
#   x11 = torch.add_(x10, x2)
#   return torch.add_(x11, x3)

2.5 总结

trace 具有明显的局限性:这篇文章的大部分篇幅都在谈论 trace 的局限性以及如何解决这些问题。实际上,这正是 trace 的优势所在:它有明确的局限性(和解决方案),因此你可以推理它是否有效。

相反,script 更像是一个黑盒子:在尝试之前,没有人知道它是否有效。文章中没有提到如何修复 script 的任何诀窍:有很多诀窍,但不值得你花时间去探究和修复一个黑盒子。

trace 和 script 都会影响代码的编写方式,但 trace 因为我们明确它的要求,对我们原始的代码造成的一些修改也不会太严重:

  • 它限制了输入/输出格式,但仅限于最外层的模块。(如上所述,这个问题可以通过一个wrapper解决)。
  • 它需要修改一些代码才能通用(例如在 trace 时添加一些 script),但这些修改只涉及受影响模块的内部实现,而不是它们的接口。

3. LibTorch 的使用

在得到所需模型后,可以尝试在 C++ 环境下使用得到的模型,这里就用到了 LibTorch。

3.1 LibTorch 的链接

结合自己环境的 CUDA 版本,去官网下载对应版本的 libTorch。例如 CUDA 版本为 11.1,则需要在下载地址中找到 libtorch-cxx11-abi-shared-with-deps-1.9.1%2Bcu111.zip 进行下载。

链接进需要再 cmake 内加上这几行即可:

set(TORCH_PATH "/home/yj/libtorch/share/cmake/Torch")
message("TORCH_PATH set to: ${TORCH_PATH}")
set(Torch_DIR ${TORCH_PATH})

find_package(Torch REQUIRED)
message(STATUS "Torch version is: ${Torch_VERSION}")

# <target> is your target's name
target_link_libraries(<target> 
  ${TORCH_LIBRARIES}
)

3.2 接口和实现

  1. 头文件引入 :

    #include <torch/script.h>
    #include <torch/torch.h>
    
  2. 加载模型

    module = torch::jit::load(PATH);
    
  3. 函数实现

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1404682.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

品牌价值的累积与倍增:指数函数的含义及其在企业运营中的应用

品牌的价值日益凸显。品牌价值的累积与倍增不仅是企业追求的目标&#xff0c;也是市场竞争的重要标志。指数函数作为一种数学模型&#xff0c;对于描述品牌价值的增长具有重要意义。本文将深入探讨指数函数的含义及其在企业运营中的应用&#xff0c;并分析如何通过持续创新、品…

Unity 抽象工厂模式(实例详解)

文章目录 简介实例1实例2 简介 抽象工厂模式是一种创建型设计模式&#xff0c;它提供了一种方式来封装一组相关或相互依赖对象的创建过程&#xff0c;而无需指定具体类。这种模式常用于系统中有多组相关产品族&#xff0c;且客户端需要使用不同产品族中的对象时。 在Unity中&a…

canvas绘制旋转的椭圆花

查看专栏目录 canvas实例应用100专栏&#xff0c;提供canvas的基础知识&#xff0c;高级动画&#xff0c;相关应用扩展等信息。canvas作为html的一部分&#xff0c;是图像图标地图可视化的一个重要的基础&#xff0c;学好了canvas&#xff0c;在其他的一些应用上将会起到非常重…

Java 设计者模式以及与Spring关系(四) 代理模式

目录 简介: 23设计者模式以及重点模式 代理模式&#xff08;Proxy Pattern&#xff09; 静态代理示例 spring中应用 动态代理 1.基于JDK的动态代理 target.getClass().getInterfaces()作用 内名内部类写法(更简洁&#xff0c;但不推荐) 2.基于CGLIB实现 spring中应用 …

第137期 Oracle的数据生命周期管理(20240123)

数据库管理137期 2024-01-23 第137期 Oracle的数据生命周期管理&#xff08;20240123&#xff09;1 ILM2 Heat Map3 ADO4 优点5 对比总结 第137期 Oracle的数据生命周期管理&#xff08;20240123&#xff09; 作者&#xff1a;胖头鱼的鱼缸&#xff08;尹海文&#xff09; Orac…

图灵日记之java奇妙历险记--异常包装类泛型

目录 异常概念与体系结构异常的分类异常的处理防御式编程异常的抛出异常的捕获异常声明throwstry-catch捕获并处理 自定义异常类 包装类基本数据类型及其对应包装类装箱和拆箱 泛型泛型使用类型推导 裸类型说明 泛型的编译机制泛型的上界语法 异常概念与体系结构 在java中,将程…

《SPSS统计学基础与实证研究应用精解》视频讲解:SPSS数据排序

《SPSS统计学基础与实证研究应用精解》4.6 视频讲解 视频为《SPSS统计学基础与实证研究应用精解》张甜 杨维忠著 清华大学出版社 一书的随书赠送视频讲解4.6节内容。本书已正式出版上市&#xff0c;当当、京东、淘宝等平台热销中&#xff0c;搜索书名即可。本书旨在手把手教会使…

智能机器人与旋量代数(9)

Chapt 3. 螺旋运动与旋量代数 3.1 螺旋运动 螺旋运动是关于一条空间直线的一个旋转运动&#xff0c;并伴随沿此直线的一个平移。是一种刚体绕空间轴 s s s旋转 θ \theta θ角&#xff0c;再沿该轴平移距离 d d d的复合运动&#xff0c;类似螺母沿螺纹做进给运动的情形。 一…

x-cmd pkg | dasel - JSON、YAML、TOML、XML、CSV 数据的查询和修改工具

目录 简介首次用户快速实验指南基本功能性能特点竞品进一步探索 简介 dasel&#xff0c;是数据&#xff08;data&#xff09;和 选择器&#xff08;selector&#xff09;的简写&#xff0c;该工具使用选择器查询和修改数据结构。 支持 JSON&#xff0c;YAML&#xff0c;TOML&…

SQL提示与索引终章

✨博客主页&#xff1a;小小恶斯法克的博客 &#x1f388;该系列文章专栏&#xff1a;重拾MySQL-进阶篇 &#x1f4dc; 感谢大家的关注&#xff01; ❤️ 可以关注黑马IT&#xff0c;进行学习 目录 &#x1f680;SQL提示 &#x1f680;覆盖索引 &#x1f680;前缀索引 &…

科技、文化与旅游的融合创新:智慧文旅的未来之路

在当今社会&#xff0c;科技、文化与旅游的融合已经成为文旅产业转型升级的重要趋势。这种融合不仅有助于提升文旅产业的核心竞争力&#xff0c;更有助于推动产业的数字化转型和可持续发展。 本文将深入探讨科技、文化与旅游的融合创新&#xff0c;以及智慧文旅场景的解决方案…

Unity3d引擎中使用AIGC生成的360全景图(天空盒)

前言 在这里与Skybox AI一起&#xff0c;一键打造体验无限的360世界&#xff0c;这是这个AIGC一键生成全景图的网站欢迎语。 刚使用它是23年中旬&#xff0c;在没有空去给客户实地拍摄全景图时&#xff0c;可以快速用它生成一些相关的全景图&#xff0c;用作前期沟通的VR de…

【第十四课】并查集(acwing-837连通块中点的数量 / c++代码 / 思路详解)

目录 思路 代码如下 一些解释 思路 由于这道题是在并查集这个知识点下面&#xff0c;所以自然我们直接将无向图及之间连线的表示模型化为我们并查集的模板(或许其实也并不难想到?)&#xff0c;要解释一下的话就是&#xff1a;我们将无向图中的每个顶点当作一个集合&…

JSON简单了解

文章目录 1、JSON介绍2、ES6模版字符串3、JS对象转化为JSON字符串3.1、手动JS对象转化为JSON字符串3.2、自动JS对象转化为JSON字符串 4、JS对象和java互相转换 1、JSON介绍 JSON 概念&#xff1a;JavaScript Object Notation。JavaScript 对象表示法&#xff0c;简单理解JSON是…

C++参悟:数值运算相关

数值运算相关 一、概述二、常用数学函数1. 基础运算1. 浮点值的绝对值&#xff08; |x| &#xff09;2. 浮点除法运算的余数3. 除法运算的有符号余数4. 除法运算的有符号余数和最后三个二进制位5. 混合的乘加运算6. 两个浮点值的较大者7. 两个浮点值的较小者8. 两个浮点值的正数…

个人云服务器docker搭建部署前后端应用-myos

var code "87c5235c-b551-45bb-a5e4-9593cb104663" mysql、redis、nginx、java应用、前端应用部署 本文以单台云服务器为例&#xff1a; 1. 使用腾讯云服务器 阿里或其他云服务器皆可&#xff0c;类似 安装系统&#xff0c;现在服务器系统都集成安装了docker镜像&a…

仓库管理系统

仓库管理系统 项目环境要求 1.设备支持&#xff1a;Windows7、Windows8或Windows10&#xff1b; 2.数据库&#xff1a;Mysql 8.0&#xff1b; 3.软件支持&#xff1a;eclipse、navicat 需求分析 需求分析阶段的根本任务是要明确仓库管理系统功能需求&#xff0c;以便提出整个系…

Mapbox加载浙江省天地图服务和数据处理

1. 加载影像服务 通过浙江省天地图官网申请所需服务&#xff0c;使用token获取服务数据 由于浙江省天地图使用的坐标系是 cgcs2000&#xff0c;需要使用 的框架对应为 cgcs2000/mapbox-gl&#xff0c;通过cdn引入或npm下载 影像服务地址为&#xff1a; ‘https://ditu.zjzw…

2024年 复习 HTML5+CSS3+移动web 笔记 之CSS遍

第一天第二天第三天 1.1 引入方式 1.2 选择器 1.3 画盒子 1.4 文字控制 1.5 综合案例 一 新闻详情 2.1 复合选择器 2.2 伪类选择器 2.3 CSS 特性 2.4 Emmet 写法 2.5 背景属性 2.6 显示模式 2.6 综合案例 一 热词 &#xff08;设计稿&#xff1f;&#xff09; 2.7 综合案例 一…

优化用户体验测试应用领域:提升产品质量与用户满意度

在当今数字化时代&#xff0c;用户体验测试应用已经成为确保产品质量、提升用户满意度的关键工具。随着技术的不断发展&#xff0c;用户的期望也在不断演变&#xff0c;因此&#xff0c;为了保持竞争力&#xff0c;企业必须将用户体验置于产品开发的核心位置。本文将探讨用户体…