Meta-Learning数学原理

news2025/2/23 22:53:25

文章目录

  • 什么是元学习
  • 元学习的目标
  • 元学习的类型
  • 数学推导
    • 1. 传统机器学习的数学表述
    • 2. 元学习的基本思想
    • 3. MAML 算法推导
      • 3.1 元任务设置
      • 3.2 内层优化:任务级别学习
      • 3.3 外层优化:元级别学习
      • 3.4 元梯度计算
      • 3.5 最终更新规则
    • 4. 算法合并
    • 5. 理解 MAML 的优化
  • 图例
  • MAML 的优势
  • 其他元学习方法
  • 总结
  • 手写笔记

🍃作者介绍:双非本科大四网络工程专业在读,阿里云专家博主,专注于Java领域学习,擅长web应用开发,目前开始人工智能领域相关知识的学习
🦅个人主页:@逐梦苍穹
📕所属专栏:人工智能
🌻gitee地址:xzl的人工智能代码仓库
✈ 您的一键三连,是我创作的最大动力🌹

之前介绍过元学习的内容:https://xzl-tech.blog.csdn.net/article/details/142025393
这篇文章讲一下Meta-Learning的数学原理。

什么是元学习

元学习(Meta-Learning),也称为“学习如何学习”,是一种机器学习方法,其目的是通过学习算法的经验和结构特性,提升算法在新任务上的学习效率。

换句话说,元学习试图学习一种更有效的学习方法,使得模型能够快速适应新的任务或环境。


传统的机器学习算法通常需要大量的数据来训练模型,并且当数据分布发生变化或者遇到一个新任务时,模型往往需要重新训练才能保持良好的性能。

而元学习则不同,它通过 从多个相关任务中学习,从而在面对新任务时更快速地进行学习。

元学习的核心思想是利用“学习的经验”来提高学习的速度和质量。

在元学习的框架中,有两个层次的学习过程:

  1. 元学习者(Meta-Learner): 负责从多个任务中提取经验和知识,用于更新学习策略或模型参数。
  2. 基础学习者(Base Learner): 在每个具体任务上执行实际的学习过程。

元学习的目标

元学习的目标是解决以下问题:

  • 快速适应: 当模型面临新任务时,能够基于已有的经验快速适应,而无需大量的数据和计算资源。
  • 跨任务泛化: 提高模型从多个任务中学习到的知识在新任务上的泛化能力。
  • 提高数据效率: 减少模型在新任务上所需的数据量,尤其是在数据稀缺或高昂的情况下。

元学习的类型

元学习可以按照不同的方式分类,以下是三种主要类型:

  1. 基于模型的元学习(Model-Based Meta-Learning):
    • 这种方法通过直接设计一种能够快速适应新任务的模型架构,通常是通过某种特殊的神经网络结构来实现的。例如,基于记忆的神经网络(如 LSTM 或 Memory-Augmented Neural Networks)被设计成能有效地记住过去的任务信息,并在新任务上进行快速调整。
    • 例子: MANN(Memory-Augmented Neural Networks),SNAIL(Simple Neural Attentive Meta-Learner)。
  2. 基于优化的元学习(Optimization-Based Meta-Learning):
    • 这种方法的核心是通过改进优化过程本身来实现快速学习。其代表算法是 MAML(Model-Agnostic Meta-Learning),它通过在所有任务上共享一个初始模型参数,使得初始模型在每个任务上进行少量梯度下降更新后能够快速适应新任务。
    • 例子: MAML(Model-Agnostic Meta-Learning),Reptile。
  3. 基于记忆的元学习(Memory-Based Meta-Learning):
    • 这类方法直接存储并检索训练过程中的经验数据。当遇到新任务时,通过查找与之相似的旧任务,并利用这些旧任务的数据和经验来快速学习。k-NN(k-近邻)方法是最基本的例子,而更复杂的方法可能使用深度记忆网络。
    • 例子: Meta Networks,Prototypical Networks。

数学推导

1. 传统机器学习的数学表述

在传统的机器学习中,我们通常试图找到一个函数 f θ f_\theta fθ来最小化给定数据集 D D D的损失函数:
θ ∗ = arg ⁡ min ⁡ θ L ( f θ , D ) \theta^* = \arg\min_{\theta} L(f_\theta, D) θ=argminθL(fθ,D)
其中:

  • θ \theta θ是模型的参数。
  • L ( f θ , D ) L(f_\theta, D) L(fθ,D)是损失函数,例如交叉熵损失。
  • 通过梯度下降等优化方法,我们不断更新参数 θ \theta θ以最小化损失。

