【联邦学习】联邦学习量化——non-iid数据集下的仿真

news2025/1/11 11:46:40

文章目录

    • 改进项目背景
    • 量化函数的改进
    • non-iid数据集的设置
      • Fedlab划分数据集的踩雷

改进项目背景

在前面的项目中,虽然对联邦学习中,各个ue训练出来的模型上传的参数进行了量化,并仿真的相关结果。但是仍有一些俺不是非常符合场景的情况,需要改进的方向如下:

  1. 量化函数需要重写,将前面的只对小数点后进行0~1的量化改成自适应的在一段范围之内的数组量化。
  2. 信道仿真函数,在真实的通信环境中,一个信道的速率模拟可以由一个基于正态分布的初始速率,每隔一定时间加减均匀分布的变化值。
  3. 根据量化的程度不同,模型分为基础值和增量值,先传数据量较少的基础值,如果通信条件好的画
  4. 需要模仿接收方BS的接收规则:首先应该计算传输过程中的耗时,如果耗时超过了一个等待的门限,那么这个模型就不会被纳入聚合的model们里。如果接收了基础值之后,BS还会等待一段时间,如果增量值没到达,那只能用基础值去参与聚合了。这样对于BS来说就陷入一个博弈:是采用更精确的量化模型去提高自己模型的准确度呢,还是采用更少的量化程度的模型来保证在信道上能够正确传输。而我正是要仿真这样一个场景。
  5. 联邦学习框架的使用:PySyft在自己一个电脑上的仿真完全没用!完全可以抛弃框架自己写代码去模拟模型聚合与信道上的传输过程。
  6. 急需一个能够仿真出non-iid数据集的库,方便后续的仿真代码编写。

量化函数的改进

对原始数组进行线性变化,映射在一定范围内:
V q = Q × ( V x − min ⁡ ( V x ) ) V_q=Q\times(V_x-\min(V_x)) Vq=Q×(Vxmin(Vx))
V x ′ = V q / Q + min ⁡ ( V x ) V_x'=V_q/Q+\min(V_x) Vx=Vq/Q+min(Vx)
Q = S / R , R = max ⁡ ( V x ) + m i n ( V x ) , S = 1 < < b i t s − 1 Q=S/R,R=\max(V_x)+min(V_x),S=1<<bits-1 Q=S/R,R=max(Vx)+min(Vx),S=1<<bits1
其中 V x V_x Vx表示原浮点数, V q V_q Vq表示量化后的定点数值, V x ′ V_x' Vx表示根据量化参数还原出的浮点数,bits为量化比特位数。
传输的时候只需要传输低比特矩阵 V q V_q Vq和参数 Q , S , R Q,S,R Q,S,R等,在接收端即可还原成浮点数。

总而言之,举例:一个正弦函数值数组,经过4bit量化后呈现如下效果:
在这里插入图片描述
而神经网络中同一个层的tensor,数值分布恰好在同一个数量范围内,适合这样的数组量化:我们选择这个tensor中的最大值和最小值,以此为范围进行量化。

#4bit量化前:
tensor([ 0.0201,  0.0059,  0.0153, -0.0319, -0.0419,  0.0025, -0.0467, -0.0022,
         0.0106,  0.0512, -0.0321, -0.0190, -0.0409,  0.0128,  0.0191,  0.0479,
        -0.0289, -0.0515, -0.0237, -0.0473, -0.0420, -0.0156, -0.0371,  0.0184,
         0.0014,  0.0103, -0.0436, -0.0375,  0.0042, -0.0070,  0.0027,  0.0168])
#4bit量化后:
tensor([ 0.0215,  0.0072,  0.0143, -0.0287, -0.0430,  0.0000, -0.0502,  0.0000,
         0.0072,  0.0502, -0.0287, -0.0215, -0.0430,  0.0143,  0.0215,  0.0502,
        -0.0287, -0.0502, -0.0215, -0.0502, -0.0430, -0.0143, -0.0359,  0.0215,
         0.0000,  0.0072, -0.0430, -0.0359,  0.0072, -0.0072,  0.0000,  0.0143])

具体的函数如下:

def Quant(Vx, Q, RQM):
    return round(Q * Vx) - RQM


def QuantRevert(VxQuant, Q, RQM):
    return (VxQuant + RQM) / Q


def ListQuant(data_list, quant_bits):
    # 数组范围估计
    data_min = min(data_list)
    data_max = max(data_list)

    # 量化参数估计
    Q = ((1 << quant_bits) - 1) * 1.0 / (data_max - data_min)
    RQM = (int)(np.round(Q*data_min))

    # 产生量化后的数组
    quant_data_list = []
    for x in data_list:
        quant_data = Quant(x, Q, RQM)
        quant_data_list.append(quant_data)
    quant_data_list = np.array(quant_data_list)
    return (Q, RQM, quant_data_list)


