联邦学习开山之作论文解读与Pytorch实现FedAvg

news2025/1/11 21:40:43

参考文献:McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.

参考的文章:

1.联邦学习代码解读,超详细-CSDN博客

2.联邦学习开山之作代码解读与收获_联邦学习代码-CSDN博客

3.数据独立同分布vs非独立同分布:https://www.zhihu.com/question/395555567/answer/3214082688

目录

​​​​​​​Part One.论文解读

Part Two. FedAvg 代码解读

1.代码整体结构

2.options.py

2.1 federated arguments

2.2 model arguments

2.3 other argument

3. models.py

3.1 MLP多层感知机

3.2 LeNet-5

3.3 CNN卷积神经网络

3.4 modelC自定义模型

4. sampling.py

4.0 数据分布

4.1 mnist_iid()

4.2 mnist_noniid()

4.3 mnist_noniid_unequal

4.4 cifar10

5. update.py

5.1 DatasetSplite(Dataset)

5.2 LocalUpdate

5.2.1准备工作

5.2.2 train_val_test()

5.2.3 update_weights() 本地权重更新

5.2.4 评估函数:inference(self,model)

5.2.5 test_inference(self,model)

6. utils.py

6.1 get_dataset(args)

6.2 average_weights(w)

6.3 exp_details(args)

7. 主函数 federated_main.py

7.1 库的引用

7.2 主函数

7.3 建立模型

7.4 模型训练

7.5 测试

8.作图PLOTTING


Part One.论文解读

对文论的理解就直接贴之前的汇报PPT了。

回看去年的汇报ppt,感觉当时的理解并不是很深入,比如对数据的独立同分布和非独立同分布的认知就几乎为零,完全没有当回事。然后在简单弄懂了训练流程后,就开始疯狂啃收敛性证明,感觉也是被带偏了,浪费了很多时间。

这个暑假突然悟了,准备按自己的节奏走。回顾一下万能Baseline FedAvg的代码。

Part Two. FedAvg 代码解读

1.代码整体结构

2.options.py

从最简单的超参数开始。可以看到参数分为三类:

  • federated arguments 联邦参数
  • model arguments 模型参数
  • other arguments 其他参数

超参数代码框架:

import argparse
def args_parser():
    parser = argparse.ArgumentParser()
    # federated arguments (Notation for the arguments followed from paper)
    parser.add_argument('--epochs', type=int, default=10, help="number of rounds of training:R")
    
    # model arguments
    parser.add_argument('--model', type=str, default='mlp', help='model name')
    
    # other arguments
    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
    parser.add_argument('--seed', type=int, default=1, help='random seed')

    args = parser.parse_args()
    return args

2.1 federated arguments

  • epochs:训练轮数,R,默认10轮
  • num_users:用户数K,默认100个用户
  • frac:用户选取比例C,默认0.1
  • local_ep:本地训练次数E,默认10
  • local_bs:本地训练小批次的大小,默认32
  • lr:学习率,默认0.01
  • momentum:SGD动量,默认0.5
# federated arguments (Notation for the arguments followed from paper)
    parser.add_argument('--epochs', type=int, default=10, help="number of rounds of training:R")
    parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
    parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
    parser.add_argument('--local_ep', type=int, default=10, help="the number of local epochs: E")
    parser.add_argument('--local_bs', type=int, default=32, help="local batch size: B")
    parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
    parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")

2.2 model arguments

  • model:模型,默认mlp多层感知机
  • kernel_num:卷积核数量,默认9
  • kernel_size:卷积核大小,默认3、4、5
  • num_channels:图像通道数,默认1
  • norm:归一化方法,默认batch_norm
  • num_filters:过滤器数量,默认32
  • max_pool:最大池化,默认True
# model arguments
    parser.add_argument('--model', type=str, default='mlp', help='model name')
    parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5', help='comma-separated kernel size to use for convolution')
    parser.add_argument('--num_channels', type=int, default=1, help="number of channels of imgs")
    parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets -- 32 for mini-imagenet, 64 for omiglot.")
    parser.add_argument('--max_pool', type=str, default='True',help="Whether use max pooling rather than strided convolutions")