2. 元学习的基本思想

元学习的目标是找到一种元算法 F ϕ F_\phi Fϕ,使得它可以快速学习新任务。这里的关键是学习一种 学习算法。换句话说,元学习希望找到一组元参数 ϕ \phi ϕ,从而在给定一个新任务 T i T_i Ti时,使用少量数据和梯度更新就可以迅速找到特定任务的参数 θ i \theta_i θi

3. MAML 算法推导

MAML 的目标是学习一个初始模型参数 θ \theta θ,使得它可以通过少量的梯度更新快速适应新任务。

3.1 元任务设置

假设有一组任务 { T 1 , T 2 , … , T N } \{T_1, T_2, \dots, T_N\} {T1,T2,,TN},每个任务 T i T_i Ti有自己的训练数据 D i train D_i^{\text{train}} Ditrain和测试数据 D i test D_i^{\text{test}} Ditest

3.2 内层优化:任务级别学习

对于每个任务 T i T_i Ti,我们首先使用任务的训练数据 D i train D_i^{\text{train}} Ditrain和当前的模型参数 θ \theta θ进行一次或多次梯度更新,得到任务特定的参数 θ i ′ \theta_i' θi
θ i ′ = θ − α ∇ θ L T i ( f θ , D i train ) \theta_i' = \theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}}) θi=θαθLTi(fθ,Ditrain)
其中:

  • α \alpha α是学习率。
  • L T i ( f θ , D i train ) L_{T_i}(f_\theta, D_i^{\text{train}}) LTi(fθ,Ditrain)是任务 T i T_i Ti的损失函数,例如对于分类任务可以是交叉熵损失。

3.3 外层优化:元级别学习

在每个任务的测试数据上评估更新后的模型参数 θ i ′ \theta_i' θi,计算其损失,并在所有任务上最小化测试损失的总和:
min ⁡ θ ∑ i = 1 N L T i ( f θ i ′ , D i test ) \min_{\theta} \sum_{i=1}^N L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) minθi=1NLTi(fθi,Ditest)
θ i ′ \theta_i' θi展开,这个目标实际上是关于初始参数 θ \theta θ的优化问题:
min ⁡ θ ∑ i = 1 N L T i ( f θ − α ∇ θ L T i ( f θ , D i train ) , D i test ) \min_{\theta} \sum_{i=1}^N L_{T_i}(f_{\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})}, D_i^{\text{test}}) minθi=1NLTi(fθαθLTi(fθ,Ditrain),Ditest)

3.4 元梯度计算

为了优化这个目标,我们需要对 θ \theta θ求梯度。这里涉及二阶梯度,因为 θ i ′ \theta_i' θi是通过内层优化得到的:
θ ← θ − β ∑ i = 1 N ∇ θ L T i ( f θ i ′ , D i test ) \theta \leftarrow \theta - \beta \sum_{i=1}^N \nabla_\theta L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) θθβi=1NθLTi(fθi,Ditest)
其中 β \beta β是元学习的学习率。

  • 这个更新包含了二阶导数项: ∇ θ θ i ′ = ∇ θ ( θ − α ∇ θ L T i ( f θ , D i train ) ) \nabla_\theta \theta_i' = \nabla_\theta \left(\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})\right) θθi=θ(θαθLTi(fθ,Ditrain))

3.5 最终更新规则

最终的元学习更新规则可以写为:
θ ← θ − β ∑ i = 1 N ∇ θ L T i ( f θ − α ∇ θ L T i ( f θ , D i train ) , D i test ) \theta \leftarrow \theta - \beta \sum_{i=1}^N \nabla_\theta L_{T_i}\left(f_{\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})}, D_i^{\text{test}}\right) θθβi=1NθLTi(fθαθLTi(fθ,Ditrain),Ditest)

4. 算法合并

