DataParallel里为什么会显存不均匀以及如何解决

news2025/1/22 16:05:09

878dce4c518892bdc598bd2e8c53877d.pnge018fb6487dbbfde56a2212f04ab2d10.png

作者:台运鹏 (正在寻找internship...)
主页:https://yunpengtai.top

鉴于网上此类教程有不少模糊不清,对原理不得其法,代码也难跑通,故而花了几天细究了一下相关原理和实现,欢迎批评指正!

关于此部分的代码,可以去https://github.com/sherlcok314159/dl-tools查看

「在开始前,我需要特别致谢一下一位挚友,他送了我双显卡的机器来赞助我做个人研究,否则多卡的相关实验就得付费在云平台上跑了,感谢好朋友一路以来的支持,这份恩情值得一辈子铭记!这篇文章作为礼物赠与挚友。」

Why Parallel

我们在两种情况下进行并行化训练[1]

  1. 「模型一张卡放不下」:我们需要将模型不同的结构放置到不同的GPU上运行,这种情况叫ModelParallel(MP)

  2. 「一张卡的batch size(bs)过小」:有些时候数据的最大长度调的比较高(e.g., 512),可用的bs就很小,较小的bs会导致收敛不稳定,因而将数据分发到多个GPU上进行并行训练,这种情况叫DataParallel(DP)。当然,DP肯定还可以加速训练,常见于大模型的训练中

这里只讲一下DP在pytorch中的原理和相关实现,即DataParallel和DistributedParallel

Data Parallel

实现原理

实现就是循环往复一个过程:数据分发,模型复制,各自前向传播,汇聚输出,计算损失,梯度回传,梯度汇聚更新,可以参见下图[2]

fd959ce18d7163bd79baa58540e00565.png

pytorch中部分关键源码[3]截取如下:

def data_parallel(
 module, 
 input, 
 device_ids, 
 output_device=None
):
    if not device_ids:
        return module(input)

    if output_device is None:
        output_device = device_ids[0]

    # 复制模型
    replicas = nn.parallel.replicate(module, device_ids)
    # 拆分数据
    inputs = nn.parallel.scatter(input, device_ids)
    replicas = replicas[:len(inputs)]
    # 各自前向传播
    outputs = nn.parallel.parallel_apply(replicas, inputs)
    # 汇聚输出
    return nn.parallel.gather(outputs, output_device)

代码使用

因为运行时会将数据平均拆分到GPU上,所以我们准备数据的时候, batch size = per_gpu_batch_size * n_gpus

同时,需要注意主GPU需要进行汇聚等操作,因而需要比单卡运行时多留出一些空间

import torch.nn as nn
# device_ids默认所有可使用的设备
# output_device默认cuda:0
net = nn.DataParallel(model, device_ids=[0, 1, 2], 
                      output_device=None, dim=0)
# input_var can be on any device, including CPU
output = net(input_var)

接下来看个更详细的例子[4],需要注意的是被DP包裹之后涉及到模型相关的,需要调用DP.module,比如加载模型

class Model(nn.Module):
    # Our model
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        # for convenience
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)
        print("\tIn Model: input size", input.size(),
              "output size", output.size())
        return output

bs, input_size, output_size = 6, 8, 10
# define inputs
inputs = torch.randn((bs, input_size)).cuda()
model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [6, xxx] -> [2, ...], [2, ...], [2, ...] on 3 GPUs
  model = nn.DataParallel(model)
# 先DataParallel,再cuda
model = model.cuda()
outputs = model(inputs)
print("Outside: input size", inputs.size(),
   "output_size", outputs.size())
# assume 2 GPUS are available
# Let's use 2 GPUs!
#    In Model: input size torch.Size([3, 8]) output size torch.Size([3, 10])
#    In Model: input size torch.Size([3, 8]) output size torch.Size([3, 10])
# Outside: input size torch.Size([6, 8]) output_size torch.Size([6, 10])

# save the model
torch.save(model.module.state_dict(), PATH)
# load again
model.module.load_state_dict(torch.load(PATH))
# do anything you want

如果经常使用huggingface,这里有两个误区需要小心:

