持续学习中避免灾难性遗忘的Elastic Weight Consolidation Loss数学原理及代码实现

news2024/9/24 21:23:41

训练人工神经网络最重要的挑战之一是灾难性遗忘。神经网络的灾难性遗忘(catastrophic forgetting)是指在神经网络学习新任务时,可能会忘记之前学习的任务。这种现象特别常见于传统的反向传播算法和深度学习模型中。主要原因是网络在学习新数据时,会调整权重以适应新任务,这可能会导致之前学到的知识被覆盖或忘记,尤其是当新任务与旧任务有重叠时。

在本文中,我们将探讨一种方法来解决这个问题,称为Elastic Weight Consolidation。EWC提供了一种很有前途的方法来减轻灾难性遗忘,使神经网络在获得新技能的同时保留先前学习任务的知识。

在任务a和任务B的灰色和黄色区域中,存在许多具有期望的低误差的最优参数配置。假设我们为任务A找到了一个这样的配置θꭺ*,当继续从这样的配置训练模型到新的任务B时,会出现三种不同的场景:

蓝色箭头:简单地继续在任务B上进行训练而不受惩罚,将在任务B的低水平区域结束,但在任务A上的表现低于预期的准确性。

绿色箭头:使用任务A的权重的L2约束可能太强,使得模型在任务A上表现良好,但在任务B上表现不佳。

红色箭头:这是EWC是提出的解决方案,它将在模型在两个任务上都表现良好的区域(两个区域之间的交叉点)中找到参数。

下面我们将解释这是如何完成的。

费雪信息矩阵(FIM)

EWC方法所基于的FIM(Fisher Information Matrix)。FIM是一种统计度量,用于量化给定数据提供的关于我们要估计的未知参数θ的信息量。在持续学习的背景下,FIM将有助于识别神经网络参数,这些参数从以前的任务中获取的数据信息较少。通过更新这些参数,网络可以学习新的任务,而不会删除存储在参数中的重要信息,这些信息是关于先前学习任务的非常有用的信息。

假设X是一个随机变量,其概率密度函数f(X |θ)参数化为θ。样本x的似然函数(仅在数据固定的情况下为参数函数)为:

和对数拟然:

将FIM定义为:

这表明对数似然函数对参数的微小变化有多敏感。我们可以将FIM视为似然函数二阶导数的负期望:

当求二阶导数时,基本上是在看似然函数的曲率。

可以考虑下面的两个绘制的似然函数的图表。蓝色曲线表示在峰值附近非常窄的分布,表明数据更有可能在θ附近,并且随着远离θ而迅速减少。相反,黑色曲线代表一个更广泛的分布,即使远离θ,数据也保持相似的可能性。

FIM量化了这个概念——数据是多么严格地限制在某个θ值上。较大的FIM(如蓝色曲线所示)意味着参数值的微小变化将导致数据在这些参数下的可能性显著下降。相反,较小的FIM(如黑色曲线所示)意味着参数值的较小变化将导致可能性的较小降低。

事实证明,费雪信息矩阵与数据的方差(或多变量情况下的协方差)成反比。在上面的图表中,如果假设曲线分别代表均值θ 0和方差σ²ᵦₗᵤₑ和σ²ᵦₗ꜀ₖ的两个高斯分布,其中σ²ᵦₗᵤₑ< σ²ᵦₗ꜀ₖFIM等于1/σ²,因此蓝色曲线包含更多信息。

弹性重量固结

给定数据D和一个参数为θ的神经网络,我们的目标是在给定数据p(θ|D)的情况下最大化参数的概率。根据贝叶斯规则,我们得到:

弹性权重保持

弹性权重保持(Elastic Weight Consolidation,EWC)是一种用于减轻神经网络灾难性遗忘问题的方法。它的基本思想是在学习新任务时保护先前任务的关键权重。

给定数据D和一个参数为θ的神经网络,我们的目标是在给定数据p(θ|D)的情况下最大化参数的概率。根据贝叶斯规则,我们得到:

对两边应用对数并不改变最大化的目标,因为对数是一个单调变换。因此目标变成:

假设两个独立任务D = {A, B},我们有:

最后一个是独立于A和B的。这里log(p(B|θ))是任务B的损失,log(p(B))是B的可能性,它可以作为优化的常数,因为它不依赖于θ, log(p(θ| a))是任务a的后验分布,它包含了任务a重要参数的所有信息。