2.3 other argument

  • dataset:数据集,默认mnist
  • num_classes:分类数量,默认10
  • gpu:gpu,默认为使用
  • optimizer:优化器,默认sgd
  • iid:独立同分布,默认1(独立同分布)
  • unequal:平均分配数据集,默认0(平均分配)
  • stopping_rounds:停止轮数,默认10 (??和前面epochs也是训练轮数,这里早停设置?)
  • verbose:日志显示,默认1,(0不输出,1输出带进度条的日志,2输出不带进度条的日志)
  • seed:随机种子,默认7
# other arguments
    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
    parser.add_argument('--gpu', default=None, help="To use cuda, set to a specific GPU ID. Default set to use CPU.")
    parser.add_argument('--optimizer', type=str, default='sgd', help="type of optimizer")
    parser.add_argument('--iid', type=int, default=1, help='Default set to IID. Set to 0 for non-IID.')
    parser.add_argument('--unequal', type=int, default=0, help='whether to use unequal data splits for non-i.i.d setting (use 0 for equal splits)')
    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
    parser.add_argument('--verbose', type=int, default=1, help='verbose')
    parser.add_argument('--seed', type=int, default=1, help='random seed')

3. models.py

接着,说模型,模型这也相对简单。联邦的实验一般不需要自己去设计模型,选择已有模型完成实验即可,联邦的侧重点不在模型本身。这里主要使用了3个模型完成训练:

  • MLP多层感知机
  • LeNet-5
  • CNN卷积神经网络
  • modelC自定义模型

3.1 MLP多层感知机

class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        # 将输入特征映射到隐藏层
        self.layer_input = nn.Linear(dim_in, dim_hidden)
        self.relu = nn.ReLU()
        # Dropout 防止过拟合,提高模型的泛化能力
        self.dropout = nn.Dropout()
        # 将隐藏层的输出映射到输出层
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)
        # 将输出转换为概率分布,从而得到每个类别的概率
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        #  MLP 模型的前向传播
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return self.softmax(x)

3.2 LeNet-5

class MyLeNet5(nn.Module):
    def __init__(self):
        super(MyLeNet5, self).__init__()
        self.c1 = nn.Conv2d(in_channels=1, out_channels=6,kernel_size=5, padding=2)
        self.Sigmoid = nn.Sigmoid()
        self.s2 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.c3 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.s4 = nn.AvgPool2d(kernel_size=2, stride=2)
        self.c5 = nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5)
        self.flatten = nn.Flatten()
        self.f6 = nn.Linear(120, 84)
        self.output = nn.Linear(84, 10)

    # forward():定义前向传播过程,描述了各层之间的连接关系
    def forward(self, x):
        x = self.Sigmoid(self.c1(x))
        x = self.s2(x)
        x = self.Sigmoid(self.c3(x))
        x = self.s4(x)
        x = self.c5(x)
        x = self.flatten(x)
        x = self.f6(x)
        x = self.output(x)
        return x

3.3 CNN卷积神经网络

class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


