paddle2.3-基于联邦学习实现FedAVg算法-CNN

news2024/11/29 5:33:43

目录

1. 联邦学习介绍

2. 实验流程

3. 数据加载

4. 模型构建

5. 数据采样函数

6. 模型训练


1. 联邦学习介绍

联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备)。联邦学习的模式是在各分支节点分别利用本地数据训练模型,再将训练好的模型汇合到中心节点,获得一个更好的全局模型。

联邦学习的提出是为了充分利用用户的数据特征训练效果更佳的模型,同时,为了保证隐私,联邦学习在训练过程中,server和clients之间通信的是模型的参数(或梯度、参数更新量),本地的数据不会上传到服务器。

本项目主要是升级1.8版本的联邦学习fedavg算法至2.3版本,内容取材于基于PaddlePaddle实现联邦学习算法FedAvg - 飞桨AI Studio星河社区

2. 实验流程

联邦学习的基本流程是:

1. server初始化模型参数,所有的clients将这个初始模型下载到本地;

2. clients利用本地产生的数据进行SGD训练;

3. 选取K个clients将训练得到的模型参数上传到server;

4. server对得到的模型参数整合,所有的clients下载新的模型。

5. 重复执行2-5,直至收敛或达到预期要求

import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import random
import time
import paddle
import paddle.nn as nn
import numpy as np
from paddle.io import Dataset,DataLoader
import paddle.nn.functional as F

3. 数据加载

mnist_data_train=np.load('data/data2489/train_mnist.npy')
mnist_data_test=np.load('data/data2489/test_mnist.npy')
print('There are {} images for training'.format(len(mnist_data_train)))
print('There are {} images for testing'.format(len(mnist_data_test)))
# 数据和标签分离(便于后续处理)
Label=[int(i[0]) for i in mnist_data_train]
Data=[i[1:] for i in mnist_data_train]
There are 60000 images for training
There are 10000 images for testing

4. 模型构建

class CNN(nn.Layer):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1=nn.Conv2D(1,32,5)
        self.relu = nn.ReLU()
        self.pool1=nn.MaxPool2D(kernel_size=2,stride=2)
        self.conv2=nn.Conv2D(32,64,5)
        self.pool2=nn.MaxPool2D(kernel_size=2,stride=2)
        self.fc1=nn.Linear(1024,512)
        self.fc2=nn.Linear(512,10)
        # self.softmax = nn.Softmax()
    def forward(self,inputs):
        x = self.conv1(inputs)
        x = self.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool2(x)
        
        x=paddle.reshape(x,[-1,1024])
        x = self.relu(self.fc1(x))
        y = self.fc2(x)
        return y

5. 数据采样函数

# 均匀采样,分配到各个client的数据集都是IID且数量相等的
def IID(dataset, clients):
  num_items_per_client = int(len(dataset)/clients)
  client_dict = {}
  image_idxs = [i for i in range(len(dataset))]
  for i in range(clients):
    client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False)) # 为每个client随机选取数据
    image_idxs = list(set(image_idxs) - client_dict[i]) # 将已经选取过的数据去除
    client_dict[i] = list(client_dict[i])

  return client_dict
# 非均匀采样,同时各个client上的数据分布和数量都不同
def NonIID(dataset, clients, total_shards, shards_size, num_shards_per_client):
  shard_idxs = [i for i in range(total_shards)]
  client_dict = {i: np.array([], dtype='int64') for i in range(clients)}
  idxs = np.arange(len(dataset))
  data_labels = Label

  label_idxs = np.vstack((idxs, data_labels)) # 将标签和数据ID堆叠
  label_idxs = label_idxs[:, label_idxs[1,:].argsort()]
  idxs = label_idxs[0,:]

  for i in range(clients):
    rand_set = set(np.random.choice(shard_idxs, num_shards_per_client, replace=False)) 
    shard_idxs = list(set(shard_idxs) - rand_set)

    for rand in rand_set:
      client_dict[i] = np.concatenate((client_dict[i], idxs[rand*shards_size:(rand+1)*shards_size]), axis=0) # 拼接
  
  return client_dict

