【博士每天一篇文献-算法】持续学习经典算法之LwF: Learning without forgetting

news2025/1/10 10:49:29

1 介绍

年份:2017
作者:Zhizhong Li,Amazon AWS Rekognition;Derek Hoiem,伊利诺伊大学计算机科学教授
会议:IEEE transactions on pattern analysis and machine intelligence
引用量:4325
Li Z, Hoiem D. Learning without forgetting[J]. IEEE transactions on pattern analysis and machine intelligence, 2017, 40(12): 2935-2947.
作者提出了一种名为“Learning without Forgetting”(LwF)的方法。利用知识蒸馏损失来保持旧任务的输出,这是一种创新的损失函数应用,与传统的参数正则化方法相比,能够更直接地保留旧任务的知识。这种方法使用新任务数据来训练网络,同时保留原始功能。LwF的表现优于常用的特征提取和微调适应技术。
image.png
image.png

2 创新点

  1. 新任务学习与旧任务保留的结合
    • 论文提出了一种新颖的方法,使得在训练卷积神经网络以学习新任务的同时,能够保留对旧任务的记忆,有效解决了灾难性遗忘问题。
  2. 无需旧任务数据
    • LwF算法不需要旧任务的训练数据,只需要新任务的数据来更新网络,这与传统的多任务学习和迁移学习方法不同,后者通常需要访问所有任务的数据。
  3. 知识蒸馏的创新应用
    • 利用知识蒸馏损失来保持旧任务的输出,这是一种创新的损失函数应用,与传统的参数正则化方法相比,能够更直接地保留旧任务的知识。
  4. 预热步骤与联合优化步骤
    • 算法采用了预热步骤和联合优化步骤的训练策略,预热步骤首先训练新任务参数,然后联合优化步骤同时训练所有参数,这种分阶段的训练方法提高了学习效率和性能。
  5. 损失函数的平衡权重
    • 通过引入损失平衡权重 λ o \lambda_o λo,LwF算法能够平衡新旧任务的损失,提供了一种灵活的方法来调整新旧任务性能之间的权衡。

3 相关研究

  1. Catastrophic Forgetting:
    • 描述了在神经网络中,当学习新任务时,旧任务的性能可能会急剧下降的现象,这被称为灾难性遗忘。
    • 文献:[1] M. McCloskey and N. J. Cohen, “Catastrophic interference in connectionist networks: The sequential learning problem.”
  2. Transfer Learning:
    • 讨论了迁移学习,即在一个任务上学到的知识可以帮助另一个不同但相关任务的学习。
    • 文献:[5] J. Donahue et al., “Decaf: A deep convolutional activation feature for generic visual recognition;” [6] R. Girshick et al., “Rich feature hierarchies for accurate object detection and semantic segmentation.”
  3. Multi-task Learning:
    • 多任务学习,旨在同时提高多个任务的性能,通过共享表示来提高泛化能力。
    • 文献:[7] R. Caruana, “Multitask learning.”
  4. Feature Extraction:
    • 特征提取方法,其中预训练的深度CNN用于计算图像的特征,然后使用这些特征训练新任务的分类器。
    • 文献:[5] J. Donahue et al., “Decaf: A deep convolutional activation feature for generic visual recognition.”
  5. Fine-tuning:
    • 微调方法,通过修改预训练网络的参数来适应新任务,通常使用较小的学习率以避免大幅偏离原始参数。
    • 文献:[6] R. Girshick et al., “Rich feature hierarchies for accurate object detection and semantic segmentation.”
  6. Joint Training:
    • 联合训练,即同时优化所有任务的参数,通过交错不同任务的样本来进行训练。
    • 文献:[7] R. Caruana, “Multitask learning.”
  7. Continual Learning or Lifelong Learning:
    • 持续学习或终身学习,关注如何在学习新任务的同时保留对旧任务的记忆。
    • 文献:[24] S. Thrun, “Lifelong learning algorithms;” [25] T. Mitchell et al., “Never-ending learning.”
  8. Knowledge Distillation:
    • 知识蒸馏,一种将大型网络的知识转移到小型网络的方法,通过优化损失函数使得小型网络的输出接近大型网络。
    • 文献:[11] G. Hinton et al., “Distilling the knowledge in a neural network.”
  9. Net2Net:
    • Net2Net方法,可以快速初始化网络以进行超参数探索,通过生成一个更深、更宽的网络,该网络在功能上等同于现有的网络。
    • 文献:[20] T. Chen et al., “Net2net: Accelerating learning via knowledge transfer.”

