分布式机器学习(Parameter Server)

news2025/1/10 23:21:15

分布式机器学习中,参数服务器(Parameter Server)用于管理和共享模型参数,其基本思想是将模型参数存储在一个或多个中央服务器上,并通过网络将这些参数共享给参与训练的各个计算节点。每个计算节点可以从参数服务器中获取当前模型参数,并将计算结果返回给参数服务器进行更新。

为了保持模型一致性,通常采用下列两种方法:

  1. 将模型参数保存在一个集中的节点上,当一个计算节点要进行模型训练时,可从集中节点获取参数,进行模型训练,然后将更新后的模型推送回集中节点。由于所有计算节点都从同一个集中节点获取参数,因此可以保证模型一致性。
  2. 每个计算节点都保存模型参数的副本,因此要定期强制同步模型副本,每个计算节点使用自己的训练数据分区来训练本地模型副本。在每个训练迭代后,由于使用不同的输入数据进行训练,存储在不同计算节点上的模型副本可能会有所不同。因此,每一次训练迭代后插入一个全局同步的步骤,这将对不同计算节点上的参数进行平均,以便以完全分布式的方式保证模型的一致性,即All-Reduce范式

PS架构

在该架构中,包含两个角色:parameter server和worker

parameter server将被视为master节点在Master/Worker架构,而worker将充当计算节点负责模型训练

整个系统的工作流程分为4个阶段:

  1. Pull Weights: 所有worker从参数服务器获取权重参数
  2. Push Gradients: 每一个worker使用本地的训练数据训练本地模型,生成本地梯度,之后将梯度上传参数服务器
  3. Aggregate Gradients:收集到所有计算节点发送的梯度后,对梯度进行求和
  4. Model Update:计算出累加梯度,参数服务器使用这个累加梯度来更新位于集中服务器上的模型参数

可见,上述的Pull Weights和Push Gradients涉及到通信,首先对于Pull Weights来说,参数服务器同时向worker发送权重,这是一对多的通信模式,称为fan-out通信模式。假设每个节点(参数服务器和工作节点)的通信带宽都为1。假设在这个数据并行训练作业中有N个工作节点,由于集中式参数服务器需要同时将模型发送给N个工作节点,因此每个工作节点的发送带宽(BW)仅为1/N。另一方面,每个工作节点的接收带宽为1,远大于参数服务器的发送带宽1/N。因此,在拉取权重阶段,参数服务器端存在通信瓶颈。

对于Push Gradients来说,所有的worker并发地发送梯度给参数服务器,称为fan-in通信模式,参数服务器同样存在通信瓶颈。

基于上述讨论,通信瓶颈总是发生在参数服务器端,将通过负载均衡解决这个问题

将模型划分为N个参数服务器,每个参数服务器负责更新1/N的模型参数。实际上是将模型参数分片(sharded model)并存储在多个参数服务器上,可以缓解参数服务器一侧的网络瓶颈问题,使得参数服务器之间的通信负载减少,提高整体的通信效率。

代码实现

定义网络结构:如上定义了一个简单的CNN

实现参数服务器:

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        if torch.cuda.is_available():
            device = torch.device("cuda:0")
        else:
            device = torch.device("cpu")
 
        self.conv1 = nn.Conv2d(1,32,3,1).to(device)
        self.dropout1 = nn.Dropout2d(0.5).to(device)
        self.conv2 = nn.Conv2d(32,64,3,1).to(device)
        self.dropout2 = nn.Dropout2d(0.75).to(device)
        self.fc1 = nn.Linear(9216,128).to(device)
        self.fc2 = nn.Linear(128,20).to(device)
        self.fc3 = nn.Linear(20,10).to(device)
 
    def forward(self,x):
        x = self.conv1(x)
        x = self.dropout1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.dropout2(x)
        x = F.max_pool2d(x,2)
        x = torch.flatten(x,1)
 
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
 
        output = F.log_softmax(x,dim=1)
 
        return output

如上定义了一个简单的CNN

实现参数服务器:

