这篇文章记录torch.matmul()
的用法
这里仿照官方文档中的例子说明,此处取整数随机数,用于直观的查看效果:
vector x vector
两个一维向量的matmul
相当于点积,得到一个标量
tensor1 = torch.randint(1, 6, (3,))
tensor2 = torch.randint(1, 6, (3,))
output = torch.matmul(tensor1, tensor2)
print("tensor1: {}".format(tensor1))
print("tensor2: {}".format(tensor2))
print("output: {}".format(output))
print("output_size: {}".format(output.size()))
tensor1: tensor([5, 1, 1])
tensor2: tensor([3, 1, 2])
output: 18
output_size: torch.Size([])
matrix x vector
一个二维张量和一个一维张量的乘积,是把多出的一维提出来,相当于沿那个维度分别点积
tensor1 = torch.randint(1, 6, (3, 4))
tensor2 = torch.randint(1, 6, (4,))
output = torch.matmul(tensor1, tensor2)
print("tensor1: {}".format(tensor1))
print("tensor2: {}".format(tensor2))
print("output: {}".format(output))
print("output_size: {}".format(output.size()))
tensor1: tensor([[5, 1, 5, 1],
[4, 4, 4, 2],
[2, 4, 4, 1]])
tensor2: tensor([1, 3, 4, 3])
output: tensor([31, 38, 33])
output_size: torch.Size([3])
batched matrix x broadcasted vector
一个批量矩阵也就是三维张量和一个一维广播向量的乘积,把多出的两维都提出来,沿这两维分别点积
tensor1 = torch.randint(1, 3, (2, 3, 4))
tensor2 = torch.randint(1, 3, (4,))
output = torch.matmul(tensor1, tensor2)
print("tensor1: {}".format(tensor1))
print("tensor2: {}".format(tensor2))
print("output: {}".format(output))
print("output_size: {}".format(output.size()))
tensor1: tensor([[[1, 2, 2, 2],
[2, 2, 2, 2],
[2, 2, 1, 2]],
[[1, 2, 1, 1],
[2, 1, 1, 1],
[2, 2, 2, 2]]])
tensor2: tensor([2, 2, 1, 2])
output: tensor([[12, 14, 13],
[ 9, 9, 14]])
output_size: torch.Size([2, 3])
batched matrix x batched matrix
批量矩阵三维和三维的乘积,把公共的维度提出来,然后二维矩阵分别乘积
tensor1 = torch.randint(1, 3, (2, 3, 4))
tensor2 = torch.randint(1, 3, (2, 4, 5))
output = torch.matmul(tensor1, tensor2)
print("tensor1: {}".format(tensor1))
print("tensor2: {}".format(tensor2))
print("output: {}".format(output))
print("output_size: {}".format(output.size()))
tensor1: tensor([[[2, 2, 2, 1],
[2, 2, 2, 2],
[2, 1, 1, 2]],
[[1, 2, 2, 1],
[1, 1, 1, 1],
[2, 1, 2, 1]]])
tensor2: tensor([[[1, 1, 1, 1, 2],
[2, 1, 1, 2, 1],
[2, 1, 1, 2, 1],
[1, 1, 1, 1, 1]],
[[1, 2, 2, 2, 2],
[2, 1, 1, 1, 1],
[1, 2, 1, 1, 1],
[1, 1, 1, 2, 2]]])
output: tensor([[[11, 7, 7, 11, 9],
[12, 8, 8, 12, 10],
[ 8, 6, 6, 8, 8]],
[[ 8, 9, 7, 8, 8],
[ 5, 6, 5, 6, 6],
[ 7, 10, 8, 9, 9]]])
output_size: torch.Size([2, 3, 5])
batched matrix x broadcasted matrix
批量矩阵三维和广播矩阵二维,多出来的一维提出来,二维矩阵乘积
tensor1 = torch.randint(1, 3, (2, 3, 4))
tensor2 = torch.randint(1, 3, (4, 5))
output = torch.matmul(tensor1, tensor2)
print("tensor1: {}".format(tensor1))
print("tensor2: {}".format(tensor2))
print("output: {}".format(output))
print("output_size: {}".format(output.size()))
tensor1: tensor([[[2, 1, 2, 1],
[1, 1, 2, 1],
[1, 2, 1, 2]],
[[1, 1, 1, 2],
[1, 2, 2, 1],
[1, 1, 2, 2]]])
tensor2: tensor([[2, 1, 1, 1, 1],
[2, 1, 1, 2, 2],
[1, 1, 1, 1, 1],
[1, 2, 2, 1, 1]])
output: tensor([[[9, 7, 7, 7, 7],
[7, 6, 6, 6, 6],
[9, 8, 8, 8, 8]],
[[7, 7, 7, 6, 6],
[9, 7, 7, 8, 8],
[8, 8, 8, 7, 7]]])
output_size: torch.Size([2, 3, 5])