4 算法原理

image.png

  1. 定义参数
    • 假设一个CNN有一组共享参数 (\theta_s)(例如AlexNet中的五个卷积层和两个全连接层)。
    • 旧任务有特定的参数 (\theta_o)(例如用于ImageNet分类的输出层及其权重)。
    • 新任务有随机初始化的特定参数 (\theta_n)(例如场景分类器)。
    • 原始网络对新任务图像的旧任务输出 (y_o)。
  2. 根据新任务调整网络结构
    • 根据新任务的分类输出数量,调整输出层的节点。
  3. 训练过程
    • 使用随机梯度下降(SGD)训练网络,最小化所有任务的损失和正则化项 ®。正则化项通常是一个简单的权重衰减(0.0005)。
    • 训练分为两个步骤:
      • 预热步骤(warm-up step):冻结 θ s \theta_s θs θ o \theta_o θo,只训练 θ n \theta_n θn直到收敛。这有助于提高新任务的性能。新任务损失使用的是多类分类的常用损失函数,例如多项式逻辑损失(multinomial logistic loss)

L w a r m − u p = L n e w ( y n , y ^ n ) L_{warm-up}=L_{new}(y_n,\hat{y}_n) Lwarmup=Lnew(yn,y^n)

  - **联合优化步骤(joint-optimize step)**:联合训练所有权重$\theta_s$、$\theta_o$和$\theta_n$直到收敛。这有助于在新任务上优化共享参数,同时保留旧任务的性能。旧任务损失使用的是知识蒸馏损失(Knowledge Distillation loss),这是一种修改后的交叉熵损失,增加了对较小概率的权重,以鼓励网络输出接近原始网络的输出。通过损失平衡权重$\lambda_o$来调整新旧任务损失的相对重要性,在新任务和旧任务性能之间取得平衡。

L j o i n t = λ o L o l d ( y o , y ^ o ) + L n e w ( y n , y ^ n ) + R ( θ s , θ o , θ n ) L_{joint}=λ_{o}L_{old}(y_{o},\hat{y}_o)+L_{new}(y_n,\hat{y}_n)+R(\theta_s,\theta_o,\theta_n) Ljoint=λoLold(yo,y^o)+Lnew(yn,y^n)+R(θs,θo,θn)

5 实验分析

(1)不同方法性能比较
不同方法相对于LwF的性能差异。实验结果以与LwF方法的比较为基础来报告,以便进行比较。对于VOC数据集,使用平均精度均值(mAP)来衡量性能,而对于其他数据集,则使用准确率(acc)

  • 正值:表示某种方法的性能高于LwF方法。例如,如果某个方法在新任务上的准确率比LwF方法高出2%,则这个差值会表示为+2%。
  • 负值:表示某种方法的性能低于LwF方法。例如,如果某种方法在旧任务上的准确率比LwF方法低1%,则这个差值会表示为-1%。

