PyTorch2.0向后兼容性和加速效果浅探

news2025/1/24 1:37:14

前言

在PyTorch2022开发者大会上,PyTorch团队发布了一个新特性——torch.compile,将PyTorch的性能推向了新的高度,称这个新版本为PyTorch2.0。torch.compile的引入不影响之前的功能,其是一个完全附加和可选的功能,因此PyTorch2.0完全向后兼容,基于之前1.x版本开发的项目可以直接迁移到PyTorch2.0使用。

环境升级

比较简单,按照官方说明安装即可。
在这里插入图片描述
先建一个新环境torch2.0.1,python版本使用3.8+,在新环境中安装PyTorch2.0:

conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia

测试向后兼容性

待所有依赖包安装好之后,切换到新环境。

conda activate torch2.0.1

运行之前torch1.x下能正常运行的网络训练代码,可以看到能够正常运行。此时速度没什么明显差别。

需要注意的是,如果是使用DDP模式训练的话,可能会报“local_rank”相关的错。将代码中的相关配置参数修改一下:

__author__ = 'TracelessLe'

import argparse
import torch


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    if (torch.__version__).startswith('2.0'):
        parser.add_argument("--local-rank", type=int, required=True)
    else:
        parser.add_argument("--local_rank", type=int, required=True)
    
    main()

测试加速效果

根据PyTorch官方博客中的内容,使用torch.compile后模型训练和推理的加速效果很明显。
在这里插入图片描述
这里快速上手,直接根据新手教程中的操作来修改相应代码:

import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
opt_model = torch.compile(model, backend="inductor")
model(torch.randn(1,3,64,64))

在实验中发现自己的某个简单的网络训练速度由~0.8s/step加速到~0.6s/step,加速比达到25%。实践说明该新功能确实能够加速训练速度。

本次不深入测试更多的功能,包括不同的backend,以及纯推理过程的加速比。
在这里插入图片描述

其他说明

使用torch.compile功能时如果同时需要加载预训练模型,根据预训练模型保存的版本和正在使用的PyTorch版本的区别分情况进行处理:

1、预训练好的模型由PyTorch1.x保存,需要使用PyTorch2.0的torch.compile加速功能。则需要网络先加载模型参数,再使用torch.compile进行编译。

__author__ = 'TracelessLe'

import torch


device = 'cuda:0'
model_pth = 'pretrained_model.pth'
model = TrainNet()
model_state_dict = torch.load(model_pth, map_location=device)
model.load_state_dict(model_state_dict, strict=False)

if (torch.__version__).startswith('2.0'):
    model = torch.compile(model, backend="inductor")

2、预训练好的模型由PyTorch2.x保存,需要使用PyTorch2.0的torch.compile加速功能。则需要网络先编译再加载模型参数。

__author__ = 'TracelessLe'

import torch


device = 'cuda:0'
model_pth = 'pretrained_model.pth'
model = TrainNet()

if (torch.__version__).startswith('2.0'):
    model = torch.compile(model, backend="inductor")

model_state_dict = torch.load(model_pth, map_location=device)
model.load_state_dict(model_state_dict, strict=False)

当然,PyTorch2.0保存的模型PyTorch1.x也是可以正常加载的,只是需要注意的是模型中存的key有一定差异需要特殊处理一下。

其中,PyTorch2.0模型的key前缀是“_orig_mod.module.”,而PyTorch1.x模型的key前缀是“module.”。根据这个差异对模型加载过程特殊处理即可。

__author__ = 'TracelessLe'

import torch
import collections


def load_model_compile(model, model_pth, device, strict=False, backend="inductor"):
    # 兼容torch1/2大版本之间的模型加载
    origin_dict = torch.load(model_pth, map_location=device)
    state_dict = collections.OrderedDict()
    # torch1_model_prefix = 'module.'
    # offset1 = len(torch1_model_prefix)
    torch2_model_prefix = '_orig_mod.'
    offset2 = len(torch2_model_prefix)
    for key, value in origin_dict.items():
        if key.startswith(torch2_model_prefix):
            if (torch.__version__).startswith('2.0'):
                model = torch.compile(model, backend=backend)
                model.load_state_dict(origin_dict, strict=strict)
            else:
                for key, value in origin_dict.items():
                    state_dict[key[offset2: len(key)]] = value
                model.load_state_dict(state_dict, strict=strict)
        else:
            if (torch.__version__).startswith('2.0'):
                model.load_state_dict(origin_dict, strict=strict)
                model = torch.compile(model, backend=backend)
            else:
                model.load_state_dict(origin_dict, strict=strict) 
        break
    return model

