一般
    
     
      
       
        P
       
       
        y
       
       
        T
       
       
        o
       
       
        r
       
       
        c
       
       
        h
       
      
      
       PyTorch
      
     
    PyTorch加载数据的固定格式是:
 dataset = MyDataset() : 构建
    
     
      
       
        D
       
       
        a
       
       
        t
       
       
        a
       
       
        s
       
       
        e
       
       
        t
       
      
      
       Dataset
      
     
    Dataset对象
 dataLoader = DataLoader(dataset) #通过
    
     
      
       
        D
       
       
        a
       
       
        t
       
       
        a
       
       
        L
       
       
        o
       
       
        a
       
       
        d
       
       
        e
       
       
        r
       
      
      
       DataLoader
      
     
    DataLoader来构造迭代对象.
 num_epoches = 100
 for epoch in range(num_epoches): #逐步迭代数据
 for img,label in dataLoader:
 #训练代码
 但是小样本有episode这个概念,所以需要额外用一个
    
     
      
       
        s
       
       
        a
       
       
        m
       
       
        p
       
       
        l
       
       
        e
       
       
        r
       
      
      
       sampler
      
     
    sampler,写篇文章记录下原型网络是怎么加载数据哒.
episode有分 s u p p o r t support support集和 q u e r y query query集, 如果 s u p p o r t support support有N个类,每个类有 N N N个样本,我们就叫做 N w a y − K s h o t Nway-Kshot Nway−Kshot,另外,在 q u e r y query query集每个样类中有 Q Q Q个样本,注意着 Q Q Q个样本和 S u p p o r t Support Support集中 K K K个样本没有重复样本.
DataLoader是怎么获取数据的?

 当我们使用下面代码获取一个batch数据时候,
    
     
      
       
        S
       
       
        a
       
       
        m
       
       
        p
       
       
        l
       
       
        e
       
       
        r
       
      
      
       Sampler
      
     
    Sampler会先产生64个下标,然后DataLoader会根据这64个下标获取数据,最后封装成一个
    
     
      
       
        t
       
       
        e
       
       
        n
       
       
        s
       
       
        o
       
       
        r
       
      
      
       tensor
      
     
    tensor,返回给small_data
 from torch.utils.data import DataLoader:
 train_dataloader = DataLoader(training_data,batch_size = 64,shuffle = True)
 from small_data in train_dataloader:
 print(data.shape)
原型代码中的dataloader simpler和dataset分别对应于:
dataset:
omniglot.py
ds = TransformDataset(ListDataset(class_names), transforms)
Sampler
base.py
EpisodicBatchSampler
- dataLoader 
omniglot.py 
torch.utils.data.DataLoader(ds, batch_sampler=sampler, num_workers=0)
加载数据产生batch_size
Nway -Kshot-Qquery的episode
dataset
在原型代码中使用omnigolot数据集训练,在图片外面有两个文件夹,第一层是1种类,第二层是2编号,注意源来的
    
     
      
       
        o
       
       
        m
       
       
        i
       
       
        g
       
       
        l
       
       
        o
       
       
        t
       
      
      
       omiglot
      
     
    omiglot数据集有区分:
 images_background.zip和images_background.zip
 代码作者分别解压之后,合并到了omniglot/data文件夹下,详细可以看
 
 官方代码
 download_omniglot.sh文件。
 
 在官方代码中有对omiglot进行分割的txt文件,如下图所示:
 
 可以在想要的训练集,验证集和测试集时分别读取不同的txt文件完成,以train为例,
    
     
      
       
        t
       
       
        r
       
       
        a
       
       
        i
       
       
        n
       
       
        .
       
       
        t
       
       
        x
       
       
        t
       
      
      
       train.txt
      
     
    train.txt的第一行为:
Angelic/character01/rot000
分别代表了1种类/编号/旋转角度.
ll5行,首先读取train.txt文件的所有行,将其存放至容器class_names = []中.
 然后在119行