class MNISTDataset(Dataset):
    def __init__(self, data,label):
        self.data = data
        self.label = label

    def __getitem__(self, idx):
        image=np.array(self.data[idx]).astype('float32')
        image=np.reshape(image,[1,28,28])
        label=np.array(self.label[idx]).astype('int64')
        return image, label

    def __len__(self):
        return len(self.label)

6. 模型训练

class ClientUpdate(object):
    def __init__(self, data, label, batch_size, learning_rate, epochs):
        dataset = MNISTDataset(data,label)
        self.train_loader = DataLoader(dataset,
                    batch_size=batch_size,
                    shuffle=True,
                    drop_last=True)
        self.learning_rate = learning_rate
        self.epochs = epochs
        
    def train(self, model):
        optimizer=paddle.optimizer.SGD(learning_rate=self.learning_rate,parameters=model.parameters())
        criterion = nn.CrossEntropyLoss(reduction='mean')
        model.train()
        e_loss = []
        for epoch in range(1,self.epochs+1):
            train_loss = []
            for image,label in self.train_loader:
                # image=paddle.to_tensor(image)
                # label=paddle.to_tensor(label.reshape([label.shape[0],1]))
                output=model(image)
                loss= criterion(output,label)
                # print(loss)
                loss.backward()
                optimizer.step()
                optimizer.clear_grad()
                train_loss.append(loss.numpy()[0])
            t_loss=sum(train_loss)/len(train_loss)
            e_loss.append(t_loss)
        total_loss=sum(e_loss)/len(e_loss)
        return model.state_dict(), total_loss

train_x = np.array(Data)
train_y = np.array(Label)
BATCH_SIZE = 32
# 通信轮数
rounds = 100
# client比例
C = 0.1
# clients数量
K = 100
# 每次通信在本地训练的epoch
E = 5
# batch size
batch_size = 10
# 学习率
lr=0.001
# 数据切分
iid_dict = IID(mnist_data_train, 100)
def training(model, rounds, batch_size, lr, ds,L, data_dict, C, K, E, plt_title, plt_color):
    global_weights = model.state_dict()
    train_loss = []
    start = time.time()
    # clients与server之间通信
    for curr_round in range(1, rounds+1):
        w, local_loss = [], []
        m = max(int(C*K), 1) # 随机选取参与更新的clients
        S_t = np.random.choice(range(K), m, replace=False)
        for k in S_t:
            # print(data_dict[k])
            sub_data = ds[data_dict[k]]
            sub_y = L[data_dict[k]]
            local_update = ClientUpdate(sub_data,sub_y, batch_size=batch_size, learning_rate=lr, epochs=E)
            weights, loss = local_update.train(model)
            w.append(weights)
            local_loss.append(loss)

        # 更新global weights
        weights_avg = w[0]
        for k in weights_avg.keys():
            for i in range(1, len(w)):
                # weights_avg[k] += (num[i]/sum(num))*w[i][k]
                weights_avg[k]=weights_avg[k]+w[i][k]   
            weights_avg[k]=weights_avg[k]/len(w)
            global_weights[k].set_value(weights_avg[k])
        # global_weights = weights_avg
        # print(global_weights)
    #模型加载最新的参数
        model.load_dict(global_weights)

        loss_avg = sum(local_loss) / len(local_loss)
        if curr_round % 10 == 0:
            print('Round: {}... \tAverage Loss: {}'.format(curr_round, np.round(loss_avg, 5)))
        train_loss.append(loss_avg)

    end = time.time()
    fig, ax = plt.subplots()
    x_axis = np.arange(1, rounds+1)
    y_axis = np.array(train_loss)
    ax.plot(x_axis, y_axis, 'tab:'+plt_color)

    ax.set(xlabel='Number of Rounds', ylabel='Train Loss',title=plt_title)
    ax.grid()
    fig.savefig(plt_title+'.jpg', format='jpg')
    print("Training Done!")
    print("Total time taken to Train: {}".format(end-start))
  
    return model.state_dict()

#导入模型
mnist_cnn = CNN()
mnist_cnn_iid_trained = training(mnist_cnn, rounds, batch_size, lr, train_x,train_y, iid_dict, C, K, E, "MNIST CNN on IID Dataset", "orange")