当然,也可以直接改模型中参数的key以适配不同版本,此处不再展开。

针对PyTorch2.0的变化在官方博客中讲的很详细,需要深入应用的同学可以进一步查阅相关资料。

版权说明

本文为原创文章,独家发布在blog.csdn.net/TracelessLe。未经个人允许不得转载。如需帮助请email至tracelessle@163.com或扫描个人介绍栏二维码咨询。
在这里插入图片描述

参考资料

[1] PyTorch 2.0 重磅发布:一行代码提速 30% - 知乎
[2] Getting Started — PyTorch 2.0 documentation
[3] torch.compile — PyTorch 2.0 documentation
[4] 解决报错:train.py: error: unrecognized arguments: --local-rank=1 ERROR:torch.distributed.elastic.multipr_WTIAW.TIAW的博客-CSDN博客
[5] torch.compile — PyTorch 2.0 documentation
[6] Accelerated Image Segmentation using PyTorch | PyTorch
[7] Accelerated Generative Diffusion Models with PyTorch 2 | PyTorch
[8] PyTorch 2.0 | PyTorch
[9] torch.compile Tutorial — PyTorch Tutorials 2.0.1+cu117 documentation
[10] Training Compiled PyTorch 2.0 with PyTorch Lightning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/521739.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux系统调用函数(300多个)

前言:这里只是给出中文描述,方便浏览熟悉,具体情况建议去具体环境(Linux系统)下执行 1)man 2 systemcalls (查看所有系统调用函数);2)man 2 open &#xff08…

Codeforces Round 872 (Div. 2)

Problem - D2 - Codeforces 思路: 我们设good点到所有k点的距离和为dis。 假设good点不止一个,那么我们good点的dis应该都是相等的(废话)。设当前点u是good点,如果他往儿子v移动,儿子有w个点属于k&#…

Maven 项目模板学习

目录 Maven 项目模板 什么是 archetype? 使用项目模板 Maven 将询问原型的版本 创建的项目 创建 pom.xml Maven 项目文档 Maven 快照(SNAPSHOT) 什么是快照? 项目快照 vs 版本 app-ui 项目的 pom.xml 文件 Maven 快照(SNAPSHOT)的出现是因为为了如果pom有…

OpenPCDet系列 | 4.4 DataProcessor点云数据处理模块解析

文章目录 DataProcessor模块解析1. mask_points_and_boxes_outside_range2. shuffle_points3. transform_points_to_voxels DataProcessor模块解析 在对batch_data的处理中,经过了point_feature_encoder模块处理后,就轮到了进行data_processor处理。在d…

django路由(多应用配置)

一、配置全局路由 在应用下,定义视图函数views.py from django.http import HttpResponse from django.shortcuts import render# Create your views here.def get_order(request):return HttpResponse("orders应用下的路由") 在项目的urls路由配置中&…

Qt事件传递及相关的性能问题

在使用Qt时,我们都知道能通过mousePressEvent,eventFilter等虚函数的重写来处理事件,那么当我们向一个界面发送事件,控件和它的父控件之间的事件传递过程是什么样的呢? 本文将以下图所示界面为例,结合源码介…

【sentinel】热点规则详解及源码分析

何为热点?热点即经常访问的数据。很多时候我们希望统计某些热点数据中访问频次最高的Top K数据,并对其访问进行限制。 比如: 商品ID为参数,统计一段时间内最常购买的商品ID并进行限制用户ID为参数,针对一段时间内频繁…

【linux】init进程的详解

文章目录 概述init进程完成从内核态向用户态的转变(1)一个进程先后两种状态(2)init进程在内核态下的工作内容(3)init进程在用户态下的工作内容(4)init进程如何从内核态跳跃到用户态 …

springboot+vue高校社团管理系统(源码+文档)

