自然语言处理: 第九章DeepSpeed的实践

news2024/11/18 12:25:49

理论基础

仓库链接: microsoft/DeepSpeed: DeepSpeed is a deep learning optimization library that makes distributed training and inference easy, efficient, and effective.

DeepSpees正如它官网介绍的一样,它为深度学习模型提供了一站式的快速以及大规模的训练及推理框架,能在尽可能利用你手中的算力去作深度学习的应用

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jy4c1ZHS-1691735368213)(image/09_DeepSpeed/1691662528184.png)]




而本次主要介绍的是它在4月份发布的一站式的端到端的RLHF板块,如同其主页介绍一样,整合了一个B2B的训练流程如下图:

  • 步骤1:监督微调(SFT) —— 使用精选的人类回答来微调预训练的语言模型以应对各种查询;
  • 步骤2:奖励模型微调 —— 使用一个包含人类对同一查询的多个答案打分的数据集来训练一个独立的(通常比 SFT 小的)奖励模型(RW);
  • 步骤3:RLHF 训练 —— 利用PPO算法,根据 RW 模型的奖励反馈进一步微调 SFT 模型。

在步骤3中,我们提供了两个额外的功能,以帮助提高模型质量:

  • 指数移动平均(EMA) —— 可以选择基于 EMA 的检查点进行最终评估, 详情可以参考:【炼丹技巧】指数移动平均(EMA)
  • 混合训练 —— 将预训练目标(即下一个单词预测)与 PPO 目标混合,以防止在像 SQuAD2.0 这样的公开基准测试中的性能损失

这两个训练功能,EMA 和混合训练,常常被其他的开源框架所忽略,因为它们并不会妨碍训练的进行。然而,根据 InstructGPT,EMA 通常比传统的最终训练模型提供更好的响应质量,而混合训练可以帮助模型保持预训练基准解决能力。因此,我们为用户提供这些功能,以便充分获得 InstructGPT 中描述的训练体验,并争取更高的模型质量。

除了与 InstructGPT 论文高度一致外,我们还提供了一项方便的功能,以支持研究人员和从业者使用多个数据资源训练他们自己的 RLHF 模型:

  • 数据抽象和混合能力 : DeepSpeed-Chat 能够使用多个不同来源的数据集训练模型以获得更好的模型质量。它配备了(1)一个抽象数据集层,以统一不同数据集的格式;以及(2)数据拆分/混合功能,以便多个数据集在 3 个训练阶段中被适当地混合然后拆分。

在我们之前的章节中,你可以看到使用整个 DeepSpeed-Chat 训练模型在多轮对话中的表现。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-tr8QWSPD-1691735368215)(image/09_DeepSpeed/1691732870229.png)]




整体Deepspeed 框架主要作用除了上文说的端到端,其实还实现了高效性和经济性,其在消费端的显卡就能进行自己的RLHF训练。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-oVCkLLVv-1691735368215)(image/09_DeepSpeed/1691733927697.png)]



运行

安装



训练

如果你只拥有一个消费级的显卡, 你可以尝试下训练一个13b参数的OPT,大约需要两个小时左右

python train.py --actor-model facebook/opt-1.3b --reward-model facebook/opt-350m --deployment-type single_gpu

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1gBk1wkN-1691735368216)(image/09_DeepSpeed/1691734960987.png)]

如果你拥有强大的算力群, 你可以尝试下训练一个660b参数的OPT,大约需要9小时左右

python train.py --actor-model facebook/opt-66b --reward-model facebook/opt-350m --deployment-type multi_node

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eJocDNWt-1691735368217)(image/09_DeepSpeed/1691735259164.png)]

同样微软提供了自定义的API 可以自己定义训练流程:


engine = DeepSpeedRLHFEngine(
  actor_model_name_or_path=args.actor_model_name_or_path,
  critic_model_name_or_path=args.critic_model_name_or_path,
  tokenizer=tokenizer,
  num_total_iters=num_total_iters,
  args=args)

trainer = DeepSpeedPPOTrainer(engine=engine, args=args)

