用于OOD预测的稳定学习

news2024/10/5 21:25:53

当测试数据和训练数据共享相似的分布时,基于深度神经网络的方法取得了惊人的性能,但在其他情况下可能会失败。因此,消除训练和测试数据之间分布变化的影响对于构建有前景的深度模型至关重要。作者考虑了一个更具挑战性的情况。通过训练样本的学习来消除特征之间的依赖关系,从而解决OOD问题,这有助于深度模型摆脱虚假的相关性,进而更多地关注判别特征和标签之间的真实联系

来自:Deep stable learning for out-of-distribution generalization

fig1

  • 图1:当大多数训练图像包含水中的狗时,ResNet-18和StableNet生成的显著性图的可视化。
  • 显著性图的亮度指示模型对输入图像的特定区域的关注程度(即,较亮的区域比较暗的区域对预测起着更关键的作用)。由于虚假的相关性,ResNet18模型倾向于同时关注狗和水,而StableNet主要关注狗。

fig2

  • 图2:StableNet的总体架构。LSWD是指去相关的学习样本加权。最终损失用于优化分类网络。

作者通过对样本进行全局加权来直接解除每个输入样本的所有特征相关性,从而消除相关和不相关特征之间的统计相关性,从而解决分布偏移问题。StableNet利用随机傅立叶特征(RFF)和样本加权的特性,消除了特征之间的线性和非线性依赖关系。

该文要解决的问题,就是如何在深度学习网络中找到一组样本权重,使得所有变量之间都可以做到互相独立,即任意选取一个变量为目标变量,目标变量的分布不随其它变量的值的改变而改变

X ⊂ R m X X\sub \mathbb{R}^{m_X} XRmX表示原始像素的空间, Y ⊂ R m Y Y\sub \mathbb{R}^{m_Y} YRmY表示输出空间, Z ⊂ R m Z Z\sub\mathbb{R}^{m_Z} ZRmZ为表征空间。 f : X → Z f:X\rightarrow Z f:XZ为表征函数, g : Z → Y g:Z\rightarrow Y g:ZY为预测函数。假设有 n n n个样本, X i X_{i} Xi y i y_{i} yi表示第 i i i个样本, Z i : j Z_{i:j} Zi:j表示第 i i i个样本的第 j j j个变量。 w ∈ R n w\in\mathbb{R}^{n} wRn表示样本权重, u u u v v v为随机傅里叶特征映射函数。

为了消除任何一对特征 Z : , i Z_{:,i} Z:,i Z : , j Z_{:,j} Z:,j之间的相关性,作者引入了假设检验来衡量随机变量之间的独立性。假设有两个一维随机变量 A , B A,B A,B A , B A,B A,B代表 Z : , i Z_{:,i} Z:,i Z : , j Z_{:,j} Z:,j以简化描述),分别从 A A A B B B的分布中采样 ( A 1 , A 2 , . . . , A n ) (A_{1},A_{2},...,A_{n}) (A1,A2,...,An) ( B 1 , B 2 , . . . , B n ) (B_{1},B_{2},...,B_{n}) (B1,B2,...,Bn),主要问题是这两个变量基于样本的相关性是如何的。


正定核的一个重要性质是能够产生一个内积空间的特征映射,使得在该映射下的内积运算等价于在输入空间中进行的核函数计算。

RKHS是正定核函数所对应的函数空间,它是一个希尔伯特空间(Hilbert Space),具有一些特殊的性质。在RKHS中,核函数起到了一个重要的作用,它定义了内积运算和范数,从而形成了一个完备的函数空间。


考虑在随机变量 A A A的域上有可测量的正定核 k A k_{A} kA,相应的RKHS由 H A H_{A} HA表示, k B , H B k_{B},H_{B} kB,HB同样被定义,交叉-协方差操作 Σ A B \Sigma_{AB} ΣAB为: E A B [ h A ( A ) h B ( B ) ] − E A [ h A ( A ) ] E B [ h B ( B ) ] \mathbb{E}_{AB}[h_{A}(A)h_{B}(B)]-\mathbb{E}_{A}[h_{A}(A)]\mathbb{E}_{B}[h_{B}(B)] EAB[hA(A)hB(B)]EA[hA(A)]EB[hB(B)]其中, h A ∈ H A , h B ∈ H B h_{A}\in H_{A},h_{B}\in H_{B} hAHA,hBHB。然后,独立性可以由以下命题确定: Σ A B = 0 ↔ A ⊥ B \Sigma_{AB}=0\leftrightarrow A\bot B ΣAB=0AB

