论文115:Reinforced GNNs for multiple instance learning (TNNLS‘24)

news2024/10/6 22:23:46

文章目录

  • 1 要点
  • 2 预备知识
    • 2.1 MIL
    • 2.2 MIL-GNN
    • 2.3 Markov博弈
    • 2.4 深度Q-Learning
  • 3 方法
    • 3.1 观测生成与交互
    • 3.2 动作选择和指导
    • 3.3 奖励计算
    • 3.4 状态转移和终止
    • 3.5 多智能体训练

1 要点

题目:用于MIL的强化GNN

代码:https://github.com/RingBDStack/RGMIL

背景:MIL是一种监督学习变体,它处理包含多个实例的包,其中训练阶段只有包级别的标签可用。MIL在现实世界的应用中很多,尤其是在医学领域;

挑战:现有的GNN在MIL中通常需要过滤实例间的低置信度边,并使用新的包结构来调整图神经网络架构。这样的调整过程频繁且忽视了结构和架构之间的相关性;

RGMIL框架:首次在MIL任务中利用多智能体深度强化学习 (MADRL)。MADRL允许灵活定义或扩展影响包图或GNN的因素,并同步控制它们;

贡献

  1. 引入MADRL到MIL中,实现对包结构和GNN架构的自动化和同步控制;
  2. 使用边阈值和GNN层数作为因素案例来构建RGMIL,探索了以前在MIL研究中被忽视的边密度和聚合范围之间的相关性;
  3. 实验结果表明,RGMIL在多个MIL数据集上实现了最佳性能,并且具有出色的可解释性;

细节

  1. RGMIL将训练过程建模为一个完全合作的马尔可夫博弈 (MG);
  2. 通过两个智能体搜索边过滤阈值和GNN层数;
  3. 利用反向分解网络 (VDN) 来衡量智能体的贡献和相关性;
  4. 引入图注意力网络 (GAT) 并设计参数共享机制以提高效率;

符号表

符号含义
B \mathcal{B} B包集合
G \mathcal{G} G与包相对应的、图的集合
Y \mathcal{Y} Y包标签
M \mathcal{M} MMarkov博弈的七元组
S \mathcal{S} S M \mathcal{M} M的状态空间
O \mathcal{O} O M \mathcal{M} M的观测空间
A \mathcal{A} A M \mathcal{M} M的动作空间
L \mathcal{L} L智能体或者GNN模型的训练损失
N N N包数量
M M M包内实例数量
L L LGNN层的数量
T T T时间步的数量
I I I智能体的数量
D D D特征表示的维度
A \mathbf{A} A与图相对应的邻接矩阵
F \mathbf{F} F与图相对应的实例特征矩阵
E \mathbf{E} E与图相对应的包图特征矩阵
Z \mathbf{Z} Z特征变换矩阵
C \mathbf{C} C重要性系数矩阵
i ; j ; k ; l ; t i;j;k;l;t i;j;k;l;t索引变量
s ; o ; a ; r s;o;a;r s;o;a;r状态、观测、动作、奖励
v v v注意力机制特征向量
γ \gamma γ折扣系数
α \alpha α智能体学习率
μ \mu μ动作或者奖励的窗口大小
λ \lambda λ终止条件的奖励阈值
& ; % \&;\% &;%逻辑和取余运算
⊕ \oplus 拼接操作
∥ ⋅ ∥ \|\cdot\| 矩阵的Norm函数
σ ( ⋅ ) \sigma(\cdot) σ()激活函数
π ( ⋅ ) \pi(\cdot) π()智能体状态-动作函数
RWD ( ⋅ ) \text{RWD}(\cdot) RWD()奖励函数
TRN ( ⋅ ) \text{TRN}(\cdot) TRN()状态转移函数
AGG ( ⋅ ) \text{AGG}(\cdot) AGG()特征聚合函数
POL ( ⋅ ) \text{POL}(\cdot) POL()特征池化函数
EVL ( ⋅ ) \text{EVL}(\cdot) EVL()分类性能评估函数

2 预备知识

2.1 MIL

B = { B i ∣ i = 1 , … , N } \mathcal{B}=\{\mathcal{B}_i|i=1,\dots,N\} B={Bii=1,,N}表示包含多个包 B i = { B i , j ∣ j = 1 … , M } \mathcal{B}_i=\{\mathcal{B}_{i,j}|j=1\dots,M\} Bi={Bi,jj=1,M},其中 N N N M M M分别表示包和包中实例的数量 (通常 M M M是变化的)。每个包对应一个两类包标签 Y i = max ⁡ ( Y i , 1 , … , Y i , M ) \mathcal{Y}_i=\max(\mathcal{Y}_{i,1},\dots,\mathcal{Y}_{i,M}) Yi=max(Yi,1,,Yi,M),其中 Y i , j ∈ { 0 , 1 } \mathcal{Y}_{i,j}\in\{0,1\} Yi,j{0,1}是假设的实例标签。尽管数据集中少量的实例具有真实的标签,然而在MIL的训练过程中,实例标签是不可用的。因此,MIL的目标是学习一个将包映射为标签的映射函数 B → Y \mathcal{B\to Y} BY,其中 Y = { Y i ∣ i = 1 , … , N } \mathcal{Y}=\{ \mathcal{Y}_i | i=1,\dots,N \} Y={Yii=1,,N}

2.2 MIL-GNN