class CNNFashion_Mnist(nn.Module):
    def __init__(self, args):
        super(CNNFashion_Mnist, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5, padding=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=5, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.fc = nn.Linear(7*7*32, 10)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

3.4 modelC自定义模型

class modelC(nn.Module):
    def __init__(self, input_size, n_classes=10, **kwargs):
        super(AllConvNet, self).__init__()
        self.conv1 = nn.Conv2d(input_size, 96, 3, padding=1)
        self.conv2 = nn.Conv2d(96, 96, 3, padding=1)
        self.conv3 = nn.Conv2d(96, 96, 3, padding=1, stride=2)
        self.conv4 = nn.Conv2d(96, 192, 3, padding=1)
        self.conv5 = nn.Conv2d(192, 192, 3, padding=1)
        self.conv6 = nn.Conv2d(192, 192, 3, padding=1, stride=2)
        self.conv7 = nn.Conv2d(192, 192, 3, padding=1)
        self.conv8 = nn.Conv2d(192, 192, 1)

        self.class_conv = nn.Conv2d(192, n_classes, 1)


    def forward(self, x):
        x_drop = F.dropout(x, .2)
        conv1_out = F.relu(self.conv1(x_drop))
        conv2_out = F.relu(self.conv2(conv1_out))
        conv3_out = F.relu(self.conv3(conv2_out))
        conv3_out_drop = F.dropout(conv3_out, .5)
        conv4_out = F.relu(self.conv4(conv3_out_drop))
        conv5_out = F.relu(self.conv5(conv4_out))
        conv6_out = F.relu(self.conv6(conv5_out))
        conv6_out_drop = F.dropout(conv6_out, .5)
        conv7_out = F.relu(self.conv7(conv6_out_drop))
        conv8_out = F.relu(self.conv8(conv7_out))

        class_out = F.relu(self.class_conv(conv8_out))
        pool_out = F.adaptive_avg_pool2d(class_out, 1)
        pool_out.squeeze_(-1)
        pool_out.squeeze_(-1)
        return pool_out

4. sampling.py

4.0 数据分布

独立同分布与非独立同分布参考知乎上的一篇文章。

参考:https://www.zhihu.com/question/395555567/answer/3214082688

联邦学习的non-iid的定义可参考《《Advances and Open Problems in Federated Learning》》

参考文献:Kairouz P, McMahan H B, Avent B, et al. Advances and open problems in federated learning[J]. Foundations and trends® in machine learning, 2021, 14(1–2): 1-210.

全文copy过来

作者:反向人
链接:https://www.zhihu.com/question/395555567/answer/3214082688
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
 

1、Feature distribution skew( 特征分布偏差)

所谓一千个人里有一千个哈姆雷特,每个人审美都不一样,我们“判断的角度”不一样。比如有人会看他的五官身材,有人会看谈吐涵养,

2、Lable distribution skew (标签分布偏差)

我们“判断的标准/结果”不一样,有人会说他好看或者难看,有人会说友善还是粗鲁。那第三种和第四种是不是就更好理解啦~

3、Same label, different features (相同的标签,不同的特征)

4、Same features, different label (相同的特征,不同的标签)

5、Quantity skew or unbalancedness ( 数量倾斜或不平衡)

心中默念“输入特征,输出标签”,你就明白前四种Non-IID分布啦。第五种更简单啦,即不同客户拥有不同的数据量。

在传统的机器学习中,通常假设训练数据是独立同分布的(IID),这意味着每个数据样本都是独立地从相同的概率分布中抽取的,因此样本之间是相互独立的且具有相同的分布特性。

但在联邦学习中,由于数据存储在不同的本地设备上,这些设备可能采集不同类型的数据、数据量不同、数据质量不同,或者数据在不同时间和地点收集,因此不同设备上的数据样本之间可能具有不同的分布特性或相关性,而不满足独立同分布的假设。这就是联邦学习的Non-IID,通常有两个问题

  1. 模型收敛困难:当本地数据的分布不同或者数据质量差异较大时,全局模型的收敛可能会受到影响,因为不同设备的本地模型更新可能不容易合并。
  2. 性能不稳定:由于非iid数据分布,全局模型可能在某些设备上表现良好,但在其他设备上表现较差,导致性能不稳定。

而Non-IID经常伴随着异构性这三个字一起出现。

具体来说,非独立同分布是异构性的一种表现,异构性比非独立同分布更广泛。

异构性:通常是指系统或群体中包含多种不同类型、属性、特性或性质的元素。一篇联邦学习的论文把异构性分成了三种:

1、设备异构性:不同设备可能拥有不同的硬件性能,包括CPU、GPU、内存等,导致计算能力不同;网络速度和稳定性

2、统计异构性:设备的数据可能来自不同的数据源、采集方式、环境条件或时间段,导致数据的统计性质存在差异

3、数据异构性:来自不同的数据源,设备的数据可以是多种类型(数值/文本/图像)

4.1 mnist_iid()

灵魂拷问:这种随机分的样本真的就独立同分布吗?

参考4.2 noniid的抽取方法,先分类,在每个类别里面随机不重复平均抽

def mnist_iid(dataset, num_users):
    """
    从 MNIST 数据集中采样独立同分布 (IID) 的客户端数据
    超参数:param dataset
    超参数:param num_users:
    return: dict of image index
    """

    # 计算每个客户端的采样数量(平均分配)
    num_items = int(len(dataset)/num_users)
    # print("num_users is a", type(num_users), "with value", num_users)
    # print("num_items is a", type(num_items), "with value", num_items)

    # 初始化客户端数据字典dict_users存储客户端数据
    # 创建列表all_idxs,包含数据集的所有索引,从 0 到 len(dataset)-1.
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]

    # 采样客户端数据
    # 对于每个客户端使用 np.random.choice() 从 all_idxs 列表中随机选择num_items个唯一的索引
    # 随机索引,且不重复,确保每个客户端获得唯一的一组样本.
    # 选择的索引被存储为 dict_users[i] 字典中的一个集合. 使用集合可以确保没有重复的索引.
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
        # print("dict_users[i] is a", type(dict_users[i]), "with value", dict_users[i])
    return dict_users

