文章目录
- 1. reshape函数
- 1.1. 功能与用法
- 1.2. 特点
- 2. transpose和permute函数
- 2.1. transpose
- 2.2. permute
- 2.3. 区别
- 3. view和contiguous函数
- 3.1. view
- 3.2. contiguous
- 3.3. 特点
- 4. squeeze和unsqueeze函数
- 4.1. squeeze
- 4.2. unsqueeze
- 5. 应用场景
- 6. 形状操作综合比较
- 7. 最佳实践建议
- 8. 总结
1. reshape函数
1.1. 功能与用法
reshape函数可以改变张量的形状而不改变其数据,新形状的元素总数必须与原张量一致。
import torch
x = torch.arange(6) # tensor([0, 1, 2, 3, 4, 5])
# 改变形状为2x3
y = x.reshape(2, 3)
"""
tensor([[0, 1, 2],
[3, 4, 5]])
"""
# 自动推断维度大小
z = x.reshape(3, -1) # -1表示自动计算该维度大小
"""
tensor([[0, 1],
[2, 3],
[4, 5]])
"""
1.2. 特点
-
不改变原始数据,只改变视图
-
可以处理非连续内存的张量
-
当无法返回视图时会自动复制数据
2. transpose和permute函数
2.1. transpose
交换两个指定维度:
x = torch.randn(2, 3, 4)
# 交换维度0和1
y = x.transpose(0, 1) # 形状变为(3, 2, 4)
# 对于2D张量,transpose相当于矩阵转置
matrix = torch.randn(3, 4)
matrix.T == matrix.transpose(0, 1) # True
2.2. permute
重新排列所有维度:
x = torch.randn(2, 3, 4, 5)
# 重新排列维度顺序
y = x.permute(2, 0, 3, 1) # 新形状(4, 2, 5, 3)
2.3. 区别
- transpose只能交换两个维度
- permute可以任意重新排列所有维度
3. view和contiguous函数
3.1. view
类似于reshape,但要求张量是连续的:
x = torch.arange(6)
# 改变形状
y = x.view(2, 3)
# 会报错的情况
x_non_contiguous = x.t() # 转置后不连续
try:
x_non_contiguous.view(6)
except RuntimeError as e:
print(e) # 需要连续张量
3.2. contiguous
使张量在内存中连续排列:
x = torch.randn(3, 4).transpose(0, 1) # 不连续张量
# 转换为连续张量
x_cont = x.contiguous() # 可能复制数据
# 现在可以使用view
y = x_cont.view(12)
3.3. 特点
-
view比reshape更快,但有限制
-
转置、切片等操作可能导致不连续
-
需要view操作前应检查连续性
4. squeeze和unsqueeze函数
4.1. squeeze
移除所有大小为1的维度:
x = torch.randn(1, 3, 1, 2)
# 移除所有大小为1的维度
y = x.squeeze() # 形状变为(3, 2)
# 只移除指定维度
z = x.squeeze(dim=0) # 形状变为(3, 1, 2)
4.2. unsqueeze
在指定位置增加大小为1的维度:
x = torch.randn(3, 4)
# 在维度0增加一个维度
y = x.unsqueeze(0) # 形状变为(1, 3, 4)
# 在维度1增加一个维度
z = x.unsqueeze(1) # 形状变为(3, 1, 4)
5. 应用场景
-
unsqueeze常用于广播前的维度对齐
-
squeeze常用于移除不必要的单维度
-
神经网络输入/输出经常需要调整维度
6. 形状操作综合比较
操作 | 是否改变数据 | 是否要求连续 | 适用场景 | 性能 |
---|---|---|---|---|
reshape | 否 | 否 | 通用形状改变 | 中 |
view | 否 | 是 | 快速形状改变 | 高 |
transpose | 否 | 否 | 交换两个维度 | 高 |
permute | 否 | 否 | 复杂维度重排 | 中 |
squeeze | 否 | 否 | 移除单维度 | 高 |
unsqueeze | 否 | 否 | 增加单维度 | 高 |
7. 最佳实践建议
-
优先使用view:当确定张量连续时,view性能更好
-
注意连续性:复杂操作后使用is_contiguous()检查
-
维度顺序:保持合理的维度顺序(N,C,H,W等)
-
避免频繁reshape:多次形状改变可能降低性能
-
使用-1推断:合理利用-1自动计算维度大小
# 形状操作典型工作流示例
def prepare_input(data):
# 增加batch维度
data = data.unsqueeze(0)
# 确保内存连续
if not data.is_contiguous():
data = data.contiguous()
# 改变形状为网络输入格式
return data.view(1, -1, data.size(-1))
8. 总结
- reshape:用来改变张量的形状,返回一个新的张量。
- transpose:交换张量的两个维度。
- permute:按指定的维度顺序重新排列张量的所有维度。
- view:用来改变张量的形状,要求张量在内存中是连续的。
- contiguous:确保张量是连续的,可以在需要 view 操作时使用。
- squeeze:去除张量中维度为1的维度。
- unsqueeze:在张量的指定位置添加一个维度。