笔者一篇博客PyTorch 的 torch.unbind 函数详解与进阶应用:中英双语中有一个例子如下:
# 创建一个 3x2x2 的三维张量
x = torch.tensor([[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
[[9, 10], [11, 12]]])
# 第一步:沿第 0 维分解为 3 个 2x2 张量
unbind_result = torch.unbind(x, dim=0)
# 第二步:沿第 2 维重新堆叠
stack_result = torch.stack(unbind_result, dim=2)
print("最终结果:", stack_result)
结果
最终结果:
tensor([[[ 1, 5, 9],
[ 3, 7, 11]],
[[ 2, 6, 10],
[ 4, 8, 12]]])
- 使用 torch.unbind 沿第 0 维分解。
- 使用 torch.stack 沿第 2 维重新组合,从而完成了维度转换。
张量的形状在每一步的变化如下:
- 原始张量形状为 [3, 2, 2]。
- 分解后,得到 3 个形状为 [2, 2] 的张量。
- 堆叠时,将这些张量沿新的维度 dim=2 组合,最终形状变为 [2, 2, 3]。
通过这种分解和堆叠方式,我们可以灵活地操作张量的维度和数据布局。
具体是怎么变的,这里记录一下。
这个例子展示了如何通过 torch.unbind
和 torch.stack
动态调整张量的维度顺序。以下是对这个例子的详细解释,包括每一步的操作和张量形状变化:
1. 初始张量
我们先创建一个形状为 [3, 2, 2]
的张量 x
:
x = torch.tensor([[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
[[9, 10], [11, 12]]])
张量的内容:
x = [
[[1, 2], [3, 4]], # 第一个“平面”
[[5, 6], [7, 8]], # 第二个“平面”
[[9, 10], [11, 12]] # 第三个“平面”
]
形状:[3, 2, 2]
这里的含义:
- 第一维度(dim=0,大小为3):有3个“平面”(或者块)。
- 第二维度(dim=1,大小为2):每个“平面”有两行。
- 第三维度(dim=2,大小为2):每行有两个元素。
2. 使用 torch.unbind
沿 dim=0
分解
unbind_result = torch.unbind(x, dim=0)
torch.unbind
的作用是沿着指定的维度(这里是 dim=0
)移除这一维度,并返回一个元组,元组中的每个元素都是输入张量在该维度上的切片。
对于我们的例子:
x
沿着dim=0
分解,相当于把张量按“平面”切开。- 原始的 3×2×2 张量被分成了 3 个形状为
[2, 2]
的子张量。
unbind_result
的内容:
unbind_result = (
tensor([[1, 2], [3, 4]]), # 第一个平面
tensor([[5, 6], [7, 8]]), # 第二个平面
tensor([[9, 10], [11, 12]]) # 第三个平面
)
每个切片都是一个形状为 [2, 2]
的二维张量。
这里的维度变化:
- 原始张量形状
[3, 2, 2]
→ 切片形状[2, 2]
。
3. 使用 torch.stack
沿 dim=2
重新组合
stack_result = torch.stack(unbind_result, dim=2)
torch.stack
的作用是把一组张量沿着新的维度拼接起来。这里:
unbind_result
是一个包含 3 个[2, 2]
张量的元组。- 我们指定
dim=2
,意思是在原始张量的最后一维(第三维)增加一个新的维度来进行拼接。
拼接过程:
- 第一个子张量的每个位置与第二个、第三个子张量的对应位置对齐,按列方向拼接。
- 拼接后,原来
[2, 2]
的子张量变成了[2, 3]
的子张量。
举例说明:
- 原始三个
[2, 2]
的张量:tensor([[1, 2], [3, 4]]) tensor([[5, 6], [7, 8]]) tensor([[9, 10], [11, 12]])
- 沿
dim=2
进行拼接后:[ [[1, 5, 9], [3, 7, 11]], # 第一行拼接 [[2, 6, 10], [4, 8, 12]] # 第二行拼接 ]
最终结果:
stack_result = tensor([
[[ 1, 5, 9], [ 3, 7, 11]],
[[ 2, 6, 10], [ 4, 8, 12]]
])
形状变化:
- 原始张量
[3, 2, 2]
→ 分解后的切片[2, 2]
→ 拼接后的结果[2, 2, 3]
。
4. 形状变化总结
操作 | 张量内容 | 张量形状 |
---|---|---|
初始张量 | x | [3, 2, 2] |
使用 torch.unbind(dim=0) | 3 个 [2, 2] 的子张量 | [2, 2] |
使用 torch.stack(dim=2) | 拼接为一个新的张量 | [2, 2, 3] |
5. 为什么维度顺序调整了?
通过 torch.unbind
和 torch.stack
的组合,实际上我们重新定义了张量的组织方式:
torch.unbind
将dim=0
的维度移除,分解成多个子张量。torch.stack
指定新的维度(这里是dim=2
),将这些子张量拼接为一个新维度,从而实现了维度的重新排列。
最终,我们将原来的“平面”维度(dim=0)转移到了列方向(dim=2),实现了动态调整维度顺序的效果。
6. 总结
torch.unbind
用于移除一个维度并分解张量。torch.stack
用于沿指定的新维度拼接张量。- 两者结合可以灵活调整张量的维度顺序。
这个例子展示了如何从 [3, 2, 2]
变换到 [2, 2, 3]
,过程中分解和拼接操作相辅相成,适用于需要动态调整张量维度的高级场景。
后记
2024年12月12日22点28分于上海,基于GPT4o大模型生成。