【论文解读】元学习:MAML

news2025/1/15 10:06:12

一、简介

元学习的目标是在各种学习任务上训练模型,这样它就可以只使用少量的训练样本来解决新任务。
在这里插入图片描述

论文所提出的算法训练获取较优模型的参数,使其易于微调,从而实现快速自适应。该算法与任何用梯度下降训练的模型兼容,适用于各种学习问题,包括分类、回归和强化学习。
论文中表明,该算法在few-shot image classification基准上达到了SOTA的性能,在few-shot regression上也产出了良好的结果,并加速了策略梯度强化学习的微调

1.1 元学习与一般ML的区别

  • ML: 根据给定数据找到一个函数f,后续在相同的任务上运用该函数
  • Meta Learning: 根据大量任务(数据)找一个 F可以输出f 的能力,后续运用的时候在F上进行较少数据量的update 后就可以得到对应运用任务的函数f
    在这里插入图片描述

二、算法思路与伪代码(监督学习)

2.1 主要思路

核心思路就是找到一个较好的初始参数值,可以在任何同一类型的任务上进行少量数据较少次数update 后就可以得到较好的模型,下图展示了meta Learning 最终学习的参数 ϕ \phi ϕ
在这里插入图片描述

2.2 伪代码

Algorithm2 MAML for Few-Shot Supervised Learning Require:   p ( T ) : distribution over tasks Require:   α : 一系列task训练-supportSet,梯度更新学习率-在循环内更新 β : 一系列task评估-querySet,梯度更新学习率-在循环外更新  1: 初始化参数  θ  2:  while   not done  do    3:  从任务集合中抽取任务  T i ∼ p ( T )  4:  for   all   T i   do    5:  从任务中抽取k shot个样本 D = { X j , Y j } ∈ T i  6:  基于任务的损失函数计算损失 L T i = l ( Y j , f θ i ( X j ) )  7:  基于损失函数计算梯度, 并更新参数 ∂ L T i ∂ θ i = ∇ θ L T i ( f θ ) θ i ′ = θ − α ∇ θ L T i ( f θ )  8:  从任务中抽取 q query 个样本 D ′ = { X j , Y j } ∈ T i 基于更新后的 θ ′ 进行预测并计算损失,用于循环后更新 L T i ′ = l ( Y j , f θ i ′ ( X j ) ) 计算梯度 ∂ L T i ′ ∂ θ i ′ = ∇ θ L T i ′ ( f θ ′ ) 计算最终梯度 ∇ θ L T i ( f θ ′ ) = ∂ L T i ′ ∂ θ i = ∂ L T i ′ ∂ θ i ′ ∂ θ i ′ ∂ θ i  9:  end   for 10:  Update  θ ← θ − β ∑ T i ∼ p ( T ) ∇ θ L T i ( f θ ′ ) 11:  end   while   r e t u r n   θ \begin{aligned} &\rule{110mm}{0.4pt} \\ &\text{Algorithm2 MAML for Few-Shot Supervised Learning}\\ &\rule{110mm}{0.4pt} \\ &\textbf{Require: } p(\mathcal{T}): \text{distribution over tasks}\\ &\textbf{Require: } \alpha \text{: 一系列task训练-supportSet,梯度更新学习率-在循环内更新} \\ &\hspace{17mm} \beta \text{: 一系列task评估-querySet,梯度更新学习率-在循环外更新}\\ &\rule{110mm}{0.4pt} \\ &\text{ 1: 初始化参数 } \theta \\ &\text{ 2: }\textbf{while }\text{not done }\textbf{do }\\ &\text{ 3: }\hspace{5mm}\text{从任务集合中抽取任务 }\mathcal{T}_i \sim p(\mathcal{T}) \\ &\text{ 4: }\hspace{5mm}\textbf{for all }\mathcal{T}_i\textbf{ do }\\ &\text{ 5: }\hspace{10mm}\text{从任务中抽取k shot个样本} \mathcal{D}=\{X^j, Y^j\} \in \mathcal{T}_i\\ &\text{ 6: }\hspace{10mm}\text{基于任务的损失函数计算损失} \mathcal{L}_{\mathcal{T}_i}=l(Y^j, f_{\theta_{i}}(X^j))\\ &\text{ 7: }\hspace{10mm}\text{基于损失函数计算梯度, 并更新参数} \frac{\partial{\mathcal{L}_{\mathcal{T}_i}}}{\partial \theta_i} = \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta) \\ &\hspace{17mm} \theta_i^{\prime} = \theta - \alpha \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_\theta) \\ &\text{ 8: }\hspace{10mm}\text{从任务中抽取 q query 个样本} \mathcal{D}^{\prime}=\{X^j, Y^j\} \in \mathcal{T}_i\\ &\hspace{15mm} \text{基于更新后的}\theta^{\prime}\text{进行预测并计算损失,用于循环后更新} \mathcal{L}^{\prime}_{\mathcal{T}_i}=l(Y^j, f_{\theta^{\prime}_{i}}(X^j))\\ &\hspace{15mm} \text{计算梯度}\frac{\partial{\mathcal{L}^{\prime}_{\mathcal{T}_i}}}{\partial \theta^{\prime}_i} = \nabla_\theta \mathcal{L}^{\prime}_{\mathcal{T}_i}(f_{\theta^{\prime}}) \\ &\hspace{15mm} \text{计算最终梯度} \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_{\theta^{\prime}}) = \frac{\partial{\mathcal{L}^{\prime}_{\mathcal{T}_i}}}{\partial \theta_i}=\frac{\partial{\mathcal{L}^{\prime}_{\mathcal{T}_i}}}{\partial \theta^{\prime}_i}\frac{\partial \theta^{\prime}_i}{\partial \theta_i} \\ &\text{ 9: }\hspace{5mm}\textbf{end for} \\ &\text{10: }\hspace{5mm}\text{Update } \theta \leftarrow \theta - \beta \sum_{\mathcal{T}_i \sim p(\mathcal{T})} \nabla_\theta \mathcal{L}_{\mathcal{T}_i}(f_{\theta^{\prime}}) \\ &\text{11: }\textbf{end while } \\ &\bf{return} \: \theta \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} Algorithm2 MAML for Few-Shot Supervised LearningRequire: p(T):distribution over tasksRequire: α一系列task训练-supportSet,梯度更新学习率-在循环内更新β一系列task评估-querySet,梯度更新学习率-在循环外更新 1: 初始化参数 θ 2: while not done do  3: 从任务集合中抽取任务 Tip(T) 4: for all Ti do  5: 从任务中抽取k shot个样本D={Xj,Yj}Ti 6: 基于任务的损失函数计算损失LTi=l(Yj,fθi(Xj)) 7: 基于损失函数计算梯度并更新参数θiLTi=θLTi(fθ)θi=θαθLTi(fθ) 8: 从任务中抽取 q query 个样本D={Xj,Yj}Ti基于更新后的θ进行预测并计算损失,用于循环后更新LTi=l(Yj,fθi(Xj))计算梯度θiLTi=θLTi(fθ)计算最终梯度θLTi(fθ)=θiLTi=θiLTiθiθi 9: end for10: Update θθβTip(T)θLTi(fθ)11: end while returnθ

