IoT客户端+联邦学习微调大模型

news2024/12/23 23:57:12

大型模型的训练涉及到微调,微调则面临着高质量数据的稀缺性。与基于集中式数据中心的解决方案相比,物联网-IoT中大型模型的更新面临着分布式客户端私有且异构数据的协调挑战。为了解决这一挑战,作者提出了KOALA来推动物联网中大模型的训练。由于物联网客户端获得的资源是有限的,在本地执行大模型并以保护隐私的方式更新模型是不可行的。因此,利用联邦学习和知识蒸馏,通过与小模型协作来更新大模型,小模型可以在物联网客户端本地运行,单独处理其私有数据,并通过服务器和客户端之间的迭代学习实现大-小模型知识转移。此外,为了支持计算能力相似或不同的客户端,KOALA设计了同质或异构两种大模型联合学习模式。实验结果表明,与传统方法相比,KOALA不仅可以达到相似的训练效果,而且可以显著减少对本地存储和计算能力资源的需求。

来自:Federated Knowledge Transfer Fine-tuning Large Server Model with Resource-Constrained IoT Clients

目录

  • 背景概述
    • 相关工作-联邦学习
    • 相关工作-知识蒸馏
    • 相关工作-联邦知识蒸馏
  • 方法
    • 问题陈述
    • 动机
    • KOALA
      • 本地知识提取
      • 反向知识蒸馏
      • 前向知识蒸馏
  • 实验设置
    • 模型
    • 数据集
    • 基线

背景概述

面对规模不断增长的模型,如BERT,GPT,ViT等。如何利用分布式计算能力,在各种物联网场景中对其进行训练和应用变得至关重要。不幸的是,IoT客户端通常存在数据保护的考虑,和受限制的计算能力。这些因素阻碍了使用丰富的 IoT 数据来训练复杂和大规模的模型。

为了应对数据隐私的挑战,通常采用基于联邦学习FL的解决方案,以协作和隐私保护的方式支持大模型训练,例如,Yu S等人提出了一种在具有私有数据的客户端和具有标记公共数据的服务器上交替训练大模型的方法;Wu C等人引入了一种用于训练个性化大模型的联邦互蒸馏方法,可以显著降低通信成本。尽管私有知识可以通过FL在分布式客户端之间共享,但当前方法的共同前提是具有足够的本地计算能力,可以在每个学习客户端上直接运行大模型,这使得它们无法支持本地资源不足的分布式IoT客户端。

因此,为了支持大模型的微调和模型自适应以赋能各种IoT场景,该研究的目标定义如图1所示:

  • 服务器具有足够的存储和计算能力,但缺乏高质量的数据----仅具有有限数量的未标记代理数据集;
  • 物联网客户端作为一个群体具有丰富的测量数据和分布式计算能力,但对于每个客户端而言,其设备和私有数据是异构的,它的本地资源有限,无法支持大模型的运行;

fig1

  • 图1:服务端与IoT端的情况

通过集成FL在物联网客户端之间共享私有知识和知识蒸馏(KD)在不同模型之间(即教师和学生模型之间)传递编码知识,提出的KOALA实现联合迭代学习,允许物联网客户端运行其本地小模型以提取和共享本地知识,然后服务器根据每个客户端本地更新的小模型再更新大模型的adapter。具体来说,为了实现这样的学习过程,我们将前向蒸馏和反向蒸馏技术联合使用,首先对训练好的小模型进行反向蒸馏,对大模型进行微调,然后对大模型进行前向蒸馏,为IoT客户端更新小模型----小模型更新大模型,大模型再更新小模型。


前向蒸馏:让Student输出接近Teacher:目标是使Student的输出分布尽可能接近Teacher的输出分布,从而提高学生模型的性能

反向蒸馏:把KL(P,Q)改为KL(Q,P)

前向蒸馏的分布Q会比较宽,反向蒸馏的分布Q会比较窄,因此反向蒸馏可以防止Student高估Teacher的低概率区


