昇思MindSpore进阶教程--自动向量化Vmap(下)

news2024/10/5 13:33:23

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

文章上半部分请查看
自动向量化Vmap(上)

自动向量化

Vmap可以帮助我们隐藏批处理维度,您只需要调用一个接口便可以将函数转换为向量化形式。

from mindspore import vmap

auto_vectorization_conv = vmap(convolve)
auto_vectorization_conv(x_batch, w_batch)

Vmap除了为您提供了简易的编程体验外,将循环逻辑下沉至函数的各个基元操作中,结合分布式并行优化以获得更高的执行性能。 默认情况下,vmap的输入输出沿第一个轴进行批处理,如果您的输入和输出并不总是期望沿0轴批处理,可以通过in_axes和out_axes参数进行指定。您可以为所有输入或输出位置分别指定批处理轴索引,也可以为所有输入或输出指定相同的批处理轴索引。

w_batch_t = ops.transpose(w_batch, (1, 0))

auto_vectorization_conv = vmap(convolve, in_axes=(0, 1), out_axes=1)
output = auto_vectorization_conv(x_batch, w_batch_t)

ops.transpose(output, (1, 0))

对于多个输入的场景,您还可以指定只对其中的某些入参进行批处理,如上述场景变为求一组一维向量与某一权重的卷积,可在in_axes参数中的输入对应位置配置None即可,None表示不沿任何轴进行批处理。

auto_vectorization_conv = vmap(convolve, in_axes=(0, None), out_axes=0)
auto_vectorization_conv(x_batch, w)

高阶函数的嵌套

Vmap本质上是一种高阶函数,它将函数作为输入,并返回可应用于批处理数据的向量化函数。用法上它允许和其他框架提供的高阶函数进行嵌套组合使用。

  • vmap与vmap嵌套使用,应用于两层以上的批处理逻辑。
hyper_x = Tensor([[1., 2., 3., 4., 5.], [2., 3., 4., 5., 6.], [3., 4., 5., 6., 7.]], mindspore.float32)
hyper_w = Tensor([[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]], mindspore.float32)

hyper_vmap_ger = vmap(vmap(convolve, in_axes=[None, 0]), in_axes=[0, None])
hyper_vmap_ger(hyper_x, hyper_w)

  • grad内部嵌套vmap使用,应用于计算向量化函数的梯度等场景。
from mindspore import grad

def forward_fn(x, y):
    out = x + 2 * y
    out = ops.sin(out)
    reduce_sum = ops.ReduceSum()
    return reduce_sum(out)

x_hat = Tensor([[1., 2., 3.], [2., 3., 4.]], mindspore.float32)
y_hat = Tensor([[2., 3., 4.], [3., 4., 5.]], mindspore.float32)

grad_vmap_ger = grad(vmap(forward_fn), grad_position=(0, 1))
grad_vmap_ger(x_hat, y_hat)

  • vmap内部嵌套grad使用,应用于计算批量梯度、高阶梯度计算等场景,如计算Jacobian矩阵。
vmap_grad_ger = vmap(grad(forward_fn, grad_position=(0, 1)))
vmap_grad_ger(x_hat, y_hat)

本教程中只简单介绍两层高阶函数组合嵌套的用法,您可以根据场景需求实现更多层次的嵌套。

Cell的自动向量化

之前的用例我们都是以函数对象作为输入,下面将介绍Cell对象结合vmap的用法。这是一个简单定义的全连接层的例子。

import mindspore.nn as nn
from mindspore import Parameter
from mindspore.common.initializer import initializer

class Dense(nn.Cell):
    def __init__(self, in_channels, out_channels, weight_init='normal', bias_init='zeros'):
        super(Dense, self).__init__()
        self.scalar = 1
        self.weight = Parameter(initializer(weight_init, [out_channels, in_channels]), name="weight")
        self.bias = Parameter(initializer(bias_init, [out_channels]), name="bias")
        self.matmul = ops.MatMul(transpose_b=True)

    def construct(self, x):
        x = self.matmul(x, self.weight)
        output = ops.bias_add(x, self.bias)
        return output

input_a = Tensor([[1, 2, 3], [4, 5, 6]], mindspore.float32)
input_b = Tensor([[2, 3, 4], [5, 6, 7]], mindspore.float32)
input_c = Tensor([[3, 4, 5], [6, 7, 8]], mindspore.float32)

dense_net = Dense(3, 4)
print(dense_net(input_a))
print(dense_net(input_b))
print(dense_net(input_c))

inputs = mnp.stack([input_a, input_b, input_c])

vmap_dense_net = vmap(dense_net)
print(vmap_dense_net(inputs))