三、简单实践

用Meta Learning 学习 y = a × s i n ( x + b ) y = a\times sin(x + b) y=a×sin(x+b), 不同的a, b代表不同的任务

3.1 任务数据准备

class SineWaveTask:
    def __init__(self):
        self.a = np.random.uniform(0.1, 5.0)
        self.b = np.random.uniform(1, 2 * np.pi)
        self.train_x = None
    
    def f(self, x):
        return self.a * np.sin(x + self.b)
    
    def train_set(self, size=10, force_new=False):
        if self.train_x is None and not force_new:
            self.train_x = np.random.uniform(-5, 5, size)
            x = self.train_x
        elif not force_new:
            x = self.train_x
        else:
            x = np.random.uniform(-5, 5, size)
        
        y = self.f(x)
        return torch.Tensor(x).float(), torch.Tensor(y).float()

    def test_set(self, size=50):
        x = np.linspace(-5, 5, size)
        y = self.f(x)
        return torch.Tensor(x).float(), torch.Tensor(y).float()
    
    def plot(self, *args, **kwargs):
        x, y = self.test_set()
        return plt.plot(x.cpu().detach().numpy(), y.cpu().detach().numpy(), *args, **kwargs)


SineWaveTask().plot()
SineWaveTask().plot()
SineWaveTask().plot()
plt.show()