4.2 mnist_noniid()

def mnist_noniid(dataset, num_users):
    """
    从 MNIST 数据集中采样非独立同分布 (non-IID) 的客户端数据
    """

    # 数据集划分:
    # 该函数将 MNIST 训练集 (60,000 个样本) 划分为 200 个 "碎片" (shards),每个碎片包含 300 个样本。
    # 函数创建一个列表 idx_shard 来跟踪这 200 个碎片的索引。
    num_shards, num_imgs = 200, 300
    idx_shard = [i for i in range(num_shards)]

    # 函数创建一个字典 dict_users,其中每个键 (客户端索引) 对应一个空的 NumPy 数组。
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()

    # 对标签进行排序:
    # 函数获取训练标签 dataset.train_labels.numpy(),并将其与样本索引 idxs 组合成一个 2D 数组 idxs_labels。
    # 然后对 idxs_labels 按标签列进行排序,得到排序后的索引 idxs
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # divide and assign 2 shards/client
    # 为每个客户端分配数据:
    # 对于每个客户端 (索引为 i),函数随机选择 2 个碎片 (不重复),并将这 2 个碎片中的所有样本索引添加到 dict_users[i] 中。
    # 选择的碎片索引从 idx_shard 列表中删除,确保每个客户端获得唯一的一组样本。
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        idx_shard = list(set(idx_shard) - rand_set)
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]), axis=0)
            print("dict_users[i] is a", type(dict_users[i]), "with value", dict_users[i])
    return dict_users

4.3 mnist_noniid_unequal

