LLM微调(三)| 大模型中RLHF + Reward Model + PPO技术解析

news2024/11/20 12:40:08

        本文将深入探讨RLHF(Reinforcement Learning with Human Feedback)、RM(reward model)和PPO(Proximal Policy Optimizer)算法的概念。然后,通过代码演示使用RLHF训练自己的大模型和奖励模型RM。最后,简要深入研究模型毒性和幻觉,以及如何创建一个更面向模型的产品或更有益、诚实、无害、可靠,并与人类反馈对齐的生成人工智能的生命周期。

一、RLHF(Reinforcement Learning with Human Feedback)

图片

       先来举一个简单的例子——想象一下,我们正在创建一个LLM会话式人工智能产品模型,它可以为经历艰难时期的人类提供治疗,如果我们训练了一个大模型,但没有使其与人类保持一致,它通过药物滥用等方式为这些人提供让他们感觉更好和最佳的非法方式,这将导致伤害、缺乏有效的可靠性和帮助。正如OpenAI CTO所说,大模型领域正在蓬勃发展,大模型更可靠、更一致、产生更少的幻觉,唯一可能的方法是使用来自不同人群的人类反馈,以及其他方式,如RAG、Langchain,来提供基于上下文的响应。生成人工智能生命周期可以最大限度地提高了帮助性,最大限度地减少了困难,避免了与危险话题的讨论和参与。

       在深入了解RLHF之前,我们先介绍一下强化学习的基本原理,如下图所示:

图片

     RL是Agent与环境Environment不断交互的过程,首先Agent处于Environment的某个state状态下,然后执行一个action,就会对环境产生影响,从而进入另一个state下,如果对Environment是好的或者是期待的,那么会得到正向的reward,否则是负向的,最终一般是让整个迭代过程中累积reward最大。

二、在大模型的什么环节使用RL呢?

图片

       这里有Agent、Environment和大模型的Current Context,在这种情况下,策略就是知道我们预训练或者微调过的LLM模型。现在我们希望能够在给定的域中生成文本,对吗?因此,我们采取行动,LLM获取当前上下文窗口和环境上下文,并基于该动作,获得奖励。带着奖励的策略就是人类反馈的地方。

三、奖励模型Reward Model介绍

       基于人类的反馈数据来训练一个奖励模型,该模型会在RLHF中被调用,并且不需要人类的参与,就可以根据用户不同的Prompt来分配不同的奖励reward,这个过程被称为”Rollout“。

那么如何构建人类反馈的数据集呢?

图片

数据集格式,如下图所示:

图片

四、奖励模型Reward Model训练

有了人类反馈的数据集,我们就可以基于如下流程来训练RM模型:

图片

五、使用RLHF (PPO & KL Divergence)进行微调

  1. 把一个Prompt数据集输入给初始LLM中;

  2. 给instruct LLM输入大量的Prompts,并得到一些回复;

  3. 把Prompt补全输入给已经训练好的RM模型,RM会生成对应的score,然后把这些score输入给RL算法;

  4. 我们在这里使用的RL算法是PPO,会根据Prompt生成一些回复,对平均值进行排序,使用反向传播来评估响应,最后将最优的回复输入给instruct LLM;

  5. 进行几次迭代后,会得到一个奖励模型,但这有一个不利的方面。

PS:如果我们的模型不断接受积极价值观的训练,然后开始提供奇怪、模糊和不符合人类的输出,会怎么样?

图片

        为了解决上述问题,我们采用如下流程:

图片

       首先使用参考模型,冻结其中的所有权重,作为我们人类对齐模型的参考点,然后基于这种迁移,我们使用KL散度惩罚添加到奖励中,这样当模型产生幻觉时,它会使模型回到参考模型附近,以提供积极但不奇怪的积极反应。我们可以使用PEFT适配器来训练我们的PPO模型,并使模型在推出时越来越一致。

六、使用RLHF (PEFT + LORA + PPO)微调实践

6.1 安装相关的包

!pip install --upgrade pip!pip install --disable-pip-version-check \    torch==1.13.1 \    torchdata==0.5.1 --quiet​​​​​
!pip install \    transformers==4.27.2 \    datasets==2.11.0 \    evaluate==0.4.0 \    rouge_score==0.1.2 \    peft==0.3.0 --quiet# Installing the Reinforcement Learning library directly from github.!pip install git+https://github.com/lvwerra/trl.git@25fa1bd

6.2 导入相关的包​​​​​​​