Cell和函数式的自动向量化用法基本一致,只需要将vmap的第一个入参替换为Cell实例即可,Vmap将construct转换为作用于批处理数据的向量化construct。另外,该用例中初始化函数定义了两个Parameter参数, Vmap对于这类执行函数的自由变量的处理等同于将其作为入参并配置对应in_axes位置为None的场景。

通过这种方式,我们可以实现批量输入在同一个模型上进行训练或推理等功能,与现有网络模型输入支持batch轴输入的区别在于,利用Vmap实现的批处理维度更加灵活,不局限于NCHW等输入格式。

模型集成场景

模型集成场景将来自多个模型的预测结果组合在一起,传统的实现方式是通过分别在某些输入上运行各个模型,然后将各自的预测结果组合在一起。假如您正在运行的是具有相同架构的模型,那么您可以借助Vmap将它们进行向量化,从而实现加速效果。

该场景下涉及权重数据的向量化,如果您运行的模型是通过函数式编程形式实现,即权重参数在模型外部定义并通过入参传递给模型操作,那您可以直接通过配置in_axes的方式进行相应的批处理。而MindSpore框架为了提供便捷的模型定义功能,绝大部分nn接口的权重参数都在接口内部定义并初始化,这意味着模型中的权重参数在原始Vmap中无法对权重进行批处理,改造成通过入参传递的函数式实现需要额外工作量。不过您不必担心,MindSpore的vmap接口已经替您优化了该场景。您只需要将运行的多个模型实例以CellList的形式传入给vmap,框架即可自动实现权重参数的批处理。

让我们演示如何使用一组简单的CNN模型来实现模型集成推理和训练。

