大模型探索式轨迹优化:基于试错的自主智能体学习新方法

news2024/9/26 5:23:22

 人工智能咨询培训老师叶梓 转载标明出处

现有的开源LLMs在构建智能体方面的效果远不如GPT-4。标准的构建开源LLM智能体的方法涉及模仿学习,即基于专家轨迹对LLMs进行微调。然而,这些方法完全依赖于专家演示,由于对目标环境探索不足而可能产生次优策略,限制了它们的泛化能力。为了解决这一问题,来自北京大学、加州大学洛杉矶分校、俄亥俄州立大学和伊利诺伊大学香槟分校的研究者们提出了一种名为探索式轨迹优化(Exploration-based Trajectory Optimization, ETO)的新型学习方法。该方法允许智能体从探索失败中学习,通过迭代优化框架提高性能。

方法

ETO通过行为克隆开始训练基础智能体,然后通过迭代的方式不断从试错中增强策略。图1展示了探索式轨迹优化(ETO)的过程。在这一流程中,Agent首先通过行为克隆学习基础任务执行策略,然后在实际环境中探索并收集失败的轨迹。这些失败轨迹与先前收集的专家成功轨迹形成对比,Agent利用这些对比信息通过直接偏好优化(DPO)等技术更新其策略。这个过程循环进行,以提高Agent在完成任务时的性能和适应性。

首先,研究者们使用行为克隆(BC)来训练一个基础智能体。行为克隆是通过在专家互动轨迹数据上进行监督式微调,从而为构建强大的智能体打下坚实的基础。在这项工作中,研究者们采用了ReAct风格的轨迹来进行BC,这种方法在每次行动之前还会生成“思考链”(Chain-of-Thought, CoT)理由。

为了简化表示,研究者们用带有CoT的行动来表示。给定一个专家轨迹数据集D,其中包含多个轨迹,通过在自回归损失上微调大模型(LLM)来获得基础智能体πbase。这个过程中,智能体的参数θ被优化以最大化专家轨迹的概率。

然而,仅依赖专家轨迹的行为克隆无法使智能体探索环境,可能导致次优策略。为了训练更强大的智能体,让模型探索失败轨迹非常重要。研究者们采用了强化学习的方法,通过试错来让智能体主动探索环境并获得奖励,从而细化策略。

图 2 展示了ETO的迭代探索-训练循环的概述。从通过行为克隆训练的基础LLM智能体开始,该方法允许智能体迭代地收集失败轨迹,并通过不断从对比失败-成功轨迹对中学习来更新其策略。

在探索阶段,基础智能体πbase根据行为克隆的训练数据指令探索环境,收集轨迹。然后,环境会返回与轨迹相对应的奖励。研究者们基于最终奖励构建失败-成功轨迹对,并且只收集奖励不同的轨迹对。如果智能体生成的轨迹和专家轨迹都成功完成了任务,那么这个对就会被丢弃。最终,研究者们得到了对比轨迹数据集Dp。

在训练阶段,智能体的策略通过模拟轨迹对中的对比失败-成功信息来更新。给定轨迹对ew ≻ el| u,可以通过Bradley-Terry模型来模拟失败-成功关系。然后,使用最大似然法来获得最优策略πθ,这个优化目标旨在增加成功轨迹ew的可能性,并减少失败轨迹el的可能性,同时保持基本智能体能力的约束项。

为了进一步提高智能体的性能,ETO采用了迭代的探索-训练方式。在训练阶段之后,智能体的策略可以用来收集新的失败案例,并创建对比轨迹对。然后,这些新数据被用来通过轨迹对比学习进一步增强智能体。

ETO算法的输入包括专家轨迹数据集D、行为克隆步数T1、ETO迭代次数I、训练阶段步数T2和初始LLM策略πθ。算法首先进行行为克隆,然后迭代地从探索失败中学习,最后输出最终策略πθ。

通过迭代地探索环境以收集失败轨迹,并在训练阶段学习“失败-成功”轨迹对的对比信息来更新策略,ETO能够显著提高智能体的性能。这种方法不仅提高了智能体在已知任务上的性能,还增强了其在未知任务上的泛化能力。

想要掌握如何将大模型的力量发挥到极致吗?叶老师带您深入了解 Llama Factory —— 一款革命性的大模型微调工具。9月22日晚,实战专家1小时讲解让您轻松上手,学习如何使用 Llama Factory 微调模型。

加下方微信或评论留言,即可参加线上直播分享,叶老师亲自指导,互动沟通,全面掌握Llama Factory。关注享粉丝福利,限时免费录播讲解。

LLaMA Factory 支持多种预训练模型和微调算法。它提供灵活的运算精度和优化算法选择,以及丰富的实验监控工具。开源特性和社区支持使其易于使用,适合各类用户快速提升模型性能。

实验

实验在三个代表性的智能体数据集上进行,包括用于网络导航的WebShop、用于模拟科学实验的ScienceWorld,以及用于具身家务任务的ALFWorld。这些环境提供了不同任务的测试场景,并且能够以部分可观测马尔可夫决策过程(POMDP)的形式正式描述。

