展心展力 metaapp:基于 DeepRec 的稀疏模型训练实践

news2024/11/24 16:32:24

作者

metaapp-推荐广告研发部:臧若舟,朱越,司灵通

1 背景

推荐场景大模型在国内的使用很早,早在 10 年前甚至更早,百度已经用上了自研的大规模分布式的 parameter server 系统结合上游自研的 worker 来实现 TB 级别的万亿参数的稀疏模型。后来,各家平台也陆续基于这种方案,开发了自己的分布式训练系统,普遍特点是大量使用 id embedding,因此参数量巨大,模型大小也非常夸张。当然,随着开源训练工具 TensorFlow/Pytorch 的流行,使用 TensorFlow/Pytorch 作为 worker,结合自研 ps 的方案也十分流行。究其原因,以 TensorFlow 为例,虽然内置了分布式训练系统,但是对于大规模 id embedding 的支持却非常糟糕,无法作为完整的平台使用。而使用 TensorFlow+ 自研 ps 的方案也存在不少问题,比如自研 ps 一般对于特征输入都有特定的要求、二次开发成本比较高等。

一个典型的分布式 worker-ps 架构

2 业务介绍

metaapp- 推荐广告研发部,主要负责 metaapp 拳头产品 233 乐园的首页信息流的推荐和广告系统,是比较传统的推广搜组。我们在 2020 年之前也是采用了 TensorFlow+ 自研分布式 ps 的方案,模型大小在接近 TB 级别(业务体量较小),整个方案的迭代和维护成本都比较高。

在这种背景下,经过多方考量,阿里云机器学习平台 PAI 开源的 DeepRec(脱胎于 PAI-TF),作为支持了淘宝搜索、猜你喜欢、定向、直通车等核心业务的训练平台,直接基于 TensorFlow 做二次开发,针对稀疏模型在分布式、图优化、算子、Runtime 等方面进行了深度的性能优化,并且完全开源。

而因为我们公司本身跟阿里云有着深度的合作,阿里云也主动介绍了当时还是内部项目的 DeepRec 给我们尝试。在近 2 年的工作后,DeepRec 已经全量用于我们的模型训练和线上 inference,并且取得了显著的性能提升和成本下降。

3 稀疏模型训练

3.1 EmbeddingVariable 多级存储

由于模型参数量大,一些特征的 embedding 大小达到了接近 TB 级别,完全基于内存存储对于成本的要求过高,因此自然而然就会想到多级存储:将最热的 embedding 放在显存或者内存里,其余的可以分级放在 PMEM、SSD 等成本较低的存储介质中。而 DeepRec 中 提供了基于 EmbeddingVariable 的 Embedding 多级存储功能。DeepRec 目前对于 embedding 存放在各种存储介质的支持已经相当完善。

下面介绍下我们团队升级 DeepRec 在存储这一块的过程和经验:

3.1.1 compaction 的性能问题

我们原本基于自研的分布式 parameter server,而当时 PMEM 类的存储介质还不普及,因此我们选择了比较高性能的 SSD 作为多级存储介质。于是我们自然而然采用了类 leveldb(rocksdb)的方案作为 SSD 存储方案。但这种方案在模型训练时,由于参数不断增加和更新,后台会进行频繁的 compaction,此时会有严重的写放大问题导致 ps 的读取时间大大延长,从而导致模型训练的瓶颈几乎都在 ps 侧。ps:据说 rocksdb 在 2022 年底的 7.5.3 版本大幅改进了 compaction 的性能,在后台 compaction 时几乎不会影响读取的性能。

3.1.2 DeepRec 的方案

而在早期我们试用 DeepRec 时,DeepRec 的 EmbeddingVariable 对于 SSD 存储的方案同样是基于 leveldb,因此同样遇到了跟我们自研的方案类似的问题。后续我们将此问题的测试结果反馈给了 DeepRec 相关的同学,他们基于此后续推出了基于 SSDHASH 的存储方案,大大提升了 compaction 时的读取性能,因此模型训练基于不再受困于 ps 的读取性能问题。后续又进一步了基于 SSDHASH 的同步和异步两种 compaction 的方式。使用同步 compaction 时,向 SSD 写入数据和 compaction 将会使用同一个线程,异步时则各使用一个线程。这里也推荐大家使用这种方案。

3.1.3 压缩模型大小

进一步的,如果能把模型大小控制在数十 GB,那 ps 的性能就可以进一步提升了。因为采用 DeepRec,自定义各种压缩方式的算子变得非常轻松。我们调研并实现了了多篇 embedding 压缩方向的 paper,最后采用了 binary code 的方式实现了 embedding 的 multihash 方案,可以自由控制 embedding 的大小。我们尝试在最大的特征 uid embedding 上使用了 multihash,把模型大小从 800GB 降低到 40GB 以下,auc 的损失仅在千分之三左右,线上点击率下降了 1.5%;进一步的,我们通过优化序列推荐模型,更好的通过序列特征建模了用户的个性化,可以发现在序列模型的基础上把 uid embedding 换成 multihash 的方案,对于线上点击率的影响仅有 0.3% 左右,因此可以放心全量 multihash 方案。我们也把基于 multihash 的 embedding variable 算子以 pr 的形式提交给了 DeepRec。

