【Pytorch】(十三)模型部署: TorchScript

news2025/1/13 13:22:01

文章目录

  • (十三)模型部署: TorchScript
    • Pytorch动态图的优缺点
    • TorchScript
    • Pytorch模型转换为TorchScript
      • torch.jit.trace
      • torch.jit.script
      • trace和script的区别总结
      • trace 和script 混合使用
      • 保存和加载模型

(十三)模型部署: TorchScript

Pytorch动态图的优缺点

与Tensorflow使用静态计算图不同,PyTorch 使用的是动态计算图:

动态图允许在运行时渐进地构建计算图,使得模型设计更加灵活。开发者可以使用 Python 的控制流结构(如循环、条件语句等)来动态地定义模型的结构,从而更容易实现复杂的模型逻辑。

这种计算方式更直观,更pythonic。开发者可以更容易地理解和调试模型各个模块,快速地修改、迭代模型。

然而,与静态图相比,动态图的执行效率可能会较低。因为动态图难以进行一些计算图的优化,如运算符融合、图优化等。而且,动态图依赖于Python 环境。这些因素使得动态图不适合在低延迟要求较高的生产环境下部署。

因此,在部署Pytorch训练后的模型时,需要将动态图转换为静态图,这就要用到TorchScript。

TorchScript

TorchScript是PyTorch模型的一种静态图表示形式,支持模型的部署优化、跨平台部署以及与其他深度学习框架的集成:

  • 模型的部署优化:TorchScript 可以帮助优化 PyTorch 模型以提高性能和效率。通过将模型转换为静态图形式,TorchScript 可以应用各种优化技术,如运算符融合、图优化等,从而加速模型执行并降低内存消耗。
  • 跨平台部署:将模型转换为 TorchScript 格式可以实现跨平台部署,模型可以在没有 Python 环境的情况下运行。这对于在生产环境中部署模型到服务器、移动设备或边缘设备上非常有用。
  • 与其他框架集成:通过将 PyTorch 模型转换为 TorchScript 格式,可以更方便地与其他深度学习框架进行交互。例如,可以将TorchScript 进一步转换为 ONNX 格式,从而与 TensorFlow 等其他框架进行集成和交互操作。

Pytorch模型转换为TorchScript

torch.jit.tracetorch.jit.script 是 PyTorch 中用于模型转换为 TorchScript 格式的工具,但它们有不同的作用和使用场景。

torch.jit.trace

通过torch.jit.trace 将 没有控制流的MyCell 模块转化为TorchScript:


import torch  # This is all you need to use both PyTorch and TorchScript!

torch.manual_seed(191009)  # set the seed for reproducibility


class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h)
        return new_h, new_h

my_cell = MyCell()
x, h = torch.rand(3, 4), torch.rand(3, 4)
traced_cell = torch.jit.trace(my_cell, (x, h))
print(traced_cell)
MyCell(
  original_name=MyCell
  (linear): Linear(original_name=Linear)
)

torch.jit.trace调用了my_cell,记录了模块计算时发生的操作,并创建了一个torch.jit.ScriptModule的实例(TracedModule是其实例)traced_celltraced_cell 记录了my_cell的计算图。我们可以使用.graph属性来查看:

print(traced_cell.graph)
graph(%self.1 : __torch__.MyCell,
      %x : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu),
      %h : Float(3, 4, strides=[4, 1], requires_grad=0, device=cpu)):
  %linear : __torch__.torch.nn.modules.linear.Linear = prim::GetAttr[name="linear"](%self.1)
  %20 : Tensor = prim::CallMethod[name="forward"](%linear, %x)
  %11 : int = prim::Constant[value=1]() # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0
  %12 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::add(%20, %h, %11) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0
  %13 : Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu) = aten::tanh(%12) # /var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:189:0
  %14 : (Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu), Float(3, 4, strides=[4, 1], requires_grad=1, device=cpu)) = prim::TupleConstruct(%13, %13)
  return (%14)

然而,图中包含的大多数信息对我们没有用处。我们可以使用.code属性对其进行Python语法解释:

print(traced_cell.code)
def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  linear = self.linear
  _0 = torch.tanh(torch.add((linear).forward(x, ), h))
  return (_0, _0)