from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfigfrom datasets import load_datasetfrom peft import PeftModel, PeftConfig, LoraConfig, TaskType
# trl: Transformer Reinforcement Learning libraryfrom trl import PPOTrainer, PPOConfig, AutoModelForSeq2SeqLMWithValueHeadfrom trl import create_reference_modelfrom trl.core import LengthSampler
import torchimport evaluate
import numpy as npimport pandas as pd
# tqdm library makes the loops show a smart progress meter.from tqdm import tqdmtqdm.pandas()

6.3 加载LLaMA 2模型​​​​​​​

from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-34b-Instruct-hf")model = AutoModelForCausalLM.from_pretrained("codellama/CodeLlama-34b-Instruct-hf")huggingface_dataset_name = "knkarthick/dialogsum"dataset_original = load_dataset(huggingface_dataset_name)dataset_original

6.4 预处理数据集​​​​​​​

def build_dataset(model_name,    dataset_name,    input_min_text_length,    input_max_text_length):“””Preprocess the dataset and split it into train and test parts.Parameters:- model_name (str): Tokenizer model name.- dataset_name (str): Name of the dataset to load.- input_min_text_length (int): Minimum length of the dialogues.- input_max_text_length (int): Maximum length of the dialogues.Returns:- dataset_splits (datasets.dataset_dict.DatasetDict): Preprocessed dataset containing train and test parts.“””    # load dataset (only “train” part will be enough for this lab).    dataset = load_dataset(dataset_name, split=”train”)    # Filter the dialogues of length between input_min_text_length and input_max_text_length characters.    dataset = dataset.filter(lambda x: len(x[“dialogue”]) > input_min_text_length and len(x[“dialogue”]) <= input_max_text_length, batched=False)    # Prepare tokenizer. Setting device_map=”auto” allows to switch between GPU and CPU automatically.    tokenizer = AutoTokenizer.from_pretrained(model_name, device_map=”auto”)    def tokenize(sample):        # Wrap each dialogue with the instruction.        prompt = f”””        Summarize the following conversation.        {sample[“dialogue”]}        Summary:        “””        sample[“input_ids”] = tokenizer.encode(prompt)        # This must be called “query”, which is a requirement of our PPO library.        sample[“query”] = tokenizer.decode(sample[“input_ids”])        return sample    # Tokenize each dialogue.    dataset = dataset.map(tokenize, batched=False)    dataset.set_format(type=”torch”)# Split the dataset into train and test parts.    dataset_splits = dataset.train_test_split(test_size=0.2, shuffle=False, seed=42)    return dataset_splitsdataset = build_dataset(model_name=model_name,    dataset_name=huggingface_dataset_name,    input_min_text_length=200,    input_max_text_length=1000)print(dataset)

6.5 抽取模型参数​​​​​​​

def print_number_of_trainable_model_parameters(model):    trainable_model_params = 0    all_model_params = 0    for _, param in model.named_parameters():        all_model_params += param.numel()        if param.requires_grad:            trainable_model_params += param.numel()    return f"\ntrainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"

6.6 将适配器添加到原始salesforce代码生成模型中。现在,我们需要将它们传递到构建的PEFT模型,也将is_trainable=True。​​​​​​​

