torch.tensor
和 torch.from_numpy
都是用于在 PyTorch 中创建张量的函数,但它们有一些重要的区别,尤其是在数据类型转换和共享内存方面
1. torch.tensor
torch.tensor
是 PyTorch 中的一个常用函数,用于从数据(如列表、数组等)创建一个新的张量。无论传入的数据是 Python 列表还是 NumPy 数组,torch.tensor
都会 创建一个新的张量,这个张量是原数据的深拷贝。因此,张量和原始数据之间 不共享内存,修改张量的值不会影响原始数据,反之亦然。
示例:
import torch
import numpy as np
# 从列表创建张量
data = [1, 2, 3]
tensor = torch.tensor(data)
# 修改tensor不会影响原始数据
tensor[0] = 0
print(tensor) # 输出: tensor([0, 2, 3])
print(data) # 输出: [1, 2, 3],列表未被修改
特点:
- 深拷贝:创建的新张量和原始数据不共享内存。
- 数据类型转换:如果传入的数据类型与默认的张量数据类型不同(例如:NumPy 数组),
torch.tensor
会自动将数据转换为默认的 PyTorch 张量类型(通常是torch.float32
)。
优势:
- 适合从各种数据类型(Python 列表、元组等)生成张量。
- 由于深拷贝,不会影响原始数据,适合在需要独立管理张量和数据时使用。
2. torch.from_numpy
torch.from_numpy
专门用于将 NumPy 数组 转换为 PyTorch 张量。它与 torch.tensor
的最大区别在于,torch.from_numpy
不会 拷贝数据,而是 共享内存。这意味着生成的张量和原始的 NumPy 数组指向同一个内存区域,修改其中任何一个都会影响另一个。
示例:
import torch
import numpy as np
# 从NumPy数组创建张量
np_array = np.array([1, 2, 3])
tensor = torch.from_numpy(np_array)
# 修改张量会影响原始的NumPy数组
tensor[0] = 0
print(tensor) # 输出: tensor([0, 2, 3])
print(np_array) # 输出: [0 2 3],NumPy数组也被修改了
特点:
- 共享内存:转换生成的张量和原始的 NumPy 数组共享内存,修改其中一个会影响另一个。
- 数据类型限制:
torch.from_numpy
仅支持 NumPy 数组,而且 NumPy 数组的数据类型必须是float32
、float64
、int32
、int64
等 PyTorch 支持的类型。如果数据类型不支持,会报错。
优势:
- 适合需要在 NumPy 和 PyTorch 之间高效转换和共享数据的场景,避免不必要的深拷贝,节省内存。
- 在需要与 NumPy 代码进行集成时非常方便,例如当你需要将 NumPy 数据输入到 PyTorch 模型中训练时。
3. 关键区别
功能 | torch.tensor | torch.from_numpy |
---|---|---|
数据来源 | Python 列表、NumPy 数组等 | 只支持 NumPy 数组 |
内存共享 | 不共享内存(深拷贝) | 共享内存 |
数据类型转换 | 自动将输入数据转换为 PyTorch 的默认数据类型 | 维持 NumPy 数组的数据类型,且只支持部分数据类型 |
适用场景 | 当不需要与原始数据共享内存时,或从列表等创建 | 需要高效地将 NumPy 数据与 PyTorch 张量共享内存时 |
4. 使用场景
1. 使用 torch.tensor
- 深拷贝数据:适合在不希望原始数据(如 NumPy 数组或 Python 列表)被修改的场景下使用。
- 数据类型转换:当你想要自动转换输入数据类型为 PyTorch 默认类型时可以使用。
2. 使用 torch.from_numpy
- 共享内存:适合在内存敏感的应用中使用,因为不会进行数据拷贝,尤其是大数据量的场景下。
- 快速互操作:当你需要在 PyTorch 和 NumPy 之间频繁转换数据时,这种方法可以节省大量时间和内存。
总结
torch.tensor
适用于从任何类型的数据创建独立的张量,并不与原始数据共享内存,适合需要深拷贝和数据隔离的场景。torch.from_numpy
适合从 NumPy 数组快速创建张量,且共享内存,适合在需要高效数据传递和共享内存的场景下使用。