[代码学习]einsum详解

news2024/9/27 15:30:33

einsum详解

该函数用于对一组输入 Tensor 进行 Einstein 求和,该函数目前仅适用于paddle的动态图。

Einstein 求和是一种采用 Einstein 标记法描述的 Tensor 求和,输入单个或多个 Tensor,输出单个 Tensor。

在这里插入图片描述

paddle.einsum(equation, *operands)

参数

  • equation (str):求和标记
  • operands (Tensor, [Tensor, …]):输入 Tensor

返回

  • Tensor:输出 Tensor

求和特例

  • 单操作数

    • 迹:trace

    • 对角元:diagonal

    • 转置:transpose

    • 求和:sum

  • 双操作数

    • 内积:dot

    • 外积:outer

    • 广播乘积:mul,*

    • 矩阵乘:matmul

    • 批量矩阵乘:bmm

  • 多操作数

    • 广播乘积:mul,*

    • 多矩阵乘:A.matmul(B).matmul(C)

关于求和标记的约定

  • 维度分量下标:Tensor 的维度分量下标使用英文字母表示,不区分大小写,如’ijk’表示 Tensor 维度分量为 i,j,k

  • 下标对应输入操作数:维度下标以`,`分段,按顺序 1-1 对应输入操作数

  • 广播维度:省略号`…`表示维度的广播分量,例如,'i…j’表示首末分量除外的维度需进行广播对齐

  • 自由标和哑标:输入标记中仅出现一次的下标为自由标,重复出现的下标为哑标,哑标对应的维度分量将被规约消去

  • 输出:输出 Tensor 的维度分量既可由输入标记自动推导,也可以用输出标记定制化

  • 自动推导输出

    • 广播维度分量位于维度向量高维位置,自由标维度分量按字母顺序排序,位于维度向量低纬位置,哑标维度分量不输出
  • 定制化输出

    • 维度标记中`->`右侧为输出标记

    • 若输出包含广播维度,则输出标记需包含`…`

    • 输出标记为空时,对输出进行全量求和,返回该标量

    • 输出不能包含输入标记中未出现的下标

    • 输出下标不可以重复出现

    • 哑标出现在输出标记中则自动提升为自由标

    • 输出标记中未出现的自由标被降为哑标

例子

  • ‘…ij, …jk’,该标记中 i,k 为自由标,j 为哑标,输出维度’…ik’

  • ‘ij -> i’,i 为自由标,j 为哑标

  • ‘…ij, …jk -> …ijk’,i,j,k 均为自由标

  • ‘…ij, …jk -> ij’,若输入 Tensor 中的广播维度不为空,则该标记为无效标记

求和规则

Einsum 求和过程理论上等价于如下四步,但实现中实际执行的步骤会有差异。

  • 第一步,维度对齐:将所有标记按字母序排序,按照标记顺序将输入 Tensor 逐一转置、补齐维度,使得处理后的所有 Tensor 其维度标记保持一致

  • 第二步,广播乘积:以维度下标为索引进行广播点乘

  • 第三步,维度规约:将哑标对应的维度分量求和消除

  • 第四步,转置输出:若存在输出标记,则按标记进行转置,否则按广播维度+字母序自由标的顺序转置,返回转之后的 Tensor 作为输出

关于 trace 和 diagonal 的标记约定(待实现功能)

  • 在单个输入 Tensor 的标记中重复出现的下标称为对角标,对角标对应的坐标轴需进行对角化操作,如’i…i’表示需对首尾坐标轴进行对角化

  • 若无输出标记或输出标记中不包含对角标,则对角标对应维度规约为标量,相应维度取消,等价于 trace 操作

  • 若输出标记中包含对角标,则保留对角标维度,等价于 diagonal 操作

实例实践

首先,看一下一维度简单实验:

import paddle

# 定义两个输入矩阵
# paddle.seed(102)
# x = paddle.rand([4])
# y = paddle.rand([5])
x = paddle.to_tensor([1,2,], dtype='float32')
y = paddle.to_tensor([3,4,5], dtype='float32')

# sum
sum_x = paddle.einsum('i->', x).numpy()

# dot
dox_x = paddle.einsum('i,i->', x, x).numpy()

# outer
outer_xy = paddle.einsum("i,j->ij", x, y).numpy()

print(f"x: {x.numpy()}, shape: {x.shape}")
print(f"y: {y.numpy()}, shape: {y.shape}")
print(f"sum_x: {sum_x}, shape: {sum_x.shape}")
print(f"dox_x: {dox_x}, shape: {dox_x.shape}")
print(f"outer_xy: {outer_xy}, shape: {outer_xy.shape}")

结果输出为:

x: [1. 2.], shape: [2]
y: [3. 4. 5.], shape: [3]
sum_x: 3.0, shape: ()
dox_x: 5.0, shape: ()
outer_xy: [[ 3.  4.  5.]
 [ 6.  8. 10.]], shape: (2, 3)

