torch.utils.data模块主要用于进行数据集处理,是常用的一个包。在构建数据集的过程中经常会用到。要使用data函数必须先导入:
from torch.utils import data
下面介绍几个经常使用到的类。
torch.utils.data.DataLoader
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, *, prefetch_factor=2, persistent_workers=False)
DataLoader构造函数最重要的参数是 dataset,它指示要从中加载数据的数据集对象。PyTorch 支持两种不同类型的数据集——映射式数据集和可迭代式数据集。
映射式数据集是Dataset 子类的实例,它实现了 __getitem__()
和 __len__()
协议,它表示从索引/键值到数据样本的映射。例如,当使用 dataset[idx]
访问此类数据集时,它可以从磁盘上的文件夹中读取第 idx
幅图像及其对应的标签。
可迭代式数据集是IterableDataset 子类的实例,它实现了 __iter__()
协议,并表示数据样本上的可迭代对象。这种类型的数据集特别适合随机读取代价高昂甚至不可能的情况,以及批大小取决于获取的数据的情况。例如,当调用 iter(dataset)
时,此类数据集可以返回从数据库、远程服务器甚至实时生成的日志中读取的数据流。
torch.utils.data.Dataset
表示一个Dataset的抽象类。所有表示键到数据样本映射的数据集都应该继承它。所有子类都应该重写__getitem__()
,支持为给定键获取数据样本。子类还可以选择性地重写__len__()
,许多Sampler实现和DataLoader的默认选项都期望它返回数据集的大小。子类还可以选择性地实现__getitems__()
,以加速批量样本加载。此方法接受批量样本索引列表并返回样本列表。
代码运用示例:
import torch
from torch.utils.data import Dataset, DataLoader
# 自定义数据集
class SimpleDataset(Dataset):
def __init__(self, data, labels):
"""
Args:
data (list or tensor): 输入数据
labels (list or tensor): 数据对应的标签
"""
self.data = torch.tensor(data, dtype=torch.float32) # 转为张量
self.labels = torch.tensor(labels, dtype=torch.long) # 转为张量
def __len__(self):
"""返回数据集的大小"""
return len(self.data)
def __getitem__(self, idx):
"""根据索引返回一个样本"""
return self.data[idx], self.labels[idx]
# 创建数据和标签
data = [1, 2, 3, 4, 5]
labels = [0, 1, 0, 1, 0]
# 实例化数据集
dataset = SimpleDataset(data, labels)
# 用 DataLoader 加载数据
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 遍历 DataLoader
for batch_data, batch_labels in dataloader:
print("Data:", batch_data)
print("Labels:", batch_labels)
运行结果:(顺序会随着Shuffle=True发生变化)
torch.utils.data.IterableDataset
一个可迭代的数据集。所有表示数据样本可迭代的数据集都应该继承它。当数据来自流时,这种形式的数据集特别有用。所有子类都应该重写__iter__()
,它将返回此数据集中样本的迭代器。当子类与DataLoader一起使用时,数据集中的每个项目都将从DataLoader迭代器中产生。当num_workers > 0
时,每个工作进程将拥有数据集对象的副本,因此通常希望独立配置每个副本以避免工作进程返回重复的数据。get_worker_info()在工作进程中调用时,返回有关工作进程的信息。它可以在数据集的__iter__()
方法或DataLoader的worker_init_fn
选项中使用来修改每个副本的行为。
代码运用示例:
import torch
from torch.utils.data import IterableDataset, DataLoader
# 自定义 IterableDataset
class NumberStreamDataset(IterableDataset):
def __init__(self, start, end):
"""
Args:
start (int): 起始值
end (int): 结束值
"""
self.start = start
self.end = end
def __iter__(self):
"""
定义数据生成逻辑,返回一个迭代器
"""
for num in range(self.start, self.end):
yield num
# 创建一个数据集实例
dataset = NumberStreamDataset(start=0, end=10)
# 用 DataLoader 加载数据
dataloader = DataLoader(dataset, batch_size=3)
# 遍历 DataLoader
for batch in dataloader:
print(batch)
运行结果:
torch.utils.data.TensorDataset(*tensors)
包装张量的数据集。每个样本将通过沿第一个维度索引张量来检索。参数*tensors (张量)表示第一个维度大小相同的张量。
代码运用示例:
import torch
from torch.utils.data import TensorDataset, DataLoader
# 创建输入张量和标签张量
data = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
labels = torch.tensor([0, 1, 0, 1])
# 使用 TensorDataset 封装数据
dataset = TensorDataset(data, labels)
# 使用 DataLoader 加载数据
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 遍历 DataLoader
for batch_data, batch_labels in dataloader:
print("Batch data:", batch_data)
print("Batch labels:", batch_labels)
运行结果:(顺序会随着shuffle=True而发生变化)
torch.utils.data.ConcatDataset(datasets)
将多个数据集连接起来的数据集。此类用于组装不同的现有数据集。参数datasets (序列) 表示要连接的数据集列表
代码运用示例:
import torch
from torch.utils.data import TensorDataset, ConcatDataset, DataLoader
# 创建两个数据集
data1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
labels1 = torch.tensor([0, 1])
dataset1 = TensorDataset(data1, labels1)
data2 = torch.tensor([[5.0, 6.0], [7.0, 8.0]])
labels2 = torch.tensor([1, 0])
dataset2 = TensorDataset(data2, labels2)
# 使用 ConcatDataset 拼接两个数据集
concat_dataset = ConcatDataset([dataset1, dataset2])
# 用 DataLoader 加载数据
dataloader = DataLoader(concat_dataset, batch_size=2, shuffle=True)
# 遍历 DataLoader
for batch_data, batch_labels in dataloader:
print("Batch data:", batch_data)
print("Batch labels:", batch_labels)
运行结果:
torch.utils.data.Subset(dataset, indices)
指定索引处数据集的子集。参数dataset (Dataset)表示整个数据集,indices (序列) – 为子集选择的整个集合中的索引。
代码运用示例:
import torch
from torch.utils.data import TensorDataset, Subset, DataLoader
# 创建一个原始数据集
data = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
labels = torch.tensor([0, 1, 0, 1])
dataset = TensorDataset(data, labels)
# 使用 Subset 提取索引为 [1, 3] 的样本
indices = [1, 3]
subset = Subset(dataset, indices)
# 用 DataLoader 加载子集
dataloader = DataLoader(subset, batch_size=1)
# 遍历 DataLoader
for batch_data, batch_labels in dataloader:
print("Batch data:", batch_data)
print("Batch labels:", batch_labels)
运行结果: