如何编写联邦学习训练框架——Pytorch实现

news2025/1/16 3:56:51

联邦学习框架实现

联邦学习训练过程由服务器客户端两部分组成。
客户端将本地数据训练得到的模型上传服务器,服务器通过聚合客户端上传的服务器再次下发新一轮的模型,原理很简单,那么我们开始动手写代码。
在这里插入图片描述

1. 客户端部分:

客户端部分很简单,我们需要做的就是获取全局模型,利用本地数据进行训练,然后返回模型参数差。
这里我是推荐使用name_parameters进行参数更新提交,因为torch的训练模型中带有许多保存临时值的层,都提交没有意义。
我们首先建立一个字典pre_model={}来保存原先的模型数值。这个地方有两个作用:

  1. 保存先前的模型数值,用来计算新模型参数和旧模型参数之间的差,model的parameters减去new_model的parameters。
  2. 用来还原全局模型。

这里的2解释一下,因为是单机训练模型,我们没有真实地给客户端分配一个模型,因为如果要真的分配模型,那么就需要deepcopy一个model给每一个客户端,每一个模型大小都不小,开个十几个二十几个16G内存就被占满了,而且速度也没有明显变快。因此我们采取的方法是,客户端的服务器共享一个模型,客户端每次训练解释后恢复全局模型为训练前的参数值,服务器等到全部客户端结束一轮训练后,更新此全局模型。

代码编写部分:模型训练部分与集中式机器学习训练完全一致,就是训练完之后需要复位以及return diff。

class Client(object):

	client_id = 0

	def __init__(
		self,
		batch_size,
		lr,
		momentum,
		model_parameter,
		local_epochs,
		model,
		train_dataset,
	) -> None:

		Client.client_id += 1
		self.client_id = Client.client_id
		self.batch_size = batch_size
		self.lr = lr
		self.momentum = momentum
		self.model_parameter = model_parameter
		self.local_epochs = local_epochs
		self.local_model = model
		self.train_dataset = train_dataset
		self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)

	def local_train(self):

		# record the previous model parameters
		# 1. calculate diff
		# 2. restoring the global model
		pre_model = {}
		if self.model_parameter == "all":
			for name, param in self.local_model.state_dict().items():
				pre_model[name] = param.clone()
		else:
			for name, param in self.local_model.named_parameters():
				pre_model[name] = param.clone()

		optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.lr, momentum=self.momentum)
		epoch = self.local_epochs
		self.local_model.train()

		for _ in range(epoch):
			for _, batch in enumerate(self.train_loader):
				data, target = batch

				if torch.cuda.is_available():
					data = data.cuda()
					target = target.cuda()
			
				optimizer.zero_grad()
				output = self.local_model(data)
				loss = torch.nn.functional.cross_entropy(output, target)
				loss.backward()
				optimizer.step()
				
		print(f"{self.client_id} complete!")

		# record the differences between the local model and the global model
		diff = {}

		for name, param in pre_model.items():
			diff[name] = self.local_model.state_dict()[name] - param

		for name, param in pre_model.items():
			self.local_model.state_dict()[name] = param

		return diff

这段代码以后还有改写,使用的损失函数,优化器后期都会改写成可以修改的方式。

2. 服务器部分:

服务器部分包含了一个聚合部分和一个模型评估部分。聚合很简单,目前就写了fedavg,还是平均的聚合,clients_diff是一个元素为diff字典的列表,我们通过遍历此数组将每个客户端的参数差相加,最后乘以一定的权重加在最后的全局模型上,得到本轮迭代的结果。

class Server(Model):
	
	def __init__(
		self,
		model_name,
		batch_size,
		lamda,
		eval_dataset
	):
		super().__init__(model_name, eval_dataset)
		self.global_model = self.model
		self.lamda = lamda
		self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)

	def model_aggregation(self, clients_diff):

		weight_accumulator = {}
		for name, params in clients_diff[0].items():
			weight_accumulator[name] = torch.zeros_like(params)

		for _, client_diff in enumerate(clients_diff):
			for name, params in client_diff.items():
				weight_accumulator[name].add_(params)

		for name, params in weight_accumulator.items():
			update_per_layer = params * self.lamda

			if params.type() != update_per_layer.type():
				params.add_(update_per_layer.to(torch.int64))
			else:
				params.add_(update_per_layer)

	def model_eval(self):
		self.global_model.eval()
		total_loss = 0.0
		correct = 0
		dataset_size = 0

		for batch_id, batch in enumerate(self.eval_loader):
			data, target = batch 
			dataset_size += data.size()[0]
			
			if torch.cuda.is_available():
				data = data.cuda()
				target = target.cuda()
				
			output = self.global_model(data)
			total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
			pred = output.data.max(1)[1]
			correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()

		acc = 100.0 * (float(correct) / float(dataset_size))
		total_l = total_loss / dataset_size
		return acc, total_l

3. 主函数部分

