Text-to-SQL小白入门(九)InstructGPT论文:教你如何训练ChatGPT

news2024/12/23 23:25:26

论文概述

InstructGPT和ChatGPT 的训练流程基本一致 ,ChatGPT是改进后的InstructGPT,比如InstructGPT是基于GPT-3训练,而ChatGPT是基于GPT-3.5训练。

基本信息

  • 英文标题:Training language models to follow instructions with human feedback
  • 中文标题:通过人类反馈的指令训练语言模型
  • 发表时间:2023年3月 arxiv
  • 作者单位:Open AI
  • 论文链接:https://arxiv.org/pdf/2203.02155.pdf
  • 代码链接:GitHub - openai/following-instructions-human-feedback

学习InstructGPT论文之前,想了解了基本的LLM或者RLHF流程,可以看看组织「eosphoros-ai」(今年的8000+star的开源项目DB-GPT的开源社区)提出的LLM+Text2SQL汇总项目:https://github.com/eosphoros-ai/Awesome-Text2SQL,里面也收集了一些微调SFT(lora, qlora, p-tuning等),RLHF相关的论文(比如RLHF,RRHF,RLTF, RRTF, RLAIF等等),目前也有300+的star,持续更新中,欢迎围观使用star!

摘要

背景

使语言模型更大并不能使它们更好地遵循用户的意图。例如,大型语言模型可能生成不真实的(untruthful)有害的(toxic)或对用户没有帮助(not helpful)的输出。

贡献/方法

在本文中,作者展示了一种方法,通过使用人类反馈进行微调,在广泛的任务中使语言模型与用户意图保持一致。

  • 先使用有监督微调SFT
  • 然后收集一批rank排序的模型输出
  • 再使用人类反馈的强化学习rlhf微调
  • 最终得到的模型叫做InstructGPT

结果:参数量小了100倍,性能差不多。 真实性⬆️、有毒⬇️、精度⬇️(轻微)

结果惊艳:

  • 1.3b参数的InstructGPT的模型输出和175b GPT-3的输出很类似。
  • 在公共NLP数据集上,InstructGPT模型显示出真实性的改进和有毒输出生成的减少,同时性能下降最小

结论:

尽管InstructGPT仍然会犯一些简单的错误,但结果表明,根据人类反馈进行微调是使语言模型与人类意图保持一致的一个有希望的方向

结果

API prompt distribution

  • 参数说明:
    • 横坐标是模型参数大小,纵坐标是和175B GPT SFT比较赢的概率(比如绿色的线条,横坐标为175B时候,赢的概率刚好为0.5,此时就是175B GPT SFT vs 175B GPT SFT )
    • GPT就是最普通的模型
    • GPT(prompted)就是给几个例子few-shot
    • SFT 有监督微调
    • PPO 用强化学习
    • PPO-ptx: 在PPO算法期间,使用pretraining mix (但是几乎没有什么效果)
  • 对比的模型是SFT 175B,可以发现的是1.3B PPO或者PPO-ptx已经超过0.5的概率赢175B,说明方法很有效。
  • InstructGPT就是PPO-ptx

论文还在 public NLP dataset进行了实验,InstructGPT模型在公有NLP数据集上有“对齐税”导致性能下降,可能是因为API prompt 训练的原因。

论文还公布了qualitative results,InstructGPT模型泛化能力很强,具体实验参考原论文。

结论

对齐研究alignment research的影响

  • 提高模型对齐度的成本比预训练低。
  • InstructGPT泛化能力强,可以推广到没有监督数据的领域。
  • 通过微调,可以减少性能下降
  • 验证了对齐技术在现实生活中应用

对齐的是什么?

人类偏好,人类价值观 --> 标注者的偏好、OpenAI 研究人员的偏好、API 用户的偏好。

核心方法

RLHF架构图