Round: 10... 	Average Loss: [0.024]
Round: 20... 	Average Loss: [0.015]
Round: 30... 	Average Loss: [0.008]
Round: 40... 	Average Loss: [0.003]
Round: 50... 	Average Loss: [0.004]
Round: 60... 	Average Loss: [0.002]
Round: 70... 	Average Loss: [0.002]
Round: 80... 	Average Loss: [0.002]
Round: 90... 	Average Loss: [0.001]
Round: 100... 	Average Loss: [0.]
Training Done!
Total time taken to Train: 759.6239657402039

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1055339.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【k8s】集群搭建篇

文章目录 搭建kubernetes集群kubeadm初始化操作安装软件(master、所有node节点)Kubernetes Master初始化Kubernetes Node加入集群部署 CNI 网络插件测试 kubernetes 集群停止服务并删除原来的配置 二进制搭建(单master集群)初始化操作部署etcd集群安装Docker部署master节点解压…

SpringBoot 如何使用 Spring Data MongoDB 访问 MongoDB

使用 Spring Boot 和 Spring Data MongoDB 访问 MongoDB 数据库 在现代应用程序开发中,许多应用都依赖于数据库来存储和检索数据。MongoDB 是一个流行的 NoSQL 数据库,而 Spring Boot 是一个广泛使用的 Java 开发框架。本文将介绍如何使用 Spring Boot …

28383-2012 卷筒料凹版印刷机 学习笔记

声明 本文是学习GB-T 28383-2012 卷筒料凹版印刷机. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了卷筒料凹版印刷机的型式、基本参数、要求、试验方法、检验规则、标志、包装、运输与 贮存。 本标准适用于机组式的卷筒料凹版…

网络协议--链路层