深度网络的各维特征间存在复杂的依赖关系,仅去除变量间的线形相关性并不足以完全消除无关特征与标签之间的虚假关联,所以一个直接的想法就是通过kernel(核方法)映射到高维空间,但是经过kernel映射后原始特征的特征图维度被扩大到无穷维,使得各维变量间的相关性无法计算。

鉴于随机傅立叶特征(Random Fourier Feature, RFF)在近似核函数以及衡量特征独立性方面的性质,采用RFF将原始特征映射到高维空间中(可以理解为在样本维度进行扩充),消除新特征间的线形相关性即可保证原始特征严格独立。

RFF的函数空间为 H R F F H_{RFF} HRFF H R F F = { h : x → 2 c o s ( w x + ϕ ) ∣ w ∼ N ( 0 , 1 ) , ϕ ∼ U n i f o r m ( 0 , 2 π ) } H_{RFF}=\left\{h:x\rightarrow\sqrt{2}cos(wx+\phi)|w\sim N(0,1),\phi\sim Uniform(0,2\pi)\right\} HRFF={h:x2 cos(wx+ϕ)wN(0,1),ϕUniform(0,2π)}使用 w w w做样本加权,且 ∑ i = 1 n w i = n \sum_{i=1}^{n}w_{i}=n i=1nwi=n。加权后,变量 A A A B B B的交叉协方差为: Σ ^ A B ; w = 1 n − 1 ∑ i = 1 n [ ( w i u ( A i ) − 1 n ∑ j = 1 n w j u ( A j ) ) T ⋅ ( w i v ( B i ) − 1 n ∑ j = 1 n w j v ( B j ) ) ] u ( A ) = ( u 1 ( A ) , . . . , u n A ( A ) ) , u j ( A ) ∈ H R F F v ( B ) = ( v 1 ( B ) , . . . , v n B ( B ) ) , v j ( B ) ∈ H R F F \widehat{\Sigma}_{AB;w}=\frac{1}{n-1}\sum_{i=1}^{n}[(w_{i}u(A_{i})-\frac{1}{n}\sum_{j=1}^{n}w_{j}u(A_{j}))^{T}\cdot (w_{i}v(B_{i})-\frac{1}{n}\sum_{j=1}^{n}w_{j}v(B_{j}))]\\ u(A)=(u_{1}(A),...,u_{n_{A}}(A)),u_{j}(A)\in H_{RFF}\\v(B)=(v_{1}(B),...,v_{n_{B}}(B)),v_{j}(B)\in H_{RFF} Σ AB;w=n11i=1n[(wiu(Ai)n1j=1nwju(Aj))T(wiv(Bi)n1j=1nwjv(Bj))]u(A)=(u1(A),...,unA(A)),uj(A)HRFFv(B)=(v1(B),...,vnB(B)),vj(B)HRFFStableNet的目标是独立任何一对特征: w ∗ = a r g m i n w ∈ Δ n ∑ 1 ≤ i ≤ j ≤ m Z ∣ ∣ Σ ^ Z : , i Z : , j ; w ∣ ∣ F 2 w^{*}=argmin_{w\in\Delta_{n}}\sum_{1\leq i\leq j\leq m_{Z}}||\widehat{\Sigma}_{Z_{:,i}Z_{:,j};w}||_{F}^{2} w=argminwΔn1ijmZ∣∣Σ Z:,iZ:,j;wF2其中, Δ n = { w ∈ R n ∣ ∑ i = 1 n w i = n } \Delta_{n}=\left\{w\in R^{n}|\sum_{i=1}^{n}w_{i}=n\right\} Δn={wRni=1nwi=n}因此,用最优 w ∗ w^* w对训练样本进行加权可以最大限度地减轻特征之间的依赖性。