ds = TransformDataset(ListDataset(class_names), transforms)
查看官方文档和教程产生List形式数据
可以发现就是在获得ListDataset(class_names)中的任意一项数据时,
 都会对其进行transform变换,transforms变换
 内容如下:
 trainsformers = [partial(convert_dict,‘class’),
 load_class_images,
 partial(extract_episode, n_support, n_query)]
 其中convert_dict,load_class_images,extract_episode.均为代码作者自定义函数,而
    
     
      
       
        p
       
       
        a
       
       
        r
       
       
        t
       
       
        i
       
       
        a
       
       
        l
       
      
      
       partial
      
     
    partial关键字就是在调用函数的同时,有几个参数固定为所给定的值,以partial(convert_dict,‘class’),就等同于将convert_dict函数从
def convert_dict(k, v):
    return { k: v }
 
转变为:
def convert_dict(v):
	# 此时已经不用传k的值了,因为 partial(convert_dict, 'class') 已经给k 赋值了'class'
    return { k: v }
 
如果有functional(a,b,c,d,e,f,g),则:
- partial(convert_dict, ‘value1’) 就已经给参数a传递了 v a l u e 1 value1 value1
 - partial(convert_dict, ‘value1’, ‘value2’) 就已经给参数a传了value1、参数b传了value2
 - partial(convert_dict, ‘value1’, ‘value2’ ,‘value3’) 就已经给参数a传了value1、参数b传了value2、参数c传了value3
等等等>
以Avesta/character12/rot890为例,我们来看transforme最终能取得什么效果,在原型代码中设置 e p i s o d e episode episode,为 5 w a y − 5 s h o t − 15 q u e r y 5way-5shot-15query 5way−5shot−15query
为了区分两个5, 我们在这以 10 w a y − 5 s h o t − 15 q u e r y 10way-5shot-15query 10way−5shot−15query
进行举例子. 
transforms = [partial(convert_dict, ‘class’),
load_class_images,
partial(extract_episode, 5, 15)]
首先把Avesta/character/rot090传递给参数partial(convert_dict, ‘class’)
def convert_dict(k, v):
    return { k: v }
 
返回结果为:{ ‘class’ :“Avesta/character12/rot090” }
 然后把{ ‘class’ :“Avesta/character12/rot090” }传进load_class_images
def load_class_images(d):  # { 'class' :"Avesta/character12/rot090" }
    if d['class'] not in OMNIGLOT_CACHE:
        alphabet, character, rot = d['class'].split('/')  
        # 值分别为 Avesta  character12  rot090
        
        image_dir = os.path.join(OMNIGLOT_DATA_DIR, 'data', alphabet, character)
        # OMNIGLOT_DATA_DIR是omniglot的根路径,此句就是为了拼接出Avesta/character12的路径
        class_images = sorted(glob.glob(os.path.join(image_dir, '*.png')))
        # 获得路径下以png结尾的文件路径,然后排序,这是一个列表
        if len(class_images) == 0:
            raise Exception("No images found for omniglot class {} at {}. Did you run download_omniglot.sh first?".format(d['class'], image_dir))
        image_ds = TransformDataset(ListDataset(class_images), # 这个同上文讲过的,不再赘述
                                    compose([partial(convert_dict, 'file_name'),
                                             partial(load_image_path, 'file_name', 'data'),
                                             partial(rotate_image, 'data', float(rot[3:])),
                                             partial(scale_image, 'data', 28, 28),
                                             partial(convert_tensor, 'data')]))
        loader = torch.utils.data.DataLoader(image_ds, batch_size=len(image_ds), shuffle=False)
        # 将全部数据封装成一个batch,作为一个episode中的一个类
        for sample in loader:
            OMNIGLOT_CACHE[d['class']] = sample['data']
            break # only need one sample because batch size equal to dataset length
    return { 'class': d['class'], 'data': OMNIGLOT_CACHE[d['class']] }