基础背景知识

  • RLHF方法最早是2017年提出:Deep reinforcement learning from human preferences(2017)
  • 在2020年RLHF文章「Learning to summarize from human feedback(2020」中,RM训练使用了两个模型在相同input情况下的output进行比较,使用交叉熵损失。——InstructGPT使用KL散度
  • PPO算法,也是Open AI 2017年提出的:Proximal policy optimization algorithms(2017),这篇文章的作者「John Schulman」也在InstructGPT作者名单中。

这个图也是经典大图了,RLHF实践参考的范式,RLHF主要分成了3个阶段:

  • 第一阶段:SFT
  • 第二阶段:RM
  • 第三阶段:RL (使用PPO算法:proximal policy optimization 最近策略优化),对第三阶段进行一个简单解释:
    • 输入一个标注数据,模型经过PPO算法输出一个response
    • RM模型对response打分
    • 根据打分score更新PPO策略。

PPO算法具体是什么呢?——(留个坑,后续补上)

详情参考论文:Schulman, J., Wolski, F., Dhariwal, P., Radford, A., and Klimov, O. (2017). Proximal policy optimization algorithms. arXiv preprint arXiv:1707.06347.

SFT

数据格式

  • prompt - output

更直观一点,以一个具体的小任务比如Text2SQL为例子,构造的数据集如下所示:

来源知乎文档:Text-to-SQL小白入门(八)RLAIF论文:AI代替人类反馈的强化学习

{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","output": "SELECT count(*) FROM head WHERE age  >  56"}

实验参数

参数如下:

  • base model——GPT-3
  • epoch——16
  • lr decay——cosine
  • dropout——0.2

选择最终的SFT模型时,是根据验证集上的RM分数。

惊讶点:

  • 1个epoch后已经过拟合了,但是为了后续的RM分数,还是多跑几轮epoch

RM

数据格式

  • prompy-chosen-rejected

同样的,以Text2SQL任务为例子,构造的数据集如下所示:

{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","chosen": "SELECT count(*) FROM head WHERE age  >  56","rejected":"SELECT COUNT(head_name) FROM head WHERE age > 56;"}

实验参数

  • base model: 是GPT-3 SFT之后的模型,但是去掉了最后一层
    • 因为原始模型输入是prompt,输出是response
    • 现在需要模型输入是prompt + response,输出是score
  • 参数量仅选择的6B大小

为什么RM模型选6B,不是175B?

    • 6B 减少计算量
    • 175B 训练不稳定
  • 标注者,需要对K=4 和 K=9之间的response进行排序,会产生C(k, 2)个两两比较pair
  • 一个epoch中,对所有的C(k, 2)比较对训练,一次传播loss

损失函数:

  • x代表输入的prompt;y_w代表chosen_data; y_l代表rejected_data; D代表实验数据集
  • r_θ(x,y)代表RM模型输入prompt x和response y的输出得分

最后要对奖励归一化,使得平均奖励为0。

RL

数据格式

  • prompt-output

和SFT阶段数据格式一致。

{"prompt": "I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\"\n##Instruction:\ndepartment_management contains tables such as department, head, management. Table department has columns such as Department_ID, Name, Creation, Ranking, Budget_in_Billions, Num_Employees. Department_ID is the primary key.\nTable head has columns such as head_ID, name, born_state, age. head_ID is the primary key.\nTable management has columns such as department_ID, head_ID, temporary_acting. department_ID is the primary key.\nThe head_ID of management is the foreign key of head_ID of head.\nThe department_ID of management is the foreign key of Department_ID of department.\n###Input:\nHow many heads of the departments are older than 56 ?\n\n###Response:","output": "SELECT count(*) FROM head WHERE age  >  56"}

实验参数

1.RM可以和RL重复多轮迭代——这样构建更多数据,越来越趋近于人类偏好。

  • SFT训练->训练一个RM->训练一个RL->不断重复下面的步骤:
    • 构建RM数据->重新训练一个RM->重新训练一个RL->
    • 构建RM数据->重新训练一个RM->重新训练一个RL->
    • 构建RM数据->重新训练一个RM->重新训练一个RL->

2.实践中,大部分的比较数据来源于SFT的数据,少部分数据来源于RL模型的比较数据。

  • 继2020文章「Learning to summarize from human feedback」之后,作者再次使用PPO对环境中的SFT模型进行了微调。
  • 额外增加了 KL散度。
  • 额外增加了预训练梯度——目的是为了减少在NLP数据集上性能倒退,所以InstructGPT模型 == PPO-ptx

  • π^RL代表学习到的强化学习RL模型; π^SFT代表SFT阶段训练的模型。

为什么用π表示?为什么用除法表示?这就是强化学习的基本概念

从状态State到动作Action的过程就称之为一个策略Policy,一般用π表示(可以理解为一个函数表示),也就是在强化学习阶段需要找到一个关系:a=π(s) 或者是 π(a|s), a 就是action, s就是state

  • D_pretrain代表预训练阶段的数据分布;D_π^RL代表强化学习阶段的数据分布
  • r_θ(x,y)代表RM模型输入prompt x和response y的输出得分
  • β是控制KL奖励的系数; γ是控制预训练梯度的系数,如果是普通的PPO,那么γ=0

数据收集

之前听一个大学教授的讲座,有个观点很有意思:Open AI做大模型为什么比谷歌强,因为包括transformer在内的一些创新模型大多是谷歌研究的,那为什么Open AI在大模型领域为什么比谷歌强?答:因为Open AI在数据清洗,数据质量把控这方面做的很好。——所以数据是相当重要的!

API数据

为了训练本文的最终InstructGPT

prompt dataset 主要由OpenAI 的API获得,用户和API交互,把这些数据收集起来(前提是用户使用的时候就告知数据要被收集),此时的API是早期的InstructGPT模型,并且没有使用用户在生产中使用API的数据。

API数据分布如下,主要有9类。

那么问题来了?早期的InstructGPT模型的训练数据怎么来?

  • 通过人工标注的有监督学习训练得到的

对API收集的数据做了一些处理:

  • 去除重复的提示:通过检查公共前缀(感觉回到了leetcode刷题,求两个字符串的最长公共前缀)
  • 每个用户不超过200条prompt:应该是避免单独个体的偏好
  • 基于用户id,划分train,val,test——这样验证集和测试集就不包含来自训练集中的用户的数据
    • 比如训练数据用id 1, 2, 3, 4的所有数据
    • 测试的数据用id 5的数据。
  • 过滤掉了个人身份信息的数据

人工标注数据

主要是为了训练早期的InstructGPT

标注者被要求手写以下三种类型的prompt:

  • plain:标记人员提出任意的简单任务,同时保证任务的多样性
  • few-shot:标注人员提出一条指令instruction,以及该指令的多个查询/响应对(query/response)
  • user-based:标注人员在OpenAI 提供的API中获取用例,标注人员需要给出这些用例相对应的instruction

数据量级

数据中96%以上是英文,其它20个语种例如中文,法语,西班牙语等加起来不到4%,这可能导致InstructGPT/ChatGPT能进行其它语种的生成时,效果应该远不如英文

  • SFT 数据,大概13k
  • RM 数据,大概33k
  • PPO数据,大概31k

论文还有大量的附录数据详情,可以参考论文原文,比如标注人员分布,数据示例,数据标注等等,不得不说,Open AI数据扎实,正文20页,附录48页,总共68页。

其他文章

Text-to-SQL小白入门(一)综述文章学习

Text-to-SQL小白入门(二)Transformer学习

Text-to-SQL小白入门(三)IRNet:引入中间表示SemQL

Text-to-SQL小白入门(四)指令进化大模型WizardLM

Text-to-SQL小白入门(五)开源代码大模型Code Llama

Text-to-SQL小白入门(六)Awesome-Text2SQL项目介绍

Text-to-SQL小白入门(七)PanGu-Coder2论文——RRTF

Text-to-SQL小白入门(八)RLAIF论文:AI代替人类反馈的强化学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1263420.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【鸿蒙应用ArkTS开发系列】- 选择图片、文件和拍照功能实现

文章目录 前言创建多媒体Demo工程创建MediaBean 实体类创建MediaHelper工具类API标记弃用问题动态申请多媒体访问权限实现选择图片显示功能打包测试 前言 在使用App的时候,我们经常会在一些社交软件中聊天时发一些图片或者文件之类的多媒体文件,那在鸿蒙…

当TinyMCE富文本编辑器遇到Vue3+nuxt+ts项目,分享引入成功案例及过程中踩的那些坑

文章目录 前言遇到的坑插入上传图片插件上传图片请求与返回值处理本地文件引入报错解决源码 前言 如果你的前端项目技术栈使用的是Vue3nuxtts,并且老大让你集成一下那个传说中非常丝滑的TinyMCE富文本编辑器,那么恭喜你和我一样中大奖了。 网上找了好久…

uniapp 导航分类

商品分类数据&#xff0c;包括分类名称和对应的商品列表点击弹出 列表的内容 展示效果如下&#xff1a; 代码展示 ①div部分 <view class"container"><view class"menu-bar"><view class"menu"><view class"menu-sc…

CSDN最新最全python+pytest接口自动化(12)-自动化用例编写思路 (使用pytest编写一个测试脚本)

经过之前的学习铺垫&#xff0c;我们尝试着利用pytest框架编写一条接口自动化测试用例&#xff0c;来厘清接口自动化用例编写的思路。 我们在百度搜索天气查询&#xff0c;会出现如下图所示结果&#xff1a; 接下来&#xff0c;我们以该天气查询接口为例&#xff0c;编写接口测…

C语言——字符函数和字符串函数(上)

在编程的过程中&#xff0c;我们经常要处理字符和字符串&#xff0c;为了方便操作字符和字符串&#xff0c;C语⾔标准库中提供了⼀系列库函数&#xff0c;接下来我们就学习⼀下这些函数。 一、 字符分类函数 C语⾔中有⼀系列的函数是专⻔做字符分类的&#xff0c;也就是⼀个字…

[含泪解决]OSError: [Errno 99] Cannot assign requested address__踩坑记录——app.py绑定IP失败

踩坑记录下。 是这个样子的&#xff0c;前几天帮别人部署Python的Flask项目到云服务器上&#xff0c;然后在 app.run(host"xxx.xxx.xxx.xxx",port8080) 这行代码中&#xff0c;xxx.xxx.xxx.xxx代表我的IP地址&#xff0c;port代表我的端口号。 然后不是要部署到服…

Layui框架弹出框form表单中单选按钮状态不刷新

1、问题描述 如下图&#xff1a;当我们点击编辑按钮的时候&#xff0c;实现如果性别(stu_sex)的值为0男生被选中&#xff0c;如果性别的值为1&#xff0c;女生被选中。但是在使用Layui框架的过程中&#xff0c;发现性别的单选按钮无法实现刷新&#xff0c;使用不正常。 1.1、…

PLC:200smart

PLC&#xff1a;200smart 第十章、数据类型、数据存储1、数据类型1.1、有符号数1.2、有符号数 2、传送指令 第十一章、比较指令、整数、浮点数的运算1、比较指令1、运算指令1.1、浮点数运算1.2、整数运算 第十章、数据类型、数据存储 1、数据类型 数据类型分为两大类 无符号数…

k8s部署jenkins

1.先决条件 1.因为国内的容器镜像加速器无法实时更新docker hub上的镜像资源.所以可以自己进行jenkins的容器镜像创建,. 2.这里用到了storageClass k8s的动态制备.详情参考: k8s-StoargClass的使用-基于nfs-CSDN博客 3.安装docker服务.(用于构建docker image) 2.构建jenki…

使用docker-compose优雅部署nacos

查看代码中引入nacos版本 在应用的pom.xml中搜索nacos关键字&#xff0c;找到相关的nacos依赖 点击以来左边的图标&#xff0c;找到依赖管理器中的pom.xml&#xff0c;并全局搜索nacos&#xff0c;即可找到对应的nacos客户端版本 使用docker-compose部署nacos version: 3s…

React Native 更换淘宝镜像提升包下载速度

React Native 更换淘宝镜像提升包下载速度 每次运行项目的时候都是卡在包下载的命令上&#xff0c;每次一等就要 1h20m 极度崩溃&#xff0c;那是因maven镜像源为Google导致无法正常下载。 那么我们就可以切换maven镜像源&#xff0c;方法如下&#xff1a; 找到项目下的**/an…

20分钟拥有自己的ChatGPT4,高效低成本,小白必看

准备工作 1、准备一个3.5的账号 2、一张虚拟卡 开始步骤 从ChatGPT第一版发布到现在&#xff0c;还不到一年的时间中&#xff0c;可是它使用的GPT架构已经从3.5版本进化到现在的4.0版本&#xff0c;随之而来的是其能力的极大提升。下面是GPT-4在其官网的介绍中的一句话&…

C++基础 -8- 函数重载

函数重载格式(图片代码段呈现) #include "iostream"using namespace std;void rlxy(int a) {cout << "int a"<< endl; }void rlxy(char a) {cout << "char a"<< endl; }int main() {rlxy(99);rlxy(c); }函数重载的依据…

Redis 主从架构,Redis 分区,Redis哈希槽的概念,为什么要做Redis分区

文章目录 Redis 主从架构redis replication 的核心机制redis 主从复制的核心原理过程原理Redis集群的主从复制模型是怎样的&#xff1f;生产环境中的 redis 是怎么部署的&#xff1f;机器是什么配置&#xff1f;你往内存里写的是什么数据&#xff1f;说说Redis哈希槽的概念&…

Pytest:让测试断言变得轻松愉快!

前言 断言是完整的测试用例中不可或缺的因素&#xff0c;用例只有加入断言&#xff0c;将实际结果与预期结果进行比对&#xff0c;才能判断它的通过与否。 unittest 框架提供了其特有的断言方式&#xff0c;如&#xff1a;assertEqual、assertTrue、assertIn等&#xff0c;py…

Win7 SP1 x64 Google Chrome 字体模糊

1 打开 Google Chrome &#xff0c;地址栏输入 chrome://version/ &#xff0c;字体模糊。 2 Microsoft Update Catalog 搜索现在更新 kb2670838 &#xff0c;安装&#xff0c;重启电脑。 3 打开 Google Chrome&#xff0c;地址栏输入 chrome://version/ &#xff0c;字体正常。…

Vue和React配置解决跨域,proxy代理两步搞定

Vue配置&#xff1a; 第一步&#xff1a; 找到 vite.config.js 文件 进行如下代码配置 import { defineConfig } from "vite"; import vue from "vitejs/plugin-vue"; export default defineConfig({plugins: [vue()],server: {/*** /api 是代理标识*/p…

Python入门05 print函数

目录 1 Python中的内置函数2 print函数介绍3 print函数的用途总结 1 Python中的内置函数 Python中内置了很多函数&#xff0c;我们可以直接调用&#xff0c;以下是一些常见的函数&#xff1a; abs()&#xff1a;返回一个数的绝对值。all()&#xff1a;判断一个可迭代对象中的…

酒水代理商城小程序开发搭建攻略

随着互联网的快速发展&#xff0c;线上商城已成为越来越多人的选择。对于酒水代理行业来说&#xff0c;拥有一个专属的线上商城小程序能够大大提升业务效率&#xff0c;拓展销售渠道。本文将手把手教你如何开发搭建一个酒水代理商城小程序。 步骤一&#xff1a;登录乔拓云网后台…

技巧-PyTorch中num_works的作用和实验测试

简介 在 PyTorch 中&#xff0c;num_workers 是 DataLoader 中的一个参数&#xff0c;用于控制数据加载的并发线程数。它允许您在数据加载过程中使用多个线程&#xff0c;以提高数据加载的效率。 具体来说&#xff0c;num_workers 参数指定了 DataLoader 在加载数据时将创建的…