【博士每天一篇文献-算法】Overcoming catastrophic forgetting in neural networks

news2024/11/15 14:01:27

阅读时间:2023-10-24

1 介绍

年份:2016
作者:James Kirkpatrick, Razvan Pascanu, Neil Rabinowitz, Joel Veness, Guillaume Desjardins, Andrei A. Rusu, Kieran Milan, John Quan, Tiago Ramalho, Agnieszka Grabska-Barwinska, Demis Hassabis, Claudia Clopath, Dharshan Kumaran, Raia Hadsell,加州史丹佛大學史丹佛大學
期刊:Proceedings of the national academy of sciences
引用量:5449
这篇论文的主题是关于神经网络如何克服灾难性遗忘的问题,灾难性遗忘是神经网络在顺序学习任务时的一个限制。论文提出了一种称为弹性权重合并(EWC)的方法,可以使神经网络在学习新任务的同时记住旧任务。EWC会有选择地降低对先前学习任务重要的权重的学习速度,从而防止灾难性遗忘。作者通过在MNIST数据集上解决分类任务和顺序学习Atari 2600游戏的实验来证明EWC的有效性。论文将EWC与其他方法如L2正则化和dropout正则化进行了比较,结果表明EWC在保持旧任务高性能的同时学习新任务方面优于这些方法。论文解释了EWC的实现和合理性,包括如何约束重要参数和确定哪些权重对于每个任务是重要的。论文还讨论了哺乳动物大脑可能支持无灾难遗忘连续学习的神经机制。总的来说,这篇论文通过使用EWC提出了解决神经网络灾难性遗忘问题的方法。

2 创新点

  1. EWC方法:论文提出了一种名为弹性权重整合的算法,用于实现神经网络的连续学习。该算法根据先前学习任务中权重的重要性,减缓学习过程,从而保留旧任务的知识。
  2. 在MNIST数据集和Atari 2600游戏中的应用:论文通过在MNIST数据集上进行分类任务和在Atari 2600游戏中进行学习来展示EWC的有效性。结果表明,相比于L2正则化和dropout正则化等其他方法,EWC在学习新任务的同时能够维持旧任务的高性能。
  3. EWC的实施和正当性:论文解释了EWC的具体实施和合理性,包括对重要参数的约束和确定每个任务中哪些权重是重要的。论文还提到了哺乳动物大脑中支持连续学习而不发生灾难性遗忘的神经机制。

3 算法

(1)计算步骤

  • 计算每个权重在先前任务中的重要性:
    • 先前任务的损失函数:L_prev(θ),其中θ表示网络的权重。
    • Fisher信息矩阵:F_prev(θ) = E[∇²L_prev(θ)],其中∇²表示梯度的二阶导数。
    • 权重重要性:I_prev(θ) = F_prev(θ) * (θ - θ_prev)²,其中θ_prev表示在先前任务上训练后的权重。
  • 计算当前任务的损失函数:当前任务的损失函数:L_curr(θ)。
  • 计算正则化项并更新网络权重:
    • 正则化项:EWC_loss(θ) = L_curr(θ) + λ * Σ[I_prev(θ)], 其中λ是正则化项的权重,Σ表示对所有权重求和。
    • 更新网络权重:θ_new = argmin(θ)[EWC_loss(θ)]

(2)推理过程
训练了一个模型,其参数为 θ \theta θ,定义最小化以下损失函数来完成此操作:
L ( θ ) = L n e w ( θ ) + ∑ i = 1 n λ 2 F i ( θ i − θ i ∗ ) 2 \mathcal{L}(\theta) = \mathcal{L}_{new}(\theta) + \sum_{i=1}^{n} \frac{\lambda}{2} F_i (\theta_i - \theta_i^*)^2 L(θ)=Lnew(θ)+i=1n2λFi(θiθi)2

其中, L n e w ( θ ) \mathcal{L}_{new}(\theta) Lnew(θ) 是新任务的损失函数,n 是先前任务的数量, F i F_i Fi 是 Fisher 信息矩阵的对角线元素, θ i ∗ \theta_i^* θi 是在先前任务i中找到的最优参数。 λ \lambda λ 是一个超参数,控制先前任务对新任务的影响。
Fisher 信息矩阵是 Hessian 矩阵的期望值,它衡量了损失函数对参数的二阶导数。 在 EWC 中,只计算对角线元素,因为它们提供了最大的信息,同时也更容易计算。Fisher 信息矩阵的对角线元素可以通过以下公式计算:
F i , j = E x ∼ D i [ ∂ log ⁡ p ( y ∣ x , θ ) ∂ θ i ∂ log ⁡ p ( y ∣ x , θ ) ∂ θ j ] F_{i,j} = \mathbb{E}_{x\sim D_i}[\frac{\partial \log p(y|x,\theta)}{\partial \theta_i} \frac{\partial \log p(y|x,\theta)}{\partial \theta_j}] Fi,j=ExDi[θilogp(yx,θ)θjlogp(yx,θ)]

其中, D i D_i Di是先前任务i的数据分布, p ( y ∣ x , θ ) p(y|x,\theta) p(yx,θ)是模型在给定输入x 和参数 θ \theta θ的情况下预测输出y的概率分布。
在每次学习新任务之前,需要计算 Fisher 信息矩阵和最优参数 θ i ∗ \theta_i^* θi。这可以通过在先前任务上运行梯度下降来实现,直到收敛为止。一旦计算出 Fisher 信息矩阵和最优参数,就可以使用 EWC 来学习新任务,同时保留先前任务的知识。
最后,可以使用以下公式计算 EWC 梯度:
g i = ∇ θ i L n e w ( θ ) + λ ∑ j = 1 n F i , j ( θ i − θ i ∗ ) g_i = \nabla_{\theta_i} \mathcal{L}_{new}(\theta) + \lambda \sum_{j=1}^{n} F_{i,j} (\theta_i - \theta_i^*) gi=θiLnew(θ)+λj=1nFi,j(θiθi)
其中, g i g_i gi是 EWC 梯度, ∇ θ i L n e w ( θ ) \nabla_{\theta_i} \mathcal{L}_{new}(\theta) θiLnew(θ)是新任务的梯度。通过添加正则化项,EWC 可以确保新任务不会完全覆盖先前任务的知识,从而在连续学习中实现知识共享。

5 实验结果分析

(1)总结一
image.png

  • 使用纯随机梯度下降(SGD)训练这个任务序列会引发灾难性遗忘。
  • 图2A展示了两个不同任务的测试集性能。在训练从第一个任务切换到第二个任务时,任务B的性能迅速下降,而任务A的性能迅速上升。
  • 任务A的遗忘问题会随着更长的训练时间而进一步恶化。
  • 使用L2正则化不能解决这个问题,因为它对所有权重施加了相同的保护限制,导致在任务B上学习的能力受到限制。
  • 然而,使用EWC可以根据任务A中每个权重的重要性,使网络能够在不遗忘任务A的情况下很好地学习任务B。
  • 图2B展示了使用EWC和使用SGD与dropout正则化的所有任务的平均性能。可以看到EWC在旧任务上保持了高性能,并且仍然能够学习新任务。
  • 图2C展示了两个不同置换程度下网络深度的Fisher信息矩阵的相似性。任务越不相似,早期层的Fisher信息矩阵重叠越小。

(2)总结二

  • 当网络在两个非常相似的任务上训练(两个MNIST版本,只有少数像素被重排),这两个任务在整个网络中依赖于相似的权重集
  • 当两个任务之间更不相似时,网络开始为两个任务分配单独的能力(即权重)。

在进行大量重排时,网络靠近输出的层确实被两个任务重复使用。这反映了重排使得输入对内容是非常不同的,但输出的内容(即类别标签)是共享的。
(3)总结三
EWC可以在要求更高的强化学习(RL)领域中支持连续学习。作者测试了在经典的Atari 2600游戏集上,将Deep Q Networks与EWC相结合的方法。实验中,通过使用EWC,能够学习多个游戏,而不会忘记以前学习的游戏 。与以前的RL方法相比,EWC利用了固定资源(即网络容量)的单个网络,并且计算开销较小。

