Prototypical networks for few-shot learning.

news2024/11/16 18:07:01

这篇论文是介绍《Prototypical Networks for Few-shot Learning》。作者公布了他的Pytorh代码。如果看不太懂原作者的代码话可以看一下这一个:https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch

0. Few-shot learning

Few-shot learning是一类机器学习问题,指的是从少量样本中学习新的任务或类别。传统的机器学习算法通常需要大量的数据来训练模型,而few-shot learning则试图通过利用少量数据来学习新的任务或类别,实现更加灵活、高效的学习。这种学习方式可以在很多领域使用,如自然语言处理、图像识别、计算机视觉等。

1. Prototypical Networks

Prototypical neural networks是一种基于原型的神经网络模型,用于解决分类问题。该模型的主要思想是将每个类别的样本表示成其原型,即该类别的所有样本的平均值。然后,使用欧几里得距离或余弦相似度等度量方法,将待分类的样本与每个类别的原型进行比较,从而确定其属于哪个类别。该模型在许多图像分类、语音识别等任务中取得了良好的效果,并且具有较强的泛化能力。

Prototypical网络的思想是每一个类别都存在一个prototype representation,样本点都是散落在prototype representation的周围。为了估计出类别的prototype representation,使用神经网络非线性映射将样本点映射到embedding space,所有support set在embedding space的均值记为prototype。在测试阶段,我们将离query point最近的prototype的类别作为预测类别。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-uQ0X6vaU-1685938244078)(./figs/prototypical.png)]

Few-shot 分类任务一般提供了一个support set和一个querry set。对于querry set中的每一个例子,我们都希望从support set给定的例子中预测出相应的标签。
Prototypical networks也算作是一种metric learning算法。它有两个阶段:

  1. 将support set的样本和querry set的样本输入到网络中得到他们在特征空间中的向量;
  2. 将querry sample与support set中的sample作比较预测它的类别。

因此,对于few-shot问题我们的挑战是:

  1. 找到一个好的featuer space。将样本投射到该特征空间中使相同类别的样本距离较小,不同类别中的样本距离较大。
  2. 找到一个他们在特征空间中比较representations向量的方法。

Prototypical Networks就属于第二种,该模型提出了一种在feature space中比较representations的方法。Protypical Networks计算每一个类别的prototype,即support set中每一个样本embedding的均值。计算querry sample的representations与每一个prototype的欧式距离,距离最小的prototype所属类别作为该样本的预测类别。

2. Notations

通过一个映射函数 f θ : R D → R M f_\theta :\mathbb{R} ^D\to \mathbb{R} ^M fθ:RDRM将样本从样本空间映射到prototype representation空间,prototype记为 c k ∈ R M c_k \in \mathbb{R}^{M} ckRM,维度为 M M M
每一个prototype是embeded support points向量的均值:
c k = 1 S k ∑ ( x i , y i ) ∈ S k f θ ( x i ) c_k = \frac{1}{S _k}\sum_{(x_i,y_i)\in S_k}^{} f_{\theta }(x_i) ck=Sk1(xi,yi)Skfθ(xi)
给定一个距离公式 d d d,prototypical network产生query point x x x与每个prototype距离的softmax:
p θ ( y = k ∣ x ) = exp ( − d ( f θ ( x ) , c k ) ) ∑ k ′ exp ( − d ( f θ ( x ) , c k ′ ) ) p_{\theta }(y=k|x) = \frac{\text{exp}(−d(f_{\theta }(x), c_k))}{ {\textstyle \sum_{k'}^{}} \text{exp}(−d(f_{\theta }(x), c_{k'}))} pθ(y=kx)=kexp(d(fθ(x),ck))exp(d(fθ(x),ck))
通过减小类别 k k k P P P的负对数概率来优化模型:
J ( θ ) = − log ⁡ p θ ( y = k ∣ x ) J(\theta ) = −\log p_{\theta }(y=k|x) J(θ)=logpθ(y=kx)

在每一个episodes,从训练集中随机选择一个子集来,然后再从子集中每个类别选择一部分数据作为support set,剩下的数据为query set。

  • 首先,将support set输入到网络然后产生embedded representations,计算每一个类别样本的均值作为该类别的prototype
  • 将query point输入到网络中得到embeded representation,计算它与每个prototype的距离,选择最近的一个prototype作为预测类别。
  • 将预测类别与真实标签进行比较,然后使用损失函数 J J J优化模型。

3. Dataset

论文中使用的数据集是Omniglot数据集,它采集了来自50个字母表的1623个手写字符,每个字符都由20位不同的人书写。你可以使用torchvision包来下载该数据集:

image_size = 28  
train_set = Omniglot(  
    root="./data",  
    background=True,  
    transform=transforms.Compose(  
        [  
            transforms.Grayscale(num_output_channels=3),  
            transforms.RandomResizedCrop(image_size),  
            transforms.RandomHorizontalFlip(),  
            transforms.ToTensor(),  
        ]  
    ),  
    download=True,  
)  
test_set = Omniglot(  
    root="./data",  
    background=False,  
    transform=transforms.Compose(  
        [  
            transforms.Grayscale(num_output_channels=3),  
            transforms.Resize([  
                int(image_size * 1.15), int(image_size * 1.15)  
            ]),  
            transforms.CenterCrop(image_size),  
            transforms.ToTensor(),  
        ]  
    ),  
    download=True,  
)

background设置为True选择training data,background设置为False选择test data。
此外,Omniglot数据集是灰度图,只有一个channel,而模型是期望输入三个channels,所以需要使用transforms.Grayscale(num_output_channels=3)进行与处理。

4. Prototypical Networks

下面是一个Prototypical networks的一个简单部署,源代码来自于这里。

class PrototypicalNetworks(nn.Module):  
    def __init__(self, backbone: nn.Module):  
        super(PrototypicalNetworks, self).__init__()  
        self.backbone = backbone  
  
    def forward(  
        self,  
        support_images: torch.Tensor,  
        support_labels: torch.Tensor,  
        query_images: torch.Tensor,  
    ) -> torch.Tensor:  
        """  
        Predict query labels using labeled support images.  
        """  
  
        # Extract the features of support and query images  
        z_support = self.backbone.forward(support_images)  
        z_query = self.backbone.forward(query_images)  
  
        # Infer the number of classes from the labels of the support set  
        n_way = len(torch.unique(support_labels))  
        # Prototype i is the mean of all support features vector with label i  
        z_proto = torch.cat(  
            [  
                z_support[torch.nonzero(support_labels == label)].mean(0)  
                for label in range(n_way)  
            ]  
        )  
  
        # Compute the euclidean distance from queries to prototypes  
        dists = torch.cdist(z_query, z_proto)  
  
        scores = -dists  
        return scores  
  
  
convolutional_network = resnet18(pretrained=True)  
convolutional_network.fc = nn.Flatten()  
  
model = PrototypicalNetworks(convolutional_network).cuda()

这里的backbone是一个特征提取器,可以定义成你想使用的任何网络。这里使用的是在ImageNet上预训练的ResNet-18网络作为backbone,FC layer被替换成了nn.Flatten(),因此backbone的输出是一个512维的向量。

5. Build Dataloader

Pytorch中给的dataloader一般不适用与few-shot learning问题,所以我们这里要自己定义一个dataloader。这个dataloader:

  1. 每个类别的数量应该相等;
  2. 需要将数据划分成为support set和querry set。

因此,首先我们要将数据集划分为 n n n-way个类别。然后,每个类别包含 n n n-shot和 n n n-query个样本 (每个batch包含 n n n-way*( n n n-shot + n n n-query)个样本)(注意:这里的 n n n不相等)。

N_WAY = 5 # Number of classes in a task  
N_SHOT = 5 # Number of images per class in the support set  
N_QUERY = 10 # Number of images per class in the query set  
N_EVALUATION_TASKS = 100

在Pytorch中,定义dataloader时需要注意三个参数:dataset, sampler和collate_fn (只有在map style dataset的时候才会用到)。

test_set.labels = [  
    instance[1] for instance in test_set._flat_character_images  
]  

test_sampler = TaskSampler(  
    test_set,   
    n_way=N_WAY,   
    n_shot=N_SHOT,   
    n_query=N_QUERY,   
    n_tasks=N_EVALUATION_TASKS,  
)  

test_loader = DataLoader(  
    test_set,  
    batch_sampler=test_sampler,  
    num_workers=12,  
    pin_memory=True,  
    collate_fn=test_sampler.episodic_collate_fn,  
)

下面我们来看看对于一个 5 5 5-way 5 5 5-shot 任务,产生的数据集是什么样的:

(  
    example_support_images,  
    example_support_labels,  
    example_query_images,  
    example_query_labels,  
    example_class_ids,  
) = next(iter(test_loader))  
  
plot_images(example_support_images, "support images", images_per_row=N_SHOT)  
plot_images(example_query_images, "query images", images_per_row=N_QUERY)

产生的support set的数据集是这样的:
support images
query set的数据集是这样的:
query images
在获取到数据后对模型进行训练:

model.eval()  
example_scores = model(  
    example_support_images.cuda(),  
    example_support_labels.cuda(),  
    example_query_images.cuda(),  
).detach()  
_, example_predicted_labels = torch.max(example_scores.data, 1)  
print("Ground Truth / Predicted")  
for i in range(len(example_query_labels)):  
    print(  
        f"{test_set._characters[example_class_ids[example_query_labels[i]]]} / {test_set._characters[example_class_ids[example_predicted_labels[i]]]}"  
    )

测试模型:

def evaluate_on_one_task(  
    support_images: torch.Tensor,  
    support_labels: torch.Tensor,  
    query_images: torch.Tensor,  
    query_labels: torch.Tensor,  
) -> [int, int]:  
    """  
    Returns the number of correct predictions of query labels, and the total   
    number of predictions.  
    """  
    return (  
        torch.max(  
            model(  
                support_images.cuda(),   
                support_labels.cuda(),   
                query_images.cuda(),  
            ).detach().data,
            1,  
        )[1]  
        == query_labels.cuda()  
    ).sum().item(), len(query_labels)  
def evaluate(data_loader: DataLoader):  
    # We'll count everything and compute the ratio at the end  
    total_predictions = 0  
    correct_predictions = 0  
    # eval mode affects the behaviour of some layers (such as batch normalization or dropout)  
    # no_grad() tells torch not to keep in memory the whole computational graph (it's more lightweight this way)  
    model.eval()  
    with torch.no_grad():  
        for episode_index, (  
            support_images,  
            support_labels,  
            query_images,  
            query_labels,  
            class_ids,  
        ) in tqdm(enumerate(data_loader), total=len(data_loader)):  
            correct, total = evaluate_on_one_task(  
                support_images, support_labels, query_images, query_labels  
            )
            total_predictions += total  
            correct_predictions += correct  
    print(
        f"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions/total_predictions):.2f}%"  
    )
evaluate(test_loader)

在Omniglot数据集上 5 5 5-way的准确率为86%:

100%|██████████| 100/100 [00:06<00:00, 16.41it/s]  
Model tested on 100 tasks. Accuracy: 86.32%

这个是原作者的Colab和Github。

Reference

  1. Your Own Few-Shot Classification Model Ready in 15mn with PyTorch

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/611305.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一文通吃:从 ZooKeeper 一致性,Leader选举讲到 ZAB 协议与 PAXOS 算法(下)

本文首发自\[慕课网] &#xff0c;想了解更多IT干货内容&#xff0c;程序员圈内热闻&#xff0c;欢迎关注"慕课网"及“慕课网公众号”&#xff01; 作者&#xff1a;大能 | 慕课网讲师 上篇文章&#xff0c;我们介绍了ZooKeeper集群保证数据一致性和Zookeeper集群Le…

