Pytorch中张量矩阵乘法函数(mm, bmm, matmul)使用说明,含高维张量实例及运行结果

news2024/11/25 12:13:28

Pytorch中张量矩阵乘法函数使用说明

  • 1 torch.mm() 函数
    • 1.1 torch.mm() 函数定义及参数
    • 1.2 torch.bmm() 官方示例
  • 2 torch.bmm() 函数
    • 2.1 torch.bmm() 函数定义及参数
    • 2.2 torch.bmm() 官方示例
  • 3 torch.matmul() 函数
    • 3.1 torch.matmul() 函数定义及参数
    • 3.2 torch.matmul() 规则约定
    • 3.3 torch.matmul() 官方示例
    • 3.4 高维数据实例解释
  • 参考博文及感谢

1 torch.mm() 函数

全称为matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是2维

1.1 torch.mm() 函数定义及参数

torch.bmm(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一个要相乘的矩阵
** mat2
* (Tensor) – – 第二个要相乘的矩阵
不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

1.2 torch.bmm() 官方示例

mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)

tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

2 torch.bmm() 函数

全称为batch matrix-matrix product,对输入的张量做矩阵乘法运算,输入输出维度一定是3维;

2.1 torch.bmm() 函数定义及参数

torch.bmm(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一批要相乘的矩阵
** mat2
* (Tensor) – – 第二批要相乘的矩阵
不支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

2.2 torch.bmm() 官方示例

input = torch.randn(10, 3, 4)
mat2 = torch.randn(10, 4, 5)
res = torch.bmm(input, mat2)
res.size()

torch.Size([10, 3, 5])

3 torch.matmul() 函数

可进行多维矩阵运算,根据不同输入维度进行广播机制然后运算,和点积类似,广播机制可参考之前博文torch.mul()函数。

3.1 torch.matmul() 函数定义及参数

torch.matmul(input, mat2, , out=None) → Tensor
input (Tensor) – – 第一个要相乘的张量
** mat2
* (Tensor) – – 第二个要相乘的张量
支持广播到通用形状、类型推广以及整数、浮点和复杂输入。

3.2 torch.matmul() 规则约定

(1)若两个都是1D(向量)的,则返回两个向量的点积;

(2)若两个都是2D(矩阵)的,则按照(矩阵相乘)规则返回2D;

(3)若input维度1D,other维度2D,则先将1D的维度扩充到2D(1D的维数前面+1),然后得到结果后再将此维度去掉,得到的与input的维度相同。即使作扩充(广播)处理,input的维度也要和other维度做对应关系;

(4)若input是2D,other是1D,则返回两者的点积结果;

(5)如果一个维度至少是1D,另外一个大于2D,则返回的是一个批矩阵乘法( a batched matrix multiply)

  • (a)若input是1D,other是大于2D的,则类似于规则(3);
  • (b)若other是1D,input是大于2D的,则类似于规则(4);
  • (c)若input和other都是3D的,则与torch.bmm()函数功能一样;
  • (d)如果input中某一维度满足可以广播(扩充),那么也是可以进行相乘操作的。例如 input(j,1,n,m)* other (k,m,p) = output(j,k,n,p)

matmul() 根据输入矩阵自动决定如何相乘。低维根据高维需求,合理广播。

3.3 torch.matmul() 官方示例

# vector x vector
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()

torch.Size([])
# matrix x vector
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()

torch.Size([3])
# batched matrix x broadcasted vector
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()

torch.Size([10, 3])
# batched matrix x batched matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()

torch.Size([10, 3, 5])
# batched matrix x broadcasted matrix
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4, 5)
torch.matmul(tensor1, tensor2).size()

torch.Size([10, 3, 5])

3.4 高维数据实例解释

直接看一个4维的二值例子,先看图(红虚线和实线是为了便于区分维度而添加),不懂再结合代码和结果分析,先做广播,然后对应矩阵进行乘积运算
在这里插入图片描述

代码如下:

import torch
import numpy as np

np.random.seed(2022)
a = np.random.randint(low=0, high=2, size=(2, 2, 3, 4))
a = torch.tensor(a)
b = np.random.randint(low=0, high=2, size=(2, 1, 4, 3))
b = torch.tensor(b)
c = torch.matmul(a, b)
# or
# c = a @ b
print(a)
print("=============================================")
print(b)
print("=============================================")
print(c.size())
print("=============================================")
print(c)

运行结果为:

tensor([[[[1, 0, 1, 0],
          [1, 1, 0, 1],
          [0, 0, 0, 0]],

         [[1, 1, 1, 1],
          [1, 1, 0, 0],
          [0, 1, 0, 1]]],


        [[[0, 0, 0, 1],
          [0, 0, 0, 1],
          [0, 1, 0, 0]],

         [[1, 1, 1, 1],
          [1, 1, 1, 1],
          [0, 0, 0, 0]]]], dtype=torch.int32)