2.1 引言 从图1-4中可以看出,在TCP/IP协议族中,链路层主要有三个目的: (1)为IP模块发送和接收IP数据报; (2)为ARP模块发送ARP请求和接收ARP应答; (3&#xf…

28390-2012 幕墙铝型材高速五面加工中心

声明 本文是学习GB-T 28390-2012 幕墙铝型材高速五面加工中心. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了幕墙铝型材高速五面加工中心的分类、技术要求、试验方法、检测规则、标牌、使用说 明书、包装、运输和贮存。 本标…

基于Java的游戏检索系统设计与实现(源码+lw+部署文档+讲解等)

文章目录 前言用户功能已注册用户的功能后台功能管理员功能具体实现截图论文参考详细视频演示为什么选择我自己的网站自己的小程序(小蔡coding)有保障的售后福利 代码参考源码获取 前言 💗博主介绍:✌全网粉丝10W,CSDN特邀作者、博…

oracle GBK未定义编码使用Unicode写入特殊字符e000迁移lightdb-x测试

E:\HS\LightDBSVN\23.3sql文件\迁移工具\caofa\config\application.properties gbk-->uft8: logging.configclasspath:log4j2.xml # ???? etl.global.sourceDatabaseoracle etl.global.targetDatabaselightdb etl.global.showSqlfalse etl.global.fastFailfalse etl.g…

Python操作自动化

迷途小书童 读完需要 3分钟 速读仅需 1 分钟 当我们需要自动化进行一些重复性的任务时,Python 中的 pyautogui 库就可以派上用场了,这个库可以模拟鼠标和键盘的操作,让我们的程序可以像人一样与计算机进行交互。 首先,我们需要安装…

Beats Studio Buds 连接 Windows 11 声音输出不显示设备

Beats Studio Buds 连接 Windows 11 声音输出不显示设备 Beats Studio Buds 蓝牙耳机连接Windows 11电脑后,无法通过耳机播放声音,在声音输出选项中也没有耳机选项。 问题 蓝牙耳机连接电脑。 在声音输出中查看输出设备选项。 解决方法 以管理员身…

LeetCode每日一题 | 309.买卖股票的最佳时机含冷冻期

题目链接&#xff1a; 309. 买卖股票的最佳时机含冷冻期 - 力扣&#xff08;LeetCode&#xff09; 题目描述&#xff1a; 算法图解&#xff1a; 解题代码&#xff1a; class Solution { public:int maxProfit(vector<int>& prices) {int n prices.size();vector&…

求∑(1,n)⌊k/i⌋∗i

对于[k/i]*i,我们可以分两端&#xff0c;前,最多有段&#xff0c;后边从到n&#xff0c;取值范围为1-&#xff0c;所以最多有段&#xff0c;共2*段。对于每段从i开始&#xff0c;其上界jk/(k/i)&#xff08;维持k/i不变最大范围i-j&#xff09;。 计算[k/i]*i时间复杂度降到级…

Android 命令行工具简介

关于作者&#xff1a;CSDN内容合伙人、技术专家&#xff0c; 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 &#xff0c;擅长java后端、移动开发、商业变现、人工智能等&#xff0c;希望大家多多支持。 目录 一、导读二、概览三、相关工具3.1 Android SDK 命令行工…

TempleteMethod

TempleteMethod 动机 在软件构建过程中&#xff0c;对于某一项任务&#xff0c;它常常有稳定的整体操作结构&#xff0c;但各个子步骤却有很多改变的需求&#xff0c;或者由于固有的原因 &#xff08;比如框架与应用之间的关系&#xff09;而无法和任务的整体结构同时实现。如…

数据结构学习笔记(基础)

绪论 数据结构三要素&#xff08;数据的基本单位是数据元素&#xff0c;数据元素可由若干个数据项组成&#xff0c;一个数据项是构成数据元素的不可分割的最小单位&#xff09; 数据&#xff1a;指的是能被计算机识别、存储和加工处理的信息载体&#xff08;如 Word 文档&#…

【设计模式_实验①_第六题】设计模式——接口的实验模拟应用实验作业一

【实验要求】 货车要装载一批货物&#xff0c;货物由三种商品组成&#xff1a;电视、计算机和洗衣机。卡车需要计算出整批货物的重量。 【实验步骤】UML 过程 在这里插入代码片 public interface ComputerWeight {public abstract double computerWeight(); }public class T…

GD32 看门狗

1. 看门狗的概念 2. 独立看门狗 独立看门狗的原理&#xff1a;设定一个重载值。赋值计数器。每来一个脉冲计数值减减。如果计数值减到0。还没有去喂狗就会产生复位。所以在计数值在0~重载值范围必须要喂一次狗。 在键值寄存器(IWDG_KR)中写入0xCCCC&#xff0c;开始启用独立看…

嵌入式Linux应用开发-基础知识-第十八章系统对中断的处理③

嵌入式Linux应用开发-基础知识-第十八章系统对中断的处理③ 第十八章 Linux系统对中断的处理 ③18.5 编写使用中断的按键驱动程序 ③18.5.1 编程思路18.5.1.1 设备树相关18.5.1.2 驱动代码相关 18.5.2 先编写驱动程序18.5.2.1 从设备树获得 GPIO18.5.2.2 从 GPIO获得中断号18.5…

嵌入式Linux应用开发-基础知识-第十九章驱动程序基石①

嵌入式Linux应用开发-基础知识-第十九章驱动程序基石① 第十九章 驱动程序基石①19.1 休眠与唤醒19.1.1 适用场景19.1.2 内核函数19.1.2.1 休眠函数19.1.2.2 唤醒函数 19.1.3 驱动框架19.1.4 编程19.1.4.1 驱动程序关键代码19.1.4.2 应用程序 19.1.5 上机实验19.1.6 使用环形缓…

Android自动化测试之MonkeyRunner--从环境构建、参数讲解、脚本制作到实战技巧

monkeyrunner 概述、环境搭建 monkeyrunner环境搭建 (1) JDK的安装不配置 http://www.oracle.com/technetwork/java/javase/downloads/index.html (2) 安装Python编译器 https://www.python.org/download/ (3) 设置环境变量(配置Monkeyrunner工具至path目彔下也可丌配置) (4) …

【C语言经典100例题-66】(用指针解决)输入3个数a,b,c,按大小顺序输出。

代码&#xff1a; #include<stdio.h> #define _CRT_SECURE_NO_WARNINGS 1//VS编译器使用scanf函数时会报错&#xff0c;所以添加宏定义 swap(p1, p2) int* p1, * p2; {int p;p *p1;*p1 *p2;*p2 p; } int main() {int n1, n2, n3;int* pointer1, * pointer2, * point…