主函数就是实例化一个服务器,和一群客户端。主函数部分里比较重要的就是数据集的划分,目前实现了两种方法,一种是平均划分,一种是dirichlet划分。
关于dirichlet划分可以详见:Dirichlet分布

if __name__ == '__main__':

    # load the configure file
    with open('./conf.json', 'r') as f:
        conf = json.load(f)

    # load dataset
    train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["dataset"])

    server = Server(batch_size=conf["batch_size"],
                    lamda=conf["lambda"],
                    model_name=conf["model_name"],
                    eval_dataset=eval_datasets)

    # total clients array
    clients = []

    if conf["data_distribution"] == 'iid':
        n_clients = conf["num_models"]
        data_len = len(train_datasets)
        subset_indices = distribution.split_iid(n_clients, data_len)
        for idx in subset_indices:
            subset_dataset = Subset(train_datasets, idx)
            clients.append(Client(batch_size=conf["batch_size"],
                                  lr=conf["lr"],
                                  momentum=conf["momentum"],
                                  model_parameter=conf["model_parameter"],
                                  local_epochs=conf["local_epochs"],
                                  model=server.global_model,
                                  train_dataset=subset_dataset))

    elif conf["data_distribution"] == 'dirichlet':
        n_clients = conf["num_models"]
        dirichlet_alpha = conf["dirichlet_alpha"]
        train_labels = train_datasets.targets
        # return an array: every client's index
        client_idcs = distribution.dirichlet_split_noniid(train_labels, alpha=dirichlet_alpha, n_clients=n_clients)

        for c, subset_indices in enumerate(client_idcs):
            subset_dataset = Subset(train_datasets, subset_indices)
            clients.append(Client(batch_size=conf["batch_size"],
                                  lr=conf["lr"],
                                  momentum=conf["momentum"],
                                  model_parameter=conf["model_parameter"],
                                  local_epochs=conf["local_epochs"],
                                  model=server.global_model,
                                  train_dataset=subset_dataset))

    accuracy = []
    losses = []

    for e in range(conf["global_epochs"]):

        # random choice k clients
        candidates = random.sample(clients, conf["k"])

        # clients_weight recode the diffs of every client
        clients_weight = []

        for _, c in enumerate(candidates):
            diff = c.local_train()
            clients_weight.append(diff)

        server.model_aggregation(clients_diff=clients_weight)

        acc, loss = server.model_eval()
        accuracy.append(acc)
        losses.append(loss)

        print(f"Epoch {e:d}, acc: {acc:f}, loss: {loss:f}\n")

4. 模型部分和数据集部分

这两个部分就直接使用了torch里自带的数据集和模型,如果想要使用自己的模型和数据集,就和平时pytorch里自己编写模型和数据集一样。

代码地址见联邦学习代码框架,如果对你有帮助的话,可不可以给个三连~

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/735245.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LVS - DR群集

文章目录 一、DR模式 LVS负载均衡群集1.数据包流向分析 二、LVS-DR模式的特点三、LVS-DR中的ARP问题四、DR模式 LVS负载均衡群集部署1.环境准备2.配置负载调度器(192.168.40.104)2.1.配置虚拟 IP 地址(VIP:192.168.40.180&#xf…

RabbitMQ在SpringBoot中的高级应用(1)

启动RabbitMQ 1. 在虚拟机中启动RabbitMQ,要先切换到root用户下: su root 2.关闭防火墙: systemctl stop firewalld 3.rabbitmq-server start # 启用服务 4.rabbitmq-server -detached # 后台启动 1.消息确认机制 有两种确认的方式: 自动ACK:RabbitMQ将消息发送给…

一些有意思的耗尽型MOS恒流源阻抗对比

貌似没有什么管子能超过DN2540,测试的环境差别不大,LD1014D因为本身耐压太低(25V),而且达不到1mA这个值,因此,测试的时候相应降低了电压,选择了2mA的电流,并将负载电阻减…

Pytorch-ResNet50-MINIST Classify 网络实现流程

分两个文件讲解:1、train.py训练文件 2、test.py测试文件. 1、train.py训练文件 1)从主函数入口开始,设置相关参数 # 主函数入口 if __name__ __main__:# ----------------------------## 是否使用Cuda# 没有GPU可以设置成Fasle# -…

IDEA+SpringBoot+mybatis+bootstrap+jquery+Mysql车险理赔管理系统

IDEASpringBootmybatisbootstrapjqueryMysql车险理赔管理系统 一、系统介绍1.环境配置 二、系统展示1. 管理员登录2.编辑个人信息3.用户管理4.添加用户5.申请理赔管理6.赔偿金发放管理7.待调查事故保单8.已调查记录9.现场勘察管理10.勘察记录11.我的保险管理12.我的理赔管理 三…

Atcoder Beginner Contest 309——D-F讲解

前言 由于最近期末考试,所以之前几场都没打,给大家带了不便,非常抱歉。 这个暑假,我将会持续更新,并给大家带了更好理解的题解!希望大家多多支持。 由于, A ∼ C A\sim C A∼C 题比较简单&am…

