Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

news2025/1/15 22:30:45

Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

目录

trl的简介

1、亮点

2、PPO是如何工作的:PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

trl的安装

trl的使用方法

1、基础用法

(1)、如何使用库中的SFTTrainer

(2)、如何使用库中的RewardTrainer

(3)、如何使用库中的PPOTrainer

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略


trl的简介

          TRL - Transformer Reinforcement Learning使用强化学习的全栈Transformer语言模型。trl 是一个全栈库,其中我们提供一组工具,用于通过强化学习训练Transformer语言模型和稳定扩散模型,从监督微调步骤(SFT)到奖励建模步骤(RM)再到近端策略优化(PPO)步骤。该库建立在Hugging Face 的 transformers 库之上。因此,可以通过 transformers 直接加载预训练语言模型。目前,大多数解码器架构和编码器-解码器架构都得到支持。请参阅文档或示例/文件夹,以查看示例代码片段以及如何运行这些工具。

GitHub地址:GitHub - huggingface/trl: Train transformer language models with reinforcement learning.

1、亮点

>> SFTTrainer:一个轻量级且友好的围绕transformer Trainer的包装器,可以在自定义数据集上轻松微调语言模型或适配器。

>> RewardTrainer: transformer Trainer的一个轻量级包装,可以轻松地微调人类偏好的语言模型(Reward Modeling)。

>> potrainer:用于语言模型的PPO训练器,它只需要(查询、响应、奖励)三元组来优化语言模型。

>> AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个转换器模型,每个令牌有一个额外的标量输出,可以用作强化学习中的值函数。

>> 示例:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j减少毒性,Stack-Llama示例等。

2、PPO是如何工作的PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

通过PPO对语言模型进行微调大致包括三个步骤:

Rollout

Rollout(展开):语言模型基于查询生成响应或继续,查询可以是句子的开头。

Evaluation

Evaluation(评估):使用一个函数、模型、人类反馈或它们的组合来评估查询和响应。重要的是,此过程应为每个查询/响应对产生一个标量值

Optimization

Optimization(优化):这是最复杂的部分。在优化步骤中,使用查询/响应对来计算序列中token的对数概率。这是通过训练的模型和一个参考模型(通常是微调之前的预训练模型)来完成的。两个输出之间的KL-散度被用作附加奖励信号,以确保生成的响应不会偏离参考语言模型太远。然后,使用PPO训练主动语言模型。

这个过程在下面的示意图中说明。

trl的安装

pip install trl

trl的使用方法

1、基础用法

(1)、如何使用库中的SFTTrainer

以下是如何使用库中的SFTTrainer的基本示例。SFTTrainer是用于轻松微调语言模型或适配器的transformers Trainer的轻量包装器。

# imports
from datasets import load_dataset
from trl import SFTTrainer

# get dataset
dataset = load_dataset("imdb", split="train")

# get trainer
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

# train
trainer.train()

(2)、如何使用库中的RewardTrainer

以下是如何使用库中的RewardTrainer的基本示例。RewardTrainer是用于轻松微调奖励模型或适配器的transformers Trainer的包装器。

# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer

# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

# load trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# train
trainer.train()

(3)、如何使用库中的PPOTrainer

以下是如何使用库中的PPOTrainer的基本示例。基于查询,语言模型创建响应,然后进行评估。评估可以是人工干预或另一个模型的输出。

# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch

# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)

tokenizer = AutoTokenizer.from_pretrained('gpt2')

# initialize trainer
ppo_config = PPOConfig(
    batch_size=1,
)

# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# get model response
response_tensor  = respond_to_batch(model, query_tensor)

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]

# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133865725

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133873621

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1099205.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LeetCode 2 两数相加

题目描述 链接:https://leetcode.cn/problems/add-two-numbers/?envTypefeatured-list&envId2ckc81c?envTypefeatured-list&envId2ckc81c 难度:中等 给你两个 非空 的链表,表示两个非负的整数。它们每位数字都是按照 逆序 的方式…

软件设计之工厂方法模式

工厂方法模式指定义一个创建对象的接口,让子类决定实例化哪一个类。 结构关系如下: 可以看到,客户端创建了两个接口,一个AbstractFactory,负责创建产品,一个Product,负责产品的实现。ConcreteF…

基于ssm008医院门诊挂号系统+jsp【附PPT|开题|任务书|万字文档(LW)和搭建文档】

主要功能 后台登录:4个角色 管理员: ①个人中心、修改密码、个人信息 ②药房管理、护士管理、医生管理、病人信息管理、科室信息管理、挂号管理、诊断信息管理、病例库管理、开药信息管理、药品信息管理、收费信息管理 药房: ①个人中心、修…

CSS阶详细解析一

CSS进阶 目标:掌握复合选择器作用和写法;使用background属性添加背景效果 01-复合选择器 定义:由两个或多个基础选择器,通过不同的方式组合而成。 作用:更准确、更高效的选择目标元素(标签)。…

计算机算法分析与设计(11)---贪心算法(活动安排问题和背包问题)

文章目录 一、贪心算法概述二、活动安排问题2.1 问题概述2.2 代码编写 三、背包问题3.1 问题描述3.2 代码编写 一、贪心算法概述 1. 贪心算法的定义:贪心算法是指在对问题求解时,总是做出在当前看来是最好的选择。也就是说,不从整体最优上加以…

