张量
张量(Tensor)是深度学习和PyTorch中的核心概念之一,是标量、向量、矩阵在高维空间中的推广。在PyTorch中,张量用于表示和操作数据。以下是张量的基本类型、创建方法、类型转换、数值计算、索引操作、形状操作以及自动微分模块的详细介绍。
基本类型
-
0维张量:标量(scalar)
scalar = torch.tensor(7)
-
1维张量:向量(vector)
vector = torch.tensor([7, 7])
-
2维张量:矩阵(matrix)
MATRIX = torch.tensor([[7, 8], [9, 10]])
-
多维张量
TENSOR = torch.tensor([[[1, 2, 3], [3, 6, 9], [2, 4, 5]]])
张量创建
-
根据指定数据创建张量
data = torch.tensor(10) data = np.random.randn(2, 3) data = torch.tensor(data) data = [[10., 20., 30.], [40., 50., 60.]] data = torch.tensor(data)
-
根据形状创建张量
data = torch.Tensor(2, 3) data = torch.Tensor([10]) data = torch.Tensor([10, 20])
注意:如果传递列表,则创建包含指定元素的张量。
-
创建指定类型的张量
data = torch.ShortTensor() # int16 data = torch.IntTensor() # int32 data = torch.LongTensor() # int64 data = torch.FloatTensor() # float32 data = torch.DoubleTensor() # float64
-
创建线性和随机张量
# 线性张量 data = torch.arange(start, end, step) data = torch.linspace(start, end, steps) # 随机张量 data = torch.randn(n, m)
-
设置随机种子
torch.random.manual_seed(n)
-
创建0、1或指定值
data = torch.zeros(n, m) data = torch.ones(n, m) data = torch.full([n, m], 值)
张量元素类型转换
- 使用
.type()
或简写方法data = data.type(torch.ShortTensor) # int16 data = data.type(torch.IntTensor) # int32 data = data.type(torch.LongTensor) # int64 data = data.type(torch.FloatTensor) # float32 data = data.type(torch.DoubleTensor) # float64 data = data.short() # int16 data = data.int() # int32 data = data.long() # int64 data = data.float() # float32 data = data.double() # float64
张量类型转换
-
张量转换为NumPy数组
data_n = data_t.numpy() data_n = data_t.numpy().copy() # 避免共享内存
-
NumPy数组转换为张量
data_t = torch.from_numpy(data_n) data_t = torch.from_numpy(data_n.copy()) # 避免共享内存 data_t = torch.tensor(data_n) # 默认不共享内存
-
使用
item()
函数提取标量元素value = T对象.item()
张量数值计算
-
基本运算
# 加、减、乘、除、取负号 data = torch.add(data1, data2) data = torch.sub(data1, data2) data = torch.mul(data1, data2) data = torch.div(data1, data2) data = torch.neg(data) # 就地操作(带下划线的版本) data1.add_(data2)
-
点乘运算
data = torch.mul(data1, data2) data = data1 * data2
-
矩阵乘法运算
data = data1 @ data2 data = torch.matmul(data1, data2)
张量运算函数
-
平均值
mean_value = T对象.mean(dim=0) # 按列计算 mean_value = T对象.mean(dim=1) # 按行计算
-
总和
sum_value = T对象.sum(dim=0) sum_value = T对象.sum(dim=1)
-
平方根、幂次方、指数和对数
sqrt_value = T对象.sqrt() pow_value = T对象.pow(n) exp_value = T对象.exp() log_value = T对象.log() log2_value = T对象.log2() log10_value = T对象.log10()
张量索引操作
-
简单操作
row_data = T对象[行] col_data = T对象[:, 列]
-
列表索引
data = T对象[[0, 1], [2, 3]] data = T对象[[[0], [1]], [1, 2]]
-
范围索引
data = T对象[起始:结束:步长, 起始:结束:步长] data = T对象[[[0],[2]],[0,1]] data = T对象[0:3:2, :2]
-
布尔索引
data = T对象[:, T对象[2] > 5]
-
多维索引
data = T对象[0, :, :] data = T对象[:, 0, :] data = T对象[:, :, 0] data = T对象[:, 1:3, 2]
张量形状操作
-
获取张量形状
shape = T对象.shape shape = T对象.size()
-
修改形状
data = data.reshape(n, m) data = data.view(n, m) # 仅用于整块内存中的张量
-
内存问题
is_contiguous = T对象.is_contiguous() contiguous_data = T对象.contiguous()
-
改变维度
data = data.reshape(新形状)
-
堆叠
new_T = torch.stack([T1, T2, T3], dim=0)
-
升维和降维
data = torch.unsqueeze(data, dim=n) data = torch.squeeze(data, dim=n)
-
维度交换
data = T对象.permute(dims) data = torch.transpose(data, dim1, dim2)
张量拼接操作
- 拼接
data = torch.cat([data1, data2], dim=n)
张量自动微分模块
-
反向传播
data.backward()
-
获取梯度
grad_value = data.grad