trl - 微调、对齐大模型的全栈工具

news2024/9/20 5:49:07

trl

文章目录

    • 一、关于 TRL
      • 亮点
    • 二、安装
      • 1、Python包
      • 2、从源码安装
      • 3、存储库
    • 三、命令行界面(CLI)
    • 四、如何使用
      • 1、`SFTTrainer`
      • 2、`RewardTrainer`
      • 3、`PPOTrainer`
      • 4、`DPOTrainer`
    • 五、其它
      • 开发 & 贡献
      • 参考文献
        • 最近策略优化 PPO
        • 直接偏好优化 DPO


一、关于 TRL

TRL : Transformer Reinforcement Learning

Full stack library to fine-tune and align large language models.

Train transformer language models with reinforcement learning.

  • github : https://github.com/huggingface/trl
  • 文档:https://huggingface.co/docs/trl/index

trl库是一个全栈工具,用于使用监督微调步骤(SFT)、奖励建模(RM)和近似策略优化(PPO)以及直接偏好优化(DPO)等方法微调和对齐转换器语言和扩散模型。

该库建立在transformers库之上,因此允许使用那里可用的任何模型架构。


亮点

  • Efficient and scalable
    • acceleratetrl的支柱,它允许使用DDP和DeepSpeed等方法将模型训练从单个GPU扩展到大规模多节点集群。
    • PEFT是完全集成的,即使是最大的模型也可以通过量化和LoRA或QLoRA等方法在适度的硬件上训练。
    • unsloth也是集成的,允许使用专用内核显着加快训练速度。
  • CLI:使用CLI,您可以使用单个命令和灵活的配置系统微调LLM并与之聊天,而无需编写任何代码。
  • Trainers:培训师类是一个抽象,可以轻松应用许多微调方法,如SFTTrainerDPOTrainerRewardTrainerPPOTrainerCPOTrainerORPOTrainer
  • AutoModelsAutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead 类为模型添加了一个额外的值头,允许使用RL算法(如PPO)训练它们。
  • Examples:使用BERT情感分类器训练GPT2以生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j毒性更小,StackLlama示例等。以下是示例。

二、安装


1、Python包

使用pip安装库:

pip install trl

2、从源码安装

如果您想在正式发布之前使用最新功能,您可以从源代码安装:

pip install git+https://github.com/huggingface/trl.git

3、存储库

如果您想使用这些示例,您可以使用以下命令克隆存储库:

git clone https://github.com/huggingface/trl.git

三、命令行界面(CLI)

您可以使用TRL命令行界面(CLI)快速开始使用监督微调(SFT)、直接偏好优化(DPO)并使用聊天CLI测试对齐的模型:

SFT:

trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb

DPO:

trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style --output_dir opt-sft-hh-rlhf 

聊天:

trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat

在 relevant documentation section 阅读有关CLI的更多信息,或使用--help获取更多详细信息。


四、如何使用

为了获得更多的灵活性和对训练的控制,您可以使用专用的训练类 来微调Python中的模型。


1、SFTTrainer

这是如何使用库中的SFTTrainer的基本示例。

SFTTrainer 是围绕transformersTrainer的轻型包装器,可轻松微调自定义数据集上的语言模型或适配器。

# imports
from datasets import load_dataset
from trl import SFTTrainer

# get dataset
dataset = load_dataset("imdb", split="train")

# get trainer
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

# train
trainer.train()

2、RewardTrainer

这是如何使用库中的RewardTrainer的基本示例。

RewardTrainer transformers Trainer 的包装器,可轻松微调自定义偏好数据集上的奖励模型或适配器。

# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer

# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

# load trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# train
trainer.train()

3、PPOTrainer

这是如何使用库中的PPOTrainer的基本示例。

基于查询,语言模型创建一个响应,然后对其进行评估。评估可以是循环中的人或另一个模型的输出。

# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch

# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
ref_model = create_reference_model(model)

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# initialize trainer
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1)

# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# get model response
response_tensor  = respond_to_batch(model, query_tensor)

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer)

# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]

# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

4、DPOTrainer

