RLHF学习

news2024/12/31 3:23:22

整体流程

三个步骤分解:

  1. 预训练一个语言模型 (LM) ;
  2. 聚合问答数据并训练一个奖励模型 (Reward Model,RM) ;
  3. 用强化学习 (RL) 方式微调 LM。

在这里插入图片描述

在这里插入图片描述

RW

RM 的训练是 RLHF 区别于旧范式的开端。这一模型接收一系列文本并返回一个标量奖励,数值上对应人的偏好。我们可以用端到端的方式用 LM 建模,或者用模块化的系统建模 (比如对输出进行排名,再将排名转换为奖励) 。这一奖励数值将对后续无缝接入现有的 RL 算法至关重要。

  • 关于模型选择方面:
    RM 可以是另一个经过微调的 LM,也可以是根据偏好数据从头开始训练的 LM。例如 Anthropic 提出了一种特殊的预训练方式,即用偏好模型预训练 (Preference Model Pretraining,PMP) 来替换一般预训练后的微调过程。因为前者被认为对样本数据的利用率更高。但对于哪种 RM 更好尚无定论。

  • 过程:
    在这里插入图片描述

  • Bradley-Terry(BT)模型是一个常见选择(在可以获得多个排序答案的情况下,Plackett-Luce 是更一般的排序模型)

  • **排序损失:**在最后一层 transformer 层后添加一个线性层以获得奖励值的标量预测。为了确保奖励函数具有较低的方差,之前的工作会对奖励进行归一化
    在这里插入图片描述

RW代码

from dataclasses import dataclass, field
from typing import Optional

import tyro
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig

from trl import RewardConfig, RewardTrainer, is_xpu_available


tqdm.pandas()


@dataclass
class ScriptArguments:
    model_name: str = "facebook/opt-350m"
    """the model name"""
    dataset_name: str = "Anthropic/hh-rlhf"
    """the dataset name"""
    dataset_text_field: str = "text"
    """the text field of the dataset"""
    eval_split: str = "none"
    """the dataset split to evaluate on; default to 'none' (no evaluation)"""
    load_in_8bit: bool = False
    """load the model in 8 bits precision"""
    load_in_4bit: bool = False
    """load the model in 4 bits precision"""
    trust_remote_code: bool = True
    """Enable `trust_remote_code`"""
    reward_config: RewardConfig = field(
        default_factory=lambda: RewardConfig(
            output_dir="output",
            per_device_train_batch_size=64,
            num_train_epochs=1,
            gradient_accumulation_steps=16,
            gradient_checkpointing=True,
            gradient_checkpointing_kwargs={"use_reentrant": False},
            learning_rate=1.41e-5,
            report_to="tensorboard",
            remove_unused_columns=False,
            optim="adamw_torch",
            logging_steps=500,
            evaluation_strategy="no",
            max_length=512,
        )
    )
    use_peft: bool = False
    """whether to use peft"""
    peft_config: Optional[LoraConfig] = field(
        default_factory=lambda: LoraConfig(
            r=16,
            lora_alpha=16,
            bias="none",
            task_type="SEQ_CLS",
            modules_to_save=["scores"],
        ),
    )


args = tyro.cli(ScriptArguments)
args.reward_config.evaluation_strategy = "steps" if args.eval_split != "none" else "no"


# Step 1: Load the model
if args.load_in_8bit and args.load_in_4bit:
    raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif args.load_in_8bit or args.load_in_4bit:
    quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit)
    # Copy the model to each device
    device_map = (
        {"": f"xpu:{Accelerator().local_process_index}"}
        if is_xpu_available()
        else {"": Accelerator().local_process_index}
    )
else:
    device_map = None
    quantization_config = None

model = AutoModelForSequenceClassification.from_pretrained(
    args.model_name,
    quantization_config=quantization_config,
    device_map=device_map,
    trust_remote_code=args.trust_remote_code,
    num_labels=1,
)

# Step 2: Load the dataset and pre-process it
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
train_dataset = load_dataset(args.dataset_name, split="train")


# Tokenize chosen/rejected pairs of inputs
# Adapt this section to your needs for custom datasets
def preprocess_function(examples):
    new_examples = {
        "input_ids_chosen": [],
        "attention_mask_chosen": [],
        "input_ids_rejected": [],
        "attention_mask_rejected": [],
    }
    for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
        tokenized_chosen = tokenizer(chosen)
        tokenized_rejected = tokenizer(rejected)

        new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
        new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
        new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
        new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])

    return new_examples


# Preprocess the dataset and filter out examples that are longer than args.max_length
train_dataset = train_dataset.map(
    preprocess_function,
    batched=True,
    num_proc=4,
)
train_dataset = train_dataset.filter(
    lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
    and len(x["input_ids_rejected"]) <= args.reward_config.max_length
)