估计log(p(θ|A))比较复杂的,因为计算它将涉及在整个参数空间上对高维函数进行积分。但是它近似为正态分布,其均值为任务a - θꭺ的最优参数,方差为费雪信息矩阵。这种近似是有意义的,因为我们可以假设A和B任务的新参数θ与任务A的最优参数相差不远。在所有θꭺ的参数中,会有一些参数对任务A的良好表现更重要,并且不希望它们改变太多,这就是FIM的作用,FIM的值表明在这种情况下,改变某个参数将如何影响任务A的损失。因此,FIM中值越高的参数变化受到的惩罚越大。

现在,我们对任务A的最优权值进行泰勒展开直到第二项

其中log(p(θꭺ*|A))是一个常数,我们可以在优化中忽略它。我们也可以忽略第二项,因为在最优θꭺ*处,梯度为零。这样就找到了log(p(θ|A))的表达式,把它代回到图8的原始公式中:

第二项的二阶导数为Hessian,可以根据图5的定义用费雪信息矩阵近似。log(p(B|θ))是新任务B的损失,例如交叉熵,我们记为Lᵦ(θ)

我们不需要进行二阶导数,只需根据图4中等价于图5的定义,即对数似然梯度的外积,用一阶导数近似FIM即可:

这样优化L(θ)的总损失为:

λ是一个超参数,表示在前一个任务a上保持精度的重要性。

上面涉及梯度向量的外积的定义捕获了梯度的协方差结构。而FIM的对角线近似通常由梯度的平方给出,它只计算参数的方差,但计算成本较低,足以完成任务:

Pytorch实现

上面我们介绍了弹性权重保持的数学原理,下面我们来看看Pytorch的代码实现

让我们首先导入一些库以及分别代表任务A和任务B的MNIST和Fashion MNIST数据集。我们还定义了一个简单的神经网络:

 importtorch
 importtorch.nnasnn
 importtorch.nn.functionalasF
 importtorch.optimasoptim
 fromtorchimportautograd
 importnumpyasnp
 fromtorch.utils.dataimportDataLoader
 
 fromtorch.utils.dataimportDataset, DataLoader
 fromtorchvisionimportdatasets, transforms
 fromtqdmimporttqdm
 
 defget_accuracy(model, dataloader):
     model=model.eval()
     acc=0
     forinput, targetindataloader:
         o=model(input.to(device))
         acc+= (o.argmax(dim=1).long() ==target.to(device)).float().mean()
     returnacc/len(dataloader)
 
 classLinearLayer(nn.Module):
     # from https://github.com/shivamsaboo17/Overcoming-Catastrophic-forgetting-in-Neural-Networks/blob/master/elastic_weight_consolidation.py
     def__init__(self, input_dim, output_dim, act='relu', use_bn=False):
         super(LinearLayer, self).__init__()
         self.use_bn=use_bn
         self.lin=nn.Linear(input_dim, output_dim)
         self.act=nn.ReLU() ifact=='relu'elseact
         ifuse_bn:
             self.bn=nn.BatchNorm1d(output_dim)
     defforward(self, x):
         ifself.use_bn:
             returnself.bn(self.act(self.lin(x)))
         returnself.act(self.lin(x))
 
 classFlatten(nn.Module):
     
     defforward(self, x):
         returnx.view(x.shape[0], -1)
     
 classModel(nn.Module):
     
     def__init__(self, num_inputs, num_hidden, num_outputs):
         super(Model, self).__init__()
         self.f1=Flatten()
         self.lin1=LinearLayer(num_inputs, num_hidden, use_bn=True)
         self.lin2=LinearLayer(num_hidden, num_hidden, use_bn=True)
         self.lin3=nn.Linear(num_hidden, num_outputs)
         
     defforward(self, x):
         returnself.lin3(self.lin2(self.lin1(self.f1(x))))
 
 # Load MNIST dataset, representint task A
 mnist_train=datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
 mnist_test=datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())
 train_loader=DataLoader(mnist_train, batch_size=100, shuffle=True)
 test_loader=DataLoader(mnist_test, batch_size=100, shuffle=False)
 
 # FashiomMNIST is task B
 f_mnist_train=datasets.FashionMNIST("../data", train=True, download=True, transform=transforms.ToTensor())
 f_mnist_test=datasets.FashionMNIST("../data", train=False, download=True, transform=transforms.ToTensor())
 f_train_loader=DataLoader(f_mnist_train, batch_size=100, shuffle=True)
 f_test_loader=DataLoader(f_mnist_test, batch_size=100, shuffle=False)