class ParamServer(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = Net()
 
        if torch.cuda.is_available():
            self.input_device = torch.device("cuda:0")
        else:
            self.input_device = torch.device("cpu")
 
        self.optimizer = optim.SGD(self.model.parameters(),lr=0.5)
 
    def get_weights(self):
        return self.model.state_dict()
 
    def update_model(self,grads):
        for para,grad in zip(self.model.parameters(),grads):
            para.grad = grad
 
        self.optimizer.step()
        self.optimizer.zero_grad()

get_weights获取权重参数,update_model更新模型,采用SGD优化器

实现worker:

class Worker(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = Net()
        if torch.cuda.is_available():
            self.input_device = torch.device("cuda:0")
        else:
            self.input_device = torch.device("cpu")
 
    def pull_weights(self,model_params):
        self.model.load_state_dict(model_params)
 
    def push_gradients(self,batch_idx,data,target):
        data,target = data.to(self.input_device),target.to(self.input_device)
        output = self.model(data)
        data.requires_grad = True
        loss = F.nll_loss(output,target)
        loss.backward()
        grads = []
 
        for layer in self.parameters():
            grad = layer.grad
            grads.append(grad)
 
        print(f"batch {batch_idx} training :: loss {loss.item()}")
 
        return grads

Pull_weights获取模型参数,push_gradients上传梯度

训练

训练数据集为MNIST

import torch
from torchvision import datasets,transforms
 
from network import Net
from worker import *
from server import *
 
train_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=True,
               transform = transforms.Compose([transforms.ToTensor(),
               transforms.Normalize((0.1307,),(0.3081,))])),
               batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('./mnist_data', download=True, train=False,
              transform = transforms.Compose([transforms.ToTensor(),
              transforms.Normalize((0.1307,),(0.3081,))])),
              batch_size=128, shuffle=True)
 
def main():
    server = ParamServer()
    worker = Worker()
 
    for batch_idx, (data,target) in enumerate(train_loader):
        params = server.get_weights()
        worker.pull_weights(params)
        grads = worker.push_gradients(batch_idx,data,target)
        server.update_model(grads)
 
    print("Done Training")
 
if __name__ == "__main__":
    main()