对于MIL-GNN,其首先需要将所有的包转换为一个图的集合 G = { G i ∣ i = 1 , … , N } \mathcal{G}=\{ \mathcal{G}_i|i=1,\dots,N \} G={Gii=1,,N},其中每个包对应一个图 G i = ( A i , F i ) \mathcal{G}_i=(\mathbf{A}_i,\mathbf{F}_i) Gi=(Ai,Fi),此外,每个实例可以看作是一个节点。每个邻接矩阵 A i ∈ R M × M \mathbf{A}_i\in\mathbb{R}^{M\times M} AiRM×M使用原始节点特征构建,并通过阈值来过滤边,其每个元素表示一跳邻域信息。 F i ∈ R M × D \mathbf{F}_i\in\mathbb{R}^{M\times D} FiRM×D表示实例节点的特征矩阵

基于此, L L L层GNN被用于传递节点特征信息,其中对于第 i i i个图 G i \mathcal{G}_i Gi,其在第 l l l层的聚合过程表示为:
F i l = σ ( AGG l ( A i , F i l − 1 ) ) (1) \tag{1} \mathbf{F}_i^l=\sigma\left( \text{AGG}^l (\mathbf{A}_i,\mathbf{F}_i^{l-1})\right) Fil=σ(AGGl(Ai,Fil1))(1)其中 AGG l ( ⋅ ) \text{AGG}^l(\cdot) AGGl()表示在第 l l l层的聚合函数,例如卷积和注意力、 σ ( ⋅ ) \sigma(\cdot) σ()表示激活函数、 F i l \mathbf{F}_i^l Fil是更新后的特征矩阵。

接下来,一个节点特征池化函数 POL ( ⋅ ) \text{POL}(\cdot) POL()被用于GNN的最后一层,以获取最终的图级别特征矩阵 E ( i ) ∈ R 1 × D \mathbf{E}(i)\in\mathbb{R}^{1\times D} E(i)R1×D
E ( i ) = POL ( { F i L ( j ) ∣ j = 1 , … , M } ) (2) \tag{2} \mathbf{E}(i)=\text{POL}(\{ \mathbf{F}_i^L(j) |j=1,\dots,M \}) E(i)=POL({FiL(j)j=1,,M})(2)其中 F i L ( j ) ∈ R 1 × D \mathbf{F}_i^L(j)\in\mathbb{R}^{1\times D} FiL(j)R1×D是实例节点 B i , j \mathcal{B}_{i,j} Bi,j的特征向量。最后, E ( i ) \mathbf{E}(i) E(i)传递给一个包图分类器。因此,在MIL-GNN中,其映射过程为 B → G → Y \mathcal{B\to G\to Y} BGY

2.3 Markov博弈

在多智能体强化学习 (MARL) 中,Markov博弈 (MG) 是从Markov决策过程 (MDP) 扩展而来。特别地,一个MG包含多个能够共同影响奖励和状态转移的智能体。根据是否所有的智能体都能完全获得全局状态信息,已有的MG可被看作是完全或者部分可观测,其中后者则更为普遍。

部分可观测的MG可以被抽象为一个七元组 M = < S , O i , A i , π ( ⋅ ) , RED i ( ⋅ ) , TRN ( ⋅ ) , γ > \mathcal{M}=<\mathcal{S,O_i,A_i},\pi(\cdot),\text{RED}_i(\cdot),\text{TRN}(\cdot),\gamma> M=<S,Oi,Ai,π(),REDi(),TRN(),γ>,其中:

  1. S \mathcal{S} S:MG的全局状态空间;
  2. A i \mathcal{A}_i Ai:第 i i i个智能体的动作空间。在每个时间步 t ∈ [ 1 , T ] t\in[1,T] t[1,T],每个智能体根据其独有的状态动作函数 π i ( ⋅ ) \pi_i(\cdot) πi()来选择动作 a i t ∈ A i a_i^t\in\mathcal{A}_i aitAi
  3. 每个智能体会从全局状态获得一个独立的部分观察 o i t ∈ O i o_i^t\in\mathcal{O}_i oitOi,因此, π i ( ⋅ ) \pi_i(\cdot) πi()可以表示为 S → O i → A i \mathcal{S\to O_i\to A_i} SOiAi
  4. 每个智能体使用其奖励函数 RED i ( ⋅ ) \text{RED}_i(\cdot) REDi()获得即时奖励 r i t r_i^t rit,这种博弈也被称为分散的部分可观测MDP (Dec-POMDP),旨在最大化累积奖励 ∑ t = 1 T γ ( t − 1 ) r ∗ t \sum^T_{t=1}\gamma^{(t−1)}r^{*t} t=1Tγ(t1)rt,其中 γ γ γ表示控制后续奖励的折扣系数;
  5. 状态转移函数 TRN ( ⋅ ) \text{TRN}(\cdot) TRN()将当前状态 s t s^t st与联合动作 a ∗ t a^{*t} at映射到下一个状态 s ( t + 1 ) s^{(t+1)} s(t+1),即 S × A ∗ → S \mathcal{S \times A^*\to S} S×AS

2.4 深度Q-Learning