在这里插入图片描述

3.2 模型

因为query task中需要用support task后的参数进行推理,后进行二阶导来update 参数,所以多了一个query_forward 方法

class sineModel(nn.Module):
    def __init__(self):
        super(sineModel, self).__init__()
        self.l1 = nn.Linear(1, 40)
        self.l2 = nn.Linear(40, 40)
        self.head = nn.Linear(40, 1)
    
    def forward(self, x):
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        return self.head(x)
    
    def query_forward(self, x, support_param_dict):
        x = torch.relu(
            F.linear(x, support_param_dict['l1.weight'], support_param_dict['l1.bias'])
            )
        x = torch.relu(
            F.linear(x, support_param_dict['l2.weight'], support_param_dict['l2.bias'])
            )
        return F.linear(x, support_param_dict['head.weight'], support_param_dict['head.bias'])

SUPPORT_QUERY_TASKS = [SineWaveTask() for _ in range(1000)]
TEST_TASKS = [SineWaveTask() for _ in range(1000)]

3.3 MAML


def maml_sine(model, epochs, lr=1e-3, inner_lr=0.1, batch_size=1, first_order=False):
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.MSELoss()
    ep_loss = []
    for ep_i in range(epochs):
        tqd_bar = tqdm(
            enumerate(random.sample(SUPPORT_QUERY_TASKS, len(SUPPORT_QUERY_TASKS))),
            total=len(SUPPORT_QUERY_TASKS)
        )
        tqd_bar.set_description(f'[ {ep_i+1:02d} / {epochs:02d} ]')
        task_loss = []
        for idx, suport_t in tqd_bar:
            fast_weights = OrderedDict(model.named_parameters())
            s_x, s_y = suport_t.train_set(force_new=False)
            q_x, q_y = suport_t.train_set(force_new=True)
            # support
            for _ in range(1): 
                s_y_hat = model(torch.Tensor(s_x[:, None]))
                loss = loss_fn(s_y_hat, torch.Tensor(s_y.reshape(-1, 1)))
                grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=not first_order) # 便于进行二阶导
                fast_weights = OrderedDict(
                    (name, param - inner_lr * (grad.detach().data if first_order else grad) )
                    for ((name, param), grad) in zip(fast_weights.items(), grads)
                )
            
            # query
            logits = model.query_forward(torch.Tensor(q_x[:, None]), fast_weights)
            loss = loss_fn(logits, torch.Tensor(q_y.reshape(-1, 1)))
            task_loss.append(loss)
            
            if (idx + 1) % batch_size == 0:
                # update
                model.train()
                opt.zero_grad()
                meta_batch_loss = torch.stack(task_loss).mean()
                meta_batch_loss.backward()
                opt.step()
                loss_item = meta_batch_loss.cpu().detach().numpy()
                tqd_bar.set_postfix({'loss': "{:.3f}".format(loss_item)})
                task_loss = []

        ep_loss.append(loss_item)
    return ep_loss


sine_model = sineModel()
ep_losses = maml_sine(sine_model, epochs=5, lr=1e-3, inner_lr=0.02, batch_size=2, first_order=False)

