第九章 番外篇:TORCHSCRIPT

news2025/1/15 16:49:42

下文中的代码都使用参考教程中的例子。
会给出一点自己的解释。
参考教程:

文章目录

  • Introduction
    • 复习一下nn.Module()
    • Torchscript
      • torch.jit.ScriptModule()
      • torch.jit.script()
      • torch.jit.trace()
      • 一个小区别
  • 使用示例
    • tracing Modules
    • scripting Module
    • Mixing scripting and tracing
    • 保存和加载模型
  • 实践与优化

Introduction

我们训练好并保存的pytorch,支持在python语言下的使用,但是不支持在一些c++语言下使用。为了能让我们的模型在high-performance environment c++环境下使用,我们需要对模型进行格式转换。

好消息!torch本身是有模型格式转换的功能的,所以我们不需要下载额外的包,就可以把它转为能在c++使用的torchscript模型。

复习一下nn.Module()

之前的章节中有讲过,torch中所有模型都是基于nn.Module()这个类,模型的定义都继承了这个类的属性与方法。
一个完整的模型要包括以下三个基本的部分:

  1. 一个构造函数,用于调用模型模块
  2. parameters和sub-modules。它们在构造函数中被初始化,并能在调用中被使用。
  3. forward()函数,决定了模型调用的顺序。

教程中给出了下面一个简单的例子。
例子中定义了一个名为MyCell的类,它继承了torch.nn.Module()的功能。因为这个模型中没有需要训练的参数和网络层,所以先跳过parameters和sub-modules这一步。要注意这里使用了super,调用了父类的构造函数。
在forward()的部分,该方法的传入参数为x和h(忽略了self)。计算过程中只使用了torch.tanh(x+h),这一步没有参数需要更新。返回的结果为new_h。

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()

    def forward(self, x, h):
        new_h = torch.tanh(x + h)
        return new_h

my_cell = MyCell()
x = torch.rand(3, 4)
h = torch.rand(3, 4)
print(my_cell(x, h))

接下来对这个小模型进行一些改动,增加一些需要训练的参数。在教程例子中,它给这个模型增加了一个线性层。

class MyCell(torch.nn.Module):
    def __init__(self):
        super(MyCell, self).__init__()
        self.linear = torch.nn.Linear(4,4) # 在这部分增加了一个线性层
        
    def forward(self, x, h):
        new_h = torch.tanh(self.linear(x) + h) # 在调用的时候也使用了线性层,这里的参数需要在训练中更新
        return new_h, new_h
    
my_cell = MyCell()
x = torch.rand(3,4)
h = torch.rand(3,4)

print(my_cell(x,h))

可以看一下我们输出的结果中,多了一个grad_fn,之前我们曾经解释过,这个是反向传播中梯度计算的方法,因为现在有了要学习的参数,所以增加了这个方法。
在这里插入图片描述
pytorch具有很高的灵活性。在教程中提到了重要的一点是,很多框架都会在给出完整定义的情况下再进行求导的计算,而在pytorch中不是的,pytorch会在计算进行的时候记录这个操作,并在求导的过程中replay。所以pytorch时并没有很明确的对这些求导操作做出定义。

我自己也不是太理解这些话。我的个人理解是在backwards过程中tensor的grad_fn是随着当前步更新的,而不是预设好的。下面放出原文。

Many frameworks take the approach of computing symbolic derivatives given a full program representation. However, in PyTorch, we use a gradient tape. We record operations as they occur, and replay them backwards in computing derivatives. In this way, the framework does not have to explicitly define derivatives for all constructs in the language.

Torchscript

torchscript的作用就是根据pytorch code来创建一个模型,这个模型可以在非python环境下被使用。所以在pytorch中训练的模型,能够很容易地被应用到一个非python依赖的生产环境中去。

我们先来看一下代码,熟悉一下其中的方法的作用。

torch.jit.ScriptModule()

