强化学习-RLHF-PPO入门

news2024/12/26 2:22:45

一、定义

  1. 强化学习微调分类
  2. RM模型 数据集格式
  3. 训练流程
  4. Reward 模型训练流程(分类模型,积极为1,消极为0) AutoModelForSequenceClassification
  5. Reward 模型训练案例
  6. PPO模型训练流程
  7. PPO模型训练案例

二、实现

  1. 强化学习微调分类
    RLHF:基于人类反馈对语言模型进行强化学习, 分两步:
    1. RM (Reward Model)奖励模型建模,构造人类偏好排序数据集,训练奖励模型,用来建模人类偏好。
    2 RL(Reinforcement Learning)强化学习,用奖励模型来训练SFT模型,生成模型使用奖励或惩罚来更新其策略,以便生成更高质量、更符合人类偏好的文本.
    DPO(Direct Preference Optimization): 直接偏好优化方法,DPO通过直接优化语言模型来实现对其行为的精确控制,而无需使用复杂的强化学习,也可以有效学习到人类偏好。
    RLHF主要是进行对齐微调, 目标是将大语言模型的行为与人类的价值观或偏好对齐。
    PPO: (Proximal Policy Optimization,近端策略优化)是一种在强化学习领域广泛使用的算法.

  2. RM模型 数据集格式

{conversations:  [0: 
 {from:  
"human",value:  "国会的转发 美国国会由众议院和参议院组成,每两年换届一次(参议员任期为6年,但参议院选举是错位的。是更常见地转发国会议员还是来自国会外部?"}],
chosen:  {from:  "gpt",value:  "计算推文的政党边际概率,我们可以使用以下代码这表明大多数转发不是来自国会议员,而是来自国会之外。"},
rejected:  {from:  "gpt",value:  "回答问题的第(计算转发国会议员或来自国会以外的人的边际概率"}}

其中chosen 代表是好的回答, rejected代表的是不好的回答

  1. 训练流程
    在这里插入图片描述
    训练reward Model---->PPO模型

  2. Reward 模型训练流程(激励模型为深度学习模型)
    数据处理:

def preprocess_function(examples):
    new_examples = {
        "input_ids_chosen": [],
        "attention_mask_chosen": [],
        "input_ids_rejected": [],
        "attention_mask_rejected": [],
    }
    for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
        tokenized_chosen = tokenizer(chosen)
        tokenized_rejected = tokenizer(rejected)

        new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
        new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
        new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
        new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
    return new_examples

训练求损失:AutoModelForSequenceClassification 分类模型

model = AutoModelForSequenceClassification.from_pretrained(
    model_config.model_name_or_path, num_labels=1, **model_kwargs
)
def compute_loss(
    self,
    model: Union[PreTrainedModel, nn.Module],
    inputs: Dict[str, Union[torch.Tensor, Any]],
    return_outputs=False,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
    if not self.use_reward_data_collator:
        warnings.warn(
            "The current compute_loss is implemented for RewardDataCollatorWithPadding,"
            " if you are using a custom data collator make sure you know what you are doing or"
            " implement your own compute_loss method."
        )
    rewards_chosen = model(
        input_ids=inputs["input_ids_chosen"],
        attention_mask=inputs["attention_mask_chosen"],
        return_dict=True,
    )["logits"]
    rewards_rejected = model(
        input_ids=inputs["input_ids_rejected"],
        attention_mask=inputs["attention_mask_rejected"],
        return_dict=True,
    )["logits"]
    # calculate loss, optionally modulate with margin
    if "margin" in inputs:
        loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
    else:
        loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()

    if return_outputs:
        return loss, {
            "rewards_chosen": rewards_chosen,
            "rewards_rejected": rewards_rejected,
        }
    return loss
  1. Reward 模型训练案例
    https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py

  2. PPO模型训练流程

在这里插入图片描述

步骤:
1. 语言模型预测
2. 激活模型评估(分类模型),1 代表积极,0 代表消极
3. PPO算法优化。
数据:

def tokenize(sample):
    sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
    sample["query"] = tokenizer.decode(sample["input_ids"])
    return sample
  1. 加载模型, 参考模型(参考模型可以为None)
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
# Get response from gpt2    待训练的模型响应,参考模型响应
response_tensors, ref_response_tensors = ppo_trainer.generate(
    query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs
)
batch["response"] = tokenizer.batch_decode(response_tensors)
batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors)
  1. 激活模型评估(分类模型),1 代表积极,0 代表消极