结果查看

全部代码见笔者github:maml.ipynb

maml训练结果显然要好于随机模型
在这里插入图片描述

参考

  • Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
  • 李宏毅老师的课程PPT(国立台湾大学)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/998502.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

群辉 Synology NAS Docker 安装 RustDesk-server 自建服务器只要一个容器

from https://blog.zhjh.top/archives/M8nBI5tjcxQe31DhiXqxy 简介 之前按照网上的教程,rustdesk-server 需要安装两个容器,最近想升级下版本,发现有一个新镜像 rustdesk-server-s6 可以只安装一个容器。 The S6-overlay acts as a supervi…

【Proteus仿真】【STM32单片机】便携式血糖仪

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 系统运行后,LCD1602显示开机界面信息,当按下K1键开始测量,步进电机运行启动针头采血,然后检测血糖值显示在屏幕上;如果血糖高于上限&#xff0c…

Upload-labs十六和十七关

目录 第十六关第十七关 第十六关 直接上传php文件判断限制方式: 同第十五关白名单限制 第十六关源码: 代码逻辑判断了后缀名、content-type,以及利用imagecreatefromgif判断是否为gif图片,最后再做了一次二次渲染 第71行检测…

计算机网络第四章——网络层(中)

提示:待到山花烂漫时,她在丛中笑。 文章目录 需要加头加尾,其中头部最重要的就是加了IP地址和MAC地址(也就是逻辑地址和物理地址)集线器物理层设备,交换机是物理链路层的设备,如上图路由器左边就…

Vue使用ts的枚举类型

vue项目中要使用ts的枚举类型需要为script标签的lang属性添加ts属性值 <script lang"ts" setup></script > 首先要声明一下&#xff08;我这里是声明了一个名称一个颜色&#xff09;&#xff1a; 接下来是页面中的标签使用&#xff08;用的是element表格…

Linux系统编程--IO系统调用

文章目录 一、I/O系统调用1.open() 打开文件1.1 所需基础知识1.2. open() 详解1.3 示例代码 2.read() 读取文件2.1.基础知识2.2.read() 详解2.3. 读入所有字节 3.write() 写文件3.1. 基础背景知识3.2.write() 详解3.3.示例代码3.4.注意点3.4.1.同步IO1. fsync() 和fdatasync()2…

MySQL高可用搭建方案之(MMM)

有的时候博客内容会有变动&#xff0c;首发博客是最新的&#xff0c;其他博客地址可能会未同步,认准https://blog.zysicyj.top 注意&#xff1a;这篇转载文章&#xff0c;非原创 首发博客地址 原文地址 前言 MySQL的高可用有很多种&#xff0c;有我们经常说的MMM架构、MHA架构、…

内网隧道代理技术(二十三)之 DNS隧道反弹Shell

DNS隧道反弹Shell DNS隧道 DNS协议是一种请求、应答协议,也是一种可用于应用层的隧道技术。DNS隧道的工作原理很简单,在进行DNS查询时,如果查询的域名不在DNS服务器本机缓存中,就会访问互联网进行查询,然后返回结果。如果在互联网上有一台定制的服务器,那么依靠DNS协议…

标准C库IO函数和Linux系统IO函数

linux系统的io函数更加偏底层&#xff0c;更加建议使用C库的函数&#xff0c;效率较高&#xff08;有缓冲区&#xff09; 磁盘满了或者手动fflush或者关闭文件才会io一次&#xff0c;效率提高&#xff0c;但是linux没有缓冲区 主要通过file *fp指针操作文件&#xff0c;文件描…

COSCon'23 社区召集令

一年一度的开源盛会&#xff0c;COSCon23 第八届中国开源年会&#xff0c;将于10月28~29日&#xff0c;在四川成都市高新区菁蓉汇召开&#xff01;本次大会的主题是&#xff1a;“开源&#xff1a;川流不息、山海相映”&#xff01; 三年新冠疫情没有将我们击垮&#xff0c;开源…

