Title
题目
MMGPL: Multimodal Medical Data Analysis with Graph Prompt Learning
MMGPL: 基于图提示学习的多模态医学数据分析
01
文献速递介绍
神经系统疾病,包括自闭症谱系障碍(ASD)(Lord 等,2018)和阿尔茨海默病(AD)(Scheltens 等,2021),严重损害了患者的社交、语言和认知能力,已经成为全球范围内严重的公共卫生问题(Feigin 等,2020)。不幸的是,对于大多数神经系统疾病(例如 ASD 和 AD),目前尚无明确的治愈方法,因此,迫切需要对神经系统疾病进行诊断,以促进早期干预并延缓其恶化(Wingo 等,2021;Zhu 等,2022)。
在过去的十年中,研究人员(Wen 等,2020;Li 等,2021;Dvornek 等,2019)应用了多种机器学习方法,例如卷积神经网络(CNN)(LeCun 和 Bengio,1995)、图神经网络(GNN)(Kipf 和 Welling,2017)和循环神经网络(RNN)(Schuster 和 Paliwal,1997),来诊断神经系统疾病。尽管这些方法取得了显著进展,但由于这些方法直接在小规模和复杂的医学数据集上进行训练,因此难以保证这些深度学习模型的稳健性和有效性(Dinsdale 等,2022)。
Aastract
摘要
Prompt learning has demonstrated impressive efficacy in the fine-tuning of multimodal large models to awide range of downstream tasks. Nonetheless, applying existing prompt learning methods for the diagnosisof neurological disorder still suffers from two issues: (i) existing methods typically treat all patches equally,despite the fact that only a small number of patches in neuroimaging are relevant to the disease, and (ii) theyignore the structural information inherent in the brain connection network which is crucial for understandingand diagnosing neurological disorders. To tackle these issues, we introduce a novel prompt learning modelby learning graph prompts during the fine-tuning process of multimodal models for diagnosing neurologicaldisorders. Specifically, we first leverage GPT-4 to obtain relevant disease concepts and compute semanticsimilarity between these concepts and all patches. Secondly, we reduce the weight of irrelevant patchesaccording to the semantic similarity between each patch and disease-related concepts. Moreover, we constructa graph among tokens based on these concepts and employ a graph convolutional network layer to extract thestructural information of the graph, which is used to prompt the pre-trained multimodal models for diagnosingneurological disorders. Extensive experiments demonstrate that our method achieves superior performance forneurological disorder diagnosis compared with state-of-the-art methods and validated by clinicians.
提示学习在多模态大型模型的微调过程中对各种下游任务表现出了显著的效果。然而,将现有的提示学习方法应用于神经系统疾病的诊断仍然面临两个问题:(i)现有方法通常将所有补丁视为同等重要,尽管在神经影像中只有少数补丁与疾病相关;(ii)它们忽略了大脑连接网络中固有的结构信息,而这对于理解和诊断神经系统疾病至关重要。为了解决这些问题,我们在多模态模型微调过程中引入了一种通过学习图提示来诊断神经系统疾病的新颖提示学习模型。具体而言,我们首先利用GPT-4获取相关的疾病概念,并计算这些概念与所有补丁之间的语义相似度。其次,根据每个补丁与疾病相关概念之间的语义相似度,减少与疾病无关的补丁的权重。此外,我们基于这些概念在标记之间构建了一个图,并采用图卷积网络层来提取该图的结构信息,这些信息用于提示预训练的多模态模型进行神经系统疾病的诊断。大量实验表明,与最先进的方法相比,我们的方法在神经系统疾病诊断方面表现出优越的性能,并得到了临床医生的验证。
Method
方法
Utilizing transformers (Vaswani et al., 2017) as the architectureof encoders to process multimodal data has become a popular choicein modern multimodal large models, as it can effectively integrateinformation from multiple modalities. For example, pre-trained visionlanguage models like CLIP (Radford et al., 2021) employ separatetransformer-based backbones (e.g., ViT) to encode images and textseparately. To obtain representations of the samples, the transformerarchitecture involves two key components: (i) Tokenization: convertingthe raw data into tokens. (ii) Encoding: performing attention-basedfeature extraction layers on all tokens.
利用Transformer(Vaswani 等,2017)作为编码器的架构来处理多模态数据已成为现代多模态大型模型中的一种流行选择,因为它能够有效整合来自多种模态的信息。例如,预训练的视觉语言模型如CLIP(Radford 等,2021)采用基于Transformer的独立骨干网络(如ViT)分别对图像和文本进行编码。为了获得样本的表示,Transformer架构涉及两个关键组件:(i) 标记化:将原始数据转换为标记。(ii) 编码:对所有标记执行基于注意力的特征提取层。
Conclusion
结论
In this paper, we proposed a graph prompt learning fine-turningframework for neurological disorder diagnosis, by jointly considering the impact of irrelevant patches as well as the structural information among tokens in multimodal medical data. Specifically, weconduct concept learning, aiming to reduce the weights of irrelevant tokens according to the semantic similarity between each tokenand disease-related concepts. Moreover, we conducted graph promptlearning with concept embeddings, aiming to bridge the gap betweenmultimodal models and neurological disease diagnosis. Experimentalresults demonstrated the effectiveness of our proposed method, compared to state-of-the-art methods on neurological disease diagnosistasks.
在本文中,我们提出了一种用于神经系统疾病诊断的图提示学习微调框架,该框架结合了多模态医学数据中与疾病无关的补丁的影响以及标记之间的结构信息。具体而言,我们进行了概念学习,旨在根据每个标记与疾病相关概念之间的语义相似度来减少与疾病无关的标记的权重。此外,我们利用概念嵌入进行了图提示学习,旨在弥合多模态模型与神经系统疾病诊断之间的差距。实验结果表明,与最先进的方法相比,我们提出的方法在神经系统疾病诊断任务中具有显著的效果。
Figure
图
Fig. 1. The flowchart of the proposed MMGPL consists of three modules i.e., multimodal data tokenizer (light blue block), concept learning (light green block), and graph promptlearning (light yellow block). First, MMGPL divides the multimodal medical data into multiple patches and project them into a shared embedding space (Section 3.2). Second,MMGPL prompts the GPT-4 to generate disease-related concepts and further learn the weights of tokens based on the semantic similarity between tokens and concepts (Section 3.3).Third, MMGPL learns a graph among tokens and extracts structural information to prompt the unified encoder (Section 3.4). Finally, MMGPL obtains the output from the unifiedencoder and uses it to predict the label of the subject.
图1. 所提出的MMGPL流程图由三个模块组成,即多模态数据标记器(浅蓝色块)、概念学习(浅绿色块)和图提示学习(浅黄色块)。首先,MMGPL将多模态医学数据分割成多个补丁并将其投影到共享的嵌入空间中(第3.2节)。其次,MMGPL提示GPT-4生成与疾病相关的概念,并根据标记与概念之间的语义相似度进一步学习标记的权重(第3.3节)。第三,MMGPL在标记之间学习一个图并提取结构信息,以提示统一的编码器(第3.4节)。最后,MMGPL从统一的编码器中获取输出并用其预测受试者的标签。
Fig. 2. Performance of MMGPL with different combinations of components on all datasets, i.e., ‘‘B’’ denotes baseline method, ‘‘B+G’’ denotes baseline method with graph promptlearning, ‘‘B+W’’ denotes baseline method with token weights, and ‘‘B+W+G’’ denotes baseline method with graph prompt learning and token weights.
图2. MMGPL在所有数据集上使用不同组件组合的性能表现,其中“B”表示基线方法,“B+G”表示结合图提示学习的基线方法,“B+W”表示结合标记权重的基线方法,“B+W+G”表示结合图提示学习和标记权重的基线方法。
Fig. 3. Performance of MMGPL with different modalities
图3. MMGPL在不同模态下的性能表现。
Fig. 4. Heat maps generated by MMGPL on different subjects in ADNI dataset.
图4. MMGPL在ADNI数据集不同受试者上生成的热图。
Fig. 5. The visualization of concept-similarity graph on the ADNI dataset. The horizontal and vertical axes represent concepts and tokens. Different colors represent conceptsbelonging to different categories. The red texts represent concepts related to NC, the green texts represent concepts related to LMCI, and the blue texts represent concepts relatedto AD.
图5. ADNI数据集上概念相似性图的可视化。横轴和纵轴代表概念和标记。不同的颜色代表属于不同类别的概念。红色文本代表与NC相关的概念,绿色文本代表与LMCI相关的概念,蓝色文本代表与AD相关的概念。
Fig. 6. The visualization of the quantified impact of different concepts on the ADNIdataset. The concepts are shown on the left side, while classes are shown on the rightside. The width of the lines corresponds to the magnitude of the weights, and thevalues indicate the specific weight values.
图6. 不同概念对ADNI数据集量化影响的可视化。概念显示在左侧,类别显示在右侧。线条的宽度对应权重的大小,数值表示具体的权重值。
Table
表
Table 1Diagnose performance (mean and standard deviation) of all methods on all datasets. Note that, ‘‘ADNI-3CLS’’ and ‘‘ADNI-4CLS’’ indicate theclassification on three classes ‘‘NC/LMCI/AD’’ and the classification on four classes ‘‘NC/EMCI/LMCI/AD’’, respectively.
表1所有方法在所有数据集上的诊断性能(均值和标准差)。需要注意的是,“ADNI-3CLS”和“ADNI-4CLS”分别表示对三类“NC/LMCI/AD”和四类“NC/EMCI/LMCI/AD”的分类。
Table 2Comparison between MMGPL and related works on scalability. Note that, ✓(vanilla)indicates can only supports two modalities and is challenging to expand to supportsmore modalities.
表2MMGPL与相关工作在可扩展性方面的比较。需要注意的是,✓(vanilla)表示仅支持两种模态,且难以扩展以支持更多模态。