【PyTorch函数解析】einsum的用法示例

news2024/12/27 1:37:39

一、前言

einsum 是一个非常强大的函数,用于执行张量(Tensor)运算。它的名称来源于爱因斯坦求和约定(Einstein summation convention),在PyTorch中,einsum 可以方便地进行多维数组的操作和计算。

在Transfomer中,einsum用的非常多,比如使用 einsum 实现自注意力机制中注意力权重的获取,也就是Q和K的内积:

  • Q(Query):形状为 (batch_size, seq_len, d_k)

  • K(Key):形状为 (batch_size, seq_len, d_k)

import torch
import torch.nn.functional as F

Q = torch.randn(2, 10, 64)  # (batch_size, seq_len, d_k)
K = torch.randn(2, 10, 64)  # (batch_size, seq_len, d_k)

# (batch_size, seq_len, seq_len)
attention_scores = torch.einsum('bqd,bkd->bqk', Q, K) / torch.sqrt(torch.tensor(64.0))
# (batch_size, seq_len, seq_len)   
attention_weights = F.softmax(attention_scores, dim=-1)  

二、常见用法示例

2.1 向量点积

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.einsum('i,i->', a, b)
print(result)  # 输出 32

这里,'i,i->' 表示对向量 a 和 b 进行点积操作,其中 i 是索引表示,-> 之后为空表示求和。

2.2 矩阵乘法

A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = torch.einsum('ij,jk->ik', A, B)
print(result)  # 输出 tensor([[19, 22], [43, 50]])

这里,'ij,jk->ik' 表示矩阵乘法,其中 i 和 k 是结果的维度,j 是求和维度。

2.3 批量矩阵乘法

A = torch.randn(2, 3, 4)
B = torch.randn(2, 4, 5)
result = torch.einsum('bij,bjk->bik', A, B)

这里,'bij,bjk->bik' 表示对批量的矩阵进行乘法运算。

解释:

bij,bjk分别是A和B的3个维度,用字符串的形式指代。

为什么最后得到的是bik呢?这个和线性代数的矩阵运算规则有关系。

矩阵乘法规则:

  • 给定矩阵 A 的形状为 (m,n)

  • 给定矩阵 B 的形状为 (n,p)

  • 矩阵乘法 A×B 的结果矩阵 C 的形状为 (m,p)

在矩阵乘法中,结果矩阵的每个元素 Cik 是通过 A 的第 i 行和 B 的第 k 列的对应元素相乘并求和得到的,即:

C_{ik}=\sum_{j=1}^nA_{ij}\cdot B_{jk}

计算过程:

1. 匹配批次维度 (b)

  • 对于每个批次,独立进行矩阵乘法运算。

2. 求和维度 (j):

  • j 是两个张量中共同的维度,根据线性代数中的矩阵乘法规则,需要对 j 维度进行求和。

3. 保留和产生的维度:

  • i 来自 A,表示保留 A 的第一个维度。

  • k 来自 B,表示保留 B 的第二个维度。

经过上述分析,einsum 的结果保留了 b(批次维度)、i(来自 A 的第一个维度)和 k(来自 B 的第二个维度)。因此,结果张量的形状为 (batch_size, seq_len_i, seq_len_k),也就是 bik。

同样,延伸到4维计算的话。

torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

首先,假设 queries 和 keys 的形状为:

  • queries: (batch_size, seq_len_q, num_heads, head_dim)

  • keys: (batch_size, seq_len_k, num_heads, head_dim)

用具体变量名表示:

  • n: batch_size,批次大小。

  • q: seq_len_q,查询序列的长度。

  • k: seq_len_k,键序列的长度。

  • h: num_heads,多头注意力中的头数。

  • d: head_dim,每个头的维度。

1. 匹配批次维度 (n) 和头部维度 (h):

  • 批次大小和头部数量在两个输入张量中都是相同的,保持不变。

2. 求和维度 (d):

  • d 表示每个头的维度。在 queries 和 keys 中,d 都是最后一个维度,对这个维度进行点积运算后求和。

3. 保留和产生的维度:

  • q 来自 queries,表示查询序列的长度。

  • k 来自 keys,表示键序列的长度。

所以最后是nhqk。

2.4 转置操作

A = torch.tensor([[1, 2, 3], [4, 5, 6]])
result = torch.einsum('ij->ji', A)
print(result)  # 输出 tensor([[1, 4], [2, 5], [3, 6]])

这里,'ij->ji' 表示将矩阵进行转置操作。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1861642.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PCL笔记二 之VS环境配置(不同版本Debug+Release编译)

PCL笔记二 之VS环境配置(不同版本DebugRelease编译) PCL官网:https://github.com/PointCloudLibrary/pcl/releases众所周知,PCL是一个用于点云处理并且依赖不少三方库的一个算法库,同时在编译配置环境时也很复杂&…

【ONLYOFFICE8.1桌面编辑器】强势来袭—— 一款全面的办公软件套件

在日常工作和学习中,我们经常需要处理各种文档、表格和演示文稿。一款功能强大、易于使用的办公软件成为我们提高工作效率、便捷沟通和展示想法的得力助手。 而ONLYOFFICE 8.1桌面编辑器正是一款全面、高效的办公软件,集合了Word、PPT、Excel的功能&…

如何设置windows计划任务

如何设置windows计划任务 前言:在工作过程中写了一个python脚本,用于调用jira接口查询bug单数量,想要在本地定时任务执行,每天发送到钉钉群提醒,写下操作步骤用于记录。 1. 准备 Python 脚本 确保你的 Python 脚本已…