3.2 基于 GPU 的分布式训练

在解决了 ps 的性能瓶颈后,模型训练的速度就和模型 Tensor 计算的算力近似线性相关了。而近几年随着序列模型的发展,搜广推的矩阵计算复杂度也在显著提升。此时使用 gpu+ 大 batch size 来代替 cpu 作为 worker 的方案,无论在性能还是成本控制上都有巨大的优势。而阿里云机器学习平台 PAI 开源的 HybridBackend 平台就支持了基于 GPU 的分布式训练方案,并且深度支持了 DeepRec。

可以看到使用 hb 的方案在训练速度上对比 TF-PS 原生方案的优势。

3.2.1 模型参数完全放在显存里

想要充分释放 gpu 的算力,减少因为数据拷贝带来的性能损耗,最好的方案自然是把所有参数都放在 gpu 显存里。上面 2.1.3 提到的压缩模型大小,为这种方案提供了可能性。调大 batch size 则可以进一步提高显卡的利用率。经过测试,在这种情况下,单张 V100 显卡的算力可以超过 20 台 40core worker 节点的算力。

3.2.2 解决了多卡训练丢失数据的问题

在单机多卡训练时,我们发现和单卡训练相比有近 1/3 的数据被丢弃,这是由于 hybridbackend 默认使用所有 worker 按照 row group 平分数据的策略,以提高读取效率。当 group 数目不够均分时,多余的数据会被丢弃,当 parquet 文件较多且比较小时,该问题尤为严重。我们通过使用每个 worker 加载所有的 group,再按照 batch 平分数据的策略,极大地缓解了数据丢失的情况,读取压力也在可接收范围内,后续可考虑将两策略结合降低 worker 的读取压力。

4 模型 inference

4.1 痛点

在我们的实际场景里,线上 inference 的痛点大部分来自于维护成本。因为推荐广告业务场景,需要大量尝试各种模型在线上分配流量做 AB test,因此线上存在的模型量级大概是 10 倍的基线模型量级。而每次上线一个模型,都需要给对应的模型分配相应的资源,并且这个资源跟 AB test 的流量正相关;而每次调整 AB test 流量(比如模型效果不错,放大流量观察)的时候,又需要调整该模型分配的资源。这个过程比较难实现自动化,往往需要算法工程师手动扩缩容。

4.2 基于 Processer 库的 inference 方案解决痛点

上面这个图是我们线上实际的 inference 方案。

4.2.1 单机器运行所有模型

基于上面的痛点,我们给出的方案是使用大规格机器(比如 128C,512G 内存)来做线上 inference,然后每台机器都会有线上所有的模型实例。每台机器运行一个 serving-proxy 会自动的管理所有的模型进程,包括模型上下线、模型更新等。这种方案的好处是整个维护成本基本没有了,所有事情基本都自动化完成了。因为线上整体的流量相对稳定(比如扩大 AB test 模型的流量,自然基线模型流量就减少了,整体是稳定的),所以各个模型之间资源竞争也不需要重新分配资源。

4.2.2 基于 DeepRec 提供的 Processer 库

DeepRec Serving Processor 是用于线上高性能服务的 Library,可以参考文档。因为本身是一个独立的 so 包,我们可以很方便的对接到自己的 Serving RPC 框架中。我们采用 golang 语言来完成了我们自己的 serving rpc 项目,优点自然是开发成本低并且性能不错。

4.2.3 使用 DeepRec 的 Session Group

直接使用 TensorFlow 提供的 C++ 接口调用 Session::Run,无法实现多 Session 并发处理 Request,导致单 Session 无法实现 CPU 的有效利用。如果通过多 Instance 方式(多进程),无法共享底层的 Variable,导致大量使用内存,并且每个 Instance 各自加载一遍模型,严重影响资源的使用率和模型加载效率。

DeepRec 中 SessionGroup 可配置一组 Session,并且通过 Round Robin (支持用户自定义策略)方式将用户请求分发到某一个 Session。SessionGroup 对不同 Session 之间的资源进行隔离,每个 Session 拥有私有的线程池,并且支持每个线程池绑定底层的 CPU Core(numa-aware),可以最大程度地避免共享资源导致的锁冲突开销。SessionGroup 中唯一共享的资源是 Variable,所有 Session 共享底层的 Variable,并且模型加载只需要加载一次。