现在让我们在MNIST任务上训练模型:

 # parameters
 EPOCHS=4
 lr=0.001
 weight=100000
 accuracies= {}
 
 device='cuda:1'
 
 criterion=nn.CrossEntropyLoss()
 
 # train model on task A 
 model=Model(28*28, 100, 10).to(device)
 optimizer=optim.Adam(model.parameters(), lr)
 
 for_inrange(EPOCHS):
     forinput, targetintqdm(train_loader):
         output=model(input.to(device))
         loss=criterion(output, target.to(device))
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         
 accuracies['mnist_initial'] =get_accuracy(model, test_loader)

现在可以定义函数来估计FIM和EWC损失中使用的先前参数:

 defewc_loss(model, weight, estimated_fishers, estimated_means):
     losses= []
     forparam_name, paraminmodel.named_parameters():
         estimated_mean=estimated_means[param_name]
         estimated_fisher=estimated_fishers[param_name]
         losses.append((estimated_fisher* (param-estimated_mean) **2).sum())
         
     return (weight/2) *sum(losses)
 
 defestimate_ewc_params(model, train_ds, batch_size=100, num_batch=300, estimate_type='true'):
     estimated_mean= {}
 
     forparam_name, paraminmodel.named_parameters():
         estimated_mean[param_name] =param.data.clone()
         
     estimated_fisher= {}
     dl=DataLoader(train_ds, batch_size, shuffle=True)
     
     forn, pinmodel.named_parameters():
         estimated_fisher[n] =torch.zeros_like(p)
         
     model.eval()
     fori, (input, target) inenumerate(dl):
         ifi>num_batch:
             break
         model.zero_grad()
 
         output=model(input.to(device))
         # https://www.inference.vc/on-empirical-fisher-information/ - more on this here
         ifESTIMATE_TYPE=='empirical':
             # empirical
             label=target.to(device)
         else:
             # true estimate
             label=output.max(1)[1]
 
         loss=F.nll_loss(F.log_softmax(output, dim=1), label)
         loss.backward()
 
         # accumulate all the gradients
         forn, pinmodel.named_parameters():
             estimated_fisher[n].data+=p.grad.data**2/len(dl)
 
     estimated_fisher= {n: pforn, pinestimated_fisher.items()}
     returnestimated_mean, estimated_fisher

然后继续在任务B上训练EWC损失的网络:

 # compute fisher and mean parameters for EWC loss
 estimated_mean, estimated_fisher=estimate_ewc_params(model, mnist_train)
 
 # Train task B fashion mnist
 for_inrange(EPOCHS):
     forinput, targetintqdm(f_train_loader):
         output=model(input.to(device))
         loss=ewc_loss(model, weight, estimated_fisher, estimated_mean) +criterion(output, target.to(device))
         optimizer.zero_grad()
         loss.backward()
         optimizer.step()
         
 accuracies['mnist_EWC'] =get_accuracy(model, test_loader)
 accuracies['f_mnist_EWC'] =get_accuracy(model, f_test_loader)

可以得到以下精度:

 {'mnist_initial': tensor(0.9772, device='cuda:1'),
  'mnist_AB': tensor(0.9717, device='cuda:1'),
  'f_mnist': tensor(0.8312, device='cuda:1')}

最后将这些与没有EWC损失的模型进行比较:

 {'mnist_initial': tensor(0.9762, device='cuda:1'),
  'mnist_AB': tensor(0.1769, device='cuda:1'),
  'f_mnist': tensor(0.8672, device='cuda:1')}

可以看到EWC损失有助于保持任务A的准确率几乎不变,而学习任务B的准确率几乎与没有EWC损失的情况相同。

总结

我们看到了一种允许神经网络在继续学习新任务的同时保留其先前学习的知识的技术,虽然EWC在解决灾难性遗忘方面效果显著,但仍有一些挑战,例如对费雪信息矩阵的计算和存储需求较高,以及在复杂的深度神经网络结构中的实施复杂性。

还有还有其他方法可以使模型进行持续学习,比如:

