Overcoming catastrophic forgetting in neural networks

news2024/12/25 9:23:36

目录

预备知识:

论文笔记

1. Introduction

2. Elastic weight consolidation

2.1 EWC allows continual learning in a supervised learning context

2.2 EWC allows continual learning in a reinforcement learning context

3. Conclusion


文章链接:https://arxiv.org/abs/1612.00796

预备知识:

有关Life-Long Learning:Lecture 14:Life-long Learning_zzz_qing的博客-CSDN博客

Life-Long Learning Goal: A model can beat all task!

Life-Long Learning有很多种不同的做法,EWC属于Regularization-based methods:

在李宏毅老师Life-Long Learning课程对应的HW14中,给出了EWC的简介和相关代码:

The ewc class applied EWC algorithm to calculate the regularization term. Here we will focus on the algorithm of EWC.

In this assignment, we want to let our model learn 10 tasks successively. Here we show a simple example that lets the model learn 2 tasks(task A and task B) successively.

In the EWC algorithm, the definition of the loss function is shown below:

Assume we have a neural network with more than two parameters.

下面是EWC的部分代码,由代码可知道EWC需要使用label数据:

# Simply use groud truth label of dataset.  
label = data[1].to(self.device)
          
# generate Fisher(F) matrix for EWC     
loss = F.nll_loss(F.log_softmax(output, dim=1), label)
loss.backward()   

for n, p in self.model.named_parameters():
  # get the gradient of each parameter and square it, then average it in all validation set.                          
  precision_matrices[n].data += p.grad.data ** 2 / number_data
def penalty(self, model: nn.Module):
  loss = 0
  for n, p in model.named_parameters():
    # generate the final regularization term by the ewc weight (self._precision_matrices[n]) and the square of weight difference ((p - self.p_old[n]) ** 2).  
    _loss = self._precision_matrices[n] * (p - self.p_old[n]) ** 2
    loss += _loss.sum()
  return loss

论文笔记

本文通过选择性地放慢对那些任务重要的权重的学习来记住旧任务。

本文通过解决一组基于 MNIST 手写数字数据集的分类任务并通过顺序学习多个 Atari 2600 游戏来证明我们的方法是可扩展且有效的。

1. Introduction

intelligent agents必须展示持续学习的能力:即学习连续任务而不忘记如何执行先前训练的任务的能力。

关于灾难性遗忘相关知识点在life-long learning笔记中有,这里不赘述。

在life-long learning笔记中有提到同时训练多个任务不会发生灾难性遗忘,本文给出了解释:在实现continual learning方面,当前的方法通常确保来自所有任务的数据在训练期间同时可用。通过在学习期间交错来自多个任务的数据,不会发生遗忘,因为网络的权重可以针对所有任务的性能进行联合优化。在这种机制中——通常被称为多任务学习范式——深度学习技术已被用于训练能够成功玩多个 Atari 游戏的单个智能体。(在这篇文献提出的时间点,按顺序学习任务就是把所有任务的数据都记下来,学习新任务的时候再把前面的任务再学一遍,需要存储很多数据,训练起来也很慢)

本研究为人工神经网络开发了一种弹性权重整合(简称 EWC)算法。该算法根据某些权重对先前看到的任务的重要性来减慢对某些权重的学习。我们展示了 EWC 如何用于监督学习和强化学习问题以按顺序训练多个任务而不忘记旧任务。

2. Elastic weight consolidation

在本节中,我们将解释为什么我们希望在旧任务的邻域中找到新任务的解决方案,我们如何实现约束,最后我们如何确定哪些参数是重要的。

深度神经网络由多层linear projection followed by element-wise non-linearities组成。学习任务包括调整linear projection的一组权重和偏差 θ,以优化性能。 θ 的许多配置将导致相同的性能,这与 EWC 相关:过度参数化使得任务 B 的解决方案 θ* B 很可能与之前找到的任务 A 的解决方案 θ* A 很接近。因此,在学习任务 B 时,EWC 保护了通过将参数限制在以 θ* A 为中心的任务 A 的低误差区域中,可以提高任务 A 的性能,如图 1 中的示意图所示:(这就是李宏毅老师在Regularization-based的方法中介绍的思想)

