详解FedAvg:联邦学习的开山之作

news2024/11/28 16:46:36

FedAvg:2017年 开山之作

论文地址:https://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf
源码地址:https://github.com/shaoxiongji/federated-learning
针对的问题:移动设备中有大量的数据,但显然我们不能收集这些数据到云端以进行集中训练,所以引入了一种分布式的机器学习方法,即联邦学习Federal Learning。在FL中,server将全局模型下放给各client,client利用本地的数据去训练模型,并将训练后的权重上传到server,从而实现全局模型的更新。
论文贡献

  1. 提出了联邦学习这个研究方向,简单来说就是从分散的存储于各设备的数据中训练模型;
  2. 提出了FedAvg算法;
  3. 通过实验验证了FedAvg的可靠性;

总结一下就是,本文提出了FedAvg算法,这种算法融合了client上的局部随机梯度下降和server上的模型平均。作者用该算法做了不少实验,结果表明FedAvg对于unbalanced且non-iid的数据有很好的鲁棒性,并且使得在非数据中心存储的数据上进行深度网络训练所需的通信轮次减少了好几个数量级。
算法介绍

  1. 联邦随机梯度下降算法FedSGD

设定固定的学习率η,对K个client的数据计算损失梯度:
g k = ▽ F k ( w t ) g_k=\bigtriangledown F_k(w_t) gk=Fk(wt)
server将聚合每个服务器计算的梯度,以此来更新模型参数:
w t + 1 ← w t − η ∑ k = 1 K n k n g k = w t − η ▽ f ( w t ) w_{t+1}\leftarrow w_t-\eta\sum\limits_{k=1}^K\frac{n_k}{n}g_k=w_t-\eta\bigtriangledown f(w_t) wt+1wtηk=1Knnkgk=wtηf(wt)

  1. 联邦平均算法FedAvg:

在client进行局部模型的更新:
w t + 1 k ← w t − η g k w_{t+1}^k\leftarrow w_t-\eta g_k wt+1kwtηgk
server对每个client更新后的权重进行加权平均:
w t + 1 ← ∑ k = 1 K n k n w t + 1 k w_{t+1}\leftarrow \sum_{k=1}^K \frac{n_k}{n}w_{t+1}^k wt+1k=1Knnkwt+1k
注意,在这里每个client可以在本地独立地多次更新本地权重,然后将更好的权重参数发给server进行加权平均。这样做的好处是不用每更新一次就去聚合,这大大减少了通信量。
FedAvg的计算量与3个参数有关:

  • C:每轮训练选择client的比例,每一轮通信时只选择C*K个client;(K为client总数)
  • E:每个client更新本地权重时,在本地数据集上训练E轮;
  • B:client更新权重时,每次梯度下降所使用的数据量,即本地数据集的batch size;

对于一个拥有 n k n_k nk个数据样本的client,每轮通信本地参数的更新次数为:
u k = E × n k B u_k=E\times\frac{n_k}{B} uk=E×Bnk
所以我们可知,FedSGD只是FedAvg的一个特例,即当参数 E = 1 , B = ∞ E=1,B=\infty E=1B=时,FedAvg等价于FedSGD。注: B = ∞ B=\infty B=意味着batch size大小就是本地数据集大小。
下面为FedAvg的算法流程图:
FedAvg算法流程图
实验设计与实现
Q1:在训练伊始,需不需要对模型进行统一初始化?
image.png
可见,采用不同的初始化参数进行模型平均,模型性能比两个父模型都差(左图);而统一初始化后,对模型的平均可以显著减少整个训练集的loss,模型性能优于两个父模型(右图)。
该结论是实现FL的重要支持,在每一轮通信时,server有必要发布全局模型,使各client采用相同的参数在本地数据集上进行训练,可以有效减少loss。
Q2:数据集怎么设置?
原文中主要研究了MNIST数据集和一个莎士比亚作品集构建的数据集,但我们在这里主要关注MNIST数据集和Cifar-10数据集,这两个数据集也是以后FL领域工作最常用的。
在模型选择方面,作者选择了多层感知机MLP和卷积神经网络CNN。
在数据集划分方面,作者假设有100个client,对于MNIST数据集,进行了iid和non-iid两种划分:

  • MNIST-iid:数据随机打乱分给100个client,每个client得到600个样例;
  • MNIST-non-iid:按数字label将数据集划分为200个大小为300的碎片,每个client两个碎片,意味着每个client至多只能获得两种label的样例;

对于Cifar-10数据集,做了iid划分。
Q3:实验咋做的?
作者指出,相比于传统模式下训练模型时计算开销为主通信开销较小的情况,在FL中,通信开销才是大头,因此减少通信开销才是我们需要关注的,作者提出可以通过加大计算以减少训练模型所需的通信轮数。作者提出主要有两种方法:提高并行度、增加每个client的计算量
而FedAvg的计算量在前面我们也给出过,再来看一下:
u k = E × n k B u_k=E\times\frac{n_k}{B} uk=E×Bnk
提高并行度:固定参数E,对C和B进行讨论。注:此处C=0时,算法也会选择一个client参与,详见上面的算法流程图。
2NN测试集acc 97%,CNN测试集acc 99%所需的通信轮数

  • B = ∞ B=\infty B=时,增加client的比例C,效果提升的优势较小;
  • B = 10 B=10 B=10时,效果显著改善了,特别是在non-iid情况下;
  • B = 10 , C ≥ 10 B=10,C\geq10 B=10,C10时,收敛速度明显改进,当client到一定数量后,收敛速度增加也不明显了;

增加每个client的计算量:根据公式,可以通过增加E或者减小B实现。
对测试集到达期望acc所需的通信轮数

  • 每个通信轮次内增加更多的本地SGD可以显著降低通信成本;
  • 对于unbalanced-non-iid的莎士比亚数据集减少的通信轮数更多,推测可能某些client有相对较大的本地数据集,这种情况下增加了本地训练的价值;

Q4:FedAvg VS FedSGD?
image.png
蓝色实现即为FedSGD。由图可知,FedAvg相比FedSGD不仅降低通信轮数,还具有更高的测试精度。推测是平均模型产生了类似Dropout的正则化效益。
Q5:加大每个client的计算量会不会导致过拟合?
image.png
加大每个client的计算量(主要体现在加大E),确实可能导致训练损失停滞或发散。所以在实际应用时,在训练后期减少各client的E,或者在loss有震荡的苗头时即刻停止,这样做有助于收敛。
Q6:在Cifar-10数据集上的表现如何?
如下图所示:
image.png
image.png
针对第一张图的一点吐槽,你去拿分布式深度学习去pk单机上的深度学习,去比通信轮数,这不是太不公平了。。。
总结展望
作者证明了FL在实践中是可行的,能够用相对较少的通信轮数训练出高质量的模型。并且提出未来的一个方向就是通过差分隐私、安全多方技术等隐私保护技术去组合FL以提供隐私保护。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1802490.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

电调, GPS与飞塔

电调油门行程校准: 断电-----油门推到最高-------电调上电-------滴滴------油门推到最低---滴滴滴---校准完成。 http://【【教程】油门行程校准(航模,电机,电调)】https://www.bilibili.com/video/BV1yJ411J7aX?v…

区间预测 | Matlab实现QRCNN-BiGRU-Attention分位数回归卷积双向门控循环单元注意力机制时序区间预测

区间预测 | Matlab实现QRCNN-BiGRU-Attention分位数回归卷积双向门控循环单元注意力机制时序区间预测 目录 区间预测 | Matlab实现QRCNN-BiGRU-Attention分位数回归卷积双向门控循环单元注意力机制时序区间预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实…

企业办公网安全管控挑战与解决方案

在数字化浪潮的推动下,企业正经历前所未有的变革。然而,随之而来的是一系列复杂的网络安全风险和挑战。我们的网络边界不再清晰,各种设备轻松接入企业网络,这不仅带来了便利,也极大地增加了安全风险。想象一下&#xf…

[AI Google] 双子座模型家族迎来新突破:更快的模型、更长的上下文、AI代理等更多功能

Google发布了Gemini模型家族的更新,包括新的1.5 Flash模型,该模型旨在提高速度和效率,以及Project Astra,这是对未来AI助手愿景的展示。1.5 Flash是专为大规模高频任务优化的轻量级模型,具有突破性的长上下文窗口。同时…

opencv 在飞行堡垒8中调用camera导致设备消失

简介 使用 OpenCV 库时, 在最后调用cv::destroyAllWindows()之后设备管理器中的摄像头设备消失了, 看看是怎么触发的, 后面再慢慢研究RootCause是什么。 步骤 设备管理器原来摄像头显示 1. 代码 main.cpp Note: 1. haarcascade_frontalface_default…

什么是助听器呢?

助听器是一种用于改善听力障碍患者听觉能力的装置。它通过放大声音,使原本听不到或听不清的声音能够被听力受损者感知,从而提高其交流能力和生活质量。 助听器的基本工作原理是,将声音转化为电信号,经过内部电路处理后&#xff0c…

算法006:查找总价格为目标值的两个商品

. - 力扣(LeetCode). - 备战技术面试?力扣提供海量技术面试资源,帮助你高效提升编程技能,轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/he-wei-sde-liang-ge-shu-zi-lcof/ 题干说的很复杂,简化一…