lora_config = LoraConfig(    r=32, # Rank    lora_alpha=32,    target_modules=["q", "v"],    lora_dropout=0.05,    bias="none",    task_type=TaskType.SEQ_2_SEQ_LM # FLAN-T5)​​​​​​
model = AutoModelForSeq2SeqLM.from_pretrained(model_name,                                               torch_dtype=torch.bfloat16)peft_model = PeftModel.from_pretrained(model,                                        '/kaggle/input/generative-ai-with-llms-lab-3/lab_3/peft-dialogue-summary-checkpoint-from-s3/',                                        lora_config=lora_config,                                       torch_dtype=torch.bfloat16,                                        device_map="auto",                                                                              is_trainable=True)print(f'PEFT model parameters to be updated:\n{print_number_of_trainable_model_parameters(peft_model)}\n')​​​​​​​
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(peft_model,torch_dtype=torch.bfloat16,is_trainable=True)print(f'PPO model parameters to be updated (ValueHead + 769 params):\n{print_number_of_trainable_model_parameters(ppo_model)}\n')print(ppo_model.v_head)​​​​​​​
ref_model = create_reference_model(ppo_model)print(f'Reference model parameters to be updated:\n{print_number_of_trainable_model_parameters(ref_model)}\n')

  使用Meta AI基于RoBERTa的仇恨言论模型(https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target)作为奖励模型。这个模型将输出logits,然后预测两类的概率:notate和hate。输出另一个状态的logits将被视为正奖励。然后,模型将使用这些奖励值通过PPO进行微调。​​​​​​​

toxicity_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"toxicity_tokenizer = AutoTokenizer.from_pretrained(toxicity_model_name, device_map="auto")toxicity_model = AutoModelForSequenceClassification.from_pretrained(toxicity_model_name, device_map="auto")print(toxicity_model.config.id2label)​​​​​​
non_toxic_text = "#Person 1# tells Tommy that he didn't like the movie."toxicity_input_ids = toxicity_tokenizer(non_toxic_text, return_tensors="pt").input_idslogits = toxicity_model(input_ids=toxicity_input_ids).logitsprint(f'logits [not hate, hate]: {logits.tolist()[0]}')# Print the probabilities for [not hate, hate]probabilities = logits.softmax(dim=-1).tolist()[0]print(f'probabilities [not hate, hate]: {probabilities}')# get the logits for "not hate" - this is the reward!not_hate_index = 0nothate_reward = (logits[:, not_hate_index]).tolist()print(f'reward (high): {nothate_reward}')

6.7 评估模型的毒性​​​​​​​

toxicity_evaluator = evaluate.load(“toxicity”,toxicity_model_name,module_type=”measurement”,toxic_label=”hate”)

参考文献

[1] https://medium.com/@madhur.prashant7/rlhf-reward-model-ppo-on-llms-dfc92ec3885f

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1294226.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

上网监控软件——安全与隐私的平衡

网络已经成为人们生活和工作中不可或缺的一部分。然而&#xff0c;随着网络使用的普及&#xff0c;网络安全问题也日益突出。上网监控软件作为网络安全领域的一个重要组成部分&#xff0c;在保护企业和家庭网络安全方面发挥着重要作用。 本文将探讨上网监控软件的背景、功能、优…

Java第二十一章

一.网络程序设计基础 1.网络协议 网络协议规定了计算机之间连接的物理、机械(网线与网卡的连接规定)、电气(有效的电平范围)等特征&#xff0c;计算机之间的相互寻址规则&#xff0c;数据发送冲突的解决方式&#xff0c;长数据如何分段传送与接收等内容.就像不同的国家有不同的…

AI烟火识别智能视频分析系统解决方案

引言 随着城市化进程的加快和高科技的迅猛发展&#xff0c;传统的消防系统逐渐显露出局限性。在这种背景下&#xff0c;AI烟火识别智慧消防解决方案应运而生&#xff0c;它融合了最新的AI技术&#xff0c;旨在提高火灾的预防、检测、应对和控制能力&#xff0c;保护人民生命财…

Sbatch, Salloc提交任务相关

salloc 申请计算节点&#xff0c;然后登录到申请到的计算节点上运行指令&#xff1b; salloc的参数与sbatch相同&#xff0c;该部分先介绍一个简单的使用案例&#xff1b;随后介绍一个GPU的使用案例&#xff1b;最后介绍一个跨节点使用案例&#xff1b; 首先是一个简单的例子&a…

基于springboot实现的仿天猫商城项目

一、系统架构 前端&#xff1a;jsp | js | css | jquery 后端&#xff1a;springboot | mybatis-plus 环境&#xff1a;jdk1.7 | mysql | maven 二、代码及数据库 三、功能介绍 01. web端-首页 02. web端-商品查询 03. web端-商品详情 04. web端-购物车 05. web端-订单…

【网络安全】-《网络安全法》制定背景和核心内容

文章目录 1. 背景介绍1.1 数字时代的崛起1.2 中国网络安全形势 2. 《网络安全法》核心内容2.1 法律适用范围2.2 个人信息保护2.3 关键信息基础设施保护2.4 网络安全监管和应急响应2.5 网络产品和服务安全管理2.6 法律责任和处罚 3. 法律的意义和影响3.1 维护国家安全3.2 保护个…

《使用ThinkPHP6开发项目》 - 设置项目环境变量

《使用ThinkPHP6开发项目》 - 安装ThinkPHP框架-CSDN博客 在上一编我们讲了ThinkPHP6框架的创建&#xff0c;创建完成ThinkPHP6框架后&#xff0c;我们这里就可以开始设置我们的环境变量了。 安装完成ThinkPHP6框架生成的项目文件 修改项目配置我们修改项目config文件夹里的对…

<JavaEE> 多线程编程中的“等待和通知机制”:wait 和 notify 方法

目录 一、等待和通知机制的概念 二、wait() 方法 2.1 wait() 方法的使用 2.2 超时等待 2.3 异常唤醒 2.4 唤醒等待的方法 三、notify() 方法 四、notifyAll() 方法 五、wait 和 sleep 的对比 一、等待和通知机制的概念 1&#xff09;什么是等待和通知机制&#xff1f…

2023年4K投影仪怎么选?极米H6 4K高亮版怎么样?

随着人们生活水平的不断提升&#xff0c;投影仪也逐渐成为了家家户户的必备家居好物。近十年来&#xff0c;中国投影仪市场规模增长数倍&#xff0c;年均增长率大幅提高。从近10年的发展趋势来看&#xff0c;投影仪行业处于高速发展期。 此前&#xff0c;极米科技推出的极米H6…

crmeb本地开发配置代理

crmeb 是一个开源的商城系统&#xff0c; v5 版本是一个前后端分离的项目&#xff0c; 我们从git仓库中下载下来的是一个文件夹&#xff0c;其结构是这样的 我的系统没有使用docker &#xff0c;使用的是 laragon 的系统 所以首先我们要在 nginx 中配置 之后&#xff0c; 我们…

IDEA使用git从远程仓库获取项目

将地址填入url中 然后直接clone就行

Ant Design Vue 年选择器

文章目录 参考文档效果展示实现过程 参考文档 提示&#xff1a;这里可以添加本文要记录的大概内容&#xff1a; DatePicker 日期选择框 大佬&#xff1a;搬砖小匠&#xff08;Ant Design vue 只选择年&#xff09; 提示&#xff1a;以下是本篇文章正文内容&#xff0c;下面案…

C语言——指针(五)

&#x1f4dd;前言&#xff1a; 上篇文章C语言——指针&#xff08;四&#xff09;更加深入的介绍了不同类型指针的特点&#xff0c;这篇文章主要想记录一下函数与指针的结合运用以及const和assert关于指针的用法&#xff1a; 1&#xff0c;函数与指针 2&#xff0c;const 3&am…

十五届蓝桥杯分享会(一)

注&#xff1a;省赛4月&#xff0c;决赛6月 一、蓝桥杯整体介绍 1.十四届蓝桥杯软件电子赛参赛人数&#xff1a;C 8w&#xff0c;java/python 2w&#xff0c;web 4k&#xff0c;单片机 1.8w&#xff0c;嵌入式/EDA5k&#xff0c;物联网 300 1.1设计类参赛人数&#xff1a;平…

STL(一)(pair篇)

1.pair的定义和结构 在c中,pair是一个模板类,用于表示一对值的组合它位于<utility>头文件中 pair的定义如下: template<class T1, class T2> struct pair{T1 first; //第一个值T2 second; //第二个值//构造函数pair();pair(const T1&x,const T2&y);//比较…

域名与SSL证书

域名是互联网上的地址标识符&#xff0c;它通过DNS&#xff08;Domain Name System&#xff09;将易于记忆的人类可读的网址转换为计算机可以理解的IP地址。当用户在浏览器中输入一个网址时&#xff0c;实际上是通过DNS解析到对应的服务器IP地址&#xff0c;从而访问到相应的网…

诚邀莅临,共商发展丨“交汇未来”行业大模型高峰论坛

大会简介 今年以来&#xff0c;以ChatGPT为典型代表的大模型在全球数字科技界引起极大关注&#xff0c;其强大的数据处理能力和泛化性能使得其在各个领域都有广泛的应用前景&#xff0c;驱动千行百业的数字化转型升级&#xff0c;成为新型工业化和实体经济的重要推动力&#x…

【C语言】vfprintf函数

vfprintf 是 C 语言中的一个函数&#xff0c;它是 fprintf 函数的变体&#xff0c;用于格式化输出到文件中。vfprintf 函数接受一个格式化字符串和一个指向可变参数列表的指针&#xff0c;这个列表通常是通过 va_list 类型来传递的。vfprintf 函数的主要用途是在需要处理不定数…

数据分析基础之《matplotlib(5)—直方图》

一、直方图介绍 1、什么是直方图 直方图&#xff0c;形状类似柱状图却有着与柱状图完全不同的含义。直方图牵涉统计学的概念&#xff0c;首先要对数据进行分组&#xff0c;然后统计每个分组内数据元的数量。在坐标系中&#xff0c;横轴标出每个组的端点&#xff0c;纵轴表示频…

PyQt5 - 鼠标连点器

文章目录 ⭐️前言⭐️鼠标连点器 ⭐️前言 本次设计的鼠标连点器主要是对QVBoxLayout、QHBoxLayout和QStackedWidget进行一个回顾复习&#xff0c;加深对它们的理解&#xff0c;提高运用的熟练度。 ⭐️鼠标连点器 如以下代码所示&#xff0c;设计两个QWidget控件&#xff…