CICD 流程学习(四)搜素服务与消息队列

一 搜索服务 1 Lucene概念 Lucene是一种高性能、可伸缩的信息搜索 (IR)库,在2000年开源,最初由鼎鼎大名的Doug Cutting开发。是基于Java实现的高性能的开源项目 Lucene采用了基于倒排表的设计原理,可以非常高效地实现文本查找&#xff0…

GEO生信数据挖掘(九)WGCNA分析

第六节,我们使用结核病基因数据,做了一个数据预处理的实操案例。例子中结核类型,包括结核,潜隐进展,对照和潜隐,四个类别。第七节延续上个数据,进行了差异分析。 第八节对差异基因进行富集分析。…

windows内网渗透正向代理

内网渗透正向代理 文章目录 内网渗透正向代理1 正向代理图2 环境准备2.1 正向代理需求: 3 网卡配置3.1 【redream】主机3.2 【base】主机双网卡3.3 【yvkong】网卡设置 4 启动4.1【redream】网卡配置:4.2【base】网卡配置:4.3【yvkong】网卡地…

ArcGis打开影像显示全黑解决方法

我们加载图像,显示如下: 解决方法: 问题分析:Gamma值高于1影像亮化,低于1影像暗化。栅格影像导入进来呈现黑色,可能是因为影像的“Gamma校正”设置出现问题,影响了影像的拉伸度、亮度、对比度等…

FTP客户端lftp

目录 准备 1 lftp介绍 2 lftp语法 3 lftp选项 4 下载 4.1 服务端 4.2 客户端 5 上传 5.1 客户端 5.2 服务端 准备 两台虚拟机(且保证互通)关闭防火墙和SeLinux。 关闭防火墙 systemctl stop firewalld 关闭SeLinux setenforce 0 vi /etc/s…

12-k8s-HPA自动扩缩容

文章目录 一、k8s弹性伸缩类型二、HPA原理三、metrics-server插件四、创建nginx提供负载测试五、部署HPA master操作即可 一、k8s弹性伸缩类型 Cluster-Autoscale: 集群容量(node数量)自动伸缩,跟自动化部署相关的,依赖iaas的弹性伸缩,主要用…

学术特稿 | 著名书法家项国就:中国古代书法章草美学展现的形式分析

、 论文入编:大型综合美术类核心期刊《新美域》杂志2023年第七期。 中国古代书法章草美学展现的形式分析 摘要:本文旨在探讨中国古代书法风格章草的美学特点、审美价值以及代表性作品和艺术家。章草的美学特点体现在简洁流畅的笔画、清晰规整的字形结构以…

[架构之路-237]:目标系统 - 纵向分层 - 网络通信 - DNS的递归查询和迭代查询

目录 一、DNS协议与DNS系统架构 1.1 什么是DNS协议 1.2 为什么需要DNS协议 1.3 DNS系统架构 二、DNS系统的查询方式 2.1 递归与迭代的比较 2.2 DNS递归查询 2.3 DNS迭代查询 一、DNS协议与DNS系统架构 1.1 什么是DNS协议 DNS(Domain Name System&#xff…

nRF52832蓝牙从机

具体内容直接参考《nRF52832开发指南》 本文仅对关键内容和容易搞错的内容进行描述。 广播事件 扫描事件 连接事件 从机框架 日志配置和初始化log_init,具体参考手册和示例程序;APP定时器初始化timers_init,具体参考手册和示例程序&…

EFLK与logstash过滤

目录 一、Filebeat工作原理: 二、为什么要使用Filebeat: 三、Filebeat和Logstash的区别: 四、logstash 的过滤插件: 五、FilebeatELK 部署: 1. 安装filebeat: 2. 设置 filebeat 的主配置文件&#xff1…

SpringBoot面试题7:SpringBoot支持什么前端模板?

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:SpringBoot支持什么前端模板? Spring Boot支持多种前端模板,其中包括以下几种常用的: Thymeleaf:Thymeleaf是一种服务器端Java模板引擎,能够…

【2】c++11新特性(稳定性和兼容性)—>超长整型 long long

c11标准要求long long整型可以在不同的平台上有不同的长度,但是至少64位,long long整型有两种: 有符号long long:–对应类型的数值可以使用LL或者ll后缀 long long num1 123456789LL; long long num2 123456789ll;无符号unsign…

Linux性能优化--实用工具:性能工具助手

8.0 概述 本章介绍一些在Linux系统上可用的实用程序,它们能够加强性能工具的有效性和可用性。实用工具本身不是性能工具,但是当它们与性能工具一起使用时,它们可以帮助完成如下功能:自动执行繁琐的任务、分析性能统计数据&#x…

【机器学习】逻辑回归

文章目录 逻辑回归定义损失函数正则化 sklearn里面的逻辑回归 逻辑回归 逻辑回归,是一种名为“回归”的线性分类器,其本质是由线性回归变化而来的,一种广泛使用于分类问题中的广义回归算法。 线性回归是机器学习中最简单的的回归算法&#…

基本分段存储管理方式(分段,段表,地址转换以及与分页管理对比)

1.分段 1.进程的地址空间: 按照程序自身的逻辑关系划分为若干个段,每个段都有一个段名 (在低级语言中,程序员使用段名来编程),每段从0开始编址. 2.内存分配规则: 以段为单位进行分配,每个段在内存中占据…