class LeNet5(nn.Cell):
    """
    LeNet-5网络结构
    """
    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
        self.fc1 = nn.Dense(16 * 5 * 5, 120)
        self.fc2 = nn.Dense(120, 84)
        self.fc3 = nn.Dense(84, num_class)
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

    def construct(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.max_pool2d(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

假设我们正在验证同一模型架构在不同权重参数下的效果,让我们模拟四个已经训练好的模型实例和一份batch大小为16,尺寸为32 x 32的虚拟图像数据集的minibatch。

net1 = LeNet5()
net2 = LeNet5()
net3 = LeNet5()
net4 = LeNet5()

minibatch = Tensor(mnp.randn(3, 1, 32, 32), mindspore.float32)

相较于利用for循环分别运行各个模型后将预测结果集合到一起,Vmap能够一次运行获得多个模型的预测结果。

总结

本教程重点在于介绍Vmap的场景使用说明,本质上自动向量化并非将循环逻辑执行于函数外部,而是将循环下沉至函数的各个基元操作中,并将映射轴信息在基元操作间传递,从而保证计算逻辑的正确性。Vmap的性能收益主要来自于各个基元操作所对应的VmapRule实现,由于循环下沉至算子层级,因而更容易结合并行技术进行性能优化,如果您有自定义算子的场景也可以尝试为自定义算子实现特定的VmapRule,从而获得更好的性能。对于性能极致追求的场景还可以再结合图算融合特性进行优化。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2189934.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

针对线上消息积压的排查思路以及解决方案

一、背景 我们在日常工作中,经常会碰到线上告警,消息队列消息积压了,试想如果对消息的消费速率有要求的场景,消息积压一定会或多或少对自己本身的业务场景有影响,这里就针对消息积压的场景,谈谈具体的排查…

过滤器Filter【详解】

过滤器Filter 1、 现有问题 在以往的Servlet中,有冗余的代码,多个Servlet都有重复的代码 比如编码格式设置 登录信息认证 2、 概念 过滤器(Filter)是处于客户端与服务器目标资源之间的一道过滤技术。 过滤器 3、 过滤器作用 执…

Python办公自动化教程(006):Word添加标题

2.3 word标题 介绍: 在 python-docx 中,您可以使用 add_heading() 方法为文档添加标题。此方法允许您指定标题的文本和级别(例如,一级标题、二级标题等)。标题级别的范围是从 0 到 9,其中 0 表示文档的主标…

深度解析:从浏览器输入链接到页面展现的奇幻历程

〇、前言 当我们在浏览器中输入一个网址,例如:example.com,按下回车键后,会发生什么呢? 主要会发生以下这些过程:域名解析、建立HTTP连接、发送HTTP请求、数据传输、渲染网页、断开HTTP连接。 一、域名解…

类型转换【C++提升】(隐式转换、显式转换、自定义转换、转换构造函数、转换运算符重载......你想知道的全都有)

更多精彩内容..... 🎉❤️播主の主页✨😘 Stark、-CSDN博客 本文所在专栏: C系列语法知识_Stark、的博客-CSDN博客 座右铭:梦想是一盏明灯,照亮我们前行的路,无论风雨多大,我们都要坚持不懈。 一…

【srm系统】供应商管理,招投标管理,电子采购系统,询价管理

前言: 随着互联网和数字技术的不断发展,企业采购管理逐渐走向数字化和智能化。数字化采购平台作为企业采购管理的新模式,能够提高采购效率、降低采购成本、优化供应商合作效率,已成为企业实现效益提升的关键手段。系统获取在文末…

[含文档+PPT+源码等]精品基于Python实现的美术馆网站的设计与实现

基于Python实现的美术馆网站,其设计与实现背景主要源于以下几个方面的需求和发展趋势: 一、文化艺术领域的发展需求 随着文化娱乐活动的日益丰富,美术馆作为展示艺术作品、传播文化的重要场所,其管理和服务模式的创新对于提升公…

LabVIEW提高开发效率技巧----使用动态事件

在LabVIEW开发过程中,用户交互行为可能是多样且不可预知的。为应对这些变化,使用动态事件是一种有效的策略。本文将从多个角度详细介绍动态事件的概念及其在LabVIEW开发中的应用技巧,并结合实际案例,说明如何通过动态事件提高程序…

【售后资料】软件售后服务方案(word原件)

软件售后服务方案的售后服务范围广泛,涵盖了多个方面,以确保客户在使用软件过程中得到全面、及时的支持。具体来说,这些服务范围通常包括以下几个核心内容: 技术支持服务维护与更新服务培训与教育服务定制化服务数据管理与服务客户…

如何获取网页内嵌入的视频?

如何获取网页内嵌入的视频? 有时插件无法识别的视频资源,可以通过手动使用浏览器的开发者工具来抓取。你可以按照以下步骤操作: 步骤: 打开网页并按 F12:在视频页面按下 F12 或右键点击网页并选择“检查”或“Inspe…

Spring Boot实现的大学生就业市场解决方案

1系统概述 1.1 研究背景 如今互联网高速发展,网络遍布全球,通过互联网发布的消息能快而方便的传播到世界每个角落,并且互联网上能传播的信息也很广,比如文字、图片、声音、视频等。从而,这种种好处使得互联网成了信息传…

【案例】距离限制模型透明

开发平台:Unity 2023 开发工具:Unity ShaderGraph   一、效果展示 二、路线图 三、案例分析 核心思路:计算算式:透明值 实际距离 / 最大距离 (实际距离 ≤ 最大距离)   3.1 说明 | 改变 Alpha 值 在 …

简易投影仪的制作

今天不做开发类的文章,来给大家整个活哈哈哈哈哈。由于前几天室友说看小屏幕的抖音太不舒服,比较累眼睛,所以我萌生出来一个制作投影仪的想法。于是查阅了资料最终完成以下的设计。 以下设计价格最高的是一部旧的可拆卸的智能手机 简易投影仪…

C++11新特性(基础)【2】

目录 1.范围for循环 2.智能指针 3.STL中一些变化 4.右值引用和移动语义 4.1 左值引用和右值引用 4.2 左值引用与右值引用比较 4.3 右值引用使用场景和意义 4.4 右值引用引用左值及其一些更深入的使用场景分析 4.5 完美转发 1.范围for循环 int main() {int array[10] { 1,2,3,4…

CSS | CSS中强大的margin负边距

css中的负边距(negative margin)是布局中的一个常用技巧,只要运用得合理常常会有意想不到的效果。很多特殊的css布局方法都依赖于负边距,所以掌握它的用法对于前端的同学来说,那是必须的。本文非常基础,老鸟可以略过。 一、负边距…

【宽搜】3. leetcode 515 在每个树行中找最大值

1 题目描述 题目链接:在每个树行中找最大值 2 题目解析 根据题目描述,是找出每一行中的最大值,这毋庸置疑是使用宽度优先遍历了。我在这篇文章中讲解了宽度优先遍历的模板,如果没有看的同学可以先去看一下。 这道题和模板的不…

基于微信小程序的调查问卷管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:…

数据结构与算法(七)静态链表

目录 前言 一、静态链表的引入 二、线性表的静态链表存储结构 三、静态链表的插入操作 四、静态链表的删除操作 五、静态链表的优缺点总结 1、优点 2、缺点 3、小结 六、单链表小结——Tecent面试题 1、普通解法: 2、高级解法: 前言 静态链表…

基于CAN总线的TMS320F28335 Bootloader设计说明

1 设计目的 根据客户要求,开发一款基于CAN总线的TI公司TMS320F28335 DSP(数字信号处理器)bootloader,以方便应用程序的刷写。CAN设备采用周立功CAN卡(USBCAN-I、USBCAN-II、USBCAN-E-mini)。 2 专有信息 …

一篇文章吃透OA系统

一、OA系统是什么,都有什么功能? OA系统(Office Automation System)是办公自动化系统的简称,是一种利用计算机技术和网络通信技术,为企业和组织提供办公管理和协作支持的信息化系统。OA系统旨在提高办公效…