为了证明这种约束选择的合理性并定义哪些权重对任务最重要,从概率的角度考虑神经网络训练是很有用的。从这个角度来看,优化参数无异于在给定一些数据 D 的情况下找到它们最可能的值。我们可以根据参数 p(θ) 的先验概率和数据 p(D|θ) 使用贝叶斯规则:

关于任务 A 的所有信息都必须被吸收到后验分布 p(θ|DA) 中。这个后验概率必须包含关于哪些参数对任务 A 很重要的信息,因此是实施 EWC 的关键。

真正的后验概率是难以处理的,因此将后验近似为高斯分布,其均值由参数 θ∗ A 给出,对角线精度由 Fisher 信息矩阵 F 的对角线给出。F 具有三个关键属性:(a) 它相当于接近最小值的损失的二阶导数,(b) 它可以单独从一阶导数计算,因此即使计算也很容易对于大型模型,(c) 它保证是半正定的。

这种方法类似于期望传播,其中每个子任务都被视为后验的一个因素。给定这个近似值,我们在 EWC 中最小化的函数 L 是:

LB(θ) 仅是任务 B 的损失,λ 设置旧任务与新任务相比的重要性,i 标记每个参数。

2.1 EWC allows continual learning in a supervised learning context

如图 A 所示,使用普通随机梯度下降 (SGD) 对这一系列任务进行训练会导致灾难性遗忘。

这个问题不能通过对每个权重使用固定的二次约束(绿色曲线,L2 正则化)对网络进行正则化来解决。对每个权重使用固定的二次约束会导致任务 A 中的性能下降要轻得多,但任务 B 无法正确学习,因为约束保护了所有权重,导致在 B 上留下很少的剩余容量来学习。但是,当我们使用 EWC 时,考虑到了每个权重对任务 A 的不同的重要性,网络可以很好地学习任务 B 而不会忘记任务 A(红色曲线)。

将传统的 dropout 正则化与 EWC 进行比较,单独使用 dropout 正则化的随机梯度下降是有限的,并且它不能扩展到更多任务(图 B)。相比之下,EWC 允许按顺序学习大量任务,错误率只会适度增长:

EWC 允许网络有效地将更多功能压缩到具有固定容量的网络中。下面评估它是否为每个任务分配完全独立的网络部分,或者是否通过共享表示以更有效的方式使用容量(通过测量任务对各自的 Fisher 信息矩阵之间的重叠来确定每个任务是否依赖于相同的权重集,小的重叠意味着这两个任务依赖于不同的权重集,大的重叠表明两个任务都使用了权重)。

图 C 显示了作为深度函数的重叠。作为一个简单的控制,当一个网络在两个非常相似的任务上进行训练时,这些任务依赖于整个网络中相似的权重集(灰色曲线) .当这两个任务彼此更加不同时,网络开始为这两个任务(黑线)分配单独的容量(即权重)。然而,即使对于大排列,更靠近输出的网络层确实被重新用于两个任务。这反映了这样一个事实,即排列使输入域非常不同,但输出域(即类标签)是共享的。

2.2 EWC allows continual learning in a reinforcement learning context

为了在RL中应用 EWC,在每个任务切换时计算 Fisher 信息矩阵。对于每个任务,都会添加一个惩罚,其中锚点由参数的当前值给出,权重由 Fisher 信息矩阵乘以通过超参数搜索优化的缩放因子 λ 给出。我们只对经历了至少 2000 万帧的游戏添加了 EWC 惩罚。

我们还允许 DQN 代理为每个推断的任务维护单独的短期记忆缓冲区:这些允许使用经验重放机制在策略外学习每个任务的动作值。因此,整个系统在两个时间尺度上有记忆:在短时间尺度上,经验回放机制允许 DQN 中的学习基于交错和不相关的经验。在更长的时间范围内,跨任务的专业知识通过使用 EWC 得到整合。最后,我们允许少量网络参数特定于游戏,而不是跨游戏共享。特别是,我们允许网络的每一层都有特定于每个游戏的偏差和每个元素的乘法增益。

