回归原型网络代码episode数据加载

news2025/1/15 6:51:27

一般 P y T o r c h PyTorch PyTorch加载数据的固定格式是:
dataset = MyDataset() : 构建 D a t a s e t Dataset Dataset对象
dataLoader = DataLoader(dataset) #通过 D a t a L o a d e r DataLoader DataLoader来构造迭代对象.
num_epoches = 100
for epoch in range(num_epoches): #逐步迭代数据
for img,label in dataLoader:
#训练代码
但是小样本有episode这个概念,所以需要额外用一个 s a m p l e r sampler sampler,写篇文章记录下原型网络是怎么加载数据哒.

episode有分 s u p p o r t support support集和 q u e r y query query集, 如果 s u p p o r t support support有N个类,每个类有 N N N个样本,我们就叫做 N w a y − K s h o t Nway-Kshot NwayKshot,另外,在 q u e r y query query集每个样类中有 Q Q Q个样本,注意着 Q Q Q个样本和 S u p p o r t Support Support集中 K K K个样本没有重复样本.

DataLoader是怎么获取数据的?

在这里插入图片描述
当我们使用下面代码获取一个batch数据时候, S a m p l e r Sampler Sampler会先产生64个下标,然后DataLoader会根据这64个下标获取数据,最后封装成一个 t e n s o r tensor tensor,返回给small_data
from torch.utils.data import DataLoader:
train_dataloader = DataLoader(training_data,batch_size = 64,shuffle = True)
from small_data in train_dataloader:
print(data.shape)

原型代码中的dataloader simpler和dataset分别对应于:

dataset:

omniglot.py

ds = TransformDataset(ListDataset(class_names), transforms)

Sampler

base.py

EpisodicBatchSampler

  • dataLoader
    omniglot.py

torch.utils.data.DataLoader(ds, batch_sampler=sampler, num_workers=0)

加载数据产生batch_size

Nway -Kshot-Qquery的episode

dataset

在原型代码中使用omnigolot数据集训练,在图片外面有两个文件夹,第一层是1种类,第二层是2编号,注意源来的 o m i g l o t omiglot omiglot数据集有区分:
images_background.zip和images_background.zip
代码作者分别解压之后,合并到了omniglot/data文件夹下,详细可以看

官方代码
download_omniglot.sh文件。
在这里插入图片描述
在官方代码中有对omiglot进行分割的txt文件,如下图所示:
在这里插入图片描述
可以在想要的训练集,验证集和测试集时分别读取不同的txt文件完成,以train为例, t r a i n . t x t train.txt train.txt的第一行为:

Angelic/character01/rot000
分别代表了1种类/编号/旋转角度.

ll5行,首先读取train.txt文件的所有行,将其存放至容器class_names = []中.
然后在119行

ds = TransformDataset(ListDataset(class_names), transforms)

查看官方文档和教程产生List形式数据

可以发现就是在获得ListDataset(class_names)中的任意一项数据时,
都会对其进行transform变换,transforms变换
内容如下:
trainsformers = [partial(convert_dict,‘class’),
load_class_images,
partial(extract_episode, n_support, n_query)]
其中convert_dict,load_class_images,extract_episode.均为代码作者自定义函数,而 p a r t i a l partial partial关键字就是在调用函数的同时,有几个参数固定为所给定的值,以partial(convert_dict,‘class’),就等同于将convert_dict函数从

def convert_dict(k, v):
    return { k: v }

转变为:

def convert_dict(v):
	# 此时已经不用传k的值了,因为 partial(convert_dict, 'class') 已经给k 赋值了'class'
    return { k: v }

如果有functional(a,b,c,d,e,f,g),则:

  • partial(convert_dict, ‘value1’) 就已经给参数a传递了 v a l u e 1 value1 value1
  • partial(convert_dict, ‘value1’, ‘value2’) 就已经给参数a传了value1、参数b传了value2
  • partial(convert_dict, ‘value1’, ‘value2’ ,‘value3’) 就已经给参数a传了value1、参数b传了value2、参数c传了value3
    等等等>
    以Avesta/character12/rot890为例,我们来看transforme最终能取得什么效果,在原型代码中设置 e p i s o d e episode episode,为 5 w a y − 5 s h o t − 15 q u e r y 5way-5shot-15query 5way5shot15query
    为了区分两个5, 我们在这以 10 w a y − 5 s h o t − 15 q u e r y 10way-5shot-15query 10way5shot15query
    进行举例子.

