【FATE联邦学习】自定义数据集自定义神经网络模型下的横向纵向训练

news2024/9/23 5:30:20

前言

代码大部分来自

  • https://fate.readthedocs.io/en/latest/tutorial/pipeline/nn_tutorial/Hetero-NN-Customize-Dataset/#example-implement-a-simple-image-dataset
  • https://fate.readthedocs.io/en/latest/tutorial/pipeline/nn_tutorial/Homo-NN-Customize-your-Dataset/

但是官方的文档不完整,我自己记录完整一下。

我用的是mnist数据集https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/mnist.zip。
目录结构如下,横向的话,加载两次mnist就可以,而纵向一方加载mnist_guest带标签,一方加载mnist_host没有标签。mnist12两个文件夹没有用,不用管。
在这里插入图片描述

由于官方demo中的需要使用jupyter,不适合普通Python代码,本文给出此例子。在Python的解释器上,要注意在环境变量里加入FATE的安装包里的bin/init_env.sh里面的Python解释器路径,否则federatedml库会找不到。

横向

自定义数据集

自定义数据集,然后再本地测试一下。

import os
from torchvision.datasets import ImageFolder
from torchvision import transforms
from federatedml.nn.dataset.base import Dataset

class MNISTDataset(Dataset):
    
    def __init__(self, flatten_feature=False): # flatten feature or not 
        super(MNISTDataset, self).__init__()
        self.image_folder = None
        self.ids = None
        self.flatten_feature = flatten_feature
        
    def load(self, path):  # read data from path, and set sample ids
        # read using ImageFolder
        self.image_folder = ImageFolder(root=path, transform=transforms.Compose([transforms.ToTensor()]))
        # filename as the image id
        ids = []
        for image_name in self.image_folder.imgs:
            ids.append(image_name[0].split('/')[-1].replace('.jpg', ''))
        self.ids = ids
        return self

    def get_sample_ids(self):  # implement the get sample id interface, simply return ids
        return self.ids
    
    def __len__(self,):  # return the length of the dataset
        return len(self.image_folder)
    
    def __getitem__(self, idx): # get item
        ret = self.image_folder[idx]
        if self.flatten_feature:
            img = ret[0][0].flatten() # return flatten tensor 784-dim
            return img, ret[1] # return tensor and label
        else:
            return ret



ds = MNISTDataset(flatten_feature=True)
ds.load('mnist/')
# print(len(ds))
# print(ds[0])
# print(ds.get_sample_ids()[0])

成功输出后,要手动在FAET/federatedml.nn.datasets下新建数据集文件,把上文的代码扩充成组件类的形式,如下

import torch
from federatedml.nn.dataset.base import Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
import numpy as np
# 这里的包缺什么补什么哈

class MNISTDataset(Dataset):
    
    def __init__(self, flatten_feature=False): # flatten feature or not 
        super(MNISTDataset, self).__init__()
        self.image_folder = None
        self.ids = None
        self.flatten_feature = flatten_feature
        
    def load(self, path):  # read data from path, and set sample ids
        # read using ImageFolder
        self.image_folder = ImageFolder(root=path, transform=transforms.Compose([transforms.ToTensor()]))
        # filename as the image id
        ids = []
        for image_name in self.image_folder.imgs:
            ids.append(image_name[0].split('/')[-1].replace('.jpg', ''))
        self.ids = ids
        return self

    def get_sample_ids(self):  # implement the get sample id interface, simply return ids
        return self.ids
    
    def __len__(self,):  # return the length of the dataset
        return len(self.image_folder)
    
    def __getitem__(self, idx): # get item
        ret = self.image_folder[idx]
        if self.flatten_feature:
            img = ret[0][0].flatten() # return flatten tensor 784-dim
            return img, ret[1] # return tensor and label
        else:
            return ret


if __name__ == '__main__':
    pass

这样就完成了他官方文档所谓的“手动添加”了。添加后federatedml的目录应该是这样的在这里插入图片描述文件名称要和下文的dataset param对应
添加后,FATE就“认识”我们自建的数据集了。
下文中的local test是不需要做手动添加的步骤的,但是local只是做个测试。生产中没什么用……