然后,看一下高纬度的实验:

import paddle

# A = paddle.rand([2, 3, 2])
# B = paddle.rand([2, 2, 3])
A = paddle.to_tensor([[[1,2],[1,2],[1,2]], [[1,2],[1,2],[1,2]]], dtype='float32')
B = paddle.to_tensor([[[3,4,5],[3,4,5]], [[3,4,5],[3,4,5]]], dtype='float32')

# transpose
transpose_A = paddle.einsum('ijk->kji', A)

# batch matrix multiplication
BMM_AB = paddle.einsum('ijk, ikl->ijl', A,B)

# Ellipsis transpose
ET_A = paddle.einsum('...jk->...kj', A)

# Ellipsis batch matrix multiplication
EBMM_AB = paddle.einsum('...jk, ...kl->...jl', A,B)

print(f"A: {A.numpy()}, shape: {A.shape}")
print(f"B: {B.numpy()}, shape: {B.shape}")
print(f"transpose_A: {transpose_A.numpy()}, shape: {transpose_A.shape}")
print(f"BMM_AB: {BMM_AB.numpy()}, shape: {BMM_AB.shape}")
print(f"ET_A: {ET_A.numpy()}, shape: {ET_A.shape}")
print(f"EBMM_AB: {EBMM_AB.numpy()}, shape: {EBMM_AB.shape}")

结果输出为:

A: [[[1. 2.]
  [1. 2.]
  [1. 2.]]

 [[1. 2.]
  [1. 2.]
  [1. 2.]]], shape: [2, 3, 2]
B: [[[3. 4. 5.]
  [3. 4. 5.]]

 [[3. 4. 5.]
  [3. 4. 5.]]], shape: [2, 2, 3]
transpose_A: [[[1. 1.]
  [1. 1.]
  [1. 1.]]

 [[2. 2.]
  [2. 2.]
  [2. 2.]]], shape: [2, 3, 2]
BMM_AB: [[[ 9. 12. 15.]
  [ 9. 12. 15.]
  [ 9. 12. 15.]]

 [[ 9. 12. 15.]
  [ 9. 12. 15.]
  [ 9. 12. 15.]]], shape: [2, 3, 3]
ET_A: [[[1. 1. 1.]
  [2. 2. 2.]]

 [[1. 1. 1.]
  [2. 2. 2.]]], shape: [2, 2, 3]
EBMM_AB: [[[ 9. 12. 15.]
  [ 9. 12. 15.]
  [ 9. 12. 15.]]

 [[ 9. 12. 15.]
  [ 9. 12. 15.]
  [ 9. 12. 15.]]], shape: [2, 3, 3]

reference