if args.eval_split == "none":
    eval_dataset = None
else:
    eval_dataset = load_dataset(args.dataset_name, split=args.eval_split)

    eval_dataset = eval_dataset.map(
        preprocess_function,
        batched=True,
        num_proc=4,
    )
    eval_dataset = eval_dataset.filter(
        lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
        and len(x["input_ids_rejected"]) <= args.reward_config.max_length
    )


# Step 4: Define the LoraConfig
if args.use_peft:
    peft_config = args.peft_config
else:
    peft_config = None

# Step 5: Define the Trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    args=args.reward_config,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
)

trainer.train()

RLHF

  • 动手学强化学习: https://hrl.boyuai.com/chapter/2/actor-critic%E7%AE%97%E6%B3%95

让我们首先将微调任务表述为 RL 问题。

  • 首先,该 策略 (policy) 是一个接受提示并返回一系列文本 (或文本的概率分布) 的 LM。
  • 这个策略的 行动空间 (action space) 是 LM 的词表对应的所有词元 (一般在 50k 数量级)
  • 观察空间 (observation space) 是可能的输入词元序列,也比较大 (词汇量 ^ 输入标记的数量) 。
  • 奖励函数 是偏好模型和策略转变约束 (Policy shift constraint) 的结合。
    在这里插入图片描述

在这里插入图片描述

  • KL散度这一项被用于惩罚 RL 策略在每个训练批次中生成大幅偏离初始模型,以确保模型输出合理连贯的文本。如果去掉这一惩罚项可能导致模型在优化中生成乱码文本来愚弄奖励模型提供高奖励值。

可视化进度条的一种方法:

with tqdm(total=int(num_episodes / 10), desc='Iteration %d' % i) as pbar:
	for i_episode in range(episode):
		if (i_episode + 1) % 10 == 0:
	                pbar.set_postfix({
	                    'episode':
	                    '%d' % (num_episodes / 10 * i + i_episode + 1),
	                    'return':
	                    '%.3f' % np.mean(return_list[-10:])
	                })
     pbar.update(1)

策略梯度

  • 基于值函数的方法主要是学习值函数,然后根据值函数导出一个策略,学习过程中并不存在一个显式的策略;而基于策略的方法则是直接显式地学习一个目标策略。

AC算法

  • 基于值函数的方法只学习一个价值函数,而基于策略的方法只学习一个策略函数

  • Actor-Critic 算法本质上是基于策略的算法,因为这一系列算法的目标都是优化一个带参数的策略,只是会额外学习价值函数,从而帮助策略函数更好地学习。

  • Actor-Critic 算法估计一个动作价值函数 Q Q Q,代替蒙特卡洛采样得到的回报,这便是 Q ( s , a ) Q(s,a) Q(s,a)。这个时候,我们可以把状态价值函数 V V V作为基线,从 Q Q Q函数减去这个 V V V函数则得到了函数 A A A,我们称之为优势函数(advantage function)

Actor-Critic 分为两个部分:Actor(策略网络)和 Critic(价值网络)

  • Actor 要做的是与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略。
  • Critic 要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断在当前状态什么动作是好的,什么动作不是好的,进而帮助 Actor 进行策略更新。

在这里插入图片描述
在这里插入图片描述

PPO

PPO惩罚

PPO-惩罚(PPO-Penalty)用拉格朗日乘数法直接将 KL 散度的限制放进了目标函数中,这就变成了一个无约束的优化问题,在迭代的过程中不断更新 KL 散度前的系数。
在这里插入图片描述

PPO截断

在这里插入图片描述

  • 对于连续动作,让策略网络输出连续动作高斯分布(Gaussian distribution)的均值和标准差。后续的连续动作则在该高斯分布中采样得到。
PPO的训练中存在的问题

PPO会找捷径,只要有机会,PPO 算法就会利用这些缺陷。

  1. 显然,当从概率低于 SFT 模型的策略中采样令牌时,这将导致负 KL 惩罚。但平均而言,它将是正的,否则您将无法从策略中正确采样。使用 KL 惩罚项是为了推动模型的输出保持接近基本策略的输出。一般来说,KL 散度衡量两个分布之间的距离,并且始终为正值。

  2. 某些生成策略可以强制生成某些token或抑制某些token。例如,当批量生成完成的序列时,会进行填充;当设置最小长度时,EOS 令牌会被抑制。该模型可以为那些导致负 KL 的标记分配非常高或非常低的概率。当 PPO 算法针对奖励进行优化时,它会追逐这些负面惩罚,从而导致不稳定。

    • 生成响应时需要小心,我们建议在采用更复杂的生成方法之前始终先使用简单的采样策略。
  3. 损失偶尔会出现峰值,这可能会导致进一步的不稳定。

  4. 字符串的重复会导致奖励的突然增加。