def mnist_noniid_unequal(dataset, num_users):
    """
    从MNIST数据集中采样非I.I.D.(非独立同分布)且数量不均等的客户端数据
    """
    # 将整个MNIST训练集(60,000张图像)划分为1200个分片(shards),每个分片包含50张图像
    # 创建一个包含 1200 个索引的列表,表示 1200 个分片
    # 创建一个字典 dict_users,键为客户端编号 i,值为一个空的 NumPy 数组
    # 获取 MNIST 训练集的标签数据,并转换为 NumPy 数组
    num_shards, num_imgs = 1200, 50
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([]) for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs)
    labels = dataset.train_labels.numpy()

    # sort labels
    # 将图像索引 idxs 和标签 labels 垂直堆叠(vstack)到一个二维 NumPy 数组 idxs_labels 中
    # 现在 idxs_labels 是一个 2 x 60,000 的数组,第一行是图像索引,第二行是对应的标签
    idxs_labels = np.vstack((idxs, labels))
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]
    idxs = idxs_labels[0, :]

    # 设置每个客户端获得的最小和最大分片数量为1和30。
    min_shard = 1
    max_shard = 30

    # Divide the shards into random chunks for every client
    # 使用 NumPy 的 random.randint() 函数生成 num_users 个随机整数
    # 这些随机整数表示每个客户端被分配的分片(shard)数量
    # 每个随机整数都在 min_shard 和 max_shard+1 之间(包括 max_shard)
    random_shard_size = np.random.randint(min_shard, max_shard+1,
                                          size=num_users)

    # 将上一步生成的随机分片数量进行归一化
    # 首先计算所有客户端分片数量的总和
    # 然后将每个客户端的分片数量除以总和,得到一个 0 到 1 之间的比例
    # 最后将这个比例乘以总分片数 num_shards,得到每个客户端应该被分配的实际分片数量
    # 使用 np.around() 将结果四舍五入为整数
    random_shard_size = np.around(random_shard_size / sum(random_shard_size) * num_shards)

    # 将上一步得到的浮点数分片数量转换为整数
    random_shard_size = random_shard_size.astype(int)

    # Assign the shards randomly to each client
    if sum(random_shard_size) > num_shards:
        for i in range(num_users):
            # First assign each client 1 shard to ensure every client has
            # atleast one shard of data
            rand_set = set(np.random.choice(idx_shard, 1, replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

        random_shard_size = random_shard_size-1

        # Next, randomly assign the remaining shards
        for i in range(num_users):
            if len(idx_shard) == 0:
                continue
            shard_size = random_shard_size[i]
            if shard_size > len(idx_shard):
                shard_size = len(idx_shard)
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)
    else:

        for i in range(num_users):
            shard_size = random_shard_size[i]
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[i] = np.concatenate(
                    (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

        if len(idx_shard) > 0:
            # Add the leftover shards to the client with minimum images:
            shard_size = len(idx_shard)
            # Add the remaining shard to the client with lowest data
            k = min(dict_users, key=lambda x: len(dict_users.get(x)))
            rand_set = set(np.random.choice(idx_shard, shard_size,
                                            replace=False))
            idx_shard = list(set(idx_shard) - rand_set)
            for rand in rand_set:
                dict_users[k] = np.concatenate(
                    (dict_users[k], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                    axis=0)

    return dict_users

4.4 cifar10

cifar_iid() 和 cifar10_noniid(),思路同上,不赘述。

5. update.py

5.1 DatasetSplite(Dataset)

class DatasetSplit(Dataset):
    # __init__ 方法是 DatasetSplit 类的构造函数
    # 它接受两个参数:
    # dataset: 要被划分的原始数据集。
    # idxs: 一个用于创建子集的索引列表。
    def __init__(self, dataset, idxs):
        # 将传入的原始数据集存储在self.dataset中
        self.dataset = dataset
        # 将输入的索引 idxs 转换为整数列表,并存储在 self.idxs 属性中
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        # 这个方法返回子集的长度,即 self.idxs 列表的长度
        return len(self.idxs)


    def __getitem__(self, item):
        #将图像和标签作为 PyTorch 张量返回
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label)

5.2 LocalUpdate

5.2.1准备工作

为客户端本地模型更新做准备,通过封装公共逻辑,使客户端更新的实现更加简洁和模块化:

  • 参数设置
  • 日志记录
  • 数据准备:后续的训练、验证和测试集准备
  • 计算设备准备
  • 损失函数设置
    def __init__(self, args, dataset, idxs, logger):
        self.args = args
        self.logger = logger
        self.trainloader, self.validloader, self.testloader = self.train_val_test(
            dataset, list(idxs))
        # 指定运算设备
        self.device = 'cuda' if args.gpu else 'cpu'
        # 损失函数用的NLL
        self.criterion = nn.NLLLoss().to(self.device)
5.2.2 train_val_test()

train_val_test方法的作用是根据给定的数据集和用户索引,划分出训练集、验证集和测试集的数据加载器。比例为8:1:1.

    def train_val_test(self, dataset, idxs):

        # split indexes for train, validation, and test (80, 10, 10)
        idxs_train = idxs[:int(0.8*len(idxs))]
        idxs_val = idxs[int(0.8*len(idxs)):int(0.9*len(idxs))]
        idxs_test = idxs[int(0.9*len(idxs)):]

        trainloader = DataLoader(DatasetSplit(dataset, idxs_train),
                                 batch_size=self.args.local_bs, shuffle=True)
        validloader = DataLoader(DatasetSplit(dataset, idxs_val),
                                 batch_size=int(len(idxs_val)/10), shuffle=False)
        testloader = DataLoader(DatasetSplit(dataset, idxs_test),
                                batch_size=int(len(idxs_test)/10), shuffle=False)
        return trainloader, validloader, testloader
5.2.3 update_weights() 本地权重更新
  • 输入模型和全局更新的回合数
  • 优化器选择
  • 训练循环
  • 可视化
  • 输出更新后的权重和损失平均值​​​​​​​
    def update_weights(self, model, global_round):
        # Set mode to train model
        model.train()
        epoch_loss = []

        # Set optimizer for the local updates
        if self.args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=self.args.lr,
                                        momentum=0.5)
        elif self.args.optimizer == 'adam':
            optimizer = torch.optim.Adam(model.parameters(), lr=self.args.lr,
                                         weight_decay=1e-4)

        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.trainloader):
                images, labels = images.to(self.device), labels.to(self.device)
                # pytoch框架训练模型定式5步
                model.zero_grad()
                log_probs = model(images)
                loss = self.criterion(log_probs, labels)
                loss.backward()
                optimizer.step()
                # 打印训练日志并记录损失值,监控训练过程并分析模型性能
                if self.args.verbose and (batch_idx % 10 == 0):
                    print('| Global Round : {} | Local Epoch : {} | [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        global_round, iter, batch_idx * len(images),
                        len(self.trainloader.dataset),
                        100. * batch_idx / len(self.trainloader), loss.item()))
                #保存程序中的数据,然后利用tensorboard工具来进行可视化
                self.logger.add_scalar('loss', loss.item())
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))

        return model.state_dict(), sum(epoch_loss) / len(epoch_loss)