image.png
(2)逐步向预训练网络添加新任务时的每种任务的表现
image.png
在大多数情况下,LwF方法的性能随着时间的退化速度比微调(fine-tuning)慢,并且在大多数情况下优于特征提取方法。对于Places2到VOC的任务对,LwF方法的性能与联合训练相当,表明在这种情况下,不需要旧任务数据也能达到类似性能。
(3)新任务训练集大小减少(即数据子采样)对比较方法的影响
x轴显示了训练集大小的减少,即从完整的训练集逐渐减少到较小的子集。
image.png
即使在训练数据较少的情况下,LwF方法相对于其他方法仍然表现出更好的性能,这显示了LwF在不同数据条件下的稳健性。随着训练集大小的减少,所有方法的性能都可能受到影响,但LwF方法能够更好地保持性能,尤其是在新任务上。
(4)LwF算法的扩展性
image.png
图(a)中表示每个新任务有自己的全连接层,可以针对特定的任务进行训练和优化。
图(b)在网络扩展方法中,通过在现有层之上添加新的节点来扩展网络,这些节点为新任务提供了额外的表示能力。新增节点的权重初始化方法参考了Net2Net的扩展方式,即复制现有节点的权重。
理论上将更多层设为任务特定可能有助于新任务的学习,但实验结果表明,这种方法并没有带来一致的性能提升。这可能是因为过多的任务特定层增加了模型的复杂性,而没有相应的性能收益。
LwF可以应用于网络扩展,通过解冻所有节点并匹配旧任务的输出响应,进一步优化新旧任务的平衡。

(5)其他设计选择的对比
包括改变任务特定层的数量、使用网络扩展技术、调整共享参数的学习率,以及使用不同的损失函数。
image.png

  • 表中比较了仅在输出层(最后一层)与其他层也作为任务特定时的性能差异。实验结果表明,仅在输出层进行任务特定化与在更多层进行任务特定化相比,并没有表现出一致的性能优势或劣势。
  • 网络扩展是一种在现有层上添加新节点的方法,以增加网络容量。表中比较了仅进行网络扩展与结合LwF进行网络扩展的性能。结果显示,虽然网络扩展可以提供一些性能提升,但结合LwF并没有带来额外的性能增益。
  • 表中探讨了在微调过程中降低共享参数学习率的效果。实验结果表明,仅仅降低学习率并不足以保持旧任务的性能,而且可能会降低新任务的性能。这强调了LwF方法中输出保持损失的重要性。
  • 表中比较了使用L1损失、L2损失、交叉熵损失和知识蒸馏损失的性能。实验结果表明,知识蒸馏损失略微优于其他损失函数,尽管优势不是很显著。知识蒸馏损失在某些情况下可能提供轻微的性能提升,但总体而言,损失函数的选择并不是影响LwF性能的关键因素。

(6)不同方法在新旧任务上的性能
image.png

  • LwF方法在新任务上通常优于微调(fine-tuning)和其他基线方法,并且在旧任务上的性能也显著优于微调。
  • 联合训练(Joint Training)作为上界,使用旧任务数据,通常在旧任务上表现最佳,但在新任务上可能不如LwF。
  • 不同的损失函数(如L1、L2、交叉熵和知识蒸馏损失)对LwF方法的性能有轻微影响。知识蒸馏损失通常提供稍微好一点的性能,但优势不大。

6 思考

(1)使用知识蒸馏的损失函数去做联合优化很有启发。
(2)本文中特征提取的算法原理
特征提取(Feature Extraction)是一种迁移学习方法,其核心思想是利用在一个大型、多样化的数据集(如ImageNet)上预训练的卷积神经网络(CNN)来为新任务提取特征。以下是特征提取方法的算法原理:

  1. 使用预训练的网络:选择一个在大规模数据集上预训练好的CNN模型,该模型已经在其参数中学习到了丰富的特征表示。
  2. 固定共享层:在特征提取方法中,预训练网络的共享参数 θs 被冻结,即不对其进行进一步的训练更新。
  3. 提取特征:通过前向传播,使用冻结的共享层来提取输入数据的特征表示。这些特征通常是网络中最后一层全连接层之前的激活输出。
  4. 训练新任务的分类器:在提取的特征之上训练一个新的分类器(例如,一个新的全连接层),这个分类器专门针对新任务进行训练,以学习如何根据特征表示对新任务的类别进行分类。
  5. 新任务训练:使用新任务的数据来训练新分类器,而原始的共享层参数保持不变,这样新任务就能从预训练网络中受益,同时避免了对旧知识的灾难性遗忘。