调用traced_cell会产生与Python模块实例my_cell() 相同的结果:

print(my_cell(x, h))
print(traced_cell(x, h))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))
(tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>), tensor([[-0.2541,  0.2460,  0.2297,  0.1014],
        [-0.2329, -0.2911,  0.5641,  0.5015],
        [ 0.1688,  0.2252,  0.7251,  0.2530]], grad_fn=<TanhBackward0>))

torch.jit.script

我们先尝试通过torch.jit.trace 将 带有控制流的MyCell 模块转化为TorchScript:

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

print(traced_cell.dg.code)
print(traced_cell.code)
/var/lib/workspace/beginner_source/Intro_to_TorchScript_tutorial.py:261: TracerWarning:

Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

def forward(self,
    argument_1: Tensor) -> NoneType:
  return None

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = (linear).forward(x, )
  _1 = (dg).forward(_0, )
  _2 = torch.tanh(torch.add(_0, h))
  return (_2, _2)

可以看到,if-else分支并没有被表示出来。为什么?
trace记录代码运行发生的操作,并构造一个ScriptModule。控制流中只有一种情况被记录了下来,其他情况都被忽略了。

这就需要用到torch.jit.script了:

scripted_gate = torch.jit.script(MyDecisionGate())

my_cell = MyCell(scripted_gate)
scripted_cell = torch.jit.script(my_cell)

print(scripted_gate.code)
print(scripted_cell.code)
def forward(self,
    x: Tensor) -> Tensor:
  if bool(torch.gt(torch.sum(x), 0)):
    _0 = x
  else:
    _0 = torch.neg(x)
  return _0

def forward(self,
    x: Tensor,
    h: Tensor) -> Tuple[Tensor, Tensor]:
  dg = self.dg
  linear = self.linear
  _0 = torch.add((dg).forward((linear).forward(x, ), ), h)
  new_h = torch.tanh(_0)
  return (new_h, new_h)

可以考到,控制流也被记录了下来。
现在让我们尝试运行该程序:

# New inputs
x, h = torch.rand(3, 4), torch.rand(3, 4)
print(scripted_cell(x, h))
(tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],
        [ 0.5228,  0.7122,  0.6985, -0.0656],
        [ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>), tensor([[ 0.5679,  0.5762,  0.2506, -0.0734],
        [ 0.5228,  0.7122,  0.6985, -0.0656],
        [ 0.6187,  0.4487,  0.7456, -0.0238]], grad_fn=<TanhBackward0>))

trace和script的区别总结

  • torch.jit.tracetorch.jit.trace 用于将一个具体的输入示例追踪(trace)模型的一次计算过程,从而生成一个 TorchScript 模型。对于动态控制流(如条件语句),它只会记录每个分支中的一种情况。因此,它不适用于无固定形状输入、具有动态控制流的模型。

  • torch.jit.scripttorch.jit.script 用于将整个 PyTorch 模型转换为 TorchScript 模型,包括模型的所有逻辑和控制流。script适用于无固定形状输入、具有动态控制流的模型 。但是,它可能会把保存一些多余的代码, 产生额外的性能开销。

因此,可以将两者混合使用,扬长避短。

trace 和script 混合使用

torch.jit.tracetorch.jit.script 可以混合使用: 复杂模型中静态部分用torch.jit.trace进行转换, 动态部分用torch.jit.script 进行转换,以发挥各自的优势。以下是两个可能的情况:

  • torch.jit.script内联traced模块的代码,
class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)
def forward(self,
    xs: Tensor) -> Tuple[Tensor, Tensor]:
  h = torch.zeros([3, 4])
  y = torch.zeros([3, 4])
  y0 = y
  h0 = h
  for i in range(torch.size(xs, 0)):
    cell = self.cell
    _0 = (cell).forward(torch.select(xs, 0, i), h0, )
    y1, h1, = _0
    y0, h0 = y1, h1
  return (y0, h0)
  • torch.jit.trace内联scripted模块的代码,
class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)

保存和加载模型

  • traced.save : 保存TorchScript

  • torch.jit.load : 加载TorchScript

traced.save('wrapped_rnn.pt')

loaded = torch.jit.load('wrapped_rnn.pt')

