【FATE联邦学习】get out put data报错output dir = result[“directory“]KeyError:directory“

news2024/12/26 12:12:59

在这里插入图片描述报错信息不清不楚的。

经过几个周的排查,有以下原因:

  • 自定义的trainer里面的predict函数没有返回有效的返回值。
  • 也有可能是自定义的网络没有使用softmax结尾。(若没有,加上即可)

应该是二者满足其一就可以。。因为有很多网络并不是分类任务,那个在predict里面写好应该也不会出现这个错误。

应当是这样的,可以参考FedAvgTrainer的代码:

def _predict(self, dataset: Dataset):
        pred_result = []

        # switch eval mode
        dataset.eval()
        self.model.eval()
        
        labels = []
        
        length=len(dataset.get_sample_ids())
        ret_rs = torch.rand(length,1)
        ret_label = torch.rand(length, 1).int()

        return dataset.get_sample_ids(), ret_rs, ret_label

我这是随便写的废函数,里面的东西是没有意义的,但是符合FATE框架的接口,加入这些后,get out put data就能够在Fateboard中显示:
在这里插入图片描述可见_predict函数会在我们没看见(或者说人工找不着)的地方被调用,并且要按照一定的格式返回数据才行。

最后我贴一下我自定义的SATrainer,可以看见我全城是没有调用predict函数的。

import pandas as pd
from federatedml.model_base import Metric, MetricMeta
import torch.distributed as dist
from federatedml.nn.backend.utils import distributed_util
from federatedml.nn.backend.utils import deepspeed_util
import apex
import torch
import torch as t
import tqdm
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorClient as SecureAggClient
from federatedml.framework.homo.aggregator.secure_aggregator import SecureAggregatorServer as SecureAggServer
from federatedml.nn.dataset.base import Dataset
from federatedml.nn.homo.trainer.trainer_base import TrainerBase
from federatedml.util import LOGGER, consts
from federatedml.optim.convergence import converge_func_factory


