联邦学习中的模型聚合

news2025/1/9 16:46:49

目录

联邦学习中的模型聚合

1.client-server 算法

2. fully decentralized(完全去中心化)算法


联邦学习中的模型聚合

在联邦学习的情景下引入了多任务学习,其采用的手段是使每个client/task节点的训练数据分布不同,从而使各任务节点学习到不同的模型,且每个任务节点以及全局(global)的模型都由多个分量模型集成。该论文最关键与核心的地方在于将各任务节点学习到的模型进行聚合/通信,依据模型聚合方式的不同,可以将模型采用的算法分为client-server方法,和fully decentralized(完全去中心化)的方法

因为有多种任务聚合器(Aggregator)要实现,采取的措施是先实现Aggregator抽象基类,实现好一些通用方法,并规定好抽象方法的接口,然后具体的任务聚合类继承抽象基类,然后做具体的实现。

我们先来看任务聚合器(Aggregator)这一抽象基类

class Aggregator(ABC):
    r"""Aggregator的基类. `Aggregator`规定了client之间的通信"""
    def __init__(
            self,
            clients,
            global_learners_ensemble,
            log_freq,
            global_train_logger,
            global_test_logger,
            sampling_rate=1.,
            sample_with_replacement=False,
            test_clients=None,
            verbose=0,
            seed=None,
            *args,
            **kwargs
    ):

        rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
        self.rng = random.Random(rng_seed) # 随机数生成器
        self.np_rng = np.random.default_rng(rng_seed) # numpy随机数生成器

        if test_clients is None:
            test_clients = []

        self.clients = clients #  List[Client]
        self.test_clients = test_clients #  List[Client]

        self.global_learners_ensemble = global_learners_ensemble # List[Learner]
        self.device = self.global_learners_ensemble.device


        self.log_freq = log_freq
        self.verbose = verbose
        # verbose: 调整输出打印的冗余度(verbosity), 
        # `0` 表示quiet(无任何打印输出), `1` 显示日志, `2` 显示所有局部日志; 默认是 `0`
        self.global_train_logger = global_train_logger
        self.global_test_logger = global_test_logger

        self.model_dim = self.global_learners_ensemble.model_dim # #模型特征维度

        self.n_clients = len(clients)
        self.n_test_clients = len(test_clients)
        self.n_learners = len(self.global_learners_ensemble)

        # 存储为每个client分配的权重(权重为0-1之间的小数)
        self.clients_weights =\
            torch.tensor(
                [client.n_train_samples for client in self.clients],
                dtype=torch.float32
            )
        self.clients_weights = self.clients_weights / self.clients_weights.sum()

        self.sampling_rate = sampling_rate  #  clients在每一轮使用的比例,默认为`1.`
        self.sample_with_replacement = sample_with_replacement #对client进行采用是可重复还是无重复的,with_replacement=True表示可重复的,否则是不可重复的

        # 每轮迭代需要使用到的client个数
        self.n_clients_per_round = max(1, int(self.sampling_rate * self.n_clients))

        # 采样得到的client列表
        self.sampled_clients = list()

        # 记载当前的迭代通信轮数
        self.c_round = 0 
        self.write_logs()

    @abstractmethod
    def mix(self): 
        """
        该方法用于完成各client之间的权重参数与通信操作
        """
        pass

    @abstractmethod
    def update_clients(self): 
        """
        该方法用于将所有全局分量模型拷贝到各个client,相当于boardcast操作
        """
        pass

    def update_test_clients(self):
        """
        将全局(gobal)的所有分量模型都拷贝到各个client上
        """

    def write_logs(self):
        """
        对全局(global)的train和test数据集的loss和acc做记录
        需要对所有client的所有样本做累加,然后除以所有client的样本总数做平均。
        """

    def save_state(self, dir_path):
        """
        保存aggregator的模型state,。例如, `global_learners_ensemble`中每个分量模型'learner'的state字典(以`.pt`文件格式),以及`self.clients` 中每个client的 `learners_weights` (注意,这个权重不是模型内部的参数,而是进行继承的时候对各个分量模型赋予的权重,包含train和test两部分,以一个大小为n_clients(n_test_clients)× n_learners的numpy数组的格式,即`.npy` 文件)。
        """

    def load_state(self, dir_path):
        """
        加载aggregator的模型state,即save_state方法里保存的那些
        """

    def sample_clients(self):
        """
        对clients进行采样,
        如果self.sample_with_replacement为True,则为可重复采样,
        否则,则为不可重复采用。
        最终得到一个clients子集列表并赋予self.sampled_clients
        """