横向训练

import os
from torchvision.datasets import ImageFolder
from torchvision import transforms
from federatedml.nn.dataset.base import Dataset

# test local process
# from federatedml.nn.homo.trainer.fedavg_trainer import FedAVGTrainer
# trainer = FedAVGTrainer(epochs=3, batch_size=256, shuffle=True, data_loader_worker=8, pin_memory=False) # set parameter

# trainer.local_mode() 

# import torch as t
# from pipeline import fate_torch_hook
# fate_torch_hook(t)
# # our simple classification model:
# model = t.nn.Sequential(
#     t.nn.Linear(784, 32),
#     t.nn.ReLU(),
#     t.nn.Linear(32, 10),
#     t.nn.Softmax(dim=1)
# )

# trainer.set_model(model) # set model

# optimizer = t.optim.Adam(model.parameters(), lr=0.01)  # optimizer
# loss = t.nn.CrossEntropyLoss()  # loss function
# trainer.train(train_set=ds, optimizer=optimizer, loss=loss)  # use dataset we just developed



# 必须在federatedml.nn.datasets目录下  手动加入新的数据集的信息!https://blog.csdn.net/Yonggie/article/details/129404212
# real training
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model

t = fate_torch_hook(t)

import os
# bind data path to name & namespace
fate_project_path = os.path.abspath('./')
host = 1
guest = 2
arbiter = 3
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,
                                                                            arbiter=arbiter)

data_0 = {"name": "mnist_guest", "namespace": "experiment"}
data_1 = {"name": "mnist_host", "namespace": "experiment"}

data_path_0 = fate_project_path + '/mnist'
data_path_1 = fate_project_path + '/mnist'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)
pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)



reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)



from pipeline.component.nn import DatasetParam

dataset_param = DatasetParam(dataset_name='mnist_dataset', flatten_feature=True)  # specify dataset, and its init parameters


from pipeline.component.homo_nn import TrainerParam  # Interface

# our simple classification model:
model = t.nn.Sequential(
    t.nn.Linear(784, 32),
    t.nn.ReLU(),
    t.nn.Linear(32, 10),
    t.nn.Softmax(dim=1)
)

nn_component = HomoNN(name='nn_0',
                      model=model, # model
                      loss=t.nn.CrossEntropyLoss(),  # loss
                      optimizer=t.optim.Adam(model.parameters(), lr=0.01), # optimizer
                      dataset=dataset_param,  # dataset
                      trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=2, batch_size=1024, validation_freqs=1),
                      torch_seed=100 # random seed
                      )


pipeline.add_component(reader_0)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))
pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))


pipeline.compile()
pipeline.fit()


# print result and summary
pipeline.get_component('nn_0').get_output_data()
pipeline.get_component('nn_0').get_summary()

纵向

会用到mnist_host和mnist guest,下载

guest data: https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/mnist_guest.zip

host data: https://webank-ai-1251170195.cos.ap-guangzhou.myqcloud.com/fate/examples/data/mnist_host.zip

你查看一下数据集的格式。FATE里面的纵向,都是一方有标签,一方没标签,跟我所认知的合并数据集那种场景有差别。

纵向数据集

做法参考横向那里,我这里只给出新建的类的代码,跟横向的有一点点差别。

import torch
from federatedml.nn.dataset.base import Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms
import numpy as np

class MNISTDataset(Dataset):
    
    def __init__(self, return_label=True):  
        super(MNISTDataset, self).__init__() 
        self.return_label = return_label
        self.image_folder = None
        self.ids = None
        
    def load(self, path):  
        
        self.image_folder = ImageFolder(root=path, transform=transforms.Compose([transforms.ToTensor()]))
        ids = []
        for image_name in self.image_folder.imgs:
            ids.append(image_name[0].split('/')[-1].replace('.jpg', ''))
        self.ids = ids

        return self

    def get_sample_ids(self, ):
        return self.ids
        
    def get_classes(self, ):
        return np.unique(self.image_folder.targets).tolist()
    
    def __len__(self,):  
        return len(self.image_folder)
    
    def __getitem__(self, idx): # get item 
        ret = self.image_folder[idx]
        img = ret[0][0].flatten() # flatten tensor 784 dims
        if self.return_label:
            return img, ret[1] # img & label
        else:
            return img # no label, for host