关于matmul可以查看:https://blog.csdn.net/orDream/article/details/133744368
官方链接:
@misc{BibEntry2023Oct,
title = {{einsum-API文档-PaddlePaddle深度学习平台}},
year = {2023},
month = oct,
urldate = {2023-10-10},
language = {chinese},
note = {[Online; accessed 10. Oct. 2023]},
url = {https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/einsum_cn.html}
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1078651.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Win11自定义目录安装Linux子系统wsl

1. 启用适用于 Linux 的 Windows 子系统和虚拟机功能 以管理员身份打开 PowerShell(“开始”菜单 >“PowerShell” >单击右键 >“以管理员身份运行”),然后依次输入执行以下命令: dism.exe /online /enable-feature /f…

3ds Max渲染太慢?创意云“一键云渲染”提升3ds Max渲染体验

在数字艺术设计领域,3ds Max是广泛使用的三维建模和渲染软件之一。然而,许多用户都面临着一个共同的问题:渲染速度太慢。渲染一帧画面需要耗费数小时,让人无法忍受。除了之前给大家介绍的几种解决方法外: …

【斗破年番】导演紧急删减第66集预告,陨落心炎事件要重演?

Hello,小伙伴们,我是小郑继续为大家深度解析斗破苍穹年番最新资讯。 斗破苍穹年番第65集已经出来了,在这一集出来后按例官方放出来第66集和第67集的预告。只是让小郑没有想到的是,刚开始看第66集预告还好好的,但是等到再看的时候就…

代理SSL证书的优势——JoySSL

随着互联网的发展,越来越多的企业和个人开始使用网站来提供服务和信息。而SSL证书作为保障网站安全的重要工具,也逐渐被广泛应用。然而,对于许多企业和个人来说,购买和安装SSL证书是一项昂贵的任务。这就需要代理SSL证书的出现&am…

SLAM从入门到精通(camera数据读取)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 实际ros开发的时候,现场有很多特征都可以用来进行采集和标定。比如说地面,对于外资企业或者管理比较规范的企业来说&#x…

UNet及其变体在医学图像分割中的性能分析

论文链接:https://arxiv.org/abs/2309.13013 机构:英国伦敦布鲁内尔大学 日期:20230922 因为太长了长达37页所以我也就记点重点内容了hhh,我重点关注的还是在Unet以及其变体上,不过感觉严格意义上来说里面提到的方法不算很新&a…

嵌入式C语言中整形溢出问题分析

整型溢出有点老生常谈了,bla, bla, bla… 但似乎没有引起多少人的重视。整型溢出会有可能导致缓冲区溢出,缓冲区溢出会导致各种黑客攻击。 今天分享一篇文章,希望大家都了解一下整型溢出,编译器的行为,以及如何防范&a…

【安全】 Java 过滤器 解决存储型xss攻击问题

文章目录 XSS简介什么是XSS?分类反射型存储型 XSS(cross site script)跨站脚本攻击攻击场景解决方案 XSS简介 跨站脚本( cross site script )为了避免与样式css(Cascading Style Sheets层叠样式表)混淆,所以简称为XSS。 XSS是一种经常出现在web应用中的计算机安全…

mistyR官网教程 空转spatial

Modeling spatially resolved omics with mistyR • mistyR (saezlab.github.io) mistyR and data formats • mistyR (saezlab.github.io) Heidelberg University and Heidelberg University Hospital, Heidelberg, Germany Jožef Stefan Institute, Ljubljana, Sloveniajov…

阿里5年经验之谈 —— 浅谈自动化测试方法!

导读 在当今快节奏的软件开发环境中,高质量的代码交付至关重要。而针对经过多次迭代,主要功能趋向稳定的产品,大量传统的重复性手动测试方法已经无法满足高效、快速的需求。为了提高测试效率保证产品质量,本文通过产品实践应用&a…

Python接口自动化测试之token参数关联

前言 在做自动化接口测试时,有时候会遇到token的动态关联,例如查询余额接口,需要关联登录接口的token动态值,如何利用python脚本进行接口token关联呢?今天我们爱学习一下吧! 一:获取登录接口返回的token…

研发项目管理系统对比:找到最适合的高效工具

研发部门是企业非常重要的部门,代表着企业未来能否在市场上拥有优秀的技术,站稳市场份额。很多企业的研发方式往往是瀑布式的,所以项目的阶段规划,然后每个阶段的需求分配给开发人员,可以随时查看每个需求的开发进度和…

Redis学习5——有序集合Zset数据类型的操作

有序集合Zset 常用命令 数据结构 跳跃表 跳跃表

移远通信EM060K系列LTE-A Cat 6模组完成全球认证覆盖

近日,移远通信LTE-A Cat 6模组EM060K系列顺利完成全球认证覆盖,将以卓越的性能和品质助力海内外客户终端大规模部署,为其提供畅快的高速网络连接。同时,凭借着有竞争力的性能和成本优势,EM060K系列将加速释放海外固定无…

[架构之路-235]:目标系统 - 纵向分层 - 数据库 - 数据库系统基础与概述:数据库定义、核心概念、系统组成

目录 一、核心概念 1.1 什么是数据与信息 1.2 数据与数据库的关系 1.3 什么是数据库 1.4 数据库中的数据的特点 1.5 数据库与数据结构的关系 二、数据库系统 2.1 什么是数据库管理系统 2.2 什么是数据库系统 2.3 数据库相关的人员 2.4 数据库的主要功能 2.5 Excel表…

Vuex的基础使用存值及异步

目录 一、概述 ( 1 ) 讲述 ( 2 ) 概念 ( 3 ) 作用 二、取值 1. 安装 2. 菜单栏 3. 模块 4. 引用 三、改值 四、异步&后台请求 带来的获取 一、概述 ( 1 ) 讲述 Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式。它采用集中式存储管理应用的所有组件的…

【Linux初阶】多线程1 | 页表的索引作用,线程基础(优缺点、异常、用途),线程VS进程,线程控制,C++多线程引入

文章目录 ☀️一、深入理解页表☀️二、Linux线程概念🌻1.什么是线程(重点)⚡(1)线程的概念⚡(2)线程库初识 🌻2.线程的优点🌻3.线程的缺点🌻4.线程异常&…

为什么设置静态代理IP后无法正常上网,怎么解决?

静态代理IP是一个固定的IP地址,因为其出色的稳定性和安全性而得到广泛应用,常用于一些对网络质量要求高、需要长期稳定和持续可靠连接的业务。设置静态代理IP后无法上网是用户常见的网络问题,通常有多种原因: 1. 静态代理IP不可用…

【Flutter学习】AppBar

App Bar 可以视为页面的标题栏,在 Flutter 中用AppBar组件实现。 一个简单的AppBar实现代码如下: import package:flutter/material.dart;void main() {runApp(const AppBarTest()); }class AppBarTest extends StatelessWidget {const AppBarTest({Key…

【AGC】云托管新建站点时间过长的问题排查方法

【问题描述】 开发者按照指导文档使用云托管服务,已经申请了域名,在创建站点时页面显示证书配置最长需要12小时,然而,在等了两天后依然是激活中的状态,没有如期上线。 ​ 【解决方案】 卡在上线中的状态有以下几个原…