for prompt_batch in prompt_train_dataloader:
  out = trainer.generate_experience(prompt_batch)
  actor_loss, critic_loss = trainer.train_rlhf(out)

推理

将训练好的模型导入提供的简易 API,你不仅可以与模型进行问答互动,还可以进行对话式交流。

Human:        Do you know Microsoft?
Assistant:    Yes, I’m a big fan of Microsoft.  What about them?
Human:        Can you explain it to a 6-year old child? I wonder how I should describe it
Assistant:    Sure, I can try.  Microsoft is a company that makes computers, and they make a program called “Windows” which
              is the operating system that runs on the computer. It’s like the “brain” of the computer.  It’s where all the
              programs and files are stored. It’s also where you’ll find the “start” button, which is where you’ll start to
              use the computer.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/905829.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【SA8295P 源码分析】03 - SA8295P QNX Host上电开机流程分析

【SA8295P 源码分析】03 - SA8295P QNX Host上电开机流程分析 一、阶段1 固件开机自检 (SM BIST):APPS PBL加载XBL后触发 INT_RESET进行Warm Reset二、阶段2 固件开机自检 (SM BIST):加载TZ,初始Hypervisor,启动QNX Kernel&#x…

k8编写yaml文件小工具

在刚接触k8s的时候觉得yaml资源文件非常的难写,完全看不懂,经过一段时间的摸索学习,发现k8s平台中是提供了一系列的工具和技巧的,可以帮助我们很好的编写资源文件,提升编写yaml文件的能力,常用的命令工具是…

python爬虫9:实战2

python爬虫9:实战2 前言 ​ python实现网络爬虫非常简单,只需要掌握一定的基础知识和一定的库使用技巧即可。本系列目标旨在梳理相关知识点,方便以后复习。 申明 ​ 本系列所涉及的代码仅用于个人研究与讨论,并不会对网站产生不好…

时序预测 | MATLAB实现SO-CNN-LSTM蛇群算法优化卷积长短期记忆神经网络时间序列预测

时序预测 | MATLAB实现SO-CNN-LSTM蛇群算法优化卷积长短期记忆神经网络时间序列预测 目录 时序预测 | MATLAB实现SO-CNN-LSTM蛇群算法优化卷积长短期记忆神经网络时间序列预测预测效果基本介绍程序设计学习总结参考资料 预测效果 基本介绍 时序预测 | MATLAB实现SO-CNN-LSTM蛇群…

深入解析淘宝API,实现高效商务应用

淘宝API的基本调用 1. API文档与SDK 淘宝API官方提供了详细的API文档,包含了API的使用说明、参数列表、示例代码等内容。开发者可以通过文档了解每个API接口的具体功能和使用方法。此外,淘宝API还提供了多种编程语言的SDK,方便开发者进行快速…

桌游新篇:3.1 UserCase分析