print(loaded)
print(loaded.code)
RecursiveScriptModule(
  original_name=WrapRNN
  (loop): RecursiveScriptModule(
    original_name=MyRNNLoop
    (cell): RecursiveScriptModule(
      original_name=MyCell
      (dg): RecursiveScriptModule(original_name=MyDecisionGate)
      (linear): RecursiveScriptModule(original_name=Linear)
    )
  )
)
def forward(self,
    xs: Tensor) -> Tensor:
  loop = self.loop
  _0, y, = (loop).forward(xs, )
  return torch.relu(y)

参考:
https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1628943.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ubuntu解密:Root账户登录问题一网打尽

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 Ubuntu解密&#xff1a;Root账户登录问题一网打尽 前言Root用户简介Root账户无法登录的原因重设Root账户密码解决ssh不能root登录问题安全性考虑 前言 Ubuntu作为广受欢迎的Linux发行版&#xff0c;对…

[Android]引导页

使用Kotlin Jetpack Compose创建一个左右滑动的引导页, 效果如图. 1.添加依赖项 androidx.compose.ui最新版本查询:https://maven.google.com/web/index.html com.google.accompanist:accompanist-pager最新版本查询:https://central.sonatype.com/ 确保在 build.gradle (M…

C++|继承(菱形+虚拟)

目录 一、继承的概念及定义 1.1概念 1.2定义 1.2.1定义格式 1.2.2继承关系和访问限定符 1.2.3继承基类成员访问方式的变化 二、基类和派生类对象赋值转换 三、继承中的作用域 四、派生类的默认成员函数 五、继承与友元、静态成员 六、菱形继承与虚拟继承 6.1单继承与…

Redis - Set 集合

目录 前言 命令 SADD 将一个或者多个元素添加到 set 中 语法 SMEMBERS 获取一个 set 中的所有元素 语法 SISMEMBER 判断⼀个元素在不在 set 中 语法 SCARD 获取 set 中的元素个数 语法 SPOP 从 set 中随机删除并返回⼀个或者多个元素 语法 SMOME 将⼀个元素从源 se…

电脑教程1

一、介绍几个桌面上面的软件 1、火绒&#xff1a;主要用于电脑的安全防护和广告拦截 1.1 广告拦截 1.打开火绒软件点击安全工具 点击弹窗拦截 点击截图拦截 拦截具体的小广告 2、向日葵远程控制&#xff1a;可以通过这个软件进行远程协助 可以自己去了解下 这个软件不要…

每日算法之两两交换链表中的节点

题目描述 给你一个链表&#xff0c;两两交换其中相邻的节点&#xff0c;并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题&#xff08;即&#xff0c;只能进行节点交换&#xff09;。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4] 输出&…

Apollo 7周年大会自动驾驶生态利剑出鞘

前言 4月22日&#xff0c;百度Apollo在北京车展前夕举办了以“破晓•拥抱智变时刻”为主题的智能汽车产品发布会&#xff0c;围绕汽车智能化&#xff0c;发布了智驾、智舱、智图等全新升级的“驾舱图”系列产品。 1、7周年大会 自2013年百度开始布局自动驾驶&#xff0c;201…

Kotlin基础​​

数据类型 定义变量 var表示定义变量&#xff0c;可以自动推导变量类型&#xff0c;所以Int可以不用写。 定义常量 条件语句 if表达式可以返回值&#xff0c;该值一般写在if里的最后一行 类似switch的用法 区间 循环 a是标签&#xff0c;可以直接break到标签的位置&#xf…

Docker 的数据管理 端口映射 容器互联 镜像创建

一 Docker 的数据管理 1 管理 Docker 容器中数据主要有两种方式&#xff1a; 数据卷&#xff08;Data Volumes&#xff09; 数据卷容器&#xff08;DataVolumes Containers&#xff09;。 1.1 数据卷 数据卷是一个供容器使用的特殊目录&#xff0c;位于容器中。可将宿主机…

【PLC学习十四】TIA.V18无法启动仿真的问题

【PLC学习十四】TIA.V18无法启动仿真的问题 文章目录 【PLC学习十四】TIA.V18无法启动仿真的问题前言一、程序仿真出现的问题二、解决方法1.无法仿真的问题2.因安全问题&#xff0c;无法编译的问题3、在TIA V18内部设置完成PG/PC接口后&#xff0c;下一次打开仍然不能仿真&…