def ListQuantRevert(quant_data_list, Q, RQM):
    quant_revert_data_list = []
    for quant_data in quant_data_list:
        # 量化数据还原为原始浮点数据
        revert_quant_data = QuantRevert(quant_data, Q, RQM)
        quant_revert_data_list.append(revert_quant_data)
    quant_revert_data_list = np.array(quant_revert_data_list)
    return quant_revert_data_list

non-iid数据集的设置

信道变化->BS接收到ue的模型数量变化->聚合时用于平均的model数量变化

⬆️只有数据集是non-iid的时候,model数量变化才能明显表现出对性能的影响。
如果各个ue在相同数据集上训练相同batch,再进行聚合平均,聚合的model数量对性能影响不大。
我寻找到了一个由本校学长参与开发的一个联邦学习函数库Fedlab,除了数据集的处理,库还提供了别的很多在联邦学习中非常有用的函数,如BS和客户机的交流通信函数等,在这里把Github上的repo贴一下:
https://github.com/SMILELab-FL

点开才发现,这个repo居然是同校计算机学院的一位博士学长创建和维护的,后来还在飞书上联系到了他。各位如果有兴趣的话,非常建议在repo的issue上提出问题,他们都会即使解答的。
另外,如果不想下载Fedlab这个库,或者对Dirichlet划分是数学原理感兴趣的,可以参考下面这个:
https://zhuanlan.zhihu.com/p/468992765
按Dirichlet分布划分Non-IID数据集
在这里插入图片描述
由于Dataloader在每次加载时数据的索引不变,因此在多轮测试的时候,每个ue上的数据分布不会变(区别于完全随机)。

Fedlab划分数据集的踩雷

一开始,我是直接按照这个文档来的:https://zhuanlan.zhihu.com/p/411308268,刚好我也需要采用CIFAR10数据集,但是在其中有这样一句导入包:from fedlab.utils.dataset.sampler import SubsetSampler,是错误的,检查源码也发现dataset里面根本就没有sampler,这让我十分抓狂。后面询问之后才知道,原来sampler的效率太低了,他们已经在新版本放弃不用了,新的划分方案直接看github中的tutorial文件夹部分。于是我找到了如下:
partitioned_cifar10的用法

class PartitionedCIFAR10(root, path, dataname, num_clients, download=True, preprocess=False, balance=True, partition=‘iid’, unbalance_sgm=0, num_shards=None, dir_alpha=None, verbose=True, seed=None, transform=None, target_transform=None)

我们就需要先实例化这一个类,然后利用这个类提供的几个函数来实现数据集的划分与加载。这个类有如此多的参数,那么具体每个参数什么含义,我们在使用的时候又该如何设置呢?

  • root (str) – Path to download raw dataset. 和pytorch的datasets一样,填'/cifar10'

  • path (str) – Path to save partitioned subdataset.预训练好的.pkl文件名,我填'/cifar10_hetero_dir.pkl'

  • dataname (str) – “cifar10” or “cifar100”填‘/cifar10’

  • num_clients (int) – Number of clients.要分成几份,对应ue的个数

  • download (bool) – Whether to download the raw dataset.同pytorch里的datasets

  • preprocess (bool) – Whether to preprocess the dataset.是否预划分,这个第一次必须填true,后面就可以填false了

  • balance (bool, optional) – Balanced partition over all clients or not. Default as True.false

  • partition (str, optional) – Partition type, only “iid”, shards, “dirichlet” are supported. Default as “iid”.填'dirichlet'

  • unbalance_sgm (float, optional) – Log-normal distribution variance for unbalanced data partition over clients. Default as 0 for balanced partition.可选项,没填

  • num_shards (int, optional) – Number of shards in non-iid “shards” partition. Only works if partition=“shards”. Default as None.可选项,没填

  • dir_alpha (float, optional) – Dirichlet distribution parameter for non-iid partition. Only works if partition=“dirichlet”. Default as None.0.3

  • verbose (bool, optional) – Whether to print partition process. Default as True.可选项,没填

  • seed (int, optional) – Random seed. Default as None.2022

  • transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version.同pytorch里的datasets,对图像进行预处理,然后转化为tensor

  • target_transform (callable, optional) – A function/transform that takes in the target and transforms it.可选项,没填

其中一个非常容易错的点。PartitionedCIFAR10这个类中的preprocess实际上就是按照所选的划分模式,对整个数据集贴标签形成一个字典,标注每条数据属于哪个ue,保存在一个.pkl文件中。后面在训练加载数据的时候,就按照这个字典从数据集中取数据,就完成了non-iid的划分啦。
同时transform里面和pytorch的写法一样的,可以对图片进行大小的更改,进行normalize等操作,当然一定别忘记了必须Totensor()将图片转化为tensor,然而我发现忘记totensor,改了之后,发现还是报错!仔细以看才知道,原来是改了transform,但是忘记了重新preprocess一下,导致还是按照旧的方式去加载,自然错啦。