参考文献 [5] J. Donahue et al., “Decaf: A deep convolutional activation feature for generic visual recognition,” in International Conference in Machine Learning (ICML), 2014
(3)本文中微调的算法原理
微调(Fine-tuning)是一种迁移学习技术,用于将预训练的模型调整到一个新的、通常数据量较小的任务上。以下是微调方法的算法原理:

  1. 预训练模型:从一个大型数据集(如ImageNet)开始,该数据集已经在模型的参数中学习到了通用的特征表示。
  2. 修改网络结构:根据新任务的需求,可能需要对网络结构进行修改。例如,对于一个新的分类任务,可能需要替换或扩展网络的输出层,以适应新任务的类别数。
  3. 冻结共享层:在某些微调策略中,可以选择冻结网络的某些层(通常是底层),以保留学习到的通用特征。
  4. 训练新任务参数:使用新任务的数据来训练网络的某些层,特别是那些被修改或新添加的层。这通常涉及到使用比原始预训练时更小的学习率。
  5. 反向传播:通过反向传播算法来更新网络的参数,以最小化新任务的损失函数。
  6. 平衡新旧知识:微调过程中的一个关键挑战是平衡新任务学习与保留旧任务知识之间的关系,避免对旧任务性能的灾难性遗忘。

参考文献 [6] R. Girshick et al., “Rich feature hierarchies for accurate object detection and semantic segmentation,” in The IEEE Conference on Computer Vision and Pattern Recognition (CVPR), June 2014。
(4)新增节点的权重初始化方法参考了Net2Net的扩展方式,其中Net2Net是什么意思?
Net2Net是一种用于加速深度学习模型训练的技术,由Terry Chen、Ian Goodfellow和Jon Shlens在他们的论文中提出。这项技术的核心思想是允许网络在不牺牲性能的情况下快速适应新任务,通过将一个已经训练好的网络转化为一个具有更多或更少层的新网络,同时保持两者在功能上的等价性。Net2Net通过复制已有的层来扩展网络,每个新层的节点(或神经元)是原层节点的副本。新层的权重是通过原层的权重进行初始化的,这样可以保留已经学到的知识,并为新任务的学习提供一个良好的起点。
(5)本文微调算法和特征提取算法的异同是什么?
相同点:

  1. 预训练基础:两者都使用在大型数据集(如ImageNet)上预训练的卷积神经网络(CNN)作为基础。
  2. 参数重用:它们都利用了预训练模型中学习到的参数,尤其是网络的底层参数,这些参数捕获了通用的特征表示。
  3. 适应性:两种方法都旨在使模型适应新的数据集或任务,同时保留从原始任务中学到的知识。

不同点:

  1. 参数更新:
    • 特征提取:通常只训练新任务的顶层,底层的共享参数被冻结,不进行更新。
    • 微调:不仅训练新任务的顶层,还可能更新一部分或全部共享参数,以更好地适应新任务。
  2. 网络结构修改:
    • 特征提取:不需要修改网络结构,直接在预训练模型的基础上添加新的分类层。
    • 微调:可能需要根据新任务的需要修改网络结构,例如改变或扩展输出层。
  3. 训练策略:
    • 特征提取:训练过程集中在新添加的顶层,底层参数保持不变。
    • 微调:训练过程涉及整个网络或网络的一部分,使用较小的学习率来微调参数。
  4. 对新任务的适应性:
    • 特征提取:由于只训练顶层,对新任务的适应性可能有限,但可以快速部署。
    • 微调:通过对共享层的微调,可以更好地适应新任务,但风险是可能会损害旧任务的性能(灾难性遗忘)。
  5. 性能影响:
    • 特征提取:通常在新任务上的性能不如微调,因为底层特征没有针对新任务进行优化。
    • 微调:可能在新任务上实现更好的性能,但需要仔细平衡以保护旧任务的性能。
  6. 训练数据的需求:
    • 特征提取:由于底层参数不更新,对新任务的训练数据量需求较少。
    • 微调:需要足够的新任务训练数据来有效更新参数,否则可能引起过拟合。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1987636.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【NXP-MCXA153】开发板救砖教程