带你全面了解 OAuth2.0

最开始接触 OAuth2.0 的时候&#xff0c;经常将它和 SSO单点登录搞混。后来因为工作需要&#xff0c;在项目中实现了一套SSO&#xff0c;通过对SSO的逐渐了解&#xff0c;也把它和OAuth2.0区分开了。所以当时自己也整理了一篇文章《SSO单点登录原理及实现方式》 最近需要经常和…

STM32单片机GPS北斗无线防丢定位超出距离报警系统NRF24L01

实践制作DIY- GC0136-GPS北斗无线防丢定位超出距离报警系统 基于STM32单片机设计-GPS北斗无线防丢定位超出距离报警系统 二、功能介绍&#xff1a; 主机&#xff1a;STM32F103CxT6系列最小系统板OLED显示器NRF24L01无线通讯模块GPS北斗双模定位模块蜂鸣器距离加减2个按键1个模…

爬虫的操作

目录 爬虫基本 re/etree/beautifulsoup保存本地/连接数据库 基本 re lxml/etree beautifulsoup 保存到本地 传入数据库 大致分为 爬虫基本 re/etree/beautifulsoup保存本地/连接数据库 基本 爬一个很简单的百度新闻热搜 爬排名 热搜名 和热搜指数 百度热搜 我们直…

Lucid VS 特斯拉电驱系统

Lucid如何用更小的电机赶超特斯拉 Lucid 称其电机设计是小型化的&#xff0c;并且一直自夸其Air电动汽车中轻型且“紧凑“的电机轻便到可以放进登机行李箱中。然而&#xff0c;小型只是一个方面。Lucid的电机每个重只有67磅&#xff0c;能够产生670马力的动力&#xff0c;你没…

JAVA 生成微信小程序码-分享码

JAVA生成小程序码(太阳码) 工具类是获取token使用; appId 小程序appID appSecret 小程序秘钥 小程序中得配置分享项&#xff0c;不然图片是裂开的。 开发>开发管理>开发设置 nginx 配置 location ~ ^/share { #、share 你的访问路径default_type text/html;alias /d…

Dart 3.0 语法新特性 | 模式匹配 Patterns

theme: cyanosis 一、 Patterns 是什么 下面是官方对 Patterns 特性的说明 patterns :\ 从下面的第一句中可以知道&#xff0c;Patterns 是一种语法级的特性&#xff0c;而语法特性是一种语言的根基。 Patterns are a syntactic category in the Dart language, like statement…

10 缓存双写一致性之更新策略探讨

什么是缓存双写一致性 如果redis中有数据&#xff1a;需要和数据库中的值相同如果redis中无数据&#xff1a;数据库中的值要是最新值 缓存按照操作来分&#xff0c;有细分2种 只读缓存读写缓存 同步直写策略&#xff1a;写缓存时也同步写数据库&#xff0c;缓存和数据库中的…

如何移动下载文件夹到另一个盘?

下载文件夹占用了越来越多的C盘可用空间&#xff1f;本教程将教你如何安全易行地将下载文件夹移动到其他驱动器&#xff0c;以便你可以释放更多的C盘空间。 关于下载文件夹 从网站下载程序后它们会被存储在哪里&#xff1f;一般来说&#xff0c;当你从互联…

基于C++实现的智慧农业移动巡检系统设计(附源码)

Overview 项目源码 https://download.csdn.net/download/DeepLearning_/87863659 此项目开始于2023年2月7日&#xff0c;项目内容为一种AGV图形化操作系统&#xff0c;采用ROS2GO开发&#xff0c;开发环境为Ubuntu18.04、ROS melodic、Qt5.9.9&#xff0c;该项目作为23年挑战杯…

js函数this指向