此外,在传统FL中,全局模型和局部模型具有相同的结构,并且可以直接基于局部更新的聚合来更新全局模型。然而,在KOALA中实现的大-小模型协同学习过程需要在服务器端和客户端支持不同的模型,这使得传统的FL方法不可行。因此,根据小模型之间的差异,KOALA实现了两种学习模式,用于聚合同质或异质小模型中编码的局部知识。同质方法支持IoT客户端运行结构相同的小模型,异质方法支持每个IoT客户端运行不同的小模型,可以根据客户端的实际计算能力创建小模型,更加灵活。在大模型更新后,可以使用同质或异质方法,从最新的大模型中蒸馏相关的小模型,并将其分派给相应的客户端,开始新的学习迭代。

基于标准数据集评估了KOALA的效率。实验结果表明,与基线相比,KOALA可以在所有任务上接近相似的训练性能(在本地加载和执行大模型的情况),因此,KOALA显著减少了对本地资源的需求。

主要贡献如下:

  • 在数据保护和资源受限的IoT场景下,作者提出了一种新颖的大小模型协同学习过程,通过该过程,FL和KD可以共同支持大小模型的迭代学习,即使它们在模型结构上是跨尺度的;
  • 为了更好地处理基于本地数据更新的异构小模型的输出,作者设计了一种反向知识蒸馏策略,通过该策略对代理数据集上的本地模型输出进行蒸馏和集成,生成共识软标签,用于大模型微调;
  • 经验证,该方法具有性能等效和资源高效的特点。具体来说,通过KOALA微调的大模型可以达到与传统方法更新的模型相似的精度。同时,与传统方法相比,加载局部模型所需的存储空间(Homo)和存储空间(Hete)分别减少了97.6%和97.2%,局部模型的FLOPs (Homo)和FLOPs (Hete)分别减少了98.4%和98.6%。

相关工作-联邦学习

联邦学习是一种保护隐私的机器学习框架,其中服务器协调多个客户端以学习全局可共享的模型,而无需直接交换本地数据。作为经典方法,FedAvg管理每个客户端训练其本地模型,并将更新后的本地模型上传到服务器。然后,聚合本地模型以更新全局模型,然后由活动客户端在下一轮中下载全局模型。然而,客户端之间非独立同分布(Non-IID)数据的问题降低了联邦学习的性能,促使许多方法旨在缓解这一问题。因此,FedProx在局部训练中引入了损失函数的proximal term,以约束模型参数的更新。SCAFFOLD引入控制变量来减少“客户漂移”。MOON将联邦学习和对比学习相结合,使局部模型更新更接近全局模型,远离以前的局部模型。由于高度异构的数据可能会阻碍模型的收敛,并且通用的全局模型无法满足不同客户端的个性化需求,因此个性化的联邦学习是必不可少的。Per-FedAvg结合了经典的元学习框架MAML,以基于全局元模型训练个性化模型。不同的是,PFedMe没有直接利用全局模型,而是同时训练全局模型和个性化模型。

相关工作-知识蒸馏

Hinton等人首先引入了知识蒸馏。他们的工作采用hard损失和soft损失的加权总和作为完全损失。soft损失是学生模型的soft输出与教师模型生成的soft标签之间的损失,hard损失是学生模型的hard输出与真实标签之间的损失。Adriana Romero等提出了基于隐藏层知识特征的知识蒸馏(hints)。Zhang等人提出相互蒸馏(mutual distillation),使不同的模型能够相互从彼此中提取知识。

相关工作-联邦知识蒸馏

知识蒸馏与联邦学习的集成越来越受到关注。FedMD基于共享数据集进行集成,以计算平均分数用于指导每个客户端的知识蒸馏。相反,FD消除了对共享数据集的需求,并允许客户端在其本地数据集上计算每个标签的预测分数,并允许服务器计算每个标签的全局平均预测分数,这在本地蒸馏期间充当软标签。FedGKT结合了联邦学习和分裂学习(SL, split learning----将一个模型分成多个部分,每个部分都在一个分布式设备上)。FedDKC与FedGKT类似,可以减少异构模型知识分布之间的差距。虽然FedGKT和FedDKC可以支持资源受限的客户端,但这两种方法都需要上传本地的真实标签,这会损害客户端的隐私。而且,他们的目标是在大模型的指导下训练小模型,而不是考虑如何整合从不同客户端提取的知识来快速有效地更新大模型