5.2.4 评估函数:inference(self,model)

通取测试集图像和标签,模型出结果后计算loss,然后累加。

代码定义了一个 inference() 函数,用于在测试集上计算模型的推理准确率和损失:

Chat一下,解释得还不错,直接贴过来

    def inference(self, model):
        """ Returns the inference accuracy and loss.
        """
        model.eval()
        loss, total, correct = 0.0, 0.0, 0.0

        for batch_idx, (images, labels) in enumerate(self.testloader):
            images, labels = images.to(self.device), labels.to(self.device)

            # Inference
            outputs = model(images)
            batch_loss = self.criterion(outputs, labels)
            loss += batch_loss.item()

            # Prediction
            _, pred_labels = torch.max(outputs, 1)
            pred_labels = pred_labels.view(-1)
            correct += torch.sum(torch.eq(pred_labels, labels)).item()
            total += len(labels)

        accuracy = correct/total
        return accuracy, loss
5.2.5 test_inference(self,model)

与LocalUpdate中的inference函数完全一致,只不过这里的输入参数除了args和model,还要指定test_dataset.

还不太懂为啥又测试一次。

6. utils.py

封装了一些工具函数:

  • get_dataset()
  • average_weights()
  • exp_details()

6.1 get_dataset(args)

Chat确实是生产力。

6.2 average_weights(w)

FedAvg 加权平均。

def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

6.3 exp_details(args)

之前的代码没有可视化。这个可以参考。方便实验数据分析。

7. 主函数 federated_main.py

7.1 库的引用

import os
import copy
import time
import pickle
import numpy as np
from tqdm import tqdm

import torch
from tensorboardX import SummaryWriter

from options import args_parser
from update import LocalUpdate, test_inference
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar
from utils import get_dataset, average_weights, exp_details

7.2 主函数

库引用之后直接主函数。先完成一些准备工作(试验方案和实验数据提前规划好):

  • 时间
  • 日志
  • 超参数
  • 可视化参数
  • 计算设备
  • 数据集与用户数据分配
if __name__ == '__main__':
    start_time = time.time()

    # define paths
    path_project = os.path.abspath('..')
    logger = SummaryWriter('../logs')

    args = args_parser()
    exp_details(args)

    if args.gpu_id:
        torch.cuda.set_device(args.gpu_id)
    device = 'cuda' if args.gpu else 'cpu'

    # load dataset and user groups
    train_dataset, test_dataset, user_groups = get_dataset(args)

7.3 建立模型

model.py里面定义的模型。 model.py里面我加了个LeNet-5,这里还没写进去。主要是不会写(触发拖延症,后面再写)。

    # BUILD MODEL
    if args.model == 'cnn':
        # Convolutional neural netork
        if args.dataset == 'mnist':
            global_model = CNNMnist(args=args)
        elif args.dataset == 'fmnist':
            global_model = CNNFashion_Mnist(args=args)
        elif args.dataset == 'cifar':
            global_model = CNNCifar(args=args)

    elif args.model == 'mlp':
        # Multi-layer preceptron
        img_size = train_dataset[0][0].shape
        len_in = 1
        for x in img_size:
            len_in *= x
            global_model = MLP(dim_in=len_in, dim_hidden=64,
                               dim_out=args.num_classes)
    else:
        exit('Error: unrecognized model')

