扩散方法与传统方法相结合,挺有意思~
本文提出一种称为生成式元胞自动机 (Generative Cellular Automata,GeCA) 的新模型系列,其灵感来自于生物体从单细胞进化而来的过程,显著提高了11 种不同眼科疾病分类任务的表现。
论文:An Organism Starts with a Single Pix-Cell: A Neural Cellular Diffusion for High-Resolution Image Synthesis
代码:https://github.com/xmed-lab/GeCA (即将开源)
0、摘要
生成建模寻求近似真实数据的统计特性,使合成的新数据与原始分布非常相似。生成对抗网络(GANs)和去噪扩散概率模型(DDPMs)代表了生成建模方面的重大进步,它们分别从博弈论和热力学中获得了灵感。然而,通过生物进化的视角来探索生成建模,在很大程度上仍未被开发。
本文介绍了一个新的模型家族,称为生成式元胞自动机(GeCA),其灵感来自于从单个细胞进化而来的生物体。针对两种数据模态的眼部疾病分类,GeCA 是一种有效的增强工具。
OCT成像数据稀缺,类别分布具有固有倾斜,GeCA显著提高了11种不同眼部疾病分类的性能,与传统基线相比,平均F1评分增加了12%。
在相似的参数限制下,GeCAs 的性能优于两种包含 UNet 或最先进的基于 Transformer 的去噪扩散模型方法。(好牛的样子,(●’◡’●))
1、引言
1.1、深度学习在眼底照相和OCT应用中的限制
(1)数据稀缺:缺乏公开可访问的数据集,特别是对于OCT;
(2)类不平衡:疾病分布倾斜性;
1.2、目前生成模型合成数据的局限
(1)大部分生成模型严重依赖于 UNet 和 Transformers,需要大量参数,在大规模数据集上进行训练;在医学成像领域,数据集、标注和计算资源往往稀缺;
(2)神经元胞自动机(Neural Cellular Automata,NCA)受生物过程启发,在更少的参数下改善各种任务性能;在生成任务中,NCA 具有低分辨率输出,且缺乏全面的性能比较,在下游任务的评估中,NCA的图像生成效率仍然是一个未解决的挑战;
1.3、本文贡献
(1)提出生成式元胞自动机(GeCA),一种集成了神经元胞自动机(NCA)和扩散目标的新模型,专门针对 NCA 的独特结构进行了定制;
(2)利用基因遗传指导(GHG)来改进 GeCA 的图像采样。GHG 使 GeCA 在图像生成和视网膜疾病分类方面超过了 SOTA DiT,其参数仅占 DiT 的一半;
(3)证明合成图像可增强训练数据集的能力,提高了 OCT 多标签视网膜疾病分类任务性能;
PS:元胞自动机学习传送:【数学建模】元胞自动机(CA)详解 + Matlab代码实现 (还怪好玩的耶~)
2、生成式元胞自动机
2.1、一个生物体从一个小细胞开始
NCA 将输入图像建模为
H
×
W
{H×W}
H×W 的网格,包含
H
×
W
{H×W}
H×W 个实体,命名为:pix-cells(像素元胞)。每个 pix-cells 代表一个时间依赖的状态空间表示,促进类似于细胞向生物体的动态进化,即像素演化为图像。
将每个 pix-cells 在步骤
m
{m}
m 处的状态参数化为标量向量,定义为:
C
i
n
{C^{in}}
Cin:表示图像输入通道,灰度为1, RGB为3;
C
γ
{C^{\gamma}}
Cγ:表示位置编码,由一个连续平滑的正弦函数定义,促进网格内的空间感知;
C
o
u
t
{C^{out}}
Cout:表示 pix-cells 的输出状态;
C
h
{C^{h}}
Ch:表示 pix-cells 的隐藏状态变量;
为了从像素进化为图像,本文遵循传统的 NCA,采用一个随机规则:意味着一个 pix-cells 在第 m {m} m 步以概率 p {p} p 随机更新,反映了生物体中细胞更新的非同步性质。
针对 pix-cells 更新,
C
i
n
{C^{in}}
Cin 和
C
γ
{C^{\gamma}}
Cγ 是恒定的,只用关注
C
o
u
t
{C^{out}}
Cout 和
C
h
{C^{h}}
Ch,该过程如 图2 中 GeCA step 所示,被定义为:
与 SOTA 扩散 Transformer(DiT)中的
M
{M}
M 层分层建模不同,本文将
Θ
{Θ}
Θ 参数化为一个具有局部自注意机制的 single DiT block,特别是在 pix-cells 的8个最近相邻中计算。
局部注意策略,允许每个 pix-cells 根据等式(2),使用
Θ
{Θ}
Θ ,独立生长
M
{M}
M 次。
GeCA 方法将图像生成的重点转向了局部空间交互,避免了 UNet 和标准 transformers 等传统模型中的全局上下文依赖。但 GeCA 通过 C h {C^{h}} Ch 积累长期的状态空间表示来获得全局一致性,与 NCA 、Mamba、通用 Transformers 和 MLP-mixers 中记录的基本概念相一致。
GeCA 总体框架:
2.2、元胞扩散:将细胞进化为生物体
为训练模型参数 Θ {Θ} Θ ,引入的成熟的扩散过程,并在正向和反向步骤中进行特定的修改。
在正向扩散过程中,将
C
o
u
t
{C^{out}}
Cout 和
C
h
{C^{h}}
Ch 初始化为0,除了位于
H
×
W
{H×W}
H×W 网格中心的单个 pix-cells ,它用随机标量初始化,作为元胞过程的起点。
C
γ
{C^{\gamma}}
Cγ 用一个正弦位置编码来初始化,
C
i
n
{C^{in}}
Cin 可以在每个 pix-cells 的正向扩散过程中描述为(只有
C
i
n
{C^{in}}
Cin 参与了扩散过程):
使用等式(2)实现
M
{M}
M 次元胞更新来发展
C
o
u
t
{C^{out}}
Cout 和
C
h
{C^{h}}
Ch ,当
T
→
∞
{T → ∞}
T→∞ 时,
C
T
i
n
{C_{T}^{in}}
CTin 为各向同性高斯分布,优化过程为从 pix-cells 预测噪声:
2.3、通过基因遗传改进反向采样
GeCA 用 pix-cells 表示输入图像,它是一个时间依赖的状态空间表示,长期信息由内部隐藏状态
C
h
{C^{h}}
Ch 保存,可类比为遗传物质。
因此,利用
t
+
1
{t+1}
t+1 时刻的
C
h
{C^{h}}
Ch 来指导反向生成过程, 反映了遗传特性,本文修改了反向过程中的每个步骤,以启动 pix-cells 的隐藏状态
C
h
{C^{h}}
Ch :
同时,对于每个时间步,网格中心像素 pix-cells 的
C
o
u
t
{C^{out}}
Cout 被定义为:
该过程被称为基因遗传指导(Gene Heredity Guidance,GHG),为
C
i
n
{C^{in}}
Cin 去噪和
C
h
{C^{h}}
Ch 细化设置了一个合理的起点,去噪采样一个 pix-cells 的
C
0
i
n
{C_{0}^{in}}
C0in ,遵循传统的扩散步骤:
2.4、视网膜疾病分类
由于数据稀缺和偏态的类别分布,从OCT图像中分类视网膜疾病面临重大挑战。利用生成建模来有效地增强数据集,与传统的增强技术相比,这一策略被证明可以显著增强下游分类任务。
本文合成了一个扩展的训练集,反映了原始训练集的分布,给定原始数据集类别分布: p o r i g ( y ) {p_{orig}(y)} porig(y), y {y} y 表示数据集标签, N o r i g {N_{orig}} Norig 表示原始数据集大小,目标是将数据集扩展五倍: N a u g = 5 × N o r i g {N_{aug} = 5×N_{orig}} Naug=5×Norig ,同时保留 p o r i g ( y ) {p_{orig}(y)} porig(y),这是通过确保增强数据集中,每个标签 y {y} y 的计数 C o c u n t a u g ( y ) {Cocunt_{aug}(y)} Cocuntaug(y),是其原始计数的5倍来实现的:
3、实验与结果
3.1、数据集
(1)多标签OCT数据集:OCT-ML,共203例患者369只眼的1435个样本,有多种疾病(11类),包括正常、干性老年性黄斑变性(dAMD)、湿性年龄相关性黄斑变性(wAMD)、糖尿病视网膜病变(DR)、中枢性浆液性脉络膜视网膜病变(CSC)、色素上皮脱离(PED)、黄斑视网膜上膜(MEM)、液体(FLD)、渗出(EXU)、脉络膜新生血管(CNV)和视网膜血管阻塞(RVO)。
数据分布:
(2)DeepDRiD数据集:眼底成像数据,有5类。1080训练,120验证,400测试。
3.2、基线与实施细节
(1)与DiT和LDM比较;
(2)使用相同的 Classifier Free Guidance (CFG) 策略实现条件生成;
(3)所有生成都在类似于LDM的隐空间中完成的,输出大小为256×256;
(4)batch size=128,训练14000个epoch;
(5)下游分类任务采用Resnet34,Adam优化器;
3.3、生成建模评估
特征似然发散(Feature Likelihood Divergence,FLD)量化的泛化差距(generalization gap,GG),可评价:新颖性 novelty(不同于训练样本)、保真度 fidelity 和合成样本的多样性 diversity;
定量评估:
可视化:
3.4、视网膜疾病分类
所有生成模型都显著地提高了分类任务的各种指标性能。GeCA 扩展训练数据集获得了mean average precision(mAP为73.28%)。
3.5、GHG 消融
Gene Heredity Guidance (GHG)的影响:
感觉是基于每一步扩散过程结果再进行了元胞更新~