一般
P
y
T
o
r
c
h
PyTorch
PyTorch加载数据的固定格式是:
dataset = MyDataset() : 构建
D
a
t
a
s
e
t
Dataset
Dataset对象
dataLoader = DataLoader(dataset) #通过
D
a
t
a
L
o
a
d
e
r
DataLoader
DataLoader来构造迭代对象.
num_epoches = 100
for epoch in range(num_epoches): #逐步迭代数据
for img,label in dataLoader:
#训练代码
但是小样本有episode这个概念,所以需要额外用一个
s
a
m
p
l
e
r
sampler
sampler,写篇文章记录下原型网络是怎么加载数据哒.
episode有分 s u p p o r t support support集和 q u e r y query query集, 如果 s u p p o r t support support有N个类,每个类有 N N N个样本,我们就叫做 N w a y − K s h o t Nway-Kshot Nway−Kshot,另外,在 q u e r y query query集每个样类中有 Q Q Q个样本,注意着 Q Q Q个样本和 S u p p o r t Support Support集中 K K K个样本没有重复样本.
DataLoader是怎么获取数据的?
当我们使用下面代码获取一个batch数据时候,
S
a
m
p
l
e
r
Sampler
Sampler会先产生64个下标,然后DataLoader会根据这64个下标获取数据,最后封装成一个
t
e
n
s
o
r
tensor
tensor,返回给small_data
from torch.utils.data import DataLoader:
train_dataloader = DataLoader(training_data,batch_size = 64,shuffle = True)
from small_data in train_dataloader:
print(data.shape)
原型代码中的dataloader simpler和dataset分别对应于:
dataset:
omniglot.py
ds = TransformDataset(ListDataset(class_names), transforms)
Sampler
base.py
EpisodicBatchSampler
- dataLoader
omniglot.py
torch.utils.data.DataLoader(ds, batch_sampler=sampler, num_workers=0)
加载数据产生batch_size
Nway -Kshot-Qquery的episode
dataset
在原型代码中使用omnigolot数据集训练,在图片外面有两个文件夹,第一层是1种类,第二层是2编号,注意源来的
o
m
i
g
l
o
t
omiglot
omiglot数据集有区分:
images_background.zip和images_background.zip
代码作者分别解压之后,合并到了omniglot/data文件夹下,详细可以看
官方代码
download_omniglot.sh文件。
在官方代码中有对omiglot进行分割的txt文件,如下图所示:
可以在想要的训练集,验证集和测试集时分别读取不同的txt文件完成,以train为例,
t
r
a
i
n
.
t
x
t
train.txt
train.txt的第一行为:
Angelic/character01/rot000
分别代表了1种类/编号/旋转角度.
ll5行,首先读取train.txt文件的所有行,将其存放至容器class_names = []中.
然后在119行
ds = TransformDataset(ListDataset(class_names), transforms)
查看官方文档和教程产生List形式数据
可以发现就是在获得ListDataset(class_names)中的任意一项数据时,
都会对其进行transform变换,transforms变换
内容如下:
trainsformers = [partial(convert_dict,‘class’),
load_class_images,
partial(extract_episode, n_support, n_query)]
其中convert_dict,load_class_images,extract_episode.均为代码作者自定义函数,而
p
a
r
t
i
a
l
partial
partial关键字就是在调用函数的同时,有几个参数固定为所给定的值,以partial(convert_dict,‘class’),就等同于将convert_dict函数从
def convert_dict(k, v):
return { k: v }
转变为:
def convert_dict(v):
# 此时已经不用传k的值了,因为 partial(convert_dict, 'class') 已经给k 赋值了'class'
return { k: v }
如果有functional(a,b,c,d,e,f,g),则:
- partial(convert_dict, ‘value1’) 就已经给参数a传递了 v a l u e 1 value1 value1
- partial(convert_dict, ‘value1’, ‘value2’) 就已经给参数a传了value1、参数b传了value2
- partial(convert_dict, ‘value1’, ‘value2’ ,‘value3’) 就已经给参数a传了value1、参数b传了value2、参数c传了value3
等等等>
以Avesta/character12/rot890为例,我们来看transforme最终能取得什么效果,在原型代码中设置 e p i s o d e episode episode,为 5 w a y − 5 s h o t − 15 q u e r y 5way-5shot-15query 5way−5shot−15query
为了区分两个5, 我们在这以 10 w a y − 5 s h o t − 15 q u e r y 10way-5shot-15query 10way−5shot−15query
进行举例子.
transforms = [partial(convert_dict, ‘class’),
load_class_images,
partial(extract_episode, 5, 15)]
首先把Avesta/character/rot090传递给参数partial(convert_dict, ‘class’)
def convert_dict(k, v):
return { k: v }
返回结果为:{ ‘class’ :“Avesta/character12/rot090” }
然后把{ ‘class’ :“Avesta/character12/rot090” }传进load_class_images
def load_class_images(d): # { 'class' :"Avesta/character12/rot090" }
if d['class'] not in OMNIGLOT_CACHE:
alphabet, character, rot = d['class'].split('/')
# 值分别为 Avesta character12 rot090
image_dir = os.path.join(OMNIGLOT_DATA_DIR, 'data', alphabet, character)
# OMNIGLOT_DATA_DIR是omniglot的根路径,此句就是为了拼接出Avesta/character12的路径
class_images = sorted(glob.glob(os.path.join(image_dir, '*.png')))
# 获得路径下以png结尾的文件路径,然后排序,这是一个列表
if len(class_images) == 0:
raise Exception("No images found for omniglot class {} at {}. Did you run download_omniglot.sh first?".format(d['class'], image_dir))
image_ds = TransformDataset(ListDataset(class_images), # 这个同上文讲过的,不再赘述
compose([partial(convert_dict, 'file_name'),
partial(load_image_path, 'file_name', 'data'),
partial(rotate_image, 'data', float(rot[3:])),
partial(scale_image, 'data', 28, 28),
partial(convert_tensor, 'data')]))
loader = torch.utils.data.DataLoader(image_ds, batch_size=len(image_ds), shuffle=False)
# 将全部数据封装成一个batch,作为一个episode中的一个类
for sample in loader:
OMNIGLOT_CACHE[d['class']] = sample['data']
break # only need one sample because batch size equal to dataset length
return { 'class': d['class'], 'data': OMNIGLOT_CACHE[d['class']] }
返回结果为:
```python
{ 'class' :"Avesta/character12/rot090" ,
'data': size为(20,1,28,28)的一个tensor
}
注意这里的(20,1,28,28)是固定的,因为一个文件夹下面只有20张
28
∗
28
28*28
28∗28的黑白图片
然后在把上面结果传递给extract_episode:注意
10
w
a
y
−
5
s
h
o
t
−
15
q
u
e
r
y
10way-5shot-15query
10way−5shot−15query进行举例
所以每个类有5个作为
s
u
p
p
o
r
t
support
support 15个作为
q
u
e
r
y
query
query
def extract_episode(n_support, n_query, d):
# data: N x C x H x W
n_examples = d['data'].size(0) # 20
if n_query == -1:
n_query = n_examples - n_support
example_inds = torch.randperm(n_examples)[:(n_support+n_query)]
# 从20个样本中 选取5+15个
support_inds = example_inds[:n_support]
# 从中选5个作为support
query_inds = example_inds[n_support:]
# 剩下的20-5个作为query
# 根据下标加载数据
xs = d['data'][support_inds]
xq = d['data'][query_inds]
return {
'class': d['class'],
'xs': xs,
'xq': xq
}
返回结果为
{
'class': "Avesta/character12/rot090" ,
'xs': size为(5,1,28,28)的一个tensor ,
'xq': size为(15,1,28,28)的一个tensor
}
Simpler
根据 E p i s o d i c B a t c h S a m p l e r 的定义和调用 EpisodicBatchSampler的定义和调用 EpisodicBatchSampler的定义和调用,可以根据需求生成类的下标:
class EpisodicBatchSampler(object):
def __init__(self, n_classes, n_way, n_episodes):
self.n_classes = n_classes
self.n_way = n_way
self.n_episodes = n_episodes
def __len__(self):
return self.n_episodes
def __iter__(self):
for i in range(self.n_episodes):
yield torch.randperm(self.n_classes)[:self.n_way]
sampler = EpisodicBatchSampler(len(ds), n_way, n_episodes)
此时的len(ds)等于train.txt行数,
解释
此时的 l e n ( d s ) len(ds) len(ds)等于train.txt行数,也就是4112行,
- n_way等于10,.
- n_episodes等于100
- 就是一个epoch里面有100个episode.
- 每个episode有10个类,从采样定义,我们可以发现,认为字符和转换某个角度(90、180、270)后的字符,是不i一样的,可以看作是一种数据增强吧,另外 S a m p l e r Sampler Sampler的代码显示其返回的list有10个下标,
dataloader
然后 d a t a l o a d e r dataloader dataloader根据它提供的10个下标,去 d a t a s e t dataset dataset找对应下标的数据.
{
'class': "Avesta/character12/rot090" ,
'xs': size为(5,1,28,28)的一个tensor ,
'xq': size为(15,1,28,28)的一个tensor
}
然后dataloader将10个数据合并成一个episode.
在lin35行使用了dataloader:
https://github.com/jakesnell/prototypical-networks/blob/c9bb4d258267c11cb6e23f0a19242d24ca98ad8a/protonets/engine.py#L35
其中state[‘loader’]定义在lin39行,lin39
tqdm相当于有进度条的
f
o
r
for
for循环.
因此,功能上可以看作:
for sample in tqdm(state['loader'], desc="Epoch {:d} train".format(state['epoch'] + 1)):
...
# 等价于
for sample in train_loader:
...
调试分析sampler可以看作和推断一致.
总结
会自己将代码研究透彻,构造各种数据框架,会自己研究代码,将其全部都搞定都行啦的理由与打算.
- 慢慢的会自己将代码都给其弄明白,全部都将其搞透彻,研究彻底都行啦的里由与打算.