class SATrainer(TrainerBase):
    """

    Parameters
    ----------
    epochs: int >0, epochs to train
    batch_size: int, -1 means full batch
    secure_aggregate: bool, default is True, whether to use secure aggregation. if enabled, will add random number
                            mask to local models. These random number masks will eventually cancel out to get 0.
    weighted_aggregation: bool, whether add weight to each local model when doing aggregation.
                         if True, According to origin paper, weight of a client is: n_local / n_global, where n_local
                         is the sample number locally and n_global is the sample number of all clients.
                         if False, simply averaging these models.

    early_stop: None, 'diff' or 'abs'. if None, disable early stop; if 'diff', use the loss difference between
                two epochs as early stop condition, if differences < tol, stop training ; if 'abs', if loss < tol,
                stop training
    tol: float, tol value for early stop

    aggregate_every_n_epoch: None or int. if None, aggregate model on the end of every epoch, if int, aggregate
                             every n epochs.
    cuda: bool, use cuda or not
    pin_memory: bool, for pytorch DataLoader
    shuffle: bool, for pytorch DataLoader
    data_loader_worker: int, for pytorch DataLoader, number of workers when loading data
    validation_freqs: None or int. if int, validate your model and send validate results to fate-board every n epoch.
                      if is binary classification task, will use metrics 'auc', 'ks', 'gain', 'lift', 'precision'
                      if is multi classification task, will use metrics 'precision', 'recall', 'accuracy'
                      if is regression task, will use metrics 'mse', 'mae', 'rmse', 'explained_variance', 'r2_score'
    checkpoint_save_freqs: save model every n epoch, if None, will not save checkpoint.
    task_type: str, 'auto', 'binary', 'multi', 'regression'
               this option decides the return format of this trainer, and the evaluation type when running validation.
               if auto, will automatically infer your task type from labels and predict results.
    """

    def __init__(self, epochs=10, batch_size=512,  # training parameter
                 early_stop=None, tol=0.0001,  # early stop parameters
                 secure_aggregate=True, weighted_aggregation=True, aggregate_every_n_epoch=None,  # federation
                 cuda=True, pin_memory=True, shuffle=True, data_loader_worker=0,  # GPU & dataloader
                 validation_freqs=None,  # validation configuration
                 checkpoint_save_freqs=None,  # checkpoint configuration
                 task_type='auto'
                 ):

        super(SATrainer, self).__init__()

        # training parameters
        self.epochs = epochs
        self.tol = tol
        self.validation_freq = validation_freqs
        self.save_freq = checkpoint_save_freqs

        self.task_type = task_type
        task_type_allow = [
            consts.BINARY,
            consts.REGRESSION,
            consts.MULTY,
            'auto']
        assert self.task_type in task_type_allow, 'task type must in {}'.format(
            task_type_allow)

        # aggregation param
        self.secure_aggregate = secure_aggregate
        self.weighted_aggregation = weighted_aggregation
        self.aggregate_every_n_epoch = aggregate_every_n_epoch

        # GPU
        self.cuda = cuda
        if not torch.cuda.is_available() and self.cuda:
            raise ValueError('Cuda is not available on this machine')

        # data loader
        self.batch_size = batch_size
        self.pin_memory = pin_memory
        self.shuffle = shuffle
        self.data_loader_worker = data_loader_worker

        self.early_stop = early_stop
        early_stop_type = ['diff', 'abs']
        if early_stop is not None:
            assert early_stop in early_stop_type, 'early stop type must be in {}, bug got {}' \
                .format(early_stop_type, early_stop)

        # communicate suffix
        self.comm_suffix = 'fedavg'

        # check param correctness
        self.check_trainer_param([self.epochs,
                                  self.validation_freq,
                                  self.save_freq,
                                  self.aggregate_every_n_epoch],
                                 ['epochs',
                                  'validation_freq',
                                  'save_freq',
                                  'aggregate_every_n_epoch'],
                                 self.is_pos_int,
                                 '{} is not a positive int')
        self.check_trainer_param([self.secure_aggregate, self.weighted_aggregation, self.pin_memory], [
                                 'secure_aggregate', 'weighted_aggregation', 'pin_memory,'], self.is_bool, '{} is not a bool')
        self.check_trainer_param(
            [self.tol], ['tol'], self.is_float, '{} is not a float')

    def _init_aggregator(self, train_set):
        # compute round to aggregate
        cur_agg_round = 0
        if self.aggregate_every_n_epoch is not None:
            aggregate_round = self.epochs // self.aggregate_every_n_epoch
        else:
            aggregate_round = self.epochs

        # initialize fed avg client
        if self.fed_mode:
            if self.weighted_aggregation:
                sample_num = len(train_set)
            else:
                sample_num = 1.0

            if not distributed_util.is_distributed() or distributed_util.is_rank_0():
                client_agg = SecureAggClient(
                    True, aggregate_weight=sample_num, communicate_match_suffix=self.comm_suffix)
            else:
                client_agg = None
        else:
            client_agg = None

        return client_agg, aggregate_round
    def set_model(self, model: t.nn.Module):
        self.model = model
        if self.cuda:
            self.model = self.model.cuda()

    
    def train(
            self,
            train_set: Dataset,
            validate_set: Dataset = None,
            optimizer: t.optim.Optimizer = None,
            loss=None,
            extra_dict={}):

        if self.cuda:
            self.model = self.model.cuda()

        if optimizer is None or loss is None:
            raise ValueError(
                'optimizer or loss is None')

        self.model, optimizer = apex.amp.initialize(self.model, optimizer, opt_level='O2')

        if self.batch_size > len(train_set) or self.batch_size == -1:
            self.batch_size = len(train_set)
        dl = DataLoader(
            train_set,
            batch_size=self.batch_size,
            pin_memory=self.pin_memory,
            shuffle=self.shuffle,
            num_workers=self.data_loader_worker)

        # compute round to aggregate
        cur_agg_round = 0
        client_agg, aggregate_round = self._init_aggregator(train_set)



        # running var
        cur_epoch = 0
        loss_history = []
        need_stop = False
        evaluation_summary = {}


        # training process
        for i in range(self.epochs):

            if i+1 in [64, 96]:
                optimizer.param_groups[0]['lr'] *= 0.5
                optimizer.param_groups[1]['lr'] *= 0.5

            cur_epoch = i
            LOGGER.info('epoch is {}'.format(i))
            epoch_loss = 0.0
            ce_epoch=0.0
            dice_epoch=0.0
            dice_loss_epoch=0.0
            batch_idx = 0
            

            # for better user interface
            if not self.fed_mode:
                to_iterate = tqdm.tqdm(dl)
            else:
                to_iterate = dl

            for image, mask in to_iterate:
                if self.cuda:
                    image, mask = self.to_cuda(
                        image), self.to_cuda(mask)
                    self.model.cuda()

                image,mask=image.float(),mask.float()
                rand  = np.random.choice([256, 288, 320, 352], p=[0.1, 0.2, 0.3, 0.4])
                image = F.interpolate(image, size=(rand, rand), mode='bilinear')
                mask  = F.interpolate(mask.unsqueeze(1),  size=(rand, rand), mode='nearest').squeeze(1)

                pred = self.model(image)
                pred = F.interpolate(pred, size=mask.shape[1:], mode='bilinear', align_corners=True)[:,0,:,:]
                # LOGGER.info(f'pred {pred.shape}, mask {mask.shape}')
                loss_ce, loss_dice = loss(pred, mask)
                dice_epoch=dice_epoch+1-loss_dice
                ce_epoch=ce_epoch+loss_ce
                dice_loss_epoch=dice_loss_epoch+loss_dice
                

                optimizer.zero_grad()
                # apex loss加速
                with apex.amp.scale_loss(loss_ce+loss_dice, optimizer) as scale_loss:
                    scale_loss.backward()
                    
                    # 打印/Log用
                    # epoch_loss=loss_ce+loss_dice
                    cur_loss=scale_loss
                    epoch_loss+=cur_loss
                # 普通
                # epoch_loss=loss_ce+loss_dice
                # epoch_loss.backward()
                
                optimizer.step()
                
                
                if self.fed_mode:
                    LOGGER.debug(
                        'epoch {} batch {} finished'.format(
                            i, batch_idx))

            # loss compute
            epoch_loss = epoch_loss / len(train_set)
            ce_epoch=ce_epoch/len(train_set)
            dice_epoch=dice_epoch/len(train_set)
            dice_loss_epoch=dice_loss_epoch/len(train_set)

            if not distributed_util.is_distributed() or distributed_util.is_rank_0():
                
                self.callback_loss(epoch_loss.item(),i)
                # self._tracker.log_metric_data(
                #     metric_name="loss",
                #     metric_namespace="train",
                #     metrics=[Metric(epoch_idx, loss)],
                # )
                self.callback_metric('Dice',dice_epoch.item(),'train',i)
                self.callback_metric('Dice Loss',dice_loss_epoch.item(),'train',i)
                self.callback_metric('CE',ce_epoch.item(),'train',i)
                self.callback_metric('LR',optimizer.param_groups[0]['lr'],'train',i)
                
                loss_history.append(float(epoch_loss))
                LOGGER.info('epoch loss is {}'.format(epoch_loss.item()))
            

            # federation process, if running local mode, cancel federation
            if client_agg is not None or distributed_util.is_distributed():
                if not (self.aggregate_every_n_epoch is not None and (i + 1) % self.aggregate_every_n_epoch != 0):

                    # model averaging, only aggregate trainable param
                    if self._deepspeed_zero_3:
                        deepspeed_util.gather_model(self.model)

                    if not distributed_util.is_distributed() or distributed_util.is_rank_0():
                        self.model = client_agg.model_aggregation(self.model)
                        if distributed_util.is_distributed() and distributed_util.get_num_workers() > 1:
                            self._share_model()
                    else:
                        self._share_model()

                    # agg loss and get converge status
                    if not distributed_util.is_distributed() or distributed_util.is_rank_0():
                        converge_status = client_agg.loss_aggregation(epoch_loss.item())
                        cur_agg_round += 1
                        if distributed_util.is_distributed() and distributed_util.get_num_workers() > 1:
                            self._sync_converge_status(converge_status)
                    else:
                        converge_status = self._sync_converge_status()

                    if not distributed_util.is_distributed() or distributed_util.is_rank_0():
                        LOGGER.info(
                            'model averaging finished, aggregate round {}/{}'.format(
                                cur_agg_round, aggregate_round))

                    if converge_status:
                        LOGGER.info('early stop triggered, stop training')
                        need_stop = True
                    
            # save check point process
            # if self.save_freq is not None and ((i + 1) % self.save_freq == 0):
            #     if self._deepspeed_zero_3:
            #         deepspeed_util.gather_model(self.model)

            # if not distributed_util.is_distributed() or distributed_util.is_rank_0():
            #     if self.save_freq is not None and ((i + 1) % self.save_freq == 0):

            #         if self.save_to_local_dir:
            #             self.local_checkpoint(
            #                 self.model, i, optimizer, converge_status=need_stop, loss_history=loss_history)
            #         else:
            #             self.checkpoint(
            #                 self.model, i, optimizer, converge_status=need_stop, loss_history=loss_history)
            #         LOGGER.info('save checkpoint : epoch {}'.format(i))

            # if meet stop condition then stop
            if need_stop:
                break
            
        # post-process
        # if self._deepspeed_zero_3:
        #     deepspeed_util.gather_model(self.model)

        # if not distributed_util.is_distributed() or distributed_util.is_rank_0():
        #     best_epoch = int(np.array(loss_history).argmin())

        #     if self.save_to_local_dir:
        #         self.local_save(model=self.model, optimizer=optimizer, epoch_idx=cur_epoch, loss_history=loss_history,
        #                         converge_status=need_stop, best_epoch=best_epoch)
        #     else:
        #         self.save(model=self.model, optimizer=optimizer, epoch_idx=cur_epoch, loss_history=loss_history,
        #                   converge_status=need_stop, best_epoch=best_epoch)

        #     best_epoch = int(np.array(loss_history).argmin())
        #     self.summary({
        #         'best_epoch': best_epoch,
        #         'loss_history': loss_history,
        #         'need_stop': need_stop,
        #         'metrics_summary': evaluation_summary
        #     })

        
    def _predict(self, dataset: Dataset):

        pred_result = []

        # switch eval mode
        dataset.eval()
        self.model.eval()

        
        labels = []
        # with torch.no_grad():

        #     for images, masks in DataLoader(
        #             dataset, self.batch_size):
        #         if self.cuda:
        #             images,masks = self.to_cuda(images,masks)
        #         pred = self.model(images)
        #         pred_result.append(pred)
        #         # labels.append(batch_label)

        #     ret_rs = torch.concat(pred_result, axis=0)
        #     ret_label = torch.concat(labels, axis=0)

        # # switch back to train mode
        # dataset.train()
        # self.model.train()
        
        
        length=len(dataset.get_sample_ids())
        ret_rs = torch.rand(length,1)
        ret_label = torch.rand(length, 1).int()

        return dataset.get_sample_ids(), ret_rs, ret_label
        

    def predict(self, dataset: Dataset):

        ids, ret_rs, ret_label=self._predict(dataset)

        if self.fed_mode:
            return self.format_predict_result(
                ids, ret_rs, ret_label, task_type=self.task_type)
        else:
            return ret_rs, ret_label

    def server_aggregate_procedure(self, extra_data={}):

        # converge status
        check_converge = False
        converge_func = None
        if self.early_stop:
            check_converge = True
            converge_func = converge_func_factory(
                self.early_stop, self.tol).is_converge
            LOGGER.info(
                'check early stop, converge func is {}'.format(converge_func))

        LOGGER.info('server running aggregate procedure')
        server_agg = SecureAggServer(True, communicate_match_suffix=self.comm_suffix)

        # aggregate and broadcast models
        for i in range(self.epochs):
            if not (self.aggregate_every_n_epoch is not None and (i + 1) % self.aggregate_every_n_epoch != 0):

                # model aggregate
                server_agg.model_aggregation()
                converge_status = False

                # loss aggregate
                agg_loss, converge_status = server_agg.loss_aggregation(
                    check_converge=check_converge, converge_func=converge_func)
                
                self.callback_loss(agg_loss, i)
                

                # save check point process
                if self.save_freq is not None and ((i + 1) % self.save_freq == 0):
                    self.checkpoint(epoch_idx=i)
                    LOGGER.info('save checkpoint : epoch {}'.format(i))

                # check stop condition
                if converge_status:
                    LOGGER.debug('stop triggered, stop aggregation')
                    break

        LOGGER.info('server aggregation process done')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/708536.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