方法

问题陈述

假设有 N N N个客户端( i = 1 , 2 , . . . , N i=1,2,...,N i=1,2,...,N),每个客户端有一个私有数据集,标签类别为 j = 1 , 2 , . . . , C j=1,2,...,C j=1,2,...,C。客户端 i i i的样本量为 n i n_{i} ni。为了支持分类任务,在式1中定义的关键目标是,在局部资源受限的情况下,所提出方法更新的大模型与常规模型(conventional model)之间的损失差最小,其中 Ω Ω Ω Ω C o n v Ω_{Conv} ΩConv分别是所提出方法训练的大模型和常规模型, L ( ) L() L()是损失函数, D D D是测试数据集:
eq1

动机

所提出方法是基于这样的直觉:小模型可以被视为本地私有知识的提取器,可以在服务器上使用它将嵌入在私有数据中的知识传递给大模型。

为了验证这个直觉,作者设计了一个简单的实验,其中在每一轮中,小模型由标记数据集训练,然后通过知识蒸馏基于代理数据集微调大模型,小模型作为Teacher,大模型作为Student。注意,CIFAR-10用于小模型训练,CIFAR-10的测试数据集用于评估大型模型的性能。此外,小模型为MobileNet V3 small,大模型为VGG19。

fig2

  • 图2:知识迁移和随机选择的Acc。随机选择是不微调大模型,随机选择一个分类概率值。

从图2所示的结果可以看出,即使只处理未标记的代理数据集,被小模型蒸馏后大模型的准确率可以得到显著提高。因此,基于知识转移,小模型可以与大模型共享本地私有知识,这促使作者设计能够整合联邦学习和知识蒸馏的KOALA,实现一个大-小模型协同学习过程。

KOALA

KOALA实现了一个大小模型协同学习的过程,通过小模型作为本地知识提取器,并根据从小模型中提取的知识对大模型进行微调。具体来说,在每个IoT客户端中,从服务器下载相应的小模型,并根据其私有数据在本地进行训练。在服务器端,引入双向知识蒸馏机制,支持:

  • 基于小模型的反向蒸馏对大模型进行微调
  • 基于大模型的正向蒸馏对小模型进行更新

如图3所示,KOALA包括三个步骤,即:1.本地知识提取,2.反向知识蒸馏,3.正向知识蒸馏。由于IoT客户端不仅在数据上是异构的,而且在计算能力上也是异构的,因此KOALA设计了两种学习模式,一种是同质小模型(homo),另一种是异构小模型(hete)。

本地知识提取

在此步骤中,根据相应IoT客户端的私有数据更新homo或hete小模型。提取本地知识后,将小模型上传到服务器。

反向知识蒸馏

收集到所有本地更新的小模型后,服务器启动反向蒸馏,其中大模型作为Student,小模型作为Teacher。

具体而言,在homo模式下,首先将小模型聚合生成全局小模型 w w w,然后根据代理数据 x x x生成伪标签,如下所示, T T T为蒸馏温度:
eq2

全局小模型 w w w把知识迁移到大模型 Ω Ω ,其中,大模型仅更新它的adapter,反向蒸馏损失 l o s s r h o m o loss_{r}^{homo} lossrhomo在homo模式中使用,如下所示,其中, l K L l_{KL} lKL是KL损失:
eq3

由于异构小模型不能直接聚合,在hete模式下,对小模型的输出分布进行细化和集成,生成共识软标签。为了调解输出分布中的异质性,作者引入了一种分布调整策略。假设在输出分布 f ( x , w i ) f(x,w_{i}) f(x,wi)内,最大和最小值分别是 z i , m a x z_{i,max} zi,max z i , m i n z_{i,min} zi,min,标签 j j j对应的值为 z i , j z_{i,j} zi,j,调整的值为 z ^ i , j \widehat{z}_{i,j} z i,j,定义如下,其中 w i w_{i} wi是客户端 i i i的模型, k k k是用于调整的系数:
eq4
将所有标签的调整值相加,我们可以得到:
eq5
在式5中, z ‾ i \overline{z}_{i} zi是输出分布 f ( x , w i ) f(x,w_{i}) f(x,wi)的平均值。假设所有小模型的精化分布的均值等于 A A A ( A A A是一个常数),因此:
eq6
因此可以计算出 k k k
eq7