transforms = [partial(convert_dict, ‘class’),
load_class_images,
partial(extract_episode, 5, 15)]
首先把Avesta/character/rot090传递给参数partial(convert_dict, ‘class’)

def convert_dict(k, v):
    return { k: v }

返回结果为:{ ‘class’ :“Avesta/character12/rot090” }
然后把{ ‘class’ :“Avesta/character12/rot090” }传进load_class_images

def load_class_images(d):  # { 'class' :"Avesta/character12/rot090" }
    if d['class'] not in OMNIGLOT_CACHE:
        alphabet, character, rot = d['class'].split('/')  
        # 值分别为 Avesta  character12  rot090
        
        image_dir = os.path.join(OMNIGLOT_DATA_DIR, 'data', alphabet, character)
        # OMNIGLOT_DATA_DIR是omniglot的根路径,此句就是为了拼接出Avesta/character12的路径

        class_images = sorted(glob.glob(os.path.join(image_dir, '*.png')))
        # 获得路径下以png结尾的文件路径,然后排序,这是一个列表

        if len(class_images) == 0:
            raise Exception("No images found for omniglot class {} at {}. Did you run download_omniglot.sh first?".format(d['class'], image_dir))

        image_ds = TransformDataset(ListDataset(class_images), # 这个同上文讲过的,不再赘述
                                    compose([partial(convert_dict, 'file_name'),
                                             partial(load_image_path, 'file_name', 'data'),
                                             partial(rotate_image, 'data', float(rot[3:])),
                                             partial(scale_image, 'data', 28, 28),
                                             partial(convert_tensor, 'data')]))

        loader = torch.utils.data.DataLoader(image_ds, batch_size=len(image_ds), shuffle=False)
        # 将全部数据封装成一个batch,作为一个episode中的一个类

        for sample in loader:
            OMNIGLOT_CACHE[d['class']] = sample['data']
            break # only need one sample because batch size equal to dataset length

    return { 'class': d['class'], 'data': OMNIGLOT_CACHE[d['class']] }

返回结果为:

