摘要:药物推荐是智能医疗系统的一个重要方面,因为它涉及根据患者的特定健康需求开具最合适的药物。不幸的是,目前使用的许多复杂模型往往忽视医疗数据的细微语义,而仅仅严重依赖于标识信息。此外,这些模型在处理首次就诊患者的病例时面临重大挑战,因为它们缺乏先前的处方记录可供参考。为了解决这些问题,我们利用大语言模型(LLMs)强大的语义理解和输入无关特性。我们的研究旨在使用大语言模型变革现有的药物推荐方法。在本文中,我们引入了一种名为大语言模型蒸馏药物推荐(LEADER)的新方法。我们首先创建合适的提示模板,使大语言模型能够有效地推荐药物。然而,将大语言模型直接整合到推荐系统中存在特定药物的语料库外问题。我们通过为大语言模型适配一个新颖的输出层和优化的微调损失函数来解决这个问题。尽管基于大语言模型的模型表现出卓越的能力,但它们在推理过程中计算成本高昂,这对于医疗保健领域来说不切实际。为了缓解这一问题,我们开发了一种特征级知识蒸馏技术,该技术将大语言模型的专业知识转移到更紧凑的模型中。在两个真实世界数据集MIMIC - III和MIMIC - IV上进行的实验表明,我们提出的模型不仅能产生有效的结果,而且效率很高。为了便于实验的可重复性,我们在网上发布了实现代码 。 https://github.com/liuqidong07/LEADER-pytorch
指标术语- -用药推荐;大型语言模型;知识蒸馏;
I. INTRODUCTION
处方作为患者治疗的关键环节,劳动强度大,且需要专业知识[8]。自动化药物推荐系统通过提供决策支持,有望为不堪重负的医疗保健专业人员减轻负担[9]。当代药物推荐模型主要专注于根据患者的诊断和诊疗数据生成药物推荐。尽管已取得重大进展,但仍存在两个主要挑战:(i)缺乏语义理解:现有模型[1 - 4]主要通过标识数据捕捉药物、诊断和诊疗之间的协作信息。然而,处方中的语义理解,尤其是在医疗情境中[10],在药物推荐中常常被忽视。(ii)单次就诊患者的挑战:近期研究[5 - 7]表明,处方历史是当前处方实践中的一个关键因素。如表1所示,像MICRON[5]、COGNet[6]和REFINE[7]等模型将历史用药记录作为输入,以提高性能。然而,这种对历史数据的依赖在为首次就诊患者(即单次就诊患者)推荐药物时带来了重大挑战。在现实世界的医疗保健系统中,排除单次就诊患者是不可接受的,这表明这是一个亟待改进的关键领域。
大语言模型的出现[11]为改进现有药物推荐系统提供了契机。一方面,大量研究已证实大语言模型强大的语义理解能力[12]。这使得从医学语义角度优化药物推荐成为可能。另一方面,大语言模型以自然语言作为输入,这使得它们本质上对输入变量的类型和数量不敏感[13]。因此,与一些现有的药物推荐模型不同,基于大语言模型的药物推荐系统可以将任何可想到的变量,包括患者资料和历史处方,纳入模型。这种灵活性使它们能够满足所有患者的需求,无论患者是否有记录在案的医疗史。针对前面提到的两个挑战,将大语言模型应用于药物推荐任务成为一个极具吸引力的解决方案。
一些先驱性工作[14, 15]已迈出了将大语言模型与推荐系统集成的第一步。然而,它们在药物推荐任务中的直接应用受到两个重大问题的阻碍:(i)语料库外问题。众多研究[16 - 18]探索了创建输入提示以调用大语言模型。然而,自然语言输出与所需的语料库内药物之间的不兼容性仍然存在。这一挑战可能导致基于大语言模型的推荐系统给出的推荐并非药物集合的一部分,从而潜在地损害推荐性能。例如,大语言模型可能生成一个在药物库中无法验证的药物名称,导致推荐失败。(ii)高推理成本问题。鉴于大语言模型拥有数十亿参数,它们常常面临高推理延迟和内存问题[19]。虽然一般应用可以利用云计算来满足基于大语言模型服务的实时要求,但由于隐私问题[20],在医院等医疗机构中部署医疗服务很常见。此外,为每个医疗中心配备高性能计算平台存在物流挑战。因此,对于基于大语言模型的药物推荐而言,更高效的解决方案至关重要。
为应对上述挑战,我们引入了通过蒸馏的大语言模型增强药物推荐(LEADER)。在我们将大语言模型适配于药物推荐的方法中,我们首先开发合适的提示模板以激活大语言模型的语义理解能力。具体而言,针对语料库外问题,我们通过引入一个带有相应训练损失的新输出层来增强大语言模型。经过有监督的微调后,大语言模型获得了药物推荐能力,并展现出卓越的性能。然而,基于大语言模型的模型的应用受到高推理成本的阻碍。为解决这一问题,我们致力于将大语言模型的强大能力转移到一个小模型上。详细来说,我们设计了一种特征级蒸馏方法,以增强基于适配大语言模型的小药物推荐模型。本文的贡献如下:
- 我们通过修改大语言模型特定的输出层和微调损失,验证了大语言模型在药物推荐任务上的强大能力。据我们所知,我们是首次探索药物推荐与大语言模型的集成。
- 我们引入了一种特征级知识蒸馏方法,利用大语言模型来增强小模型,从而得到一个高效且有效的药物推荐模型。
- 在两个公共数据集,即MIMIC - III和MIMIC - IV上进行了广泛的实验。实验结果一致表明,所提出的LEADER模型优于当前的基线模型。
表2:LEADER . ’ med . ‘、’ diag . ‘和’ proc . '中使用的概念是用药、诊断和过程的缩写。
II. PRELIMINARY
电子健康记录(EHR)是智能医疗系统的核心要素之一,它收集患者详细的医疗程序数据。在电子健康记录中,患者的数据可以按其就诊情况进行处理。假设数据库中有
N
N
N名患者,那么患者
z
z
z的记录表示为
X
(
z
)
=
[
X
1
(
z
)
,
…
,
X
i
(
z
)
,
…
,
X
T
z
(
z
)
]
\mathcal{X}^{(z)} = [\mathcal{X}_{1}^{(z)}, \ldots, \mathcal{X}_{i}^{(z)}, \ldots, \mathcal{X}_{T_{z}}^{(z)}]
X(z)=[X1(z),…,Xi(z),…,XTz(z)],其中
T
z
T_{z}
Tz是该患者的就诊次数。为简便起见,以下省略患者标识
(
z
)
(z)
(z)。由于在现实世界中,诊断和诊疗程序对于处方开具至关重要[3, 4],所以每次就诊记录中都包含这两个要素以及用药信息。在第
i
i
i次就诊时,记录表示为
X
i
=
{
M
i
,
D
i
,
P
i
}
\mathcal{X}_{i} = \{\mathcal{M}_{i}, \mathcal{D}_{i}, \mathcal{P}_{i}\}
Xi={Mi,Di,Pi}。
患者可能服用多种药物,并接受多项诊断和诊疗程序,所以设 M i = { m 1 , … , m j , … , m ∣ M ∣ } \mathcal{M}_{i} = \{m_{1}, \ldots, m_{j}, \ldots, m_{|\mathcal{M}|}\} Mi={m1,…,mj,…,m∣M∣}, D i = { d 1 , … , d j , … , d ∣ D ∣ } \mathcal{D}_{i} = \{d_{1}, \ldots, d_{j}, \ldots, d_{|\mathcal{D}|}\} Di={d1,…,dj,…,d∣D∣}, P i = { p 1 , … , p j , … , p ∣ P ∣ } \mathcal{P}_{i} = \{p_{1}, \ldots, p_{j}, \ldots, p_{|\mathcal{P}|}\} Pi={p1,…,pj,…,p∣P∣}分别表示用药、诊断和诊疗程序的集合。 ∣ M ∣ |\mathcal{M}| ∣M∣、 ∣ D ∣ |\mathcal{D}| ∣D∣和 ∣ P ∣ |\mathcal{P}| ∣P∣分别表示它们的总数。患者的一些人口统计学特征,如年龄、性别等,也很关键,这些特征标记为 P P P。我们在表II中列出了本文的重要符号。
药物推荐旨在根据患者所有可能的医疗数据给出合适的用药集合
M
T
\mathcal{M}_{T}
MT。如前所述,许多现有方法采用患者的历史处方记录以实现更准确的推荐,这要求患者有多次就诊记录,即
T
>
1
T > 1
T>1。然而,它们无法处理
T
=
1
T = 1
T=1的单次就诊患者。在本文中,我们探索为这两种类型的患者推导模型。因此,我们分别定义问题。对于单次就诊患者,根据
{
D
T
,
P
T
,
P
}
\{\mathcal{D}_{T}, \mathcal{P}_{T}, P\}
{DT,PT,P}推荐
M
T
\mathcal{M}_{T}
MT。对于多次就诊患者,基于
{
X
1
,
…
,
X
T
−
1
,
D
T
,
P
T
,
P
}
\{\mathcal{X}_{1}, \ldots, \mathcal{X}_{T - 1}, \mathcal{D}_{T}, \mathcal{P}_{T}, P\}
{X1,…,XT−1,DT,PT,P}给出
M
T
\mathcal{M}_{T}
MT。
III. METHOD
图1:所提出的LEADER框架概述,其包含两个训练阶段。第一阶段是对教师药物推荐模型(即大语言模型)进行有监督的微调。在第二阶段,我们通过知识蒸馏训练所设计的学生药物推荐模型。为了提高效率,推理时仅使用学生模型。
在本节中,将介绍所提出的LEADER模型的详细信息。首先,我们将在III - A部分给出总体概述。然后,在III - B部分说明针对药物推荐对大语言模型(LLM)所做的修改。在III - C部分,我们将阐述用于将大语言模型强大的语义理解能力转移到小模型的蒸馏方法。最后,在III - D部分详细介绍优化和推理的过程。
A. 概述
所提出的LEADER模型的概述如图1所示。为了利用大语言模型,我们设计提示模板,将患者的电子健康记录格式化为自然语言。然后,修改输出和微调损失函数,以更好地适应药物推荐任务,该任务可被视为一个多分类问题。尽管大语言模型最近已被证明具有出色的能力[21 - 23],但它们面临高推理成本的问题,这很难被医疗保健系统接受。因此,我们探索通过所提出的知识蒸馏方法,将大语言模型的强大能力转移到设计的小模型上。在图中,基于大语言模型的模型和小模型分别表示为“教师”和“学生”。我们将使用真实标签和来自经过良好微调的教师模型的知识蒸馏损失,从零开始训练学生模型。
B. 用于药物推荐的大语言模型
大语言模型的输入和输出都是自然语言,而在传统的药物推荐模型[3, 4, 6]中,输入和输出是非语义的标识,比如“药物ID:2”。因此,为了将大语言模型应用于药物推荐,我们必须填补这一差距。一方面,我们设计合适的提示模板,将电子健康记录格式化为自然语言,使其能够直接输入到大语言模型中。另一方面,大语言模型用于推荐的语言输出面临语料库外的挑战[17, 24],所以我们用分类输出层替换原来的语言头。相应地,大语言模型的微调目标也进行了修改。接下来,我们将在以下部分详细介绍提示模板和输出层。
1) 提示模板
我们设计提示模板
T
\mathcal{T}
T,以得出患者电子健康记录(EHR)的语言表示
P
(
z
)
\mathcal{P}^{(z)}
P(z),这可以指导大语言模型理解患者的健康状况。设计的模板如下:
在该模板中,带下划线的位置将填入电子健康记录(EHR)数据。“<VISIT_NUM>”是一位患者的就诊次数。蓝色部分表示患者的历史记录({\mathcal{X}1, \ldots, \mathcal{X}{T - 1}})。然而,我们认为首次就诊的患者同样重要。对于这些单次就诊患者,他们的提示中没有这部分内容。此外,为利用大语言模型的语义理解能力,诊断、诊疗程序和药物均以其名称表示。因此,“<DIAG_NAME>”、“<PROC_NAME>”和“<MED_NAME>”在提示中都是标准医学术语。构建提示后,大语言模型可以从语言输入中理解药物推荐。
2) 输出层
大多数现有的基于大语言模型的推荐系统[25, 26]以自然语言输出推荐的名称或标识,但面临语料库外的挑战。为解决这个问题,我们用一个伴随有sigmoid函数的线性层替换预训练的词元生成层。然后,修改后的大语言模型的输出是每种药物的概率。
y ^ = σ ( W C L S ⋅ h ) (1) \hat{\mathbf{y}} = \sigma(\mathbf{W}_{CLS} \cdot \mathbf{h}) \tag{1} y^=σ(WCLS⋅h)(1)
其中, y ^ ∈ R ∣ M ∣ × 1 \hat{\mathbf{y}} \in \mathbb{R}^{|\mathcal{M}| \times 1} y^∈R∣M∣×1和 h ∈ R d h × 1 \mathbf{h} \in \mathbb{R}^{d_h \times 1} h∈Rdh×1分别是药物的预测概率和大语言模型中最后一个Transformer层的隐藏状态。 W C L S ∈ R ∣ M ∣ × d h \mathbf{W}_{CLS} \in \mathbb{R}^{|\mathcal{M}| \times d_h} WCLS∈R∣M∣×dh是一个可学习的权重矩阵, σ ( ⋅ ) \sigma(\cdot) σ(⋅)表示sigmoid函数。对于最终推荐,将设置一个阈值 γ \gamma γ。当 y k > γ y_k > \gamma yk>γ时,药物 k k k将被纳入处方药物集合。
3) 优化
由于我们更新了大语言模型的输出层,有监督微调(SFT)是必要的。同时,有监督微调能帮助大语言模型完成特定任务[10, 27]。然而,条件语言建模目标[21, 22]不适用于修改后的大语言模型,因为输出层是用于分类的。为了更好地适配药物推荐任务和输出层,我们将有监督微调的损失函数修改如下:
L S F T = − ∑ i = 1 N y ( i ) log ( y ^ ( i ) ) + ( 1 − y ( i ) ) log ( 1 − y ^ ( i ) ) (2) \mathcal{L}_{SFT} = - \sum_{i = 1}^{N} \mathbf{y}^{(i)} \log(\hat{\mathbf{y}}^{(i)}) + (1 - \mathbf{y}^{(i)}) \log(1 - \hat{\mathbf{y}}^{(i)}) \tag{2} LSFT=−i=1∑Ny(i)log(y^(i))+(1−y(i))log(1−y^(i))(2)
在这个公式中, y \mathbf{y} y是真实的药物标签。值得注意的是,微调大语言模型的所有参数成本极高。因此,在本文中我们采用LoRA[28]微调,它只更新低秩矩阵集,同时冻结大语言模型的预训练权重。设 { A i , B i } i = 1 L \{\mathbf{A}_i, \mathbf{B}_i\}_{i = 1}^{L} {Ai,Bi}i=1L表示可训练矩阵集,其中 L L L是伴随LoRA层的层数。然后,在有监督微调期间,只有参数 W C L S \mathbf{W}_{CLS} WCLS和 { A i , B i } i = 1 L \{\mathbf{A}_i, \mathbf{B}_i\}_{i = 1}^{L} {Ai,Bi}i=1L是可训练的,并且通过正态分布进行初始化。
C. 通过蒸馏进行增强
尽管大语言模型具备强大的语义理解能力,但它们需要高推理内存且存在延迟问题。这对于医疗保健系统来说是不可接受的,因此我们旨在将大语言模型的能力转移到一个相对较小的模型上。知识蒸馏[29]是一种很有前景的方法,但学生模型架构和具体的蒸馏方法仍有待确定。
1) 学生模型设计
考虑到效率问题,学生模型采用标识而非语义术语。如在第二部分所述,输入变量可以写成 { D 1 , … , D T ; P 1 , … , P T ; M 1 , … , M T − 1 ; P } \{\mathcal{D}_1, \ldots, \mathcal{D}_T; \mathcal{P}_1, \ldots, \mathcal{P}_T; \mathcal{M}_1, \ldots, \mathcal{M}_{T - 1}; P\} {D1,…,DT;P1,…,PT;M1,…,MT−1;P},其中 D \mathcal{D} D、 P \mathcal{P} P和 M \mathcal{M} M分别是诊断、诊疗程序和药物的集合。
为了从每种类型的集合中捕获协作信息,我们为它们设计了三个同质编码器,分别表示为 E D i a g \mathcal{E}_{Diag} EDiag、 E P r o c \mathcal{E}_{Proc} EProc和 E M e d \mathcal{E}_{Med} EMed。为简洁起见,我们仅以 E D i a g \mathcal{E}_{Diag} EDiag为例进行说明。我们首先推导一个嵌入表 E d ∈ R ∣ D ∣ × d e \mathbf{E}_d \in \mathbb{R}^{|\mathcal{D}| \times d_e} Ed∈R∣D∣×de,其中每一行对应诊断的唯一代码。 d e d_e de表示嵌入表的维度。然后,诊断代码集合 D i \mathcal{D}_i Di被 E d \mathbf{E}_d Ed转换为一组向量,记为 D ‾ i = [ d 1 , … , d ∣ D i ∣ ] \overline{\mathcal{D}}_i = [\mathbf{d}_1, \ldots, \mathbf{d}_{|\mathcal{D}_i|}] Di=[d1,…,d∣Di∣]。接下来,我们提议采用Transformer架构对每个集合中包含的相互关系进行编码。由多头注意力和前馈网络组成的一对Transformer层可以写成:
M = L a y e r N o r m ( D ~ i , M u l t i H e a d ( D ~ i , D ~ i , D ~ i ) ) M = LayerNorm(\widetilde{D}_i, MultiHead(\widetilde{D}_i, \widetilde{D}_i, \widetilde{D}_i)) M=LayerNorm(D i,MultiHead(D i,D i,D i)) (3)
其中, L a y e r N o r m ( ⋅ ) LayerNorm(\cdot) LayerNorm(⋅) 和 M u l t i H e a d ( ⋅ ) MultiHead(\cdot) MultiHead(⋅) 分别表示层归一化和多头注意力机制。Transformer层的另一个组件是带有残差连接的前馈神经网络,其公式如下:
D ^ ( 1 ) = L a y e r N o r m ( M , F N N ( M ) ) \widehat{D}^{(1)} = LayerNorm(M, FNN(M)) D (1)=LayerNorm(M,FNN(M)) (4)
其中 F N N ( ⋅ ) FNN(\cdot) FNN(⋅) 是一个可训练的线性层。第一个Transformer层的输出记为 D ^ ( 1 ) \widehat{D}^{(1)} D (1),它是一个向量序列。然后,我们对 E D i a g \mathcal{E}_{Diag} EDiag 的最后一个Transformer层的输出应用平均池化,得到诊断集的表示,即 D i ∈ R d t \mathbf{D}_i \in \mathbb{R}^{d_t} Di∈Rdt。
D i = A v g _ p o o l ( D ^ ( L d ) ) \mathbf{D}_i = Avg\_pool(\widehat{D}^{(L_d)}) Di=Avg_pool(D (Ld)) (5)
其中 L d L_d Ld 表示 E D i a g \mathcal{E}_{Diag} EDiag 中Transformer层的数量。通过诊断编码器,输入的诊断记录 { D 1 , … , D T } \{\mathcal{D}_1, \ldots, \mathcal{D}_T\} {D1,…,DT} 被转换为一组向量,即 [ D 1 , … , D T ] [\mathbf{D}_1, \ldots, \mathbf{D}_T] [D1,…,DT]。类似地,我们可以使用具有相同结构的 E P r o c \mathcal{E}_{Proc} EProc 和 E M e d \mathcal{E}_{Med} EMed 分别得到操作和药物集的表示。
然后,我们设计了一个就诊编码器 E V i s i t \mathcal{E}_{Visit} EVisit 来捕捉患者的历史健康状况。具体来说, E V i s i t \mathcal{E}_{Visit} EVisit 同样由几个Transformer层堆叠而成,与 E D i a g \mathcal{E}_{Diag} EDiag 相同。因此, E V i s i t \mathcal{E}_{Visit} EVisit 会将诊断记录序列编码为一个嵌入向量 D ~ \widetilde{\mathbf{D}} D ,其表达式如下:
D ~ = E V i s i t ( [ D 1 , … , D T ] ) \widetilde{\mathbf{D}} = \mathcal{E}_{Visit}([\mathbf{D}_1, \ldots, \mathbf{D}_T]) D =EVisit([D1,…,DT]) (6)
以同样的方式,我们可以得到历史操作和用药记录的表示,分别记为 P ~ \widetilde{\mathbf{P}} P 和 M ~ \widetilde{\mathbf{M}} M 。值得注意的是,这三种类型的记录共享就诊编码器 E V i s i t \mathcal{E}_{Visit} EVisit,因为这样的设计不仅可以减少参数数量,还有助于学习共享的医学知识 [2]。
对于学生模型来说,另一个挑战是单就诊患者的情况,因为当 T = 1 T = 1 T=1 时, E V i s i t \mathcal{E}_{Visit} EVisit 的用药记录输入为空。在这里,我们提议使用档案信息作为伪用药记录,因为档案可以反映患者的健康状况。具体来说,像年龄这样的档案特征被离散化,然后通过嵌入矩阵进行编码。所有档案特征的表示被连接起来,然后投影到一个 d t d_t dt 维向量,记为 P \mathbf{P} P。档案向量将被插入到用药记录序列中,所以 E V i s i t \mathcal{E}_{Visit} EVisit 的用药输入变为 [ M 1 , … , M T − 1 , P ] [\mathbf{M}_1, \ldots, \mathbf{M}_{T - 1}, \mathbf{P}] [M1,…,MT−1,P]。
最后,我们将 D ~ \widetilde{\mathbf{D}} D 、 P ~ \widetilde{\mathbf{P}} P 和 M ~ \widetilde{\mathbf{M}} M 连接起来,并采用两个线性层进行最终的用药推荐。
y ^ = σ ( W 2 ( W 1 ⋅ [ D ~ ∣ ∣ P ~ ∣ ∣ M ~ ] + b 1 ) + b 2 ) \hat{\mathbf{y}} = \sigma(\mathbf{W}_2(\mathbf{W}_1 \cdot [\widetilde{\mathbf{D}}||\widetilde{\mathbf{P}}||\widetilde{\mathbf{M}}] + \mathbf{b}_1) + \mathbf{b}_2) y^=σ(W2(W1⋅[D ∣∣P ∣∣M ]+b1)+b2) (7)
其中 W 1 ∈ R 3 d t × d t \mathbf{W}_1 \in \mathbb{R}^{3d_t \times d_t} W1∈R3dt×dt, W 2 ∈ R d t × ∣ M ∣ \mathbf{W}_2 \in \mathbb{R}^{d_t \times |\mathcal{M}|} W2∈Rdt×∣M∣, b 2 ∈ R 1 × d t \mathbf{b}_2 \in \mathbb{R}^{1 \times d_t} b2∈R1×dt 且 b 2 ∈ R 1 × ∣ M ∣ \mathbf{b}_2 \in \mathbb{R}^{1 \times |\mathcal{M}|} b2∈R1×∣M∣ 是可训练参数。那么,真实标签的损失函数写为:
L b c e = − ∑ i = 1 N y ( i ) l o g ( y ^ ( i ) ) + ( 1 − y ( i ) ) l o g ( 1 − y ^ ( i ) ) \mathcal{L}_{bce} = -\sum_{i = 1}^{N} \mathbf{y}^{(i)} log(\hat{\mathbf{y}}^{(i)}) + (1 - \mathbf{y}^{(i)}) log(1 - \hat{\mathbf{y}}^{(i)}) Lbce=−∑i=1Ny(i)log(y^(i))+(1−y(i))log(1−y^(i)) (8)
2)知识蒸馏
为了将基于大语言模型(LLM)的模型的强大能力转移到学生模型上,我们提出了一种特征级知识蒸馏方法。由于大语言模型在记忆方面很擅长 [30, 31],它们能够以相对较高的准确率预测训练集中的样本。这将导致大语言模型的预测与真实标签相似,不适合用于蒸馏。因此,我们提议通过大语言模型的隐藏状态来蒸馏学生模型。
隐藏状态 h \mathbf{h} h 是大语言模型最后一个Transformer层的表示。在传统的预训练大语言模型中,这个隐藏状态用于通过一个线性层生成词元,所以它包含全面的语义信息。在修改后的大语言模型中,由于 h \mathbf{h} h 可以通过一个分类层输出用药的概率,考虑到任务的相似性,它也适合指导学生模型。
然而,学生模型中的表示仍然与 h \mathbf{h} h 处于不同的空间,因为学生模型没有语义输入。因此,我们设计了一个可训练的投影器,将隐藏状态转换到大语言模型的表示空间。那么,知识蒸馏的损失可以写为:
L K D = 1 N ∑ i = 1 N ∣ h i − W p r o j ⋅ ( W 1 ⋅ [ D ~ i ∣ ∣ P ~ i ∣ ∣ M ~ i ] + b 1 ) ∣ 2 \mathcal{L}_{KD} = \frac{1}{N} \sum_{i = 1}^{N} | \mathbf{h}_i - \mathbf{W}_{proj} \cdot (\mathbf{W}_1 \cdot [\widetilde{\mathbf{D}}_i||\widetilde{\mathbf{P}}_i||\widetilde{\mathbf{M}}_i] + \mathbf{b}_1) |^2 LKD=N1∑i=1N∣hi−Wproj⋅(W1⋅[D i∣∣P i∣∣M i]+b1)∣2 (9)
其中 W p r o j ∈ R d t × d h \mathbf{W}_{proj} \in \mathbb{R}^{d_t \times d_h} Wproj∈Rdt×dh 是投影层的权重。请注意,在蒸馏过程中,学生模型的所有参数和 W p r o j \mathbf{W}_{proj} Wproj 会被更新,而大语言模型的参数是冻结的。
3 )轮廓对齐:
由于轮廓特征设计为伪用药记录,我们的模型可以推荐给单次就诊的患者。然而,轮廓和药物集合的表示实际上处于不同的空间,这给训练带来了困难。因此,为了对齐这两种不同类型的表示,我们设计了一种轮廓对齐方法。
受多模态研究中用于模态对齐的对比学习启发 [32, 33],我们提出一种对比损失函数,用于对齐档案特征和用药集合。为了实现更好的性能 [34],我们首先将档案特征表示 P \mathbf{P} P 和目标用药集合 M T \mathbf{M}_T MT 投影到一个新空间:
Z P = W p r o j P ⋅ P \mathbf{Z}_P = \mathbf{W}_{proj}^P \cdot \mathbf{P} ZP=WprojP⋅P (10)
Z M = W p r o j M ⋅ M T \mathbf{Z}_M = \mathbf{W}_{proj}^M \cdot \mathbf{M}_T ZM=WprojM⋅MT (11)
其中, W p r o j P ∈ R d t × d t \mathbf{W}_{proj}^P \in \mathbb{R}^{d_t \times d_t} WprojP∈Rdt×dt 和 W p r o j M ∈ R d t × d t \mathbf{W}_{proj}^M \in \mathbb{R}^{d_t \times d_t} WprojM∈Rdt×dt 是投影矩阵。设 [ Z P 1 , … , Z P B ] [\mathbf{Z}_P^1, \ldots, \mathbf{Z}_P^B] [ZP1,…,ZPB] 和 [ Z M 1 , … , Z M B ] [\mathbf{Z}_M^1, \ldots, \mathbf{Z}_M^B] [ZM1,…,ZMB] 分别表示一批档案特征和用药的表示,其中 B B B 是批量大小。当 i = j i = j i=j 时,我们将 Z P i \mathbf{Z}_P^i ZPi 和 Z M j \mathbf{Z}_M^j ZMj 视为正样本对。那么,档案特征的对比损失函数可定义为:
L P M = − 1 B ∑ i = 1 B log exp ( s i m ( Z P i , Z M i ) / τ ) ∑ j = 1 B I [ i ≠ j ] exp ( s i m ( Z P i , Z M j ) / τ ) \mathcal{L}_{PM} = -\frac{1}{B} \sum_{i = 1}^{B} \log \frac{\exp(sim(\mathbf{Z}_P^i, \mathbf{Z}_M^i)/\tau)}{\sum_{j = 1}^{B} \mathbb{I}_{[i \neq j]} \exp(sim(\mathbf{Z}_P^i, \mathbf{Z}_M^j)/\tau)} LPM=−B1∑i=1Blog∑j=1BI[i=j]exp(sim(ZPi,ZMj)/τ)exp(sim(ZPi,ZMi)/τ) (11)
其中, I [ i ≠ j ] \mathbb{I}_{[i \neq j]} I[i=j] 表示指示函数, τ \tau τ 表示损失函数中的温度参数。同样地,我们也可以推导出用药的对比损失函数 L M P \mathcal{L}_{MP} LMP。因此,对齐损失函数为这两个损失函数之和:
L a l i g n = ∑ N L P M + L M P \mathcal{L}_{align} = \sum_{N} \mathcal{L}_{PM} + \mathcal{L}_{MP} Lalign=∑NLPM+LMP (12)
D. 训练与推理
所提出的 LEADER 模型需要两阶段优化。在第一阶段,我们需要通过公式 (2) 优化修改后的大语言模型(LLM)。微调后的大语言模型将作为教师模型,称为 LEADER(T)。在第二阶段,被称为 LEADER(S) 的学生模型,通过结合真实标签损失、知识蒸馏损失和档案对齐损失,从零开始进行训练,即:
L = L b c e + α ⋅ L K D + β ⋅ L a l i g n \mathcal{L} = \mathcal{L}_{bce} + \alpha \cdot \mathcal{L}_{KD} + \beta \cdot \mathcal{L}_{align} L=Lbce+α⋅LKD+β⋅Lalign (13)
其中, α \alpha α 和 β \beta β 是超参数,用于调整蒸馏和对齐的规模。优化后,LEADER(T) 和 LEADER(S) 都可以完成用药推荐任务,但输入格式不同。为了更清楚地展示训练和推理过程,我们总结了算法 1。
首先,我们指定一些必要的超参数,并为大语言模型构建自然语言输入(第 1 - 3 行)。然后,在第一阶段,使用推导得到的语言数据集对修改后的大语言模型进行有监督微调(第 4 - 10 行)。微调后的修改大语言模型可直接用于蒸馏或用药推荐。在第二训练阶段,以自然语言格式呈现的电子健康记录(EHR)和身份信息分别被教师模型和学生模型接收(第 11 - 13 行)。然后,我们通过结合二元交叉熵(BCE)损失、蒸馏损失和对齐损失来更新学生模型(第 14 - 16 行)。在推理方面,我们可以采用 LEADER(S) 或 LEADER(T) 进行最终推荐(第 17 - 19 行)。
IV. EXPERIMENT
在本节中,我们将通过在两个真实世界数据集上进行的全面实验来分析所提出的 LEADER 模型。我们探讨以下研究问题(RQ)以阐明研究结果:
- RQ1:与当前最先进的用药推荐模型和基于大语言模型(LLM)的推荐模型相比,所提出的 LEADER 模型表现如何?
- RQ2:LEADER 的所有设计是否都有效?
- RQ3:设计的知识蒸馏和档案对齐对 LEADER 的性能有何影响?
- RQ4:所提出的学生模型能否高效地进行用药推荐?
A. 实验设置
- 数据集:实验中使用的数据集来自重症监护医学信息数据库(MIMIC)2 。目前有两个版本,即 MIMIC - III 和 MIMIC - IV。MIMIC - III 收集了 2001 年至 2012 年的数据,而 MIMIC - IV 包含 2008 年至 2019 年的记录。我们遵循先前研究 [3, 4] 中的预处理方法。由于篇幅限制,我们将数据集的更详细介绍放在附录 A 中。
- 基线模型:在实验中,我们将 LEADER 与几种最先进的用药推荐模型(RETAIN [1]、G - Bert [2]、GAMENet [3]、SafeDrug [4]、MICRON [5]、COGNet [6]、REFINE [7])以及基于大语言模型的推荐模型(TALLRec [26]、BI - GRec [24]、E4SRec [35])进行比较。基线模型的详细介绍和实现可在附录 B 中查看。我们还比较了第三节 B 部分中提出的修改后的大语言模型,记为 LEADER(T)。此外,在后续实验中,蒸馏后的学生模型标记为 LEADER(S)。
- 实现细节:本文中的所有实验均在配备 Tesla V100 32G GPU 的英特尔至强金牌 6133 平台上进行。代码基于 Python 3.9.5 和 PyTorch 1.12.0。对于基于大语言模型的用药推荐,即 LEADER(T) 以及所有基于大语言模型的推荐基线模型,本文采用 LLaMA - 7B 3 [22] 作为基础模型。此外,对于所有基于大语言模型的模型,我们采用 LoRA [28] 作为微调方法。由于篇幅限制,我们将更多实现细节放在附录 C 中。为了便于模型的复现,我们将代码在线发布 4 。
- 评估指标:遵循先前的研究 [3, 4, 6, 7],我们使用三个常用指标来评估所提出的模型,即精确率 - 召回率曲线下面积(PRAUC ↑)、杰卡德相似系数(Jaccard ↑)和平均 F1 值(F1 ↑)。为了保证实验结果的稳健性,我们在测试过程中采用自助抽样法。具体来说,每轮随机抽取 80% 的样本。下面所示的指标是在 10 轮测试中的平均值。
B. 整体性能(RQ1)
为了回答研究问题(RQ1),我们在表 III 和表 IV 中展示了所提出的方法与其他竞争模型的性能比较,然后对结果进行分析。
表 III:在 MIMIC - III 数据集上,竞争基线模型和 LEADER 模型的总体结果。粗体表示最高分,下划线表示模型的最佳结果。“*”表示相对于最佳基线模型有统计学上的显著改进(即双侧 t 检验,p < 0.05)。“-”表示由于无法处理单就诊患者,模型无法获得相应结果,或者 TALLRec 由于输出的是药物名称而非概率,没有精确率 - 召回率曲线下面积(PRAUC)指标。
表 IV:在 MIMIC - IV 数据集上,竞争基线模型和 LEADER 模型的总体结果。粗体表示最高分,下划线表示模型的最佳结果。“*”表示相对于最佳基线模型有统计学上的显著改进(即双侧 t 检验,p < 0.05)。“-”表示由于无法处理单就诊患者,模型无法获得相应结果,或者 TALLRec 由于输出的是药物名称而非概率,没有精确率 - 召回率曲线下面积(PRAUC)指标。
总体而言,在两个数据集上,LEADER(T) 与所有其他模型相比表现出强大的领先优势,这表明了大语言模型(LLM)的语义理解能力。同时,经过蒸馏的学生模型 LEADER(S) 也优于用药推荐模型和基于大语言模型的推荐模型。这种现象表明了所设计的蒸馏方法的成功。
然后,我们根据不同的患者群体探讨性能比较。如前所述,一些近期的基线模型,例如 G - Bert、MICRON、COGNet 和 REFINE,将历史用药记录视为必要输入之一,所以它们没有单就诊患者的结果。我们首先观察多就诊患者群体。G - Bert 表现最差,因为它没有考虑患者的治疗过程。然后,我们可以发现 MICRON、COGNet 和 REFINE 这三个基线模型在多就诊患者群体中能很好地利用历史处方信息,表现优于其他竞争模型。这样的比较表明,利用先前的用药记录实际上对当前就诊的推荐有帮助。由于大语言模型的强大能力,所提出的 LEADER(T) 能够始终超越所有模型。对于设计的 LEADER(S),它在精确率 - 召回率曲线下面积(PRAUC)指标上优于其他模型,但在杰卡德相似系数(Jaccard)和 F1 值上比 COGNet 差。我们认为原因在于 COGNet 采用集束搜索生成最终推荐,但它面临效率问题。
在单就诊患者群体和总体群体的性能方面,GAMENet 和 SafeDrug 优于 RETAIN,因为它们通过电子健康记录(EHR)图和分子图更细致地对药物之间的关系进行建模。然而,它们仍然持续不如所提出的 LEADER 的两个变体。一方面,LEADER 可以利用历史信息,并在多就诊患者群体中大幅超越基线模型,这有助于提高总体得分。另一方面,由于大语言模型的语义理解能力,LEADER(T) 和 LEADER(S) 在两个数据集的单就诊患者群体中都超越了竞争模型。值得注意的是,在单就诊患者群体中,经过蒸馏的 LEADER(S) 在 PRAUC 指标上甚至优于 LEADER(T)。这种现象表明了结合来自学生模型的协作信号和来自大语言模型的语义信息的好处。
对于基于大语言模型的模型,TALLRec 甚至不如一些用药推荐模型。其性能较差是由药物名称的直接输出导致的,这凸显了语料库外问题。BIGRec 和 E4SRec 在总体和单就诊患者群体上都能获得更高的推荐准确率,这表明了大语言模型强大的语义理解能力。然而,它们仍然落后于所提出的 LEADER。对于 BIGRec,原因在于其基础方法并非最优。就 E4SRec 而言,它仅将协作信号整合到大语言模型中,导致大语言模型的利用不足。
从分析中我们得出结论,所提出的基于大语言模型的用药推荐模型比传统模型表现出更强的语义理解能力和单就诊患者处理能力。此外,所设计的蒸馏方法实际上可以提升得到的学生模型。
表5:在两个数据集上的消融研究。由于篇幅有限,仅在表格中显示PRAUC得分。
C. 消融研究(RQ2)
为了验证为 LEADER 提出的每个组件的有效性,我们进行了消融实验。结果见表 V。首先,我们旨在验证所设计的特征级知识蒸馏对学生模型的影响。“w/o KD”表示在训练 LEADER(S) 时直接去除知识蒸馏损失,而“w/o feature - KD”表示使用学生模型和教师模型输出概率的 KL 散度作为知识蒸馏损失 [36]。
从结果中我们可以发现,这两种变体的性能都大幅低于所提出的 LEADER(S)。性能的急剧下降表明,特征级知识蒸馏实际上可以增强协作式学生模型。此外,与传统的输出级知识蒸馏相比,所设计的特征级知识蒸馏更适合从大语言模型进行知识转移。
然后,我们探索我们为学生模型所做的设计是否合理。在表 V 中,“w/o align”表示我们去掉了第三节 C3 中提出的档案对齐模块。实验结果表明,对齐对单就诊患者群体更有益,这有助于整体性能的提升。原因可能是对齐可以细化档案的表示,对于单就诊患者,档案被视为单一的就诊记录。“w/o shared E V i s i t \mathcal{E}_{Visit} EVisit”表示设计的学生模型对诊断、操作和用药采用单独的就诊编码器。这个变体比 LEADER(S) 差,这表明共享编码器有助于学习更通用的医学知识。作为对 RQ2 的回答,我们可以得出结论,所设计的特征级知识蒸馏以及学生模型中的其他组件对 LEADER(S) 都是有益的。此外,为了验证诸如 QWen 等各种大语言模型的效果,我们将相关实验结果和分析放在附录 D 中。
D. 超参数分析(RQ3)
为了回答 RQ3,我们在训练过程中调整知识蒸馏和档案对齐的强度。图 2 和图 3 分别显示了根据 α \alpha α 和 β \beta β 的性能变化。我们观察到,当 α \alpha α 在一定范围内增加时,LEADER(S) 的性能会提升。这种现象表明,来自基于大语言模型的教师模型的知识转移对协作模型有益。然而,过大的知识蒸馏损失会使模型训练朝着真实标签产生混淆,所以精确率 - 召回率曲线下面积(PRAUC)分数会随着 α \alpha α 的持续增加而下降。因此,MIMIC - III 数据集下 α \alpha α 的最佳值是 0.4。就档案对齐而言,图中显示总体性能趋势是随着 β \beta β 从 0.1 变化到 0,先上升后下降。PRAUC 首先增加的原因是,过大的对比损失强度会对模型收敛产生不利影响。相反,由于对齐有助于细化档案的表示,当 β \beta β 随后降至 0 时,PRAUC 会下降。因此,MIMIC - III 数据集下 β \beta β 的最佳值是 5 e − 3 5e^{-3} 5e−3。
E. 效率分析(RQ4)
如前所述,推理效率是医疗应用中的一个重要问题。因此,我们比较了基于大语言模型(LLM)的模型和协作式学生模型的效率,以回答 RQ4。我们使用延迟时间和 GPU 内存来衡量效率。具体而言,延迟时间是通过将测试集的总推理时间除以测试样本数量来计算的平均值。因此,延迟时间代表了为一位患者完成推荐的平均等待时间。内存是推理所需的最小 GPU 内存。如图 4 所示,我们可以发现 LEADER(T) 的延迟时间比一般的 TALLRec 短。这是由于在生成词元时的集束搜索导致的,而经过修改的大语言模型可以一次输出概率。总之,所提出的对大语言模型的修改能够同时提高有效性和效率。然而,基于大语言模型的用药推荐模型仍然存在推理成本高的问题。从结果来看,所提出的 LEADER(S) 与 LEADER(T) 相比,可以实现 25 倍至 30 倍的推理加速,并且仅需要约 1/15 的 GPU 内存。总之,所设计的基于大语言模型蒸馏的用药推荐模型能够在性能和效率之间取得更好的平衡。
V. RELATED WORKS
A. 用于推荐的大语言模型
最近,大语言模型在推荐系统领域的应用成为了热点 [15, 37, 38]。在用于推荐的大语言模型(LLM4Rec)领域主要有两类工作。一类是可微调的 LLM4Rec,这类工作通常通过微调来使大语言模型适应推荐任务。P5 [25] 首先将推荐任务表述为语言生成任务,然后将各种推荐任务整合到一个统一的语言模型中。它对 T5 [39] 模型进行微调,使其具备生成推荐的能力。随后,在诸如 LLaMA 和 ChatGLM 等更大模型的应用中,带来了更多的性能提升。TALLRec [26] 设计了与用户历史记录相结合的合适指令,并对 LLaMA - 7b 进行微调以完成序列推荐。值得注意的是,参数高效微调 [28, 40] 常被采用,以应对效率问题。InstructRec [41] 构造偏好、意图和任务形式来组成提示输入。为了进一步理解用户并缩短提示长度,PALR [42] 将用户概况摘要而非原始特征插入提示中。更具体地说,一些研究侧重于在提示中突出对推荐系统至关重要的物品标识。Chu 等人 [43] 设计了一种新颖的掩码机制和位置嵌入,以便在微调 GLM 模型时从语言输入中区分物品。此外,E4SRec [35] 提议使用与提示中物品嵌入相关联的 ID 嵌入。RecInterpreter [16] 和 LLaRA [44] 与 E4SRec 有相似的想法,不过它们应用了预训练的序列推荐编码器来识别物品。另一类是非可微调的 LLM4Rec,这类工作主要致力于设计超大规模大语言模型(如 ChatGPT 和 GPT - 4)的工作流程。例如,Chat - Rec [45] 将推荐任务重新表述为对话过程,从而能够利用 ChatGPT 给出合适的推荐。Hou 等人 [46] 提出结合几种类型的提示来提高排名性能。
尽管现有工作在使大语言模型适应推荐方面迈出了早期步伐,但它们仍然面临一些挑战,例如推理成本高和语料库外问题。在本文中,我们提出了一种新颖的方法来解决这两个问题。
B. 用药推荐
近年来,用药推荐因其实际价值而备受关注。在研究的早期阶段,一些工作旨在仔细对当前就诊中诊断与处方之间的关系进行建模。例如,Leap [47] 捕捉了几种诊断之间的相互影响,并将推荐建模为序列决策过程。后来,4SDrug [48] 提出测量症状与药物集之间的相似性以进行推荐。此外,Zhang 等人 [49] 通过知识图谱和属性构建基于图的架构来嵌入症状与药物之间的关系。与仅使用当前就诊信息的模型相比,许多其他工作旨在对历史治疗记录进行建模以获得更好的性能。RETAIN [1] 率先专门为医疗保健开发了一个时间序列预测模型。GAMENet [3] 和 SafeDrug [4] 都利用历史诊断和程序数据进行推荐,并考虑药物 - 药物相互作用问题。G - Bert [2] 引入预训练技术以获得更好的诊断和用药编码器用于最终推荐。此外,一些工作进一步纳入处方历史,这在当时是一个重要参考。例如,MICRON [5] 和 COGNet [6] 都考虑以一定概率将历史处方复制到当前推荐药物集中。REFINE [7] 直接将记录输入到 Transformer 编码器中进行建模。然而,现有模型仅利用身份信息,而忽略了电子健康记录中包含的医学语义。据我们所知,我们是第一个将大语言模型与用药推荐相结合以获取语义知识的。
VI. CONCLUSION
在本文中,我们提出了一种通过蒸馏增强的大语言模型用药推荐方法(LEADER)。为了使大语言模型适应用药推荐任务,我们首先设计合适的提示模板,以为大语言模型生成语言输入。然后,我们替换大语言模型的头部层以缓解语料库外问题,并采用二元交叉熵(BCE)损失对修改后的大语言模型进行微调。然而,基于大语言模型的模型面临推理成本高的挑战。为了提高效率,我们设计了一种特征级知识蒸馏方法,将大语言模型的强大能力转移到一个相对较小的学生模型上。通过在两个公开数据集上进行广泛实验,我们验证了与现有最先进的模型相比,所提出的 LEADER 能够实现有效且高效的用药推荐。在未来的工作中,我们将考虑基于大语言模型的用药推荐中的药物 - 药物相互作用问题,这与处方的安全性相关。