算法迭代优化样本权重 w w w、表示函数 f f f和预测函数 g g g,如下所示: f ( t + 1 ) , g ( t + 1 ) = a r g m i n f , g ∑ i = 1 n w i ( t ) L ( g ( f ( X i ) ) , y i ) w ( t + 1 ) = a r g m i n w ∈ Δ n ∑ 1 ≤ i ≤ j ≤ m Z ∣ ∣ Σ ^ Z : , i ( t + 1 ) Z : , j ( t + 1 ) ; w ∣ ∣ F 2 f^{(t+1)},g^{(t+1)}=argmin_{f,g}\sum_{i=1}^{n}w^{(t)}_{i}L(g(f(X_{i})),y_{i})\\ w^{(t+1)}=argmin_{w\in\Delta_{n}}\sum_{1\leq i\leq j\leq m_{Z}}||\widehat{\Sigma}_{Z_{:,i}^{(t+1)}Z_{:,j}^{(t+1)};w}||_{F}^{2} f(t+1),g(t+1)=argminf,gi=1nwi(t)L(g(f(Xi)),yi)w(t+1)=argminwΔn1ijmZ∣∣Σ Z:,i(t+1)Z:,j(t+1);wF2其中, Z ( t + 1 ) = f ( t + 1 ) ( X ) Z^{(t+1)}=f^{(t+1)}(X) Z(t+1)=f(t+1)(X) L L L表示交叉熵损失, t t t为时间步,初始 w ( 0 ) = ( 1 , 1 , . . . , 1 ) T w^{(0)}=(1,1,...,1)^{T} w(0)=(1,1,...,1)T

上述公式要求在训练过程中为每个训练样本都学习一个特定的权重,但在实践中,尤其对于深度学习任务,要想利用全部样本全局地学习样本权重需要巨大的计算和存储开销。此外,使用SGD对网络进行优化时,每轮迭代中仅有部分样本对模型可见,因此无法获取全部样本。

作者提出一种存储、重加载样本特征与样本权重的方法,在每个训练迭代的结束融合并保存当前的样本特征与权重,在下一个训练迭代开始时重加载,作为训练数据的全局先验知识优化新一轮的样本权重。

对于每个batch,用于优化样本权重的特征生成如下: Z O = C o n c a t ( Z G 1 , Z G 2 , . . . , Z G k , Z L ) w O = C o n c a t ( w G 1 , w G 2 , . . . , w G k , w L ) Z_{O}=Concat(Z_{G_{1}},Z_{G_{2}},...,Z_{G_{k}},Z_{L})\\ w_{O}=Concat(w_{G_{1}},w_{G_{2}},...,w_{G_{k}},w_{L}) ZO=Concat(ZG1,ZG2,...,ZGk,ZL)wO=Concat(wG1,wG2,...,wGk,wL)这里,符号 Z O Z_{O} ZO w O w_{O} wO分别表示用于优化新样本权重的特征和权重, Z G 1 , Z G 2 , . . . , Z G k Z_{G_{1}},Z_{G_{2}},...,Z_{G_{k}} ZG1,ZG2,...,ZGk w G 1 , w G 2 , . . . , w G k w_{G_{1}},w_{G_{2}},...,w_{G_{k}} wG1,wG2,...,wGk分别为全局特征和权重,其在每个批次结束时更新并且表示整个训练数据集的全局信息。 Z L Z_L ZL w L w_L wL是当前batch中的特征和权重,表示局部信息。

用于合并上式中的所有特征的操作是沿着样本的级联,比如,如果batch size为 B B B Z O Z_{O} ZO为矩阵,size是 ( ( k + 1 ) B ) × m Z ((k+1)B)\times m_{Z} ((k+1)B)×mZ w O w_{O} wO ( ( k + 1 ) B ) ((k+1)B) ((k+1)B)维的向量,在对每个batch训练时,保持 w G i w_{G_{i}} wGi固定,只有 w L w_{L} wL是可学习的。在每次迭代训练结束时,融合全局信息 ( Z G i , w G i ) (Z_{G_{i}},w_{G_{i}}) (ZGi,wGi)和局部信息 ( Z L , w L ) (Z_{L},w_{L}) (ZL,wL) Z G i ′ = α i Z G i + ( 1 − α i ) Z L w G i ′ = α i w G i + ( 1 − α i ) w L Z'_{G_{i}}=\alpha_{i} Z_{G_{i}}+(1-\alpha_{i})Z_{L}\\w'_{G_{i}}=\alpha_{i} w_{G_{i}}+(1-\alpha_{i})w_{L} ZGi=αiZGi+(1αi)ZLwGi=αiwGi+(1αi)wL对于每组全局信息 ( Z G i , w G i ) (Z_{G_{i}},w_{G_{i}}) (ZGi,wGi),我们使用 k k k个不同的平滑参数以考虑long-term memory( α i \alpha_{i} αi较大),和short-term memory( α i \alpha_{i} αi较小), k k k表示预测特征是原始特征的 k k k倍。然后,用 ( Z G i ′ , w G i ′ ) (Z'_{G_{i}},w'_{G_{i}}) (ZGi,wGi)代替所有 ( Z G i , w G i ) (Z_{G_{i}},w_{G_{i}}) (ZGi,wGi)作为下一batch。