来源:分布式机器学习(Parameter Server) - N3ptune - 博客园 (cnblogs.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/691807.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

高速电路设计系列分享-信号链精度分析(下)

文章目录 概要整体架构流程技术名词解释技术细节小结 概要 提示:这里可以添加技术概要 在任何设计中,信号链精度分析都可能是一项非常重要的任务,必须充分了解。之前, 我们讨论了在整个信号链累积起来并且最终会影响到转换器的多…

嵌入式软件测试笔记7 | 嵌入式软件测试中基于风险的测试策略如何开展?

7 | 嵌入式软件测试中基于风险的测试策略如何开展? 1 风险评估1.1 分析风险1.2 如何估计故障几率?1.3 导致故障几率较高的因素1.4 估计可能的损失1.5 风险评估的来源1.6 风险的处理 2 主测试计划中的策略2.1 目标2.2 制定策略的步骤2.3 选择质量特性2.4 …

美国签证办理需要户口本吗?

在申请美国签证时,有关所需文件的问题常常令人困惑。关于是否需要提供户口本,知识人网可以向您解释一下相关情况。 首先,需要明确的是,美国签证申请并不要求申请人提供户口本。美国领事馆和大使馆在签证申请过程中通常要求申请人提…

零知识证明(Sigma和Flat-shamir)

概述 定义:大概的定义就是prover可以向verifier证明自己给定的信息是大概率正确的,但是不泄露任何附加信息,包含信息本身。 举例 这里以一个比较经典的例子,即向红绿色盲(无法区分红色和绿色,看红色和绿色…

基于RFID技术的并列式挤奶厅方案

随着现代农业的不断发展,RFID技术已经广泛应用于畜牧业生产中。在奶牛养殖领域,RFID技术可以帮助养殖场管理人员实现奶牛的精准管理,提高生产效率。本文将介绍一种基于RFID技术的并列式挤奶厅方案,该方案可以实现对每头奶牛的精准…

网络安全(黑客)必备工具包

1. NMap 作为Network Mapper的缩写,NMap是一个开源的免费安全扫描工具,可用于安全审计和网络发现。它适用于Windows、Linux、HP-UX、Solaris、BSD变体(包括Mac OS)以及AmigaOS。Nmap可用于探测网络上哪些主机可访问,它们正在运行的操作系统类…

Keil MDK编程环境下的 STM32 IAP下载(学习笔记)

IAP的引入 不同的程序下载方式 ICP ICP(In Circuit Programing)。在电路编程,可通过 CPU 的 Debug Access Port 烧录代码,比如 ARM Cortex 的 Debug Interface 主要是 SWD(Serial Wire Debug) 或 JTAG(Joint Test Action Group); ISP ISP(I…

合宙Air724UG Cat.1模块硬件设计指南--Wifi扫描

概述 Air724UG具有WiFi Scan功能,支持2.4G频段下的802.11b,802.11g,802.11n等WiFi技术协议,结合模块本身支持的蓝牙功能,二者共用一路天线。 Air724UG以主动的方式,在每个信道上发出Probe Request帧&#x…

Java集合框架中取出元素时的比较问题:“==“与equals()方法

今天随便刷力扣的时候看到了最小栈&#xff0c;发现力扣上没做过&#xff0c;题不难&#xff0c;于是做了一下 一开始的代码如下&#xff1a; class MinStack {Deque<Integer> stack;Deque<Integer> minStack;public MinStack() {stack new LinkedList<>()…

HHU商务数据挖掘期末考点复习

文章目录 第一章 概述第二章 商务智能过程2.1四个部分2.2数据仓库与数据库2.3在线分析处理与在线事务处理 第三章 关联分析3.1 频繁模式与关联规则3.2 相关性度量liftcosine 第四章 分类4.1决策树4.1.1 信息熵的概念4.1.2 计算目标变量的信息熵4.1.3 算条件熵4.1.4 信息增益4.1…

从专用模型到大模型

背景&#xff1a; 在开始文章正文之前&#xff0c;我们来讲讲为何突然大模型火了&#xff0c;大模型和专用模型到底有何差异。 大模型火之前专业模型其实已经能够很好的配合做很多很复杂的事情。如果只是从提高工作效率的角度来讲应该是发展模型的专业问题解决能力&#xff0…

SpringSecurity实现Remember-Me实践

【1】基于会话技术的实现 也就是基于Cookie的实现。 ① 登录页面 这里name"remember-me"表示“记住我”的复选框&#xff0c;默认key是remember-me。 <form action"/user/login" method"post"><input type"text" name&q…

Jmeter连接数据库并进行操作

一&#xff1a;加一个JDBC组件 二、填写连接信息&#xff1a; 三&#xff1a;添加JDBC请求 四、填写sql并运行

[centos] 新买的服务器环境搭建

由于去年买的云服务器快过期了,然后最近又新买了一个服务器,所以就写下了这篇文章, 虽然可以镜像搭建,但是本身原服务器就没有多少东西,所以我选择了手动搭建... 再且,也可以帮我再熟悉一下 centos 环境... 当然很多都是我之前OneNote的学习笔记,这里就直接复制和粘贴了(&#…

操作系统3——处理机调度与死锁

本系列博客重点在深圳大学操作系统课程的核心内容梳理&#xff0c;参考书目《计算机操作系统》&#xff08;有问题欢迎在评论区讨论指出&#xff0c;或直接私信联系我&#xff09;。 梗概 本篇博客主要介绍操作系统第三章处理机调度与死锁的相关知识。 目录 一、调度基本概念…

【FFmpeg实战】MP4封装格式分析

原文地址&#xff1a;https://www.cnblogs.com/moonwalk/p/16244932.html 解析工具&#xff1a; https://gpac.github.io/mp4box.js/test/filereader.html (mp4box) 1. 概述 mp4 容器格式相较于 flv、ts 容器格式来说&#xff0c;其定义较为复杂&#xff0c;本篇文章主要记录…

1.计算机是如何工作的(上)

文章目录 1.计算机发展史2.冯诺依曼体系&#xff08;Von Neumann Architecture&#xff09;3.CPU 基本工作流程3.1逻辑门3.1.1电子开关 —— 机械继电器(Mechanical Relay)3.1.2门电路(Gate Circuit) 3.2算术逻辑单元 ALU&#xff08;Arithmetic & Logic Unit&#xff09;3…

BossPlayerCTF 靶场

sudo nmap -sn 192.168.28.0/24 sudo nmap -sT --min-rate 10000 -p- 192.168.28.40 sudo nmap -sT -sV -sC -O -p22,80 192.168.28.40 -oA nmapscan/detail sudo nmap --scriptvuln -p22,80 192.168.28.40 -oA nmapscan/vuln 访问80 查看web源码&#xff1a; robots.txt ec…

AI EXPO 2023 | 图技术激活数据资产论坛圆满落幕

2023年6月25日下午&#xff0c;由新一代人工智能产业技术创新战略联盟、苏州市人工智能协同创新中心与苏州市大数据服务中心协会联合主办&#xff0c;浙江创邻科技有限公司承办的「2023全球人工智能产品应用博览会-图技术激活数据资产主题论坛」在苏州国际博览中心圆满落幕&…

JavaWeb小记——重定向和内部转发

目录 重定向 原理图 重定向语句 重定向特点 内部转发 原理图 请求转发特点 路径的书写 请求域对象request 特点 请求转发特点 重定向特点 重定向 原理图 重定向语句 response.setStatus(302); response.setHeader("location","http://www.baidu.c…