利用spleeter库实现人声和音乐分离(踩过的坑及解决方法)

0&#xff1a;起因 事情的起因是&#xff1a;想用一首歌的伴奏剪视频&#xff0c;但找遍各大平台&#xff0c;都只有原曲&#xff0c;没有伴奏。能找到的进行人声和背景音乐分离的软件都要收费&#xff0c;最后决定用spleeter库&#xff0c;尝试进行音频分离。 1&#xff1a;…

DAY3、DAY4(路飞)

字典 因为key不能为变量 只能为不可变的值 字典的key必须是唯一的 不然后面的值会吧前面的值覆盖 字典取值只能一个一个取 且只能通过key查询value 不能反过来 d.items就是变成了列表里 元祖的形式 这种取值是最推荐的。第三种比第二种推荐的方式 是因为第2种取值会先转…

Redis实战篇(四)

六.秒杀优化 6.1 秒杀优化-异步秒杀思路 之前的下单流程 当用户发起请求&#xff0c;此时会请求nginx&#xff0c;nginx会访问到tomcat&#xff0c;而tomcat中的程序&#xff0c;会进行串行操作&#xff0c;分成如下几个步骤 1、查询优惠卷 2、判断秒杀库存是否足够 3、查…

【NoSQL之 Redis配置】

目录 一、关系数据库与非关系型数据库1、关系型数据库和非关系型数据库区别&#xff08;1&#xff09;数据存储方式不同&#xff08;2&#xff09;扩展方式不同&#xff08;3&#xff09;对事务性的支持不同 2、非关系型数据库产生背景总结 二、Redis简介1、Redis 具有以下几个…