ScriptModule()也继承了nn.Module()类,所以它也有很多和nn.Module()一样的方法。比如children(),named_children()等。
它还包括一些神秘的方法。比如:
PROPERTY code 返回forward()函数中代码。这个功能是nn.Module()中没有的。
PROPERTY graph 返回forward()函数中的graph。

torch.jit.script()

torch.jit.script(obj, optimize=None, _frames_up=0, _rcb=None, example_inputs=None)

script() 的作用是检查一个function或者nn.Module()的源码,并把它编译成torchscript code并返回一个ScriptModule或者ScriptFunctions。
TorchScript本身是python language的一个子集,所以它并不能完全支持python中的所有功能,但是一些模型相关的计算它都是支持的。

更详细的介绍可以参考。

https://pytorch.org/docs/stable/jit_language_reference.html#language-reference

里面提到了一些对torchscript的限制,比如函数中的参数类型是不可以发生改变的,在python语言中你可以判断参数的种类并作出对应的操作,在torchscript中这是一个错误操作。torchscript中的参数为做特别说明的情况下,均默认为tensor。

这里的输入可以是一个function也可以是一个nn.Module(),要注意这里的example_inputs是有格式要求的:
(Union[List[Tuple], Dict[Callable, List[Tuple]], None])。

我们对我们定义的MyCell进行script,输入是一个nn.Module(),返回结果是一个ScriptModule()。
在这里插入图片描述

torch.jit.trace()

torch.jit.trace(func, example_inputs=None, optimize=None, check_trace=True, check_inputs=None, check_tolerance=1e-05, strict=True, _force_outplace=False, _module_class=None, _compilation_unit=<torch.jit.CompilationUnit object>, example_kwarg_inputs=None, _store_inputs=True)

torch.jit.trace()可以对一个function进行追踪,并返回一个可执行object或者一个ScriptFunction。你必须提供一个example_inputs。

  • The resulting recording of a standalone function produces ScriptFunction.
  • The resulting recording of nn.Module.forward or nn.Module produces ScriptModule.

当传入的是一个普调的function时,如下图,返回的结果是一个scriptfunction。
在这里插入图片描述

不管传入的是nn.Module还是它本身的forward函数,返回的结果都是一样的。
在这里插入图片描述

一个小区别

torch.jit.trace(func, input)只会记录这个input在function中走过的路径,比如下图的示例,虽然我们的在forward()中定义了一个a = torch.rand(3,4),但是这个值和我们的input没有什么关系,所以trace的时候没有记录。而script()则会对整个源码进行分析与记录。
在这里插入图片描述
此外trace()无法对if-else等分支进行记录,后面会详细介绍。

使用示例

tracing Modules

torchscript提供了一个方法,帮你获取你的模型的完整定义。首先来看一下tracing方法的作用。

使用上方定义的带线性层的小模型。
在这里插入图片描述
来看一下jit.trace做了什么操作,它首先传入了my_cell,然后传入了对应的输入。trace方法会调用这个Module,并且记录其中的每一步操作,并创造一个ScriptModule的实例。

我们可以看一下它的code。
在这里插入图片描述
使用trace方法会有一些天然的缺陷。它追踪了你的输入在function中经过的每一步操作,所以如果你的function中存在判断语句时,未被触发的操作就会被忽略掉。

使用教程中给出的例子。

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

my_cell = MyCell(MyDecisionGate())
traced_cell = torch.jit.trace(my_cell, (x, h))

print(traced_cell.dg.code)
print(traced_cell.code)

在这个例子中,MyDecisionGate函数进行了一个判断,假如传入的x的总和大于0,就返回x本身,假如x的总和小于0,就返回-x。

Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

在这里插入图片描述
我们可以看到因为我们的输入并不能走过if-else的两条路径,所以我们trace的结果中也只有一条路。我们的if-else方法不见了。

scripting Module