【Vue3+Tres 三维开发】01-HelloWord

预览 什么是TRESJS 简单的说,就是基于THREEJS封装的能在vue3中使用的一个组件,可以像使用组件的方式去创建场景和模型。优势就是可以快速创建场景和要素的添加,并且能很明确知道创景中的要素构成和结构。 项目创建 npx create-vite@latest # 选择 vue typescript安装依赖…

【Linux 进程间通信】管道

文章目录 1.为什么操作系统需要向用户提供进程间通信方式&#xff1f;2.进程间通信的种类3.管道3.1管道的作用3.2管道的本质3.3管道的通信原理3.4管道的分类 1.为什么操作系统需要向用户提供进程间通信方式&#xff1f; ①&#x1f34e;资源共享&#xff1a;有的时候两个进程需…

QT——简易计算器(从0开始)

目录 一、题目描述&#xff1a; 二、创建工程&#xff1a; 1. ​编辑 2. 3. 4. 默认 5. 6. 7. 8. 默认 9. 创建完成 三、UI界面设计&#xff1a; 1. 添加按钮 1. 2. 按钮界面 3. 按钮绑定快捷键 2. 文本框添加 1. 文本框字体 2. 默认文本 3. 文本对齐方式…

英智数字孪生机器人解决方案,赋能仓库物流模式全面升级

工业机械臂、仓储机器人、物流机器人等模式的机器人系统在现代产业中扮演着愈发重要的角色&#xff0c;他们的发展推动了自动化和智能化水平的提高&#xff0c;有助于为制造业、物流业、医疗保健业和服务业等行业创造新效率并提升人们的生活质量。 行业面临的挑战 机器人开发、…

为何要与云产商进行云端防护合作,上云企业如何保障云端安全

随着大数据、云计算等信息技术的迅猛发展&#xff0c;企业为了降低成本、提高效率&#xff0c;纷纷将业务迁移至云端。 随着大数据、云计算等信息技术的迅猛发展&#xff0c;企业为了降低成本、提高效率&#xff0c;纷纷将业务迁移至云端。这一全面的上云浪潮对传统的安全企业格…

YOLOv8+PyQt5野外火焰检测系统(可以从图像、视频和摄像头三种路径检测)

1.效果视频&#xff1a;https://www.bilibili.com/video/BV1Tm421s7te/?spm_id_from333.999.0.0 2.资源包含可视化的野外火焰检测系统&#xff0c;可用于火灾预警或火灾救援&#xff0c;该系统可自动检测和识别图片或视频当中出现的火焰&#xff0c;以及自动开启摄像头&#…

使用Windows GDI进行绘图

使用Windows GDI绘图&#xff0c;可以使用MFC&#xff0c;也可以直接使用Windows API绘图&#xff0c;两者其实都一样。MFC也是封装了Windows API。 下面以MFC为例&#xff0c;进行说明。因为MFC帮我们做好了一些底层&#xff0c;可以直接使用Windows GDI的函数。 在MFC中使用…

如此建立网络根文件系统 Mount NFS RootFS

安静NFS系统服务 sudo apt-get install nfs-kernel-server 创建目录 sudo mkdir /rootfsLee 将buildroot编译的根文件系统解压缩到 sudo tar xvf rootfs.tar -C /rootfsLee/ 添加文件NFS访问路径 sudo vi /etc/exports sudo /etc/exports文件&#xff0c;添加如下一行 …

SecureCRT中添加命令显示为空如何处理?(原因添加了空行)

相关背景信息 配置相关路径:~/Library/Application\ Support/VanDyke/SecureCRT/Config包括的配置信息 按钮、命令、全局配置、色彩、以及license都在$ ls ButtonBarV4.ini Commands Global.ini SSH2.ini Button…

STM32单片机通过ST-Link 烧录和调试

系列文章目录 STM32单片机系列专栏 C语言术语和结构总结专栏 文章目录 1. ST-LINK V2 2. 操作步骤 2.1 连接方式 2.2 驱动安装常规步骤 2.3 Keil中的设置 3. 调式仿真 4. 常见问题排查 1. ST-LINK V2 ST LINK v2下载器用于STM32单片机&#xff0c;可以下载程序、调试…