记录socket的使用 | TCP/IP协议下服务器与客户端之间传送数据 | java学习笔记

谨以此篇&#xff0c;记录TCP编程&#xff0c;方便日后查阅笔记 注意&#xff1a;用BufferedWriter write完后&#xff0c;一定要flush&#xff1b;否则字符不会进入流中。去看源码可知&#xff1a;真正将字符写入的不是write()&#xff0c;而是flush()。 服务器端代码&#…

运维学习之部署Alertmanager-0.24.0

参考《监控系统部署prometheus基本功能》先完成prometheus部署。 参考《运维学习之采集器 node_exporter 1.3.1安装并使用》安装node_exporter。 下载 nohup wget https://github.com/prometheus/alertmanager/releases/download/v0.24.0/alertmanager-0.24.0.linux-amd64.ta…

SecureCRT ssh链接服务器

SecureCRT通过密钥进行SSH登录 说明&#xff1a; 一般的密码方式登录容易被密码暴力破解。所以一般我们会将 SSH 的端口设置为默认22以外的端口&#xff0c;或者禁用root账户登录。其实可以通过密钥登录这种方式来更好地保证安全。 密钥形式登录的原理是&#xff1a;利用密钥…

day34 集合总结

集合总结 一、概述 作用&#xff1a;存储对象的容器&#xff0c;代替数组的&#xff0c;使用更加的便捷 所处的位置&#xff1a;java.util 体系结构 二、Collection 内部的每一个元素都得是引用数据类型 常用方法 add(Object o) 添加元素 addAll(Collection c) 将指定集…

【LeetCode周赛】LeetCode第362场周赛

LeetCode第362场周赛 与车相交的点判断能否在给定时间到达单元格将石头分散到网格图的最少移动次数 与车相交的点 给你一个下标从 0 开始的二维整数数组 nums 表示汽车停放在数轴上的坐标。对于任意下标 i&#xff0c;nums[i] [starti, endi] &#xff0c;其中 starti 是第 i…

讯飞星火认知大模型,多种应用一键体验整合

分享几个可以&#xff0c;直接可以使用的AI应用&#xff0c;依托于讯飞星火大模型实现的&#xff1b; 现在讯飞星火认知大模型&#xff0c;使用已经完全开放&#xff0c;可以直接使用&#xff1b; AI抖音商品种草文案 功能&#xff1a; 通过将商品信息输入到讯飞星火AI大模…

IntelliJ IDEA工具常用插件汇总

&#x1f61c;作 者&#xff1a;是江迪呀✒️本文关键词&#xff1a;IntelliJ IDEA 、常用插件☀️每日 一言&#xff1a;人的一生其实都在偏见和走出偏见中度过 文章目录 一、前言二、Plugins1.Key Promoter X2.CodeGlance3.Git Integration&#xff1a;4.Markdow…

SpringBoot整合Mybatis-Plus(含自动配置分析)

目录 1. Mybatis-Plus介绍2. 创建Mysql表和添加测试数据3. 添加pom.xml依赖4. 自动配置分析5. 代码实现5.1 User类实现5.2 指定MapperScan扫描路径5.3 Mapper接口实现5.4 Service实现5.5 UserMapper测试 1. Mybatis-Plus介绍 Mybatis-Plus是一个Mybatis的增强工具&#xff0c;…

Rich Bowen: 无论你在创造什么,最终交付的是信任。

早在开源被我们称之为开源&#xff0c;Rich Bowen 就已经参与其中。作为 Apache 软件基金会的成员&#xff0c;Rich 目前担任董事会成员、会议副总裁。此外&#xff0c;他还是亚马逊云科技的开源策略师。这些多重角色赋予了他对开源的更广泛和深刻的理解。 在他于 2023 年 Com…