paddle2.3-基于联邦学习实现FedAVg算法

news2024/11/24 1:16:33

目录

1. 联邦学习介绍

2. 实验流程

3. 数据加载

4. 模型构建

5. 数据采样函数

6. 模型训练


1. 联邦学习介绍

联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备)。联邦学习的模式是在各分支节点分别利用本地数据训练模型,再将训练好的模型汇合到中心节点,获得一个更好的全局模型。

联邦学习的提出是为了充分利用用户的数据特征训练效果更佳的模型,同时,为了保证隐私,联邦学习在训练过程中,server和clients之间通信的是模型的参数(或梯度、参数更新量),本地的数据不会上传到服务器。

本项目主要是升级1.8版本的联邦学习fedavg算法至2.3版本,内容取材于基于PaddlePaddle实现联邦学习算法FedAvg - 飞桨AI Studio星河社区

2. 实验流程

联邦学习的基本流程是:

1. server初始化模型参数,所有的clients将这个初始模型下载到本地;

2. clients利用本地产生的数据进行SGD训练;

3. 选取K个clients将训练得到的模型参数上传到server;

4. server对得到的模型参数整合,所有的clients下载新的模型。

5. 重复执行2-5,直至收敛或达到预期要求

import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import random
import time
import paddle
import paddle.nn as nn
import numpy as np
from paddle.io import Dataset,DataLoader
import paddle.nn.functional as F

3. 数据加载

mnist_data_train=np.load('data/data2489/train_mnist.npy')
mnist_data_test=np.load('data/data2489/test_mnist.npy')
print('There are {} images for training'.format(len(mnist_data_train)))
print('There are {} images for testing'.format(len(mnist_data_test)))
# 数据和标签分离(便于后续处理)
Label=[int(i[0]) for i in mnist_data_train]
Data=[i[1:] for i in mnist_data_train]
There are 60000 images for training
There are 10000 images for testing

4. 模型构建

class CNN(nn.Layer):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1=nn.Conv2D(1,32,5)
        self.relu = nn.ReLU()
        self.pool1=nn.MaxPool2D(kernel_size=2,stride=2)
        self.conv2=nn.Conv2D(32,64,5)
        self.pool2=nn.MaxPool2D(kernel_size=2,stride=2)
        self.fc1=nn.Linear(1024,512)
        self.fc2=nn.Linear(512,10)
        # self.softmax = nn.Softmax()
    def forward(self,inputs):
        x = self.conv1(inputs)
        x = self.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool2(x)
        
        x=paddle.reshape(x,[-1,1024])
        x = self.relu(self.fc1(x))
        y = self.fc2(x)
        return y

5. 数据采样函数

# 均匀采样,分配到各个client的数据集都是IID且数量相等的
def IID(dataset, clients):
  num_items_per_client = int(len(dataset)/clients)
  client_dict = {}
  image_idxs = [i for i in range(len(dataset))]
  for i in range(clients):
    client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False)) # 为每个client随机选取数据
    image_idxs = list(set(image_idxs) - client_dict[i]) # 将已经选取过的数据去除
    client_dict[i] = list(client_dict[i])

  return client_dict
# 非均匀采样,同时各个client上的数据分布和数量都不同
def NonIID(dataset, clients, total_shards, shards_size, num_shards_per_client):
  shard_idxs = [i for i in range(total_shards)]
  client_dict = {i: np.array([], dtype='int64') for i in range(clients)}
  idxs = np.arange(len(dataset))
  data_labels = Label

  label_idxs = np.vstack((idxs, data_labels)) # 将标签和数据ID堆叠
  label_idxs = label_idxs[:, label_idxs[1,:].argsort()]
  idxs = label_idxs[0,:]

  for i in range(clients):
    rand_set = set(np.random.choice(shard_idxs, num_shards_per_client, replace=False)) 
    shard_idxs = list(set(shard_idxs) - rand_set)

    for rand in rand_set:
      client_dict[i] = np.concatenate((client_dict[i], idxs[rand*shards_size:(rand+1)*shards_size]), axis=0) # 拼接
  
  return client_dict

