PyTorch 中通道在最后的内存格式(beta)

news2024/9/20 19:55:03

PyTorch 中通道在最后的内存格式(beta)

什么是通道在最后

通道在最后的内存格式是在保留内存尺寸的顺序中对 NCHW 张量进行排序的另一种方法。 通道最后一个张量的排序方式使通道成为最密集的维度(又称为每像素存储图像)。

例如,NCHW 张量的经典(连续)存储(在我们的示例中是具有 3 个颜色通道的两个2x2图像)如下所示:

在这里插入图片描述

通道最后的存储格式对数据的排序方式不同:

在这里插入图片描述

Pytorch 通过使用现有的跨步结构支持内存格式(并提供与现有模型(包括 eager,JIT 和 TorchScript)的向后兼容性)。 例如,通道在最后的格式中的10x3x16x16批量的步幅等于(768, 1, 48, 3)

通道最后一个存储格式仅适用于 4D NCWH 张量。

import torch
N, C, H, W = 10, 3, 32, 32

内存格式 API

这是在连续和通道最后存储格式之间转换张量的方法。

经典 PyTorch 连续张量

x = torch.empty(N, C, H, W)
print(x.stride()) # Ouputs: (3072, 1024, 32, 1)

出:

(3072, 1024, 32, 1)

转换运算符

x = x.contiguous(memory_format=torch.channels_last)
print(x.shape) # Outputs: (10, 3, 32, 32) as dimensions order preserved
print(x.stride()) # Outputs: (3072, 1, 96, 3)

出:

torch.Size([10, 3, 32, 32])
(3072, 1, 96, 3)

返回连续

x = x.contiguous(memory_format=torch.contiguous_format)
print(x.stride()) # Outputs: (3072, 1024, 32, 1)

出:

(3072, 1024, 32, 1)

替代选择

x = x.to(memory_format=torch.channels_last)
print(x.stride()) # Ouputs: (3072, 1, 96, 3)

出:

(3072, 1, 96, 3)

格式检查