现代C++新特性 扩展的聚合类型(C++17 C++20)(PC浏览效果更佳)

文字版PDF文档链接:现代C新特性(文字版)-C文档类资源-CSDN下载 1.聚合类型的新定义 C17标准对聚合类型的定义做出了大幅修改,即从基类公开且非虚继承的类也可能是一个聚合。同时聚合类型还需要满足常规条件。 1.没有用户提供的构造函数。…

用C语言写一个压缩文件的程序

本篇目录 数据在计算机中的表现形式huffman 编码将文件的二进制每4位划分,统计其值在文件中出现的次数构建二叉树搜索二叉树的叶子节点运行并输出新的编码文件写入部分写入文件首部写入数据部分压缩运行调试解压缩部分解压缩测试为可执行文件配置环境变量总结完整代…

23数字图像置乱技术(matlab程序)

1.简述 一、引言 所谓“置乱”,就是将图像的信息次序打乱,a像素移动到b像素位置上,b像素移动到c像素位置上,……,使其变换成杂乱无章难以辨认的图片。数字图像置乱技术属于加密技术,是指发送发借助数学或者…

Python实现PSO粒子群优化算法优化Catboost分类模型(CatBoostClassifier算法)项目实战

说明:这是一个机器学习实战项目(附带数据代码文档视频讲解),如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 PSO是粒子群优化算法(Particle Swarm Optimization)的英文缩写,是一…

《低代码指南》——轻流5.0发布,无代码引擎矩阵全面升级

7月6日,由轻流主办「无代码无边界 202376Day|轻流无代码探索者大会」于上海顺利举行。轻流也在会上重磅发布了更加开放、灵活、低门槛的轻流5.0,和全面升级的专有轻流。 轻流5.0全面迭代升级了轻流的无代码引擎矩阵(表单引擎、流程引擎、报表引擎、门户引擎、数据引擎)。…

软件测试项目实战,电商项目测试实例 - 业务测试(重点)

目录:导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结(尾部小惊喜) 前言 支付功能怎么测试…

pytest自动化测试实战之执行参数

上一篇介绍了如何运行pytest代码,以及用例的一些执行规则,执行用例发现我们中间print输出的内容,结果没有给我们展示出来,那是因为pytest执行时,后面需要带上一些参数。 参数内容 我们可以在cmd中通过输入 pytest -h…

域名捡漏的好方法,希望能够帮到你:域霸扫描器 V0.44 绿色免费版,供大家学习研究参考

高速扫描域名的工具,一均程序每小时五万条。 扫描域名是否注册,注册商是谁,域名的注册日期与过期日期。 供大家学习研究参考! 下载:https://download.csdn.net/download/weixin_43097956/88025564

【SpringBoot——Error记录】

IDEA正常安装后,运行按钮为灰色解决方法尝试 解决方法一(本人适用)解决方法二 解决方法一(本人适用) 检查创建项目时JDK是否添加,版本是否正确。 解决方法二 点击左下角的Structure 参考链接&#xff1…

回归预测 | MATLAB实现WOA-CNN-LSTM鲸鱼算法优化卷积长短期记忆神经网络多输入单输出回归预测

回归预测 | MATLAB实现WOA-CNN-LSTM鲸鱼算法优化卷积长短期记忆神经网络多输入单输出回归预测 目录 回归预测 | MATLAB实现WOA-CNN-LSTM鲸鱼算法优化卷积长短期记忆神经网络多输入单输出回归预测预测效果基本介绍模型描述程序设计学习总结参考资料 预测效果 基本介绍 回归预测 …

node中的数据持久化之mongoDB

一、什么是mongoDB MongoDB是一种开源的非关系型数据库,正如它的名字所表示的,MongoDB支持的数据结构非常松散,是一种以bson格式(一种json的存储形式)的文档存储方式为主,支持的数据结构类型更加丰富的NoS…

mysql多表查询练习题

创建表及插入数据 create table if not exists dept3( deptno varchar(20) primary key , -- 部门号 name varchar(20) -- 部门名字 ); -- 创建员工表 create table if not exists emp3( eid varchar(20) primary key , -- 员工编号 ename varchar(20), -- 员工名字 age int, -…

换零钱——最小钱币张数(贪心算法)

贪心算法:根据给定钱币面值列表,输出给定钱币金额的最小张数。 (本笔记适合学完python基本数据结构,初通 Python 的 coder 翻阅) 【学习的细节是欢悦的历程】 Python 官网:https://www.python.org/ Free:大咖免费“圣…

CS EXE上线主机+文件下载上传键盘记录

前言 书接上文,CobaltStrike_1_部署教程及CS制作office宏文档钓鱼教程,该篇介绍【使用CS生成对应exe木马,上线主机;对上线主机进行,文件下载,文件上传,键盘记录】。 PS:文章仅供学习…