距离上一次停止更新这个系列有将近9个月了。 工作这么久,学会了一件事,就是想清楚再动手。当然,后续工作已经渐渐展开了,而且当下属于天时地利人和(既有当前MR设备带来的硬件buff,又有大语言模型&#xff…

SOPC之NIOS Ⅱ实现电机转速PID控制

通过FPGA开发板上的NIOS Ⅱ搭建电机控制的硬件平台,包括电机正反转、编码器的读取,再通过软件部分实现PID算法对电机速度进行控制,使其能够渐近设定的编码器目标值。 一、PID算法 PID算法(Proportional-Integral-Derivative Algo…

21-注意点说明:scoped样式冲突 / data

组件的三大组成部分 - 注意点说明 组件的样式冲突 scoped 默认情况:写在组件中的样式会 全局生效 -> 因此很容易造成多个组件之间的样式冲突问题 1.全局样式: 默认组件中的样式会作用到全局 2.局部样式: 可以给组件加上 scoped 属性,可以让样式只作用于当前组件 scoped原理…

《有效调节情绪,保持工作心态平和》

工作中,我们有时会遇到各种挑战和困难,这些挑战和困难可能引发我们的负面情绪,例如焦虑、愤怒和沮丧等。然而,保持稳定的情绪是实现高效工作的重要因素之一。本文将分享如何在工作中保持稳定的情绪。 首先,让我们来谈谈…

Spring Boot 如何通过jdbc+HikariDataSource 完成对Mysql 操作

😀前言 本篇博文是关于Spring Boot 如何通过jdbcHikariDataSource 完成对Mysql 操作的说明,希望你能够喜欢😊 🏠个人主页:晨犀主页 🧑个人简介:大家好,我是晨犀,希望我的…

Python多组数据三维绘图系统

文章目录 增添和删除坐标数据更改绘图逻辑源代码 Python绘图系统: 基础:将matplotlib嵌入到tkinter 📈简单的绘图系统 📈数据导入📈三维绘图系统自定义控件:坐标设置控件📉坐标列表控件 增添和…

录屏有哪些讲究?有哪些好用的录屏软件?

在如今数字时代,视频分享已经成为一种流行的传播方式。为了制作高质量的视频内容,录屏已经成为了一种必备的技能。但是,要想制作出令人满意的录屏视频,需要了解一些讲究和使用一些好用的录屏软件。 录屏是一种视觉传达方式&#x…

【prism】发布订阅和取消订阅,进一步梳理

一个对象对应一个事件订阅 一个事件是可以被重复订阅的,如果一个事件被订阅了三次,那边发布一次该事件,就会触发三次事件订阅: 通过观察Prism的事件聚合器对象,发现它此时包含了三个事件对象,其中第三个事件订阅数量达到了3! 这样的话,如果调用一次 Publish ,那么S…

Android 获取 SHA256 签名

在 Android Studio 中的 Terminal ,输入命令: keytool -list -v -keystore debug.keystore 如果出现以下提示: keytool -genkey -v -keystore debug.keystore -alias androiddebugkey -keyalg RSA -validity 10000 按照提示输入相关信息,…

SIP 7英寸触摸屏寻呼主机

SV-8006TP SIP7英寸触摸屏寻呼主机 一、描述 SV-8006TP是我司的一款SIP桌面式对讲广播主机,具有10/100M以太网接口,从网络接口接收网络的音频数据,提供立体声音频输出。 SV-8006TP寻呼话筒可以通过麦克风或者本地线路输入对终端进行分区广…

Java【手撕双指针】LeetCode 283. “移动零“, 图文详解思路分析 + 代码

文章目录 前言一、移动零1, 题目2, 思路分析3, 代码展示 前言 各位读者好, 我是小陈, 这是我的个人主页, 希望我的专栏能够帮助到你: 📕 JavaSE基础: 基础语法, 类和对象, 封装继承多态, 接口, 综合小练习图书管理系统等 📗 Java数据结构: 顺序表, 链表,…

传统图像处理之直方图均衡化

重要说明:本文从网上资料整理而来,仅记录博主学习相关知识点的过程,侵删。 一、参考资料 直方图均衡化的原理及实现 图像处理之直方图均衡化 二、直方图 1. 直方图的概念 图像的灰度直方图,描述了图像中灰度分布情况&#xf…

BaiChuan13B多轮对话微调范例

前方干货预警:这可能是你能够找到的,最容易理解,最容易跑通的,适用于多轮对话数据集的大模型高效微调范例。 我们构造了一个修改大模型自我认知的3轮对话的玩具数据集,使用QLoRA算法,只需要5分钟的训练时间…

antd5源码调试环境启动(MacOS)

将源码下载至本地 这里antd5 版本是5.8.3 $ git clone gitgithub.com:ant-design/ant-design.git $ cd ant-design $ npm install $ npm start前提:安装python3、node版本18.14.0(这是本人当前下载的版本) python3安装教程可参考:https://…

达梦数据库读写分离集群原理

概述 本文就达梦数据库读写分离原理进行介绍。 达梦读写分离集群特点: 可以配置8个即时备库或8个实时备库;读写操作自动分离、负载均衡;提供数据同步;备库故障自动处理,故障恢复自动数据同步等功能,也支持…