联邦学习FedAvg-基于去中心化数据的深度网络高效通信学习

news2024/11/26 4:53:43

        随着计算机算力的提升,机器学习作为海量数据的分析处理技术,已经广泛服务于人类社会。 然而,机器学习技术的发展过程中面临两大挑战:一是数据安全难以得到保障,隐私泄露问题亟待解决;二是网络安全隔离和行业隐私,不同行业部门之间存在数据壁垒,导致数据形成“孤岛”无法安全共享,而仅凭各部门独立数据训练的机器学习模型性能无法达到全局最优化。为解决上述问题,谷歌提出了联邦学习(FL,federated learning)技术。

        本文主要对联邦学习的开山之作《Communication-Efficient Learning of Deep Networks from Decentralized Data》 进行重点内容的解读与整理总结。

论文链接:Communication-Efficient Learning of Deep Networks from Decentralized Data

源码实现:https://gitcode.net/mirrors/WHDY/fedavg?utm_source=csdn_github_accelerator 

目录

摘要

1. 介绍

1.1 问题来源

1.2 本文贡献

1.3 联邦学习特性

1.4 联邦优化

1.5 相关工作

1.6 联邦学习框架图

2. 算法介绍

2.1 联邦随机梯度下降(FedSGD)

2.2 联邦平均算法(FedAvg)

3. 实验设计与实现

3.1 模型初始化

3.2 数据集的设置

3.2.1 MNIST数据集

3.2.2 莎士比亚作品集

3.3 实验优化

3.3.1 增加并行性

3.3.2 增加客户端计算量

 3.4 探究客户端数据集的过度优化

3.5 CIFAR实验

3.6 大规模LSTM实验

4. 总结展望

 摘要

现代移动设备拥有大量的适合模型学习的数据,基于这些数据训练得到的模型可以极大地提升用户体验。例如,语言模型能提升语音设别的准确率和文本输入的效率,图像模型能自动筛选好的照片。然而,移动设备拥有的丰富的数据经常具有关于用户的敏感的隐私信息且多个移动设备所存储的数据总量很大,这样一来,不适合将各个移动设备的数据上传到数据中心,然后使用传统的方法进行模型训练。作者提出了一个替代方法,这种方法可以基于分布在各个设备上的数据(无需上传到数据中心),然后通过局部计算的更新值进行聚合来学习到一个共享模型。作者定义这种非中心化方法为“联邦学习”。作者针对深度网络的联邦学习任务提出了一种实用方法,这种方法在学习过程中多次对模型进行平均。同时,作者使用了五种不同的模型和四个数据集对这种方法进行了实验验证。实验结果表明,这种方法面对不平衡以及非独立同分布的数据,具有较好的鲁棒性。在这种方法中,通信所产生的资源开销是主要的瓶颈,实验结果表明,与同步随机梯度下降相比,该方法的通信轮次减少了10-100倍。

1 介绍

1.1 问题来源

        移动设备中有大量数据适合机器学习任务,利用这些数据反过来可以改善用户体验。例如图像识别模型可以帮助用户挑选好的照片。但是这些数据具有高度私密性,并且数据量大,所以我们不可能把这些数据拿到云端服务器进行集中训练。论文提出了一种分布式机器学习方法称为联邦学习(Federal Learning),在该框架中,服务器将全局模型下发给客户,客户端利用本地数据集进行训练,并将训练后的权重上传到服务器,从而实现全局模型的更新。

1.2 本文贡献

  • 提出了从分散的存储于各个移动设备的数据中训练模型是一个重要的研究方向
  • 提出了一个简单实用的算法来解决这种在非中心化设置下的学习问题
  • 做了大量实验来评估所提算法

        具体来说,本文介绍了“联邦平均”算法,这种算法融合了客户端上的局部随机梯度下降计算与服务器上的模型平均。作者使用该算法进行了大量实验,结果表明了这种算法对于不平衡且非独立同分布的数据具有很好的鲁棒性,并且使得在非中心存储的数据上进行深度网络训练所需的通信轮次减少了几个数量级。