重播记忆(Replay Memory):保存旧数据以便周期性地重训练。

联合训练(Joint Training):同时训练网络以处理旧任务和新任务。

元学习方法(Meta-learning Approaches):通过元学习算法来优化模型,以便快速适应新任务而不会忘记旧任务。

这些方法有助于减轻灾难性遗忘的影响,使神经网络能够持续学习和适应多个任务。

https://avoid.overfit.cn/post/56aee34117764e89a1a707c316fa305f

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1924313.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

全网最详细单细胞保姆级分析教程(二) --- 多样本整合

上一节我们研究了如何对单样本进行分析,这节我们就着重来研究一下如何对多样本整合进行研究分析! 1. 导入相关包 library(Seurat) library(tidyverse) library(patchwork)2. 数据准备 # 导入单样本文件 dir c(~/Desktop/diversity intergration/scRNA_26-0_filtered_featur…

基于TCP的在线词典系统(分阶段实现)(阻塞io和多路io复用(select)实现)

1.功能说明 一共四个功能&#xff1a; 注册 登录 查询单词 查询历史记录 单词和解释保存在文件中&#xff0c;单词和解释只占一行, 一行最多300个字节&#xff0c;单词和解释之间至少有一个空格。 2.功能演示 3、分阶段完成各个功能 3.1 完成服务器和客户端的连接 servic…

WAF基础介绍

WAF 一、WAF是什么&#xff1f;WAF能够做什么 二 waf的部署三、WAF的工作原理 一、WAF是什么&#xff1f; WAF的全称是&#xff08;Web Application Firewall&#xff09;即Web应用防火墙&#xff0c;简称WAF。 国际上公认的一种说法是&#xff1a;Web应用防火墙是通过执行一…

电表及销售统计Python应用及win程序2

接着上一篇给代码添加了表格功能&#xff0c;方便更好的处理数据。 import json import os from datetime import datetime from tkinter import * from tkinter import messagebox from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg from matplotlib.figure …

JAVA设计模式>>结构型>>适配器模式

本文介绍23种设计模式中结构型模式的适配器模式 目录 1. 适配器模式 1.1 基本介绍 1.2 工作原理 1.3 适配器模式的注意事项和细节 1.4 类适配器模式 1.4.1 类适配器模式介绍 1.4.2 应用实例 1.4.3 注意事项和细节 1.5 对象适配器模式 1.5.1 基本介绍 1.5.2 …

visual studio 2019版下载以及与UE4虚幻引擎配置(过程记录)(官网无法下载visual studio 2019安装包)

一、概述 由于需要使用到UE4虚幻引擎&#xff0c;我使用的版本是4.27版本的&#xff0c;其官方默认的visual studio版本是2019版本的&#xff0c;相应的版本对应关系可以通过下面的官方网站对应关系查询。https://docs.unrealengine.com/4.27/zh-CN/ProductionPipelines/Develo…

java实现资产管理系统图形化用户界面

创建一个&#x1f495;资产管理系统的GUI&#xff08;图形用户界面&#xff09;❤️画面通常需要使用Java的Swing或者JavaFX库。下面我将提供一个简单的资产管理系统GUI的示例代码&#xff0c;使用Java Swing库来实现。这个示例将包括一个主窗口&#xff0c;一个表格来显示资产…

捷配笔记-如何设计PCB板布线满足生产标准?

PCB板布线是铺设连接各种设备与通电信号的路径的过程。PCB板布线是铺设连接各种设备与通电信号的路径的过程。 在PCB设计中&#xff0c;布线是完成产品设计的重要步骤。可以说&#xff0c;之前的准备工作已经为它做好了。在整个PCB设计中&#xff0c;布线设计过程具有最高的极限…

Web浏览器通过串口读取RFID卡号js JavaScript

本示例使用的读卡器&#xff1a;USB转RS232COM虚拟串口RFID读卡器主动读卡Web浏览器Andro、Linux-淘宝网 (taobao.com) <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"…

翁恺-C语言程序设计-05-3. 求a的连续和

05-3. 求a的连续和 输入两个整数a和n&#xff0c;a的范围是[0,9]&#xff0c;n的范围是[1,8]&#xff0c;求数列之和S aaaaaa…aaa…a&#xff08;n个a&#xff09;。如a为2、n为8时输出的是222222…22222222的和。 输入格式&#xff1a; 输入在一行中给出两个整数&#xf…

