小白学Pytorch 系列–Torch API
Torch version 1.13
Tensors
TORCH.IS_TENSOR
如果obj是PyTorch张量,则返回True。
注意,这个函数只是简单地执行isinstance(obj, Tensor)
。使用isinstance
更适合用mypy进行类型检查,而且更显式-所以建议使用它而不是is_tensor
。
obj (Object) – Object to test
True if obj is a PyTorch tensor.
import torch
x = torch.tensor([1,2,3])
is_t = torch.is_tensor(x)
print(is_t)
TORCH.IS_STORAGE
判断是否是存储对象
x = torch.tensor([1,2,3])
is_s = torch.is_storage(x)
print(is_s)
TORCH.IS_COMPLEX
此方法的意思是如果输入是一个复数数据类型(例如torch.complex64或者 torch.complex128)就返回True,否则返回False。
import torch
a = torch.tensor([1, 2], dtype=torch.float32)
b = torch.tensor([3, 4], dtype=torch.float32)
z = torch.complex(a, b)
print(z)
print(z.dytpe)
print(torch.is_complex(z))
TORCH.IS_CONJ
如果输入是一个共轭张量,即它的共轭位被设置为True,则返回True。
x = torch.tensor([-1 + 1j, -2 + 2j, 3 - 3j])
x.is_conj()
y = torch.conj(x)
tensor([-1.-1.j, -2.-2.j, 3.+3.j])
y.is_conj()
TORCH.IS_FLOATING_POINT
判断给定的input中data的值是不是浮点类型
import torch
a = torch.tensor([1, 2], dtype=torch.float16)
torch.is_floating_point(a)
TORCH.IS_NONZERO
判断一个标量是不是为0, 不能使用多维度张量。
torch.is_nonzero(torch.tensor([0.]))
torch.is_nonzero(torch.tensor([1.5]))
torch.is_nonzero(torch.tensor([False]))
torch.is_nonzero(torch.tensor([3]))
torch.is_nonzero(torch.tensor([1, 3, 5]))
torch.is_nonzero(torch.tensor([]))
TORCH.SET_DEFAULT_DTYPE
设置pytorch中浮点数的默认类型。pytorch中有很多浮点类型,例如torch.float16、torch.float32、torch.float64这些在初始化一个浮点tensor的时候是可以指定的,如果我们不指定,那么pytorch就是默认给其一个类型,此方法的作用就是指定pytorch默认给不指定浮点类型的浮点数哪个类型。
torch.set_default_dtype(torch.float64)
set_default_tensor_type
设置pytorch中张量的默认类型
torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32
torch.set_default_tensor_type(torch.DoubleTensor)
torch.tensor([1.2, 3]).dtype # a new floating point tensor
numel
返回input中的元素个数
a = torch.randn(1, 2, 3, 4, 5)
torch.numel(a)
a = torch.zeros(4,4)
torch.numel(a)
set_printoptions
打印时显示浮点tensor中元素的精度(显示到小数点后几位),默认是4
# Limit the precision of elements
torch.set_printoptions(precision=2)
torch.tensor([1.12345])
# Limit the number of elements shown
torch.set_printoptions(threshold=5)
torch.arange(10)
# Restore defaults
torch.set_printoptions(profile='default')
torch.tensor([1.12345])
torch.arange(10)
set_flush_denormal
使CPU上不正规的浮点数失效。
torch.set_flush_denormal(True)
torch.tensor([1e-323], dtype=torch.float64)
torch.set_flush_denormal(False)
torch.tensor([1e-323], dtype=torch.float64)