实验结果表明,通过使用 EWC,agent确实学会了玩多个游戏:

3. Conclusion

本文提出了一种新算法,即弹性权重合并,它解决了神经网络持续学习带来的重大问题。 EWC 允许在新学习过程中保护先前任务的知识,从而避免灾难性地遗忘旧能力。

我们将 EWC 实现为一个软的二次约束,其中每个权重被拉回其旧值,其数量与其对先前学习任务的性能的重要性成正比。就任务共享结构而言,使用 EWC 训练的网络会重用网络的共享组件。

EWC 算法可以基于贝叶斯学习方法。形式上,当有新任务要学习时,网络参数由先验调节,先验是给定先前任务数据的参数的后验分布。这使得对先前任务约束不佳的参数的快速学习率成为可能,而对于那些至关重要的参数则可以降低学习率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/508558.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

100ASK_全志V853-PRO开发板支持人形检测和人脸识别

1.前言 V853 芯片内置一颗 NPU核,其处理性能为最大 1 TOPS 并有 128KB 内部高速缓存用于高速数据交换,支持 OpenCL、OpenVX、android NN 与 ONNX 的 API 调用,同时也支持导入大量常用的深度学习模型。本章提供一个例程,展示如何使…

JavaScript基础之数值计算

常见的几种场景 场景一:进行浮点值运算结果的判断 常见错误写法:floatNum1 floatNum2 res 我们在Chrome里测试一下 0.1 0.2 0.3,得出的结果是false,而不是预期结果true,因为 0.1 0.2 0.30000000000000004 场…

【Vue-Treeselect 和 vue3-treeselect】树形下拉框

Vue-Treeselect Vue2树形下拉框 链接 文档:Vue-Treeselect 实现 第一步:安装 npm install --save riophae/vue-treeselect 第二步:实现 import Treeselect from riophae/vue-treeselect import riophae/vue-treeselect/dist/vue-treeselect.css属性…

python数据类型总结