实验设置:

数据集:实验涉及的数据集统计信息在表 1中展示,包括训练集、已见测试集(Seen)、未见测试集(Unseen)以及专家轨迹的平均交互轮数(Turns)。

训练配置:主要使用Llama-2-7BChat作为构建LLM智能体的基础模型,并在8个NVIDIA A100 80G GPU上进行实验。对于SFT阶段,批量大小为64,学习率设置为1e-5,并使用余弦调度器。ETO训练阶段的批量大小为32,学习率设置为1e-6,DPO损失中的β值根据数据集不同而有所调整。

基线对比:

与ETO进行对比的基线方法包括SFT行为克隆、Best-of-N采样、RFT(拒绝采样微调)和PPO(近端策略优化)。此外,还比较了GPT-3.5-Turbo、GPT-4以及未调优的Llama-2-7BChat。

评估:

所有方法都使用ReAct风格的交互格式进行评估,即在每个动作之前生成CoT理由。评估指标主要是测试集中所有任务实例的平均奖励。

实验结果:

表 2展示了ETO与基线方法在三个智能体数据集上的性能对比。ETO在所有数据集上的性能均显著超过了SFT模仿学习,并且在WebShop数据集上甚至超过了GPT-4的平均奖励,显示出非凡的性能。

在图 3中,展示了ETO与SFT基础智能体以及Oracle智能体在ScienceWorld-Seen测试集上的得分轨迹案例。ETO能够在更少的行动步骤中达到比SFT基础智能体更高的奖励,在某些情况下,如15-90和19-23,甚至超过了Oracle智能体,更早地达到了满分。

表 3 展示了基于不同基础LLMs(包括Llama-2-13B-Chat和Mistral-7B)的ETO方法的结果。这些结果表明,ETO能够一致地提升智能体的性能,即使是在更强大的LLMs上。

在图 4 中,展示了ETO迭代次数对性能的影响。结果表明,在WebShop和ScienceWorld数据集上,ETO在前两次迭代中能够提升智能体的性能,但进一步增加迭代次数并没有带来持续的性能提升,反而在第三次迭代后性能开始下降。

表 4 展示了不同对比数据构建策略的比较结果。轨迹级别的对比(trajectory-wise contrastive)提供了最佳性能,而步级别的对比(step-wise contrastive)则表现出较低的稳定性。

在没有专家轨迹的挑战性场景中,研究者们探索了ETO方法的潜力(表5)。在这种情况下,智能体需要依赖自我对弈来提升能力。实验结果表明,单独使用ETO并不能在没有行为克隆的情况下提升性能,而RFT(拒绝采样微调)显示出在没有专家轨迹的情况下增强智能体能力的潜力。当将RFT与ETO结合使用时,可以进一步增强智能体的性能。

这些实验结果不仅展示了ETO方法在不同任务和不同智能体上的有效性,还揭示了通过试错学习进一步提升智能体泛化能力的可能性。

论文链接:https://arxiv.org/pdf/2403.02502

代码链接:github.com/Yifan-Song793/ETO

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2137831.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows11+ubuntu20.04.6双系统安装

记录win11和ubuntu20.04.6在单个硬盘上安装的主要流程 系统说明 BIOS模式: UEFI 硬盘: 1TB固态 内存: 32GB 步骤 1、 准备两个不小于16GB的U盘,一个用于装Windows,一个用于装ubuntu,注意8G的U盘虽然能够…

操作系统知识点-进程与线程,一文搞懂!

本文图片均来自王道考研 一、进程的概念、组成和特征 进程(Process)是计算机中的一个核心概念,它是对正在运行的程序的一个抽象表示。在计算机科学中,一个进程是系统进行资源分配和调度的一个独立单元,是操作系统结构…

Python数据分析 Pandas基本操作