if __name__ == '__main__':
    pass

在这里插入图片描述

纵向训练

详细的注释都放在里面了。

import numpy as np
from federatedml.nn.dataset.base import Dataset
from torchvision.datasets import ImageFolder
from torchvision import transforms



# 本地定义的
# class MNISTHetero(Dataset):
    
#     def __init__(self, return_label=True):  
#         super(MNISTHetero, self).__init__() 
#         self.return_label = return_label
#         self.image_folder = None
#         self.ids = None
        
#     def load(self, path):  
        
#         self.image_folder = ImageFolder(root=path, transform=transforms.Compose([transforms.ToTensor()]))
#         ids = []
#         for image_name in self.image_folder.imgs:
#             ids.append(image_name[0].split('/')[-1].replace('.jpg', ''))
#         self.ids = ids

#         return self

#     def get_sample_ids(self, ):
#         return self.ids
        
#     def get_classes(self, ):
#         return np.unique(self.image_folder.targets).tolist()
    
#     def __len__(self,):  
#         return len(self.image_folder)
    
#     def __getitem__(self, idx): # get item 
#         ret = self.image_folder[idx]
#         img = ret[0][0].flatten() # flatten tensor 784 dims
#         if self.return_label:
#             return img, ret[1] # img & label
#         else:
#             return img # no label, for host


# test guest dataset
# ds = MNISTHetero().load('mnist_guest/')
# print(len(ds))
# print(ds[0][0]) 
# print(ds.get_classes())
# print(ds.get_sample_ids()[0: 10])

# test host dataset
# ds = MNISTHetero(return_label=False).load('mnist_host')
# print(len(ds))
# print(ds[0]) # no label


import os
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HeteroNN
from pipeline.component.hetero_nn import DatasetParam
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model

fate_torch_hook(t)

# bind path to fate name&namespace
fate_project_path = os.path.abspath('./')
guest = 4
host = 3

pipeline_img = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host)

guest_data = {"name": "mnist_guest", "namespace": "experiment"}
host_data = {"name": "mnist_host", "namespace": "experiment"}

guest_data_path = fate_project_path + '/mnist_guest'
host_data_path = fate_project_path + '/mnist_host'
pipeline_img.bind_table(name='mnist_guest', namespace='experiment', path=guest_data_path)
pipeline_img.bind_table(name='mnist_host', namespace='experiment', path=host_data_path)



guest_data = {"name": "mnist_guest", "namespace": "experiment"}
host_data = {"name": "mnist_host", "namespace": "experiment"}
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=guest_data)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=host_data)


# 这里为什么要这样定义,可以看文档,有模型https://fate.readthedocs.io/en/latest/federatedml_component/hetero_nn/
hetero_nn_0 = HeteroNN(name="hetero_nn_0", epochs=3,
                       interactive_layer_lr=0.01, batch_size=512, task_type='classification', seed=100
                       )
guest_nn_0 = hetero_nn_0.get_party_instance(role='guest', party_id=guest)
host_nn_0 = hetero_nn_0.get_party_instance(role='host', party_id=host)

# define model
# image features 784, guest bottom model
# our simple classification model:
guest_bottom = t.nn.Sequential(
    t.nn.Linear(784, 8),
    t.nn.ReLU()
)

# image features 784, host bottom model
host_bottom = t.nn.Sequential(
    t.nn.Linear(784, 8),
    t.nn.ReLU()
)

# Top Model, a classifier
guest_top = t.nn.Sequential(
    nn.Linear(8, 10),
    nn.Softmax(dim=1)
)

# interactive layer define
interactive_layer = t.nn.InteractiveLayer(out_dim=8, guest_dim=8, host_dim=8)