碳排放预测模型 | Python实现基于DT决策树的碳排放预测模型

文章目录 效果一览文章概述研究内容源码设计参考资料效果一览 文章概述 碳排放预测模型 | Python实现基于DT决策树的碳排放预测模型 研究内容 碳排放被认为是全球变暖的最主要原因之一。 该项目旨在提供各国碳排放未来趋势的概述以及未来十年的全球趋势预测。 其方法是分析这些…

迁移Notes最近应用和工作台图标到Nomad Web中

大家好&#xff0c;才是真的好。 今天我们分享一个十分有趣的技术话题&#xff0c;就是将Notes客户机&#xff08;MacOS和Windows&#xff09;上的最近访问应用和工作台图标迁移到Nomad Web中&#xff0c;这样用户就可以在Nomad Web和Notes中获得一致的使用体验。 毕竟Nomad …

maven项目如何引入项目本地jar包

目录 背景操作 背景 由于项目需要&#xff0c;对jar包中的内容进行了一点改变&#xff0c;但是由于不熟悉公司maven仓库发布流程&#xff0c;所以就把jar包放到了项目中&#xff0c;那就需要将本地jar包交给maven管理 操作 在项目中新建目录lib&#xff0c;然后将jar包放在其…

Xilinx ZYNQ系列10款型号IDCODE汇总(2023年7月最新版)