2. 获取激励值
# Compute sentiment score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)                      #激励函数
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]   #激励值
  1. PPO算法优化。
#   问题query  、  模型响应   、激励值
#其中上图优化模块,均在step 方法中实现。
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
step内部:
        all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
            self.model,
            queries,
            responses,
            model_inputs,
            response_masks=response_masks,
            return_logits=full_kl_penalty,
        )
        with self.optional_peft_ctx():
            ref_logprobs, ref_logits_or_none, _, _ = self.batched_forward_pass(
                self.model if self.is_peft_model else self.ref_model,
                queries,
                responses,
                model_inputs,
                return_logits=full_kl_penalty,
            )
  1. PPO模型训练案例
    https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1861708.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

深度学习 —— 1.单一神经元

深度学习初级课程 1.单一神经元2.深度神经网络3.随机梯度下降法4.过拟合和欠拟合5.剪枝、批量标准化6.二分类 前言 本套课程仍为 kaggle 课程《Intro to Deep Learning》,仍按之前《机器学习》系列课程模式进行。前一系列《Keras入门教程》内容,与本系列…

eNSP中三层交换机的配置和使用

一、拓扑图 1.新建拓扑图 2.PC端配置 PC1: PC2&#xff1a; 二、基本命令配置 1.S1配置 <Huawei>system-view [Huawei]sysname S1 [S1]vlan 10 //在交换机 S1 上创建 VLAN 10 [S1-vlan10]vlan 20 // 在交换机 S1 上创建 VLAN 20 [S1-vlan20]quit //退出 VLAN 配置…

基于JSP的在线教育资源管理系统

开头语&#xff1a; 你好呀&#xff0c;我是计算机学长猫哥&#xff01;如果你对在线教育资源管理系统感兴趣或者有相关需求&#xff0c;欢迎在文末找到我的联系方式。 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;JSP技术 工具&#xff1a;IDE、N…

基于Java仓储出入库管理系统设计和实现(源码+LW+调试文档+讲解等)

&#x1f497;博主介绍&#xff1a;✌全网粉丝10W,CSDN作者、博客专家、全栈领域优质创作者&#xff0c;博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌&#x1f497; &#x1f31f;文末获取源码数据库&#x1f31f;感兴趣的可以先收藏起来&#xff0c;还…

代码随想录第34天|贪心算法

