文章目录
- 1 要点
- 2 预备知识
- 2.1 MIL
- 2.2 MIL-GNN
- 2.3 Markov博弈
- 2.4 深度Q-Learning
- 3 方法
- 3.1 观测生成与交互
- 3.2 动作选择和指导
- 3.3 奖励计算
- 3.4 状态转移和终止
- 3.5 多智能体训练
1 要点
题目:用于MIL的强化GNN
代码:https://github.com/RingBDStack/RGMIL
背景:MIL是一种监督学习变体,它处理包含多个实例的包,其中训练阶段只有包级别的标签可用。MIL在现实世界的应用中很多,尤其是在医学领域;
挑战:现有的GNN在MIL中通常需要过滤实例间的低置信度边,并使用新的包结构来调整图神经网络架构。这样的调整过程频繁且忽视了结构和架构之间的相关性;
RGMIL框架:首次在MIL任务中利用多智能体深度强化学习 (MADRL)。MADRL允许灵活定义或扩展影响包图或GNN的因素,并同步控制它们;
贡献:
- 引入MADRL到MIL中,实现对包结构和GNN架构的自动化和同步控制;
- 使用边阈值和GNN层数作为因素案例来构建RGMIL,探索了以前在MIL研究中被忽视的边密度和聚合范围之间的相关性;
- 实验结果表明,RGMIL在多个MIL数据集上实现了最佳性能,并且具有出色的可解释性;
细节:
- RGMIL将训练过程建模为一个完全合作的马尔可夫博弈 (MG);
- 通过两个智能体搜索边过滤阈值和GNN层数;
- 利用反向分解网络 (VDN) 来衡量智能体的贡献和相关性;
- 引入图注意力网络 (GAT) 并设计参数共享机制以提高效率;
符号表:
符号 | 含义 |
---|---|
B \mathcal{B} B | 包集合 |
G \mathcal{G} G | 与包相对应的、图的集合 |
Y \mathcal{Y} Y | 包标签 |
M \mathcal{M} M | Markov博弈的七元组 |
S \mathcal{S} S | M \mathcal{M} M的状态空间 |
O \mathcal{O} O | M \mathcal{M} M的观测空间 |
A \mathcal{A} A | M \mathcal{M} M的动作空间 |
L \mathcal{L} L | 智能体或者GNN模型的训练损失 |
N N N | 包数量 |
M M M | 包内实例数量 |
L L L | GNN层的数量 |
T T T | 时间步的数量 |
I I I | 智能体的数量 |
D D D | 特征表示的维度 |
A \mathbf{A} A | 与图相对应的邻接矩阵 |
F \mathbf{F} F | 与图相对应的实例特征矩阵 |
E \mathbf{E} E | 与图相对应的包图特征矩阵 |
Z \mathbf{Z} Z | 特征变换矩阵 |
C \mathbf{C} C | 重要性系数矩阵 |
i ; j ; k ; l ; t i;j;k;l;t i;j;k;l;t | 索引变量 |
s ; o ; a ; r s;o;a;r s;o;a;r | 状态、观测、动作、奖励 |
v v v | 注意力机制特征向量 |
γ \gamma γ | 折扣系数 |
α \alpha α | 智能体学习率 |
μ \mu μ | 动作或者奖励的窗口大小 |
λ \lambda λ | 终止条件的奖励阈值 |
& ; % \&;\% &;% | 逻辑和取余运算 |
⊕ \oplus ⊕ | 拼接操作 |
∥ ⋅ ∥ \|\cdot\| ∥⋅∥ | 矩阵的Norm函数 |
σ ( ⋅ ) \sigma(\cdot) σ(⋅) | 激活函数 |
π ( ⋅ ) \pi(\cdot) π(⋅) | 智能体状态-动作函数 |
RWD ( ⋅ ) \text{RWD}(\cdot) RWD(⋅) | 奖励函数 |
TRN ( ⋅ ) \text{TRN}(\cdot) TRN(⋅) | 状态转移函数 |
AGG ( ⋅ ) \text{AGG}(\cdot) AGG(⋅) | 特征聚合函数 |
POL ( ⋅ ) \text{POL}(\cdot) POL(⋅) | 特征池化函数 |
EVL ( ⋅ ) \text{EVL}(\cdot) EVL(⋅) | 分类性能评估函数 |
2 预备知识
2.1 MIL
令 B = { B i ∣ i = 1 , … , N } \mathcal{B}=\{\mathcal{B}_i|i=1,\dots,N\} B={Bi∣i=1,…,N}表示包含多个包 B i = { B i , j ∣ j = 1 … , M } \mathcal{B}_i=\{\mathcal{B}_{i,j}|j=1\dots,M\} Bi={Bi,j∣j=1…,M},其中 N N N和 M M M分别表示包和包中实例的数量 (通常 M M M是变化的)。每个包对应一个两类包标签 Y i = max ( Y i , 1 , … , Y i , M ) \mathcal{Y}_i=\max(\mathcal{Y}_{i,1},\dots,\mathcal{Y}_{i,M}) Yi=max(Yi,1,…,Yi,M),其中 Y i , j ∈ { 0 , 1 } \mathcal{Y}_{i,j}\in\{0,1\} Yi,j∈{0,1}是假设的实例标签。尽管数据集中少量的实例具有真实的标签,然而在MIL的训练过程中,实例标签是不可用的。因此,MIL的目标是学习一个将包映射为标签的映射函数 B → Y \mathcal{B\to Y} B→Y,其中 Y = { Y i ∣ i = 1 , … , N } \mathcal{Y}=\{ \mathcal{Y}_i | i=1,\dots,N \} Y={Yi∣i=1,…,N}。
2.2 MIL-GNN
对于MIL-GNN,其首先需要将所有的包转换为一个图的集合 G = { G i ∣ i = 1 , … , N } \mathcal{G}=\{ \mathcal{G}_i|i=1,\dots,N \} G={Gi∣i=1,…,N},其中每个包对应一个图 G i = ( A i , F i ) \mathcal{G}_i=(\mathbf{A}_i,\mathbf{F}_i) Gi=(Ai,Fi),此外,每个实例可以看作是一个节点。每个邻接矩阵 A i ∈ R M × M \mathbf{A}_i\in\mathbb{R}^{M\times M} Ai∈RM×M使用原始节点特征构建,并通过阈值来过滤边,其每个元素表示一跳邻域信息。 F i ∈ R M × D \mathbf{F}_i\in\mathbb{R}^{M\times D} Fi∈RM×D表示实例节点的特征矩阵。
基于此,
L
L
L层GNN被用于传递节点特征信息,其中对于第
i
i
i个图
G
i
\mathcal{G}_i
Gi,其在第
l
l
l层的聚合过程表示为:
F
i
l
=
σ
(
AGG
l
(
A
i
,
F
i
l
−
1
)
)
(1)
\tag{1} \mathbf{F}_i^l=\sigma\left( \text{AGG}^l (\mathbf{A}_i,\mathbf{F}_i^{l-1})\right)
Fil=σ(AGGl(Ai,Fil−1))(1)其中
AGG
l
(
⋅
)
\text{AGG}^l(\cdot)
AGGl(⋅)表示在第
l
l
l层的聚合函数,例如卷积和注意力、
σ
(
⋅
)
\sigma(\cdot)
σ(⋅)表示激活函数、
F
i
l
\mathbf{F}_i^l
Fil是更新后的特征矩阵。
接下来,一个节点特征池化函数
POL
(
⋅
)
\text{POL}(\cdot)
POL(⋅)被用于GNN的最后一层,以获取最终的图级别特征矩阵
E
(
i
)
∈
R
1
×
D
\mathbf{E}(i)\in\mathbb{R}^{1\times D}
E(i)∈R1×D:
E
(
i
)
=
POL
(
{
F
i
L
(
j
)
∣
j
=
1
,
…
,
M
}
)
(2)
\tag{2} \mathbf{E}(i)=\text{POL}(\{ \mathbf{F}_i^L(j) |j=1,\dots,M \})
E(i)=POL({FiL(j)∣j=1,…,M})(2)其中
F
i
L
(
j
)
∈
R
1
×
D
\mathbf{F}_i^L(j)\in\mathbb{R}^{1\times D}
FiL(j)∈R1×D是实例节点
B
i
,
j
\mathcal{B}_{i,j}
Bi,j的特征向量。最后,
E
(
i
)
\mathbf{E}(i)
E(i)传递给一个包图分类器。因此,在MIL-GNN中,其映射过程为
B
→
G
→
Y
\mathcal{B\to G\to Y}
B→G→Y。
2.3 Markov博弈
在多智能体强化学习 (MARL) 中,Markov博弈 (MG) 是从Markov决策过程 (MDP) 扩展而来。特别地,一个MG包含多个能够共同影响奖励和状态转移的智能体。根据是否所有的智能体都能完全获得全局状态信息,已有的MG可被看作是完全或者部分可观测,其中后者则更为普遍。
部分可观测的MG可以被抽象为一个七元组 M = < S , O i , A i , π ( ⋅ ) , RED i ( ⋅ ) , TRN ( ⋅ ) , γ > \mathcal{M}=<\mathcal{S,O_i,A_i},\pi(\cdot),\text{RED}_i(\cdot),\text{TRN}(\cdot),\gamma> M=<S,Oi,Ai,π(⋅),REDi(⋅),TRN(⋅),γ>,其中:
- S \mathcal{S} S:MG的全局状态空间;
- A i \mathcal{A}_i Ai:第 i i i个智能体的动作空间。在每个时间步 t ∈ [ 1 , T ] t\in[1,T] t∈[1,T],每个智能体根据其独有的状态动作函数 π i ( ⋅ ) \pi_i(\cdot) πi(⋅)来选择动作 a i t ∈ A i a_i^t\in\mathcal{A}_i ait∈Ai;
- 每个智能体会从全局状态获得一个独立的部分观察 o i t ∈ O i o_i^t\in\mathcal{O}_i oit∈Oi,因此, π i ( ⋅ ) \pi_i(\cdot) πi(⋅)可以表示为 S → O i → A i \mathcal{S\to O_i\to A_i} S→Oi→Ai;
- 每个智能体使用其奖励函数 RED i ( ⋅ ) \text{RED}_i(\cdot) REDi(⋅)获得即时奖励 r i t r_i^t rit,这种博弈也被称为分散的部分可观测MDP (Dec-POMDP),旨在最大化累积奖励 ∑ t = 1 T γ ( t − 1 ) r ∗ t \sum^T_{t=1}\gamma^{(t−1)}r^{*t} ∑t=1Tγ(t−1)r∗t,其中 γ γ γ表示控制后续奖励的折扣系数;
- 状态转移函数 TRN ( ⋅ ) \text{TRN}(\cdot) TRN(⋅)将当前状态 s t s^t st与联合动作 a ∗ t a^{*t} a∗t映射到下一个状态 s ( t + 1 ) s^{(t+1)} s(t+1),即 S × A ∗ → S \mathcal{S \times A^*\to S} S×A∗→S。
2.4 深度Q-Learning
作为基于价值的RL的基算法,Q-Learning非常适合实现单一智能体的顺序决策系统。QLearning包含一个状态-动作表
π
(
⋅
)
π(·)
π(⋅),它记录了各种状态下所有可能动作的
Q
Q
Q值。初始化后,智能体不断与环境交互,并通过Bellman方程更新
π
(
⋅
)
π(·)
π(⋅)直到收敛。
π
(
⋅
)
π(·)
π(⋅)的更新过程可以表示如下:
x
=
x
+
α
[
r
t
+
γ
max
a
π
(
s
t
+
1
,
a
)
−
x
]
]
s.t.
x
=
π
(
s
t
,
a
t
)
(3)
\tag{3} \begin{aligned} & x = x + \alpha \left[ r_t + \gamma \max_{a} \pi(s_{t+1}, a) - x \right] ]\\ & \text{s.t. } x = \pi(s_t, a_t) \end{aligned}
x=x+α[rt+γamaxπ(st+1,a)−x]]s.t. x=π(st,at)(3)其中:
π
(
s
t
,
a
t
)
\pi(s_t, a_t)
π(st,at)是预测的Q值,以及在状态
s
t
s_t
st下选择动作
a
t
a_t
at的预期奖励、
r
t
r_t
rt表示时间步
t
t
t的即时奖励、
max
a
π
(
s
t
+
1
,
a
)
\max_a \pi(s_{t+1}, a)
maxaπ(st+1,a)是下一个状态
s
t
+
1
s_{t+1}
st+1的最大Q值,以及
α
\alpha
α是
π
(
⋅
)
\pi(·)
π(⋅)的学习率。
在实际应用中,许多环境的状态空间是无限的,记录所有状态-动作对的值是不可行的。受深度学习的启发,许多工作引入了深度神经网络 (DNN) 来近似返回值,其中深度Q-Learning (DQN) 是传统Q-Learning的直接扩展:
- DQN使用DNN构建动作-价值函数 π π π (亦称为 Q Q Q函数),该函数将每个状态向量映射到 Q Q Q值向量 π ( s ) ∈ R 1 × ∣ A ∣ \pi(s) \in \mathbb{R}^{1 \times |A|} π(s)∈R1×∣A∣,其中 ∣ A ∣ |A| ∣A∣表示动作空间 A A A的大小;
- DQN应用经验回放和目标网络技术来更新函数
π
(
⋅
)
\pi(·)
π(⋅)。例如,给定过去时间步
t
t
t的经验记录,其元组形式为
⟨
s
t
,
a
t
,
r
t
,
s
t
+
1
⟩
\langle s_t, a_t, r_t, s_{t+1} \rangle
⟨st,at,rt,st+1⟩,则
π
π
π的时序差分损失可以计算如下:
L π = E s , a , r , s ′ [ ( π ‾ ( s t , a t ) − π ( s t , a t ) ) 2 ] s.t. π ‾ ( s t , a t ) = r t + γ max a π ‾ ( s t + 1 , a ) (4) \tag{4} \begin{aligned} &L_\pi = \mathbb{E}_{s,a,r,s'} \left[ \left( \overline{\pi}(s_t, a_t) - \pi(s_t, a_t) \right)^2 \right]\\ &\text{s.t. } \overline{\pi}(s_t, a_t) = r_t + \gamma \max_a \overline{\pi}(s_{t+1}, a) \end{aligned} Lπ=Es,a,r,s′[(π(st,at)−π(st,at))2]s.t. π(st,at)=rt+γamaxπ(st+1,a)(4)其中: π ( ⋅ ) \pi(·) π(⋅)表示评估网络,其用于预测状态 s t s_t st和动作 a t a_t at的 Q Q Q值的评估网络、 π ‾ ( ⋅ ) \overline{\pi}(·) π(⋅)是一个目标网络,其架构与 π ( ⋅ ) \pi(·) π(⋅)相同。只有 π ( ⋅ ) \pi(·) π(⋅)被优化,并且其训练参数周期性复制到 π ‾ ( ⋅ ) \overline{\pi}(·) π(⋅)。由于 π ‾ \overline{\pi} π不更新时目标 Q Q Q值是稳定的,因此 π ( ⋅ ) \pi(·) π(⋅)的训练稳定性是极好的; - 为了权衡探索新动作的概率,DQN应用了
ϵ
ϵ
ϵ-贪婪算法。因此,它并不总是选择
π
(
s
)
\pi(s)
π(s)中最大条目的对应动作,其可以表示如下:
a = { random action , w.p. ϵ argmax a π ( s t , a ) , w.p. 1 − ϵ (5) \tag{5} \begin{aligned} a = \begin{cases} \text{random action}, & \text{w.p.} \quad\epsilon \\ \text{argmax}_a \pi(s_t, a), & \text{w.p.} \quad 1 - \epsilon \end{cases} \end{aligned} a={random action,argmaxaπ(st,a),w.p.ϵw.p.1−ϵ(5)其中, ϵ \epsilon ϵ表示随机选择动作的概率,即探索,而 1 − ϵ 1-\epsilon 1−ϵ表示选择当前基于 π π π的最优动作,即利用。通过这样做,DQN避免了在强化学习任务中的探索-利用困境,避开了局部最优,并促进了更好的 π π π函数的发现。
3 方法
本节介绍RGMIL的细节,包括:1) 用于提升博弈公平性的观测生成与交互;2) 用于提升GNN效率的动作选择和指导技术;3) 用于提升博弈稳定性的奖励计算;4) 用于确保博弈收敛的状态转移和终止技术;以及5) 多智能体训练。
RGMIL的总览如图4所示,其中左子图对应章节3.1至3.4,右子图对应章节3.5。
3.1 观测生成与交互
在RGMIL中,我们将其训练过程建模为一个合作的马尔可夫博弈 (MG),涉及两个智能体,分别用于搜索最佳的边过滤阈值和GNN层数:
- 利用一个改进的VDN来实现MG:
- 将训练集划分为多个等大小的区块,其中一个区块作为验证集,其余区块用作构建MG状态空间 S S S;
- 在第一个时间步之前,随机选择一个训练区块作为全局状态;
- 由于边过滤阈值的选择通常与拓扑信息相关,我们随后指定当前状态中包图的结构特征作为第一个智能体的观察;
- 通过包图的成对相似性建立实例节点的初始边。以属于当前区块的第 i i i个包 B i \mathcal{B}_i Bi为例,它的包图 G i \mathcal{G}_i Gi可以被抽象为一个邻接矩阵 A i \mathbf{A}_i Ai以及一个特征矩阵 F i \mathbf{F}_i Fi;
- 给定初始矩阵
F
i
0
\mathbf{F}^0_i
Fi0,初始邻接矩阵
A
i
\mathbf{A}_i
Ai的计算如下:
A i ( j , j ′ ) = ∥ F i 0 ( j ) − F i 0 ( j ′ ) ∥ 2 (6) \tag{6} \mathbf{A}_i(j, j') = \|\mathbf{F}^0_i(j) - \mathbf{F}^0_i(j')\|_2 Ai(j,j′)=∥Fi0(j)−Fi0(j′)∥2(6)其中 ∥ ⋅ ∥ 2 \|\cdot\|_2 ∥⋅∥2表示矩阵的二范数、 A i ( j , j ′ ) \mathbf{A}_i(j, j') Ai(j,j′)编码了第 j j j个和第 j ′ j' j′个实例节点之间的欧式距离。 - 因此,第一个智能体的观察计算如下:
o 1 ( d ) = 1 N d ∑ i = 1 N d exp ( − A i ) s.t. M i = d , d ∈ [ 1 , max M i ] (7) \tag{7} \begin{aligned} &o_1(d) = \frac{1}{N_d} \sum_{i=1}^{N_d} \exp(-\mathbf{A}_i)\\ & \text{s.t. } M_i = d, \quad d \in [1, \max M_i] \end{aligned} o1(d)=Nd1i=1∑Ndexp(−Ai)s.t. Mi=d,d∈[1,maxMi](7)其中 o 1 ( d ) o_1(d) o1(d)表示向量 o 1 o_1 o1的第 d d d个条目、 N d N_d Nd是当前区块中包的数量,并且它包含的实例数量等于 d d d、 M i M_i Mi是包图 G i G_i Gi的实例节点数量;
- 由于GNN层数控制特征聚合的迭代,随后从初始节点特征
F
i
0
\mathbf{F}^0_i
Fi0中获取第二个智能体的观察:
o 2 = 1 N ∑ i = 1 N ( 1 M i ∑ j = 1 M i F i 0 ( j ) ) (8) \tag{8} o_2 = \frac{1}{N} \sum_{i=1}^{N} \left( \frac{1}{M_i} \sum_{j=1}^{M_i} F^0_i(j) \right) o2=N1i=1∑N(Mi1j=1∑MiFi0(j))(8)其中 F i 0 ( j ) \mathbf{F}^0_i(j) Fi0(j)是第 j j j个实例节点的特征向量、 N N N是当前区块中包图的总数; - 为了进一步探索边密度和聚合迭代之间的潜在相关性,引入了观察信息交互:
o 1 = o 1 ⊕ σ ( ( o 1 ⊕ o 2 ) ( o 2 ⊕ o 1 ) T o 1 ) o 2 = o 2 ⊕ σ ( ( o 1 ⊕ o 2 ) ( o 2 ⊕ o 1 ) T o 2 ) (9) \tag{9} \begin{aligned} &o_1 = o_1 \oplus \sigma((o_1 \oplus o_2)(o_2 \oplus o_1)^T {o_1})\\ &o_2 = o_2 \oplus \sigma((o_1 \oplus o_2)(o_2 \oplus o_1)^T {o_2}) \end{aligned} o1=o1⊕σ((o1⊕o2)(o2⊕o1)To1)o2=o2⊕σ((o1⊕o2)(o2⊕o1)To2)(9)其中 ⊕ ( ⋅ ) \oplus(\cdot) ⊕(⋅)是向量的连接操作。通过此操作,观察 o 1 o_1 o1和 o 2 o_2 o2具有相同的维度,并且都编码了来自对方的信息;
RGMIL减轻了由于观察的特征维度或信息量的变化可能导致的MG中的不公平博弈。此外,为了提高这部分的效率,RGMIL只为每个数据区块一次性计算并记录这些初始邻接矩阵和观察。
3.2 动作选择和指导
当输入当前的观察向量 o i o_i oi后,每个智能体将其映射为一个 Q Q Q值向量 π i ( o i ) ∈ R 1 × ∣ A i ∣ \pi_i(o_i) \in \mathbb{R}^{1 \times |\mathcal{A}_i|} πi(oi)∈R1×∣Ai∣,并基于最大的 Q Q Q值条目或随机选择一个动作 a i a_i ai (如公式5):
- 第一个阈值动作 a 1 ∈ [ 0 , 1 ] a_1 \in [0, 1] a1∈[0,1]是一个小数,而第二个层数动作 a 2 a_2 a2是一个整数;
- 在
a
1
a_1
a1的指导下,可以获得一个更可靠的邻接矩阵
A
i
\mathbf{A}_i
Ai:
A i ( j , j ′ ) = { 1 , if exp ( − A i ( j , j ′ ) ) ≥ a 1 0 , if exp ( − A i ( j , j ′ ) ) < a 1 (10) \tag{10} \mathbf{A}_i(j, j') = \begin{cases} 1, & \text{if } \exp(-\mathbf{A}_i(j, j')) \geq a_1 \\ 0, & \text{if } \exp(-\mathbf{A}_i(j, j')) < a_1 \end{cases} Ai(j,j′)={1,0,if exp(−Ai(j,j′))≥a1if exp(−Ai(j,j′))<a1(10) - 在
a
2
a_2
a2的指导下,RGMIL将构建定制的GNN。以GAT为例,节点特征的聚合过程可以表示为:
C i ( l − 1 ) ( j , j ′ ) = v ⋅ ( F i ( l − 1 ) ( j ) Z ( l − 1 ) ⊕ F i ( l − 1 ) ( j ′ ) Z ( l − 1 ) ) T F i l ( j ) = σ ( ∑ j ′ x F i ( l − 1 ) ( j ′ ) Z ( l − 1 ) ) s.t. x = softmax ( σ ( C i l − 1 ( j , j ′ ) ) ) & A i ( j , j ′ ) = 1 \begin{aligned} &\mathbf{C}^{(l-1)}_i(j, j') = v \cdot (\mathbf{F}^{(l-1)}_i(j) \mathbf{Z}^{(l-1)} \oplus \mathbf{F}^{(l-1)}_i(j') \mathbf{Z}^{(l-1)} )^T\\ &\mathbf{F}^{l}_i(j) = \sigma\left(\sum_{j'}x \mathbf{F}^{(l-1)}_i(j') \mathbf{Z}^{(l-1)}\right)\\ & \text{s.t.}\quad x=\text{softmax}\left(\sigma\left(\mathbf{C}_i^{l-1}\left(j,j'\right)\right)\right)\&\mathbf{A}_i(j,j')=1 \end{aligned} Ci(l−1)(j,j′)=v⋅(Fi(l−1)(j)Z(l−1)⊕Fi(l−1)(j′)Z(l−1))TFil(j)=σ j′∑xFi(l−1)(j′)Z(l−1) s.t.x=softmax(σ(Cil−1(j,j′)))&Ai(j,j′)=1其中: l ∈ [ 1 , a 2 ] l \in [1, a_2] l∈[1,a2]表示运行 a 2 a_2 a2次迭代聚合、 F i l ( j ) \mathbf{F}^{l}_i(j) Fil(j)是第 l l l层GNN中第 j j j个节点 B i , j \mathcal{B}_{i,j} Bi,j的 D l D_l Dl维特征向量、 Z ( l − 1 ) \mathbf{Z}^{(l-1)} Z(l−1)表示形状为 D ( l − 1 ) × D l D^{(l-1)} \times D^l D(l−1)×Dl的特征转换矩阵。此外, v ∈ R 1 × 2 D l v \in \mathbb{R}^{1 \times 2D^l} v∈R1×2Dl表示自注意力机制的特征向量、 C i ( j , j ′ ) \mathbf{C}_i(j, j') Ci(j,j′)是邻居 B i , j \mathcal{B}_{i,j} Bi,j相对于其目标 B i , j ′ \mathcal{B}_{i,j'} Bi,j′的重要性系数,其需要通过softmax函数获得,以及 & \& &表示逻辑操作。基于注意力的节点特征池化函数,RGMIL获得当前训练区块的最终包图特征矩阵 E ∈ R N × D a 2 E \in \mathbb{R}^{N \times D^{{a_2}}} E∈RN×Da2:
C i ( a 2 ) ( j ) = softmax ( v ′ ( F i a 2 ( j ) Z ′ ) T ) E ( i ) = ∑ j = 1 M i C i a 2 ( j ) F i a 2 ( j ) (12) \tag{12} \begin{aligned} &\mathbf{C}^{(a_2)}_i(j) = \text{softmax}\left( v' (\mathbf{F}^{a_2}_i(j) \mathbf{Z}')^T \right)\\ &\mathbf{E}(i) = \sum_{j=1}^{M_i} \mathbf{C}^{a_2}_i(j) \mathbf{F}^{a_2}_i(j) \end{aligned} Ci(a2)(j)=softmax(v′(Fia2(j)Z′)T)E(i)=j=1∑MiCia2(j)Fia2(j)(12)其中: F i a 2 ( j ) \mathbf{F}^{a_2}_i(j) Fia2(j)表示最后一层中第 j j j个节点的 D a 2 D^{a_2} Da2维特征向量,其相应的重要性系数为 C i a 2 ( j ) \mathbf{C}^{a_2}_i(j) Cia2(j)。 v ′ v' v′和 Z ′ \mathbf{Z}' Z′分别是注意力机制的查询向量和线性变换矩阵。 E ( i ) \mathbf{E}(i) E(i)是 E \mathbf{E} E的第 i i i行,也是包图 G i G_i Gi的特征向量。结合包图标签 Y \mathcal{Y} Y,GNN损失表示为:
L GNN = − ∑ i = 1 N Y ‾ i log ( E ( i ) Z ‾ ) T (13) \tag{13} \mathcal{L}_{\text{GNN}} = -\sum_{i=1}^{N} \overline{\mathcal{Y}}_i \log\left(\mathbf{E}(i) \overline{\mathbf{Z}}\right)^T LGNN=−i=1∑NYilog(E(i)Z)T(13)其中: Z ‾ \overline{\mathbf{Z}} Z是图分类器、 Y ‾ i \overline{\mathcal{Y}}_i Yi是包图 G i G_i Gi的标签向量,由 Y i ∈ Y \mathcal{Y}_i \in \mathcal{Y} Yi∈Y扩展而来。
为了提高GNN效率,RGMIL在GNN框架中引入了参数共享机制,其层数固定为最大动作值。这样,RGMIL每次只需要使用并微调GNN框架的前 a 2 a_2 a2层。RGMIL避免了每次重建和重新训练新GNN时消耗大量时间和空间资源。此外,RGMIL记录了每个动作组合的出现次数。如果 ( a 1 , a 2 ) (a_1, a_2) (a1,a2)的记录超过了预定义的数量,当前的GNN训练过程将被省略。
3.3 奖励计算
获得动作组合并优化GNN之后,RGMIL将通过在验证数据区块上计算即时奖励来评估该组合。具体来说,在RGMIL建模的完全合作MG中,所有智能体拥有相同的联合奖励 (也称为团队奖励)。由于GNN模型旨在提高表示学习,奖励是基于相邻时间步上的包图分类性能差异来计算的。类似地,RGMIL根据动作
a
1
a_1
a1处理验证样本,并将它们输入到具有
a
2
a_2
a2层的模型中。奖励函数
RWD
(
⋅
)
\text{RWD}(\cdot)
RWD(⋅):
r
∗
=
RWD
(
a
1
,
a
2
)
=
EVL
(
t
)
−
1
μ
∑
t
′
t
EVL
(
t
′
)
s.t.
t
′
=
t
−
μ
+
1
(14)
\tag{14} \begin{aligned} &r^* = \text{RWD}(a_1, a_2) = \text{EVL}(t) - \frac{1}{\mu} \sum_{t'}^{t} \text{EVL}(t')\\ &\text{s.t.}\quad t'=t-\mu+1 \end{aligned}
r∗=RWD(a1,a2)=EVL(t)−μ1t′∑tEVL(t′)s.t.t′=t−μ+1(14)其中
t
t
t表示当前步、
EVL
(
⋅
)
\text{EVL}(\cdot)
EVL(⋅)是分类性能评估函数、
μ
\mu
μ表示历史记录窗口大小。RGMIL平均
μ
\mu
μ个历史记录以确保奖励的可靠性以及博弈的稳定性。特别的,
μ
\mu
μ还作为动作组合的预定义记录数量。
3.4 状态转移和终止
RGMIL引入了一种新颖的启发式状态转移函数来获取下一个全局状态和观察:
- RGMIL根据当前动作组合计算下一个全局状态的数据区块索引。考虑到
a
1
a_1
a1和
a
2
a_2
a2分别属于小数和整数,RGMIL将它们视为不同的状态转移依赖性。下一个状态对应的数据区块索引
k
k
k计算如下:
k = ( ( round ( a 1 ) + a 2 ) % ∣ S ∣ ) + 1 s.t. k ∈ [ 1 , ∣ S ∣ ] & round ( a 1 ) ∈ { 0 , 1 } & a 2 > ∣ S ∣ (15) \tag{15} \begin{aligned} &k = ((\text{round}(a_1) + a_2) \%|S|) + 1\\ &\text{s.t. } k \in [1, |S|] \& \text{round}(a_1) \in \{0, 1\} \& a_2 > |S| \end{aligned} k=((round(a1)+a2)%∣S∣)+1s.t. k∈[1,∣S∣]&round(a1)∈{0,1}&a2>∣S∣(15)其中: round ( ⋅ ) \text{round}(\cdot) round(⋅)表示四舍五入、 % \% %是余数操作。动作 a 2 a_2 a2较大以确保覆盖训练区块,而 round ( a 1 ) \text{round}(a_1) round(a1)则提供小偏移以增加变化。由于 k k k主要受 a 2 a_2 a2的影响,RGMIL避免了由于后期两个动作同时剧烈波动可能导致的博弈不收敛问题; - 通过第3.1节介绍的方法构建下一个观察;
- 一个时间步的经验
⟨
(
o
1
,
o
2
)
,
(
a
1
,
a
2
)
,
r
∗
,
(
o
1
′
,
o
2
′
)
⟩
\langle (o_1, o_2), (a_1, a_2), r^*, (o'_1, o'_2) \rangle
⟨(o1,o2),(a1,a2),r∗,(o1′,o2′)⟩被存储起来。转移将不会终止,直到达到最后一个时间步
T
T
T,或者在较早的中间步
t
t
t (其中
t
≤
T
t \leq T
t≤T) 满足以下终止条件:
∣ 1 μ ∑ t ′ t r ∗ t ′ ∣ < λ , s.t. t ′ = t − μ + 1 (16) \tag{16} \left|\frac{1}{\mu} \sum_{t'}^{t} r^{*t'}\right| < \lambda,\qquad\text{s.t.}\ \ t'=t-\mu+1 μ1t′∑tr∗t′ <λ,s.t. t′=t−μ+1(16)其中:不等式符号表示过去 μ \mu μ个奖励的平均值没有超过预定义阈值 λ \lambda λ,以及 r ∗ t ′ r^{*t'} r∗t′是过去时间步 t ′ t' t′的联合奖励。
3.5 多智能体训练
当历史经验的数量大于 μ \mu μ并且博弈尚未结束时,RGMIL需要在完成上述过程后通过经验回放训练两个智能体。由于VDN证明了联合 Q Q Q函数可以分解为不同智能体的 Q Q Q函数,因此RGMIL以值分解的方式更新智能体:
- RGMIL通过反向传播将联合 Q Q Q值分解给每个智能体。两个智能体将通过测量它们对联合 Q Q Q值的贡献来积极地朝着共同的目标工作;
- 给定在时间步
t
t
t收集的经验元组
⟨
(
o
1
t
,
o
2
t
)
,
(
a
1
t
,
a
2
t
)
,
r
t
∗
,
(
o
1
t
+
1
,
o
2
t
+
1
)
⟩
\langle (o^t_1, o^t_2), (a^t_1, a^t_2), r^*_t, (o^{t+1}_1, o^{t+1}_2) \rangle
⟨(o1t,o2t),(a1t,a2t),rt∗,(o1t+1,o2t+1)⟩,智能体的联合损失计算如下:
L π ∗ = E ⟨ s , a , r , s ′ ⟩ [ ( ( π ‾ ∗ ( o ∗ t , a ∗ t ) − π ∗ ( o ∗ t , a ∗ t ) ) 2 ] (17) \tag{17} L_{\pi^*} = \mathbb{E}_{\langle s,a,r,s'\rangle} \left[ \left( (\overline{\pi}^*(o^{*t},a^{*t}) - \pi^*(o^{*t},a^{*t})\right)^2 \right] Lπ∗=E⟨s,a,r,s′⟩[((π∗(o∗t,a∗t)−π∗(o∗t,a∗t))2](17)其中 π ∗ ( o ∗ t , a ∗ t ) \pi^*(o^{*t},a^{*t}) π∗(o∗t,a∗t)表示预测的联合 Q Q Q值:
π ∗ ( o ∗ t , a ∗ t ) ≈ π 1 ( o 1 t , a 1 t ) + π 2 ( o 2 t , a 2 t ) (18) \tag{18} \pi^*(o^{*t},a^{*t}) \approx \pi_1(o^t_1, a^t_1) + \pi_2(o^t_2, a^t_2) π∗(o∗t,a∗t)≈π1(o1t,a1t)+π2(o2t,a2t)(18)其中: π 1 ( ⋅ ) \pi_1(\cdot) π1(⋅)和 π 2 ( ⋅ ) \pi_2(\cdot) π2(⋅)分别是第一个和第二个智能体的评估网络 (Q函数)、 π ‾ ∗ ( o ∗ t , a ∗ t ) \overline{\pi}^*(o^{*t},a^{*t}) π∗(o∗t,a∗t)表示目标联合 Q Q Q值其类似于公式(18)的总和形式,其中每个加法分量 π ‾ i ( o i t , a i t ) \overline{\pi}_i(o^t_i, a^t_i) πi(oit,ait))可以表示为:
π ‾ i ( o i t , a i t ) = r ∗ t + γ max a π ‾ i ( o i t + 1 , a ) (19) \tag{19} \overline{\pi}_i(o^t_i, a^t_i) = r^{*t} + \gamma \max_{a} \overline{\pi}_i(o^{t+1}_i, a) πi(oit,ait)=r∗t+γamaxπi(oit+1,a)(19)其中: π ‾ i ( ⋅ ) \overline{\pi}_i(\cdot) πi(⋅)是第 i i i个智能体的目标网络,其架构与 π i ( ⋅ ) \pi_i(\cdot) πi(⋅)相同; - RGMIL只训练评估网络,并且每隔 μ \mu μ个时间步将它们的参数复制到目标网络;
- 为了缓解可能的高估问题,RGMIL采用传统的双DQN算法来计算目标Q值,该算法用评估网络确定动作,并用目标网络计算
Q
Q
Q值,如图4右子图所示。因此,公式(19)重写为:
π ‾ i ( o i t , a i t ) = r ∗ t + γ π ‾ i ( o i t + 1 , argmax a π ( o i t + 1 , a ) ) (20) \tag{20} \overline{\pi}_i(o^t_i, a^t_i) = r^{*t} + \gamma \overline{\pi}_i \left(o^{t+1}_i, \text{argmax}_a \pi\left(o^{t+1}_i, a\right) \right) πi(oit,ait)=r∗t+γπi(oit+1,argmaxaπ(oit+1,a))(20)