在上面的trace方法中,它对你的输入走过的路径进行记录,所以它看不到输入没有经过的地方。而我们的第二个方法,script() 则是直接对你的源码进行分析,所以能够保留比较完整的结果。

在这里插入图片描述

Mixing scripting and tracing

假如你的代码中有些不希望被torch.jit.script记录的常量,你可以使用trace和script的组合,将这些常量隐藏。

对这部分的理解是,对于有多个分支并且又有你想要隐藏的参数的情况下,可以使用trace和script的组合。多分支的部分用script记录,隐藏参数的部分用trace记录。

class MyDecisionGate(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x
        
scripted_gate = torch.jit.script(MyDecisionGate())

class MyCell(torch.nn.Module):
    def __init__(self, dg):
        super(MyCell, self).__init__()
        self.dg = dg
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, x, h):
        new_h = torch.tanh(self.dg(self.linear(x)) + h)
        return new_h, new_h

第一个例子,torch.jit.script和traced module内联。

class MyRNNLoop(torch.nn.Module):
    def __init__(self):
        super(MyRNNLoop, self).__init__()
        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))

    def forward(self, xs):
        h, y = torch.zeros(3, 4), torch.zeros(3, 4)
        for i in range(xs.size(0)):
            y, h = self.cell(xs[i], h)
        return y, h

rnn_loop = torch.jit.script(MyRNNLoop())
print(rnn_loop.code)

第二个例子,torch.jit.trace()和scripted module内联。

class WrapRNN(torch.nn.Module):
    def __init__(self):
        super(WrapRNN, self).__init__()
        self.loop = torch.jit.script(MyRNNLoop())

    def forward(self, xs):
        y, h = self.loop(xs)
        return torch.relu(y)

traced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))
print(traced.code)

我们观察一下第二个例子,比较一下最后使用jit.trace和jit.script有什么区别。
大家可以看到使用trace时,loop的返回结果是_0, y;使用script时,lopp返回的结果是y, h。
在这里插入图片描述

保存和加载模型

torchscript可以将模型独立地保存下来,保存的信息包括模型的code,parameters, attribute和debug information。这些完整的信息让我们的模型可以独立地表达,并在一个完全不同的进程中被加载,下面给出了代码例子。

traced.save('wrapped_rnn.pt')

loaded = torch.jit.load('wrapped_rnn.pt')

print(loaded)
print(loaded.code)

实践与优化

放一下源码的链接OPTIMIZING VISION TRANSFORMER MODEL FOR DEPLOYMENT。链接里内容更详细,有条件的直接看源码。我只是crop出来了中间和torchscript相关的部分。

from PIL import Image
import torch
import timm
import requests
import torchvision.transforms as transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

print(torch.__version__)
# should be 1.8.0


model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()

transform = transforms.Compose([
    transforms.Resize(256, interpolation=3),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])

img = Image.open(requests.get("https://raw.githubusercontent.com/pytorch/ios-demo-app/master/HelloWorld/HelloWorld/HelloWorld/image.png", stream=True).raw)
ten = transform(img)[None,]
out = model(ten)
clsidx = torch.argmax(out)
print(clsidx.item())

在这里插入图片描述
将模型以script 的形式保存下来

model = torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)
model.eval()
scripted_model = torch.jit.script(model)
scripted_model.save("fbdeit_scripted.pt")

比较一下两者的时间,两者在时间上是没有什么明显差别的。在教程中使用了一些模型加速的方法,所以inference的时间会变快。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/667010.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

乐鑫线上研讨会|探索 LCD 屏在物联网中的发展趋势

LCD 屏通过显示实时信息并提供交互式体验&#xff0c;现已成为各类设备的重要组成部分。在当下的 AIoT 时代&#xff0c;随着物联网技术的快速发展和应用场景的不断拓展&#xff0c;LCD 作为人机交互的主要输入输出设备&#xff0c;在智能家居、智能安防、工业控制、智慧城市等…