1.client-server 算法

这种方式的通信/聚合方法也称中心化(centralized)方法,因为该方法在每一轮迭代最后将所有client的权重数据汇集到server节点。这种方法的优化迭代部分的伪代码示意如下:

 

落实到具体代码实现上,这种方法的Aggregator设计如下:

class CentralizedAggregator(Aggregator):
    r""" 标准的中心化Aggreagator
    所有clients在每一轮迭代末和average client完全同步.
    """
    def mix(self):
        self.sample_clients()

        # 对self.sampled_clients中每个client的参数进行优化
        for client in self.sampled_clients:
            # 相当于伪代码第11行调用的LocalSolver函数
            client.step()

        # 遍历global模型(self.global_learners_ensemble) 中每一个分量模型(learner)
        # 相当于伪代码第13行
        for learner_id, learner in enumerate(self.global_learners_ensemble):
            # 获取所有client中对应learner_id的分量模型
            learners = [client.learners_ensemble[learner_id] for client in self.clients]
            # global模型的分量模型为所有client对应分量模型取平均,相当于伪代码第14行
            average_learners(learners, learner, weights=self.clients_weights)

        # 将更新后的模型赋予所有clients,相当于伪代码第5行的boardcast操作
        self.update_clients()

        # 通信轮数+1
        self.c_round += 1

        if self.c_round % self.log_freq == 0:
            self.write_logs()

    def update_clients(self):
        """
        此函数负责将所有全局分量模型拷贝到各个client,相当于伪代码中第5行的boardcast操作
        """
        for client in self.clients:
            for learner_id, learner in enumerate(client.learners_ensemble):
                copy_model(learner.model, self.global_learners_ensemble[learner_id].model)

                if callable(getattr(learner.optimizer, "set_initial_params", None)):
                    learner.optimizer.set_initial_params(
                        self.global_learners_ensemble[learner_id].model.parameters()
                    )

2. fully decentralized(完全去中心化)算法

这种方法之所以被称为去中心化的,因为该方法在每一轮迭代不需要所有client的权重数据汇集到一个特定的server节点,而只需要完成每个节点和其邻居进行通信(参数共享)即可。这种方法的优化迭代部分的伪代码示意如下:

落实到具体代码实现上,这种方法的Aggregator设计如下:

 

class DecentralizedAggregator(Aggregator):
    def __init__(
            self,
            clients,
            global_learners_ensemble,
            mixing_matrix,
            log_freq,
            global_train_logger,
            global_test_logger,
            sampling_rate=1.,
            sample_with_replacement=True,
            test_clients=None,
            verbose=0,
            seed=None):

        super(DecentralizedAggregator, self).__init__(
            clients=clients,
            global_learners_ensemble=global_learners_ensemble,
            log_freq=log_freq,
            global_train_logger=global_train_logger,
            global_test_logger=global_test_logger,
            sampling_rate=sampling_rate,
            sample_with_replacement=sample_with_replacement,
            test_clients=test_clients,
            verbose=verbose,
            seed=seed
        )

        self.mixing_matrix = mixing_matrix
        assert self.sampling_rate >= 1, "partial sampling is not supported with DecentralizedAggregator"

    def update_clients(self):
        pass

    def mix(self):
        
        # 对各clients的模型参数进行优化
        for client in self.clients:
            client.step()

        # 存储每个模型各参数混合的权重
        # 行对应不同的client,列对应单个模型中不同的参数
        # (注意:每个分量有独立的mixing_matrix)
        mixing_matrix = torch.tensor(
            self.mixing_matrix.copy(),
            dtype=torch.float32,
            device=self.device
        )

        # 遍历global模型(self.global_learners_ensemble) 中每一个分量模型(learner)
        # 相当于伪代码第14行
        for learner_id, global_learner in enumerate(self.global_learners_ensemble):
            # 用于将指定learner_id的各client的模型state读出暂存
            state_dicts = [client.learners_ensemble[learner_id].model.state_dict() for client in self.clients]

            # 遍历global模型中的各参数, key对应模型中参数的名称
            for key, param in global_learner.model.state_dict().items():
                shape_ = param.shape
                models_params = torch.zeros(self.n_clients, int(np.prod(shape_)), device=self.device)

                for ii, sd in enumerate(state_dicts):
                    # models_params的第ii个下标存储的是第ii个client的(名为key的)参数
                    models_params[ii] = sd[key].view(1, -1) 

                # models_params的每一行是一个client的参数
                # @符号表示矩阵乘/矩阵向量乘
                # 故这里表示每个client参数是其他所有client参数的混合
                models_params = mixing_matrix @ models_params

                for ii, sd in enumerate(state_dicts):
                    # 将第ii个client的(名为key的)参数存入state_dicts中对应位置
                    sd[key] = models_params[ii].view(shape_)

            # 将更新好的参数从state_dicts存入各client节点的模型中
            for client_id, client in enumerate(self.clients):
                client.learners_ensemble[learner_id].model.load_state_dict(state_dicts[client_id])

        # 通信轮数+1
        self.c_round += 1

        if self.c_round % self.log_freq == 0:
            self.write_logs()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/696937.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[python] 进度条使用