在推理阶段,预测模型直接进行预测,而不需要计算任何样本权重。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/526700.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

大数据Doris(二十一):Bloom Filter索引以及Doris索引总结

文章目录 Bloom Filter索引以及Doris索引总结 一、Bloom Filter索引 1、BloomFilter索引原理 2、BloomFilter索引语法 3、注意事项 二、Doris索引总结 Bloom Filter索引以及Doris索引总结 一、Bloom Filter索引 1、BloomFilter索引原理 BloomFilter是由Bloom在1970年提…

移动机器人运动规划---基于图搜索的基础知识---广度优先遍历与深度优先遍历

移动机器人运动规划---基于图搜索的基础知识---广度优先遍历与深度优先遍历 广度优先搜索(BFS)深度优先搜索(DFS)BFS vs DFS 图搜索优化的方向就是: 按照什么规则去访问节点,按照什么规则弹出节点&#xff…

快速了解 TypeScript

目录 1、简介 2、安装TypeScript 3、编译代码 4、类型注解 5、接口 6、类 7、运行TypeScript Web应用 1、简介 TypeScript是JavaScript类型的超集,它可以编译成纯JavaScript。 TypeScript可以在任何浏览器、任何计算机和任何操作系统上运行,并且…

【哈士奇赠书活动 - 23期】-〖你好 ChatGPT〗

文章目录 ⭐️ 赠书 - 《你好 ChatGPT》⭐️ 内容简介⭐️ 作者简介⭐️ 精彩书评⭐️ 赠书活动 → 获奖名单 ⭐️ 赠书 - 《你好 ChatGPT》 ⭐️ 内容简介 人工智能(AI)时代已经来临,AIGC(人工智能生成内容)正在进一步…

【精选】各种节日祝福(C语言,可修改),Easyx图形库应用+源代码分享

博主:命运之光✨✨ 专栏:Easyx图形库应用📂 目录 ✨一、程序展示 范例一:❤新年祝福❤ 范例二:❤母亲节祝福❤ ✨二、项目环境 简单介绍一下easyx图形库应用 Easyx图形库 ✨三、运行效果展示(视频&am…

【C++起飞之路】初级——缺省参数、函数重载、引用

C:函数重载、引用 一、缺省参数🛫1.1 🚝什么是缺省参数1.2 🚝缺省参数的分类a. 全缺省参数b. 半缺省参数(部分缺省参数) 1.3 🚝注意事项 二、函数重载🛫2.1 🚝什么是函数…

时间复杂度:根号n一般来说大于log(n)

f ( x ) x − l o g 2 x f(x)\sqrt{x}-log_2 x f(x)x ​−log2​x 对这函数求导后,比较分母大小,可以得到结论 f ( x ) f(x) f(x)先减后增,分界点为 x 4 ( l n 2 ) 2 x \frac{4}{(ln2)^2} x(ln2)24​ f ( x ) f(x) f(x)的图像如下所示&a…

PPT技能之文字格式,转身的文字这样做

只要用PPT,一定需要设置文字格式。好的文字格式,给人惊艳的感觉,是一种愉悦的享受。 你的关注,是我最大的动力!你的转发,我的10W!茫茫人海有你的支持,给我无限动力。 1、字体。 按…

什么是Java中的阻塞队列?它有什么作用?

在Java中,阻塞队列是一种特殊的队列,它可以在队列为空或队列已满时阻塞添加或移除元素的操作。阻塞队列通常用于多线程编程中,可以帮助我们更加方便地进行线程通信和协作。在本文中,我将从面试的角度,详细讲解Java中的…

在线办公时代,如何选择合适的云办公软件?

文章目录 在线办公时代,如何选择合适的云办公软件?在线文档石墨文档腾讯文档飞书文档 远程控制ToDesk向日葵 会议协同腾讯会议ZOOM 总结 在线办公时代,如何选择合适的云办公软件? 随着数字经济的发展和疫情的影响,云办…

100天精通Python(可视化篇)——第87天:matplotlib绘制不同种类炫酷雷达图参数说明+代码实战(普通、堆叠、多个、矩阵、极坐标雷达图)

文章目录 专栏导读1. 雷达图1)介绍2)参数说明 2. 基本雷达图3. 堆叠雷达图4. 六边形战士5. 多个雷达图6. 雷达图矩阵7. 极坐标雷达图 专栏导读 🔥🔥本文已收录于《100天精通Python从入门到就业》:本专栏专门针对零基础…

