联邦学习框架实现
联邦学习训练过程由服务器和客户端两部分组成。
客户端将本地数据训练得到的模型上传服务器,服务器通过聚合客户端上传的服务器再次下发新一轮的模型,原理很简单,那么我们开始动手写代码。
1. 客户端部分:
客户端部分很简单,我们需要做的就是获取全局模型,利用本地数据进行训练,然后返回模型参数差。
这里我是推荐使用name_parameters进行参数更新提交,因为torch的训练模型中带有许多保存临时值的层,都提交没有意义。
我们首先建立一个字典pre_model={}来保存原先的模型数值。这个地方有两个作用:
- 保存先前的模型数值,用来计算新模型参数和旧模型参数之间的差,model的parameters减去new_model的parameters。
- 用来还原全局模型。
这里的2解释一下,因为是单机训练模型,我们没有真实地给客户端分配一个模型,因为如果要真的分配模型,那么就需要deepcopy一个model给每一个客户端,每一个模型大小都不小,开个十几个二十几个16G内存就被占满了,而且速度也没有明显变快。因此我们采取的方法是,客户端的服务器共享一个模型,客户端每次训练解释后恢复全局模型为训练前的参数值,服务器等到全部客户端结束一轮训练后,更新此全局模型。
代码编写部分:模型训练部分与集中式机器学习训练完全一致,就是训练完之后需要复位以及return diff。
class Client(object):
client_id = 0
def __init__(
self,
batch_size,
lr,
momentum,
model_parameter,
local_epochs,
model,
train_dataset,
) -> None:
Client.client_id += 1
self.client_id = Client.client_id
self.batch_size = batch_size
self.lr = lr
self.momentum = momentum
self.model_parameter = model_parameter
self.local_epochs = local_epochs
self.local_model = model
self.train_dataset = train_dataset
self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
def local_train(self):
# record the previous model parameters
# 1. calculate diff
# 2. restoring the global model
pre_model = {}
if self.model_parameter == "all":
for name, param in self.local_model.state_dict().items():
pre_model[name] = param.clone()
else:
for name, param in self.local_model.named_parameters():
pre_model[name] = param.clone()
optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.lr, momentum=self.momentum)
epoch = self.local_epochs
self.local_model.train()
for _ in range(epoch):
for _, batch in enumerate(self.train_loader):
data, target = batch
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
optimizer.zero_grad()
output = self.local_model(data)
loss = torch.nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
print(f"{self.client_id} complete!")
# record the differences between the local model and the global model
diff = {}
for name, param in pre_model.items():
diff[name] = self.local_model.state_dict()[name] - param
for name, param in pre_model.items():
self.local_model.state_dict()[name] = param
return diff
这段代码以后还有改写,使用的损失函数,优化器后期都会改写成可以修改的方式。
2. 服务器部分:
服务器部分包含了一个聚合部分和一个模型评估部分。聚合很简单,目前就写了fedavg,还是平均的聚合,clients_diff是一个元素为diff字典的列表,我们通过遍历此数组将每个客户端的参数差相加,最后乘以一定的权重加在最后的全局模型上,得到本轮迭代的结果。
class Server(Model):
def __init__(
self,
model_name,
batch_size,
lamda,
eval_dataset
):
super().__init__(model_name, eval_dataset)
self.global_model = self.model
self.lamda = lamda
self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)
def model_aggregation(self, clients_diff):
weight_accumulator = {}
for name, params in clients_diff[0].items():
weight_accumulator[name] = torch.zeros_like(params)
for _, client_diff in enumerate(clients_diff):
for name, params in client_diff.items():
weight_accumulator[name].add_(params)
for name, params in weight_accumulator.items():
update_per_layer = params * self.lamda
if params.type() != update_per_layer.type():
params.add_(update_per_layer.to(torch.int64))
else:
params.add_(update_per_layer)
def model_eval(self):
self.global_model.eval()
total_loss = 0.0
correct = 0
dataset_size = 0
for batch_id, batch in enumerate(self.eval_loader):
data, target = batch
dataset_size += data.size()[0]
if torch.cuda.is_available():
data = data.cuda()
target = target.cuda()
output = self.global_model(data)
total_loss += torch.nn.functional.cross_entropy(output, target, reduction='sum').item()
pred = output.data.max(1)[1]
correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()
acc = 100.0 * (float(correct) / float(dataset_size))
total_l = total_loss / dataset_size
return acc, total_l
3. 主函数部分
主函数就是实例化一个服务器,和一群客户端。主函数部分里比较重要的就是数据集的划分,目前实现了两种方法,一种是平均划分,一种是dirichlet划分。
关于dirichlet划分可以详见:Dirichlet分布
if __name__ == '__main__':
# load the configure file
with open('./conf.json', 'r') as f:
conf = json.load(f)
# load dataset
train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["dataset"])
server = Server(batch_size=conf["batch_size"],
lamda=conf["lambda"],
model_name=conf["model_name"],
eval_dataset=eval_datasets)
# total clients array
clients = []
if conf["data_distribution"] == 'iid':
n_clients = conf["num_models"]
data_len = len(train_datasets)
subset_indices = distribution.split_iid(n_clients, data_len)
for idx in subset_indices:
subset_dataset = Subset(train_datasets, idx)
clients.append(Client(batch_size=conf["batch_size"],
lr=conf["lr"],
momentum=conf["momentum"],
model_parameter=conf["model_parameter"],
local_epochs=conf["local_epochs"],
model=server.global_model,
train_dataset=subset_dataset))
elif conf["data_distribution"] == 'dirichlet':
n_clients = conf["num_models"]
dirichlet_alpha = conf["dirichlet_alpha"]
train_labels = train_datasets.targets
# return an array: every client's index
client_idcs = distribution.dirichlet_split_noniid(train_labels, alpha=dirichlet_alpha, n_clients=n_clients)
for c, subset_indices in enumerate(client_idcs):
subset_dataset = Subset(train_datasets, subset_indices)
clients.append(Client(batch_size=conf["batch_size"],
lr=conf["lr"],
momentum=conf["momentum"],
model_parameter=conf["model_parameter"],
local_epochs=conf["local_epochs"],
model=server.global_model,
train_dataset=subset_dataset))
accuracy = []
losses = []
for e in range(conf["global_epochs"]):
# random choice k clients
candidates = random.sample(clients, conf["k"])
# clients_weight recode the diffs of every client
clients_weight = []
for _, c in enumerate(candidates):
diff = c.local_train()
clients_weight.append(diff)
server.model_aggregation(clients_diff=clients_weight)
acc, loss = server.model_eval()
accuracy.append(acc)
losses.append(loss)
print(f"Epoch {e:d}, acc: {acc:f}, loss: {loss:f}\n")
4. 模型部分和数据集部分
这两个部分就直接使用了torch里自带的数据集和模型,如果想要使用自己的模型和数据集,就和平时pytorch里自己编写模型和数据集一样。
代码地址见联邦学习代码框架,如果对你有帮助的话,可不可以给个三连~