from tqdm import tqdm# 创建一个示例字典 my_dict {a: 1, b: 2, c: 3}# 使用tqdm遍历字典的键 for key in tqdm(my_dict.keys()):# 在这里编写你的代码# 这部分代码将会在进度条中显示pass# 使用tqdm遍历字典的值 for value in tqdm(my_dict.values()):# 在这里编写你的代码#…

查看 git的 config 配置

git config --list // 查看全部配置信息git config user.name // 查看指定配置信息 查看某一个配置信息 git config --global user.email 参考 如何查看gitconfig配置_笔记大全_设计学院

牛客BM21 旋转数组的最小数字

描述 有一个长度为 n 的非降序数组,比如[1,2,3,4,5],将它进行旋转,即把一个数组最开始的若干个元素搬到数组的末尾,变成一个旋转数组,比如变成了[3,4,5,1,2],或者[4,5,1,2,3]这样的。请问,给定…

IDEA远程DeBug调试

1. 介绍 当我们在开发过程中遇到一些复杂的问题或需要对代码进行调试时,远程调试是一种非常有用的工具。使用 IntelliJ IDEA 进行远程调试可以让你在远程服务器上的应用程序中设置断点、查看变量和执行调试操作。 远程调试的好处如下: 提供更方便的调试…

大众汽车车载娱乐系统曝安全漏洞,可被远程控制

根据GitHub的一份报告,大众汽车Discover Media信息娱乐系统的漏洞是在2023年2月28日发现的。 该漏洞可能会使未打补丁的系统遭到拒绝服务(DoS)攻击。该漏洞起初是由大众汽车的用户发现的,随后大众汽车方面确认了该漏洞&#xff0…

Golang 一个支持错误堆栈, 错误码, 错误链的工具库

介绍 来腾讯之后主要使用go, 在业务开发中需要一个支持错误码对外返回, 堆栈打印等能力的错误工具库, 先开始使用pkg/errors, 但该项目已经只读, 上次更新是几年前, 而且有一些点比如调整堆栈深度等没有支持, 后续根据业务的需要抽取了一个通用库, 且做了一些优化, 详见下方. …

Apikit 自学日记:发起文档测试-RPC

以DUBBO接口为例,进入某个DUBBO协议的API文档详情页,点击文档上方 测试 标签,即可进入 API 测试页,系统会根据API文档的定义的请求报文自动生成测试界面并且填充测试数据。 对RPC/DUBBO接口发起测试 填写请求报文参数值 此测试D…

Spring(8) Springboot自动配置原理

目录 1.背景2.SpringBootApplication 注解3.三大注解4.EnableAutoConfiguration 注解5.spring.factories6.示例:RedisAutoConfiguration 类 1.背景 Springboot 的自动配置原理,是Springboot中最高频的一道面试题,也是Springboot框架最核心的…