DPOTrainer是使用直接偏好优化算法的培训师,这是如何使用库中的DPOTrainer的基本示例DPOTrainertransformersTrainer的包装器,可轻松微调自定义偏好数据集上的奖励模型或适配器。

# imports
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DPOTrainer

# load model and dataset - dataset needs to be in a specific format
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

# load trainer
trainer = DPOTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# train
trainer.train()

五、其它

开发 & 贡献

如果您想为trl做出贡献或根据您的需求对其进行定制,请务必阅读贡献指南并确保您进行了开发安装:

git clone https://github.com/huggingface/trl.git
cd trl/
make dev

参考文献


最近策略优化 PPO

PPO实现在很大程度上遵循D. Ziegler等人的**“来自人类偏好的微调语言模型”**论文中介绍的结构。[论文,代码]。


直接偏好优化 DPO

DPO基于E. Mitchell等人的**《直接偏好优化:您的语言模型是秘密的奖励模型》**的原始实现。[论文,代码]


2024-07-17(三)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1931712.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

安全防御,防火墙配置NAT转换智能选举综合实验

目录: 一、实验拓扑图 二、实验需求 三、实验大致思路 四、实验步骤 1、防火墙的相关配置 2、ISP的配置 2.1 接口ip地址配置: 3、新增设备地址配置 4、多对多的NAT策略配置,但是要保存一个公网ip不能用来转换,使得办公区的…

c++入门----类与对象(上)

大家好啊,好久没有更新了。因为本人的愚笨,想与大家分享的话肯定还得自己明白了才能给大家分享吧。所以这几天都在内部消化。好给大家优质的文章。当然我写的肯定还是很有问题的,希望大家可以在评论区里面指出来。好,废话不多说&a…

LabVIEW 与 PLC 通讯方式

在工业自动化中,LabVIEW 与 PLC(可编程逻辑控制器)的通信至关重要,常见的通信方式包括 OPC、Modbus、EtherNet/IP、Profibus/Profinet 和 Serial(RS232/RS485)。这些通信协议各有特点和应用场景&#xff0c…

从图表访问Data Store Memory

Simulink模型将全局变量实现为数据存储,可以是数据存储内存块,也可以是Simulink.Signal的实例。您可以使用数据存储在多个Simulink块之间共享数据,而无需显式的输入或输出连接来将数据从一个块传递到另一个块。Stateflow图表通过符号化地读取…

警惕预言成真!3本预警、On Hold已被剔除,新增8本SCI/SSCI被除名!7月WOS更新(附下载)