作为基于价值的RL的基算法,Q-Learning非常适合实现单一智能体的顺序决策系统。QLearning包含一个状态-动作表 π ( ⋅ ) π(·) π(),它记录了各种状态下所有可能动作的 Q Q Q值。初始化后,智能体不断与环境交互,并通过Bellman方程更新 π ( ⋅ ) π(·) π()直到收敛。 π ( ⋅ ) π(·) π()更新过程可以表示如下:
x = x + α [ r t + γ max ⁡ a π ( s t + 1 , a ) − x ] ] s.t.  x = π ( s t , a t ) (3) \tag{3} \begin{aligned} & x = x + \alpha \left[ r_t + \gamma \max_{a} \pi(s_{t+1}, a) - x \right] ]\\ & \text{s.t. } x = \pi(s_t, a_t) \end{aligned} x=x+α[rt+γamaxπ(st+1,a)x]]s.t. x=π(st,at)(3)其中: π ( s t , a t ) \pi(s_t, a_t) π(st,at)是预测的Q值,以及在状态 s t s_t st下选择动作 a t a_t at的预期奖励、 r t r_t rt表示时间步 t t t的即时奖励、 max ⁡ a π ( s t + 1 , a ) \max_a \pi(s_{t+1}, a) maxaπ(st+1,a)是下一个状态 s t + 1 s_{t+1} st+1的最大Q值,以及 α \alpha α π ( ⋅ ) \pi(·) π()的学习率。

在实际应用中,许多环境的状态空间是无限的,记录所有状态-动作对的值是不可行的。受深度学习的启发,许多工作引入了深度神经网络 (DNN) 来近似返回值,其中深度Q-Learning (DQN) 是传统Q-Learning的直接扩展:

  1. DQN使用DNN构建动作-价值函数 π π π (亦称为 Q Q Q函数),该函数将每个状态向量映射到 Q Q Q值向量 π ( s ) ∈ R 1 × ∣ A ∣ \pi(s) \in \mathbb{R}^{1 \times |A|} π(s)R1×A,其中 ∣ A ∣ |A| A表示动作空间 A A A的大小;
  2. DQN应用经验回放和目标网络技术来更新函数 π ( ⋅ ) \pi(·) π()。例如,给定过去时间步 t t t的经验记录,其元组形式为 ⟨ s t , a t , r t , s t + 1 ⟩ \langle s_t, a_t, r_t, s_{t+1} \rangle st,at,rt,st+1,则 π π π时序差分损失可以计算如下:
    L π = E s , a , r , s ′ [ ( π ‾ ( s t , a t ) − π ( s t , a t ) ) 2 ] s.t.  π ‾ ( s t , a t ) = r t + γ max ⁡ a π ‾ ( s t + 1 , a ) (4) \tag{4} \begin{aligned} &L_\pi = \mathbb{E}_{s,a,r,s'} \left[ \left( \overline{\pi}(s_t, a_t) - \pi(s_t, a_t) \right)^2 \right]\\ &\text{s.t. } \overline{\pi}(s_t, a_t) = r_t + \gamma \max_a \overline{\pi}(s_{t+1}, a) \end{aligned} Lπ=Es,a,r,s[(π(st,at)π(st,at))2]s.t. π(st,at)=rt+γamaxπ(st+1,a)(4)其中: π ( ⋅ ) \pi(·) π()表示评估网络,其用于预测状态 s t s_t st和动作 a t a_t at Q Q Q值的评估网络、 π ‾ ( ⋅ ) \overline{\pi}(·) π()是一个目标网络,其架构与 π ( ⋅ ) \pi(·) π()相同。只有 π ( ⋅ ) \pi(·) π()被优化,并且其训练参数周期性复制到 π ‾ ( ⋅ ) \overline{\pi}(·) π()。由于 π ‾ \overline{\pi} π不更新时目标 Q Q Q值是稳定的,因此 π ( ⋅ ) \pi(·) π()的训练稳定性是极好的;
  3. 为了权衡探索新动作的概率,DQN应用了 ϵ ϵ ϵ-贪婪算法。因此,它并不总是选择 π ( s ) \pi(s) π(s)中最大条目的对应动作,其可以表示如下:
    a = { random action , w.p. ϵ argmax a π ( s t , a ) , w.p. 1 − ϵ (5) \tag{5} \begin{aligned} a = \begin{cases} \text{random action}, & \text{w.p.} \quad\epsilon \\ \text{argmax}_a \pi(s_t, a), & \text{w.p.} \quad 1 - \epsilon \end{cases} \end{aligned} a={random action,argmaxaπ(st,a),w.p.ϵw.p.1ϵ(5)其中, ϵ \epsilon ϵ表示随机选择动作的概率,即探索,而 1 − ϵ 1-\epsilon 1ϵ表示选择当前基于 π π π的最优动作,即利用。通过这样做,DQN避免了在强化学习任务中的探索-利用困境,避开了局部最优,并促进了更好的 π π π函数的发现。

3 方法

本节介绍RGMIL的细节,包括:1) 用于提升博弈公平性的观测生成与交互;2) 用于提升GNN效率的动作选择和指导技术;3) 用于提升博弈稳定性的奖励计算;4) 用于确保博弈收敛的状态转移和终止技术;以及5) 多智能体训练。

RGMIL的总览如图4所示,其中左子图对应章节3.1至3.4,右子图对应章节3.5。

图4:RGMIL总览。左右子图分别对应经验收集和代理优化:1) 每一个时间步,初始观测从当前的block导出;2) 观测作为代理的输入,用于选择当前的动作;3) 构建可信包图,并作为定制的GNN的输入;4) GNN训练后,通过动作组合来评估性能,并确定当前的奖励;5) 带有动作的转移函数作为输入,以生成下一次观测;6) 记录以上过程,到达一定数量后,由VDN执行代理优化

3.1 观测生成与交互