react antd 样式修改

最近在做一个大数据的大屏ui更改,使用的是antd,需要根据ui稿调很多的antd组件样式 特做一个样式修改记录,也给需要的人一些帮助 我们修改的有以下样式: 如何改呢: /*修改 antd 组件样式 */// 仅 drop 下的下拉框改变样…

Linux Host is not allowed to connect to this MySQL server解决方法

先说说这个错误,其实就是我们的MySQL不允许远程登录,所以远程登录失败了,解决方法如下: 在装有MySQL的机器上登录MySQL mysql -u root -p密码 执行use mysql; 执行update user set host % where user root;这一句执行完可能会报…

PoseiSwap IDO、IEO 结束,即将登录 BNB Chain

PoseiSwap 是 Nautilus Chain 上的首个 DEX,其正在基于模块化 Layer3 架构底层,以及Nautilus Chain 所提供的 ZKP 来构建属于自己的 Rollup 应用层,并以订单簿作为交易模型,这为其向更多的功能进行拓展提供了早期基础。

如何打开Windows11上自带安装unbunt系统

首先你看到在你电脑上有一个这样 如果直接鼠标点击打开或者使用powershell打开,也可以打开,但发现只是一堆文件夹而已 正确打开方式,使用unbunt-LTS,打开,这个在哪里? 你可以在电脑Microsoft store输入 u…

春秋云镜cve-2022-32991wp

首先看靶标介绍:该CMS的welcome.php中存在SQL注入攻击 访问此场景,为登录界面,可注册,注册并登陆后找可能存在sql注入的参数,尝试在各个参数后若加一个单引号报错,加两个单引号不报错,说明此参…

CentOS7安装Nginx(tar包安装)

一. 安装环境 操作系统:Centos 7. 最小化安装 服务器地址:*** 二.安装过程: 1. 安装wget yum install wget -y 2. 下载Nginx wget http://nginx.org/download/nginx-1.25.1.tar.gz 官网下载 Nginx: http://nginx.org/en/down…

使用el-menu做侧边栏导航遇到需要点击两次菜单才展开

在根据路由遍历生成侧边导航栏时,遇到一个问题,就是当我点击选中某个垂直菜单时,只有点击第二次它才会展开,第一次在选中垂直菜单之后垂直菜单它就收缩起来了,如下图: 如上图,在我第一次点击选…

Gitlab升级报错二:rails_migration[gitlab-rails] (gitlab::database_migrations line 51)

gitlab-ctl 修改文件目录后出现以下错误:从root --> home 先停掉gitlab: gitlab-ctl stop 单独启动数据库,如果不单独启动数据库,就会报以上错误 sudo gitlab-ctl start postgresql 解决办法: sudo gitlab-rake db:migrat…

Interactive Linear Algebra:免费的交互式线性代数学习教程

本文介绍一个学习线性代数的网站,该网站通过将线性代数中的数学规则可视化,更直观的展示线性代数的运算过程。该网站可以帮助我们更快更高效的学习线性代数。如果有考研的同学或者觉得学习线性代数很枯燥或者很困难的同学,可以了解该网站&…

XILINX 7系列FPGA封装兼容原则及同封装替换注意问题

🏡《电子元器件高级指南》 目录 1,概述2,封装兼容原则3,注意问题4,总结 1,概述 XILINX 7系列的FPGA同封装的元器件一般都是可以兼容的,在一定程度上可以做到PIN TO PIN的替换,本文介…

Windows服务器部署项目自启动

1.下载jar包 https://github.com/kohsuke/winsw 2.重命名 3. 编辑xml文件 <configuration> <id>MyApp</id> <name>MyApp</name> <description>This is MyApp.</description><executable>java</executable> <argum…

NTP服务设置开机自启启动失败

文章目录 前言一、NTP服务设置开机自启启动失败原因二、解决办法 前言 Linux服务器设置了ntpd开机自启动&#xff0c;重启服务器ntpd却没有自启动 一、NTP服务设置开机自启启动失败原因 原因&#xff1a;chrony服务与NTP服务冲突导致开机启动未生效 二、解决办法 关闭chrony服务…