6 代码

https://github.com/yashkant/Elastic-Weight-Consolidation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1134069.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

js实现在报表参数界面获取body中控件的值

要在报表参数界面获取body中控件的值,你可以使用JavaScript来实现。下面是一个详细的介绍: 1. DOM(文档对象模型): - DOM是用于操作HTML文档的API,它允许你通过JavaScript访问和操作文档中的元素。 - 在报…

python由0到1的基础第一篇(基础语法、变量类型、运算符)

文章目录 前言编程语言是什么?编译型语言和解释型语言的区别编译型语言解释型语言编译型语言和解释型语言的差异总结 一、Python是什么?Python简介1.1Python是什么?1.2Python简介1.2.1Python优点1.2.2Python的缺点 二、Python能干什么&#x…

iview form 动态表单

最开始用得网上得 <FormItemlabel"采购方开户行":rules"baseForm.receiptType 12? baseInfoRule.procureBank: [{ required: false }]"><Inputv-model"baseForm.procureBank"placeholder"请输入采购方开户行"style"w…

RISC-V架构——中断处理和中断控制器介绍

1、ARM架构中断机制介绍 本文不是从零开始讲解中断&#xff0c;对于中断的基本知识不再赘述&#xff0c;对中断不是很了解可以先学习ARM中断的文章。参考博客&#xff1a;《ARM架构的外部中断介绍(S5PV210芯片)》&#xff1b; 2、RIAC_V架构的中断控制器架构 &#xff08;1&…

如何运用设计模式中的享元模式

文章目录 &#x1f31f; 如何将设计模式中的享元模式运用到生活当中&#x1f34a; 什么是享元模式&#x1f34a; 生活中的应用&#x1f389; 衣物&#x1f389; 图书馆 &#x1f34a; 总结 &#x1f4d5;我是廖志伟&#xff0c;一名Java开发工程师、Java领域优质创作者、CSDN博…

WORD中的表格内容回车行距过大无法调整行距

word插入表格&#xff0c;编辑内容&#xff0c;换行遇到如下问题&#xff1a; 回车后行距过大&#xff0c;无法调整行距。 解决方法&#xff08;并行&#xff09;&#xff1a; 方法1&#xff1a;选中要调整的内容&#xff0c;菜单路径&#xff1a;“编辑-清除-格式” 方法2&am…

Unity3D 基础——WASD控制物体移动