# add models, 根据文档定义,guest要add2个,host只有一个
guest_nn_0.add_top_model(guest_top)
guest_nn_0.add_bottom_model(guest_bottom)
host_nn_0.add_bottom_model(host_bottom)

# opt, loss
optimizer = t.optim.Adam(lr=0.01) 
loss = t.nn.CrossEntropyLoss()

# use DatasetParam to specify dataset and pass parameters
# 注意和你手动加入的文件库名字要对应
guest_nn_0.add_dataset(DatasetParam(dataset_name='mnist_hetero', return_label=True))
host_nn_0.add_dataset(DatasetParam(dataset_name='mnist_hetero', return_label=False))

hetero_nn_0.set_interactive_layer(interactive_layer)
hetero_nn_0.compile(optimizer=optimizer, loss=loss)




pipeline_img.add_component(reader_0)
pipeline_img.add_component(hetero_nn_0, data=Data(train_data=reader_0.output.data))
pipeline_img.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=hetero_nn_0.output.data))
pipeline_img.compile()

pipeline_img.fit()


# 获得结果
pipeline_img.get_component('hetero_nn_0').get_output_data()  # get result

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/400529.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[沧海月明珠有泪]两数求和

自己写的像屎山一样的代码,实在难以入眼。学习下人家优秀的代码思想粗看这个代码,用的是递归的思想前面三行的意思:初始化两个链表。第四行:把两个链表的值相加赋给sum第五行:判断是否大于9第六行:如果l1的…

map 、multimap

目录 1.基本概念,键值对 2.map的构造和赋值 3.map的大小和交换,size,empty,swap 4.map的插入和删除,insert(make_pair),clear,erase,[]利用key访问vale 5.map的查找和统计,find,count 6.map容器的排序,自定义排序,仿函数 6.…

5MW风电永磁直驱发电机-1200V直流并网MATLAB仿真模型

MATLAB2016b运行。主体模型:风机传动模块、PMSG模块、蓄电池模块、超级电容模块、无穷大电源。蓄电池控制、风机控制、逆变器控制。风机输出功率:直流母线电压:逆变器输出电压:逆变器输出电流:混合储能荷电状态&#x…

2023年金三银四跳槽季,阿里巴巴 Java10W 字面经,首次公布

Java 面试 “金三银四”这个字眼对于程序员应该是再熟悉不过的了,每年的金三银四都会有很多程序员找工作、跳槽等一系列的安排。说实话,面试中 7 分靠能力,3 分靠技能;在刚开始的时候介绍项目都是技能中的重中之重,它…

超图iServer扩展开发记录Restlet 1

在“REST 服务发布机制简述”中,讲述了 REST 服务发布的过程,资源的信息保存在资源配置文件里,并通过 REST 应用上下文传递给 REST 应用对象,从而在 HTTP 请求到达 REST 应用对象的时候,能够找到合适的资源实现来处理。…

Jwt简介

目录前言What is JSON Web Token?When should you use JSON Web Tokens?What is the JSON Web Token structure?HeaderPayloadSignaturePutting all togetherHow do JSON Web Tokens work?Why should we use JSON Web Tokens?前言 技术文档这种东西,我一直认为…

数枝营销与纷享销客达成战略合作,共同推动B2B企业营与销一体化

近日,营销咨询与数字化服务商数枝营销同国内知名SaaS CRM厂商纷享销客举行了战略合作签约仪式,双方就促进B2B企业的“营与销协同增长”将展开全面合作。纷享销客创始人兼CEO罗旭与数枝营销创始人黄海钧 另据工商信息显示,数枝营销&#xff08…

webshell管理工具-菜刀的管理操作

什么是webshell Webshell是一种运行在Web服务器上的脚本程序,通常由黑客使用来绕过服务器安全措施和获取对受攻击服务器的控制权。Webshell通常是通过利用Web应用程序中的漏洞或者弱密码等安全问题而被植入到服务器上的。 一旦Webshell被植入到服务器上&#xff0…

基于应用理解的协议栈优化

作者:余兵 移动互联网时代,不同的应用追求的产品体验差异性很大。 应用商店和图片等下载类型业务追求速度、越快越好,短视频关注起播、拖拽响应速度和观看过程卡不卡,直播追求画质清晰、高码率和直播过程流畅;而游戏则…