7.4 模型训练

    # Set the model to train and send it to device.
    global_model.to(device)
    global_model.train()
    print(global_model)

    # copy weights
    global_weights = global_model.state_dict()

    # Training
    train_loss, train_accuracy = [], []
    val_acc_list, net_list = [], []
    cv_loss, cv_acc = [], []
    print_every = 2
    val_loss_pre, counter = 0, 0

    for epoch in tqdm(range(args.epochs)):
        local_weights, local_losses = [], []
        print(f'\n | Global Training Round : {epoch+1} |\n')

        global_model.train()
        m = max(int(args.frac * args.num_users), 1)
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)

        for idx in idxs_users:
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)
            w, loss = local_model.update_weights(
                model=copy.deepcopy(global_model), global_round=epoch)
            local_weights.append(copy.deepcopy(w))
            local_losses.append(copy.deepcopy(loss))

        # update global weights
        global_weights = average_weights(local_weights)

        # update global weights
        global_model.load_state_dict(global_weights)

        loss_avg = sum(local_losses) / len(local_losses)
        train_loss.append(loss_avg)

        # Calculate avg training accuracy over all users at every epoch
        list_acc, list_loss = [], []
        global_model.eval()
        for c in range(args.num_users):
            local_model = LocalUpdate(args=args, dataset=train_dataset,
                                      idxs=user_groups[idx], logger=logger)
            acc, loss = local_model.inference(model=global_model)
            list_acc.append(acc)
            list_loss.append(loss)
        train_accuracy.append(sum(list_acc)/len(list_acc))

        # print global training loss after every 'i' rounds
        if (epoch+1) % print_every == 0:
            print(f' \nAvg Training Stats after {epoch+1} global rounds:')
            print(f'Training Loss : {np.mean(np.array(train_loss))}')
            print('Train Accuracy: {:.2f}% \n'.format(100*train_accuracy[-1]))

7.5 测试

测试,保存训练损失和训练精度了,输出时间。

# Test inference after completion of training
    test_acc, test_loss = test_inference(args, global_model, test_dataset)

    print(f' \n Results after {args.epochs} global rounds of training:')
    print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
    print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

    # Saving the objects train_loss and train_accuracy:
    file_name = '../save/objects/{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}].pkl'.\
        format(args.dataset, args.model, args.epochs, args.frac, args.iid,
               args.local_ep, args.local_bs)

    with open(file_name, 'wb') as f:
        pickle.dump([train_loss, train_accuracy], f)

    print('\n Total Run Time: {0:0.4f}'.format(time.time()-start_time))

8.作图PLOTTING

之前的图是Origin画的。这个作者写的作图代码没有认真看。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1971423.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Object.defineProperty在Vue2双向绑定中的核心原理及应用

