在 PyTorch 中,张量的形状操作是非常重要的,可以让你灵活地调整和处理张量的维度和数据结构。以下是一些常用的张量形状函数及其用法,带有详细解释和举例说明:
1. reshape()
功能: 改变张量的形状,但不改变数据的顺序。
语法: tensor.reshape(*shape)
示例:
import torch
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
x_reshaped = x.reshape(3, 2)
print(x_reshaped)
输出:
tensor([[1, 2], [3, 4], [5, 6]])
在这个例子中,张量 x
被从形状 (2, 3)
重塑为 (3, 2)
。
2. squeeze()
功能: 去除张量中大小为1的维度(例如,形状是 (1, 3, 1)
会变成 (3)
)。
语法: tensor.squeeze(dim=None)
示例:
x = torch.tensor([[[1, 2, 3]]]) # shape: (1, 1, 3)
x_squeezed = x.squeeze()
print(x_squeezed)
输出:
tensor([1, 2, 3])
在这个例子中,squeeze()
去除了前两个大小为1的维度。
3. unsqueeze()
功能: 在指定的维度插入大小为1的新维度。
语法: tensor.unsqueeze(dim)
示例:
x = torch.tensor([1, 2, 3]) # shape: (3,)
x_unsqueezed = x.unsqueeze(0) # 插入新的0维
print(x_unsqueezed.shape) # 输出: torch.Size([1, 3])
在这个例子中,
unsqueeze(0)
在第0个维度插入一个新的大小为1的维度,将形状从 (3,)
变成 (1, 3)
。
4. transpose()
功能: 交换张量的两个维度。
语法: tensor.transpose(dim0, dim1)
示例:
x = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: (2, 3)
x_transposed = x.transpose(0, 1)
print(x_transposed)
输出:
tensor([[1, 4], [2, 5], [3, 6]])
在这个例子中,transpose(0, 1)
交换了维度0和维度1,使张量的形状从 (2, 3)
变成 (3, 2)
。
5. permute()
功能: 改变张量的维度顺序,允许对多个维度进行交换。
语法: tensor.permute(*dims)
示例:
x = torch.randn(2, 3, 5) # shape: (2, 3, 5)
x_permuted = x.permute(2, 0, 1)
print(x_permuted.shape) # 输出: torch.Size([5, 2, 3])
在这个例子中,permute(2, 0, 1)
重新排列了维度顺序,
使得形状从 (2, 3, 5)
变为 (5, 2, 3)
。
6. view()
功能: 类似于 reshape()
,但是 view()
需要张量在内存中是连续的。
语法: tensor.view(*shape)
示例:
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
x_viewed = x.view(3, 2)
print(x_viewed)
输出:
tensor([[1, 2], [3, 4], [5, 6]])
view()
的使用需要张量是连续的,否则会报错。
关于连续性,可以结合 contiguous()
使用。
7. contiguous()
功能: 将非连续的张量转换为在内存中连续存储的张量。
语法: tensor.contiguous()
示例:
x = torch.randn(2, 3, 5)
x_permuted = x.permute(2, 0, 1) # 这使得张量不再连续
x_contiguous = x_permuted.contiguous().view(5, 6) # 转换为连续后再进行view操作
print(x_contiguous.shape)
permute()
操作后的张量不一定是连续的,因此需要 contiguous()
来保证可以使用 view()
。
8. expand()
和 repeat()
功能: 扩展张量到更高的维度。
-
expand()
只是广播,不复制内存。 -
repeat()
会实际复制数据。
示例:
x = torch.tensor([1, 2, 3])
x_expanded = x.expand(3, 3) # 广播
x_repeated = x.repeat(3, 1) # 重复数据
print(x_expanded)
print(x_repeated)
输出:
tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]]) tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
区别在于 expand()
不会占用更多内存,而 repeat()
会真正复制数据。
总结
上述这些张量操作函数在处理多维数据时非常有用,能够灵活地调整和转换张量的形状,以便进行各种操作和模型设计。