暗影精灵8Pro声音没有了,这个方法可以解决,亲测有效!

这个OMEN by HP Gaming Laptop 16-k0xxx Windows 10 Sound Driver Mod ,真的解决了我的大问题! 如果你的暗影精灵8 Pro酷睿版突然变得哑巴了,扬声器和麦克风都发不出声音,那可能是声卡驱动出了问题。 别担心,我也是个…

Android简介

1. Android简介 Android是一种基于Linux内核的自由及开放源代码的操作系统。最初是由安迪鲁宾(Andy Rubin)开发的一款相机操作系统。2005年8月被Google收购。2007年11月,Google与84家硬件制造商、软件开发商及电信营运商组建开放手机联盟共同研发改良Android系统。…

8.12 矢量图层面要素单一符号使用七(随机标记填充)

文章目录 前言随机标记填充(Random Marker Fill)QGis设置面符号为随机标记填充(Random Marker Fill)二次开发代码实现随机标记填充(Random Marker Fill) 总结 前言 本章介绍矢量图层线要素单一符号中使用随…

手机照片压缩到20k以内免费,这几款心动软件快收好!

在数字化时代,手机拍照已成为我们记录生活的重要方式之一。然而,高清的照片也意味着占用着越来越多的手机存储空间。如果你正在为手机内存告急而烦恼,那么这几款手机照片压缩神器或许能成为你的救星!它们不仅可以将照片轻松压缩至…

Three.js——第一篇:部署以及基础代码创建场景、GUI调整样式

three.js官网:three.js docs 中文技术文档1:| 麒跃科技 中文技术文档2:3. 开发和学习环境,引入threejs | Three.js中文网 很多教程一开始要大家自己部署three.js的中文本地部署,我就不弄了,我弄了半天也没…

大厂薪资福利篇第四弹:字节跳动

欢迎来到绝命Coding! 今天继续更新大家最关心的 大厂薪资福利系列! 往期分享: 福利开水喝不完?大厂薪资福利篇!美团 职场文化发源地?大厂薪资福利篇!阿里巴巴 给这么多!还能带宠物上…

Adams Flex模块功能介绍

通过该教程对Adams Flex模块有基本的认知,为以后使用柔性体进行刚柔耦合做好基础学习。 有需要购买的可以邮箱(digitaltwins126.com)或站内信联系,谢谢!

机器学习之数学基础(七)~过拟合(over-fitting)和欠拟合(under-fitting)

目录 1. 过拟合与欠拟合 1.1 Preliminary concept 1.2 过拟合 over-fitting 1.3 欠拟合 under-fitting 1.4 案例解析:黑天鹅 1. 过拟合与欠拟合 1.1 Preliminary concept 误差 经验误差:模型对训练集数据的误差。泛化误差:模型对测试…

基于SpringBoot的“智慧食堂”管理系统设计与实现

你好呀,我是计算机学姐码农小野!如果有相关需求,可以私信联系我。 开发语言:Java 数据库:MySQL 技术:SpringBootVue 工具:IDEA/Eclipse、Navicat、Maven 系统展示 首页 用户管理界面 菜品…

超炫酷, 不用学前端也能自己做网页!这个Python库,3分钟内复刻GPT WEB应用

大家好,我是海鸽。 今天,我要和大家分享如何将请求 GPT 的案例,快速“复刻”成 GPT 网页版。这不仅简单,而且对于我们这些后端开发者来说,简直是福音! 先睹为快 看看这个界面,是不是感觉很熟…

更适合敏感口腔的护理牙刷

最近在用一款清九野小红盾舒敏牙刷,感觉它很适合牙龈敏感的人,让刷牙体验有了显著的提升。这款牙刷的柔软刷毛和精细设计让我的刷牙过程变得轻松愉快。它的内外圈双重植毛技术,在清洁牙齿的同时,还能深入牙缝,温和地去…

js实现数据加密,jwt加密方式

一个简单的数据加密 const crypto require("crypto");// 普通的数据加密 function sign(msg,key){ // 原始信息,密钥 String// "sha256" :加密的算法,key :密钥,msg :要加密的信息,"hex" :转成16…

攻击者开始使用 XLL 文件进行攻击

近期,研究人员发现使用恶意 Microsoft Excel 加载项(XLL)文件发起攻击的行动有所增加,这项技术的 MITRE ATT&CK 技术项编号为 T1137.006。 这些加载项都是为了使用户能够利用高性能函数,为 Excel 工作表提供 API …

【Mac】XnViewMP for Mac(图片浏览查看器)及同类型软件介绍

软件介绍 XnViewMP 是一款多功能、跨平台的图像查看和管理软件,适用于 macOS、Windows 和 Linux 系统。它是经典 XnView 软件的增强版本,更加现代化且功能更强大。XnViewMP 支持数百种图像格式,并提供多种图像处理工具,使其成为摄…

基于Java微信小程序自驾游拼团设计和实现(源码+LW+调试文档+讲解等)

💗博主介绍:✌全网粉丝10W,CSDN作者、博客专家、全栈领域优质创作者,博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 🌟文末获取源码数据库🌟感兴趣的可以先收藏起来,还…

全省高等职业学校大数据技术专业建设暨专业质量监测研讨活动顺利开展

6月21日,省教育评估院在四川邮电职业技术学院组织开展全省高等职业学校大数据技术专业建设暨专业质量监测研讨活动。省教育评估院副院长赖长春,四川邮电职业技术学院党委副书记、校长冯远洪,四川邮电职业技术学院党委委员、副校长程德杰等出席…