目录 1.Object.defineProperty方法 (1)介绍 (2)语法 (3)descriptor属性描述符 2.Object.defineProperty在Vue2双向绑定的核心原理 3.Object.defineProperty在vue2中的应用 (1&#xff09…

专业人士如何选?揭秘4款2024年常用的电脑录屏软件!

在这个数字化时代,无论是教学、演示、游戏直播还是软件操作,电脑录屏软件已经是我们日常工作中的好帮手。但市面上这么多的电脑录屏软件,要想挑一款既专业又好用的,还真是挺让人头疼的。今天,我们就来聊聊四款常用的电…

mybatis开启数据库的驼峰命名

在application.yml文件中添加 mybatis:configuration:map-underscore-to-camel-case: true

powerjob连接postgresql数据库(支持docker部署)

1.先去pg建一个powerjob-product库 2.首先去拉最新的包,然后找到server模块,把mysql的配置文件信息替换成pg的 spring.datasource.hikari.auto-committrue spring.datasource.remote.hibernate.properties.hibernate.dialecttech.powerjob.server.pers…

全自动迷你洗衣机什么牌子好?五款卓越内衣洗衣机大合集!

随着科技的发展,市面上也出现许多便利的小家电。其中被多次讨论起来的莫过于是内衣洗衣机,选择一款耐用、质量优秀的内衣洗衣机,不仅可以减少洗衣负担,还能提供高效的洗涤效果。然而,随着内衣洗衣机的爆火,…

maven仓库密码加密方案原理

前言 有一个要求就是说不能使用明文密码&#xff0c;需要对 settings.xml 文件中的password密码进行加密 原始配置是没有对密码进行加密的 <server><id>gleam-repo</id><username>admin</username><password>admin123</password>&l…

7.2 单变量(多->多),attention/informer

继续上文书写&#xff1a; 1 GRU Attention 收敛速度稳定的很多&#xff0c;你看这些模型是不是很容易搭&#xff0c;像积木一样&#xff1b; def create_model(input_shape, output_length,lr1e-3, warehouse"None"):input Input(shapeinput_shape)conv1 Conv…

怎么给电脑文件加密?实用的四种方法,「重磅来袭」!

小李&#xff1a;“嘿&#xff0c;小张&#xff0c;你上次提到的那个重要项目报告&#xff0c;我放在了电脑里&#xff0c;但总觉得不太安全&#xff0c;万一被误删了或者不小心泄露了怎么办&#xff1f;” 小张&#xff1a;“别担心&#xff0c;小李&#xff0c;给文件加密是…

如何提高工作效率?分享9个高效率工作的方法

如果您的企业正在面临以下问题&#xff1a; 员工敏捷性和生产力降低员工满意度不足利润下降 那么您需要创建一个运营改进指南。 这需要经常更新&#xff0c;因为这不是一次性的努力&#xff0c;而是必须定期进行的持续过程。然而&#xff0c;您的运营改进指南还必须强调优化…

java 垃圾回收器以及JVM调优方式

什么是垃圾&#xff1a; 没有被引用的对象 就是垃圾。 定位的方式 reference count: 引用计数&#xff0c;即在对象上记录着有多少个引用指向它。&#xff08;循环引用无法解决&#xff09; root searching: 根可达算法&#xff0c;根对象包含 线程栈变量&#xff0c;静态变…

bootStrap中操作行详情,删除,修改等操作

点击列表某一行的操作按钮&#xff0c;结合swtich case 出发不同操作

【2024算力大会分会 | SPIE出版】2024云计算、性能计算与深度学习国际学术会议(CCPCDL 2024)

【2024算力大会分会 | SPIE出版】 2024云计算、性能计算与深度学习国际学术会议(CCPCDL 2024) 2024 International conference on Cloud Computing, Performance Computing and Deep Learning CCPCDL往届均已完成EI检索&#xff0c;最快会后4个半月完成&#xff01; 2024中…

postgresql 11.17 开发环境rpm安装及扩展安装

进入postgresql安装文件rpm所在文件夹 cd /data460/software 执行 yum local install *.rpm 提示缺少啥依赖就对应yum安装 最后有个依赖比较特殊 Requires: llvm-toolset-7-clang > 4.0.1 You could try using --skip-broken to work around the problem 需要安装centos-re…

Spring WebFlux 整合 r2dbc 的增删改查案例

无障碍阅读方法 微信公众号关注:张家的小伙子 回复:10205文章目录 无障碍阅读方法说明准备创建mysql数据库和数据表创建一个maven项目添加项目依赖包创建项目基本目录接口启动类编写编写application配置添加跨域请求配置创建实体-数据表映射类创建Dao操作类编写自己的增删改…

VS code 美化之 代码窗背景图 日志2024/8/2

VS code 美化之 代码窗背景图 先看效果: 参考文档: VSCode设置背景图片的两种方式_vscode代码背景-CSDN博客 用插件那个方法我试了,其只会在右侧 侧边栏目出现背景图,可能是我设置不正确吧 而且安装这个插件之后出现弹窗 vscode安装出现问题什么的提示,删除这个拓展就不会有…

时间价值衰减对期权价格有哪些影响?投资必知!

今天带你了解时间价值衰减对期权价格有哪些影响&#xff1f;投资必知&#xff01;期权的时间对期权的价格和价值具有重要影响&#xff0c;这是由于期权的特性和市场机制决定的&#xff0c;其实期权的时间价值是会衰减的。 期权的时间价值&#xff0c;指的是潜在的可能性。 比…

TypeScript(switch判断)

1.switch 语法用法 switch是对某个表达式的值做出判断。然后决定程序执行哪一段代码 case语句中指定的每个值必须具有与表达式兼容的类型 语法switch(表达式){ case 值1&#xff1a; ​ 执行语句块1 break; case 值2&#xff1a; ​ 执行语句块3 break; dfault: //如…

CSDN选择:腾讯cdn缓存跟阿里云cdn对比

在如今互联网迅速发展的时代&#xff0c;内容分发网络&#xff08;CDN&#xff09;变得越来越重要。而在众多CDN提供商中&#xff0c;腾讯云和阿里云的CDN服务无疑是具代表性的两家。那么&#xff0c;这两家的CDN服务究竟有何差异&#xff1f;哪一家更值得选择呢&#xff1f;今…

Python WSGI服务器库之gunicorn使用详解

概要 在部署 Python Web 应用程序时,选择合适的 WSGI 服务器是关键的一步。Gunicorn(Green Unicorn)是一个高性能、易于使用的 Python WSGI HTTP 服务器,适用于各种应用部署场景。Gunicorn 设计简洁,支持多种工作模式,能够有效地管理和处理大量并发请求。本文将详细介绍…

【Canvas与艺术】八角大楼

【成图】 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8"/> <head><title>八角大楼</title><style type"text/css">.cen…