我们使用 session group 后,实测调整到合适的 group 数量,可以提高 50% 的 inference 性能。

4.2.4 基于 oneDNN 的优化

DeepRec 集成了英特尔开源的跨平台深度学习性能加速库 oneDNN(oneAPI Deep Neural Network Library),并且修改 oneDNN 原有的线程池,统一成 DeepRec 的 Eigen 线程池,减少了线程池切换开销,避免了不同线程池之间竞争而导致的性能下降问题。oneDNN 已经针对大量主流算子实现了性能优化,包括 MatMul、BiasAdd、LeakyReLU 等在业务场景中使用到的常见算子,为业务模型提供了强有力的性能支持。更值得一提的是, oneDNN 的算子支持 BF16 数据类型,与搭载 AMX(Advanced Matrix Extensions)指令集的第四代英特尔® 至强® 可扩展处理器同时使用,可显著提升模型训练和推理性能。在 DeepRec Serving Processor 编译选项中,只需加入“--config=mkl_threadpool”,便可轻松开启 oneDNN 优化。

4.2.5 子图优化

子图融合是推理性能优化的常用方法。但是对于本模型中左图所示的子图结构含有 Reshape 算子,原生 tensorflow 并没有对应结构的图优化器以及算子实现,我们通过手动融合来实现,融合前后的子图构成如下图所示。这样减少了多余算子的运行开销,减少了内存访问,提升了计算效率。再结合 oneDNN 加速融合算子,最终业务端到端加速了 10%,CPU 利用率下降 10%。

4.2.6 cost model 的设计

由于大机器的 cpu core 数较多,而我们是一台机器有所有模型的进程,那么所有模型都共享所有 cpu core 显然会造成不必要的资源竞争等。因此给不同模型设计合理的 cost model 就很有必要。我们目前采用比较简单的方式,因为基线模型和需要做 AB test 的模型资源差别较大(流量差距大),我们会给每个基线模型分配对应的 core,然后让所有非基线模型共享一组 core(总体 AB test 的流量有上限)。虽然这个方案很简单,但是取得了非常好的效果,大概有 30% 的性能提升。

5 后续规划

1、cost model 的优化,显然有更好的方案来动态的调整每个模型需要的 core。我们打算开发更好的 cost model 并提供给 DeepRec。

2、开源我们的 inference 架构方案,因为在我们的业务里,基于 DeepRec processor 设计的 inference 架构带来了巨大的便利,并且性能很好,我们预计在上半年会开源我们的 inference 架构方案,欢迎大家到时关注。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/424562.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【LeetCode】剑指 Offer(27)

目录 题目:剑指 Offer 53 - I. 在排序数组中查找数字 I - 力扣(Leetcode) 题目的接口: 解题思路: 代码: 过啦!!! 写在最后: 题目:剑指 Offe…

【机器学习 P19】【实战 P1】 MINST 手写数字识别

MINST 手写数字识别引入数据模型训练模型创建程序模型编译程序模型训练程序模型预测程序完整代码引入数据 MINST数据集是一个经典的手写数字识别数据集,由Yann LeCun等人创建。它包含了来自真实手写数字图片的70000个灰度图像,这些图像是由250个不同的人…

三行Python代码,让数据处理速度提高2到6倍

本文可以教你仅使用 3 行代码,大大加快数据预处理的速度。 Python 是机器学习领域内的首选编程语言,它易于使用,也有很多出色的库来帮助你更快处理数据。但当我们面临大量数据时,一些问题就会显现…… 在默认情况下,…

OpenShift 4 - 使用 virtctl 远程访问 OpenShift Virtualization 的虚拟机

《OpenShift / RHEL / DevSecOps 汇总目录》 说明:本文已经在支持 OpenShift 4.12 的 OpenShift 环境中验证 在《OpenShift 4 - 用 OpenShift Virtualization 运行容器化虚拟机 (视频)》一文中使用了 OpenShift 控制台直接访问运行在 OpenSh…

SQL中去除重复数据的几种方法,我一次性都告诉你​

使用SQL对数据进行提取和分析时,我们经常会遇到数据重复的场景,需要我们对数据进行去重后分析。以某电商公司的销售报表为例,常见的去重方法我们用到distinct 或者group by 语句, 今天介绍一种新的方法,利用窗口函数对…

MIT 6.S965 韩松课程 05