print(x.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True

出:

True

最后创建为渠道

x = torch.empty(N, C, H, W, memory_format=torch.channels_last)
print(x.stride()) # Ouputs: (3072, 1, 96, 3)

出:

(3072, 1, 96, 3)

clone保留内存格式

y = x.clone()
print(y.stride()) # Ouputs: (3072, 1, 96, 3)

出:

(3072, 1, 96, 3)

tocudafloat…保留内存格式

if torch.cuda.is_available():
    y = x.cuda()
    print(y.stride()) # Ouputs: (3072, 1, 96, 3)

出:

(3072, 1, 96, 3)

empty_like*_like运算符保留内存格式

y = torch.empty_like(x)
print(y.stride()) # Ouputs: (3072, 1, 96, 3)

出:

(3072, 1, 96, 3)

点向运算符保留内存格式

z = x + y
print(z.stride()) # Ouputs: (3072, 1, 96, 3)

出:

(3072, 1, 96, 3)

转换,Batchnorm模块支持通道在最后(仅适用于CudNN >= 7.6

if torch.backends.cudnn.version() >= 7603:
    input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, device="cuda", requires_grad=True)
    model = torch.nn.Conv2d(8, 4, 3).cuda().float()

    input = input.contiguous(memory_format=torch.channels_last)
    model = model.to(memory_format=torch.channels_last) # Module parameters need to be Channels Last

    out = model(input)
    print(out.is_contiguous(memory_format=torch.channels_last)) # Ouputs: True

出:

True

性能提升

在具有张量核心支持的 Nvidia 硬件上观察到了最大的性能提升。 在运行 Nvidia 提供的 AMP(自动混合精度)训练脚本时,我们可以将性能提高 22% 以上。

python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 ./data

# opt_level = O2
# keep_batchnorm_fp32 = None <class 'NoneType'>
# loss_scale = None <class 'NoneType'>
# CUDNN VERSION: 7603
# => creating model 'resnet50'
# Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.
# Defaults for this optimization level are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
# Processing user overrides (additional kwargs that are not None)...
# After processing overrides, optimization options are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
# Epoch: [0][10/125] Time 0.866 (0.866) Speed 230.949 (230.949) Loss 0.6735125184 (0.6735) Prec@1 61.000 (61.000) Prec@5 100.000 (100.000)
# Epoch: [0][20/125] Time 0.259 (0.562) Speed 773.481 (355.693) Loss 0.6968704462 (0.6852) Prec@1 55.000 (58.000) Prec@5 100.000 (100.000)
# Epoch: [0][30/125] Time 0.258 (0.461) Speed 775.089 (433.965) Loss 0.7877287269 (0.7194) Prec@1 51.500 (55.833) Prec@5 100.000 (100.000)
# Epoch: [0][40/125] Time 0.259 (0.410) Speed 771.710 (487.281) Loss 0.8285319805 (0.7467) Prec@1 48.500 (54.000) Prec@5 100.000 (100.000)
# Epoch: [0][50/125] Time 0.260 (0.380) Speed 770.090 (525.908) Loss 0.7370464802 (0.7447) Prec@1 56.500 (54.500) Prec@5 100.000 (100.000)
# Epoch: [0][60/125] Time 0.258 (0.360) Speed 775.623 (555.728) Loss 0.7592862844 (0.7472) Prec@1 51.000 (53.917) Prec@5 100.000 (100.000)
# Epoch: [0][70/125] Time 0.258 (0.345) Speed 774.746 (579.115) Loss 1.9698858261 (0.9218) Prec@1 49.500 (53.286) Prec@5 100.000 (100.000)
# Epoch: [0][80/125] Time 0.260 (0.335) Speed 770.324 (597.659) Loss 2.2505953312 (1.0879) Prec@1 50.500 (52.938) Prec@5 100.000 (100.000)

传递--channels-last true允许以通道在最后的格式运行模型,观察到 22% 的表现增益。

python main_amp.py -a resnet50 --b 200 --workers 16 --opt-level O2 --channels-last true ./data

# opt_level = O2
# keep_batchnorm_fp32 = None <class 'NoneType'>
# loss_scale = None <class 'NoneType'>
#
# CUDNN VERSION: 7603
#
# => creating model 'resnet50'
# Selected optimization level O2:  FP16 training with FP32 batchnorm and FP32 master weights.
#
# Defaults for this optimization level are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
# Processing user overrides (additional kwargs that are not None)...
# After processing overrides, optimization options are:
# enabled                : True
# opt_level              : O2
# cast_model_type        : torch.float16
# patch_torch_functions  : False
# keep_batchnorm_fp32    : True
# master_weights         : True
# loss_scale             : dynamic
#
# Epoch: [0][10/125] Time 0.767 (0.767) Speed 260.785 (260.785) Loss 0.7579724789 (0.7580) Prec@1 53.500 (53.500) Prec@5 100.000 (100.000)
# Epoch: [0][20/125] Time 0.198 (0.482) Speed 1012.135 (414.716) Loss 0.7007197738 (0.7293) Prec@1 49.000 (51.250) Prec@5 100.000 (100.000)
# Epoch: [0][30/125] Time 0.198 (0.387) Speed 1010.977 (516.198) Loss 0.7113101482 (0.7233) Prec@1 55.500 (52.667) Prec@5 100.000 (100.000)
# Epoch: [0][40/125] Time 0.197 (0.340) Speed 1013.023 (588.333) Loss 0.8943189979 (0.7661) Prec@1 54.000 (53.000) Prec@5 100.000 (100.000)
# Epoch: [0][50/125] Time 0.198 (0.312) Speed 1010.541 (641.977) Loss 1.7113249302 (0.9551) Prec@1 51.000 (52.600) Prec@5 100.000 (100.000)
# Epoch: [0][60/125] Time 0.198 (0.293) Speed 1011.163 (683.574) Loss 5.8537774086 (1.7716) Prec@1 50.500 (52.250) Prec@5 100.000 (100.000)
# Epoch: [0][70/125] Time 0.198 (0.279) Speed 1011.453 (716.767) Loss 5.7595844269 (2.3413) Prec@1 46.500 (51.429) Prec@5 100.000 (100.000)
# Epoch: [0][80/125] Time 0.198 (0.269) Speed 1011.827 (743.883) Loss 2.8196096420 (2.4011) Prec@1 47.500 (50.938) Prec@5 100.000 (100.000)

以下模型列表完全支持通道在最后,并在 Volta 设备上显示了 8%-35% 的表现增益:alexnetmnasnet0_5mnasnet0_75mnasnet1_0mnasnet1_3mobilenet_v2resnet101resnet152resnet18resnet34resnet50resnext50_32x4dshufflenet_v2_x0_5shufflenet_v2_x1_0shufflenet_v2_x1_5shufflenet_v2_x2_0squeezenet1_0squeezenet1_1vgg11vgg11_bnvgg13vgg13_bnvgg16vgg16_bnvgg19vgg19_bnwide_resnet101_2wide_resnet50_2

转换现有模型

通道在最后支持不受现有模型的限制,因为只要输入格式正确,任何模型都可以转换为通道在最后,并通过图传播格式。

# Need to be done once, after model initialization (or load)
model = model.to(memory_format=torch.channels_last) # Replace with your model

# Need to be done for every input
input = input.to(memory_format=torch.channels_last) # Replace with your input
output = model(input)

但是,并非所有运算符都完全转换为支持通道在最后(通常返回连续输出)。 这意味着您需要根据支持的运算符列表来验证已使用运算符的列表,或将内存格式检查引入急切的执行模式并运行模型。

运行以下代码后,如果运算符的输出与输入的存储格式不匹配,运算符将引发异常。

def contains_cl(args):
    for t in args:
        if isinstance(t, torch.Tensor):
            if t.is_contiguous(memory_format=torch.channels_last) and not t.is_contiguous():
                return True
        elif isinstance(t, list) or isinstance(t, tuple):
            if contains_cl(list(t)):
                return True
    return False

def print_inputs(args, indent=''):
    for t in args:
        if isinstance(t, torch.Tensor):
            print(indent, t.stride(), t.shape, t.device, t.dtype)
        elif isinstance(t, list) or isinstance(t, tuple):
            print(indent, type(t))
            print_inputs(list(t), indent=indent + '    ')
        else:
            print(indent, t)

def check_wrapper(fn):
    name = fn.__name__

    def check_cl(*args, **kwargs):
        was_cl = contains_cl(args)
        try:
            result = fn(*args, **kwargs)
        except Exception as e:
            print("`{}` inputs are:".format(name))
            print_inputs(args)
            print('-------------------')
            raise e
        failed = False
        if was_cl:
            if isinstance(result, torch.Tensor):
                if result.dim() == 4 and not result.is_contiguous(memory_format=torch.channels_last):
                    print("`{}` got channels_last input, but output is not channels_last:".format(name),
                          result.shape, result.stride(), result.device, result.dtype)
                    failed = True
        if failed and True:
            print("`{}` inputs are:".format(name))
            print_inputs(args)
            raise Exception(
                'Operator `{}` lost channels_last property'.format(name))
        return result
    return check_cl

old_attrs = dict()

def attribute(m):
    old_attrs[m] = dict()
    for i in dir(m):
        e = getattr(m, i)
        exclude_functions = ['is_cuda', 'has_names', 'numel',
                             'stride', 'Tensor', 'is_contiguous', '__class__']
        if i not in exclude_functions and not i.startswith('_') and '__call__' in dir(e):
            try:
                old_attrs[m][i] = e
                setattr(m, i, check_wrapper(e))
            except Exception as e:
                print(i)
                print(e)

attribute(torch.Tensor)
attribute(torch.nn.functional)
attribute(torch)

出:

Optional
'_Optional' object has no attribute '__name__'

如果您发现不支持通道在最后的张量的运算符并且想要贡献力量,请随时使用以下开发人员指南。

下面的代码是恢复火炬的属性。

for (m, attrs) in old_attrs.items():
  for (k,v) in attrs.items():
    setattr(m, k, v)

要做的工作

仍有许多事情要做,例如:

  • 解决 N1HW 和 NC11 张量的歧义;
  • 测试分布式训练支持;
  • 提高运算符覆盖率。

如果您有反馈和/或改进建议,请通过创建 ISSUE 来通知我们。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/675909.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java——《面试题——SpringCloud》

前文 java——《面试题——基础篇》 Java——《面试题——JVM篇》 Java——《面试题——多线程&并发篇》 Java——《面试题——Spring篇》 Java——《面试题——SpringBoot篇》 Java——《面试题——MySQL篇》​​​​​​ 目录 前文 1、什么是SpringCloud 2、什…

实战:NPMYARN构建工具实践-2023.6.22(测试成功)

实战&#xff1a;NPM&YARN构建工具实践-2023.6.22(测试成功) 目录 推荐文章 https://www.yuque.com/xyy-onlyone/aevhhf?# 《玩转Typora》 实验环境 gitlab/gitlab-ce:15.0.3-ce.0 jenkins/jenkins:2.346.3-2-lts-jdk11 openjdk 11.0.18 [rootDevops6 ~]#npm -v 6.14.12…

智能汽车 | 整车控制器(VCU)系统框图,功能拆解及供应商排名

摘要&#xff1a; 随着新能源EE架构的迭代及控制单元集成度越来越高&#xff0c;VCU的功能可能会被拆解到中央控制器域控制器&#xff0c;或者拆解到多合一的控制单元&#xff1b; VCU&#xff08;Vehicle Control Unit&#xff09;即整车控制器&#xff0c;是新能源汽车控制系…

JavaSE基础语法--static成员

假设我们现在有一个场景&#xff0c;定义一个学生类。 class Student{private String name;private int age;private int classroom_num;public Student(String name, int age, int classroom_num) {this.name name;this.age age;this.classroom_num classroom_num;} } pu…

翻筋斗觅食策略改进灰狼算法(IGWO)

目录 一、动态扰动因子策略 二、翻筋斗觅食策略 三、改进灰狼算法收敛曲线图 灰狼优化算法存在收敛的不合理性等缺陷&#xff0c;目前对GWO算法的收敛性改进方式较少&#xff0c;除此之外&#xff0c;当GWO迭代至后期&#xff0c;所有灰狼个体都逼近狼、狼、狼&#xff0c;…

HarmonyOS学习路之开发篇—多媒体开发(媒体会话管理开发)

一、媒体会话管理开发 AVSession是一套媒体播放控制框架&#xff0c;对媒体服务和界面进行解耦&#xff0c;并提供规范的通信接口&#xff0c;使应用可以自由、高效地在不同的媒体之间完成切换。 约束与限制 在使用完AVSession类后&#xff0c;需要及时进行资源释放。播放器类…

Linux常用命令——ftpshut命令

在线Linux命令查询工具 ftpshut 在指定的时间关闭FTP服务器 补充说明 功能说明&#xff1a;在指定的时间关闭ftp服务器。本指令提供系统管理者在设置的时间关闭FTP服务器&#xff0c;且能在关闭之前发出警告信息通知用户。关闭时间若设置后为"none"&#xff0c;则…

【实战项目开发技术分享】如何解决机器人运动不平稳的问题

文章目录 前言一、机器人设计的考虑因素二、控制算法的优化三、传感器改进四、实时监测与调试五、总结前言 机器人的运动平稳性对于其在各种应用中的成功执行任务至关重要。当机器人在执行任务过程中出现不稳定的运动,可能导致任务失败、损坏周围环境或甚至危及人员安全。因此…

ChatGPT在能源行业的预测场景:智能能源管理和异常检测的未来趋势

第一章&#xff1a;引言 能源是现代社会发展的关键驱动力之一&#xff0c;然而&#xff0c;传统的能源管理方法存在许多挑战&#xff0c;如能源浪费、供需不平衡以及能源异常等。为了应对这些挑战&#xff0c;智能能源管理系统逐渐崭露头角。在本文中&#xff0c;我们将探讨Ch…

基于Java+Swing实现仿QQ屏幕截图工具

基于JavaSwing实现仿QQ屏幕截图工具 一、系统介绍二、功能展示三、其它1.其他系统实现四.获取源码 一、系统介绍 实现能够实现对屏幕的随机截取&#xff0c;复制&#xff0c;保存以及添加文字等操作&#xff0c;便于用户对数据的处理。 该软件的功能&#xff1a; &#xff08…

I/O设备与主机信息传送的方式(程序查询方式,程序中断方式,DMA方式)

一.程序查询方式 CPU和I/O设备串行工作&#xff0c;CPU连接I/O设备和内存&#xff0c;CPU需要等待&#xff0c;效率很低 &#xff08;由CPU通过程序不断查询IO设备是否已经做好准备&#xff0c;从而控制IO设备与主机交换信息&#xff09; 二.程序中断方式&#xff1a; 中断&…

前端Vue自定义数字输入框 带加减按钮的数字输入框组件

前端Vue自定义数字输入框 带加减按钮的数字输入框组件&#xff0c; 下载完整代码请访问uni-app插件市场地址&#xff1a;https://ext.dcloud.net.cn/plugin?id13163 效果图如下&#xff1a; # cc-numbox #### 使用方法 使用方法 <!-- title: 标题 isSetMax: 是否设置最…

手把手叫你学会搭建FreeRTOS工程文件

手把手教你学会搭建FreeRTOS工程文件 一.序言二.提取文件2.1 Source文件夹2.2 portble文件夹2.3 Demo 文件夹 三.建立FreeRTOS工程3.1 新建FreeRTOS目录3.2 移植src文件夹3.3 移植port文件夹3.4 添加include文件夹3.5 提取FreeRTOSConfig.h文件3.5.1 拷贝FreeRTOSConfig.h文件 …

前端Vue自定义简单实用轮播图封装组件 快速实现轮播图

前端Vue自定义简单实用轮播图封装组件 快速实现轮播图&#xff0c; 下载完整代码请访问uni-app插件市场地址&#xff1a;https://ext.dcloud.net.cn/plugin?id13153 效果图如下&#xff1a; # cc-mySwiper #### 使用方法 使用方法 <!-- 自定义轮播图 swiperArr: 轮播数…

Day5——数据库基础2-SQL查询语句

网络安全学习笔记Day5 SQL查询语句&#xff08;重在实操&#xff01;&#xff01;&#xff01;&#xff09; 回顾1.中英文符号混淆2.数据库的操作流程&#xff08;回顾mysql相关语句&#xff09;3.“everything”工具 学习目标1.查询数据基本语法形式基本查询语句表单查询where…

ubuntu20下yolov4训练多目标实战

1、安装nvidia驱动和cudnn,不熟悉的小伙伴请移步&#xff1a;Ubuntu20.04安装NVIDIA显卡驱动、CUDA、CUDNN及突破NVENC并发限制_ubuntu20.04安装显卡驱动_BetterJason的博客-CSDN博客 2、编译opencv&#xff0c;不熟悉的小伙伴请移步:ubuntu20.04 和centos8平台opencv4.5.3&am…

【八大排序(九)】计数排序-非比较排序法

&#x1f493;博主CSDN主页:杭电码农-NEO&#x1f493;   ⏩专栏分类:八大排序专栏⏪   &#x1f69a;代码仓库:NEO的学习日记&#x1f69a;   &#x1f339;关注我&#x1faf5;带你学习排序知识   &#x1f51d;&#x1f51d; 计数排序 1. 前言2. 计数排序基本思路3. …

6.19 Nginx网站服务——服务基础

文章目录 一.Nginx服务基础1.关于Nginx的特点2.简述Nginx和Apache的差异3.Nginx 相对于 Apache 的优点4.Apache 相对于 Nginx 的优点5.阻塞与非阻塞6.同步与异步7.nginx的应用场景 二.编译安装nginx服务1.在线安装nginx1.1 yum部署Nginx1.2 扩展源安装完后直接安装Nginx 2.ngin…

【Red Hat 7.9---详细安装Oracle 11g】---静默方式安装

【Red Hat 7.9---详细安装Oracle 11g】---静默方式安装 &#x1f53b; 一、安装前规划&#x1f53b; 二、安装前准备一&#xff08;系统参数修改&#xff09;⛳ 2.1 内核版本、系统版本查看⛳ 2.2 修改主机名-重启生效⛳ 2.3 关闭selinux⛳ 2.4 防火墙设置1521端口开放⛳ 2.5 系…

哈希密码的加盐强化

&#x1f389;&#x1f389;&#x1f389;点进来你就是我的人了博主主页&#xff1a;&#x1f648;&#x1f648;&#x1f648;戳一戳,欢迎大佬指点! 欢迎志同道合的朋友一起加油喔&#x1f93a;&#x1f93a;&#x1f93a; 目录 一、什么是哈希加密&#xff1f; 二、哈希加密…