将其代入式4,得到分布调整策略为:
eq8
根据式8,得到调整后的分布 z ^ i = { z ^ i , 1 , z ^ 2 , . . . , z ^ i , C } \widehat{z}_{i}=\left\{\widehat{z}_{i,1},\widehat{z}_{2},...,\widehat{z}_{i,C}\right\} z i={z i,1,z 2,...,z i,C}。然后,通过式9得到小模型的综合输出分布 z ~ \widetilde{z} z ,假设当前轮中,活动的客户端为集合 S S S
eq9
基于 z ~ \widetilde{z} z ,共识软标签为:
eq10
然后,我们根据公式11中定义的反向蒸馏损失 l o s s r h e t e loss_{r}^{hete} lossrhete对大模型 Ω Ω Ω进行微调。
eq11

前向知识蒸馏

在反向蒸馏之后,作者实现正向蒸馏,根据更新后的大模型更新小模型,其中大模型作为Teacher,小模型作为Student。为了计算正向蒸馏损失,需要计算输出特征损失(输出层之间的损失)和隐藏特征损失(隐藏层之间的损失)。

在homo模式中,全局小模型 w w w作为student被更新, Ω h Ω^h Ωh表示大模型中的前 h h h层, w g w^g wg表示全局小模型的前 g g g层。因此,输出特征损失 l o s s o u t h o m o loss_{out}^{homo} lossouthomo和隐藏特征损失 l o s s h i d h o m o loss_{hid}^{homo} losshidhomo分别根据式12和13计算,其中 W W W是桥接矩阵, l M S E ( ) l_{MSE}() lMSE()是MSE损失。
eq13
因此,组合两者得到前向蒸馏损失:
eq14

对于hete模式,每个小模型 w i , i ∈ S w_{i},i\in S wi,iS作为Student,假设 w i w_i wi是第 i i i个模型, w i g w_{i}^{g} wig是它的前 g g g层, W i W_{i} Wi是它的桥接矩阵,则输出损失和隐藏损失为:
eq16
i i i个小模型的前向蒸馏损失为:
eq17
最后,无论是homo模式还是hete模式,都是基于前向蒸馏损失对小模型进行更新,更新后将其分派给相关客户端开始新的学习轮,直到满足某些条件(例如,模型收敛或达到最大学习轮)。

实验设置

模型

选择TorchVision backbone,并将分类器附加到每个骨干网的最后一层,形成实验中使用的大模型和小模型。大模型的分类器被视为adapter。大模型的主干是VGG19。在homo模式中,小模型统一为MobileNetV2,在hete模式中,小模型分别为MobileNet V2、MobileNet V3 small、EfficientNet B0、ShuffleNet V2 X0_5和ShuffleNet V2 X2_0。此外,作者实现了额外的工具来计算模型FLOPs,其中使用64×64随机生成的“图像”作为输入。

数据集

作者选择了4个数据集:CIFAR-10、Fashion-MNIST、USPS和GTSRB。每个数据集的整个测试集用于评估大模型,记录其在训练前(第0轮)和每个学习轮结束时的性能。通过去除标签,代理数据集是原始训练集的子集。客户端的本地数据集采用Dirichlet分布,浓度参数为1.0(从原始数据集减去代理集再采样)。此外,代理数据集和私有客户端数据集之间没有重叠。

基线

作者在假设所有IoT客户端都有足够的本地资源来直接运行大型模型的情况下设置了基线,并使用FedAvg来更新全局模型。具体的,基线更新全局模型的工作流程包括三个步骤,即:

  • 客户端下载全局大模型
  • 对大模型进行局部微调
  • 将大模型参数上传到服务器进行全局聚合。