ZYNQ系列产品选型手册&#xff1a;zynq-7000-product-selection-guide /* Zynq Devices. */ #define IDCODE_XC7Z007 0x03723093 #define IDCODE_XC7Z010 0x03722093 #define IDCODE_XC7Z012 0x0373C093 #define IDCODE_XC7Z014 0x03728093 #defi…

Linux学习之服务管理工具systemctl

在CentOS 7中有两种服务集中管理工具&#xff1a; service systemctl /etc/init.d/中放着service的启动脚本。比如network这个脚本里边就有网络服务的启动脚本&#xff0c;cat network | wc -l可以看到这个文件中有264行内容。 /usr/lib/systemd/system下放着systemctl的启动脚…

Redis————主从架构

主从架构搭建 单机多实例 粗制一份redis.conf文件 将相关配置修改为如下值&#xff1a; port 与主节点端口后不相同即可 pidfile pid进程号保存文件pidfile的路径 logfile 日志文件名称 dir 指定数据存放目录 #需要注释掉bind #bind 127.0.0.1&#xff08;bind绑定的是自己机…

Spring Boot 中的 RabbitMQ 的消息接收配置是什么,原理,如何使用

Spring Boot 中的 RabbitMQ 的消息接收配置是什么&#xff0c;原理&#xff0c;如何使用 RabbitMQ 是一个流行的消息队列系统&#xff0c;它可以用于在应用程序之间传递消息。Spring Boot 提供了对 RabbitMQ 的支持&#xff0c;我们可以使用 Spring Boot 中的 RabbitMQ 消息接…