嵌入式Linux系统编程 — 3.2 stat、fstat 和 lstat 函数查看文件属性

目录 1 文件有哪些属性 2 stat函数 2.1 stat函数简介 2.2 struct stat 结构体 2.3 struct timespec 结构体 2.4 示例程序 3 fstat 和 lstat 函数 3.1 fstat 函数 3.2 lstat 函数 1 文件有哪些属性 Linux文件属性是对文件和目录的元数据描述,包括文件类型…

浅谈安全用电管理系统对重要用户的安全管理

1用电安全管理的重要性   随着社会经济的不断发展,电网建设力度的不断加大,供电的可靠性和供电质量日益提高,电网结构也在不断完善。但在电网具备供电的条件下,部分高危和重要电力用户未按规定实现双回路电源线路供电&#xff1…

问题:设备管理指标为完好率不低于( ),待修率不高于5%,事故率不高于1%。 #知识分享#经验分享#经验分享

问题:设备管理指标为完好率不低于( ),待修率不高于5%,事故率不高于1%。 A、100% B、95% C、90% D、80% 参考答案如图所示

自动驾驶---Control之LQR控制

1 前言 在前面的系列博客文章中为读者阐述了很多规划相关的知识(可参考下面专栏),本篇博客带领读者朋友们了解控制相关的知识,后续仍会撰写规控相关文档。 在控制理论的发展过程中,人们逐渐认识到对于线性动态系统的控…

vue数组在浏览器里可以看到值, 但是length为空

arr数组 length为0, 检查了代码在created 里调用了 this.getEnergyList(); 和 this.initChart(); 问题就在这里, this.initChart用到了getEnergyList里的数据, 造成了数据异步, 把this.initChart(); 放入 this.getEnergyList(); 方法里即可解决问题

Elasticsearch 认证模拟题 - 13

一、题目 集群中有索引 task3,用 oa、OA、Oa、oA 查询结构是 4 条,使用 dingding 的查询结果是 1 条。通过 reindex 索引 task3 为 task3_new,能够使 task3_new 满足以下查询条件。 使用 oa、OA、Oa、oA、0A、dingding 查询都能够返回 6 条…

【计算机视觉】数字图像处理基础:以像素为单位的图像基本运算(点运算、代数运算、逻辑运算、几何运算、插值)

0、前言 在上篇文章中,我们对什么是数字图像、以及数字图像的组成(离散的像素点)进行了讲解🔗【计算机视觉】数字图像处理基础知识:模拟和数字图像、采样量化、像素的基本关系、灰度直方图、图像的分类。 我们知道&a…

Ruoyi-Vue-Plus 下载启动后菜单无法点击展开,

1.Ruoyi-Vue-Plus框架下载后运行 2.使用mock数据 3.进入页面后无法点击菜单 本以为是动态路由或者菜单逻辑出了问题,最后发现是websocket的问题 解决办法 把这两行代码注释 页面菜单即可点击。 以上。

蓝屏绿屏黑屏?别急,有它们仨【送源码】

使用Windows系统的电脑时,可能会碰到各种问题,导致系统无法正常使用。 这些问题都有一个统一的专业叫法就是bug! 系统一旦出现bug,最明显的特点就是, ①电脑蓝屏 电脑蓝屏是最经典的,从XP时代一直延续到…

数据结构之快速排序算法(快排)【图文详解】

P. S.:以下代码均在VS2019环境下测试,不代表所有编译器均可通过。 P. S.:测试代码均未展示头文件stdio.h的声明,使用时请自行添加。 博主主页:LiUEEEEE                        …

Cesium开发环境搭建(一)

1.下载安装Node.js 进入官网地址下载安装包 Node.js — Download Node.js https://cdn.npmmirror.com/binaries/node/ 选择对应你系统的Node.js版本,这里我选择的是Windows系统、64位 安装完成后,WINR,输入node --version,显示…

全网最强下载神器IDM之如何用IDM下载百度网盘文件不限速 如何用IDM下载百度云资源 IDM激活码免费版下载安装

百度网盘是比较早的网盘类应用,用户群体比较多,但百度网盘对于非会员用户限速比较严重。IDM是非常好用的下载工具,那么我们如何用IDM下载百度网盘文件不限速?我们可以通过多种方法使用IDM下载百度网盘文件。下面我们就来看如何用I…

Windows11系统 和Android 调试桥(Android Debug Bridge,ADB)工具安装,app抓取日志内容

文章目录 目录 文章目录 安装流程 小结 概要安装流程技术细节小结 概要 Android调试桥(ADB)是一种多功能命令行工具,它允许开发者与连接到计算机上的Android设备进行通信和控制。ADB工具的作用包括但不限于: 安装和卸载应用程序&…