DPO

与以往的 RLHF 方法(先学习一个奖励函数,然后通过强化学习优化)不同,我们的方法跳过了奖励建模步骤,直接使用偏好数据优化语言模型。

  • 我们的核心观点是利用从奖励函数到最优策略的解析映射,将对奖励函数的损失转化为对策略的损失。这种变量转换的方法使我们能够跳过显式的奖励建模步骤,同时仍然在现有的人类偏好模型(如 Bradley-Terry 模型)下进行优化。实质上,策略网络既代表语言模型,又代表奖励。

在这里插入图片描述
在这里插入图片描述

RLHF开源工具

  • TRL
  • RL4LM

TRL实践

  • demo.py
# 0. imports
import torch
from transformers import GPT2Tokenizer

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer


# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# 2. initialize trainer
ppo_config = {"batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)

# 3. encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)

# 4. generate model response
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 20,
}
response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs)
response_txt = tokenizer.decode(response_tensor[0])

# 5. define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0, device=model.pretrained_model.device)]

# 6. train model with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1415657.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

1、PDManer 快速入门

文章目录 序言一、快速入门1.1 PDMan 介绍1.2 特点1.3 下载和安装 小结 序言 本人长期以来一直从事于应用软件的研发以及项目实施工作&#xff0c;经常做数据库建模&#xff08;数据表设计&#xff09;。有一款称心如意的数据库建模工具&#xff0c;自然能够事半功倍&#xff0…

【算法路线图】算法小抄题解-一文理解算法体系-费元星

做研发多年&#xff0c;对算法理解一直不够成体系&#xff0c;基本是每次在面试的时候才会去重点看算法&#xff0c;刷一些题&#xff0c;因此在这里&#xff0c;把我多年的总结发出来&#xff0c;希望晚辈站在一个高的位置学习。 最新链接&#xff1a;有道云笔记 -----------…

阿里云部署配置幻兽帕鲁Palworld联机服务器详细教程

阿里云作为国内领先的云计算服务提供商&#xff0c;为企业和个人提供了丰富的云服务。本文将为大家详细介绍如何在阿里云上配置幻兽帕鲁Palworld联机服务器&#xff0c;以便与更多玩家共同体验游戏的乐趣。 第一步&#xff1a;登录服务器创建页 1、进入幻兽帕鲁联机服务快速部…

设计模式⑩ :用类来实现

文章目录 一、前言二、Command 模式1. 介绍2.应用3. 总结 三、Interpreter 模式1. 介绍2. 应用3. 总结 参考文章 一、前言 有时候不想动脑子&#xff0c;就懒得看源码又不像浪费时间所以会看看书&#xff0c;但是又记不住&#xff0c;所以决定开始写"抄书"系列。本系…

GCP :Stackdriver Logging

官方介绍 Logs Explorer 利用 Logs Explorer&#xff0c;您可以通过灵活的查询语句、丰富的直方图视觉呈现、简单的字段浏览器以及保存查询的功能&#xff0c;对日志进行搜索、排序和分析。设置提醒以便在您包含的日志中出现特定消息时通知您&#xff0c;或者使用 Cloud Moni…

GPT-SoVITS 测试

开箱直用版&#xff08;使用 AutoDL&#xff09; step1 打开地址 https://www.codewithgpu.com/i/RVC-Boss/GPT-SoVITS/GPT-SoVITS-Official 选择 AutoDL创建实例&#xff0c;选择 3080ti 机器 step2 创建好实例之后&#xff0c;进入命令行&#xff0c;输入命令 echo {}>…

Kubernetes成本优化

云原生可以帮助团队更精细化利用资源&#xff0c;但如果缺乏工具的帮助&#xff0c;很难采取适当的措施优化资源的使用。本文介绍了若干用于可视化Kubernetes资源使用情况的工具&#xff0c;并且可以自定义策略优化资源使用&#xff0c;实现更好的成本优化。原文: Kubernetes C…

【计算机二级考试C语言】C强制类型转换

C 强制类型转换 强制类型转换是把变量从一种类型转换为另一种数据类型。例如&#xff0c;如果您想存储一个 long 类型的值到一个简单的整型中&#xff0c;您需要把 long 类型强制转换为 int 类型。您可以使用强制类型转换运算符来把值显式地从一种类型转换为另一种类型&#x…

【NodeJS】004- NodeJS的模块化与包管理工具