本周投稿推荐 SCI • 能源科学类,1.5-2.0(25天来稿即录) • IEEE计算机类,4.0-5.0(48天录用) • 生物医学制药类(2天逢投必中) EI • 各领域沾边均可(2天录用&…

精益思维在数字工厂建设中的具体应用

在数字化浪潮席卷全球的今天,数字工厂建设已成为企业转型升级的必由之路。然而,如何确保数字工厂的高效运行和持续创新,成为摆在众多企业面前的难题。精益思维,作为一种追求持续改进和卓越绩效的管理理念,正成为助力数…

iPhone手机怎么识别藏文?藏语翻译通App功能介绍:藏文OCR识别提取文字

在工作学习的过程中,遇到不会的藏文,也不知道怎么把文字打出来,这个时候可以试试《藏语翻译通》App的图片识别功能,支持拍照识别和图片识别,拍一拍就能提取藏文文字,并支持一键翻译和复制分享。 跟着小编的…

汽车免拆诊断案例 | 2017 款林肯大陆车发动机偶尔无法起动

故障现象 一辆2017款林肯大陆车,搭载2.0T发动机,累计行驶里程约为7.5万km。车主进厂反映,有时按下起动按钮,起动机不工作,发动机无法起动,组合仪表点亮正常;多次按下起动按钮,发动机…

01大学物理电磁篇 静电场

5-6 静电场的环路定理 电势能 5-7电势 5-8电场强度与电势梯度

背部筋膜炎最有效的治疗方法

背部筋膜炎症状:背部筋膜炎引起的疼痛通常是钝痛或酸痛,且这种疼痛是无菌性炎症产生的炎症因子、疼痛因子刺激局部神经引起的。疼痛主要发生在腰背部,特别是两侧腰肌和髂嵴上方可能会更加明显。长时间不活动或活动过度都可能诱发疼痛。疼痛可…

使用element UI Cascader 级联选择器实现省/市/区选择

<template><div><label>位置</label><el-cascader:options"pcaTextArr"v-model"selectedOptions"change"handleChangeAddress":props"{expandTrigger: hover,multiple: true,checkStrictly: true,emitPath: fal…

windows 打包pyd文件

1.新建一个py文件&#xff0c;myunit.py&#xff0c;里面的代码是: class Adder: def __init__(self, a, b): self.a a self.b b def add(self): return self.a self.b 2.新建一个py文件&#xff0c;setup.py&#xff0c;里面的代码是: from setuptools import setup fro…

完整教程 linux下安装百度网盘以及相关依赖库,安装完成之后启动没反应 或者 报错

完整教程 linux下安装百度网盘以及相关依赖库&#xff0c;安装完成之后启动没反应 或者 报错。 配置国内镜像源&#xff1a; yum -y install wget mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.bak wget -O /etc/yum.repos.d/CentOS-Base.repo ht…

MySQL运维实战之Clone插件(10.1)使用Clone插件

作者&#xff1a;俊达 clone插件介绍 mysql 8.0.17版本引入了clone插件。使用clone插件可以对本地l或远程的mysql实例进行clone操作。clone插件会拷贝innodb存储引擎表&#xff0c;clone得到的是原数据库的一个一致性的快照&#xff0c;可以使用该快照数据来启动新的实例。cl…

服务器基础2

服务器基础复习02 1.网络管理 nmcli nmcli是NetworkManager的一个命令行工具&#xff0c;它提供了使用命令行配置由NetworkManager管理网络连接的方法。nmcli命令的基本格式为&#xff1a; nmcli [OPTIONS] OBJECT { COMMAND | help }其中&#xff0c;OBJECT选项可以是genera…

PHP旅游门票预订系统小程序源码

旅游门票预订系统&#xff1a;轻松规划&#xff0c;畅游无忧&#x1f30d; &#x1f3ab;【一键预订&#xff0c;说走就走】 还在为排队购票浪费时间而烦恼吗&#xff1f;旅游门票预订系统让你告别长龙&#xff0c;享受说走就走的旅行&#xff01;只需在手机或电脑上轻轻一点…

通过Dockerfile构建镜像

案例一&#xff1a; 使用Dockerfile构建tomcat镜像 cd /opt mkdir tomcat cd tomcat/ 上传tomcat所需的依赖包 使用tar xf 解压三个压缩包vim Dockerfile FROM centos:7 LABEL function"tomcat image" author"tc" createtime"2024-07-16"ADD j…

橙芯创想:香橙派AIPRO解锁升腾LLM与Stable Diffusion的创意密码

文章目录 引言 一. 香橙派AI PRO配置以及展示优秀的扩展能力实物展示 二、Ascend-LLM模型部署开机xshell连接香橙派实战运行部署运行结果分析开发版表现 三、Stable Diffusion文生图性能表现 四、体验总结性能噪音便捷性 引言 在科技的浪潮中&#xff0c;一场融合智慧与创意的盛…

AI 和平台工程对云原生演进的影响:将云之旅自动化到光速

2024 年和云原生 AI 技术的曙光标志着计算能力的重大飞跃。我们正在经历一个新时代&#xff0c;人工智能 &#xff08;AI&#xff09; 和平台工程融合在一起&#xff0c;改变云计算格局。人工智能现在正在与云计算融合&#xff0c;我们正在经历一个人工智能超越传统界限的时代&…

IDEA实现NPM项目的自打包自发布自部署

目录 前言 正文 操作背景 NPM自发布 Package自发布 NPM部署 尾声 &#x1f52d; Hi,I’m Pleasure1234&#x1f331; I’m currently learning Vue.js,SpringBoot,Computer Security and so on.&#x1f46f; I’m studying in University of Nottingham Ningbo China&#x1f…