从0到1精通自动化测试,pytest自动化测试框架,allure标记用例级别severity(二十一)

目录 一、前言 二、用例等级 三、pytest用例 四、统计缺陷 五、allure命令行参数allure-severities 一、前言 我们在做功能测试的时候&#xff0c;执行完一轮测试用例&#xff0c;输出测试报告的时候&#xff0c;会有统计缺陷的数量和等级 在做自动化测试的过程中&#…

SpringBoot3【② Web开发】

SpringBoot3-Web开发 SpringBoot的Web开发能力&#xff0c;由SpringMVC提供。 0. WebMvcAutoConfiguration原理 1. 生效条件 AutoConfiguration(after { DispatcherServletAutoConfiguration.class, TaskExecutionAutoConfiguration.class,ValidationAutoConfiguration.clas…

12-C++算法笔记-递推

&#x1f4d6; 引入 让我们从一个有趣的例子开始&#xff0c;棋盘放米的问题。假设有一个 8 8 8\times8 88 的棋盘&#xff0c;皇帝想要奖赏一位大臣。大臣提出的要求是在棋盘上按如下规则领赏&#xff1a;第一个格子上放一粒米&#xff0c;随后的每个格子都放置前一个格子上…

IM即时通讯APP在聊天场景中的应用

即时通讯&#xff08;IM&#xff09;应用可以满足人们随时随地进行文字、语音、图片、视频等多媒体信息的传递需求&#xff0c;为个人和企业提供了高效、便捷的沟通方式。在企业中&#xff0c;IM即时通讯APP更是发挥着重要的作用&#xff0c;促进了协作和团队工作的效率提升。以…

jenkins邮箱设置报:501 mail from address must be same as authorization user

jenkins配置邮箱时遇到如下错误&#xff1a;501 mail from address must be same as authorization user 原因是管理员邮箱地址与发送邮箱地址不统一&#xff0c;配置管理员邮件地址&#xff1a;系统管理-系统配置-Jenkins Location&#xff0c;输入与发件人统一的地址即可

Airtest:Windows桌面应用自动化测试三【Airtest脚本的点击位置与点击偏移】

Airtest脚本的点击位置与点击偏移 1. 前言2. Airtest的点击位置3.Airtest的点击偏移图像点击偏移&#xff0c;常用于下述场景中&#xff1a;3.1、一个是&#xff0c;当我们的页面中&#xff0c;存在很多个相同的图标&#xff0c;我们想指定点击某个位置的图标&#xff0c;就有可…

台灯的功能作用有哪些?分享好用的台灯

照明对于我们来说是非常重要的&#xff0c;从远古时期的钻木取火到古代的蜡烛、油灯以及近代电灯&#xff0c;可以说人们在不断的创造着能够发亮的东西&#xff0c;而现在电灯的种类很多&#xff0c;包括壁灯、吊灯、台灯等&#xff0c;因为实际用到的环境不同起到的作用也不尽…

抖音矩阵号/抖音短视频SEO矩阵系统源码开发及开发者思路分享....

抖音矩阵号短视频系统&#xff0c;抖音矩阵号系统源码开发,思路分享&#xff0c;说一点开发者掏心窝子的话...... 一套优秀的短视频获客系统&#xff0c;支持短视频智能剪辑、短视频定时发布&#xff0c;短视频排名查询及优化&#xff0c;短视频智能客服等&#xff0c;那么短视…

C语言进阶---动态内存管理

1、为什么存在动态内存分配&#xff1f; 我们已经掌握的内存开辟方式有&#xff1a; int a 20; //在栈空间上开辟四个字节。 char arr[20]; //在栈空间上开辟10个字节的连续空间。但是上述的开辟空间的方式有两个特点&#xff1a; 开辟空间大小是固定的数组在申…