Lecture 05: Quantization (Part 1) 文章目录Lecture 05: Quantization (Part 1)动机数字的数据类型整数定点数浮点数量化基于 K-Means 的量化 [[Han et al., ICLR 2016]](https://arxiv.org/pdf/1510.00149v5.pdf)线性量化 [[Jacob et al. CVPR 2018]](https://arxiv.org/pdf/…

Makefile项目管理-----在Linux下编译c/c++程序

这里写目录标题起因makefile项目管理一、用途:二、 makefile的基础规则1.多文件联合编译2. makefile检测原理3. ALL来指定终极目标三、 makefile的两个函数和clean四、 makefile中的三个自动变量五、模式规则六、 静态模式规则七、 扩展1. 扩展1 伪目标2. 扩展2 可添…

在 Python 中检查字符串是否为 ASCII

使用 str.isascii() 方法检查字符串是否为 ASCII,例如 if my_str.isascii():。 如果字符串为空或字符串中的所有字符都是 ASCII,则 str.isascii() 方法返回 True,否则返回 False。 my_str www.jiyik.comif my_str.isascii():# &#x1f447…

网络安全工程师做什么?

​ 网络安全很复杂。数字化转型、远程工作和不断变化的威胁形势需要不同的工具和不同的技能组合。 系统必须到位以保护端点、身份和无边界网络边界。负责处理这种复杂安全基础设施的工作角色是网络安全工程师。 简而言之,网络安全工程师是负责设计和实施组织安全系…

基于TF-IDF+KMeans聚类算法构建中文文本分类模型(附案例实战)

🤵‍♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞&#x1f4…

UHD安装教程

UHD Universal Hardware Driver,即USRP驱动。 UHD,Windows平台安装教程 uhd驱动安装 http://files.ettus.com/binaries/misc/erllc_uhd_winusb_driver.zip 安装LibUSBx http://files.ettus.com/binaries/uhd/latest_release 下载默认C盘 环境配置 将…

Android FrameWork 知识点与面试题整合~

1.如何对 Android 应用进行性能分析 android 性能主要之响应速度 和UI刷新速度。 首先从函数的耗时来说,有一个工具TraceView 这是androidsdk自带的工作,用于测量函数耗时的。 UI布局的分析,可以有2块,一块就是Hierarchy Viewe…

面试-Sqrt(x)

题目 给你一个非负整数 x ,计算并返回 x 的 算术平方根 。 由于返回类型是整数,结果只保留 整数部分 ,小数部分将被 舍去 。 注意:不允许使用任何内置指数函数和算符,例如 pow(x, 0.5) 或者 x ** 0.5 。 思路 二分查…

项目管理:项目进度难以把控,项目经理应该怎么办?

项目管理中,对进度的管理也是保障整个项目顺利完成的重要条件。项目进度难以把控,项目常常延期,项目经理怎么办?如何跟进整个项目的进度? 对于如何做好项目进度管理,有几点建议,希望能对大家有…

Java实现导出多个excel表打包到zip文件中,供客户端另存为窗口下载

文章目录一、业务背景二、实现思路二、准备工作1.准备data模板.xlsx2.引入poi相关依赖,用于操作excel3.针对WorkBookZIP压缩输入/输出流,相关方法知识点要有所了解三、完整的项目代码四、可能遇到的问题错误场景1:java.io.IOException: Strea…

【RabbitMQ】SpringBoot整合RabbitMQ实现延迟队列、TTL、DLX死信队列

目录 一、TTL 1、什么是TTL 2、设置TTL的两种方式 3、控制台设置TTL 4、SpringBoot实现两种方式设置TTL 1.给消息设置过期时间 2.给队列设置过期时间 二、DLX死信队列 1、什么是死信交换机与死信队列 2、消息何时会成为死信 3、队列如何绑定死信交换机与死信队列 4…

vscode“检测到 #include 错误,请更新 includepath。”的问题解决办法

目录 一.报错更新includepath​编辑 二.原因 三.解决方法 一.报错更新includepath 如图 二.原因 1.没有安装gcc 2.没有配置好环境 winR打开cmd,输入gcc -v,如果安装了gcc,会返回版本 三.解决方法 1.安装MinGW 2.添加MinGW环境变量 将bin文件夹的位置添加到系统环境变量中…

三分钟搭建个人博客技术栈Nuxt3+vite+mysql+koa2

最近也是想入一下Nuxt3的坑,然后就写了一个博客系统,目前已开源github,欢迎大家star!!! 效果预览 网址:http://180.76.121.2:3000/ github地址 https://github.com/ztzzhi/ztzzhi-nuxt3-vite…

MySQL事物(基础篇)

MySQL事务事物的基本概念事物的ACID属性事务的使用事务隔离级别MVCC&ReadViewMySQL是否还存在幻读事物的基本概念 Transaction作为关系型数据库的核心组成,在数据安全方面有着非常重要的作用,本文会一步步解析事务的核心特性,以获得对事…

多云数据存储,理想与现实之间还差着什么?

去年底,“数据二十条”正式颁布,数据要素全面提速已是指日可待。 无疑,数据作为数字经济的基础,其价值的释放依赖于数据的流动、共享和应用。数据要素只有充分地流动和应用起来,才能够实现价值的最大化。 换而言之&a…