后面的使用,PartitionedCIFAR10提供了两个比较有用的函数;
在这里插入图片描述

他们的返回值就是pytorch中的datasetdataloader了。用法也和pytorch中的一样:

hetero = PartitionedCIFAR10(
    root='/cifar10',
    path='/cifar10_hetero_dir.pkl',
    dataname="cifar10",
    num_clients=train_args['num_clients'],
    download=False,
    preprocess=False,
    balance=False,
    partition="dirichlet",
    seed=2022,
    dir_alpha=0.3,
    transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4750, 0.4750, 0.4750], std=[0.2008, 0.2008, 0.2008])]
    ),
    target_transform=transforms.ToTensor()
)

for id, ue in enumerate(UE_list):
    train_loader = hetero.get_dataloader(
        id, batch_size=train_args['batch_size'])
    for batch_idx, (data, target) in enumerate(train_loader):
        # if batch_idx > 100:
        #     break
        ue_data = data.to(device)
        ue_target = target.to(device)
        loss = ue.train(ue_data, ue_target)

后面的就是按照正常方式去训练啦,最后也是成功得到了一个不同bit压缩的联邦学习训练效果对比图:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/55187.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《Hive性能调优实战》读书笔记

很不错的一本书。章节划分清晰明了&#xff0c;可根据个人需要读相应的章节。Hive各个方面的知识体系都有涉及。可作为工具书&#xff0c;常读常新&#xff0c;值得翻阅。 第2章 Hive问题排查与调优思路 优化方法 PL-SQL和T-SQL经验总结&#xff1a; 通过改写SQL&#xff0…

Hust计算机组成原理实验

文章目录logisim的使用1.添加门电路2.添加引脚3.添加导线4.添加文本5.测试电路补充工具实验一&#xff0c;运算器实验8位串行可控加减法器分析一位全加器八位串行加法器即可给出8位串行可控加减法器回答问题CLA74182&#xff08;先行进位加法器&#xff09;参数公式问题位快速加…

10个最常见的JavaScript问题

如今&#xff0c;JavaScript几乎是所有现代web应用程序的核心。这就是为什么JavaScript问题以及找出导致这些问题的错误是web开发人员的首要任务。 用于单页应用程序&#xff08;SPA&#xff09;开发、图形和动画以及服务器端JavaScript平台的强大的基于JavaScript的库和框架并…

opencv c++ 二值图像、阈值计算方法、全局阈值、自适应阈值

1、图像定义&#xff1a; 彩色图像 &#xff1a;三通道&#xff0c;像素值一般为0~255&#xff1b; 灰度图像&#xff1a;单通道&#xff0c;像素值一般为0~255&#xff1b; 二值图像&#xff1a;单通道&#xff0c;像素值一般为0&#xff08;黑色&#xff09;、255&#xff08…

Python计算器(包含机制转换)

实现思路&#xff1a; 要优先处理内层括号运算&#xff0d;&#xff0d;外层括号运算&#xff0d;&#xff0d;先乘除后加减的原则&#xff1a; 1、正则处理用户输入的字符串&#xff0c;然后对其进行判断&#xff0c;判断计算公式是否有括号&#xff0c;有就先将计算公式进行…