class MNISTDataset(Dataset):
    def __init__(self, data,label):
        self.data = data
        self.label = label

    def __getitem__(self, idx):
        image=np.array(self.data[idx]).astype('float32')
        image=np.reshape(image,[1,28,28])
        label=np.array(self.label[idx]).astype('int64')
        return image, label

    def __len__(self):
        return len(self.label)

6. 模型训练

class ClientUpdate(object):
    def __init__(self, data, label, batch_size, learning_rate, epochs):
        dataset = MNISTDataset(data,label)
        self.train_loader = DataLoader(dataset,
                    batch_size=batch_size,
                    shuffle=True,
                    drop_last=True)
        self.learning_rate = learning_rate
        self.epochs = epochs
        
    def train(self, model):
        optimizer=paddle.optimizer.SGD(learning_rate=self.learning_rate,parameters=model.parameters())
        criterion = nn.CrossEntropyLoss(reduction='mean')
        model.train()
        e_loss = []
        for epoch in range(1,self.epochs+1):
            train_loss = []
            for image,label in self.train_loader:
                # image=paddle.to_tensor(image)
                # label=paddle.to_tensor(label.reshape([label.shape[0],1]))
                output=model(image)
                loss= criterion(output,label)
                # print(loss)
                loss.backward()
                optimizer.step()
                optimizer.clear_grad()
                train_loss.append(loss.numpy()[0])
            t_loss=sum(train_loss)/len(train_loss)
            e_loss.append(t_loss)
        total_loss=sum(e_loss)/len(e_loss)
        return model.state_dict(), total_loss

train_x = np.array(Data)
train_y = np.array(Label)
BATCH_SIZE = 32
# 通信轮数
rounds = 100
# client比例
C = 0.1
# clients数量
K = 100
# 每次通信在本地训练的epoch
E = 5
# batch size
batch_size = 10
# 学习率
lr=0.001
# 数据切分
iid_dict = IID(mnist_data_train, 100)
def training(model, rounds, batch_size, lr, ds,L, data_dict, C, K, E, plt_title, plt_color):
    global_weights = model.state_dict()
    train_loss = []
    start = time.time()
    # clients与server之间通信
    for curr_round in range(1, rounds+1):
        w, local_loss = [], []
        m = max(int(C*K), 1) # 随机选取参与更新的clients
        S_t = np.random.choice(range(K), m, replace=False)
        for k in S_t:
            # print(data_dict[k])
            sub_data = ds[data_dict[k]]
            sub_y = L[data_dict[k]]
            local_update = ClientUpdate(sub_data,sub_y, batch_size=batch_size, learning_rate=lr, epochs=E)
            weights, loss = local_update.train(model)
            w.append(weights)
            local_loss.append(loss)

        # 更新global weights
        weights_avg = w[0]
        for k in weights_avg.keys():
            for i in range(1, len(w)):
                # weights_avg[k] += (num[i]/sum(num))*w[i][k]
                weights_avg[k]=weights_avg[k]+w[i][k]   
            weights_avg[k]=weights_avg[k]/len(w)
            global_weights[k].set_value(weights_avg[k])
        # global_weights = weights_avg
        # print(global_weights)
    #模型加载最新的参数
        model.load_dict(global_weights)

        loss_avg = sum(local_loss) / len(local_loss)
        if curr_round % 10 == 0:
            print('Round: {}... \tAverage Loss: {}'.format(curr_round, np.round(loss_avg, 5)))
        train_loss.append(loss_avg)

    end = time.time()
    fig, ax = plt.subplots()
    x_axis = np.arange(1, rounds+1)
    y_axis = np.array(train_loss)
    ax.plot(x_axis, y_axis, 'tab:'+plt_color)

    ax.set(xlabel='Number of Rounds', ylabel='Train Loss',title=plt_title)
    ax.grid()
    fig.savefig(plt_title+'.jpg', format='jpg')
    print("Training Done!")
    print("Total time taken to Train: {}".format(end-start))
  
    return model.state_dict()

#导入模型
mnist_cnn = CNN()
mnist_cnn_iid_trained = training(mnist_cnn, rounds, batch_size, lr, train_x,train_y, iid_dict, C, K, E, "MNIST CNN on IID Dataset", "orange")