深入了解 MySQL 的 EXPLAIN 命令

一、什么是 EXPLAIN 命令&#xff1f; EXPLAIN 命令用于显示 MySQL 如何执行某个 SQL 语句&#xff0c;尤其是 SELECT 语句。通过 EXPLAIN 命令&#xff0c;可以看到查询在实际执行前的执行计划&#xff0c;这对于优化查询性能至关重要。 二、EXPLAIN 的基本用法 要使用 EXP…

【算法篇】KMP算法,一种高效的字符串匹配算法

我们今天了解一个字符串匹配算法-KMP算法&#xff0c;内容难度相对来说较高&#xff0c;建议先收藏再细品&#xff01;&#xff01;&#xff01; KMP算法的基本概念 KMP算法是一种高效的字符串匹配算法&#xff0c;由D.E.Knuth&#xff0c;J.H.Morris和V.R.Pratt提出的&#…

Cxx Primer-CP-2

开篇第一句话足见作者的高屋建瓴&#xff1a;类型决定程序中数据和操作的意义。随后列举了简单语句i i j;的意义取决于i和j的类型。若它们都是整形&#xff0c;则为通常的算术意义。若它们都为字符串型&#xff0c;则为进行拼接操作。若为用户自定义的class类型&#xff0c;则…

《Linux系统编程篇》Visual Studio Code配置下载,中文配置,连接远程ssh ——基础篇

引言 vscode绝对值得推荐&#xff0c;非常好用&#xff0c;如果你能体会其中的奥妙的话。 工欲善其事&#xff0c;必先利其器 ——孔子 文章目录 引言下载VS Code配置VS Code中文扩展连接服务器 连接服务器测试确定服务器的IP地址VS code 配置ssh信息选择连接到主机选择这个添…

K8s学习笔记1-搭建k8s集群

本次使用kubeadm方式&#xff0c;部署1.23.17版本 安装包百度云盘地址&#xff1a; 链接&#xff1a;https://pan.baidu.com/s/1UrIotP253DoyDIYB7G1C0Q 提取码&#xff1a;8q6a 集群所需虚拟机环境 主机名称IP资源harbor10.0.0.2301c2gmaster10.0.0.2312c4gworker110.0.0…

全新UI自助图文打印系统源码(含前端小程序源码 PHP后端 数据库)

最新自助图文打印系统和证件照云打印小程序源码PHP后端&#xff0c;为用户用户自助打印的服务&#xff0c;包括但不限于文档、图片、表格等多种格式的文件。此外&#xff0c;它们还提供了诸如美颜、换装、文档打印等功能&#xff0c;以及后台管理系统&#xff0c;方便管理员对打…

7.13 专题训练DP

P1255 数楼梯 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) ac代码 #include<bits/stdc.h> using namespace std; typedef long long ll; #define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0) const ll mod 1e97;int main() {IOS;int n;cin>>n;int a[…

面向对象与C++进阶—并发与多线程篇

文章目录 18. 并发与多线程篇说在前面(1).线程和进程(2).并发和并行(3).thread(C11)#1.thread库与thread类#2.join和detach方法#3.id类和get_id方法#4.this_thread和系列操作 (4).原子操作#1.为什么是原子?#2.为什么需要原子操作&#xff1f;#3.atomic库 (5).竞态条件(6).线程…

7-1、2、3 IPFS介绍使用及浏览器交互(react+区块链实战)

7-1、2、3 IPFS介绍使用及浏览器交互&#xff08;react区块链实战&#xff09; 7-1 ipfs介绍7-2 IPFS-desktop使用7-3 reactipfs-api浏览器和ipfs交互 7-1 ipfs介绍 IPFS区块链上的文件系统 https://ipfs.io/ 这个网站本身是需要科学上网的 Ipfs是点对点的分布式系统 无限…

TEB局部路径规划算法代码及原理解读

TEB(Timed Elastic Band) 是一个基于图优化的局部路径规划算法&#xff0c;具有较好的动态避障能力&#xff0c;在ROS1/ROS2的导航框架中均被采用。该图优化以g2o优化框架实现&#xff0c;以机器人在各个离散时刻的位姿和离散时刻之间的时间间隔为顶点&#xff0c;约束其中的加…