=============================================
tensor([[[[0, 1, 0],
          [1, 1, 0],
          [0, 0, 0],
          [1, 1, 0]]],


        [[[0, 1, 0],
          [1, 1, 1],
          [1, 1, 1],
          [1, 0, 1]]]], dtype=torch.int32)
=============================================
torch.Size([2, 2, 3, 3])
=============================================
tensor([[[[0, 1, 0],
          [2, 3, 0],
          [0, 0, 0]],

         [[2, 3, 0],
          [1, 2, 0],
          [2, 2, 0]]],


        [[[1, 0, 1],
          [1, 0, 1],
          [1, 1, 1]],

         [[3, 3, 3],
          [3, 3, 3],
          [0, 0, 0]]]], dtype=torch.int32)

参考博文及感谢

部分内容参考以下链接,这里表示感谢 Thanks♪(・ω・)ノ
参考博文1 官方文档查询地址
https://pytorch.org/docs/stable/index.html
参考博文2 Pytorch矩阵乘法之torch.mul() 、 torch.mm() 及torch.matmul()的区别
https://blog.csdn.net/irober/article/details/113686080

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1010318.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

黑马头条 热点文章实时计算、kafkaStream

热点文章-实时计算 1 今日内容 1.1 定时计算与实时计算 1.2 今日内容 kafkaStream 什么是流式计算kafkaStream概述kafkaStream入门案例Springboot集成kafkaStream 实时计算 用户行为发送消息kafkaStream聚合处理消息更新文章行为数量替换热点文章数据 2 实时流式计算 2…

【C++】哈希思想的应用——位图、布隆过滤器和哈希切割

前言: 前面我们学习了unordered_map和unordered_set和哈希表哈希桶等,并且我们自己用哈希桶封装了unordered_map和unordered_set。我们知道哈希的查找效率非常高为O(1),本章我们将延续哈希的思想,共同学习哈希的应用。 目录 &am…

sh脚本工具集锦(文件批量操作、音视频相关)持续更新

1 文件夹目录下所有图片转换成视频文件 pic_2_videos.sh: #!/bin/bash # 放到图片文件夹目录下,把所有jpeg图片推成视频文件 # sh pic_2_videos.sh 0 # 0: pad to 1920*1080 ; 1 or other no pad pad_1920$1if [[ $pad_1920 0 ]] thenfilesls|grep jp…

【Flink】 FlinkCDC读取Mysql( DataStream 方式)(带完整源码,直接可使用)

简介: FlinkCDC读取Mysql数据源,程序中使用了自定义反序列化器,完整的Flink结构,开箱即用。 本工程提供 1、项目源码及详细注释,简单修改即可用在实际生产代码 2、成功编译截图 3、自己编译过程中可能出现的问题 4、mysql建表语句及测试数据 5、修复FlinkCDC读取Mys…

【C++】匿名对象 ① ( 匿名对象引入 | 匿名对象简介 | 匿名对象概念 | 匿名对象作用域 - 对象创建与销毁 )

文章目录 一、匿名对象引入二、匿名对象简介1、匿名对象概念2、匿名对象作用域 - 对象创建与销毁3、代码示例 - 创建并使用匿名对象 一、匿名对象引入 匿名对象引入 : 在上一篇博客 【C】拷贝构造函数调用时机 ② ( 对象值作为函数参数 | 对象值作为函数返回值 ) 中 , 讲到了 如…

【Java基础】- RMI原理和使用详解

【Java基础】- RMI原理和使用详解 文章目录 【Java基础】- RMI原理和使用详解一、什么RMI二、RMI原理2.1 工作原理图2.2 工作原理 三、RMI远程调用步骤3.1 RMI远程调用运行流程图3.2 RMI 远程调用步骤 四、JAVA RMI简单实现4.1 如何实现一个RMI程序4.2 JAVA实现RMI程序 一、什么…

小程序中如何查看指定会员的所有订单?

在小程序中,查看指定会员的所有订单可以通过如下方式实现。 1. 找到指定的会员卡。在管理员后台->会员管理处,找到需要查看订单记录的会员卡。也支持对会员卡按卡号、手机号和等级进行搜索。 2. 查看会员卡详情。点击查看详情进入该会员卡的详情页面…

GTS 中testPersistentProcessMemory fail 详解

0. 前言 GTS 在测试 case armeabi-v7a GtsMemoryTestCases 的时候出现下面异常,本文总结一下。 com.google.android.memory.gts.MemoryTest#testPersistentProcessMemory 1. error log 09-14 09:41:40.523 10182 13340 13359 E TestRunner: failed: testPersiste…