1.3 联邦学习特性

  • 从多个移动设备中存储的真实数据中进行模型训练比从存储在数据中心的数据中进行模型训练更具优势
  • 由于数据具有隐私,且多个移动设备所存储的数据总量很大,因此不适合将其上传至数据中心再进行模型训练
  • 对于监督学习任务,数据中的标签信息可以从用户与应用程序的交互中推断出来

1.4 联邦优化

        传统分布式学习关注点在于如何将一个大型神经网络训练分布式进行,数据仍然可能是在几个大的训练中心存储。而联邦学习更关注数据本身,利用联邦学习保证了数据不出本地,并根据数据的特点,对学习模型进行改进。相比于典型的分布式优化问题,联邦优化具有几个关键特性:

  • Non-IID:数据的特征和分布在不同参与方间存在差异
  • Unbalanced:一些用户会更多地使用服务或应用程序,导致本地训练数据量存在差
  • Massively distributed:参与优化的用户数>>平均每个用户的数据量
  • Limited communication:无法保证客户端和服务器端的高效通信

 本文重点关注优化任务中非独立同分布和不平衡问题,以及通信受限的临界属性。

注:独立同分布假设(IID)

        非凸神经网络的目标函数:

对于一个机器学习的问题来说,有,即用模型参数w预测实例的损失。

        设有K个client,第k个client的数据点为P_{k},对应的数据集数量为n_{k}=\left | P_{k} \right |上式可写为:

P_{k}上的数据集是随机均匀采样的,称IID设置,此时有:

不成立则称Non-IID。 

1.5 相关工作

        相关工作中,2010年通过迭代平均本地训练的模型来对感知机进行分布式训练,2015年研究了语音识别深度神经网络的分布式训练,在2015论文里研究了使用“软”平均的异步训练方法。这些工作都考虑的是数据中心化背景下的分布式训练,没有考虑具有数据不平衡且非独立同分布特点的联邦学习任务。但是它们提供了一种思路,即通过迭代平均本地训练模型的算法来解决联邦学习的问题。与本文的研究动机相似在这篇论文中讨论了保护设备中的用户数据的隐私的优点。而在这篇论文中,作者关注于训练深度网络,强调隐私的重要性以及通过在每一轮通信中仅共享一部分参数,进而降低通信开销;但是,他们也没有考虑数据的不平衡以及非独立同分布性,并且他们的研究工作缺乏实验评估。

1.6 联邦学习框架图

2 算法介绍

2.1 联邦随机梯度下降(FedSGD)

设置固定的学习率η,对K个客户端的数据计算其损失梯度:

中心服务器聚合每个客户端计算的梯度,以此来更新模型参数:

其中,

2.2 联邦平均算法(FedAvg)

在客户端进行局部模型的更新:

中心服务器对每个客户端更新后的参数进行加权平均:

每个客户端可以独立地更新模型参数多次,然后再将更新好的参数发送给中心服务器进行加权平均:

FedAvg的计算量与三个参数有关:

  • C:每轮训练选择客户端的比例
  • E:每个客户端更新参数的循环次数所设计的一个因子
  • B:客户端更新参数时,每次梯度下降所使用的数据量

对于一个拥有n_{k}个数据样本的客户端,每轮本地参数更新的次数为:

注:FedSGD只是FedAvg的一个特例,即当参数E=1,B=∞时,FedAvg等价于FedSGD。
 
FedSGD和FedAvg的关系示意图:
地址:https://blog.csdn.net/biongbiongdou/article/details/104358321

3 实验设计与实现

3.1 模型初始化

实验设置
  • 数据集:MNIST中600个无重复的独立同分布样本
  • E=20; C=1; B=50; 中心服务器聚合一次
  • 不同模型使用不同/相同的初始化模型,并通过θ对两模型参数进行加权求和
       

研究模型平均对模型效果的影响:

        这里有两种情况,一种是不同模型使用不同的初始化模型;一种是不同模型使用相同的初始化模型。并且可以通过参数控制权重比进行模型的加权求和。

        可看到,采用不同的初始化参数进行模型平均后,平均模型的效果变差,模型性能比两个父模型都差;采用相同的初始化参数进行模型平均后,对模型的平均可以显著的减少整个训练集的损失,模型性能优于两个父模型。

        该结论是用于实现联邦学习的重要支撑,在每一轮训练时,server发布全局模型,使各个client采用相同的参数模型进行训练,可以有效的减少训练集的损失。

3.2 数据集的设置

        初步研究包括两个数据集三个模型族,前两个模型用于识别MNIST数据集,后一个用于实现莎士比亚作品集单词预测。

3.2.1 MNIST数据集

2NN:拥有两个隐藏层,每层200个神经元的多层感知机模型,ReLu激活;

CNN:两个卷积核大小为5X5的卷积层(分别是32通道和64通道,每层后都有一个2X2的最大池化层);

IDD:数据随机打乱分给100个客户端,每个客户端600个样例;

Non-IDD:按数字标签将数据集划分为200个大小为300的碎片,每个客户端两个碎片;

  • 3.2.2 莎士比亚作品集

LSTM:将输入字符嵌入到一个已学习的8维空间中,然后通过两个LSTM层处理嵌入的字符,每层256个节点,最后,第二个LSTM层的输出被发送到每一个字符有一个节点的softmax输出层,使用unroll的80个字符长度进行训练;

Unbalanced-Non-IID:每个角色形成一个客户端,共1146个客户端;

Balanced-IID:直接将数据集划分给1146个客户端;

3.3 实验优化

        在数据中心存储的优化中,通信开销相对较小,计算开销占主导地位。而在联邦优化中,任何一个单一设备所具有的数据量较少,且现代移动设备有相对快的处理器所以这里更关注通信开销因此,我们想要使用额外的计算来减少训练模型所需通信的轮次主要有两个方法,分别是提高并行度以及增加每个客户端的计算量。

3.3.1 增加并行性

固定参数E,对C和B进行讨论。

  •  当B=∞时,增加客户端比例,效果提升的优势较小;
  • 当B=10时,有显著改善,特别是在Non-IID情况下;
  • 在B=10,当C≥0.1时,收敛速度有明显改进,当用户达到一定数量时,收敛增加的速度不再明显。

3.3.2 增加客户端计算量

对于增加每个客户端的计算量,可以通过减小B或者增加E来实现。

  • 每轮增加更多的本地SGD更新可以显著降低通信成本;
  • 对于Unbalanced-Non-IDD的莎士比亚数据减少通信轮数倍数更多,推测可能某些客户端有相对较大的本地数据集,使得增加本地训练更有价值;

 将上述实验结果用折线图的形式展示,这里蓝色线表示的是联邦随机梯度下降的结果:

  • FedAvg相比FedSGD不仅降低通信轮数,还具有更高的测试精度。推测是平均模型产生了类似Dropout的正则化效益; 

 3.4 探究客户端数据集的过度优化

        在E=5以及E=25的设置下,对于大的本地更新次数而言,联邦平均的训练损失会停滞或发散;因此在实际应用时,对于一些模型,在训练后期减少本地训练周期将有助于收敛。 

3.5 CIFAR实验

在CTFAR数据集上进行实验,模型是TensorFlow教程中的模型包括两个卷积层,两个全连接层和一个线性传输层,大约10^6个参数。下表给出了baselineSGD、FedSGD和FedAvg达到三种不同精度目标的通信轮数。

不同学习率下FedSGD和FedAvg的曲线:

3.6 大规模LSTM实验

 为了证明我们的方法对于解决实际问题的有效性,我们进行了一项大规模单词预测任务。

训练集包含来自大型社交网络的100万个公共帖子。我们根据作者对帖子进行分组,总共有超过50个客户端。我们将每个客户的数据集限制为最多5000个单词。模型是一个256节点的LSTM,其词汇量为10000个单词。每个单词的输入和输出嵌入为192维,并与模型共同训练;总共有4950544个参数,使用10个字符的unroll。

对于联邦平均和联邦随机梯度下降的最佳学习率曲线:

  • 相同准确率的情况下,FedAvg的通信轮数更少;测试精度方差更小;
  • E=1比E=5的表现效果更好; 

4 总结展望

         我们的实验表明,联邦学习可以在实践中实现,因为它可以使用相对较少的几轮通信来训练高质量的模型,这一点在各种模型体系结构上得到了证明:一个多层感知器、两个不同的卷积NNs、一个两层LSTM和一个大规模LSTM。虽然联邦学习提供了许多实用的隐私保护,但是通过差分隐私、安全多方计算提供了可以提供更有力的保障,或者他们的组合是未来工作的一个有趣方向。请注意,这两类技术最自然地应用于像FedAvg这样的同步算法。

参考文章:

https://blog.csdn.net/qq_41605740/article/details/124584939?spm=1001.2014.3001.5506

https://blog.csdn.net/weixin_45662974/article/details/119464191?spm=1001.2014.3001.5506 

https://zhuanlan.zhihu.com/p/515756280 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/949181.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【HCIP】18.防火墙

区域隔离,以防火墙的接口为中心定义区域,在防火墙中不同区域互访使用策略来进行控制 NGFW,下一代防火墙,除了是否对他通过进行判断,也可以对安全进行判断(例如是否是病毒,DDOS攻击)…

常见的下载方式

一. 使用 window.open() 使用场景 // 1. 先封装一个实习下载的函数 export const download (path) > {window.open(下载的接口,例如:/fs/download?path path) } // 2. 使用:在需要下载的地方调用download函数,传入下载的u…

Data Rescue Professional for Mac:专业的数据恢复工具

在数字化时代,我们的生活和工作离不开电脑和存储设备。但是,意外情况时常发生,例如误删除文件、格式化硬盘、病毒攻击等,这些都可能导致重要的数据丢失。面对数据丢失,我们迫切需要一款可靠的数据恢复工具。今天&#…

ASEMI整流桥GBU816的原理和应用

编辑-Z 摘要:整流桥GBU816是一种用于将交流电转换为直流电的电子元器件。本文将从原理、结构、应用以及优点等四个方面对整流桥GBU816进行详细的阐述。 1、整流桥GBU816的原理 整流桥GBU816由四个二极管组成,分别连接在一个桥形电路中。当输入交流电通…

TikTok选品分析:越南7月家电销量第一,这款吸尘器凭什么?

随着经济发展,人们的生活向智能化、便捷化发展,消费者的消费喜好也随之产生变化。家电也不例外,传统吸尘器因其体积较大、清洁不便正逐渐被淘汰。取而代之的是手持吸尘器,其凭借轻便、多功能的特点迅速赢得消费者的喜爱。 过去一…

腾讯云国际代充-GPU服务器安装驱动教程NVIDIA Tesla

腾讯云国际站GPU 云服务器是基于 GPU 的快速、稳定、弹性的计算服务,主要应用于深度学习训练/推理、图形图像处理以及科学计算等场景。 GPU 云服务器提供和标准腾讯云国际 CVM 云服务器一致的方便快捷的管理方式。 GPU 云服务器通过其强大的快速处理海量数据的计算性…

【Python】利用python-docx生成word版本学生花名册

如图,可以用python创建word文档,生成一个学生的花名册。生成的过程:先下载第三方依赖包,安装依赖包,然后引入依赖文件,创建docx文件,添加标题,创建表头,创建表格正文&…

创作纪念日-我的第1024天

机缘 不知不觉已经成为创作者的第1024天啦… … 刚开始接触博客的初衷就是为了记笔记📒、记总结📝,或许对于当时就等同于是为了找工作。坚持学习并持续输出博客一年后,这时我发现再写博客,不在是为了找一份工作&…

比亚迪宋L高调亮相成都车展,媒介盒子多家媒体助阵

