官方文档:https://pytorch.org/docs/stable/generated/torch.split.html?highlight=split
torch.split(tensor, split_size_or_sections, dim=0)
Splits the tensor into chunks. Each chunk is a view of the original tensor.
If
split_size_or_sections
is an integer type, then tensor will be split into equally sized chunks (if possible). Last chunk will be smaller if the tensor size along the given dimensiondim
is not divisible bysplit_size
.If
split_size_or_sections
is a list, then tensor will be split intolen(split_size_or_sections)
chunks with sizes indim
according tosplit_size_or_sections
.将张量拆分为块。每个块都是原始张量的一个视图。
如果split_size_or_sections是整型,那么张量将被拆分为大小相等的块(如果可能的话)。如果沿着给定维度dim的张量大小不能被split_size整除,则最后一个块将更小。
如果split_size_or_sections是一个列表,那么张量将被拆分为len(split_size _or_section)块,其大小根据split_sze_or_secttions为dim。
参数:
-
tensor (Tensor) – tensor to split.需要分裂的tensor
-
split_size_or_sections (int) or (list(int)) – size of a single chunk or list of sizes for each chunk单个块的大小
-
dim (int) – dim默为0,即按行分类;dim=1按列分裂
返回类型:List[Tensor]
import torch
a = torch.arange(10).reshape(5, 2)
print(a)
torch.split(a, 2)
torch.split(a, [1, 4])
torch.split(a, 1)
torch.split(a, [3,2])
# dim=1的时候,按列分裂
a = torch.arange(10).reshape(2, 5)
print(a)
torch.split(a, [3,2],1)
结果: