使用PyTorch和Flower 进行联邦学习

news2024/12/26 8:04:47

本文将介绍如何使用 Flower 构建现有机器学习工作的联邦学习版本。我们将使用 PyTorch 在 CIFAR-10 数据集上训练卷积神经网络,然后将展示如何修改训练代码以联邦的方式运行训练。

什么是联邦学习?

我们将在这篇文章中区分两种主要方法:集中式和联邦式(本文的图例表示如下)

集中式

每个设备都会将其数据发送到全局服务器,然后服务器将使用它来训练全局模型。训练完成后服务器将经过训练的全局模型发送到设备。

这并不是我们所说的联邦学习的解决方案,传输了数据,会带来很多问题

联邦式

每个设备都不会与服务器共享数据,而是将数据保存在本地并用它来训练模型。模型的权重会被发送到全局服务器,然后全局服务器会将收到的所有权重聚合到一个全局模型中,服务器最终将经过训练的全局模型发送到设备。这种方式是一般形式的联邦学习,它的主要优点是保护用户的隐私,避免数据泄露。

我们先完成集中式训练代码,因为该训练模式基本上与传统的PyTorch 训练相同,然后再将其改为联邦学习的方式。

集中式 PyTorch 训练

让我们创建一个名为 cifar.py 的新文件,其中包含在 CIFAR-10 上进行传统(集中式)训练所需的所有组件。首先,需要导入所有的包(例如 torch 和 torchvision)。我们现在没有导入任何用于联邦学习的包。可以稍后再进行导入。

 fromtypingimportTuple, Dict
 importtorch
 importtorch.nnasnn
 importtorch.nn.functionalasF
 importtorchvision
 importtorchvision.transformsastransforms
 fromtorchimportTensor
 fromtorchvision.datasetsimportCIFAR10

模型架构(一个非常简单的卷积神经网络)在 Net() 类中定义。

 classNet(nn.Module):
     def__init__(self) ->None:
         super(Net, self).__init__()
         self.conv1=nn.Conv2d(3, 6, 5)
         self.pool=nn.MaxPool2d(2, 2)
         self.conv2=nn.Conv2d(6, 16, 5)
         self.fc1=nn.Linear(16*5*5, 120)
         self.fc2=nn.Linear(120, 84)
         self.fc3=nn.Linear(84, 10)
     defforward(self, x: Tensor) ->Tensor:
         x=self.pool(F.relu(self.conv1(x)))
         x=self.pool(F.relu(self.conv2(x)))
         x=x.view(-1, 16*5*5)
         x=F.relu(self.fc1(x))
         x=F.relu(self.fc2(x))
         x=self.fc3(x)
         returnx

load_data() 函数加载 CIFAR-10 训练和测试集。转换在加载后规范化了数据。

 DATA_ROOT="~/data/cifar-10"
 
 defload_data() ->Tuple[
     torch.utils.data.DataLoader, 
     torch.utils.data.DataLoader, 
     Dict
 ]:
     """Load CIFAR-10 (training and test set)."""
     transform=transforms.Compose(
         [transforms.ToTensor(),
          transforms.Normalize(
               (0.5, 0.5, 0.5), 
               (0.5, 0.5, 0.5)
          )
         ]
     )
     trainset=CIFAR10(DATA_ROOT, 
                        train=True, 
                        download=True, 
                        transform=transform)
     trainloader=torch.utils.data.DataLoader(trainset,
                                               batch_size=32, 
                                               shuffle=True)
     testset=CIFAR10(DATA_ROOT, 
                       train=False, 
                       download=True, 
                       transform=transform)
     testloader=torch.utils.data.DataLoader(testset, 
                                              batch_size=32, 
                                              shuffle=False)
     num_examples= {"trainset" : len(trainset), "testset" : len(testset)}
     returntrainloader, testloader, num_examples

我们现在需要定义训练函数 train(),它循环遍历训练集、计算损失、反向传播,然后对每批训练执行一个优化步骤。

模型的评估在函数 test() 中定义。该函数遍历所有测试样本并根据测试数据集测量模型的损失。

 deftrain(
     net: Net,
     trainloader: torch.utils.data.DataLoader,
     epochs: int,
     device: torch.device,
 ) ->None:
     """Train the network."""
     # Define loss and optimizer
     criterion=nn.CrossEntropyLoss()
     optimizer=torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
     print(f"Training {epochs} epoch(s) w/ {len(trainloader)} batches each")
     # Train the network
     forepochinrange(epochs):  # loop over the dataset multiple times
         running_loss=0.0
         fori, datainenumerate(trainloader, 0):
             images, labels=data[0].to(device), data[1].to(device)
             # zero the parameter gradients
             optimizer.zero_grad()
             # forward + backward + optimize
             outputs=net(images)
             loss=criterion(outputs, labels)
             loss.backward()
             optimizer.step()
             # print statistics
             running_loss+=loss.item()
             ifi%100==99:  # print every 100 mini-batches
                 print("[%d, %5d] loss: %.3f"% (epoch+1, 
                                                 i+1, 
                                                 running_loss/2000))
                 running_loss=0.0
 deftest(
     net: Net,
     testloader: torch.utils.data.DataLoader,
     device: torch.device,
 ) ->Tuple[float, float]:
     """Validate the network on the entire test set."""
     criterion=nn.CrossEntropyLoss()
     correct=0
     total=0
     loss=0.0
     withtorch.no_grad():
         fordataintestloader:
             images, labels=data[0].to(device), data[1].to(device)
             outputs=net(images)
             loss+=criterion(outputs, labels).item()
             _, predicted=torch.max(outputs.data, 1)
             total+=labels.size(0)
             correct+= (predicted==labels).sum().item()
     accuracy=correct/total
     returnloss, accuracy

定义了数据加载、模型架构、训练和评估后,我们可以将所有内容放在一起并在 CIFAR-10 上训练我们的 CNN。

 defmain():
     DEVICE=torch.device("cuda:0"iftorch.cuda.is_available() else"cpu")
     print("Centralized PyTorch training")
     print("Load data")
     trainloader, testloader, _=load_data()
     print("Start training")
     net=Net().to(DEVICE)
     train(net=net, trainloader=trainloader, epochs=2, device=DEVICE)
     print("Evaluate model")
     loss, accuracy=test(net=net, testloader=testloader, device=DEVICE)
     print("Loss: ", loss)
     print("Accuracy: ", accuracy)
 
 if__name__=="__main__":
     main()

现在就可以直接运行了:

 python3 cifar.py

到目前为止,如果你以前使用过 PyTorch,这一切看起来应该相当熟悉。下面开始进入正题,我们开始构建一个简单的联邦学习系统,该系统由一个服务器和两个客户端组成。

PyTorch的联邦学习

我们已经在单个数据集 (CIFAR-10) 上训练了模型, 我们称之为集中学习。这种集中学习的概念是我们以前常用的方式。通常,如果你想以联邦学习的方式运行,则必须更改大部分代码并从头开始设置所有内容。但是,这里有一个包 Flower,它可以将预先存在的代码以联邦学习运行(当然需要少量的修改)。

既然是联邦学习,我们必须有服务器,然后 cifar.py 代码也需要连接到服务器的客户端。服务器向客户端发送模型参数。客户端运行训练并更新参数。更新后的参数被发送回服务器,服务器对所有接收到的参数更新进行平均,这就是联邦学习的一个简单的流程。

我们这个例子是由一台服务器和两个客户端组成。我们先设置server.py。服务端需要导入Flower包flwr,然后使用 start_server 函数启动服务器并告诉它执行三轮联邦学习。

 importflwrasfl
 
 if__name__=="__main__":
     fl.server.start_server(
         server_address="0.0.0.0:8080", 
         config=fl.server.ServerConfig(num_rounds=3)
      )

然后就可以启动服务器了:

 python3 server.py

我们还要在 client.py 中定义客户端逻辑,主要就是将之前在 cifar.py 中定义的集中训练的代码进行整合:

 fromcollectionsimportOrderedDict
 fromtypingimportDict, List, Tuple
 importnumpyasnp
 importtorch
 importcifar
 importflwrasfl
 
 DEVICE: str=torch.device("cuda:0"iftorch.cuda.is_available() else"cpu")

Flower 客户端需要实现 flwr.client.Client 或 flwr.client.NumPyClient 类。这里的实现将基于 flwr.client.NumPyClient,我们将其称为 CifarClient。因为我们使用了 NumPy ,而PyTorch 或 TensorFlow/Keras)都是直接是吃NumPy的互操作,所以使用NumPyClient 比 Client 更容易。

完成我们的CifarClient需要实现四个方法,两个获取/设置模型参数的方法,一个训练模型的方法,一个测试模型的方法:

1、set_parameters

这个方法有2个作用:

  • 在从服务器接收的本地模型上设置模型参数
  • 遍历作为 NumPy ndarray 接收的模型参数列表

2、get_parameters

获取模型参数并将它们作为 NumPy ndarray 的列表返回(这是 flwr.client.NumPyClient 所需要的)

3、fit

一看就知道,这是训练本地模型的方法,它有3个作用:

  • 使用从服务器接收到的参数更新本地模型的参数
  • 在本地训练集上训练模型
  • 训练本地模型,并将权重上传服务器

4、evaluate

验证模型的方法:

  • 从服务器接收到的参数更新本地模型的参数
  • 在本地测试集上评估更新后的模型
  • 将本地损失和准确率等指标返回给服务器

我们先前在 cifar.py 中定义的函数 train() 和 test()可以作为 fit 和 evaluate 使用。所以在这里真正要做的是通过我们的 NumPyClient 类告诉 Flower 已经定义的哪些函数,剩下的两个方法实现起来也不复杂:

 classCifarClient(fl.client.NumPyClient):
     """Flower client implementing CIFAR-10 image classification using
     PyTorch."""
     def__init__(
         self,
         model: cifar.Net,
         trainloader: torch.utils.data.DataLoader,
         testloader: torch.utils.data.DataLoader,
         num_examples: Dict,
     ) ->None:
         self.model=model
         self.trainloader=trainloader
         self.testloader=testloader
         self.num_examples=num_examples
 
     defget_parameters(self, config) ->List[np.ndarray]:
         # Return model parameters as a list of NumPy ndarrays
         return [val.cpu().numpy() for_, valinself.model.state_dict().items()]
     
     defset_parameters(self, parameters: List[np.ndarray]) ->None:
         # Set model parameters from a list of NumPy ndarrays
         params_dict=zip(self.model.state_dict().keys(), parameters)
         state_dict=OrderedDict({k: torch.tensor(v) fork, vinparams_dict})
         self.model.load_state_dict(state_dict, strict=True)
 
     deffit(
         self, parameters: List[np.ndarray], config: Dict[str, str]
     ) ->Tuple[List[np.ndarray], int, Dict]:
         # Set model parameters, train model, return updated model parameters
         self.set_parameters(parameters)
         cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)
         returnself.get_parameters(config={}), self.num_examples["trainset"], {}
 
     defevaluate(
         self, parameters: List[np.ndarray], config: Dict[str, str]
     ) ->Tuple[float, int, Dict]:
         # Set model parameters, evaluate model on local test dataset, return result
         self.set_parameters(parameters)
         loss, accuracy=cifar.test(self.model, self.testloader, device=DEVICE)
         returnfloat(loss), self.num_examples["testset"], {"accuracy": float(accuracy)}

最后我们要定义一个函数来加载模型和数据,创建并启动这个CifarClient客户端。

 defmain() ->None:
     """Load data, start CifarClient."""
     # Load model and data
     model=cifar.Net()
     model.to(DEVICE)
     trainloader, testloader, num_examples=cifar.load_data()
     # Start client
     client=CifarClient(model, trainloader, testloader, num_examples)
     fl.client.start_numpy_client(server_address="0.0.0.0:8080", client)
 
 if__name__=="__main__":
     main()

这样就完成了。现在可以打开两个额外的终端窗口并运行(因为我们要演示2个客户端的联邦学习)

 python3 client.py

在每个窗口中(请确保前面的服务器正在运行)可以看到你的PyTorch 项目在两个客户端上进行训练了。

总结

本文介绍了如何使用Flower将我们原有pytorch代码改造为联邦学习的方式进行训练,完整的代码可以在这里找到:

https://avoid.overfit.cn/post/8d05a12c208c4f499573c9966d0fe415

作者:Charles Beauville

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/458721.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据库的概念?怎么在linux内安装数据库?怎么使用?

目录 一、概念 二、mysql安装及设置 1.安装mysql 2.数据库服务启动停止 三、数据库基本操作 1、数据库的登录及退出 2、数据表的操作 3、mysql查询操作 一、概念 数据库:是存放数据的仓库,它是一个按数据结构来存储和管理数据的计算机软件系统。数据库管理…

BM38-在二叉树中找到两个节点的最近公共祖先

题目 给定一棵二叉树(保证非空)以及这棵树上的两个节点对应的val值 o1 和 o2,请找到 o1 和 o2 的最近公共祖先节点。 数据范围:树上节点数满足 1≤n≤10^5 , 节点值val满足区间 [0,n) 要求:时间复杂度 O(n) 注:本题保证二叉树…

深入理解Javascript事件处理机制

深入理解javascript事件处理机制 前言 在开发web应用程序时,事件处理机制是javascript中至关重要的一部分。许多高级特性,如事件冒泡、事件捕获和事件委托,都是通过事件处理来实现的。熟练掌握这些技术可以帮助我们更好地组织代码、提高代码…

腾讯多媒体实验室画质增强技术的前沿应用

全真互联时代,音视频技术内核不断更新迭代,LiveVideoStackCon 2022 北京站邀请到腾讯多媒体实验室视频技术研发负责人——夏珍,与大家分享画质增强技术的一些前沿探索和应用研究,在经典影像中非常重要的画质提升技术人脸修复和去压…

告别web.xml映射Servlet、Filter、Listener,解锁注解新方式开发

编译软件:IntelliJ IDEA 2019.2.4 x64 操作系统:win10 x64 位 家庭版 服务器软件:apache-tomcat-8.5.27 目录 一. Servlet、Filter、Listener的注解方式是什么?二. 为什么要使用Servlet、Filter、Listener的注解方式?三…

【架构】互联网应用开发架构演进历程

文章目录 一、背景二、技术架构演进史三、架构演进一: 早期雏形四、架构演进二: 数据库开发(LAMP特长)五、架构演进三: javaweb的雏形六、架构演进四: javaweb的集群发展​七、架构演进五: javaweb的分布式发展八、架构演进六: javaweb的微服务发展​8.1…

开源 AI 辅助编程工具 AutoDev 现已上架 Jetbrains 插件市场

我们非常高兴地宣布 AutoDev v0.2.0 的发布!AutoDev 是一款强大的 AI 辅助编程工具,可以与 Jetbrains 系列 IDE 无缝集成(VS Code 支持正在开发中)。通过与需求管理系统(如 Github Issue 等)直接对接&#…

WPF教程(八)--数据绑定(1)--基础概述

使用WPF可以很方便的设计出强大的用户界面,同时 WPF提供了数据绑定功能。WPF的数据绑定跟Winform与ASP.NET中的数据绑定功能类似,但也有所不同,在 WPF中以通过后台代码绑定、前台XAML中进行绑定,或者两者组合的方式进行数据绑定。…

用python制作剪刀石头布的小游戏

1 问题 在python中我们学习了条件语句,那么我们是否可以通过python中条件判断的功能来写出可以判断胜负的剪刀石头布小游戏呢? 2 方法 导入随机函数,保证胜负的随机性 设置对应数值,写好判断输赢的条件语句 运行并查看结果 代码清单 1 impor…

斯坦福| ChatGPT用于生成式搜索引擎的可行性

文|智商掉了一地 随着 ChatGPT 在文本生成领域迈出了重要一步,Bing 浏览器也接入了聊天机器人功能,因此如何保证 Bing Chat 等搜索引擎结果的精确率和真实性也成为了搜索领域的热门话题之一。 当我们使用搜索引擎时,往往希望搜索结…

教你如何进行DNS域名解析

目录 一:DNS系统介绍 1.DNS服务概述 2.DNS域名空间介绍 3.DNS 域名结构 4.DNS解析方式 5.DNS查询方式 (1)递归查询 (2)迭代查询 6.DNS服务器类型: (1)主域名服务器 (2)从域名服务器 (3)缓存域名服务器 (4)…

Android进阶宝典 -- 解读Handler机制核心源码,让ANR无处可藏

其实ANR核心本质就是让UI线程(主线程)等了太久,导致系统判定在主线程做了耗时操作导致ANR。当我们执行任何一个任务的时候,在Framework底层是通过消息机制来维护任务的分发,从下面这个日志可以看到, "…

thrift、go与php

学习一下thrift。 环境 mac m1,go 1.20,php 7.4,thrift 0.18.1 要学习thrift,第一步得先安装 $ brew install thrift学习的计划是用go作为server,php作为client,通过thrift的方式完成一次请求demo。 建…

Java语言的特点和八大基本类型

“byte和short两兄弟去找int问long去哪了” “int摇摇头说不知道” “此时float和double两兄弟也来凑热闹” “共同商议后决定去找char询问” “char面对五人的询问只好说boolean知道” “六人来到boolean的住处发现long竟然在玩猜真假游戏” Java语言的特点 1.简单易学…

个性化学习路径推荐综述

源自:软件学报 作者:云岳 代欢 张育培 尚学群 李战怀 摘 要 近年来, 伴随着现代信息技术的迅猛发展, 以人工智能为代表的新兴技术在教育领域得到了广泛应用, 引发了学习理念和方式的深刻变革. 在这种大背景下, 在线学习超越了时空的限制,…

2023年电信推出新套餐:月租19元=135G流量+长期套餐+无合约期!

在三大运营商推出的流量卡当中,电信可以说是性价比最高的一个,相对于其他两家运营商,完全符合我们低月租,大流量的要求,所以,今天小编介绍的还是电信流量卡。 在这里说一下,小编推荐的卡都是免…

教你怎样用PXE高效的批量网络装机

目录 一:PXE介绍 1.XPE概述 2.PXE批量部署的优点 3.搭建PXE各部作用 (1)PXE(Preboot eXcution Environment) (2)服务端 (3)客户端 二:部署PXE服务 1.安装并启用TFTP服务 2.安…

Tiktok/抖音旋转验证码

声明 本文以教学为基准、本文提供的可操作性不得用于任何商业用途和违法违规场景。 本人对任何原因在使用本人中提供的代码和策略时可能对用户自己或他人造成的任何形式的损失和伤害不承担责任。 如有侵权,请联系我进行删除。 抖音系的旋转验证码,跟得物一样,都是内外圈一起…

blast的-max_target_seqs?

Shah, N., Nute, M.G., Warnow, T., and Pop, M. (2018). Misunderstood parameter of NCBI BLAST impacts the correctness of bioinformatics workflows. Bioinformatics. 杂志Bioinformatics以letter to the editor的形式刊发了来自美国马里兰大学计算机系的Nidhi Shah等人…

基于html+css的图展示42

准备项目 项目开发工具 Visual Studio Code 1.44.2 版本: 1.44.2 提交: ff915844119ce9485abfe8aa9076ec76b5300ddd 日期: 2020-04-16T16:36:23.138Z Electron: 7.1.11 Chrome: 78.0.3904.130 Node.js: 12.8.1 V8: 7.8.279.23-electron.0 OS: Windows_NT x64 10.0.19044 项目…