W0605 23:22:00.961916 10307 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0605 23:22:00.966121 10307 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.


Round: 10... 	Average Loss: 0.033
Round: 20... 	Average Loss: 0.011
Round: 30... 	Average Loss: 0.012
Round: 40... 	Average Loss: 0.008
Round: 50... 	Average Loss: 0.003
Round: 60... 	Average Loss: 0.002
Round: 70... 	Average Loss: 0.001
Round: 80... 	Average Loss: 0.001
Round: 90... 	Average Loss: 0.001

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1043518.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

8+铜死亡+分型+预后模型

今天给同学们分享一篇铜死亡分型预后模型的生信文章“Signature construction and molecular subtype identification based on cuproptosis-related genes to predict the prognosis and immune activity of patients with hepatocellular carcinoma”,这篇文章于2…

Vue实现Hello World

<div id"aa"> <p>{{h}}</p> </div> <script src"https://cdn.jsdelivr.net/npm/vue2/dist/vue.js"></script> <script> const hello new Vue({ el:#aa, data:{ h : Hello World } }) </script>

Bigemap在地质矿产行业的应用

工具 Bigemap gis office地图软件 BIGEMAP GIS Office-全能版 Bigemap APP_卫星地图APP_高清卫星地图APP 首先 在地质矿产勘查中 &#xff0c;Gis技术可以用来被直接利用到工作分析里面&#xff0c;可以快速获取存储和管理大量空间数据 比如结合卫星影像和无人机遥感数据&am…

ASN1编码查看工具

ASN1编码查看工具 多页面ASN1编码查看工具&#xff0c;支持DER和base64编码&#xff0c;支持数据高亮 多页面ASN1编码查看工具&#xff0c;支持DER和base64编码&#xff0c;支持数据高亮

MySQL学习笔记16

MySQL的用户管理和权限管理&#xff1a; 在MySQL中创建更多的用户&#xff0c;不可能都使用root用户。root用户的权限太高&#xff0c;一个误操作&#xff0c;容易造成很大的失误。 注意&#xff1a;MySQL中不能单纯通过用户名来说明用户&#xff0c;必须加上主机&#xff0c;…

[每周一更]-(第63期):Linux-nsenter命令使用说明

nsenter命令是一个可以在指定进程的命令空间下运行指定程序的命令。它位于util-linux包中。 1、用途 一个最典型的用途就是进入容器的网络命令空间。相当多的容器为了轻量级&#xff0c;是不包含较为基础的命令的&#xff0c; 比如说ip address&#xff0c;ping&#xff0c;t…

ENVI IDL:MODIS SWATH产品的点位-像元提取(另附Python代码)

目录 01 说明 1.1 ENVI IDL本小节实验说明 1.2 Python本小节实验说明 02 代码 2.1 我的ENVI IDL代码 2.2 实验给定ENVI IDL代码 2.3 我的python代码 2.4 实验给定python代码 01 说明 实验说明 提取/coarse_data/chapter_3/modis_swath/目录下所有MODIS气溶胶产品中Image…

复杂度(7.23)

1.算法效率 如何衡量算法的好坏&#xff1f; 这里需要引入算法的复杂度 1.2算法的复杂度 算法在编写成可执行程序后&#xff0c;运行时需要耗费时间资源和空间 ( 内存 ) 资源。因此 衡量一个算法的好坏&#xff0c;一般是从时间和空间两个维度来衡量的 &#xff0c; 即时间复…

docker 安装 nessus新版、awvs15-简单更快捷

一、docker 安装 nessus 参考项目地址&#xff1a; https://github.com/elliot-bia/nessus 介绍&#xff1a;几行代码即可一键安装更新 nessus -推荐 安装好 docker后执行以下命令 #拉取镜像创建容器 docker run -itd --nameramisec_nessus -p 8834:8834 ramisec/nessus …

Cisco Packet Tracer的下载与安装+汉化

一、下载 1、思科官网下载 参考连接&#xff1a;Cisco Packet Tracer - Networking Simulation Tool (netacad.com) 2、百度网盘 链接&#xff1a;https://pan.baidu.com/s/18xzqdACkCXngMzYm5UV_zg 提取码&#xff1a;xuyi 二、安装 双击安装包 点击下一步 选择同意条款&…

VUE指令语法解析标签属性

我们可以在标签体中使用插值语法 {{ }} 来直接读取data中的属性 那我们能使用相同的方法将我们的网址给填入a标签的href属性中吗&#xff1f; 我们运行后会发现并没有给我们变为<a href"https://blog.csdn.net/XunLin233">&#xff0c;而是<a href"{{…

个性化定制企业邮件域名的解决方案

随着互联网的快速发展&#xff0c;电子邮件已经成为了商业沟通的主要方式之一。而企业邮箱作为企业进行正式商务交流的重要工具&#xff0c;其域名和后缀的选择也引起了人们的关注。其中一个关键的问题是&#xff1a;企业邮箱的后缀是否可以自定义&#xff1f;企业邮箱的后缀是…

MSF的安装与使用教程,超详细,附安装包和密钥

MSF简介 Metasploit&#xff08;MSF&#xff09;是一个免费的、可下载的框架 它本身附带数百个已知软件漏洞&#xff0c;是一款专业级漏洞攻击工具。 当H.D. Moore在2003年发布Metasploit时&#xff0c;计算机安全状况也被永久性地改变了&#xff0c;仿佛一夜之间&#xff0…

ipad触控笔有必要买原装吗?性价比触控笔排行榜

随着社会经济的发展&#xff0c;越来越多的人需要用到电容笔。国产的平替电器笔&#xff0c;与苹果原装的电容笔差别并不大&#xff0c;无论是在功能上&#xff0c;还是在触感上&#xff0c;都相差无几&#xff0c;写起字来更是行云流水&#xff0c;让我有些意外的是&#xff0…

el-upload实现复制粘贴图片

前言&#xff1a; 在之前的项目中&#xff0c;利用el-upload实现了上传图片视频的预览。项目上线后&#xff0c;经使用人员反馈&#xff0c;上传图片、视频每次要先保存到本地然后再上传&#xff0c;很是浪费时间&#xff0c;公司客服人员时间又很紧迫&#xff08;因为要响应下…

Xmake v2.8.3 发布,改进 Wasm 并支持 Xmake 源码调试

Xmake 是一个基于 Lua 的轻量级跨平台构建工具。 它非常的轻量&#xff0c;没有任何依赖&#xff0c;因为它内置了 Lua 运行时。 它使用 xmake.lua 维护项目构建&#xff0c;相比 makefile/CMakeLists.txt&#xff0c;配置语法更加简洁直观&#xff0c;对新手非常友好&#x…

描述性统计分析

前言&#xff1a; 本专栏参考教材为《SPSS22.0从入门到精通》&#xff0c;由于软件版本原因&#xff0c;部分内容有所改变&#xff0c;为适应软件版本的变化&#xff0c;特此创作此专栏便于大家学习。本专栏使用软件为&#xff1a;SPSS25.0 本专栏所有的数据文件可在个人主页—…

电脑D盘格式化会有什么影响?电脑D盘格式化了怎么恢复数据

当电脑出现问题时&#xff0c;往往会出现一些提示&#xff0c;例如提示格式化的问题&#xff0c;而最近有位小伙伴也遇到了相似的问题&#xff0c;即D盘一打开就显示格式化&#xff0c;由于不清楚D盘格式化会有什么影响&#xff0c;因此不小心进行了格式化操作&#xff0c;结果…

较真儿学源码系列-PowerJob启动流程源码分析

PowerJob版本&#xff1a;4.3.2-main。 1 简介 PowerJob是全新一代的分布式任务调度与计算框架&#xff0c;官网地址&#xff1a;http://www.powerjob.tech/。其中介绍了PowerJob的功能特点&#xff0c;以及与其他调度框架的对比&#xff0c;这里就不再赘述了。 以上是PowerJob…

基于PSO算法的功率角摆动曲线优化研究(Matlab代码实现)

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…