在 PyTorch(一个流行的深度学习框架)中,
reshape
和 view
都是用于改变张量(tensor)形状的方法,但它们在实现方式和使用上有一些区别。下面是它们之间的主要区别:
-
实现方式:
reshape
:reshape
方法创建一个新的张量,其元素与原始张量共享内存空间。这意味着改变形状后,原始张量和新的张量将共享相同的数据存储,所以在一个张量上的修改会影响到另一个张量。view
:view
方法并不会创建一个新的张量,而是返回一个与原始张量共享数据存储的新视图(view)。如果原始张量和新的视图张量上的元素被修改,它们会互相影响,因为它们共享相同的数据。
-
支持条件:
reshape
: 可以用于任意形状的变换,但需要确保变换前后元素总数保持一致,否则会抛出错误。view
: 只能用于支持大小相同的变换,也就是变换前后元素总数必须保持不变。这是因为view
并不改变数据的存储,所以必须保持数据总量不变,否则会抛出错误。
-
内存连续性:
reshape
: 不保证新张量在内存中的连续性,即可能导致新张量的元素在内存中的存储顺序与原始张量不同。view
: 如果原始张量在内存中是连续存储的,那么新视图张量也会保持连续性,否则会返回一个不连续的张量。
-
是否支持自动计算维度:
reshape
: 可以通过将某个维度指定为-1,让 PyTorch 自动计算该维度的大小。view
: 不支持将维度指定为-1,需要手动计算新视图张量的大小。当对不连续的张量进行形状变换时,PyTorch 会自动将其复制为连续的张量,这可能会导致额外的内存开销。为了避免这种情况,你可以使用contiguous()
方法将张量变为连续的。例如:x.contiguous().view(3, 4)
。
import torch # 原始张量 x = torch.arange(12) # 使用 reshape x_reshaped = x.reshape(3, 4) # 创建一个新的形状为(3, 4)的张量 x_reshaped[0, 0] = 100 # 修改新张量的元素会影响到原始张量 print(x) # 输出 tensor([100, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) print(x_reshaped) # 输出 tensor([[100, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) # 使用 view x_viewed = x.view(3, 4) # 创建一个新的形状为(3, 4)的张量视图 x_viewed[0, 1] = 200 # 修改视图张量的元素会影响到原始张量 print(x) # 输出 tensor([100, 200, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) print(x_viewed) # 输出 tensor([[100, 200, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) # 使用 view 自动计算维度大小 x_auto_viewed = x.view(3, -1) # 可以将某个维度指定为-1,让 PyTorch 自动计算大小 print(x_auto_viewed) # 输出 tensor([[100, 200, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]]) # 由于 x_auto_viewed 是连续的,所以修改它也会影响原始张量 x x_auto_viewed[2, 2] = 300 print(x) # 输出 tensor([100, 200, 2, 3, 4, 5, 6, 7, 8, 9, 300, 11])