目录 this的绑定规则  绑定一&#xff1a;默认绑定&#xff1b; ​ 绑定二&#xff1a;隐式绑定&#xff1b; ​ 绑定三&#xff1a;显式绑定&#xff1b; 通过call或者apply绑定this对象  绑定四&#xff1a;new绑定&#xff1b; 内置函数的绑定 this绑定规则的…

给电脑重装系统的时间需要多久才能装好

在进行电脑重装系统时&#xff0c;如果遇到系统安装时间过长的情况&#xff0c;可能会引起用户的困惑和不安。本文将介绍一些常见的原因和解决方法&#xff0c;以帮助您理解并应对系统安装时间过长的情况。 ​工具/原料&#xff1a; 系统版本&#xff1a;Windows 10 专业版 品…

《Java并发编程实战》课程笔记(九)

Semaphore&#xff1a;如何快速实现一个限流器&#xff1f; 信号量模型 信号量模型还是很简单的&#xff0c;可以简单概括为&#xff1a;一个计数器&#xff0c;一个等待队列&#xff0c;三个方法。 在信号量模型里&#xff0c;计数器和等待队列对外是透明的&#xff0c;所以…

chatgpt赋能python:Python图片大小设置的SEO指南

Python 图片大小设置的SEO指南 在网站设计和开发中&#xff0c;图片大小通常是一个重要的问题。合适的图片大小可以极大地影响用户体验和搜索引擎优化&#xff08;SEO&#xff09;结果。Python是一种广泛使用的编程语言&#xff0c;可以用来控制和设置图片大小。在本文中&…

BUUCTF MD5

密文&#xff1a; e00cf25ad42683b3df678c61f42c6bda 简述&#xff1a; 一般MD5值是32位由数字“0-9”和字母“a-f”所组成的字符串&#xff0c;字母大小写统一&#xff1b;如果出现这个范围以外的字符说明这可能是个错误的md5值&#xff0c;就没必要再拿去解密了。 特征&…

SQL-DDL操作数据库、表

SQL-DDL操作数据库、表 1 DDL:操作数据库 1.1 查询数据库 查询所有的数据库 SHOW DATABASES; show databases;1.2 创建数据库 创建数据库 CREATE DATABASE 数据库名称; create database 数据库名称;创建数据库(判断&#xff0c;如果不存在则创建) CREATE DATABASE IF NOT…

SyntaxError:Unexpected end of JSON input while parsing near xxxxx 报错及解决

环境&#xff1a;Node 12.21.0、npm 6.14.11 &#xff08;其他版本也会出现这样的问题&#xff09; 找到报错日志并进行查看&#xff1a; less /Users/roc/.npm/_logs/2023-06-05T02_23_51_747Z-debug.log报错信息如下&#xff1a; 19067 verbose stack SyntaxError: Unexp…

【遇到的问题】JAVA应用程序处于安全原因被阻止。

遇到的问题&#xff1a; 直入正题&#xff0c;远程服务器用JAVA连接KVM报以下错(如图)。 应用程序处于安全原因被阻止 无法验证证书 将不执行该应用程序 名称&#xff1a;Java viewer 发行者&#xff1a;ATEN 位置&#xff1a;https://192.168.210:443 原因&#xff1a; 通过…

vue3实现高德地图多点标注(so easy)

vue3实现高德地图多点标注&#xff08;so easy&#xff09; 前言思路清晰&#xff0c;抽丝剥茧必要的准备工作最简单的部分处理数据之前&#xff08;最关键的思路&#xff09;效果完整代码 前言 非常感谢你能打开这篇博客&#xff0c;我想你一定是遇到了地图多点标注有关的问题…

采购管理系统对企业有什么作用?原来用零代码搭建如此便捷

什么是采购管理系统&#xff1f; 采购管理系统是一种企业内部管理软件&#xff0c;用于协调和管理企业的采购过程。它涵盖了采购计划、询价、比价、采购订单、采购合同、采购收货、发票等一系列采购环节&#xff0c;以及与供应商的信息和交流。其主要目的是&#xff1a;优化采…