判断二叉树是否是平衡二叉树(c#)

问题描述 给定一棵二叉树&#xff0c;判断其是否为平衡二叉树。 示例 示例1 Input: root [3,9,20,null,null,15,7] Output: true 示例2 Input: root [1,2,2,3,3,null,null,4,4] Output: false 解决方案描述 二叉树的每个节点的左子节点和右子节点的高度差小于等于1&#x…

Windows和Linux混合系统通过AD域实现用户集中认证

一、Windows AD域 1、统一认证简介 管理的Linux服务器和Windows服务器如果很多,如果都用本地用户名管理,要管理和记住几十台甚至上百台服务器的不同账号不同密码,这是很难的。但是如果所有服务器账号密码都设置一样,那又完全没有安全性可言。 什么是服务器的集中认证(统…

数据结构(8)树形结构——B树、B+树(含完整建树过程)

目录 8.1.B树 8.1.1.概述 8.1.2.完整建树过程 8.2.B树 8.1.B树 8.1.1.概述 B树存在的意义&#xff1a; 二叉树在存储数据时可能出现向一边倾斜导致查询效率降低的情况&#xff0c;为了防止二叉树的倾斜&#xff0c;出现了平衡二叉树&#xff0c;通过旋转的方式保证二叉树…

[附源码]计算机毕业设计springboot校园商铺

项目运行 环境配置&#xff1a; Jdk1.8 Tomcat7.0 Mysql HBuilderX&#xff08;Webstorm也行&#xff09; Eclispe&#xff08;IntelliJ IDEA,Eclispe,MyEclispe,Sts都支持&#xff09;。 项目技术&#xff1a; SSM mybatis Maven Vue 等等组成&#xff0c;B/S模式 M…

[附源码]Python计算机毕业设计Django基于web的羽毛球管理系统

项目运行 环境配置&#xff1a; Pychram社区版 python3.7.7 Mysql5.7 HBuilderXlist pipNavicat11Djangonodejs。 项目技术&#xff1a; django python Vue 等等组成&#xff0c;B/S模式 pychram管理等等。 环境需要 1.运行环境&#xff1a;最好是python3.7.7&#xff0c;…

学生HTML个人网页作业作品 HTML+CSS+JavaScript环保页面设计与实现制作

&#x1f380; 精彩专栏推荐&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb;&#x1f447;&#x1f3fb; ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业…

【PPT计时器】如何在wps演示PPT中使用定时器、计时器功能?不使用第三方插件,仅仅使用第三方计时器软件

一、问题背景和解决思路 很多人在展示PPT时&#xff0c;有精确的时间要求&#xff0c;比如五分钟&#xff0c;十分钟。 这时候&#xff0c;我们希望在演示的时候&#xff0c;PPT上附带一个小计时器、或者定时器。 网上有很多加定时器动画的教程&#xff0c;大多数停留在动画…

【D3.js】1.15-反转 SVG 元素

title: 【D3.js】1.15-反转 SVG 元素 date: 2022-12-02 14:07 tags: [JavaScript,CSS,HTML,D3.js,SVG] 文章目录一、学习目标二、题目三、通关代码参考更新svg坐标的y轴是在顶部的&#xff0c;即画出来的rect也是底朝上&#xff0c;如何让rect的底处于底部呢&#xff1f;一、学…

【C++初阶】STL-string的使用

文章目录一.string初识1.STL简介a.STL的组成b.STL和string的关系2.basic_string二.构造函数三.三种遍历方式四.容量相关的函数1.size()2.reserve()–调整容量3.resize()–调整size五.字符串的增删查改1.assign2.replace3.find()4.substr()5.insert()6.相关应用a.替换空格:b.取出…

【Redis-08】面试题之Redis数据结构与对象-RedisObject(上篇)

Redis本质上是一个数据结构服务器&#xff0c;使用C语言编写&#xff0c;是基于内存的一种数据结构存储系统&#xff0c;它可以用作数据库、缓存或者消息中间件。 我们经常使用的redis的数据结构有5种&#xff0c;分别是&#xff1a;string(字符串)、list(列表)、hash(哈希)、s…

string类的模拟实现

目录 一、浅拷贝、深拷贝 二、传统版本写法的String类 三、现代版本写法的String类 四、String类的模拟实现 一、浅拷贝、深拷贝 构造 //构造函数String(const char* str ""){if (nullptr str){assert(false);return;}_str new char[strlen(str) 1];strcpy(_s…

使用 Pandas 和 SQL 进行实用数据分析,让我们用 pandas 和 SQL 进行数据分析并实际理解它们(教程含数据csv)

Pandas是一种快速、强大、灵活且易于使用的开源数据分析和操作工具, 构建于 Python 编程语言之上。 SQL代表结构化查询语言。SQL 允许您从 RDBMS(关系数据库管理系统)访问数据,并可用于数据分析。 Pandas 和 SQL 都广泛用于数据分析。 在这篇博客中,我们将使用pandas和…

做好自己安全第一责任人 嘀嗒全面上线安全带智能语音提醒

2022年12月2日是第十一个“全国交通安全日”&#xff0c;今年主题为“文明守法 平安回家”。 当天&#xff0c;嘀嗒出行启动主题为“共建三方安全观&#xff0c;安全要靠你我他”共塑行动&#xff0c;倡导平台、用户、行业各方形成合力&#xff0c;共塑共创安全文明的新出行之路…

简单的PCI总线INTx中断实现流程

一个简单的PCI总线INTx中断实现流程,如下图所示。 1. 首先,PCI设备通过INTx边带信号产生中断请求,经过中断控制器(Interrupt Controller,PIC)后,转换为INTR信号,并直接发送至CPU; 2. CPU收到INTR信号置位后,意识到了中断请求的发生,但是此时并不知道是什么中断请求…

记一次 .NET 某电子厂OA系统 非托管内存泄露分析

一&#xff1a;背景 1.讲故事 这周有个朋友找到我&#xff0c;说他的程序出现了内存缓慢增长&#xff0c;没有回头的趋势&#xff0c;让我帮忙看下到底怎么回事&#xff0c;据朋友说这个问题已经困扰他快一周了&#xff0c;还是没能找到最终的问题&#xff0c;看样子这个问题…