网攻西北工业大学的美国安局人员真实身份锁定!

14日,《环球时报》从国家计算机病毒应急处理中心和360获悉,在侦办西北工业大学网络攻击案过程中,我方成功提取了名为“二次约会”(Second Date)“间谍”软件的多个样本。在多国业内伙伴通力合作下,现已成功…

【ELK】日志分析系统概述及部署

目录 一、ELK概述 1、ELK是什么? 2、ELK的组成部分 2.1 ElasticSearch (1)分片和副本 (2)es和传统数据库的区别 2.2 Kiabana 2.3 Logstash (1)Log Stash主要组件 2.4 可添加的其它组件 …

Opencv之区域生长和分裂

区域生长 1.基本原理 区域生长法是较为基础的一种区域分割方法 它的基本思想我说的通俗些,即是一开始有一个生长点(可以一个像素也可以是一个小区域),从这个生长点开始往外扩充,扩充的意思就是它会把跟自己有相似特征…

Python工程师Java之路(p)Maven聚合和继承

文章目录 依赖管理依赖传递可选依赖和排除依赖 继承与聚合 依赖管理 指当前项目运行所需的jar&#xff0c;一个项目可以设置多个依赖 <!-- 设置当前项目所依赖的所有jar --> <dependencies><!-- 设置具体的依赖 --><dependency><!-- 依赖所属群组…

gmssl v2 用 dgst 命令通过 sm2 签名出的结果,在别的工具上无法验签的问题分析

结论 通过分析发现&#xff0c;导致问题的原因是&#xff1a;gmssl v2 调用的算法不是 sm2 算法。 分析详情 具体情况如下所述 在 gmssl 调用 pkey_ec_init 函数时&#xff0c;默认会把 ec_scheme 设置为 NID_secg_scheme 签名的过程中会调用 pkey_ec_sign 函数&#xff0c…

【Redis】Redis的特性和应用场景 · 数据类型 · 持久化 · 数据淘汰 · 事务 · 多机部署

【Redis】Redis常见面试题&#xff08;3&#xff09; 文章目录 【Redis】Redis常见面试题&#xff08;3&#xff09;1. 特性&应用场景1.1 Redis能实现什么功能1.2 Redis支持分布式的原理1.3 为什么Redis这么快1.4 Redis实现分布式锁1.5 Redis作为缓存 2. 数据类型2.1 Redis…

Day_14 > 指针进阶(3)> bubble函数

目录 1.回顾回调函数 2.写一个bubble_sort函数 2.1认识一下qsort函数 ​编辑2.2写bubble_sort函数 今天我们继续深入学习指针 1.回顾回调函数 我们回顾一下之前学过的回调函数 回调函数就是一个通过函数指针调用的函数 如果你把函数的指针&#xff08;地址&#xff09;…

​Qt for Python 入门¶​

本页重点介绍如何从源代码构建Qt for Python&#xff0c;如果你只想安装PySide2。 与你需要运行&#xff1a;pip pip install pyside2有关更多详细信息&#xff0c;请参阅我们的快速入门指南。此外&#xff0c;您可以 查看与项目相关的常见问题解答。 一般要求 Python&#xf…

博客系统(升级(Spring))(四)(完)基本功能(阅读,修改,添加,删除文章)(附带项目)

博客系统 (三&#xff09; 博客系统博客主页前端后端个人博客前端后端显示个人文章删除文章 修改文章前端后端提取文章修改文章 显示正文内容前端后端文章阅读量功能 添加文章前端后端 如何使用Redis项目地点&#xff1a; 博客系统 博客系统是干什么的&#xff1f; CSDN就是一…

Activity生命周期递归问题查看

这类问题一般比较难分析&#xff0c;符合以下情况的才有可能分析出来&#xff1a; 能够复现并调试有问题时的堆栈以及对应的event log TaskFragment#shouldSleepActivities 方法导致递归 There is a recursion among check for sleep and complete pause during sleeping 关…

dlib库详解及Python环境安装指南

dlib是一个开源的机器学习库&#xff0c;它包含了众多的机器学习算法&#xff0c;例如分类、回归、聚类等。此外&#xff0c;dlib还包含了众多的数据处理、模型训练等工具&#xff0c;使得其在机器学习领域被广泛应用。本文将详细介绍dlib库的基本概念、功能&#xff0c;以及如…

删除数据库

MySQL从小白到总裁完整教程目录:https://blog.csdn.net/weixin_67859959/article/details/129334507?spm1001.2014.3001.5502 语法格式: drop database 数据库名称;这个命令谨慎使用,俗话说:删库跑路! 案列:删除testing数据库,并验证 mysql> show databases; -----------…