```python
{ 'class' :"Avesta/character12/rot090" ,
'data':  size为(20,1,28,28)的一个tensor
}

注意这里的(20,1,28,28)是固定的,因为一个文件夹下面只有20张 28 ∗ 28 28*28 2828的黑白图片
然后在把上面结果传递给extract_episode:注意 10 w a y − 5 s h o t − 15 q u e r y 10way-5shot-15query 10way5shot15query进行举例
所以每个类有5个作为 s u p p o r t support support 15个作为 q u e r y query query

def extract_episode(n_support, n_query, d):
    # data: N x C x H x W
    n_examples = d['data'].size(0)  # 20

    if n_query == -1:
        n_query = n_examples - n_support  

    example_inds = torch.randperm(n_examples)[:(n_support+n_query)]
    # 从20个样本中 选取5+15个
    
    support_inds = example_inds[:n_support]
    # 从中选5个作为support
    
    query_inds = example_inds[n_support:]
    # 剩下的20-5个作为query

	# 根据下标加载数据
    xs = d['data'][support_inds]
    xq = d['data'][query_inds]

    return {
        'class': d['class'],
        'xs': xs,
        'xq': xq
    }

返回结果为

{
	'class': "Avesta/character12/rot090" ,
	'xs': size为(5,1,28,28)的一个tensor ,
	'xq': size为(15,1,28,28)的一个tensor
}

Simpler

根据 E p i s o d i c B a t c h S a m p l e r 的定义和调用 EpisodicBatchSampler的定义和调用 EpisodicBatchSampler的定义和调用,可以根据需求生成类的下标:

class EpisodicBatchSampler(object):
    def __init__(self, n_classes, n_way, n_episodes):
        self.n_classes = n_classes
        self.n_way = n_way
        self.n_episodes = n_episodes

    def __len__(self):
        return self.n_episodes

    def __iter__(self):
        for i in range(self.n_episodes):
            yield torch.randperm(self.n_classes)[:self.n_way]

sampler = EpisodicBatchSampler(len(ds), n_way, n_episodes)
此时的len(ds)等于train.txt行数,

解释

此时的 l e n ( d s ) len(ds) len(ds)等于train.txt行数,也就是4112行,

  • n_way等于10,.
  • n_episodes等于100
  • 就是一个epoch里面有100个episode.
  • 每个episode有10个类,从采样定义,我们可以发现,认为字符和转换某个角度(90、180、270)后的字符,是不i一样的,可以看作是一种数据增强吧,另外 S a m p l e r Sampler Sampler的代码显示其返回的list有10个下标,

dataloader

然后 d a t a l o a d e r dataloader dataloader根据它提供的10个下标,去 d a t a s e t dataset dataset找对应下标的数据.

{
	'class': "Avesta/character12/rot090" ,
	'xs': size为(5,1,28,28)的一个tensor ,
	'xq': size为(15,1,28,28)的一个tensor
}

然后dataloader将10个数据合并成一个episode.
在lin35行使用了dataloader:

https://github.com/jakesnell/prototypical-networks/blob/c9bb4d258267c11cb6e23f0a19242d24ca98ad8a/protonets/engine.py#L35

其中state[‘loader’]定义在lin39行,lin39
tqdm相当于有进度条的 f o r for for循环.
在这里插入图片描述
因此,功能上可以看作:

for sample in tqdm(state['loader'], desc="Epoch {:d} train".format(state['epoch'] + 1)):
	...
# 等价于
for sample in train_loader:
	...


调试分析sampler可以看作和推断一致.
在这里插入图片描述

总结

会自己将代码研究透彻,构造各种数据框架,会自己研究代码,将其全部都搞定都行啦的理由与打算.

  • 慢慢的会自己将代码都给其弄明白,全部都将其搞透彻,研究彻底都行啦的里由与打算.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/131100.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

元宇宙产业委评选2022全球元宇宙十大事件(含国外元宇宙五大事件)

中国移动通信联合会元宇宙产业工作委员会(简称为:元宇宙产业委) 评选2022全球元宇宙十大事件(含国外元宇宙五大事件) 1、1月5日,CES 2022上,英伟达(NVIDIA)宣布旗下元宇…

【django】HttpRequest对象的属性和路由补充

文章目录一、HttpRequest对象的常用属性1、request.GET:获取查询字符串参数案例:特别注意:2、request.POST:post请求数据,只能获取表单参数3、request.body:请求body,响应结果为字节类型4、request.method&…

一文搞懂G1垃圾回收器

G1是从JDK9之后的默认垃圾回收器,其功能强大,性能优异,不过目前市面的材料不算多,很多都是抄来抄去,讲得也不太清楚。经过仔细阅读oracle官网以及相关的材料,从整体上梳理了G1的过程,希望这一文…

数据库设计规范详解

对于后端开发人员,建表是个基础活,是地基,如果地基不大牢固,后面在程序开发过程中会带来很多麻烦,在建表的时候不注意细节,等后面系统上线之后,表的维护成本变得非常高,而且很容易踩…

基数排序分析

🥔 原理介绍: [排序算法] 基数排序 (C) - Amαdeus - 博客园 前述的各类排序方法都是建立在关键字比较的基础上,而基数排序是一种非比较型整数排序算法。它的基本思想是将整数按位数切割成不同的数字,然后按每个位数分别比较。 …

单片机基础知识之定时计数器和寄存器

目录 一、定时计数器 二、什么是寄存器 三、定时器如何定时10毫秒 四、定时器编程前寄存器配置计划 五、编程定时器控制LED每隔一秒亮灭 一、定时计数器 1、定时计数器的概念引入 定时器和计数器,电路一样 定时或者计数的本质就是让单片机某个部件数数 当定…

Linux基础------高级IO

文章目录阻塞IO非阻塞IO信号驱动异步IO多路转接(核心终点)实际上 IO “等” 拷贝 等什么呢? -----> 等待的是内核将数据准备好。 拷贝-------> 数据从内核考到用户 IO话题: 无非就是 1 , 改变等的方式 2 &…

Linux中编译带kafka模块的搜狗workflow开源库

workflow依赖的第三方库 openssl https://github.com/openssl/openssl apt install libssl-dev zlib https://github.com/madler/zlib git clone https://github.com/madler/zlib.git./configuremake -j4 make install lz4 (版本>1.7.5) https://github.com/lz4/lz4 …

C语言:预处理(2)

宏通常被用于执行简单的运算。 宏相比于函数的优势: 1.用于调用函数和从函数返回的代码可能比实际执行这个小型计算工作所需要的时间更多。所以宏比函数在程序的规模和速度方面更胜一筹。 2.更为重要的是函数的参数必须声明为特定的类型。所以函数只能在类型合适的…

Diffusion Model原理详解及源码解析

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题 🍊专栏推荐:深度学习网络原理与实战 🍊近期目标:写好专栏的每一篇文章 🍊支持小苏:点赞👍🏼、…

KubeSphere中间件部署

目录 🧡应用部署总览 🧡中间件部署 MySQL有状态副本集 🍠KubeSphere创建配置集 🍠KubeSphere创建存储卷 🍠KubeSphere创建有状态副本集 🍠集群访问 💟这里是CS大白话专场,让枯…

Entity Framework Core 代码自动化迁移

简述 文章内容基于:.NET6 Entity Framewor kCore 7.0.* 使用 EF Core 进行 Code First 开发的时候,肯定会遇到将迁移更新到生产数据库这个问题,大多数都是使用命令生成迁移 SQL,然后使用 SQL 脚本将更新迁移到生产数据库的方式&a…

【一起从0开始学习人工智能0x03】文本特征抽取TfidVectorizer

文章目录文本特征抽取TfidVectorizerTfidVecorizer--------Tf-IDFTF-IDF------重要程度文本特征抽取TfidVectorizer 前几种方法的缺点:有很多词虽然没意义,但是出现次数很多,会影响结果,有失偏颇------------关键词 TfidVecoriz…

一篇文章带你搞懂nodeJs环境配置

1、nodeJs下载地址,这里可以选择你想要的版本,我这里以14.15.1为例 2、下载完成后,直接傻瓜式安装即可。 3、打开命令行(以管理员身份打开),输入node -v,出现以下版本号,代表node成功安装 4、在…

html+css设计两个摆动的大灯笼

实现效果 新年马上就要到了,教大家用htmlcss设计两个大灯笼,喜气洋洋。 html代码: html代码部分非常简单,将一个灯笼分成几部分进行设计,灯笼最上方部分,中间的线条部分和最下方的灯笼穗。组合在一起就…

docker系列教程:docker图形化工具安装及docker系列教程总结

通过前面的学习,我们已经掌握了docker-compose容器编排及实战了。高级篇也算快完了。有没有相关,我们前面学习的时候,都是通过命令行来操作docker的,难道docker就没有图形化工具吗?答案是肯定有的。咱们本篇就来讲讲docker图形化工具及使用图形化工具安装Nginx及docker系列…

读书系列2022(下)读书纪录片

目录 一、认知类 二、纪录片 一、认知类 《蓝海战略》: 让你(企业/个人)在竞争中产生错位竞争,获得优势 《认知盈余》:“人们实际上很喜欢创造并分享”, 参与是一种行为 将人们的自由时间和特殊才能汇聚在一起,共同…

移动Web【字体图标、平面转换[位移,旋转,转换原点,多重转换]、渐变】

文章目录一、字体图标1.1 图标库1.2 下载字体包:1.3 使用字体图标:1.4 使用字体图标 – 类名:1.5 案例:淘宝购物车1.6 上传矢量图:二、平面转换2.1 位移2.1 位移-绝对定位居中2.3 案例2.4 旋转2.5 转换原点2.6 多重转换…

2022年终总结:不一样的形式,不一样的展现

Author:AXYZdong 硕士在读 工科男 有一点思考,有一点想法,有一点理性! 定个小小目标,努力成为习惯!在最美的年华遇见更好的自己! CSDNAXYZdong,CSDN首发,AXYZdong原创 唯…

你真的了解表达式求值吗?

表达式求值大家很熟悉特别是整型十进制的表达式求值。那么char类型的表达式求值是怎么样的&#xff1f;Eg&#xff1a;#include <stdio.h>int main() {char a 127;char b 3;char c a b;printf("%d %d %d\n",a,b,c);return 0; }上面程序输出的结果是多少&am…