风定落花生,歌声逐流水,大家好我是风歌,混迹在java圈的辛苦码农。今天要和大家聊的是一款基于springboot的高校社团管理系统。项目源码以及部署相关请联系风歌,文末附上联系信息 。 💕💕作者:风…

Linux快速安装Erlang和RabbitMQ单机版

环境 CentOS7Xshell6XFtp6Erlang 21.3RabbitMQ 3.8.4 安装方式 同一个软件有很多种安装方式,在Linux系统有几种常见的软件安装方式: 源码编译安装:一般需要解压,然后使用make、make install等命令RPM(RedHat Packa…

从物业管理到IT互联网精英,月薪11k的她几经辗转,终得偿所愿!

所谓“男怕入错行”,其实对女生来说也是一样,不同行业对人生的改变太大,想要找到满意的工作,就要不断去尝试。 西安的学员小文,大学毕业后从事的本专业(物业管理)工作,但不是很喜欢…

条款1:理解模板类型推导

现代C中被广泛应用的auto是建立在模板类型推导的基础上的。而当模板类型推导规则应用于auto环境时,有时不如应用于模板中那么直观。由于这个原因,真正理解auto基于的模板类型推导的方方面面非常重要。 在c中声明一个模板函数的伪代码基本如下&#xff1…

JVM 直接内存(Direct Memory)

直接内存概述 不是虚拟机运行时数据区的一部分&#xff0c;也不是<<Java 虚拟机规范>> 中定义的内存区域直接内存是Java 堆外的、直接向系统申请的内存区间来源于 NIO&#xff0c;通过存在堆中的 DirectByteBuffer 操作 Native 内存访问直接内存的速度会优于 Java…

智慧停车APP系统开发 停车取车缴费智能搞定

生活水平的提高让车辆成为很多人出行主要的代步工具&#xff0c;很多家庭现在已经不止拥有一辆汽车了&#xff0c;所以城市建设中关于停车场的规划管理也是很重要的部分。不过现在出门很多时候还是会碰到找不到停车场&#xff0c;没有车位、收费不合理、乱收费等现象。智慧停车…

调试和优化遗留代码

1. 认识调试器 1.1 含义 一个能让程序运行、暂停、然后对进程的状态进行观测甚至修改的工具。 在日常的开发当中使用非常广泛。(PHP开发者以及前端开发者除外) 1.2 常见的调试器 Go语言的自带的 delve 简写为 “dlv”GNU组织提供的 gdbPHP Xdebug前端浏览器debug 调试 1.3…

DNS投毒

定义 DNS缓存投毒又称DNS欺骗,是一种通过查找并利用DNS系统中存在的漏洞,将流量从合法服务器引导至虚假服务器上的攻击方式。与一般的钓鱼攻击采用非法URL不同的是,这种攻击使用的是合法URL地址。 DNS缓存中毒如何工作 在实际的DNS解析过程中,用户请求某个网站,浏览器首…

English Learning - L3 作业打卡 Lesson1 Day6 2023.5.10 周三

English Learning - L3 作业打卡 Lesson1 Day6 2023.5.10 周三 引言&#x1f349;句1: The expression was first used in America at the beginning of the twentieth century .成分划分弱读连读爆破语调 &#x1f349;句2: It probably comes from the fact that many babies…

分享一组有意思的按钮设计

先上效果图&#xff1a; 一共16个&#xff0c;每个都有自己不同的样式和效果&#xff0c;可以用在自己的项目中&#xff0c;提升客户体验~ 再上代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8">&l…

非Autosar软件手动集成XCP协议栈

文章目录 前言XCP发送XCP接收Xcp初始化Xcp主函数Xcp Event总结前言 最近项目由于各种原因没有直接采用基于Autosar工具生成的代码。只使用了NXP的MCAL。Demo需求实现XCP功能。本文记录手动集成XCP协议的过程,基于CAN总线。集成的前提过程是已有了XCP的静态代码和配置代码。可…

数据结构pta第一天: 堆中的路径 【用数组模拟堆的操作】

这道题其实就涉及两个堆操作&#xff0c; 一个是插入&#xff0c;一个是通过从底到根的遍历 堆的插入&#xff1a;其实就是从下面往上&#xff0c;一个一个比较&#xff0c;&#xff08;因为上面的节点里的值越来越小&#xff0c;如果插入的值比上面的节点小那么就要向上推&am…