在RGMIL中,我们将其训练过程建模为一个合作的马尔可夫博弈 (MG),涉及两个智能体,分别用于搜索最佳的边过滤阈值GNN层数

  1. 利用一个改进的VDN来实现MG:
    • 将训练集划分为多个等大小的区块,其中一个区块作为验证集,其余区块用作构建MG状态空间 S S S
    • 在第一个时间步之前,随机选择一个训练区块作为全局状态
    • 由于边过滤阈值的选择通常与拓扑信息相关,我们随后指定当前状态中包图的结构特征作为第一个智能体的观察;
    • 通过包图的成对相似性建立实例节点的初始边。以属于当前区块的第 i i i个包 B i \mathcal{B}_i Bi为例,它的包图 G i \mathcal{G}_i Gi可以被抽象为一个邻接矩阵 A i \mathbf{A}_i Ai以及一个特征矩阵 F i \mathbf{F}_i Fi
    • 给定初始矩阵 F i 0 \mathbf{F}^0_i Fi0初始邻接矩阵 A i \mathbf{A}_i Ai的计算如下:
      A i ( j , j ′ ) = ∥ F i 0 ( j ) − F i 0 ( j ′ ) ∥ 2 (6) \tag{6} \mathbf{A}_i(j, j') = \|\mathbf{F}^0_i(j) - \mathbf{F}^0_i(j')\|_2 Ai(j,j)=Fi0(j)Fi0(j)2(6)其中 ∥ ⋅ ∥ 2 \|\cdot\|_2 2表示矩阵的二范数、 A i ( j , j ′ ) \mathbf{A}_i(j, j') Ai(j,j)编码了第 j j j个和第 j ′ j' j个实例节点之间的欧式距离。
    • 因此,第一个智能体的观察计算如下:
      o 1 ( d ) = 1 N d ∑ i = 1 N d exp ⁡ ( − A i ) s.t.  M i = d , d ∈ [ 1 , max ⁡ M i ] (7) \tag{7} \begin{aligned} &o_1(d) = \frac{1}{N_d} \sum_{i=1}^{N_d} \exp(-\mathbf{A}_i)\\ & \text{s.t. } M_i = d, \quad d \in [1, \max M_i] \end{aligned} o1(d)=Nd1i=1Ndexp(Ai)s.t. Mi=d,d[1,maxMi](7)其中 o 1 ( d ) o_1(d) o1(d)表示向量 o 1 o_1 o1的第 d d d个条目、 N d N_d Nd是当前区块中包的数量,并且它包含的实例数量等于 d d d M i M_i Mi是包图 G i G_i Gi的实例节点数量;
  2. 由于GNN层数控制特征聚合的迭代,随后从初始节点特征 F i 0 \mathbf{F}^0_i Fi0中获取第二个智能体的观察
    o 2 = 1 N ∑ i = 1 N ( 1 M i ∑ j = 1 M i F i 0 ( j ) ) (8) \tag{8} o_2 = \frac{1}{N} \sum_{i=1}^{N} \left( \frac{1}{M_i} \sum_{j=1}^{M_i} F^0_i(j) \right) o2=N1i=1N(Mi1j=1MiFi0(j))(8)其中 F i 0 ( j ) \mathbf{F}^0_i(j) Fi0(j)是第 j j j个实例节点的特征向量、 N N N是当前区块中包图的总数;
  3. 为了进一步探索边密度和聚合迭代之间的潜在相关性,引入了观察信息交互
    o 1 = o 1 ⊕ σ ( ( o 1 ⊕ o 2 ) ( o 2 ⊕ o 1 ) T o 1 ) o 2 = o 2 ⊕ σ ( ( o 1 ⊕ o 2 ) ( o 2 ⊕ o 1 ) T o 2 ) (9) \tag{9} \begin{aligned} &o_1 = o_1 \oplus \sigma((o_1 \oplus o_2)(o_2 \oplus o_1)^T {o_1})\\ &o_2 = o_2 \oplus \sigma((o_1 \oplus o_2)(o_2 \oplus o_1)^T {o_2}) \end{aligned} o1=o1σ((o1o2)(o2o1)To1)o2=o2σ((o1o2)(o2o1)To2)(9)其中 ⊕ ( ⋅ ) \oplus(\cdot) ()是向量的连接操作。通过此操作,观察 o 1 o_1 o1 o 2 o_2 o2具有相同的维度,并且都编码了来自对方的信息;

RGMIL减轻了由于观察的特征维度或信息量的变化可能导致的MG中的不公平博弈。此外,为了提高这部分的效率,RGMIL只为每个数据区块一次性计算并记录这些初始邻接矩阵和观察

3.2 动作选择和指导

当输入当前的观察向量 o i o_i oi后,每个智能体将其映射为一个 Q Q Q值向量 π i ( o i ) ∈ R 1 × ∣ A i ∣ \pi_i(o_i) \in \mathbb{R}^{1 \times |\mathcal{A}_i|} πi(oi)R1×Ai,并基于最大的 Q Q Q值条目或随机选择一个动作 a i a_i ai (如公式5):

  1. 第一个阈值动作 a 1 ∈ [ 0 , 1 ] a_1 \in [0, 1] a1[0,1]是一个小数,而第二个层数动作 a 2 a_2 a2是一个整数;
  2. a 1 a_1 a1的指导下,可以获得一个更可靠的邻接矩阵 A i \mathbf{A}_i Ai
    A i ( j , j ′ ) = { 1 , if  exp ⁡ ( − A i ( j , j ′ ) ) ≥ a 1 0 , if  exp ⁡ ( − A i ( j , j ′ ) ) < a 1 (10) \tag{10} \mathbf{A}_i(j, j') = \begin{cases} 1, & \text{if } \exp(-\mathbf{A}_i(j, j')) \geq a_1 \\ 0, & \text{if } \exp(-\mathbf{A}_i(j, j')) < a_1 \end{cases} Ai(j,j)={1,0,if exp(Ai(j,j))a1if exp(Ai(j,j))<a1(10)
  3. a 2 a_2 a2的指导下,RGMIL将构建定制的GNN。以GAT为例,节点特征的聚合过程可以表示为:
    C i ( l − 1 ) ( j , j ′ ) = v ⋅ ( F i ( l − 1 ) ( j ) Z ( l − 1 ) ⊕ F i ( l − 1 ) ( j ′ ) Z ( l − 1 ) ) T F i l ( j ) = σ ( ∑ j ′ x F i ( l − 1 ) ( j ′ ) Z ( l − 1 ) ) s.t. x = softmax ( σ ( C i l − 1 ( j , j ′ ) ) ) & A i ( j , j ′ ) = 1 \begin{aligned} &\mathbf{C}^{(l-1)}_i(j, j') = v \cdot (\mathbf{F}^{(l-1)}_i(j) \mathbf{Z}^{(l-1)} \oplus \mathbf{F}^{(l-1)}_i(j') \mathbf{Z}^{(l-1)} )^T\\ &\mathbf{F}^{l}_i(j) = \sigma\left(\sum_{j'}x \mathbf{F}^{(l-1)}_i(j') \mathbf{Z}^{(l-1)}\right)\\ & \text{s.t.}\quad x=\text{softmax}\left(\sigma\left(\mathbf{C}_i^{l-1}\left(j,j'\right)\right)\right)\&\mathbf{A}_i(j,j')=1 \end{aligned} Ci(l1)(j,j)=v(Fi(l1)(j)Z(l1)Fi(l1)(j)Z(l1))TFil(j)=σ jxFi(l1)(j)Z(l1) s.t.x=softmax(σ(Cil1(j,j)))&Ai(j,j)=1其中: l ∈ [ 1 , a 2 ] l \in [1, a_2] l[1,a2]表示运行 a 2 a_2 a2次迭代聚合、 F i l ( j ) \mathbf{F}^{l}_i(j) Fil(j)是第 l l l层GNN中第 j j j个节点 B i , j \mathcal{B}_{i,j} Bi,j D l D_l Dl维特征向量、 Z ( l − 1 ) \mathbf{Z}^{(l-1)} Z(l1)表示形状为 D ( l − 1 ) × D l D^{(l-1)} \times D^l D(l1)×Dl特征转换矩阵。此外, v ∈ R 1 × 2 D l v \in \mathbb{R}^{1 \times 2D^l} vR1×2Dl表示自注意力机制的特征向量、 C i ( j , j ′ ) \mathbf{C}_i(j, j') Ci(j,j)是邻居 B i , j \mathcal{B}_{i,j} Bi,j相对于其目标 B i , j ′ \mathcal{B}_{i,j'} Bi,j的重要性系数,其需要通过softmax函数获得,以及 & \& &表示逻辑操作。基于注意力的节点特征池化函数,RGMIL获得当前训练区块的最终包图特征矩阵 E ∈ R N × D a 2 E \in \mathbb{R}^{N \times D^{{a_2}}} ERN×Da2
    C i ( a 2 ) ( j ) = softmax ( v ′ ( F i a 2 ( j ) Z ′ ) T ) E ( i ) = ∑ j = 1 M i C i a 2 ( j ) F i a 2 ( j ) (12) \tag{12} \begin{aligned} &\mathbf{C}^{(a_2)}_i(j) = \text{softmax}\left( v' (\mathbf{F}^{a_2}_i(j) \mathbf{Z}')^T \right)\\ &\mathbf{E}(i) = \sum_{j=1}^{M_i} \mathbf{C}^{a_2}_i(j) \mathbf{F}^{a_2}_i(j) \end{aligned} Ci(a2)(j)=softmax(v(Fia2(j)Z)T)E(i)=j=1MiCia2(j)Fia2(j)(12)其中: F i a 2 ( j ) \mathbf{F}^{a_2}_i(j) Fia2(j)表示最后一层中第 j j j个节点的 D a 2 D^{a_2} Da2维特征向量,其相应的重要性系数为 C i a 2 ( j ) \mathbf{C}^{a_2}_i(j) Cia2(j) v ′ v' v Z ′ \mathbf{Z}' Z分别是注意力机制的查询向量和线性变换矩阵。 E ( i ) \mathbf{E}(i) E(i) E \mathbf{E} E的第 i i i行,也是包图 G i G_i Gi的特征向量。结合包图标签 Y \mathcal{Y} YGNN损失表示为:
    L GNN = − ∑ i = 1 N Y ‾ i log ⁡ ( E ( i ) Z ‾ ) T (13) \tag{13} \mathcal{L}_{\text{GNN}} = -\sum_{i=1}^{N} \overline{\mathcal{Y}}_i \log\left(\mathbf{E}(i) \overline{\mathbf{Z}}\right)^T LGNN=i=1NYilog(E(i)Z)T(13)其中: Z ‾ \overline{\mathbf{Z}} Z是图分类器、 Y ‾ i \overline{\mathcal{Y}}_i Yi是包图 G i G_i Gi的标签向量,由 Y i ∈ Y \mathcal{Y}_i \in \mathcal{Y} YiY扩展而来。

为了提高GNN效率,RGMIL在GNN框架中引入了参数共享机制,其层数固定为最大动作值。这样,RGMIL每次只需要使用并微调GNN框架的前 a 2 a_2 a2层。RGMIL避免了每次重建和重新训练新GNN时消耗大量时间和空间资源。此外,RGMIL记录了每个动作组合的出现次数。如果 ( a 1 , a 2 ) (a_1, a_2) (a1,a2)的记录超过了预定义的数量,当前的GNN训练过程将被省略。

3.3 奖励计算

获得动作组合并优化GNN之后,RGMIL将通过在验证数据区块上计算即时奖励来评估该组合。具体来说,在RGMIL建模的完全合作MG中,所有智能体拥有相同的联合奖励 (也称为团队奖励)。由于GNN模型旨在提高表示学习,奖励是基于相邻时间步上的包图分类性能差异来计算的。类似地,RGMIL根据动作 a 1 a_1 a1处理验证样本,并将它们输入到具有 a 2 a_2 a2层的模型中。奖励函数 RWD ( ⋅ ) \text{RWD}(\cdot) RWD()
r ∗ = RWD ( a 1 , a 2 ) = EVL ( t ) − 1 μ ∑ t ′ t EVL ( t ′ ) s.t. t ′ = t − μ + 1 (14) \tag{14} \begin{aligned} &r^* = \text{RWD}(a_1, a_2) = \text{EVL}(t) - \frac{1}{\mu} \sum_{t'}^{t} \text{EVL}(t')\\ &\text{s.t.}\quad t'=t-\mu+1 \end{aligned} r=RWD(a1,a2)=EVL(t)μ1ttEVL(t)s.t.t=tμ+1(14)其中 t t t表示当前步、 EVL ( ⋅ ) \text{EVL}(\cdot) EVL()分类性能评估函数 μ \mu μ表示历史记录窗口大小。RGMIL平均 μ \mu μ个历史记录以确保奖励的可靠性以及博弈的稳定性。特别的, μ \mu μ还作为动作组合的预定义记录数量。

3.4 状态转移和终止

RGMIL引入了一种新颖的启发式状态转移函数来获取下一个全局状态和观察:

  1. RGMIL根据当前动作组合计算下一个全局状态的数据区块索引。考虑到 a 1 a_1 a1 a 2 a_2 a2分别属于小数和整数,RGMIL将它们视为不同的状态转移依赖性。下一个状态对应的数据区块索引 k k k计算如下:
    k = ( ( round ( a 1 ) + a 2 ) % ∣ S ∣ ) + 1 s.t.  k ∈ [ 1 , ∣ S ∣ ] & round ( a 1 ) ∈ { 0 , 1 } & a 2 > ∣ S ∣ (15) \tag{15} \begin{aligned} &k = ((\text{round}(a_1) + a_2) \%|S|) + 1\\ &\text{s.t. } k \in [1, |S|] \& \text{round}(a_1) \in \{0, 1\} \& a_2 > |S| \end{aligned} k=((round(a1)+a2)%∣S)+1s.t. k[1,S]&round(a1){0,1}&a2>S(15)其中: round ( ⋅ ) \text{round}(\cdot) round()表示四舍五入、 % \% %是余数操作。动作 a 2 a_2 a2较大以确保覆盖训练区块,而 round ( a 1 ) \text{round}(a_1) round(a1)提供小偏移以增加变化。由于 k k k主要受 a 2 a_2 a2的影响,RGMIL避免了由于后期两个动作同时剧烈波动可能导致的博弈不收敛问题;
  2. 通过第3.1节介绍的方法构建下一个观察;
  3. 一个时间步的经验 ⟨ ( o 1 , o 2 ) , ( a 1 , a 2 ) , r ∗ , ( o 1 ′ , o 2 ′ ) ⟩ \langle (o_1, o_2), (a_1, a_2), r^*, (o'_1, o'_2) \rangle ⟨(o1,o2),(a1,a2),r,(o1,o2)⟩被存储起来。转移将不会终止,直到达到最后一个时间步 T T T,或者在较早的中间步 t t t (其中 t ≤ T t \leq T tT) 满足以下终止条件
    ∣ 1 μ ∑ t ′ t r ∗ t ′ ∣ < λ , s.t.   t ′ = t − μ + 1 (16) \tag{16} \left|\frac{1}{\mu} \sum_{t'}^{t} r^{*t'}\right| < \lambda,\qquad\text{s.t.}\ \ t'=t-\mu+1 μ1ttrt <λ,s.t.  t=tμ+1(16)其中:不等式符号表示过去 μ \mu μ个奖励的平均值没有超过预定义阈值 λ \lambda λ,以及 r ∗ t ′ r^{*t'} rt是过去时间步 t ′ t' t联合奖励

3.5 多智能体训练

当历史经验的数量大于 μ \mu μ并且博弈尚未结束时,RGMIL需要在完成上述过程后通过经验回放训练两个智能体。由于VDN证明了联合 Q Q Q函数可以分解为不同智能体的 Q Q Q函数,因此RGMIL以值分解的方式更新智能体

  1. RGMIL通过反向传播将联合 Q Q Q值分解给每个智能体。两个智能体将通过测量它们对联合 Q Q Q值的贡献来积极地朝着共同的目标工作;
  2. 给定在时间步 t t t收集的经验元组 ⟨ ( o 1 t , o 2 t ) , ( a 1 t , a 2 t ) , r t ∗ , ( o 1 t + 1 , o 2 t + 1 ) ⟩ \langle (o^t_1, o^t_2), (a^t_1, a^t_2), r^*_t, (o^{t+1}_1, o^{t+1}_2) \rangle ⟨(o1t,o2t),(a1t,a2t),rt,(o1t+1,o2t+1)⟩智能体的联合损失计算如下:
    L π ∗ = E ⟨ s , a , r , s ′ ⟩ [ ( ( π ‾ ∗ ( o ∗ t , a ∗ t ) − π ∗ ( o ∗ t , a ∗ t ) ) 2 ] (17) \tag{17} L_{\pi^*} = \mathbb{E}_{\langle s,a,r,s'\rangle} \left[ \left( (\overline{\pi}^*(o^{*t},a^{*t}) - \pi^*(o^{*t},a^{*t})\right)^2 \right] Lπ=Es,a,r,s[((π(ot,at)π(ot,at))2](17)其中 π ∗ ( o ∗ t , a ∗ t ) \pi^*(o^{*t},a^{*t}) π(ot,at)表示预测的联合 Q Q Q
    π ∗ ( o ∗ t , a ∗ t ) ≈ π 1 ( o 1 t , a 1 t ) + π 2 ( o 2 t , a 2 t ) (18) \tag{18} \pi^*(o^{*t},a^{*t}) \approx \pi_1(o^t_1, a^t_1) + \pi_2(o^t_2, a^t_2) π(ot,at)π1(o1t,a1t)+π2(o2t,a2t)(18)其中: π 1 ( ⋅ ) \pi_1(\cdot) π1() π 2 ( ⋅ ) \pi_2(\cdot) π2()分别是第一个和第二个智能体的评估网络 (Q函数)、 π ‾ ∗ ( o ∗ t , a ∗ t ) \overline{\pi}^*(o^{*t},a^{*t}) π(ot,at)表示目标联合 Q Q Q其类似于公式(18)的总和形式,其中每个加法分量 π ‾ i ( o i t , a i t ) \overline{\pi}_i(o^t_i, a^t_i) πi(oit,ait))可以表示为:
    π ‾ i ( o i t , a i t ) = r ∗ t + γ max ⁡ a π ‾ i ( o i t + 1 , a ) (19) \tag{19} \overline{\pi}_i(o^t_i, a^t_i) = r^{*t} + \gamma \max_{a} \overline{\pi}_i(o^{t+1}_i, a) πi(oit,ait)=rt+γamaxπi(oit+1,a)(19)其中: π ‾ i ( ⋅ ) \overline{\pi}_i(\cdot) πi()是第 i i i个智能体的目标网络,其架构与 π i ( ⋅ ) \pi_i(\cdot) πi()相同;
  3. RGMIL只训练评估网络,并且每隔 μ \mu μ个时间步将它们的参数复制到目标网络
  4. 为了缓解可能的高估问题,RGMIL采用传统的双DQN算法来计算目标Q值,该算法用评估网络确定动作,并用目标网络计算 Q Q Q值,如图4右子图所示。因此,公式(19)重写为:
    π ‾ i ( o i t , a i t ) = r ∗ t + γ π ‾ i ( o i t + 1 , argmax a π ( o i t + 1 , a ) ) (20) \tag{20} \overline{\pi}_i(o^t_i, a^t_i) = r^{*t} + \gamma \overline{\pi}_i \left(o^{t+1}_i, \text{argmax}_a \pi\left(o^{t+1}_i, a\right) \right) πi(oit,ait)=rt+γπi(oit+1,argmaxaπ(oit+1,a))(20)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1718331.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ollama 本地大模型框架

该篇教程主要讲解*Ollama的安装和简单使用* Ollama&#xff1a; 在本地启动并运行大型语言模型。 主要流程目录&#xff1a; 1.安装 2.使用 2.1.下载模型 2.2.简单使用 2.3.中文模型 2.4.中文社区 3.总结 1.安装 创建一个容器 切换”高级视图“ 参考填写 ollama oll…

ARM32开发——总线与时钟

&#x1f3ac; 秋野酱&#xff1a;《个人主页》 &#x1f525; 个人专栏:《Java专栏》《Python专栏》 ⛺️心若有所向往,何惧道阻且长 文章目录 APB总线时钟树时钟树 外部晶振内部晶振 在这个例子中&#xff0c;这条大街和巴士构成了一套系统&#xff0c;我们称之为AHB总线。 …

响应式界面控件DevExtreme - 更强的数据分析和可视化功能

DevExtreme拥有高性能的HTML5 / JavaScript小部件集合&#xff0c;使您可以利用现代Web开发堆栈&#xff08;包括React&#xff0c;Angular&#xff0c;ASP.NET Core&#xff0c;jQuery&#xff0c;Knockout等&#xff09;构建交互式的Web应用程序。从Angular和Reac&#xff0c…

新火种AI|OpenAI要和苹果合作了?微软有些不高兴

作者&#xff1a;一号 编辑&#xff1a;美美 和苹果之间的合作&#xff0c;可能会称为Altman引以为傲的功绩。 根据 The Information 援引知情人士的消息&#xff0c;OpenAI 已经和苹果达成了协议&#xff0c;将在其产品中运用 OpenAI 的对话式 AI。 如果进展顺利&#xff…

gitlab服务器迁移(亲测有效)

描述&#xff1a;最近公司迁移gitlab&#xff0c;我没有迁移过&#xff0c;经过网上查找资料最终完成迁移&#xff0c;途中也遇到挺多坑和两个问题&#xff0c;希望能帮到你。 新服务器安装gitlab 注意&#xff1a;新服务器gitlab版本也需要和旧版本一致。 首先查看原Gitlab…

12V转5V5A降压芯片:AH8317的全面解析

# 12V转5V降压芯片&#xff1a;AH8317的全面解析 在电子设计领域&#xff0c;电压转换器是不可或缺的组件之一&#xff0c;它们允许电子设备在不同的电源电压下稳定运行。今天&#xff0c;我们将深入探讨一款高性能的同步降压转换器——AH8317&#xff0c;它以其出色的性能和广…

连锁便利店水电远程抄表管理系统是什么?

一、系统概述 连锁便利店水电远程抄表管理系统是一种高效、智能化的解决方案&#xff0c;旨在优化便利店的能源管理&#xff0c;提高运营效率。它通过先进的技术手段&#xff0c;实现了对便利店水电用量的实时监控和远程抄表&#xff0c;大大降低了人工成本&#xff0c;提升了…

悬剑武器库5.04版

工具介绍 悬剑5 基于“悬剑网盘”精选工具集悬剑5“飞廉”云武器库制作。 操作系统&#xff1a;Windows 10 专业版 锁屏密码&#xff1a;secquan.org 解压密码: 圈子社区secquan.org 镜像大小&#xff1a;33.1GB 系统占用空间63.0 GB 镜像导入 下载镜像&#xff0c;文末…

vm:为虚拟机配置多个虚拟网卡(ubuntu20.04)

前言&#xff1a; 环境&#xff1a;虚拟机 ubuntu 20.04 要求&#xff1a;如标题&#xff0c;但是这里针对的是 ubuntu 20.04&#xff0c;对于其他操作系统&#xff0c;可以找一下其他操作系统对应的配置文件是什么 vm 添加虚拟网卡 首先进入 vm&#xff1a; 点击设置&#xf…

员工恶意删除公司数据怎么办,如何防范员工恶意删除公司数据

员工恶意删除公司数据怎么办&#xff0c;如何防范员工恶意删除公司数据 面对员工恶意删除公司数据的情况&#xff0c;企业应当采取一系列紧急且有序的应对措施&#xff0c;以最小化损失并确保业务连续性。以下是一套推荐的应对流程&#xff1a; 1.立即行动&#xff1a; 断开网…

freertos初体验 - 在stm32上移植

1. 说明 freertos内核 非常精简&#xff0c;代码量也很少&#xff0c;官方也针对主流的编译器和内核准备好了移植文件&#xff0c;所以 freertos 的移植是非常简单的&#xff0c;很多工具&#xff08;例如CubeMX&#xff09;点点鼠标就可以生成一个 freertos 的工程&#xff0…

【Python】解决Python报错:ModuleNotFoundError: No module named ‘xxx.yyy‘

&#x1f9d1; 博主简介&#xff1a;阿里巴巴嵌入式技术专家&#xff0c;深耕嵌入式人工智能领域&#xff0c;具备多年的嵌入式硬件产品研发管理经验。 &#x1f4d2; 博客介绍&#xff1a;分享嵌入式开发领域的相关知识、经验、思考和感悟&#xff0c;欢迎关注。提供嵌入式方向…

VNC server ubuntu20 配置

介绍 最近想使用实验室的4卡服务器跑一些深度学习实验&#xff0c;因为跑的是三维建图实验&#xff0c;需要配上可视化界面&#xff0c;本来自带的IPMI可以可视化&#xff0c;但分辨率固定在640*480&#xff0c;看起来很别扭&#xff0c;就捣鼓服务器远程可视化访问了两天&…

无法删除dll文件

碰到xxxxxx.dll文件无法删除不要慌&#xff01; 通过Tasklist /m dll文件名称 去查看它和哪个系统文件绑定运行&#xff0c;发现是explorer.exe。 我们如果直接通过del命令【当然需要在该dll文件所在的路径中】。发现拒绝访问 我们需要在任务管理器中&#xff0c;将资源管理器…

【开源】在线考试系统 JAVA+Vue.js+SpringBoot 新手入门项目

目录 一、项目介绍 二、项目截图 三、核心代码 【开源】在线考试系统 JAVAVue.jsSpringBoot 新手入门项目 一、项目介绍 经典老框架SSM打造入门项目《在线考试系统》&#xff0c;包括班级模块、教师学生模块、试卷模块、试题模块、考试模块、考试回顾模块&#xff0c;项目编…

采样频率低于“奈奎斯特频率”时发生的混叠现象(抽样定理与信号恢复实验)

混叠现象&#xff08;Aliasing&#xff09; 混叠现象发生在采样频率低于奈奎斯特频率时&#xff0c;即采样频率低于信号最高频率的两倍。此时&#xff0c;信号的高频成分会被错误地映射到低频范围内&#xff0c;导致无法正确重建原始信号。具体来说&#xff1a; 奈奎斯特频率…

6-Django项目--分页模块化封装参数共存

目录 utils/page_data.py 分页模块化封装 在app当中创建一个python package 在当前包里面创建py文件 参数共存 完整代码 utils/page_data.py --包里创建py文件. # -*- coding:utf-8 -*- from django.utils.safestring import mark_safe from copy import deepcopyclass…

怎么制作能下载文件的二维码?扫码实现文件下载的方法

现在很多人为了能够方便其他人查看文件&#xff0c;经常会将文件生成二维码图片后&#xff0c;将二维码分享给其他人扫码在手机上查看&#xff0c;这种方式既能够节省成本&#xff0c;又可以实现多人同时获取内容&#xff0c;有利于文件的快速分享。 在制作文件二维码的时候&a…

python数据集优化技巧:统一小分类的方法

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、引言 二、统一小分类的需求与背景 三、统一小分类的步骤与方法 1. 数据集分析 2. 确…

renren-fast-vue启动报错

问题描述 拉取人人开源vue项目启动失败 报错信息 版本信息 序号名称版本号1node14.21.3 启动方案 1.拉取项目 git clone https://gitee.com/renrenio/renren-fast-vue.git 2.执行安装依赖命令 npm install 3.此时报错 chromedriver2.27.2 install: node install.js 4.手动…