模块化 1. 介绍 1.1.什么是模块化与模块 ? 将一个复杂的程序文件依据一定规则(规范)拆分成多个文件的过程称之为 模块化 其中拆分出的 每个文件就是一个模块 ,模块的内部数据是私有的,不过模块可以暴露内部数据以便其他模块使用 1.2 什么是模块化项目 ? 编码时是按照模…

openssl3.2 - 测试程序的学习 - test\aesgcmtest.c

文章目录 openssl3.2 - 测试程序的学习 - test\aesgcmtest.c概述笔记能学到的流程性内容END openssl3.2 - 测试程序的学习 - test\aesgcmtest.c 概述 openssl3.2 - 测试程序的学习 aesgcmtest.c 工程搭建时, 发现没有提供 test_get_options(), cleanup_tests(), 需要自己补上…

公考之判断推理(一、图形推理)

一、前言 判断推理这一题型主要具体分为四种题型&#xff1a; 1.图形推理 2.类比推理 3.定义判断 4.逻辑判断每种题型做题方法又不一样。 才本文采用总分的形式结构。 每一小标题的下面紧接着就是总结。二、图形推理常见的命题形式 图形推理常见的命题形式&#xff1a; 1.…

路飞项目--04

分析后端接口 # 用户板块--原型图--分析需要写哪些接口 多方式登录接口 短信登录接口 发送短信接口 短信注册接口 校验手机号是否注册接口 手机号是否存在接口 思路&#xff1a; 1 用了全局异常捕获&#xff0c;直接抛出异常报错 2 路由用了自定义路由&…

剑指offer——删除链表的节点

题目描述&#xff1a;给定单向链表的头指针和一个要删除的节点的值&#xff0c;定义一个函数删除该节点。返回删除后的链表的头节点。 数据范围&#xff1a; 0 <链表节点值 < 10000 0 <链表长度 < 10000 示例1&#xff1a; 输入&#xff1a;{2,5,1,9}&#xff…

NIO-Selector详解

NIO-Selector详解 Selector概述 Selector选择器&#xff0c;也可以称为多路复⽤器。它是Java NIO的核⼼组件之⼀&#xff0c;⽤于检查⼀个或多个Channel的状态是否处于可读、可写、可连接、可接收等。通过⼀个Selector选择器管理多个Channel&#xff0c;可以实现⼀个线程管理…

STM32标准库——(5)EXTI外部中断

1.中断系统 中断&#xff1a;在主程序运行过程中&#xff0c;出现了特定的中断触发条件&#xff08;中断源&#xff09;&#xff0c;使得CPU暂停当前正在运行的程序&#xff0c;转而去处理中断程序&#xff0c;处理完成后又返回原来被暂停的位置继续运行 中断优先级&#xff…

Ribbon 体系架构解析

前面已经介绍了服务治理相关组件&#xff0c;接下来趁热打铁&#xff0c;快速通关Ribbon&#xff01;前面我们了解了负载均衡的含义&#xff0c;以及客户端和服务端负载均衡模型&#xff0c;接下来我们就来看下SpringCloud 下的客户端负载均衡组件Ribbon 的特点以及工作模型。 …

day04 两两交换链表中的节点、删除链表倒数第N个节点、链表相交、环形链表II

题目链接&#xff1a;leetcode24-两两交换链表中的节点, leetcode19-删除链表倒数第N个节点, leetcode160-链表相交, leetcode142-环形链表II 两两交换链表中的节点 基础题没有什么技巧 解题思路见代码注释 时间复杂度: O(n) 空间复杂度: O(1) Go func swapPairs(head *Li…

Android Handler完全解读

一&#xff0c;概述 Handler在Android中比较基础&#xff0c;本文笔者将对此机制做一个完全解读。读者可简单参考上述类图与时序图&#xff0c;便于后续理解。 二&#xff0c;源码解读 1&#xff0c;主线程伊始 众所周知&#xff0c;通过Zygote的fork方式&#xff0c;新创建…

腾讯云轻量应用Ubuntu服务器如何一键部署幻兽帕鲁Palworld私服?

幻兽帕鲁/Palworld是一款2024年Pocketpair开发的开放世界生存制作游戏&#xff0c;在帕鲁的世界&#xff0c;玩家可以选择与神奇的生物“帕鲁”一同享受悠闲的生活&#xff0c;也可以投身于与偷猎者进行生死搏斗的冒险。而帕鲁可以进行战斗、繁殖、协助玩家做农活&#xff0c;也…

网页转文件下载工具

为了更快捷copy博客 做了个 网页转文件下载工具 1.0.1 更新如下&#xff1a; javaphpjava提供页面转换文件的微服务APIphp调用接口&#xff0c;输出文件下载支持网页转md 1.0.2 更新如下&#xff1a; 样式表切换&#xff0c;白天or黑夜&#xff0c;cookie七天保质期 未…