Python数据分析 Pandas基本操作 一、Series基础操作 ​ Series是pandas的基础数据结构,它可以用来创建一个带索引的一维数组,下面开始介绍它的基础操作 1、创建Series 1)使用数据创建Series: import pandas as pd pd.Series(1…

学习笔记JVM篇(三)

一、垃圾回收机制 垃圾回收(Garbage Collection)机制,是自动回收无用对象从而释放内存的一种机制。Java之所以相对简单,很大程度是归功于垃圾回收机制。(例如C语言申请内存后要手动的释放) 优点&#xff…

基于less和scss 循环生成css

效果 一、less代码 复制代码 item-count: 12; // 生成多少个 .item 类.item-loop(n) when (n > 0) {.icon{n} {background: url(../../assets/images/menu/icon{n}.png) no-repeat;background-size: 100% 100%;}.item-loop(n - 1);}.item-loop(item-count);二、scss代码 f…

在线查看 Android 系统源代码 Android Code Search

在线查看 Android 系统源代码 Android Code Search 1. Android Code Search2. Android2.1. platform/superproject2.2. build/envsetup.sh2.3. build/make/envsetup.sh References 1. Android Code Search https://cs.android.com/ Android https://cs.android.com/android An…

PCIe进阶之TL:Address Spaces, Transaction Types, and Usage

1 Transaction Layer Overview 如上图为PCIe设备的一个分层结构,从上层逻辑看,事务层的关键点是: 流水线式的完整的 split-transaction 协议事务层数据包(TLP)的排序和处理基于信用的流控制机制可选支持的数据中毒功能和端到端数据完整性检测功能事务层包含以下内容: TLP…

【C++】标准库IO查漏补缺

【C】标准库 IO 查漏补缺 文章目录 系统I/O1. 概述2. cout 与 cerr3. cerr 和 clog4. 缓冲区5. 与 printf 的比较 系统I/O 1. 概述 标准库提供的 IO 接口,包含在 iostream 文件中 输入流: cin输出流:cout / cerr / clog。 输入流只有一个 cin&#x…

MFC工控项目实例之十六输入信号验证

承接专栏《MFC工控项目实例之十五定时刷新PC6325A模拟量输入》 验证选定的输入信号实时状态 在BoardTest.cpp文件中添加代码 void CBoardTest::OnButton2() {// TODO: Add your control notification handler code hereisThreadBegin true; //运行线程执行pThre…

medium_socnet

0x00前言 靶场要安装在virtualbox (最新版)。否者会出现一些问题。 攻击机:kali2024 靶机:medium_socnet 0x01信息搜集 因为把靶机和虚拟机啊放在了同一网段。 所以我先使用了 arp-scan,查看有多少同一网段ipUP 。 经过推断…

OSS对象资源管理

1、登录aliyun 1.1、什么是OSS?有什么用? OSS 是“Object Storage Service”的缩写,中文常称为“对象存储服务”。OSS 是一种互联网云存储服务,主要用于海量数据的存储与管理。 相较于nginx,OSS更灵活,不…

点云深度学习系列:Sam2Point——基于提示的点云分割

文章:SAM2POINT:Segment Any 3D as Videos in Zero-shot and Promptable Manners 代码:https://github.com/ZiyuGuo99/SAM2Point Demo:https://huggingface.co/spaces/ZiyuG/SAM2Point 1)摘要 文章介绍了SAM2POINT,这是…

跟《经济学人》学英文:2024年09月14日这期 People are splurging like never before on their pets

People are splurging like never before on their pets Would you buy your furry companion a cologne? like never before:从未有过;未曾发生过 splurge:挥霍;浪费;破费;大量花费;过度消…

python 读取excel数据存储到mysql

一、安装依赖 pip install mysql-connector-python 二、mysql添加表students CREATE TABLE students (ID int(11) NOT NULL AUTO_INCREMENT,Name varchar(50) DEFAULT NULL,Sex varchar(50) DEFAULT NULL,PRIMARY KEY (ID) ) ENGINEInnoDB AUTO_INCREMENT13 DEFAULT CHARSETu…

S32K3 工具篇5:如何使用lauterbach下载调试elf文件

S32K3 工具篇5:如何使用lauterbach下载调试elf文件 一,利用trace32现有flash脚本烧录elf二,debug 现有elf文件 之前写过如何在S32DS中使用lauterbach下载,但是对于RTD EB MCAL的代码,通常情况下是使用命令的方式去编译…

Spring Boot母婴商城:安全、便捷、高效

2 相关技术 2.1 SSM框架介绍 本课题程序开发使用到的框架技术,英文名称缩写是SSM,在JavaWeb开发中使用的流行框架有SSH、SSM、SpringMVC等,作为一个课题程序采用SSH框架也可以,SSM框架也可以,SpringMVC也可以。SSH框架…

C语言 | Leetcode C语言题解之第399题除法求值

题目: 题解: /*** Note: The returned array must be malloced, assume caller calls free().*/typedef struct hash_node_t {char *key;double val;int distinguish_flag; // 用于区分不同的关系struct hash_node_t *p_next; }HASH_NODE_T;typedef str…

clip论文阅读(Learning Transferable Visual Models From Natural Language Supervision)

目录 摘要训练pre-train model的过程将pre-train model应用于下游任务应用(待更新) 论文/项目地址:https://github.com/OpenAI/CLIP 提供了clip的pre-trained model的权重,也可安装使用pre-trained model 摘要 使用标签标注的图…

【IEEEACM Fellow、CCF组委】第三届人工智能与智能信息处理国际学术会议(AIIIP 2024)

第三届人工智能与智能信息处理国际学术会议(AIIIP 2024) 2024 3rd International Conference on Artificial Intelligence and Intelligent Information Processing 中国-天津 | 2024年10月25-27日 | 会议官网:www.aiiip.net 会…

【CTF MISC】XCTF GFSJ1086 [简单] 简单的base编码 Writeup(Base64编码+循环解码+Base92编码)

[简单] 简单的base编码 你懂base编码吗? 工具 在线BASE92编码解码:https://ctf.bugku.com/tool/base92 解法 Vm0wd2QyUXlVWGxWV0d4V1YwZDRWMVl3WkRSV01WbDNXa1JTVjAxV2JETlhhMUpUVmpBeFYySkVUbGhoTVVwVVZtcEJlRll5U2tWVWJHaG9UVlZ3VlZadGNFSmxSbGw1V…