哈喽,大家好,今天媒介盒子小编又来跟大家分享媒体推广的干货知识了,本篇分享的主要内容是:比亚迪宋L的营销策略。 比亚迪宋L又于2023年8月25日在成都车展上首次亮相,该车将配备比亚迪黑科技中的CTB技术、云辇-C底盘系统和iTAC系统等,预计将在今年第四季…

python教程:如何写类?

目录标题 前言类的定义知识点扩展:构建和初始化1. __ new__(cls,[…)2. __ init__(self,[…)3. __ del__(self) 尾语 前言 嗨喽~大家好呀,这里是魔王呐 ❤ ~! python更多源码/资料/解答/教程等 点击此处跳转文末名片免费获取 类的定义 Python中&#…

正中优配:A股早盘三大股指微涨 华为概念表现活跃

周三(8月30日),到上午收盘,三大股指团体收涨。其间上证指数涨0.06%,报3137.72点;深证成指和创业板指别离涨0.33%、0.12%;沪深两市合计成交额6423.91亿元,总体来看,两市个…

一文搞懂CAN和CAN FD总线协议

一、CAN与CAN FD的概念 1、CAN是什么 控制器局域网总线(CAN,Controller Area Network)是一种用于实时应用的串行通讯协议总线,它可以使用双绞线来传输信号,是世界上应用最广泛的现场总线之一。 CAN协议用于汽车中各种…

正中优配:创业板怎么开通

作为我国资本市场的一个重要组成部分,股票市场一直是出资者追逐高收益的抢手挑选。而近年来,创业板作为我国股票市场上的一颗新星,备受创业者、出资者的关注。但关于一些新手出资者来说,可能对“创业板怎样注册”这个问题还比较陌…

一文看懂开发者需要了解的信创概念

信创这个概念对于大家来说并不陌生,至少我们在海量的新闻中会时不时的听到这个概念,特别是在西方国家对中国进行技术封锁加剧时,证券市场中它还会时不时成为一个风口板块。 其实“信创”理解起来也并不困难,就像它的字面意思&…

手把手教你Jenkins整合Jmeter实现自动化接口测试

01、在机器上安装jmeter 下载:http://jmeter.apache.org/download_jmeter.cgi 这里我用了一台Windows安装jmeter用来写接口测试的脚本,启动前修改jmeter.properties 中 jmeter.save.saveservice.output_format值为xml。 编写接口测试脚本: …

外贸软件鞋类行业管理难点及解决方案

鞋子作为一种常见的商品,在出口外贸中占据着重要的地位。近几年,随着我国经济的建设步伐的不断加快,对外贸易活跃度也随之得以提升,中产阶层的消费人群及需求量都在不断增长,其中鞋业也经历了急剧的发展,成…

汽车自适应巡航系统控制策略研究

目 录 第一章 绪论 .............................................................................................................................. 1 1.1 研究背景及意义 ..........................................................................................…

C# 如何将使用的Dll嵌入到.exe应用程序中?

文章目录 前言详细实操简要步骤 前言 有没有想自己开发的exe保留一点神秘,不想让他人知道软件使用了哪些dll; 又或许是客户觉得一个软件里面的dll文件太多了,能不能简单一点,直接双击.exe就可以直接运行了,别搞那么多乱七八糟的。…

Android开发仿美团购物左右联动列表

概述 Android开发左右联动列表,仿照美团外卖点餐时,左右列表可以联动。 详细 Android开发仿美团购物左右联动列表 概述 左右联动列表是仿照美团外卖点餐时,左右列表可以联动。比如右边列表会有小项对应左边的,滑动时会置顶&a…

数字IC验证高频面试问题整理—附答案(三)

最近大家无不在讨论IC秋招,秋招想必缺的就是面试题目了。这不就来了~ 共150道验证高频面试题整理~含答案(文末可领取全部题目) Q1.二进制码、格雷码、独热码的特点 二进制码:基本的机器语言,每一位只能是0或1&…