# data parallel object has no save_pretrained
model = xxx.from_pretrained(PATH)
model = nn.DataParallel(model).cuda()
model.save_pretrained(NEW_PATH) # error
# 因为model被DP wrap了,得先取出模型 #
model.module.save_pretrained(NEW_PATH)
# HF实现貌似是返回N个loss(N为GPU数量)
# 然后对N个loss取mean
outputs = model(**inputs)
loss, logits = outputs.loss, outputs.logits
loss = loss.mean()
loss.backward()

# 返回的logits是汇聚后的
# HF实现和我们手动算loss有细微差异
# 手动算略好于HF
loss2 = loss_fct(logits, labels)
assert loss != loss2
True

显存不均匀

了解前面的原理后,就会明白为什么会显存不均匀。因为GPU0比其他GPU多了汇聚的工作,得留一些显存,而其他GPU显然是不需要的。那么,解决方案就是让其他GPU的batch size开大点,GPU0维持原状,即不按照默认实现的平分数据

首先我们继承原来的DataParallel(此处参考[5])),这里我们给定第一个GPU的bs就可以,这个是实际的bs而不是乘上梯度后的。假如你想要总的bs为64,梯度累积为2,一共2张GPU,而一张最多只能18,那么保险一点GPU0设置为14,GPU1是18,也就是说你DataLoader每个batch大小是32,gpu0_bsz=14

class BalancedDataParallel(DataParallel):
    def __init__(self, gpu0_bsz, *args, **kwargs):
        self.gpu0_bsz = gpu0_bsz
        super().__init__(*args, **kwargs)

核心代码就在于我们重新分配chunk_sizes,实现思路就是将总的减去第一个GPU的再除以剩下的设备,源码的话有些死板,用的时候不妨参考我的[6]

def scatter(self, inputs, kwargs, device_ids):
    # 不同于源码,获取batch size更加灵活
    # 支持只有kwargs的情况,如model(**inputs)
    if len(inputs) > 0:
        bsz = inputs[0].size(self.dim)
    elif kwargs:
        bsz = list(kwargs.values())[0].size(self.dim)
    else:
        raise ValueError("You must pass inputs to the model!")

    num_dev = len(self.device_ids)
    gpu0_bsz = self.gpu0_bsz
    # 除第一块之外每块GPU的bsz
    bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
    if gpu0_bsz < bsz_unit:
        # adapt the chunk sizes
        chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
        delta = bsz - sum(chunk_sizes)
        # 补足偏移量
        # 会有显存溢出的风险,因而最好给定的bsz是可以整除的
        # e.g., 总的=52 => bsz_0=16, bsz_1=bsz_2=18
        # 总的=53 => bsz_0=16, bsz_1=19, bsz_2=18
        for i in range(delta):
            chunk_sizes[i + 1] += 1
        if gpu0_bsz == 0:
            chunk_sizes = chunk_sizes[1:]
    else:
        return super().scatter(inputs, kwargs, device_ids)

    return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)

优缺点

  • 优点:便于操作,理解简单

  • 缺点:GPU分配不均匀;每次更新完都得销毁「线程」(运行程序后会有一个进程,一个进程可以有很多个线程)重新复制模型,因而速度慢

参考资料

[1]

Pytorch中多GPU并行计算教程: https://blog.csdn.net/qq_37541097/article/details/109736159

[2]

DDP: https://www.cnblogs.com/ljwgis/p/15471530.html

[3]

MULTI-GPU EXAMPLES: https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html?highlight=dataparallel

[4]

OPTIONAL: DATA PARALLELISM: https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html

[5]

transformer-xl: https://github.com/kimiyoung/transformer-xl

[6]

Balanced Data Parallel: https://github.com/sherlcok314159/dl-tools/blob/main/balanced_data_parallel/README.md


📝论文解读投稿,让你的文章被更多不同背景、不同方向的人看到,不被石沉大海,或许还能增加不少引用的呦~ 投稿加下面微信备注“投稿”即可。

最近文章

COLING'22 | SelfMix:针对带噪数据集的半监督学习方法

ACMMM 2022 | 首个针对跨语言跨模态检索的噪声鲁棒研究工作

ACM MM 2022 Oral  | PRVR: 新的文本到视频跨模态检索子任务

统计机器学习方法 for NLP:基于CRF的词性标注

