在 Cloud TPU 上训练 DLRM 和 DCN (TF 2.x)

news2025/1/22 19:54:20

本教程介绍如何训练 DLRM 和 DCN v2 排名模型, 用于预测点击率 (CTR) 等任务。查看以下语言版本的备注 设置以运行 DLRM 或 DCN 模型,了解如何设置参数 来训练 DLRM 或 DCN v2 排名模型。

模型输入是数值特征和分类特征,输出是标量 (例如点击概率)。模型可以 Cloud TPU。深度排名模型都属于内存密集型模型(对于嵌入, 表和查询),以及深度网络 (MLP) 的计算密集型应用。TPU 两者兼顾

该模型将 TPUEmbedding 层用于分类特征。TPU 嵌入 支持具有快速查找功能的大型嵌入表, 会根据 TPU Pod 的大小线性扩缩。您最多只能使用 90 GB 的嵌入表 用于 TPU v3-8,5.6 TB 用于 v3-512 Pod,22.4 TB 用于 v3-2048 TPU Pod。

模型代码位于 TensorFlow Recommenders 库中, 其中介绍了输入流水线、配置和训练循环, TensorFlow Model Garden。

目标

  • 设置训练环境
  • 使用合成数据运行训练作业
  • 验证输出结果

费用

在本文档中,您将使用 Google Cloud 的以下收费组件:

  • Compute Engine
  • Cloud TPU
  • Cloud Storage

您可使用价格计算器根据您的预计使用情况来估算费用。 Google Cloud 新用户可能有资格申请免费试用。

准备工作

在开始学习本教程之前,请检查您的 Google Cloud 项目是否已正确设置。

  1. In the Google Cloud console, on the project selector page, select or create a Google Cloud project.

    Go to project selector

  2. 确保您的 Google Cloud 项目已启用结算功能。

本演示使用 Google Cloud 的收费组件。请查看 Cloud TPU 价格页面估算您的费用。请务必在创建应用时清理您创建的 TPU 资源 以免产生不必要的费用

设置资源

本部分介绍如何设置本教程使用的 Cloud Storage 存储桶、虚拟机和 Cloud TPU 资源。

  1. 打开一个 Cloud Shell 窗口。

    打开 Cloud Shell

  2. 为项目 ID 创建一个变量。

    export PROJECT_ID=project-id
  1. 配置 Google Cloud CLI 以使用要在其中创建项目的项目 Cloud TPU。

如需详细了解 gcloud 命令,请参阅 Google Cloud CLI 参考文档。

gcloud config set project ${PROJECT_ID}

当您第一次在新的 Cloud Shell 虚拟机中运行此命令时,系统会显示 Authorize Cloud Shell 页面。点击底部的 Authorize 允许 gcloud 使用您的凭据进行 API 调用。

  1. 为 Cloud TPU 项目创建服务账号。
gcloud beta services identity create --service tpu.googleapis.com --project $PROJECT_ID

该命令会返回一个 Cloud TPU 服务账号,其格式如下:

service-PROJECT_NUMBER@cloud-tpu.iam.gserviceaccount.com
  1. 使用以下命令创建 Cloud Storage 存储桶,其中 --location 选项指定应创建存储桶的区域。如需详细了解可用区和区域,请参阅类型和可用区:
 gcloud storage buckets create gs://bucket-name --project=${PROJECT_ID} --location=europe-west4

此 Cloud Storage 存储分区存储您用于训练模型的数据和训练结果。本教程中使用的 gcloud compute tpus execution-groups 工具会为您在上一步中设置的 Cloud TPU 服务账号设置默认权限。如果您需要更精细的权限,请查看访问级层权限。

存储分区位置必须要与 Compute Engine(虚拟机)和 Cloud TPU 节点位于同一地区。

  1. 使用 gcloud 命令启动 Compute Engine 虚拟机和 Cloud TPU。
$ gcloud compute tpus tpu-vm create dlrm-dcn-tutorial \
        --zone=europe-west4-a \
        --accelerator-type=v3-8 \
        --version=tpu-vm-tf-2.17.0-se

命令标志说明

`zone`

拟在其中创建 Cloud TPU 的[区域](https://cloud.google.com/tpu/docs/types-zones?hl=zh-cn)。

`accelerator-type`

加速器类型指定要创建的 Cloud TPU 的版本和大小。 如需详细了解每个 TPU 版本支持的加速器类型,请参阅 [TPU 版本](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm?hl=zh-cn#versions)。

`version`

Cloud TPU [软件版本](https://cloud.google.com/tpu/docs/supported-tpu-versions?hl=zh-cn#tpu_software_versions)。
  1. 使用 SSH 连接到 Compute Engine 实例。连接到网络后 您的 Shell 提示符会从 username@projectname 更改为 username@vm-name
gcloud compute tpus tpu-vm ssh dlrm-dcn-tutorial --zone=europe-west4-a

设置 Cloud Storage 存储分区变量

设置以下环境变量,将 bucket-name 替换为 Cloud Storage 存储分区的名称:

(vm)$ export STORAGE_BUCKET=gs://bucket-name
(vm)$ export PYTHONPATH="/usr/share/tpu/models/:${PYTHONPATH}"
(vm)$ export EXPERIMENT_NAME=dlrm-exp

为 TPU 名称设置环境变量。

  (vm)$ export TPU_NAME=local

训练应用预期能够访问您在 Cloud Storage 中的训练数据。在训练期间,训练应用还会使用您的 Cloud Storage 存储分区来存储检查点。

进行设置以使用合成数据运行 DLRM 或 DCN 模型

该模型可以使用各种数据集进行训练。最常用的两个数据集是 Criteo TB 和 Criteo Kaggle。本教程通过设置标志 use_synthetic_data=True 使用合成数据进行训练。

合成数据集仅用于了解如何使用 Cloud TPU 和验证端到端性能。准确率 和已保存的模型就没有意义了。

请访问 Criteo Terabyte 和 Criteo Kaggle 网站,了解如何下载和预处理这些数据集。

  1. 安装必需的软件包。
    (vm)$ pip3 install tensorflow-recommenders
    (vm)$ pip3 install -r /usr/share/tpu/models/official/requirements.txt
  1. 切换到脚本目录。
 (vm)$ cd /usr/share/tpu/models/official/recommendation/ranking
  1. 运行训练脚本。它使用类似 Criteo 的虚构数据集来训练 DLRM 模型。训练大约需要 20 分钟。

    auto
    export EMBEDDING_DIM=32
    
    python3 train.py --mode=train_and_eval 
         --model_dir=${STORAGE_BUCKET}/model_dirs/${EXPERIMENT_NAME} --params_override="
         runtime:
             distribution_strategy: 'tpu'
         task:
             use_synthetic_data: true
             train_data:
                 input_path: '${DATA_DIR}/train/*'
                 global_batch_size: 16384
             validation_data:
                 input_path: '${DATA_DIR}/eval/*'
                 global_batch_size: 16384
             model:
                 num_dense_features: 13
                 bottom_mlp: [512,256,${EMBEDDING_DIM}]
                 embedding_dim: ${EMBEDDING_DIM}
                 top_mlp: [1024,1024,512,256,1]
                 interaction: 'dot'
                 vocab_sizes: [39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63,
                     38532951, 2953546, 403346, 10, 2208, 11938, 155, 4, 976, 14,
                     39979771, 25641295, 39664984, 585935, 12972, 108, 36]
         trainer:
             use_orbit: false
             validation_interval: 1000
             checkpoint_interval: 1000
             validation_steps: 500
             train_steps: 1000
             steps_per_loop: 1000
         "
    

此训练在 v3-8 TPU 上运行大约 10 分钟。完成后,您将看到如下所示的消息:

在这里插入图片描述

清除数据

为避免因本教程中使用的资源导致您的 Google Cloud 账号产生费用,请删除包含这些资源的项目,或者保留项目但删除各个资源。

  1. 断开与 Compute Engine 实例的连接(如果您尚未这样做):
(vm)$ exit

你的提示现在应为username@projectname,表明你处于 运行命令的 Cloud Shell 实际效果

  1. 删除您的 Cloud TPU 资源。
 $ gcloud compute tpus tpu-vm delete dlrm-dcn-tutorial \
      --zone=europe-west4-a
  1. 通过运行 gcloud compute tpus execution-groups list 验证资源是否已删除。删除操作可能需要几分钟时间才能完成。以下命令的输出 不应包含本教程中创建的任何资源:
$ gcloud compute tpus tpu-vm list --zone=europe-west4-a
  1. 使用 gcloud CLI 删除 Cloud Storage 存储桶。将 bucket-name 替换为您的 Cloud Storage 存储分区的名称。
$ gcloud storage rm gs://bucket-name --recursive

后续步骤

TensorFlow Cloud TPU 教程通常使用示例数据集来训练模型。此训练的结果不能用于推理。接收者 使用模型进行推理,就可以使用公开可用的 或您自己的数据集。在 Cloud TPU 上训练的 TensorFlow 模型 通常需要将数据集 TFRecord 格式。

您可以使用数据集转换工具 示例,用于转换图片 转换为 TFRecord 格式。如果您使用的不是图片 分类模型,您需要将数据集转换为 TFRecord 格式 。如需了解详情,请参阅 TFRecord 和 tf.Example。

超参数调节

要使用数据集提升模型的性能,您可以调整模型的 超参数。您可以找到所有用户共有的超参数的相关信息 支持 TPU 的 GitHub。 如需了解特定于模型的超参数,请参阅源代码 每个代码 模型。如需详细了解超参数调优,请参阅概览 超参数调优 和 Tune 超参数。

推断

训练模型后,即可将其用于推理(也称为 预测)。您可以使用 Cloud TPU 推断转换器 工具来准备和优化 在 Cloud TPU v5e 上用于推理的 TensorFlow 模型。有关 如需了解如何在 Cloud TPU v5e 上进行推理,请参阅 Cloud TPU v5e 推理 简介。

文章来源:google cloud

推荐阅读

  • 在 Cloud TPU 上训练 Mask RCNN (TF 2.x)
  • 在 Cloud TPU 上训练 ShapeMask (TF 2.x)
  • 在 Cloud TPU 上训练 RetinaNet (TF 2.x)

更多芯擎AI开发板干货请关注芯擎AI开发板专栏。欢迎添加极术小姐姐微信(id:aijishu20)加入技术交流群,请备注研究方向。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2103576.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【HuggingFace Transformers】LlamaRotaryEmbedding源码解析

LlamaRotaryEmbedding源码解析 1. LlamaRotaryEmbedding类 介绍2. 逆频率向量3. LlamaRotaryEmbedding类 源码解析3.1 transformers v4.44.2版3.2 transformers v4.41.1版 1. LlamaRotaryEmbedding类 介绍 在LLaMa模型中,LlamaRotaryEmbedding类实现了Rotary Posit…

Elasticsearch 向量数据库本地部署 及操作方法

elasticsearch是个分布式向量数据库,支持多种查找模式。此外还拥有 Metadata、Filtering、Hybrid Search、Delete、Store Documents、Async等能力。本文仅是记录本地测试途中遇到的问题。 一,环境部署 下载软件 首先去官网,选择适合平台下…

Kafka-设计原理

ControllerLeader - PartitionRebalance消息发布机制HW与LEO日志分段 Controller Kafka核心总控制器Controller:在Kafka集群中会有一个或者多个broker,其中有一个broker会被选举为控制器(Kafka Controller),它负责管理…

Hyper-v 安装 centOS

一.Hyper-v安装 1. 右键此电脑,点击属性,查看自己的window版本 如果是专业版或者企业版,则无需额外操作,如果是家庭版,则需要先运行一个脚本来进行安装。 参考这一篇:window10 家庭版如何开启Hyper-v-CSDN…

FPGA开发:初识FPGA

FPGA是什么? FPGA的全称是现场可编程门阵列(Field Programmable Gate Array),一种以数字电路为主的集成芯片,属于可编程逻辑器件PLD的一种。简单来说,就是能用代码编程,直接修改FPGA芯片中数字…

OceanBase 关于 place_group_by HINT的使用

PLACE_GROUP_BY Hint 表示在多表关联时,如果满足单表查询后直接进行group by 的情形下,在跟其它表进行关联统计,减少表内部联接。 NO_PLACE_GROUP_BY Hint 表示在多表关联时,在关联后才对结果进行group by。 使用place_group_by …

二百五十九、Java——采集Kafka数据,解析成一条条数据,写入另一Kafka中(一般JSON)

一、目的 由于部分数据类型频率为1s,从而数据规模特别大,因此完整的JSON放在Hive中解析起来,尤其是在单机环境下,效率特别慢,无法满足业务需求。 而Flume的拦截器并不能很好的转换数据,因为只能采用Java方…

启动.cmd文件一闪而过,看不到报错信息

在window的环境中,双击.cmd文件,有报错信息,但是一闪而过 例如启动zookeeper时,没有zoo.cfg文件会报错,但是启动一闪而过,你看不到报错信息 有文本工具编辑cmd文件,在最后添加 pause 再次启…

Linux 之 lsblk 【可用块的设备信息】

功能介绍 在 Linux 系统中,“lsblk”(list block devices)命令用于列出所有可用的块设备信息 应用场景 查看存储设备信息:“lsblk” 命令可以帮助你快速了解系统中的存储设备,包括硬盘、固态硬盘、U 盘等。你可以查…

9_4_QTextEdit

QTextEdit //核心属性//获取文本 toPlainText(); toHtml(); toMarkdown(); //输入框为空时的提示功能 placeHolderText(); //只读 readOnly();//定义文本光标 QTextcursor cursorcursor.position(); cursor.selectedText();//核心信号//文本改变 textChanged(); //选中范围 se…

【黑马点评】附近商户

需求 选择商铺类型后,按照距离当前用户所在位置从近到远的顺序,分页展示该类型的所有商铺。 接口: 参数: typeId:商铺类型current:页码x:经度y:纬度 返回值:所有typeId…

LVS 负载均衡集群指南

1. 引言 LVS (Linux Virtual Server) 虚拟服务器,是 Linux 内核中实现的负载均衡技术,以其高性能、高可靠性和高可用性而闻名。LVS 工作在 TCP/IP 协议栈的第四层 (传输层),通过将流量分配到多个后端服务器,提高系统性能、可用性…

硬件工程师笔试面试知识器件篇——电阻

目录 1、电阻 1.1 基础 电阻原理图 阻实物图 1.1.1、定义 1.1.2、工作原理 1.1.3、类型 1.1.4、材料 1.1.5、标记 1.1.6、应用 1.1.7、特性 1.1.8、测量 1.1.9、计算 1.1.10、颜色编码 1.1.11、公差 1.1.12、功率 1.1.13、重要性 1.2、相关问题 1.2.1、电阻…

数组和指针 笔试题(1)

目录 0.复习 1.笔试题1 2.笔试题2 3.笔试题3 4.笔试题4 5.笔试题5 0.复习 在做笔试题之前,我们首先复习一下数组名的理解 数组名的所有情况: 1.&数组名,取出的是整个数组的地址 2.sizeof(数组名)&#x…

LLM常见问题(Attention 优化部分)

1. 传统 Attention 存在哪些问题? 传统的 Attention 机制忽略了源端或目标端句子中词与词之间的依赖关系。传统的 Attention 机制过度依赖 Encoder-Decoder 架构上。传统的 Attention 机制依赖于Decoder的循环解码器,所以依赖于 RNN,LSTM 等循环结构。传…

【Transformer】Tokenization

文章目录 直观理解分词方式词粒度-Word字粒度-Character子词粒度-Subword(目前最常使用) 词表大小的影响参考资料 直观理解 在理解Transformer或者大模型对输入进行tokenize之前,需要理解什么是token? 理工科的兄弟姐妹们应该都…

027集——goto语句用法——C#学习笔记

goto语句可指定代码的跳行运行: 实例如下: 代码如下: using System; using System.Collections.Generic; using System.Linq; using System.Security.Policy; using System.Text; using System.Threading.Tasks;namespace ConsoleApp2 { //…

采用基于企业服务总线(ESB)的面向服务架构(SOA)集成方案实现统一管理维护的银行信息系统

目录 案例 【题目】 【问题 1】(7 分) 【问题 2】(12 分) 【问题 3】(6 分) 【答案】 【问题 1】解析 【问题 2】解析 【问题 3】解析 相关推荐 案例 阅读以下关于 Web 系统设计的叙述,在答题纸上回答问题 1 至问题 3。 【题目】 某银行拟将以分行为主体…

是噱头还是低成本新宠?加州大学用视觉追踪实现跨平台的机器手全掌控?

导读: 在当今科技飞速发展的时代,机器人的应用越来越广泛。从工业生产到医疗保健,从物流运输到家庭服务,机器人正在逐渐改变我们的生活方式。而机器人的有效操作和控制,离不开高效的遥操作系统。今天,我们要…

OHIF Viewer (3.9版本最新版) 适配移动端——最后一篇

根据一些调用资料和尝试,OHIF 的底层用的是Cornerstonejs ,这个是基于web端写的,如果说写在微信小程序里,确实有很多报错, 第一个问题就是 npm下载的依赖, 一、运行环境差异 微信小程序的运行环境与传统的 Node.js 环境有很大不同。小程序在微信客户端中运行,有严格的…