C#开发的OpenRA游戏之建造物品的窗口5

C#开发的OpenRA游戏之建造物品的窗口5 前面分析了TAB窗口的建立和运行,现在关注它的子窗口,也就是ProductionPaletteWidget类实现的窗口,这个窗口主要用来显示所有可以创建物品的ICON图标。用户可以通过这个窗口实现物品创建,如下图所示: 比如要创建电厂,就是点击上面…

【好书精读】网络是怎样连接的 之 创建套接字

&#xff08;该图由AI制作 学习AI绘图 联系我&#xff09; 目录 协议栈的内部结构 套接字的实体就是通信控制信息 真正的套接字 调用 socket 时的操作 从应用程序收到委托后 &#xff0c; 协议栈通过 TCP 协议收发数据的操作可以分为 4 个阶段 。 首先是创 建套接字 &…

SothisAI创建容器和conda环境

1.创建容器&#xff08;设置torch版本&#xff0c;cuda&#xff0c;python版本等等&#xff09;后进入web shell 2.shell里输入ssh username&#xff08;你自己的用户名&#xff09; IP&#xff08;你创建的实例的ip地址&#xff09; 3.在web平台创建你自己的文件夹 4.shel…

小程序请求封装、使用

小程序请求封装 1、要了解方法 1.1、wx.request() wx.request 发起 HTTPS 网络请求。&#xff08;详情点击wx.request查看官方文档&#xff09; 1.2、wx.showModal() wx.showModal 显示模态对话框。&#xff08;详情点击wx.showModal查看官方文档&#xff09; 1.3、wx.sho…

Swift 周报 第三十一期

文章目录 前言新闻和社区注册 WWDC23 实验室和活动Apple Vision Pro 和 visionOS 撼世登场App Store 中新增的隐私功能 提案正在审查的提案 Swift论坛推荐博文话题讨论关于我们 前言 本期是 Swift 编辑组自主整理周报的第二十二期&#xff0c;每个模块已初步成型。各位读者如果…

奇数分频器电路设计

目录 奇数分频器电路设计 1、奇数分频器电路简介 2、实验任务 3、程序设计 3.1、7分频电路代码 3.2、仿真验证 3.2.1、编写 TB 文件 3.2.2、仿真验证 4、用状态机实现7分频电路设计 4.1、代码如下&#xff1a; 4.2、使用状态机的好处 奇数分频器电路设计 前面一节我…

前端JS限制绕过测试-业务安全测试实操(17)

前端JS限制绕过测试,请求重放测试 前端JS限制绕过测试 测试原理和方法 很多商品在限制用户购买数量时,服务器仅在页面通过JS脚本限制,未在服务器端校验用户提交的数量,通过抓取客户端发送的请求包修改JS端生成处理的交易数据,如将请求中的商品数量改为大于最大数限制的值…

Vue中使用分布式事务管理解决方案

文章目录 分布式事务管理是什么优点&#xff1a;缺点&#xff1a;弥补缺点的方法有&#xff1a; 解决方案 分布式事务管理是什么 分布式事务管理是指在分布式系统中对跨多个数据库或服务的操作进行协调和保证一致性的机制。在分布式环境下&#xff0c;由于涉及到多个独立的资源…

半年面试12家大厂,我总结出了这份2023版互联网大厂(Java岗)面试真题汇总

Java面试 现在互联网大环境不好&#xff0c;互联网公司纷纷裁员并缩减HC&#xff0c;更多程序员去竞争更少的就业岗位&#xff0c;整的IT行业越来越卷。身为Java程序员的我们就更不用说了&#xff0c;上班8小时需要做好本职工作&#xff0c;下班后还要不断提升技能、技术栈&am…

docker 命令解释 - nginx镜像制作