59.合并区间 class Solution { public:struct cmp{bool operator()(vector<int>& a, vector<int>& b) {return a[0] < b[0];}};vector<vector<int>> merge(vector<vector<int>>& intervals) {if (intervals.size() 1)retu…

六款顶级原型设计工具推荐,满足你所有需求!

即时设计作为一款专业原型工具&#xff0c;无论是从功能还是插件库配备情况来看&#xff0c;都是毫无疑问可以进行原型图设计的&#xff0c;而且&#xff0c;即时设计内设海量资源库&#xff0c;可以支持大家通过关键词进行搜索相关资源&#xff0c;并且在线编辑使用&#xff0…

Zookeeper:分布式系统中的协调者

Zookeeper&#xff1a;分布式系统中的协调者 前言&#xff1a;引言Zookeeper是什么&#xff1f; 基本概念Zookeeper 数据模型Znode 类型会话Watcher 应用场景分布式锁配置维护组服务名字服务 典型应用场景数据发布/订阅负载均衡命名服务分布式协调/通知集群管理Master选举 工作…

Open3D kitti数据集bin与pcd的相互转换

目录 一、Kitti数据集简介 1.1数据集内容 1.2数据集结构 二、代码实现 2.1bin转pcd 2.2pcd转bin 三、实现效果 一、Kitti数据集简介 KITTI 数据集是由德国卡尔斯鲁厄理工学院&#xff08;KIT&#xff09;和丰田美国技术研究院&#xff08;Toyota Technological Institut…

昇思25天学习打卡营第2天|快速入门

使用MindSpore实现简单的深度学习模型 环境配置 第一步当然是装包&#xff1a; !pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore2.2.14 import mindspore from mindspore import nn from mindspore.dataset import vision, transforms from mindspore.d…

深入了解MySQL的哈希索引

深入了解MySQL的哈希索引 哈希索引是一种基于哈希表的数据结构&#xff0c;通过对索引键值进行哈希运算&#xff0c;直接定位存储位置&#xff0c;从而实现快速数据访问。哈希索引在等值查询中表现尤为出色&#xff0c;但不适用于范围查询。虽然哈希索引在某些场景下可以显著提…

从特斯拉视角,看智能驾驶研究框架

第一章:回顾历史&#xff0c;智能驾驶的核心主线是算法的演进史&#xff0c;从2017年至今在感知侧规控侧实现算法从规则为主走向端到端。算法方面&#xff0c;2017-2022年&#xff0c;特斯拉在感知侧走向端到端&#xff0c;实现BEVTransformerOccupancy。2021-2023年&#xff0…

Python深度学习技术

原文链接&#xff1a;Python深度学习技术 近年来&#xff0c;伴随着以卷积神经网络&#xff08;CNN&#xff09;为代表的深度学习的快速发展&#xff0c;人工智能迈入了第三次发展浪潮&#xff0c;AI技术在各个领域中的应用越来越广泛。Transformer模型&#xff08;BERT、GPT-…

2024年最新中级会计职称考试题库。

46.甲将一汇票背书转让给乙&#xff0c;但该汇票上未记载乙的名称。其后&#xff0c;乙在该汇票被背书人栏内记载了自己的名称。根据《票据法》的规定&#xff0c;下列有关该汇票背书与记载效力的表述中&#xff0c;正确的是&#xff08;&#xff09;。 A.甲的背书无效&#x…

C语言:sprintf与snprintf

C语言提供了强大的格式化输出的接口&#xff0c;可以输出到不同的文件或者字符串等&#xff0c;以sprintf和snprintf为例介绍一下 sprintf 格式化输出到字符串 函数签名 int sprintf(char *str, const char *format, ...);与printf相比就是多了前面的char*参数&#xff0c;…

创新降重工具助力学术写作:提升论文独创性

现在大部分学校已经进入到论文查重降重的阶段了。如果查重率居高不下&#xff0c;延毕的威胁可能就在眼前。对于即将告别校园的学子们&#xff0c;这无疑是个噩梦。四年磨一剑&#xff0c;谁也不想在最后关头功亏一篑。 查重率过高&#xff0c;无非以下两种原因。要么是作为“…

【激光雷达使用记录】—— 如何在ubuntu中利用ros自带的rviz工具实时可视化雷达点云的数据

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、查看雷达数据的 frame_id1. 查看雷达数据的话题2. 查看数据的frame_id 二、可视化雷达数据总结 前言 RViz&#xff08;ROS Visualization&#xff09;是机…

华为面试题及答案——机器学习(二)

21. 如何评价分类模型的优劣? (1)模型性能指标 准确率(Accuracy): 定义:正确分类的样本数与总样本数之比。适用:当各类样本的数量相对均衡时。精确率(Precision): 定义:预测为正类的样本中实际为正类的比例。适用:当关注假阳性错误的成本较高时(例如垃圾邮件检测…

超细毛搭配超宽设计,一款更呵护牙龈的牙刷

牙龈敏感的时候&#xff0c;刷牙特别难受&#xff0c;最近试了试惠百施&#xff08;EBISU&#xff09;65孔宽头软毛牙刷&#xff0c;感觉它的口腔护理体验很不错。这款牙刷的设计独特&#xff0c;采用宽头设计&#xff0c;一次就能刷两排牙齿&#xff0c;极大地提高了清洁效率。…

2024广东省职业技能大赛云计算赛项实战——集群部署GitLab

集群部署GitLab 前言 题目是这样的&#xff1a; 在Kubernetes集群中新建命名空间gitlab-ci&#xff0c;将GitLab部署到该命名空间下&#xff0c;Deployment和Service名称均为gitlab&#xff0c;以NodePort方式将80端口对外暴露为30880&#xff0c;设置GitLab服务root用户的密…