前言 新手接触到NXP的板子时,一个不留神把调试的GPIO(RXD、TXD)改掉,很容易出现MDK Keil无法识别CMSIS-DAP调试器的情况;主控进入了莫名其妙模式导致调试器无法识别了,你根本无法下载程序,想改…

大数据-67 Kafka 高级特性 分区 分配策略 Ranger、RoundRobin、Sticky、自定义分区器

点一下关注吧!!!非常感谢!!持续更新!!! 目前已经更新到了: Hadoop(已更完)HDFS(已更完)MapReduce(已更完&am…

虚拟机Centos7 minimal版本安装docker

1、在 CentOS 7 上启用 EPEL 软件包存储库; (删除epel软件包和其他操作可参考:如何在 CentOS 7 上使用 EPEL (linux-console.net)) 1.1: 要安装epel前会报错,如下所示: 先参照这个链接安装&a…

【python】OpenCV—Image Super Resolution

文章目录 1、背景介绍2、准备工作3、EDSR4、ESPCN5、FSRCNN6、LapSRN7、汇总对比8、参考 1、背景介绍 图像超分,即图像超分辨率(Image Super Resolution,简称SR),是指由一幅低分辨率图像或图像序列恢复出高分辨率图像…

HTML基础 - HTML5

目录 一. 简介 二. 新增元素 三. 拖放 地理定位 A、HTML5 拖放(Drag and Drop) B.HTML5 地理定位(Geolocation) 四. input 五. web存储 webSQL 六. 应用程序缓存 web workers 七. web socket 可以先看上篇HTML基础再来看…

RabbitMQ、Kafka对比(超详细),Kafka、RabbitMQ、RocketMQ的区别

文章目录 一、kafka和rabbitmq全面对比分析1.1 简介1.2 kafka和rabbitmq全面对比分析1.3 影响因素 二、RabbitMQ、Kafka主要区别2.1 详解/主要区别2.1.1 设计目标和适用场景2.1.2 架构模型方面2.1.3 吞吐量和性能2.1.4 消息存储和持久化2.1.5 消息传递保证2.1.6 集群负载均衡方…

理解二分搜索算法

一.介绍 在本文中,我们将了解二分搜索算法。二分搜索算法是一种在排序数组中查找特定元素的高效方法。它的工作原理是将搜索间隔反复分成两半,从而大大减少了找到所需元素所需的比较次数。该算法的时间复杂度为 O(log n),因此对于大型数据集…

CLOS架构

CLOS Networking CLOS Networking 是指使用 Clos 网络拓扑结构(Clos Network Topology)进行网络设计的一种方法。该方法是由贝尔实验室的工程师 Charles Clos 在1950年代提出的,以解决电路交换网络的可扩展性和性能问题。随着现代计算和网络…

SpringBoot基础(一):快速入门

SpringBoot基础系列文章 SpringBoot基础(一):快速入门 目录 一、SpringBoot简介二、快速入门三、SpringBoot核心组件1、parent1.1、spring-boot-starter-parent1.2、spring-boot-dependencies 2、starter2.1、spring-boot-starter-web2.2、spring-boot-starter2.3、…

YOLOv10改进 | 主干篇 | YOLOv10引入CVPR2023 顶会论文BiFormer用于主干修改

1. 使用之前用于注意力的BiFormer在这里用于主干修改。 YOLOv10改进 | 注意力篇 | YOLOv10引入BiFormer注意力机制 2. 核心代码 from collections import OrderedDict from functools import partial from typing import Optional, Union import torch import torch.nn as n…

C++:vector容器

概览 std::vector是C标准模板库(STL)中的一种动态数组容器。它提供了一种类似于数组的数据结构,但是具有动态大小和更安全的内存管理。 定义和基本特性 std::vector是C标准库中的一 个序列容器,它代表了能够动态改变大小的数组。与普通数组一样&#x…

酒店智能插座在酒店智慧化中的重要性

在当今数字化和智能化的时代,酒店行业也在不断追求创新和提升服务品质,以满足客人日益增长的需求。酒店智能插座作为酒店智慧化的重要组成部分,发挥着不可忽视的作用。 提升客人的便利性: 酒店智能插座能够为客人提供更加便捷的充…

使用 Java Swing 的 IMEI 验证器

一.介绍 本文档介绍如何使用 Java Swing 创建一个简单的 IMEI 验证器应用程序。 二.什么是 IMEI 号码 IMEI 代表国际移动设备识别码。IMEI 用于在移动设备连接到网络时对其进行识别。每个 GSM、CDMA 或卫星移动设备都有唯一的 IMEI 号码。此号码将印在设备电池组件内。用户可…

Flutter GPU 是什么?为什么它对 Flutter 有跨时代的意义?

Flutter 3.24 版本引入了 Flutter GPU 概念的新底层图形 API flutter_gpu ,还有 flutter_scene 的 3D 渲染支持库,它们目前都是预览阶段,只能在 main channel 上体验,并且依赖 Impeller 的实现。 Flutter GPU 是 Flutter 内置的底…

Python3 第六十六课 -- CGI编程

目录 一. 什么是 CGI 二. 网页浏览 三. CGI 架构图 四. Web服务器支持及配置 五. 第一个CGI程序 5.1. HTTP 头部 5.2. CGI 环境变量 六. GET和POST方法 6.1. 使用GET方法传输数据 6.1.1. 简单的url实例:GET方法 6.1.2. 简单的表单实例:GET方法…

暑期数据结构 空间复杂度

3.空间复杂度 空间复杂度也是一个数学表达式,是对一个算法在运行过程中临时占用存储空间大小的量度。 空间复杂度不是程序占用了多少bytes的空间,因为这个也没太大意义,所以空间复杂度算的是变量的个数。空间复杂度计算规则基本跟…

SAM2:在图像和视频中分割任何内容

SAM 2: Segment Anything in Images and Videos 一、关键信息 1. SAM 2概述: SAM 2 是一种基础模型,设计用于在图像和视频中实现可提示的视觉分割。该模型采用变压器架构和流式内存进行实时视频处理。它在原始的Segment Anything Model(SAM…

自用 K8S 资源对象清单 YAML 配置模板手册-1

Linux 常用资源对象清单配置速查手册-1 文章目录 1、Pod 容器集合2、Pod 的存储卷3、Pod 的容器探针4、ResourceQuota 全局资源配额管理5、PriorityClass 优先级类 管理多个资源对象清单文件常用方法: 使用 sed 流式编辑器批量修改脚本键值进行资源清单的创建&am…

【高中数学/函数/值域】求f(x)=(x^2+1)^0.5/(x-1) 的值域

【问题】 求f(x)(x^21)^0.5/(x-1) 的值域 【来源】 《高中数学解题思维策略》P3 例1-1 杨林军著 天津出版传媒集团出版 【解答】 表达式说明f(x)(x^21)^0.5/(x-1)f(x)((x^21)/(x-1)^2)^0.5准备采用配方法f(x)(12/(x-1)2/(x-1)^2)^0.5(1)式f(x)(2*(1/(x-1)1/2)^21/2)^0.5(2)…

Pytorch系列-张量的类型转换

🌈个人主页:羽晨同学 💫个人格言:“成为自己未来的主人~” 张量转换为NumPy数组 使用Tensor.numpy()函数可以将张量转换为ndarray数组 # 1.将张量转换为numpy数组 data_tensortorch.tensor([2,3,4]) # 使用张量对象中的numpy函数进行转…