联邦学习 (FL) 中常见的3种模型聚合方法的 Tensorflow 示例

news2025/1/17 14:03:56

联合学习 (FL) 是一种出色的 ML 方法,它使多个设备(例如物联网 (IoT) 设备)或计算机能够在模型训练完成时进行协作,而无需共享它们的数据。

“客户端”是 FL 中使用的计算机和设备,它们可以彼此完全分离并且拥有各自不同的数据,这些数据可以应用同不隐私策略,并由不同的组织拥有,并且彼此不能相互访问。

使用 FL,模型可以在没有数据的情况下从更广泛的数据源中学习。FL 的广泛使用的领域如下:

  • 卫生保健
  • 物联网 (IoT)
  • 移动设备

由于数据隐私对于许多应用程序(例如医疗数据)来说是一个大问题,因此 FL 主要用于保护客户的隐私而不与任何其他客户或方共享他们的数据。FL的客户端与中央服务器共享他们的模型更新以聚合更新后的全局模型。全局模型被发送回客户端,客户端可以使用它进行预测或对本地数据采取其他操作。

FL的关键概念

数据隐私:适用于敏感或隐私数据应用。

数据分布:训练分布在大量设备或服务器上;模型应该能够泛化到新的数据。

模型聚合:跨不同客户端更新的模型并且聚合生成单一的全局模型,模型的聚合方式如下:

  • 简单平均:对所有客户端进行平均
  • 加权平均:在平均每个模型之前,根据模型的质量,或其训练数据的数量进行加权。
  • 联邦平均:这在减少通信开销方面很有用,并有助于提高考虑模型更新和使用的本地数据差异的全局模型的收敛性。
  • 混合方法:结合上面多种模型聚合技术。

通信开销:客户端与服务器之间模型更新的传输,需要考虑通信协议和模型更新的频率。

收敛性:FL中的一个关键因素是模型收敛到一个关于数据的分布式性质的良好解决方案。

实现FL的简单步骤

  1. 定义模型体系结构
  2. 将数据划分为客户端数据集
  3. 在客户端数据集上训练模型
  4. 更新全局模型
  5. 重复上面的学习过程

Tensorflow代码示例

首先我们先建立一个简单的服务端:

 importtensorflowastf
 
 # Set up a server and some client devices
 server=tf.keras.server.Server()
 devices= [tf.keras.server.ClientDevice(worker_id=i) foriinrange(4)]
 
 # Define a simple model and compile it
 inputs=tf.keras.Input(shape=(10,))
 outputs=tf.keras.layers.Dense(2, activation='softmax')(inputs)
 model=tf.keras.Model(inputs=inputs, outputs=outputs)
 model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
 
 # Define a federated dataset and iterate over it
 federated_dataset=tf.keras.experimental.get_federated_dataset(devices, model, x=X, y=y)
 forx, yinfederated_dataset:
     # Train the model on the client data
     model.fit(x, y)

然后我们实现模型聚合步骤:

1、简单平均

 # Average the updated model weights
 model_weights=model.get_weights()
 fordeviceindevices:
     device_weights=device.get_weights()
     fori, (model_weight, device_weight) inenumerate(zip(model_weights, device_weights)):
         model_weights[i] = (model_weight+device_weight) /len(devices)
 
 # Update the model with the averaged weights
 model.set_weights(model_weights)

2、加权平均

 # Average the updated model weights using weights based on the quality of the model or the amount of data used to train it
     model_weights=model.get_weights()
     total_weight=0
     fordeviceindevices:
         device_weights=device.get_weights()
         weight=compute_weight(device)  # Replace this with a function that returns the weight for the device
         total_weight+=weight
         fori, (model_weight, device_weight) inenumerate(zip(model_weights, device_weights)):
             model_weights[i] =model_weight+ (device_weight-model_weight) * (weight/total_weight)
 
 # Update the model with the averaged weights    
 model.set_weights(model_weights)

3、联邦平均

 # Use federated averaging to aggregate the updated models
 model_weights=model.get_weights()
 client_weights= []
 fordeviceindevices:
     client_weights.append(device.get_weights())
 server_weights=model_weights
 for_inrange(num_rounds):
     fori, deviceinenumerate(devices):
         device.set_weights(server_weights)
         model.fit(x[i], y[i])
         client_weights[i] =model.get_weights()
     server_weights=server.federated_average(client_weights)
 
 # Update the model with the averaged weights
 model.set_weights(server_weights)

以上就是联邦学习中最基本的3个模型聚合方法,希望对你有所帮助

https://avoid.overfit.cn/post/d426b291716c48409d3b68704545f6d0

作者:Dr Roushanak Rahmat, PhD

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/155736.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于Java springmvc+mybatis酒店信息管理系统设计和实现

基于Java springmvcmybatis酒店信息管理系统设计和实现 博主介绍:5年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 超级帅帅吴 Java毕设项目精品实战案例《500套》 欢迎点赞 收藏 ⭐留言 文末获取…

程序员接私活的几个平台和建议,避免掉坑!

大家对于程序员接私活这件事的看法,褒贬不一。但是你如果确实用钱,价格又合适,那就大胆去接。 如果不那么缺钱,那么接私活之前先考虑清楚,如果自己将空余时间用在接私活所产生的价值是不是大于提升自己。如果是的话&a…

2022年 大学生工程训练比赛[物料搬运]

本人和团结参加了2022年大学生工程训练(简称工训赛)校赛选拔,准备了几个月的时间和花费了较多的资金,由于疫情等多种情况,很遗憾未能参加湖南省省赛,过了这么久还是写个博客记录参赛准备和调试过程。 目录 一、比赛要求 二、整体…

第十章面向对象编程(高级部分)

10.1 类变量和类方法(关键字static) 10.1.31类变量快速入门 思考: 如果,设计一个 int count 表示总人数,我们在创建一个小孩时,就把 count 加 1,并且 count 是所有对象共享的就 ok 了! package com.hspedu.static_;public class ChildGame {…

MS【1】:Metric

文章目录前言1. Dice Loss1.1. Dice coefficient1.2. F1 score - Dice1.3. Dice Loss2. Sensitivity & Specificity2.1. Sensitivity2.2. Specificity3. Hausdorff distance3.1. 概念3.2. 单向 Hausdorff distance3.3. 双向 Hausdorff distance3.4. 部分 Hausdorff distanc…

使用ResNet18实现CIFAR100数据集的训练

如果对你有用的话,希望能够点赞支持一下,这样我就能有更多的动力更新更多的学习笔记了。😄😄 使用ResNet进行CIFAR-10数据集进行测试,这里使用的是将CIFAR-10数据集的分辨率扩大到32X32,因为算力相关的…

二、数据仓库模型设计

数据仓库模型设计一、数据模型二、关系模型三、维度模型1、事实表(1)事务事实表(2)周期快照事实表(3)累计快照事实表(4)无事实的事实表2、维度表3、维度模型类型(1&#…

LVGL学习笔记16 - 进度条Bar

目录 1. Parts 2. 模式 2.1 LV_BAR_MODE_SYMMETRICAL:对称模式 2.2 LV_BAR_MODE_RANGE:范围模式 3. 动画 4. 样式 4.1 方向 4.2 渐变色 4.3 增加边框 4.4 滚动条方向 进度条有一个背景和一个指示器组成,通过lv_bar_create创建对象。…

mysql多表查询

一、关联查询(联合查询) 1.1 什么是关联查询 关联查询:两个或者多个表,一起查询。 前提条件: 这些一起查询的表之间是有关系的(一对一、一对多),它们之间一定是有关联字段&#x…

初识IL2CPP

在Unity中进行打包时,有两种打包方式选择:Mono和IL2CPP Mono和IL2Cpp是Unity的脚本后处理方式,通过脚本后处理实现Unity的跨平台 1.Mono (1). Mono组成组件: C#编辑器,CLI虚拟机,以及核心类别程序库 (2).跨平台过程 Mo…

【Linux】多线程概念

目录🌈前言🌸1、Linux线程概念🍡1.1、概念🍢1.2、线程的优点🍧1.3、线程的缺点🍨1.4、线程的异常和用途🌺2、Linux下进程 vs 线程🌈前言 这篇文章给大家带来线程的学习!…

PID算法入门(一)

1.简介 PID是Proportional(比例), Integral(积分), Differential(微分)的首字母缩写,他是一种结合比例,积分,微分三个环节于一体的闭环控制算法. 2.PID各环节 2.1比例环节 成比例地反应控制系统的偏差信号,即输出&a…

Codeforces Round #843 (Div. 2) A1 —— D

题目地址:Dashboard - Codeforces Round #843 (Div. 2) - Codeforces一个不知名大学生,江湖人称菜狗 original author: jacky Li Email : 3435673055qq.com Time of completion:2023.1.11 Last edited: 2023.1.11 目录 ​编辑 A1. Gardener…

读论文——day61 目标检测模型的决策依据与可信度分析

目标检测模型的决策依据与可信度分析本文贡献及原文1 相关工作(略看)1.3 目标检测模型2 背景知识(LIME)2.2 LIME3 目标检测决策依据及可信度分析3.1 决策依据3.2 对目标检测模型的预测进行可信度评价4 基于 LIME 的目标检测模型解…

(第四章)OpenGL超级宝典学习:必要的数学知识

必要的数学知识 前言 在本章当中,作者着重介绍了几个和3D图形学重要的数学知识,线性代数基础好的同学可以直接绕过本章,说实话这篇博客写到这里,我是非常犹豫的,本章节的内容可以说是很基础,但是相当…

SSM框架01_Spring

有一个效应叫知识诅咒:自己一旦知道了某事,就无法想象这件事在未知者眼中的样子。00-Spring课程介绍01-初识Spring今天所学的Spring其实是Spring家族中的Spring Framework;Spring Fra是Spring家族中其他框架的底层基础,学好Spring可以为其他S…

Morse1题解

原理摩尔斯电码和电报简单说一下电报和摩尔斯电码的原理最简单的电报模型就是一个电源,一个开关和一个电磁铁当需要长距离使用时候,需要用到继电器按下开关,电磁铁会吸引磁铁长按开关,电磁铁就会闭合一段时间,留下一划…

Jenkins集成GitLab Webhooks自动化构建

JenkinsGitLab Webhooks自动构建项目1 构建步骤1.1 Jenkins中设置构建触发器1.2 Build Authorization Token Root插件安装1.3 GitLab配置Webhooks2 测试webhooks2.1 测试推送事件2.2 测试合并请求事件2.3 代码修改提交测试1 构建步骤 1.1 Jenkins中设置构建触发器 这里先随便写…

Markdown与DITA比较

Markdown是一种轻量级标记语言,创始人为John Gruber。它允许人们使用易读易写的纯文本格式编写文档,然后转换成有效的HTML文档。这种语言吸收了很多在电子邮件中已有的纯文本标记的特性。由于Markdown的轻量化、易读易写特性,并且对于图片&am…

第一章Mybatis基础操作学习

文章目录MyBatis简介MyBatis历史MyBatis特性和其它持久化层技术对比搭建MyBatis开发环境创建maven工程创建MyBatis的核心配置文件创建mapper接口创建MyBatis的映射文件通过junit测试功能加入log4j日志功能不带参数的增删改查Mapper接口的编写对应Mapper接口的xml文件编写核心配…