标准数据类型 Python 有以下几种标准数据类型: 整数(int):表示整数值,如 1, -5, 0 等。浮点数(float):表示小数值,如 3.14, -0.01, 1.0 等。字符串(str&…

AI智能音箱高性价比出好音质的功放芯片

近几年人工智能等技术的不断发展,AI智能音箱已成为炙手可热的爆款;众多企业纷纷加入其中;如我们熟知的天猫精灵、小爱同学、小度智能音箱、华为AI音箱、腾讯叮当等等智能音箱;据不完全统计,目前国内做智能音箱的企业已…

SpringBoot——创建一个SpringBoot工程

简单介绍: 在之前我们学习JavaEE的时候,是直接使用Spring进行操作,以比较原始的方式进行了SSM的整合,这次我们就来学习一个强大的框架——SpringBoot,这个框架是用来简化Spring应用的初始化创建过程,以及开…

APP外包项目的代码规范

APP项目在工作中使用越来越多,2C的APP项目基本饱和,2B的AP项目P还有很大的发展空间。越来越多的企业希望通过APP来提升工作效率或加强和客户的沟通,但这些企业大多数不是专业的软件公司,开发软件时需要找软件外包开发公司&#xf…

C/C++每日一练(20230510) 编辑距离、多数元素、数列累和

目录 1. 编辑距离 🌟🌟🌟 2. 多数元素 🌟 3. 求分数数列的前N项和 ※ 🌟 每日一练刷题专栏 🌟 Golang每日一练 专栏 Python每日一练 专栏 C/C每日一练 专栏 Java每日一练 专栏 1. 编辑距离 给你…

OpenCV教程——Mat对象

1.Mat对象和IplIamge对象 Mat对象是OpenCV2.0之后引进的图像数据结构、自动分配内存、不存在内存泄漏的问题,是面向对象的数据结构。分为两个部分:头部和数据部分。IplIamge是从2001年OpenCV发布之后就一直存在,是C语言风格的数据结构&#…

笔记本电脑没有声音怎么办?5个必会方法分享

案例:笔记本电脑没有声音怎么办? 【我的笔记本电脑为什么会没有声音呢?看视频听音乐一点声音都没有,实在太烦人了!应该怎么解决呢?】 笔记本电脑逐渐成为人们工作生活必备的工具。如果笔记本电脑没有声音…

怎样检测和维护LED显示屏系统

检测和维护LED显示屏系统是确保其正常运行和延长寿命的重要步骤。以下是一些常见的检测和维护LED显示屏系统的方法: 视觉检查:定期进行视觉检查以确保LED显示屏没有明显的损坏或故障。检查显示屏表面是否有损坏、裂纹或漏光等情况。如果发现任何问题&…

ChatGPT作者John Schulman:通往TruthGPT之路

OneFlow编译 翻译|贾川、徐佳渝、杨婷 大型语言模型(LLM)有一个众所周知的“硬伤”——它们经常会一本正经编造貌似真实的内容。 OpenAI团队希望通过改进强化学习反馈步骤“原生地”阻止神经网络产生幻觉,OpenAI首席科学家Ilya …

spark-sql 报错:Exception thrown flushing changes to datastore

报错背景 hive创建数据库时添加中文备注信息报错。 命令:CREATE DATABASE IF NOT EXISTS hive_ods_db COMMENT Hive ODS层数据库; 报错现象 FAILED: Execution Error, return code 1 from org.apache.hadoop.hive.ql.exec.DDLTask. MetaException(message:Excep…

递归行为与归并排序

master公式 T(N)a*T(N/b)O(N^d) T(N):问题的规模是N个数据 N/b:子过程的规模 a:调用的次数 O(N^d) :除子问题的调用之外,剩余的代码的时间复杂度 使用条件:满足子问题等规模的递归 arr[L,R]范围…

49天精通Java,第27天,队列、双端队列、优先队列

目录 一、队列与双端队列二、Queue和Deque三、api对比1、add和offer区别2、remove和poll3、element和peek 四、优先队列1、PriorityQueue常用方法2、ArrayDeque常用方法 大家好,我是哪吒。 一、队列与双端队列 双端队列是一种特殊的队列,它的两端都可以…

吴军《计算之魂》读后感

前言 断断续续,终于完成了这本书的第一次通读,记录下自己的一些想法。 先说一个小故事。前段时间家里买了一个小鱼缸,问我有没有办法让水自动循环,但不想用电。没有好的想法,去小某书上搜了下,好多案例教…

【哈士奇赠书活动 - 18期】-〖Flask Web全栈开发实战〗

文章目录 ⭐️ 赠书活动 - 《Flask Web全栈开发实战》⭐️ 编辑推荐⭐️ 内容提要⭐️ 赠书活动 → 获奖名单 ⭐️ 赠书活动 - 《Flask Web全栈开发实战》 内容简介: 《Flask Web全栈开发实战》围绕 Flask 框架,详细地讲解了使用 Flask 开发网站的各项技…

vue2项目中使用本机图片的一些问题

前后端分离项目,(vue2springboot2.6.3) 前端上传图片,后端将图片保存在本地。当前端使用上传的图片时出现的一些问题说明。 前端上传图片文件,后端接收到图片文件后将图片保存到vue项目的src/assets/club目录下,如下…

VMware vSphere Replication 8.7 (for vSphere 8.0U1) - 虚拟机复制和数据保护

请访问原文链接:https://sysin.org/blog/vmware-vsphere-replication-8/,查看最新版。原创作品,转载请保留出处。 作者主页:sysin.org 新增功能 vSphere Replication 8.7 | 2023 年 4 月 18 日 | 内部版本 21591677 vSphere Re…

GPT4.0+Midjourney=最佳组合,简单玩法

以下是我设计的一个简单的组合玩法,操作如下: 问gpt4: Here is a MidJourney Prompt Formula: A detailed image of [Subject] [doing something interesting] during [time of day], taken with a [type of camera], using [type of lens] with cinema…