在客户端-服务器交互期间,服务器和客户端之间交换的是adapter,而不是整个模型(除了第一次将大模型从服务器下载到客户机)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2085718.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

单线程Redis:Redis为什么这么快

1 Redis是不是单线程 Redis 使用后台线程执行耗时的 I/O 操作,以免阻塞主线程 bio_close_file:后台 I/O 线程之一,异步关闭文件 bio_aof_fsync:后台 I/O 线程之一,异步地将 AOF(Append Only File&#xff…

C++系列-STL容器之vector

STL概念 vector基本概念vector与数组的区别vector容器的特点动态大小连续存储自动扩容尾部操作高效 vector动态扩展的含义vector常用的接口示意 vector的构造函数vector赋值操作重载赋值assign赋值 vector的容量和大小vector的插入和删除vector数据存取vector互换容器vector互换…

音视频入门基础:WAV专题(7)——FFmpeg源码中计算WAV音频文件每个packet的size值的实现

一、引言 从文章《音视频入门基础:WAV专题(6)——通过FFprobe显示WAV音频文件每个数据包的信息》中我们可以知道,通过FFprobe命令可以显示WAV音频文件每个packet(也称为数据包或多媒体包)的信息&#xff0…

VMware16安装包+详细安装教程

VMware Workstation Pro16.0安装 安装包下载: 通过百度网盘分享的文件:VMware16.0.rar 链接:https://pan.baidu.com/s/1ZSWns5kJYUmhpZFjuKXqrQ?pwdv7s2 提取码:v7s2右键解压之后的安装包【VMware-workstation-full-16.0.0-16…

FrameNet介绍——从同义词语义知识库到框架语义知识库

FrameNet 是一个为期三年的项目,获得了 NSF(美国国家科学基金会)的支持,专注于基于语料库的计算词典编纂。 项目特点 FrameNet承诺使用语料库证据(corpus evidence)来进行语义和句法的概括; 并…

网络基础-实现在Windows系统下的socket环境地址通信

实现客户端和服务端的数据交互 1.写所要实现功能的声明&#xff08;封装在tcpsocket.h文件&#xff09; #ifndef TCPSOCKET_H #define TCPSOCKET_H//在Windows下进行网络编程&#xff0c;需要引入Windows的socket库 #include <winsock2.h> //做一些预编译工作&#xff…

MyBatis结果集复杂映射超详细版(一对多关系映射)

目录 1.一对多关系映射 1.1创建两个表&#xff1a;goods表与goods_class表 1.2xml文件中两部分&#xff1a;与(存放SQL语句)1.3数据库中&#xff1a;测试SQL语句&#xff0c;涉及到的知识点&#xff1a;左连接 1.一对多关系映射 1.1创建两个表&#xff1a;goods表与goods_c…

C++对C的扩充(8.28)

1.使用C手动封装一个顺序表&#xff0c;包括成员数组1个&#xff0c;成员变量n个 代码&#xff1a; #include <iostream>using namespace std;//类型重命名 using datatype int; #define MAX 30struct seqList { private: //私有权限datatype *data; //相当于 …

【项目源码】终于有人将打字游戏和编程英语结合起来啦!编程初学者的福音

Hello&#xff01;各位彦祖&#xff0c;亦菲们&#xff01;又是美好的一天&#xff01;今天给大家分享一个Java项目源码&#xff1a;Java打字游戏项目源码&#xff01; 看到这里&#xff0c;你可能会说&#xff01; 一个破打字游戏有什么可神气的&#xff01;&#xff01;&…

【自由能系列(中级)】状态与动作的协同机制解析 ——从马尔可夫毯到大脑功能的全方位剖析

状态与动作的协同机制解析 ——从马尔可夫毯到大脑功能的全方位剖析 Synergistic Mechanism of States and Actions —— A Comprehensive Analysis from Markov Blanket to Brain Function 核心结论&#xff1a; 中文总结&#xff1a; 系统将状态划分为内部状态和隐藏或外…

Flutter中的Key

在Flutter 中&#xff0c;Key 是 几乎所有 widget 都具有的属性。为什么 widget 具有 Key 呢&#xff1f;Key的作用是什么&#xff1f; 什么是 Key Key是Widget、Element 和 SemanticNodes 的标识符。 Key 是Widget、Element 和 SemanticNodes的唯一标识。例如对于 Widget 在 …

MyBatis的学习————下篇

目录 一、动态SQL 简介 1、if标签 2、where标签 3、trim标签 4、choose、when、otherwise 5、foreach 5.1、批量删除 5.2、批量添加 6、sql标签 二、MyBatis的缓存 1、一级缓存 2、二级缓存 3、二级缓存的相关配置 4、MyBatis缓存查询的顺序 5、 第三方缓存EHCac…

如何在Windows 11上关闭无响应的应用程序?这里有详细步骤

序言 无响应的应用程序令人沮丧,但更糟糕的是这些应用程序拒绝关闭。如果你发现自己处于这种情况,我们有几种方法可以帮助你强制关闭Windows 11 PC上的这些应用程序。让我们找出可用的解决方案。 使用键盘快捷键结束程序 关闭无响应应用程序的最简单方法是使用Windows键盘…

DataWhale AI夏令营 2024大运河杯-数据开发应用创新赛-task2

DataWhale AI夏令营 2024大运河杯-数据开发应用创新赛 YOLO(You Only Look Once)上分心得分享 YOLO(You Only Look Once) YOLO算的上是近几年最火的目标检测模型了&#xff0c;被广泛的应用在工业、学术等领域。 YOLOv1&#xff08;You Only Look Once 第一版&#xff09;于 2…

基于麒麟信安操作系统的光伏发电功率预测系统完成大规模部署建设

麒麟信安操作系统&#xff0c;作为行业数智化建设的安全根基&#xff0c;为电力业务系统提供了稳定可靠的底层平台&#xff0c;在全球能源结构转型大潮中扮演着至关重要的角色。某光伏电站项目中&#xff0c;基于麒麟信安操作系统的光伏发电功率预测系统完成大规模部署建设&…

c#如何加密exe程序防止反编译附软件

1. 先说软件&#xff0c;使用的软件是Dotfuscator&#xff0c;下载地址如下&#xff1a; 链接&#xff1a;https://pan.quark.cn/s/6f2e785c003f2. 软件使用方法&#xff0c;打开软件&#xff0c;选择Create New Project 3. 找到input&#xff0c;把你需要加密的文件导入 4.…

k8s项目的发布

目录 三种发布方式 1.蓝绿发布 2.金丝雀发布&#xff08;灰度发布&#xff09; 实验&#xff1a;k8s实现金丝雀发布 3.滚动发布&#xff08;默认形式&#xff09; 因为应用升级以及新旧业务切换&#xff0c;所以在这个过程当中如何保证对外的服务正常是一个非常重要的问题…

手把手教你如何使用Python连接MySQL数据

数据库编程是在应用程序中与数据库交互和管理数据的关键部分。MySQL是一种流行的关系型数据库管理系统&#xff08;RDBMS&#xff09;&#xff0c;在Python中进行MySQL数据库编程相对容易。 本文介绍如何使用Python进行MySQL数据库编程&#xff0c;包括连接数据库、执行SQL查询…

高频面试题:SpringMVC的执行流程

SpringMVC一直以来都是面试中的重点&#xff0c;尽管随着近年来springboot和微服务的广泛流行&#xff0c;关于对springMVC的考察比重略有下降&#xff0c;但依然是面试中的重点&#xff0c;也需要我们对其有一个比较清楚和全面的认识。 如果将java的发展史中重要的组件进行排…

备忘录模式 详解

备忘录模式 简介: 保存一个对象的某个状态&#xff0c;以便在适当的时候恢复对象, 允许在不破坏封装性的前提下&#xff0c;捕获和恢复对象的内部状态。 场景: 很多地方都用到了备忘录模式, 比如网络消息的序列化和反序列化, 数据的本地保存与加载等, 最简单的json的dump和loa…