很多文章都是从 D a t a s e t Dataset Dataset等对象自下网上进行介绍的,但是对于初学者而言,其实这并不好理解,因为有时候,会不自觉的陷入到一些细枝末节中去,而不能把握重点,所以本文将自上而下的对 P y t o r c h Pytorch Pytorch数据读取方法进行介绍。
自上而下理解三者关系
首先,我们看一下 D a t a L o a d e r . n e x t DataLoader.next DataLoader.next的源代码长什么样,为方便理解,我只选取了num_works为0的情况,(num_works)简单理解都是能够并行化读取数据。
def __next__(self):
if self.num_workers == 0: # same-process loading
indices = next(self.sample_iter) # may raise StopIteration
batch = self.collate_fn([self.dataset[i] for i in indices])
if self.pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch
在阅读上面代码时候,我们可以假设,我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取的数据就只需要对应index即可,即上面代码中的
i
n
d
i
c
e
s
indices
indices,而选取index的方式有多种:有按顺序的,也有乱序的,所以这个工作需要
S
a
m
p
l
e
r
Sampler
Sampler来完成,现在你不需要具体的细节,后面会介绍,只需要了解
D
a
t
a
L
o
a
d
e
r
DataLoader
DataLoader和
S
a
m
p
l
e
r
Sampler
Sampler在这里产生关系.
那么
D
a
t
a
s
e
t
Dataset
Dataset和
D
a
t
a
L
o
a
d
e
r
DataLoader
DataLoader在什么时候产生关系呢?没错就是下面一行,我们已经拿到了
i
n
d
i
c
e
s
indices
indices,那么下一步,我们只需要根据
i
n
d
i
c
e
s
indices
indices对数据进行读取即可.
在下面 i f if if语句的作用都是,如果 p i n m e m o r y = T r u e , pin_memory=True, pinmemory=True,,那么 P y t o r c h Pytorch Pytorch会采用一系列操作把数据拷贝到GPU中,总之为了加速.
综上,可以了解DataLoader Sampler和Dataset三者关系如下:
在阅读后文中,始终需要将上面的关系记在心里,这样能帮助你更好的理解
Sampler
参数传递
class DataLoader(object):
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
batch_sampler=None, num_workers=0, collate_fn=default_collate,
pin_memory=False, drop_last=False, timeout=0,
worker_init_fn=None)
要更加细致的理解
S
a
m
p
l
e
r
Sampler
Sampler原理,我们需要先阅读以下
D
a
t
a
L
o
a
d
e
r
DataLoader
DataLoader的源代码 如下:
可以看到初始化参数有两种
S
a
m
p
l
e
r
Sampler
Sampler : Sampler和batch_sampler
都默认为None,前者作用是生成一系列
i
n
d
e
x
index
index,而batch_sampler则是将sampler生成indices打包分组,得到一个又一个batch的index,例如,下面所示示例:
Batchsampler将
S
e
q
u
e
n
t
i
a
l
S
a
m
p
l
e
r
SequentialSampler
SequentialSampler,生成的index按照指定的batchsize分组.
>>>in : list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
>>>out: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
pyTorch已经实现的sampler有以下几种
-
SequentialSampler
-
RandomSampler
-
WeightedSampler
-
SubsetRandomSampler
需要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你可以通过阅读理解源码更深刻的理解,这里只做总结: -
源码
- 如果自定义batch_sampler,那么这些参数都必须使用默认值:batch_size Shuffle sampler drop_last.
- 如果自定义了sampler :那么shuffle需要设置为false
- 如果sampler和batch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况:
- 若shuffle = True时,则sampler=RandomSampler(dataset)
- 若shuffle = False时,则sampler=SequentialSampler(dataset)
如果定义sampler和BatchSampler
仔细查看源代码可以发现,所有采样器其实都继承同一个父类,即 S a m p l e r Sampler Sampler,其代码定义如下:
class Sampler(object):
r"""Base class for all Samplers.
Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
way to iterate over indices of dataset elements, and a :meth:`__len__` method
that returns the length of the returned iterators.
.. note:: The :meth:`__len__` method isn't strictly required by
:class:`~torch.utils.data.DataLoader`, but is expected in any
calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
"""
def __init__(self, data_source):
pass
def __iter__(self):
raise NotImplementedError
def __len__(self):
return len(self.data_source)
所以,你要做好的都是定义好__iter__(self) 函数,不过要注意的是该函数的返回值需要是可迭代的,例如
S
e
q
u
e
n
t
i
a
l
S
a
m
p
l
e
r
SequentialSampler
SequentialSampler返回的是:
iter(range(len(self.data_source)))
另外
B
a
t
c
h
S
a
m
p
l
e
r
BatchSampler
BatchSampler与其他
S
a
m
p
l
e
r
Sampler
Sampler的主要区别是其需要将
S
a
m
p
l
e
r
Sampler
Sampler作为参数进行打包,进而每次迭代返回以batch size为大小的index列表,也就是说后面读取数据的过程中都是
b
a
t
c
h
s
a
m
p
l
e
r
batch sampler
batchsampler.
Dataset
定义如下
class Dataset(object):
def __init__(self):
...
def __getitem__(self, index):
return ...
def __len__(self):
return ...
上面三个方法最基本的,其中__getitem__是最主要的方法,其规定了如何读取数据,但是其又不同于一般的方法,因为它是 p y t h o n b u i l t − i n python built-in pythonbuilt−in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问,加入你定义好一个dataset,那么可以直接通过dataset[0]来访问第一个数据,在之前,我一值没弄清__getitem__是什么作用,所以一值不知道该怎么进入这个函数进行调试,现在如果你想对__getitem__方法进行调试,可以写一个for循环遍历dataset来进行调试,而不用构建dataloader等一大堆东西啦,建议学会使用ipdb这个库,非常实用!!!以后有时间再写一篇ipdb的使用教程。另外,其实我们通过最前面的Dataloader的__next__函数可以看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了,我直接复制了一遍,如下:
class DataLoader(object):
...
def __next__(self):
if self.num_workers == 0:
indices = next(self.sample_iter)
batch = self.collate_fn([self.dataset[i] for i in indices]) # this line
if self.pin_memory:
batch = _utils.pin_memory.pin_memory_batch(batch)
return batch
我们仔细可以发现,前面有一个self.collate_fn方法,这个是干嘛用的呢?在介绍前,我们需要知道每个参数的含义:
- indices: 表示每一个iteration,sampler返回的indices,即一个batch size大小的索引列表
- self.dataset[i] 这里对第i个数据进行读取操作.
一般来说:self.dataset[i]=(img, label)
我们不难猜出,collate_fn的作用就是将一个batch的数据进行合并的操作,默认的是collate_fn是将img和label分别合并成 i m g s imgs imgs和 l a b e l s labels labels,所以,如果你的__getitem__方法只是返回img,label.那么你可以使用默认的collate_fn方法,但是如果你每次读取的数据有img, box, label等等,那么你就需要自定义collate_fn来将对应的数据合并成一个batch数据,这样方便后续的训练步骤。
自己理解
DataLoader Dataset和Sampler之间的关系
- Sampler产生对数据进行采样
- Dataset:产生数据
- DataLoader将数据迭代产生batch_size数据格式.
总结
会自己看源代码,根据源代码了解,这里只是做总结
慢慢的将各种数据之间的关系都搞明白,全部都将其搞透彻.