苹果iPhone屏下Touch ID技术专利获批,苹果Find My技术大火

根据美国商标和专利局(USPTO)公示的最新清单,苹果近日获得了屏下 Touch ID 的新技术专利。专利中重点提及了“短波红外线”技术,相关元件位于屏幕下方或者集成到屏幕内。 该专利主要介绍了应用于屏幕 Touch ID 的光学成像系统&…

自动化工具selenium(一)

一)什么是自动化?为什么要做自动化? 自动化测试可以代替一部分手工测试,不能够完全代替手工测试 1)自动化测试相比于手工测试来说人力的投入和时间的投入是非常非常少的,自动化测试能够提高测试效率 2)在回归测试里面,…

被隐藏的过程——预处理

文章目录0. 前言1. 程序的翻译环境和执行环境2. 被隐藏的过程2.1 翻译环境2.2 编译3.2.1 预编译3.2.2 编译2.2.3 汇编2.3 链接2.4 运行环境3. 预处理3.1 预定义符号3.2 #define3.2.1 #define定义标识符3.2.2 #define定义宏3.2.3 #define替换规则3.2.4 #和##3.2.5 带副作用的宏参…

API Gateway vs Load Balancer:选择适合你的网络流量管理组件

本文从对比了 API Gateway 和 Load Balancer 的功能区别,帮助读者更好地了解他们在系统架构中扮演的角色。 作者陈泵,API7.ai 技术工程师。 原文链接 由于互联网技术的发展,网络数据的请求数节节攀升,这使得服务器承受的压力越来…

vue-virtual-scroll-list虚拟列表

当DOM中渲染的列表数据过多时,页面会非常卡顿,非常占用浏览器内存。可以使用虚拟列表来解决这个问题,即使有成百上千条数据,页面DOM元素始终控制在指定数量。 一、参考文档 https://www.npmjs.com/package/vue-virtual-scroll-li…

Web前端学习:章三 -- JavaScript预热(三)

六九:函数的变量提升 函数的变量提升没有var高,var是最高的。 先提var,再提函数 解析: 1、4行打印之前没有定义变量,预解析触发变量提升 2、先提var,再提函数。所以先把var提升到最上面,然后提…

【蓝牙系列】蓝牙5.4到底更新了什么(2)

【蓝牙系列】蓝牙5.4到底更新了什么(2) 一、 背景 上一篇文章讲了蓝牙5.4的PAwR特征,非常适合应用在电子货架标签(ESL)领域, 但是实际应用场景中看,只有PAwR特性是不够的,如何保证广…

【latex】总结最近使用到的画图、表格及公式操作

前言 推荐使用overleaf写latex文章,内含很多会议/期刊的模板,可以直接套用。 https://www.overleaf.com下文都是在写论文过程中比较头疼的部分,有人建议我写完文章,最后再调整格式。但图片过大看起来实在是不适~ 插入图片 \beg…

5GHz 你得先认识DFS

想用Wi-Fi 5GHz?你得先认识DFS! 添加链接描述 无线网络2.4 GHz的频段,因为频道过少、使用技术过多太过拥挤,频宽性能不佳早已不是新闻。在5 GHz的频段,频道数大幅超过2.4 GHz,但其中也有一大部份是DFS频道…

【MySQL高级篇】第04章_逻辑架构

第04章_逻辑架构 1. 逻辑架构剖析 1.1 服务器处理客户端请求 首先MySQL是典型的C/S架构,即Clinet/Server 架构,服务端程序使用的mysqld。 不论客户端进程和服务器进程是采用哪种方式进行通信,最后实现的效果是:客户端进程向服…

线程池的原理

1. 为什么要用线程池降低资源消耗。通过重复利用已创建的线程降低线程创建、销毁线程造成的消耗。提高响应速度。当任务到达时,任务可以不需要等到线程创建就能立即执行。提高线程的可管理性。线程是稀缺资源,如果无限制的创建,不仅会消耗系统…