做一名活动策划是什么体验

在一些不了解的人眼中,活动策划就是那种外表光鲜亮丽,气场十足,眼神犀利,跷着二郎腿,情绪饱满的完成一场又一场的完美的秀。 好像确实是这样,但是你们又知不知道这背后的一切我们活动策划到底付出了什么&a…

SpringMVC的三大功能

目录 一、初识SpringMVC 1.1 MVC的定义 1.2 MVC和SpringMVC的关系是什么? 1.3 SpringMVC的重要性 二、Spring MVC的三大功能 2.1 连接功能 2.1.1 RequestMapping 注解介绍 2.1.2 GetMapping 和 PostMapping 2.2 获取参数功能 2.2.1 传递普通参数 2.2.2 传递对象 2…

【K8s】Ingress的使用

文章目录 一、Ingress介绍1、Ingress的作用2、Ingress工作流程 二、Ingress使用1、测试数据准备2、HTTP代理3、HTTPS代理 一、Ingress介绍 1、Ingress的作用 上一章中,NotePort和LoadBalancer类型的Service可给集群外部机器提供访问,但这两种类型都有缺…

JavaScript数组

1.数组是什么 2.数组的基本使用 3.操作数组 4.数组案例 一、数组是什么? 1.数组(Array)是一种可以按顺序保存数据的数据类型2.为什么要使用数组?例如:如果想保存一个班所有同学的姓名怎么办?场景:如果有多个数据可以用…

vue3中ts定义对象,pinia中使用ts定义状态对象

文章目录 引入reactive中使用数组reactive中定义对象类型pinia中定义状态对象 引入 用惯了js,突然使用ts属实有点不习惯,这里介绍一下自己在vue3中使用ts初始化内容的一些小技巧 reactive中使用数组 例如下面所示的代码,我们就像写js代码一…

数组a与数组b作内积:即a和b所有对应位置两元素相乘 将所有的相乘结果(积)求和 numpy.inner(a,b)

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 数组a与数组b作内积: 即a和b所有对应位置两元素相乘 将所有的相乘结果(积)求和 numpy.inner(a,b) [太阳]选择题 请问关于以下代码的输出结果是? import numpy as np …

招银网络科技-2024届暑期实习-Java后端开发

目录 1.SpringBoot 中的 SpringBootApplication注解的作用是什么?2.SpringBoot 中你们是如何加载配置信息的?3.RabbitMQ 如何保证消息不丢失?4.如果消费者这边消费到一半宕机了怎么办?5.RabbitMQ 如何保证消息没有被重复消费&…

C语言函数大全-- w 开头的函数(3)

C语言函数大全 本篇介绍C语言函数大全-- w 开头的函数 1. wcsdup 1.1 函数说明 函数声明函数功能wchar_t *wcsdup(const wchar_t *str);用于复制宽字符字符串 参数: str : 待复制的宽字符串 返回值: 如果成功复制,则返回指向该…

跨域解决方案

同源策略 同源策略是一种约定,它是浏览器最核心也是最基本的安全功能,如果缺少了同源策略,浏览器很容易受到XSS、CSRF等攻击。 所谓的同源是指【协议域名端口】三者相同,即便两个不同的域名,指向同一个IP地址&#xf…