统计机器学习方法 for NLP:基于HMM的词性标注


点击这里进群—>加入NLP交流群

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/87379.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

商品上下游第六讲-交易中心-商品秒杀

交易中心-商品秒杀设计 文章目录 交易中心-商品秒杀设计1、项目背景2、主要技术3、项目职责4、项目实现4.1、需求分析4.2、核心流程4.3、关键链路技术方案4.4、库存处理方式1、库存超卖问题订单层面的控制4.5、限流,熔断,降级4.6、超职购小程序—接口梳理4.7、缓存的设计4.8、…

mybatis入门(一)

什么是 MyBatis &#xff1f; MyBatis 是一款优秀的持久层框架&#xff0c;它支持定制化 SQL、存储过程以及高级映射。MyBatis 避免了几乎所有的 JDBC 代码和手动设置参数以及获取结果集。MyBatis 可以使用简单的 XML 或注解来配置和映射原生信息&#xff0c;将接口和 Java 的…

如何使用匈牙利算法解决多维度的约束条件问题

&#x1f37f;*★,*:.☆欢迎您/$:*.★* &#x1f37f; 正文 假设 一个项目 有三个 维度的参数 A B C 都要 组合后最小 分别求解 a b c 三个维度的最优组合 如果三个组合方案刚好 重叠 那么说明有一个使得三个方案最优的 解 如果没有 那么若选择某个方案 其他维度的参数 的值 是…

直播倒计时 2 天 | SOFAChannel#31 RPC 框架设计的考和量

SOFARPC 是蚂蚁集团开源的一款基于 Java 实现的 RPC 服务框架&#xff0c;为应用之间提供远程服务调用能力&#xff0c;具有高可伸缩性&#xff0c;高容错性&#xff0c;目前蚂蚁集团所有的业务的相互间的 RPC 调用都是采用 SOFARPC。SOFARPC 为用户提供了负载均衡&#xff0c;…

Android -- 每日一问:回调函数和观察者模式的区别?

知识点 观察者模式 网上很容易查到观察者模式的定义&#xff1a; 观察者模式定义了对象间的一种一对多依赖关系&#xff0c;使得每当一个对象改变状态&#xff0c;则所有依赖于它的对象都会得到通知并被自动更新。 Android中大量的使用了观察者模式。你可能已经用过ListView…

基于51单片机的舞蹈机器人步进机仿真设计

程序运行图&#xff1a; 仿真原理图&#xff1a; 部分程序&#xff1a; #include "reg51.h" #include "intrins.H" //8步式步进电机脉冲序列 //unsigned char steps[8] {0x77,0x33,0xbb,0x99,0xdd,0xcc,0xee,0x66}; unsigned char steps[8] {0x2,0x…

Vue2快速入门

Vue 介绍 Vue 是一套构建用户界面的渐进式前端框架只关注视图层&#xff0c;并且非常容易学习&#xff0c;还可以很方便的与其它库或已有项目整合通过尽可能简单的API来实现响应数据的绑定和组合的视图组件特点易用&#xff1a;在有HTML CSS JavaScript的基础上&#xff0c;快速…

拓扑排序(数据结构之图的应用)

我们先搞清楚一个概念&#xff1a; 什么是出度与入度&#xff1f; 在有向图中&#xff0c;箭头是具有方向的&#xff0c;从一个顶点指向另一个顶点&#xff0c;这样一来&#xff0c;每个顶点被指向的箭头个数&#xff0c;就是它的入度。从这个顶点指出去的箭头个数&#xff0c…

不锈钢风淋室的使用需要注意哪些事项

风淋室的使用需要注意哪些事项 一、风淋室的操作说明&#xff1a; 1) 接通380V&#xff0c;50HZ电源(L1、L2、L3-火线&#xff0c;N-零线&#xff0c;E-接地线)&#xff0c;打开工作、照明开关&#xff0c;确认风机与照明工作正常&#xff0c;此时&#xff0c;风/货淋室处于初…

原创 | Attention is all you need 论文解析(附代码)

作者&#xff1a;杨金珊审校&#xff1a;陈之炎本文约4300字&#xff0c;建议阅读8分钟“Attention is all you need”一文在注意力机制的使用方面取得了很大的进步&#xff0c;对Transformer模型做出了重大改进。目前NLP任务中的最著名模型&#xff08;例如GPT-2或BERT&#x…