将内层优化 θ i ′ \theta_i' θi代入外层优化的公式中,外层优化的梯度 ∇ θ L T i ( f θ i ′ , D i test ) \nabla_\theta L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) θLTi(fθi,Ditest)需要应用链式法则:
∇ θ L T i ( f θ i ′ , D i test ) = ∇ θ L T i ( f θ − α ∇ θ L T i ( f θ , D i train ) , D i test ) \nabla_\theta L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) = \nabla_\theta L_{T_i}\left(f_{\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})}, D_i^{\text{test}}\right) θLTi(fθi,Ditest)=θLTi(fθαθLTi(fθ,Ditrain),Ditest)
通过链式法则,展开这个公式:
∇ θ L T i ( f θ i ′ , D i test ) = ∇ θ i ′ L T i ( f θ i ′ , D i test ) ⋅ ∇ θ θ i ′ \nabla_\theta L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) = \nabla_{\theta_i'} L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) \cdot \nabla_\theta \theta_i' θLTi(fθi,Ditest)=θiLTi(fθi,Ditest)θθi
其中 ∇ θ θ i ′ \nabla_\theta \theta_i' θθi的形式为:
∇ θ θ i ′ = I − α ∇ θ 2 L T i ( f θ , D i train ) \nabla_\theta \theta_i' = I - \alpha \nabla^2_\theta L_{T_i}(f_\theta, D_i^{\text{train}}) θθi=Iαθ2LTi(fθ,Ditrain)
I I I是单位矩阵, ∇ θ 2 L T i ( f θ , D i train ) \nabla^2_\theta L_{T_i}(f_\theta, D_i^{\text{train}}) θ2LTi(fθ,Ditrain)是损失函数关于 θ \theta θ的二阶导数(Hessian 矩阵)。


最终的公式:

将这些部分合并在一起,得到 MAML 的最终更新公式为:
θ ← θ − β ∑ i = 1 N ∇ θ i ′ L T i ( f θ − α ∇ θ L T i ( f θ , D i train ) , D i test ) ⋅ ( I − α ∇ θ 2 L T i ( f θ , D i train ) ) \theta \leftarrow \theta - \beta \sum_{i=1}^N \nabla_{\theta_i'} L_{T_i}\left(f_{\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})}, D_i^{\text{test}}\right) \cdot \left(I - \alpha \nabla^2_\theta L_{T_i}(f_\theta, D_i^{\text{train}})\right) θθβi=1NθiLTi(fθαθLTi(fθ,Ditrain),Ditest)(Iαθ2LTi(fθ,Ditrain))


解释:

  • 内层优化:第一部分 θ i ′ = θ − α ∇ θ L T i ( f θ , D i train ) \theta_i' = \theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}}) θi=θαθLTi(fθ,Ditrain)表示在每个任务上用梯度下降更新 θ \theta θ,得到特定于任务的参数 θ i ′ \theta_i' θi
  • 外层优化:外层优化考虑测试集上的损失,并通过链式法则计算对 θ \theta θ的梯度。这部分的关键是包含了内层更新的二阶导数 ∇ θ θ i ′ \nabla_\theta \theta_i' θθi
  • 合并公式:最终的更新公式同时结合了内层和外层优化的过程,充分考虑了内层更新对外层优化的影响。

简化(在某些情况下):

在实际应用中,计算二阶导数(Hessian 矩阵)非常昂贵。因此,有时会使用近似方法来简化计算,例如“一次近似 MAML (First-Order MAML, FOMAML)”,忽略二阶项,仅使用一阶导数进行更新。简化后的更新公式为:
θ ← θ − β ∑ i = 1 N ∇ θ i ′ L T i ( f θ i ′ , D i test ) \theta \leftarrow \theta - \beta \sum_{i=1}^N \nabla_{\theta_i'} L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) θθβi=1NθiLTi(fθi,Ditest)

这个简化版本去除了 ∇ θ θ i ′ \nabla_\theta \theta_i' θθi中的二阶导数计算。

5. 理解 MAML 的优化

通过上面的推导,MAML 的优化分为两个阶段:

  1. 内层优化:在每个任务上利用任务的训练数据对模型进行一次或多次更新,以获得任务特定的模型参数。
  2. 外层优化:在所有任务的测试数据上评估内层优化后的模型,并利用这个评估结果更新模型的初始参数。

图例

MAML 的优势

MAML 的一个关键优势在于,它学习了一个初始参数 θ \theta θ,使得它可以通过少量梯度更新快速适应新任务。这使得它非常适合少样本学习场景,如几次样本分类。

其他元学习方法

除了 MAML,文件中还提到其他元学习方法,如基于优化器的元学习、网络架构搜索(NAS)等。这些方法都在不同程度上优化了元学习的过程,使得模型能够在少量数据的情况下快速学习。

总结

元学习的数学推导核心在于通过多个任务的训练,学习到一个通用的学习算法(或模型初始化),使得模型可以快速适应新任务。MAML 是元学习的一个经典方法,通过在元任务上进行二阶优化,使模型获得更好的泛化能力。

手写笔记

最后放几张今天的手写笔记,主要是方便查阅。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2145679.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Paper Digest|OpenSPG 超大规模知识仓储 KGFabric 论文解读

本文作者:祝锦烨,蚂蚁集团开发工程师,主要研究方向是图谱存储与计算。过去一年在团队的主要工作是蚂蚁知识图谱平台和 KGFabric 相关研发,研究成果收录于 VLDB24。 2024 年 8 月 26 日,数据管理与数据库领域顶级国际会…

[数据集][目标检测]红外微小目标无人机直升机飞机飞鸟检测数据集VOC+YOLO格式7559张4类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):7559 标注数量(xml文件个数):7559 标注数量(txt文件个数):7559 标注…

Hikvision综合安防管理平台isecure center文件读取深度利用

前言 远离一线很久了,很难有实战的机会。碰到Hikvision的漏洞,市面上的很多文章又很模糊,自己摸全点做个详细记录。 参考文章,向佬学习。本次测试为内部授权测试,已脱敏。https://mp.weixin.qq.com/s/zvo195UQvWwTppm…

WPF 的TreeView的TreeViewItem下动态生成TreeViewItem

树形结构仅部分需要动态生成TreeViewItem的可以参考本文。 xaml页面 <TreeView MinWidth"220" ><TreeViewItem Header"功能列表" ItemsSource"{Binding Functions}"><TreeViewItem.ItemTemplate><HierarchicalDataTempla…

TikTok直播专线服务商推荐

在追求TikTok直播的极致体验时&#xff0c;搭建稳定高效的专线网络无疑是最重要的第一步。国内市场涌现出众多TikTok直播专线服务商&#xff0c;面对如此多的选择&#xff0c;用户究竟该如何权衡利弊&#xff0c;作出明智的决策呢&#xff1f;以下是一些关键因素和TIKTOK直播专…

基于 K8S kubernetes 的常见日志收集方案

目录 1、日志对我们来说到底重不重要&#xff1f; 2、常见的日志收集方案 2.1 EFK 2.2 ELK Stack 2.3 ELKfilebeat 2.4 其他方案 2、elasticsearch组件介绍 3、filebeat组件介绍 3.1 filebeat和beat关系 3.2 filebeat是什么&#xff1f; 3.3 Filebeat工作原理 3.4 …

FEAD:fNIRS-EEG情感数据库(视频刺激)

摘要 本文提出了一种可用于训练情绪识别模型的fNIRS-EEG情感数据库——FEAD。研究共记录了37名被试的脑电活动和脑血流动力学反应&#xff0c;以及被试对24种情绪视听刺激的分类和维度评分。探讨了神经生理信号与主观评分之间的关系&#xff0c;并在前额叶皮层区域发现了显著的…

56.【C语言】字符函数和字符串函数(strtok函数)(未完)

目录 12.strtok函数(较复杂) *简单使用 总结: *优化 12.strtok函数(较复杂) *简单使用 strtok:string into tokens cplusplus的介绍 点我跳转 翻译: 函数 strtok char * strtok ( char * str, const char * delimiters ); 总结: delimiters参数指向一个字符串&#xff0…

RK3568平台(基础篇)示波器的使用

一.示波器面板介绍 示波器的横轴表示的是时间,在横轴上有10个小格,每个小格的时间是200us。 示波器的纵轴表示的是电压,在纵轴上有8个小格,每个小格的电压表示1V。 以上是个方波,方波在纵轴上占5个小格,每个小格的电压是500mv,所以这个方波的电压为2500mv。 方波在横…

每日OJ题_牛客_dd爱框框(滑动窗口)

目录 dd爱框框&#xff08;滑动窗口&#xff09; 解析代码 dd爱框框&#xff08;滑动窗口&#xff09; dd爱框框_牛客题霸_牛客网 解析代码 基础同向双指针算法。关于滑动窗口的介绍可看这篇&#xff1a;Offer必备算法02_滑动窗口_八道力扣OJ题详解&#xff08;由易到难&am…

【我的 PWN 学习手札】Largebin Attack(<= glibc-2.38可利用)

目录 前言 一、Largebin Attack的通用利用方法 二、再次 Largebin Attack 三、测试与模板 前言 早期的 Largebin Attack&#xff0c;通过修改 largebin 中 free chunk 的 bk 和 bk_nextsize 指针域&#xff0c;能够实现任意地址写堆地址。然而在 glibc > version2.30 后…

Mycat搭建读写分离

启动Mycat 进入 /mycat/conf/datasources目录下&#xff0c;修改prototypeDs.datasource.json文件 去mycat/bin目录用启动mycat ./mycat start (关闭mycat ./mycat stop)连接mycat 默认端口8066 用户名root 密码123456 注意&#xff1a;这里ip设为null表示任何ip都可以访问…

【学习笔记】SSL/TLS安全机制之CAA

1、概念界定 CAA全称Certificate Authority Authorization&#xff0c;即证书颁发机构授权&#xff0c;每个CA都能给任何网站签发证书。 2、CAA要解决的问题 例如&#xff0c;蓝色网站有一张橙色CA颁发的证书&#xff0c;我们也知道还有许多其他的CA&#xff1b;中间人可以说服…

JACM23 - A New Algorithm for Euclidean Shortest Paths in the Plane

前言 如果你对这篇文章感兴趣&#xff0c;可以点击「【访客必读 - 指引页】一文囊括主页内所有高质量博客」&#xff0c;查看完整博客分类与对应链接。 本文关注的问题为计算几何学中的经典问题&#xff0c;即「在平面上给定一组两两不相交的多边形障碍物&#xff0c;寻找两点…

Redis(redis基础,SpringCache,SpringDataRedis)

文章目录 前言一、Redis基础1. Redis简介2. Redis下载与安装3. Redis服务启动与停止3 Redis数据类型4. Redis常用命令5. 扩展数据类型 二、在Java中操作Redis1. Spring Data Redis的使用1.1. 介绍1.2. 环境搭建1.3. 编写配置类&#xff0c;创建RedisTemplate对象1.4. 通过Redis…

助力数字农林业发展服务香榧智慧种植,基于嵌入式端超轻量级模型LeYOLO全系列【n/s/m/l】参数模型开发构建香榧种植场景下香榧果实检测识别系统

作为一个生在北方但在南方居住多年的人&#xff0c;居然头一次听过香榧&#xff08;fei&#xff09;这种作物&#xff0c;而且这个字还不会念&#xff0c;查了以后才知道读音&#xff08;fei&#xff09;&#xff0c;三声&#xff0c;这着实引起了我的好奇心&#xff0c;我相信…

C++入门基础知识75(高级)——【关于C++ Web 编程】

成长路上不孤单&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a;&#x1f60a; 【14后&#x1f60a;///C爱好者&#x1f60a;///持续分享所学&#x1f60a;///如有需要欢迎收藏转发///&#x1f60a;】 今日分享关于C Web 编程的相关内容&#xff01; 关于…

HomeAssistant显示节假日

先看效果 步骤&#xff1a; 新建卡片时选择“Markdown 卡片”代码在文章最下方&#xff0c;当然你也可以自己修改 点击保存/完成 ### {% if now().hour > 6 and now().hour < 9 -%} 早上好&#xff0c; {%- elif now().hour > 9 and now().hour < 12 -%} 上午好…

【SSM-Day2】第一个SpringBoot项目

运行本篇中的代码&#xff1a;idea专业版或者idea社区版本&#xff08;2021.1~2022.1.4&#xff09;->这个版本主要是匹配插件spring boot Helper的免费版(衰) 【SSM-Day2】第一个SpringBoot项目 框架->Spring家族框架快速上手Spring BootSpring Boot的作用通过idea创建S…

【iOS】引用计数

引用计数 自动引用计数引用计数内存管理的思考方式自己生成的对象&#xff0c;自己所持有非自己生成的对象&#xff0c;自己也能持有不再需要自己持有的对象时释放无法释放非自己持有的对象 自动引用计数 自动引用计数(ARC,Automatic Reference Counting)是指内存管理中对引用…