返回结果为:
```python
{ 'class' :"Avesta/character12/rot090" ,
'data':  size为(20,1,28,28)的一个tensor
}
 
注意这里的(20,1,28,28)是固定的,因为一个文件夹下面只有20张
    
     
      
       
        28
       
       
        ∗
       
       
        28
       
      
      
       28*28
      
     
    28∗28的黑白图片
 然后在把上面结果传递给extract_episode:注意
    
     
      
       
        10
       
       
        w
       
       
        a
       
       
        y
       
       
        −
       
       
        5
       
       
        s
       
       
        h
       
       
        o
       
       
        t
       
       
        −
       
       
        15
       
       
        q
       
       
        u
       
       
        e
       
       
        r
       
       
        y
       
      
      
       10way-5shot-15query
      
     
    10way−5shot−15query进行举例
 所以每个类有5个作为
    
     
      
       
        s
       
       
        u
       
       
        p
       
       
        p
       
       
        o
       
       
        r
       
       
        t
       
      
      
       support
      
     
    support 15个作为
    
     
      
       
        q
       
       
        u
       
       
        e
       
       
        r
       
       
        y
       
      
      
       query
      
     
    query
def extract_episode(n_support, n_query, d):
    # data: N x C x H x W
    n_examples = d['data'].size(0)  # 20
    if n_query == -1:
        n_query = n_examples - n_support  
    example_inds = torch.randperm(n_examples)[:(n_support+n_query)]
    # 从20个样本中 选取5+15个
    
    support_inds = example_inds[:n_support]
    # 从中选5个作为support
    
    query_inds = example_inds[n_support:]
    # 剩下的20-5个作为query
	# 根据下标加载数据
    xs = d['data'][support_inds]
    xq = d['data'][query_inds]
    return {
        'class': d['class'],
        'xs': xs,
        'xq': xq
    }
 
返回结果为
{
	'class': "Avesta/character12/rot090" ,
	'xs': size为(5,1,28,28)的一个tensor ,
	'xq': size为(15,1,28,28)的一个tensor
}
 
Simpler
根据 E p i s o d i c B a t c h S a m p l e r 的定义和调用 EpisodicBatchSampler的定义和调用 EpisodicBatchSampler的定义和调用,可以根据需求生成类的下标:
class EpisodicBatchSampler(object):
    def __init__(self, n_classes, n_way, n_episodes):
        self.n_classes = n_classes
        self.n_way = n_way
        self.n_episodes = n_episodes
    def __len__(self):
        return self.n_episodes
    def __iter__(self):
        for i in range(self.n_episodes):
            yield torch.randperm(self.n_classes)[:self.n_way]
 
sampler = EpisodicBatchSampler(len(ds), n_way, n_episodes)
此时的len(ds)等于train.txt行数,
解释
此时的 l e n ( d s ) len(ds) len(ds)等于train.txt行数,也就是4112行,
- n_way等于10,.
 - n_episodes等于100
 - 就是一个epoch里面有100个episode.
 - 每个episode有10个类,从采样定义,我们可以发现,认为字符和转换某个角度(90、180、270)后的字符,是不i一样的,可以看作是一种数据增强吧,另外 S a m p l e r Sampler Sampler的代码显示其返回的list有10个下标,
 
dataloader
然后 d a t a l o a d e r dataloader dataloader根据它提供的10个下标,去 d a t a s e t dataset dataset找对应下标的数据.
{
	'class': "Avesta/character12/rot090" ,
	'xs': size为(5,1,28,28)的一个tensor ,
	'xq': size为(15,1,28,28)的一个tensor
}
 
然后dataloader将10个数据合并成一个episode.
 在lin35行使用了dataloader:
https://github.com/jakesnell/prototypical-networks/blob/c9bb4d258267c11cb6e23f0a19242d24ca98ad8a/protonets/engine.py#L35
 
其中state[‘loader’]定义在lin39行,lin39
 tqdm相当于有进度条的
    
     
      
       
        f
       
       
        o
       
       
        r
       
      
      
       for
      
     
    for循环.
 
 因此,功能上可以看作:
for sample in tqdm(state['loader'], desc="Epoch {:d} train".format(state['epoch'] + 1)):
	...
# 等价于
for sample in train_loader:
	...
 
调试分析sampler可以看作和推断一致.
 
总结
会自己将代码研究透彻,构造各种数据框架,会自己研究代码,将其全部都搞定都行啦的理由与打算.
- 慢慢的会自己将代码都给其弄明白,全部都将其搞透彻,研究彻底都行啦的里由与打算.
 

















![移动Web【字体图标、平面转换[位移,旋转,转换原点,多重转换]、渐变】](https://img-blog.csdnimg.cn/254076a8730141c08912fc2b69c5b6c2.png)

