大型模型的训练涉及到微调,微调则面临着高质量数据的稀缺性。与基于集中式数据中心的解决方案相比,物联网-IoT中大型模型的更新面临着分布式客户端私有且异构数据的协调挑战。为了解决这一挑战,作者提出了KOALA来推动物联网中大模型的训练。由于物联网客户端获得的资源是有限的,在本地执行大模型并以保护隐私的方式更新模型是不可行的。因此,利用联邦学习和知识蒸馏,通过与小模型协作来更新大模型,小模型可以在物联网客户端本地运行,单独处理其私有数据,并通过服务器和客户端之间的迭代学习实现大-小模型知识转移。此外,为了支持计算能力相似或不同的客户端,KOALA设计了同质或异构两种大模型联合学习模式。实验结果表明,与传统方法相比,KOALA不仅可以达到相似的训练效果,而且可以显著减少对本地存储和计算能力资源的需求。
来自:Federated Knowledge Transfer Fine-tuning Large Server Model with Resource-Constrained IoT Clients
目录
- 背景概述
- 相关工作-联邦学习
- 相关工作-知识蒸馏
- 相关工作-联邦知识蒸馏
- 方法
- 问题陈述
- 动机
- KOALA
- 本地知识提取
- 反向知识蒸馏
- 前向知识蒸馏
- 实验设置
- 模型
- 数据集
- 基线
背景概述
面对规模不断增长的模型,如BERT,GPT,ViT等。如何利用分布式计算能力,在各种物联网场景中对其进行训练和应用变得至关重要。不幸的是,IoT客户端通常存在数据保护的考虑,和受限制的计算能力。这些因素阻碍了使用丰富的 IoT 数据来训练复杂和大规模的模型。
为了应对数据隐私的挑战,通常采用基于联邦学习FL的解决方案,以协作和隐私保护的方式支持大模型训练,例如,Yu S等人提出了一种在具有私有数据的客户端和具有标记公共数据的服务器上交替训练大模型的方法;Wu C等人引入了一种用于训练个性化大模型的联邦互蒸馏方法,可以显著降低通信成本。尽管私有知识可以通过FL在分布式客户端之间共享,但当前方法的共同前提是具有足够的本地计算能力,可以在每个学习客户端上直接运行大模型,这使得它们无法支持本地资源不足的分布式IoT客户端。
因此,为了支持大模型的微调和模型自适应以赋能各种IoT场景,该研究的目标定义如图1所示:
- 服务器具有足够的存储和计算能力,但缺乏高质量的数据----仅具有有限数量的未标记代理数据集;
- 物联网客户端作为一个群体具有丰富的测量数据和分布式计算能力,但对于每个客户端而言,其设备和私有数据是异构的,它的本地资源有限,无法支持大模型的运行;
- 图1:服务端与IoT端的情况
通过集成FL在物联网客户端之间共享私有知识和知识蒸馏(KD)在不同模型之间(即教师和学生模型之间)传递编码知识,提出的KOALA实现联合迭代学习,允许物联网客户端运行其本地小模型以提取和共享本地知识,然后服务器根据每个客户端本地更新的小模型再更新大模型的adapter。具体来说,为了实现这样的学习过程,我们将前向蒸馏和反向蒸馏技术联合使用,首先对训练好的小模型进行反向蒸馏,对大模型进行微调,然后对大模型进行前向蒸馏,为IoT客户端更新小模型----小模型更新大模型,大模型再更新小模型。
前向蒸馏:让Student输出接近Teacher:目标是使Student的输出分布尽可能接近Teacher的输出分布,从而提高学生模型的性能
反向蒸馏:把KL(P,Q)改为KL(Q,P)
前向蒸馏的分布Q会比较宽,反向蒸馏的分布Q会比较窄,因此反向蒸馏可以防止Student高估Teacher的低概率区
此外,在传统FL中,全局模型和局部模型具有相同的结构,并且可以直接基于局部更新的聚合来更新全局模型。然而,在KOALA中实现的大-小模型协同学习过程需要在服务器端和客户端支持不同的模型,这使得传统的FL方法不可行。因此,根据小模型之间的差异,KOALA实现了两种学习模式,用于聚合同质或异质小模型中编码的局部知识。同质方法支持IoT客户端运行结构相同的小模型,异质方法支持每个IoT客户端运行不同的小模型,可以根据客户端的实际计算能力创建小模型,更加灵活。在大模型更新后,可以使用同质或异质方法,从最新的大模型中蒸馏相关的小模型,并将其分派给相应的客户端,开始新的学习迭代。
基于标准数据集评估了KOALA的效率。实验结果表明,与基线相比,KOALA可以在所有任务上接近相似的训练性能(在本地加载和执行大模型的情况),因此,KOALA显著减少了对本地资源的需求。
主要贡献如下:
- 在数据保护和资源受限的IoT场景下,作者提出了一种新颖的大小模型协同学习过程,通过该过程,FL和KD可以共同支持大小模型的迭代学习,即使它们在模型结构上是跨尺度的;
- 为了更好地处理基于本地数据更新的异构小模型的输出,作者设计了一种反向知识蒸馏策略,通过该策略对代理数据集上的本地模型输出进行蒸馏和集成,生成共识软标签,用于大模型微调;
- 经验证,该方法具有性能等效和资源高效的特点。具体来说,通过KOALA微调的大模型可以达到与传统方法更新的模型相似的精度。同时,与传统方法相比,加载局部模型所需的存储空间(Homo)和存储空间(Hete)分别减少了97.6%和97.2%,局部模型的FLOPs (Homo)和FLOPs (Hete)分别减少了98.4%和98.6%。
相关工作-联邦学习
联邦学习是一种保护隐私的机器学习框架,其中服务器协调多个客户端以学习全局可共享的模型,而无需直接交换本地数据。作为经典方法,FedAvg管理每个客户端训练其本地模型,并将更新后的本地模型上传到服务器。然后,聚合本地模型以更新全局模型,然后由活动客户端在下一轮中下载全局模型。然而,客户端之间非独立同分布(Non-IID)数据的问题降低了联邦学习的性能,促使许多方法旨在缓解这一问题。因此,FedProx在局部训练中引入了损失函数的proximal term,以约束模型参数的更新。SCAFFOLD引入控制变量来减少“客户漂移”。MOON将联邦学习和对比学习相结合,使局部模型更新更接近全局模型,远离以前的局部模型。由于高度异构的数据可能会阻碍模型的收敛,并且通用的全局模型无法满足不同客户端的个性化需求,因此个性化的联邦学习是必不可少的。Per-FedAvg结合了经典的元学习框架MAML,以基于全局元模型训练个性化模型。不同的是,PFedMe没有直接利用全局模型,而是同时训练全局模型和个性化模型。
相关工作-知识蒸馏
Hinton等人首先引入了知识蒸馏。他们的工作采用hard损失和soft损失的加权总和作为完全损失。soft损失是学生模型的soft输出与教师模型生成的soft标签之间的损失,hard损失是学生模型的hard输出与真实标签之间的损失。Adriana Romero等提出了基于隐藏层知识特征的知识蒸馏(hints)。Zhang等人提出相互蒸馏(mutual distillation),使不同的模型能够相互从彼此中提取知识。
相关工作-联邦知识蒸馏
知识蒸馏与联邦学习的集成越来越受到关注。FedMD基于共享数据集进行集成,以计算平均分数用于指导每个客户端的知识蒸馏。相反,FD消除了对共享数据集的需求,并允许客户端在其本地数据集上计算每个标签的预测分数,并允许服务器计算每个标签的全局平均预测分数,这在本地蒸馏期间充当软标签。FedGKT结合了联邦学习和分裂学习(SL, split learning----将一个模型分成多个部分,每个部分都在一个分布式设备上)。FedDKC与FedGKT类似,可以减少异构模型知识分布之间的差距。虽然FedGKT和FedDKC可以支持资源受限的客户端,但这两种方法都需要上传本地的真实标签,这会损害客户端的隐私。而且,他们的目标是在大模型的指导下训练小模型,而不是考虑如何整合从不同客户端提取的知识来快速有效地更新大模型。
方法
问题陈述
假设有
N
N
N个客户端(
i
=
1
,
2
,
.
.
.
,
N
i=1,2,...,N
i=1,2,...,N),每个客户端有一个私有数据集,标签类别为
j
=
1
,
2
,
.
.
.
,
C
j=1,2,...,C
j=1,2,...,C。客户端
i
i
i的样本量为
n
i
n_{i}
ni。为了支持分类任务,在式1中定义的关键目标是,在局部资源受限的情况下,所提出方法更新的大模型与常规模型(conventional model)之间的损失差最小,其中
Ω
Ω
Ω和
Ω
C
o
n
v
Ω_{Conv}
ΩConv分别是所提出方法训练的大模型和常规模型,
L
(
)
L()
L()是损失函数,
D
D
D是测试数据集:
动机
所提出方法是基于这样的直觉:小模型可以被视为本地私有知识的提取器,可以在服务器上使用它将嵌入在私有数据中的知识传递给大模型。
为了验证这个直觉,作者设计了一个简单的实验,其中在每一轮中,小模型由标记数据集训练,然后通过知识蒸馏基于代理数据集微调大模型,小模型作为Teacher,大模型作为Student。注意,CIFAR-10用于小模型训练,CIFAR-10的测试数据集用于评估大型模型的性能。此外,小模型为MobileNet V3 small,大模型为VGG19。
- 图2:知识迁移和随机选择的Acc。随机选择是不微调大模型,随机选择一个分类概率值。
从图2所示的结果可以看出,即使只处理未标记的代理数据集,被小模型蒸馏后大模型的准确率可以得到显著提高。因此,基于知识转移,小模型可以与大模型共享本地私有知识,这促使作者设计能够整合联邦学习和知识蒸馏的KOALA,实现一个大-小模型协同学习过程。
KOALA
KOALA实现了一个大小模型协同学习的过程,通过小模型作为本地知识提取器,并根据从小模型中提取的知识对大模型进行微调。具体来说,在每个IoT客户端中,从服务器下载相应的小模型,并根据其私有数据在本地进行训练。在服务器端,引入双向知识蒸馏机制,支持:
- 基于小模型的反向蒸馏对大模型进行微调
- 基于大模型的正向蒸馏对小模型进行更新
如图3所示,KOALA包括三个步骤,即:1.本地知识提取,2.反向知识蒸馏,3.正向知识蒸馏。由于IoT客户端不仅在数据上是异构的,而且在计算能力上也是异构的,因此KOALA设计了两种学习模式,一种是同质小模型(homo),另一种是异构小模型(hete)。
本地知识提取
在此步骤中,根据相应IoT客户端的私有数据更新homo或hete小模型。提取本地知识后,将小模型上传到服务器。
反向知识蒸馏
收集到所有本地更新的小模型后,服务器启动反向蒸馏,其中大模型作为Student,小模型作为Teacher。
具体而言,在homo模式下,首先将小模型聚合生成全局小模型
w
w
w,然后根据代理数据
x
x
x生成伪标签,如下所示,
T
T
T为蒸馏温度:
全局小模型
w
w
w把知识迁移到大模型
Ω
Ω
Ω,其中,大模型仅更新它的adapter,反向蒸馏损失
l
o
s
s
r
h
o
m
o
loss_{r}^{homo}
lossrhomo在homo模式中使用,如下所示,其中,
l
K
L
l_{KL}
lKL是KL损失:
由于异构小模型不能直接聚合,在hete模式下,对小模型的输出分布进行细化和集成,生成共识软标签。为了调解输出分布中的异质性,作者引入了一种分布调整策略。假设在输出分布
f
(
x
,
w
i
)
f(x,w_{i})
f(x,wi)内,最大和最小值分别是
z
i
,
m
a
x
z_{i,max}
zi,max和
z
i
,
m
i
n
z_{i,min}
zi,min,标签
j
j
j对应的值为
z
i
,
j
z_{i,j}
zi,j,调整的值为
z
^
i
,
j
\widehat{z}_{i,j}
z
i,j,定义如下,其中
w
i
w_{i}
wi是客户端
i
i
i的模型,
k
k
k是用于调整的系数:
将所有标签的调整值相加,我们可以得到:
在式5中,
z
‾
i
\overline{z}_{i}
zi是输出分布
f
(
x
,
w
i
)
f(x,w_{i})
f(x,wi)的平均值。假设所有小模型的精化分布的均值等于
A
A
A (
A
A
A是一个常数),因此:
因此可以计算出
k
k
k:
将其代入式4,得到分布调整策略为:
根据式8,得到调整后的分布
z
^
i
=
{
z
^
i
,
1
,
z
^
2
,
.
.
.
,
z
^
i
,
C
}
\widehat{z}_{i}=\left\{\widehat{z}_{i,1},\widehat{z}_{2},...,\widehat{z}_{i,C}\right\}
z
i={z
i,1,z
2,...,z
i,C}。然后,通过式9得到小模型的综合输出分布
z
~
\widetilde{z}
z
,假设当前轮中,活动的客户端为集合
S
S
S:
基于
z
~
\widetilde{z}
z
,共识软标签为:
然后,我们根据公式11中定义的反向蒸馏损失
l
o
s
s
r
h
e
t
e
loss_{r}^{hete}
lossrhete对大模型
Ω
Ω
Ω进行微调。
前向知识蒸馏
在反向蒸馏之后,作者实现正向蒸馏,根据更新后的大模型更新小模型,其中大模型作为Teacher,小模型作为Student。为了计算正向蒸馏损失,需要计算输出特征损失(输出层之间的损失)和隐藏特征损失(隐藏层之间的损失)。
在homo模式中,全局小模型
w
w
w作为student被更新,
Ω
h
Ω^h
Ωh表示大模型中的前
h
h
h层,
w
g
w^g
wg表示全局小模型的前
g
g
g层。因此,输出特征损失
l
o
s
s
o
u
t
h
o
m
o
loss_{out}^{homo}
lossouthomo和隐藏特征损失
l
o
s
s
h
i
d
h
o
m
o
loss_{hid}^{homo}
losshidhomo分别根据式12和13计算,其中
W
W
W是桥接矩阵,
l
M
S
E
(
)
l_{MSE}()
lMSE()是MSE损失。
因此,组合两者得到前向蒸馏损失:
对于hete模式,每个小模型
w
i
,
i
∈
S
w_{i},i\in S
wi,i∈S作为Student,假设
w
i
w_i
wi是第
i
i
i个模型,
w
i
g
w_{i}^{g}
wig是它的前
g
g
g层,
W
i
W_{i}
Wi是它的桥接矩阵,则输出损失和隐藏损失为:
第
i
i
i个小模型的前向蒸馏损失为:
最后,无论是homo模式还是hete模式,都是基于前向蒸馏损失对小模型进行更新,更新后将其分派给相关客户端开始新的学习轮,直到满足某些条件(例如,模型收敛或达到最大学习轮)。
实验设置
模型
选择TorchVision backbone,并将分类器附加到每个骨干网的最后一层,形成实验中使用的大模型和小模型。大模型的分类器被视为adapter。大模型的主干是VGG19。在homo模式中,小模型统一为MobileNetV2,在hete模式中,小模型分别为MobileNet V2、MobileNet V3 small、EfficientNet B0、ShuffleNet V2 X0_5和ShuffleNet V2 X2_0。此外,作者实现了额外的工具来计算模型FLOPs,其中使用64×64随机生成的“图像”作为输入。
数据集
作者选择了4个数据集:CIFAR-10、Fashion-MNIST、USPS和GTSRB。每个数据集的整个测试集用于评估大模型,记录其在训练前(第0轮)和每个学习轮结束时的性能。通过去除标签,代理数据集是原始训练集的子集。客户端的本地数据集采用Dirichlet分布,浓度参数为1.0(从原始数据集减去代理集再采样)。此外,代理数据集和私有客户端数据集之间没有重叠。
基线
作者在假设所有IoT客户端都有足够的本地资源来直接运行大型模型的情况下设置了基线,并使用FedAvg来更新全局模型。具体的,基线更新全局模型的工作流程包括三个步骤,即:
- 客户端下载全局大模型
- 对大模型进行局部微调
- 将大模型参数上传到服务器进行全局聚合。
在客户端-服务器交互期间,服务器和客户端之间交换的是adapter,而不是整个模型(除了第一次将大模型从服务器下载到客户机)。