【数集项目之 MCDF】(四) 整形器 formatter

根据上一章的arbiter结构图&#xff0c;结合设计文档中MCDF的整体结构图&#xff0c;可以发现formatter整形器模块是arbiter的上级&#xff0c;负责最终的数据输出&#xff0c;与外界数据接收端相连。 第一节 fromatter文档理解 设计文档formatter的部分时序介绍如下 如图所示…

钡铼技术S274数据遥测终端机

钡铼技术S274数据遥测终端机功能特点&#xff1a; 内置 2 路 DC 直流电源输出&#xff0c;无需单独额外增加变送器的电源适配器&#xff0c;节省布线成本&#xff1b;  采用完备的防掉线机制&#xff0c;保证数据终端永远在线&#xff0c;掉线重发数据以及掉线短信通知用户…

第38篇 网络(八)TCP(二)

导语 在上一节里我们使用TCP服务器发送一个字符串&#xff0c;然后在TCP客户端进行接收。在这一节将重新写一个客户端程序和一个服务器程序&#xff0c;这次实现客户端进行文件的发送&#xff0c;服务器进行文件的接收。有了上一节的基础&#xff0c;这一节的内容就很好理解了…

postgresql_internals-14 学习笔记(三)冻结、rebuild

一、 Freezing 冻结 1. 引入原因 简单说来就是目前pg事务id只有32位&#xff0c;大业务量下很可能用完&#xff0c;触发事务id回卷&#xff08;循环使用&#xff09;。而pg是根据事务id大小判断可见性的&#xff0c;如果新事务却使用了小id&#xff0c;旧事务将可以看到新事务…

win下 conda 虚拟环境没有名字怎么进入

本文主要介绍windows下&#xff0c;在conda 虚拟环境名字消失后的解决办法。主要介绍两种解决方案。 文章目录前言解决方案一&#xff1a;往.condarc文件中添加envs_dirs1. 设置envs_dirs2. 重新查看虚拟环境解决方案二&#xff1a;直接通过path 激活虚拟环境总结前言 我们都知…

Grafana 监控大屏可视化图表

Grafana 系列文章&#xff0c;版本&#xff1a;OOS v9.3.1 Grafana 的介绍和安装Grafana监控大屏配置参数介绍&#xff08;一&#xff09;Grafana监控大屏配置参数介绍&#xff08;二&#xff09;Grafana监控大屏可视化图表 前面我们以Time series 图表为例&#xff0c;学习了面…

每天投递一两个公司,我连续投了三个月

作者&#xff1a;阿秀校招八股文学习网站&#xff1a;https://interviewguide.cn你好&#xff0c;我是阿秀。阿秀以前在秋招的时候投递过八九十份简历&#xff0c;当时还没有简历一键上传功能&#xff0c;很多时候都需要自己去那些公司注册账号&#xff0c;然后找到校园招聘模块…

手机备忘录误删恢复的操作方法

手机备忘录在使用的过程中&#xff0c;会有多种不同的操作&#xff0c;通过不同的操作来实现不同的效果。对于有的内容来说&#xff0c;是可以过期删除的&#xff0c;但是在删除这个操作的过程当中&#xff0c;如果不小心把有用的东西误删了&#xff0c;那么恢复误删内容的操作…

行内块元素因换行带来的间隔问题及解决方法

行内块元素因换行带来的间隔问题 先看一个案例所展示的行内块元素因换行带来的间隔问题,俺直接上截图 再来一张截图可以更加清楚地看见行内块元素之间因换行而带来的间隔 从上方所有图片可以看出,行内块元素之间一行并排放置时编译器中的换行操作会在浏览器渲染时带来行内块元…

鸡汤来了

这几天&#xff0c;网上铺天盖地都是各种感染新冠的消息&#xff0c;连我一直关注的和菜头也感染上了&#xff0c;关键是连怎么感染的都不知道。他写道&#xff1a;我也很委屈。自从北京开始比拼首阳之后&#xff0c;我的确是缩在家里&#xff0c;想着越晚感染越好。为了达到这…