using System.Collections; using System.Collections.Generic; using UnityEngine;public class MotionControl : MonoBehaviour {public float speed 3f; //定义一个速度// Start is called before the first frame updatevoid Start(){}// Update is called once per fram…

一文读懂MT4:从小白到专家,MT4教程全解析!

亲爱的读者&#xff0c;欢迎来到这篇全面解析MT4交易平台的文章。无论你是刚刚接触金融交易&#xff0c;还是已经有一定经验的投资者&#xff0c;这篇文章都将为你提供深入浅出的MT4使用指南。通过阅读本文&#xff0c;你将能够全面了解并掌握MT4交易平台的使用技巧和操作方法。…

JAVA 链式编程和建造者模式的使用(lombok的使用)

0.说明 0.1 链式编程 链式编程的原理是返回一个this对象&#xff0c;也就是返回对象本身&#xff0c;从而达到链式效果。这样可以减少一些代码量&#xff0c;是java8新增的内容。 此处主要介绍在新建对象使用链式编程更加方便的创建对象。链式编程的一些常见用法可以看这个&a…

使用scapy 分析报文

比wireshark 更 happy udp 就简易的多&#xff0c;tcp 可能在设置bpf 时 多加几个条件 由于协议分析是手写的,所以可以对数据包的交互记录到excel 中再次进行分析

2024年天津中德应用技术大学专升本物流管理专业课考试大纲

天津中德应用技术大学物流管理专业&#xff08;高职升本科&#xff09;2024年专业基础考试大纲 一、试卷类型 物流管理专业升本专业课考试共1套试卷&#xff0c;总分200分&#xff0c;考试时间为2小时。内容包含仓储与配送管理40%、物流基础30%&#xff0c;运输管理30%&#…

双十一期间高预算广告增加,开发者如何精细化运营才能达到抢量增收目标?

随着双十一时间节点的临近&#xff0c;“双十一”大促也迎来了推广高峰&#xff0c;通常&#xff0c;大家总是认为推广高峰就是媒体收益高峰&#xff0c;但很多在变现的开发者都深有体会的是&#xff0c;广告主的投放高峰并不意味着收益高峰&#xff0c;很多开发者总结以往经验…

博通BCM575系列RDMA网卡驱动bnxt_re分析(一)

简介 整个BCM系列驱动分成以太网部分(bnxt_en.ko)和RDMA部分(bnxt_re.ko), 两个模块之间通过内核的auxiliary_bus进行管理.我们主要分析下bnxt_re驱动. 代码结构 这个驱动的核心是 qplib_fp.c, 这个文件主要包含了驱动的数据路径, 包括Post Send, Post Recv, Poll CQ流程的实…

项目管理-2023西电网课课后习题答案-第四章

文章目录 第四章答案1-1011-20 [✅] 第一章答案[✅] 第二章答案[✅] 第三章答案[✅] 第四章答案[✅] 第五章答案 第四章答案 1-10 11-20

在 history 模式下,为什么刷新页面会出现404?

1、原因 因为浏览器在刷新页面时&#xff0c;它会向服务器发送 GET 请求&#xff0c;但此时服务器并没有配置相应的资源来匹配这个请求&#xff0c;因此返回 404 错误。 2、解决方案 为了解决这个问题&#xff0c;我们需要在服务器端进行相关配置&#xff0c;让所有的路由都指…

质子 8.0-4 发布,支持更多 Linux 上的 Windows 游戏

导读Valve 近日发布了 Proton 8.0-4&#xff0c;这是 Steam Play 基于 Wine 和其他组件的开源兼容工具的最新版本&#xff0c;可让 Linux 用户玩 Windows 游戏。 在 Proton 8.0-3 发布两个半月后&#xff0c;Proton 8.0-4 正式支持更多 Windows 游戏在 Linux 上运行&#xff0c…

怎么禁止U盘拷贝电脑资料

怎么禁止U盘拷贝电脑资料 现如今U盘已经成为了人们日常传输文件的主要方式之一&#xff0c;U盘在给我们提供便利的同时&#xff0c;也带来了一些安全隐患&#xff0c;比如U盘可以轻松地复制电脑文件&#xff0c;这可能会导致机密信息泄露。因此&#xff0c;本文将介绍一些方法…

2023-10学习笔记

1.sql注入 不管是上一篇博客&#xff0c;通过java代码执行sql 还是我们常用的Mybatis的#{}和${} 都会提到sql注入的问题 1.1啥是sql注入 应该知道是说传入无关的参数&#xff0c;比如本来是想要一个where条件查询参数 但是你拼了一个drop 比如 原来的sql select * from…

经典卷积神经网络 - ResNet

ResNet是一种残差网络&#xff0c;咱们可以把它理解为一个子网络&#xff0c;这个子网络经过堆叠可以构成一个很深的网络。 我们一直在加深神经网络&#xff0c;但是加深不一定只会带来好处。 残差块 串联一个层改变函数类&#xff0c;我们希望能扩大函数类残差块加入快速通…

Unity的碰撞检测(三)

温馨提示&#xff1a;本文基于前一篇“Unity的碰撞检测(二)”继续探讨两个游戏对象具备刚体的碰撞检测&#xff0c;阅读本文则默认已阅读前文。 &#xff08;一&#xff09;测试说明 在基于两个游戏对象都具备碰撞器和刚体且属性一致的条件下&#xff0c;若二者刚体的BodyType…