目录 Dockerfile 部分命令解释 1、ENTRYPOINT 而ENTRYPOINT 语言 CMD的区别 1、docker run 启动容器的时候&#xff0c;可以传递参数进入给ENTRYPOINT里面的命令&#xff08;-e&#xff09; 2、当2者都存在的时候&#xff0c;CMD里的内容会成为 ENTRYPOINT 里的参数&#x…

Pytest中断言的重要性

目录 前言 pytest断言 增加断言详细信息 异常断言 .type .value .traceback pytest常用断言 前言 在pytest中&#xff0c;断言是非常重要的一部分。断言可以帮助我们验证代码的正确性&#xff0c;检查函数返回的值是否符合要求&#xff0c;以及判断程序中预期行为是否发生。如…

MySQL数据库(二)

前言 本文是关于MySQL数据库的第二弹。 临时表不受原表数据类型的约束&#xff01;&#xff01; SQL语法不区分大小写。 一、列的使用 &#xff08;一&#xff09;列的增加 1、全列插入 insert into 表名 values (数据,数据); 也可以同时插入多条数据: insert into 表名 va…

小程序跳转多次返回首页

小程序跳转多次返回首页 小程序跳转多个页面后直接返回首页 问题 例&#xff1a; 跳转&#xff1a;A(首页) - > B -> C -> D 返回&#xff1a;D -> A(首页) 1、页面中按钮跳转 <!--D页面 WXML--> <view class"-btn"><button bindtap&q…

6月第3周榜单丨飞瓜数据B站UP主排行榜(哔哩哔哩)发布!

飞瓜轻数发布2023年6月12日-6月18日飞瓜数据UP主排行榜&#xff08;B站平台&#xff09;&#xff0c;通过充电数、涨粉数、成长指数三个维度来体现UP主账号成长的情况&#xff0c;为用户提供B站号综合价值的数据参考&#xff0c;根据UP主成长情况用户能够快速找到运营能力强的B…

浅谈智能配电房的系统设计和技术方案

张心志acrelzxz 安科瑞电气股份有限公司 上海嘉定 201801 摘 要&#xff1a;为了进一步提升配网运维工作质量和效率&#xff0c;支撑配网技术发展向数字化、精益化、智能化转型。在大量的配电房现状问题分析以及新监测技术调研的基础上&#xff0c;文章提出了智能配电房…

WebGL/Threejs瀑布水流粒子效果

webgl瀑布效果 初始化场景 function init () {scene new THREE.Scene();camera new THREE.PerspectiveCamera (45, scr.w / scr.h, 0.1, 10000);renderer new THREE.WebGLRenderer ({ antialias: true });renderer.gammaInput true;renderer.gammaOutput true;renderer.…

盘点中国顶级黑客Top10,最后一位你猜是谁

第一名&#xff1a;袁仁广 别名&#xff1a;大兔子(datuzi)&#xff0c;人称袁哥。提起袁任广&#xff0c;知道的人或许并不多。但如果提起袁哥或者大兔子&#xff0c;在国内安全业界称得上尽人皆知。在国内&#xff0c;他的windows系统方面的造诣可谓首屈一指&#xff0c;早在…

centos系统socket5安装与使用

一、socket5安装 1、安装依赖 yum -y install gcc openldap-devel pam-devel openssl-devel 2、安装socket5 wget http://nchc.dl.sourceforge.net/project/ss5/ss5/3.8.9-8/ss5-3.8.9-8.tar.gztar -xzvf ss5-3.8.9-8.tar.gzcd ss5-3.8.9./configuremakemake install 二、…

Android Jetpack Compose — Slider滑动条

在Android Jetpack Compose中&#xff0c;Slider(滑动条&#xff09;是一个常用的用户界面控件&#xff0c;它允许通过滑动条来选择一个范围或数值。Slider控件非常适用于调整音量、亮度、进度等需要连续调整的场景。 一、Slider的属性 Slider是Android Jetpack Compose中的一个…