[论文阅读]MaIL: Improving Imitation Learning with Mamba

news2024/12/26 0:56:08

Abstract

这项工作介绍了mamba模仿学习(mail),这是一种新颖的模仿学习(il)架构,为最先进的(sota)变换器策略提供了一种计算高效的替代方案。基于变压器的策略由于能够处理具有固有非马尔可夫行为的人类记录数据而取得了显著成果。然而,它们的高性能伴随着大型模型的缺点,这使得有效的训练变得复杂。虽然状态空间模型(ssms)以其效率而闻名,但它们无法与变压器的性能相匹配。mamba显著提高了ssms和竞争对手对transformers的性能,使其成为il政策的一个有吸引力的替代方案。mail利用mamba作为骨干,并引入了一种形式化,允许在编码器-解码器结构中使用mamba。这种形式化使其成为一种通用的架构,既可以用作独立的策略,也可以用作更高级架构的一部分,例如扩散过程中的扩散器

对LIBERO IL benchmark和三个真实机器人实验的广泛评估表明,mail:i)在所有libero任务中都优于transformer,ii)即使在小数据集下也能实现良好的性能,iii)能够有效地处理多模态感官输入,iv)与transformer相比,对输入噪声更具鲁棒性

Introduction

这里,当前的方法要么使用仅解码器结构[5],要么使用解码器-编码器架构[6]。这些架构中哪一个擅长通常取决于任务。变压器的性能通常伴随着难以训练的大型模型,特别是在数据稀缺的领域。处理观测序列的另一种概念是状态空间模型[12]。这些模型假设观测值(嵌入)之间存在线性关系,通常在计算上更高效。最近的方法,如选择性状态空间模型mamba[13],严格提高了状态空间模型的性能,并在许多任务中与变压器竞争。由于其在推理速度、内存使用和效率方面的特性,mamba是一个有吸引力的il策略模型

邮件可以用作独立的策略,也可以用作更高级流程的一部分,例如扩散流程中的扩散器。我们以两种变体实现邮件。在仅解码器的变体中,mail处理噪声动作和观测特征[5]以及扩散过程的时间嵌入,并输出去噪动作。

Related Works

Sequence Models.

变压器中的自我关注机制允许并行处理序列,有效地解决了rnn在顺序数据处理中的局限性[17,18,19,20]。然而,结构化状态空间模型[12,22,22,13]为变压器提供了一种有吸引力的替代方案。变压器在序列长度上按二次缩放,而结构化状态空间模型则按线性缩放[13]

最近的工作[13]依赖于关联扫描,它也允许并行计算,但还允许输入相关的可学习矩阵[13]

Imitation Learning (IL).

早期的模仿学习方法主要侧重于学习状态-动作对之间的一一映射。但这些方法忽略了历史中包含的丰富时间信息。随后的方法结合了rnns来编码观测序列,证明了利用历史观测可以提高模型性能。然而,这些方法存在基于rnn架构的固有局限性,包括表示能力有限、序列建模时间长以及训练时间慢,因为它们不适合大规模并行化。

Transformer可以对长序列进行建模,同时通过并行序列处理保持训练效率。这一趋势延伸到具有多模态感官输入的il[36,37,38,39],其中变换器对图像和语言序列进行编码

最近,扩散模型在模仿学习中表现出了优越性[5,40,6,41,39]。由于其强大的泛化能力和丰富的表示能力来捕捉多模态动作分布,它们已成为模仿学习领域的sota

Preliminaries

3.1 Mamba: Selective State-Space Models

mamba[13]通过使用选择性扫描算子改进了结构化状态空间序列模型(ssms)

输入序列,输出序列,b、l、d分别表示批量大小、序列长度和维度

标准ssm定义了时不变参数和时间步长向量∆将x(l)的输入映射到隐藏状态,然后可以将其投影到输出y(l)

mamba通过使ssm参数成为输入的函数来实现选择机制

线性是指线性投影层,softplus是relu的平滑近似。那么输出可以通过以下方式计算

其中是离散化的[42]个具有时间步长∆的对应项。由于时变模型只能以循环方式计算,mamba进一步实现了一种硬件感知方法来高效计算选择性ssm。

图1:d-ma:mamba去噪架构集成了用于状态编码的resnet-18和用于动作编码的动作编码器。状态序列的长度为K,而扩散步骤t处的动作序列的长度是J。在将输入馈送到曼巴模块之前,位置编码(PE)和时间编码(TE)增强了输入,其中sk和ak共享相同的位置编码。曼巴模块有n×曼巴块,详细结构[13]如左图所示。mamba模块的输出由线性输出层处理,从而实现一步去噪操作。mamba块中的符号×表示矩阵乘法,σ表示silu激活函数。

 3.2 Policy Representations

在这项工作中,我们使用了两种策略表示:行为克隆(bc)和去噪扩散策略(ddps)。为了清楚起见,我们关注的是非连续性的情况。

Behavioral Cloning

行为克隆假设参数化的条件高斯分布作为策略表示,即最大化模型参数θ的可能性简化为均方误差(mse)损失,其中使用演示数据中的状态-动作对来近似s、a的期望。

Denoising Diffusion Policies

去噪扩散策略利用去噪函数从马尔可夫链中采样开始,对给定的观测值s产生无噪声动作

训练去噪函数,通过最小化损失来预测噪声动作的源噪声

,t上的期望对应于中的均匀采样。

4 Mamba for Imitation Learning

从成功的仅解码器(D-Tr)和编码器-解码器(ED-Tr)变换器中汲取灵感,我们提出了两种基于mamba的架构:仅解码器mamba(D-Ma)和编码器解码器mamba(ED-Ma)

这些架构充当策略的参数化。具体来说,当采用行为克隆(bc)时,这些架构将条件高斯分布的均值μθ参数化。当使用去噪扩散策略(ddps)时,这些架构将去噪函数εθ参数化。鉴于前一种情况的简单性,我们将重点介绍ddps背景下的这些架构

4.1 Decoder-Only Mamba

与仅解码器转换器类似,我们使用mamba块来处理输入。图1显示了纯解码器mamba架构的概述。ddps的解码器专用mamba被设计为学习去噪函数εθ,该函数接受一系列观测值,噪声动作和diffusion step t,来生成噪声较小的动作序列使用时间嵌入te对扩散步骤进行编码。使用resnet-18对观测值进行编码,在不同时间步长的图像之间共享权重。动作编码器enca用于对有噪声的动作输入进行标记。此外,位置嵌入pe被应用于观测和动作。然后,时间嵌入、状态嵌入和动作嵌入将被输入到mamba解码器decm中。mamba解码器是通过堆叠多个具有残差连接和层归一化的mamba块来实现的。算法1中示出了完整的推理例程

4.2 Encoder-Decoder Mamba

与仅包含自注意机制的解码器变压器相比,具有交叉注意的编码器-解码器变压器是一种更灵活有效的设计,可以处理复杂的输入输出关系,特别是在输入和输出序列结构不同的情况下。然而,由于目标和源共享相同的序列长度,mamba没有提供这样的机制来支持编码器-解码器结构。

我们提出了一种称为mamba聚合的新方法,用于设计mamba的编解码器版本。可视化可以在图2中找到。mamba编码器encm用于处理时间嵌入和状态嵌入,mamba解码器decm用于处理噪声嵌入。由于em和dm的输入具有不同长度的序列,我们建议添加可学习变量来补充每个序列

图2:ed-ma:与d-ma模型不同,ed-ma包含用于处理时间嵌入和状态嵌入的mamba编码器,以及用于处理噪声动作的mamba解码器。为了聚合来自编码器和解码器的信息,将可学习的动作变量引入编码器输入,将可习得的时间变量和状态变量引入解码器输出,以进行序列对齐。

act对比:

5 Experiments 

我们的调查侧重于以下关键问题:

q1)MaIL能否实现与变压器相当或更优的性能?

q2)MaIL可以使用多模式输入,如语言指令吗?

q3)MaIL如何有效地处理观察中的连续信息?

5.1 Baselines

我们的实验包含四种架构:去卷积变换器(d-tr)、编解码器变换器(ed-tr)、仅解码器mamba(d-ma)、编解码mamba(ed-ma)。

为了进行公平的比较,我们使用resnet18对每种方法的视觉输入进行编码。对于使用语言指令的任务,我们使用预训练的clip模型[43]来获得相应的语言嵌入,该嵌入用于所有方法的训练和推理。

基于上述设置,我们实施以下模仿学习策略:

行为克隆(bc)我们实现了一种用变压器和曼巴结构的mse损失训练的vanilla行为克隆策略。

深度学习论文中的黑话总结 (ngui.cc)

基于bc中相同结构的去噪扩散策略(ddp),我们进一步使用离散去噪过程实现了一种扩散策略[44]。我们为每种架构使用16个扩散时间步长进行训练和采样。

5.2 Simulation Evaluation

LIBERO

评估是使用libero基准进行的,该基准包括五个不同的任务套件:LIBERO-Spatial, LIBERO-Object, LIBERO-Goal, LIBERO-Long, and LIBERO90。每个任务套件包括10个任务和50个人类演示,但libero-90除外,它包含90个任务,50个演示。每个任务套件都旨在测试机器人学习和操纵能力的不同方面。任务可视化如图3所示。更多细节见附录c。

Evaluation Protocol

我们分别在五个libero任务套件中比较了每种方法。除了libero-90包含900个轨迹外,我们没有使用完整的演示,而是为每个子任务只使用了20%的演示,每个任务套件总共使用了100个轨迹。我们调整变压器和曼巴的超参数,确保它们的参数量相似。所有模型都训练了50个epoch,我们使用最后一个检查点进行评估。遵循libero的官方基准设置,我们为每个子任务执行了20次部署 rollouts,每个任务套件总共进行了200次评估,但libero-90除外,它包括1800次评估。我们报告了超过3个种子的每个任务套件的平均成功率。

Main Results.

我们在表1中报告了主要结果。我们基于mamba的架构d-ma和ed-ma在基于bc策略的所有libero任务套件中的表现明显优于基于转换器的方法

表1:libero基准测试的性能,其中“w/o语言”表示我们不使用语言指令,“w/language”表示我们使用从预训练的剪辑模型生成的语言令牌,h1和h5分别表示使用当前状态和5步历史状态

具体来说,基于曼巴的模型在libero-object和libero-90中的成功率提高了近30%。当使用ddp策略时,我们的模型始终超过变压器基线,在大多数任务中性能提高超过5%。这些结果证实了q1,表明mail的性能优于变压器。

为了解决q2问题,我们使用额外的语言嵌入作为输入,将邮件与libero-target和libero-90上的transformers进行了比较。我们观察到,在这些任务中,基于mamba的方法有了显著改进,表明邮件有效地利用了多模态输入。

鉴于最近的视觉模仿学习作品使用历史观察作为输入,我们用1和5个历史观察来评估这些方法。我们发现历史信息并不总是能提高绩效。 只有在libero对象中h5模型的表现优于h1模型,而在其他任务中,h5模型取得了类似或更差的结果。

基于mamba的h5模型的性能再次始终优于基于transformer的模型,这表明mail能够有效地捕获连续的观察特征,回答了问题3

Ablation on Observation Occlusions

为了进一步了解transformer和mamba的顺序学习能力,我们随机屏蔽图像区域并测试模型的性能下降。结果如图4所示。而对于零遮挡,变压器架构可以与mamba相当,添加遮挡会更快地降低变压器的性能,表明mamba可以更好地从历史序列中提取重要信息。

Ablation on Dataset Size

鉴于邮件仅在20%的演示中表现良好,我们有兴趣随着数据集大小的增加来评估其可扩展性。我们在libero空间任务上使用bc策略将基于mamba的模型与transformer模型进行了比较。结果如图4所示。很明显,当数据稀缺时,基于mamba的模型明显优于transformer,并且随着数据集大小的增加,其性能也相当。

5.3 Real Robot Evaluation

我们基于7自由度franka熊猫机器人设计了三个具有挑战性的任务,利用模型的视觉输入。位于机器人前方不同角度的两个摄像头提供视觉数据。一个图像被裁剪并调整为(128,256,3),而另一个图像则调整为(256,256,3)。

整个设置如图5所示。这些图像在每个时间步上堆叠以形成观察结果。我们从输入中排除了机器人状态,因为之前的研究报告称,包括它们可能会导致性能不佳[7]。

动作空间是8维的,包括关节位置和夹持器状态。下面详细介绍的任务设置如图6-8所示。相应的结果如表2-4所示。

我们使用ddp-h1模型将ed-tr与ed-ma进行了比较。我们对每种方法训练了100个迭代周期(收敛),并使用最终的检查点对模型进行了评估。对于每个任务,我们为对象执行了20个具有不同初始状态的展开。为了确保公平比较,我们对变压器和曼巴评估使用了相同的初始状态。从结果来看,基于mamba的方法与变压器模型取得了相当的结果。

6 Limitations

虽然mail在较小的数据集大小下表现出了出色的性能,但随着数据集的扩展,它的优势变得不那么明显。当在更大的数据集上训练时,mail的结果与transformer模型相当,但并不超过后者。

此外,mamba的设计是为了快速高效地处理大规模序列。然而,在序列相对较短的模仿学习策略的背景下,Transformer的推理时间与mamba相似。这降低了mamba在这些场景中的性能效率优势。

7 Conclusion

总之,这项工作提出了一种新的模仿学习(il)策略架构mail,它弥合了处理观察序列的效率和性能之间的差距。通过利用状态空间模型的优势并对其进行严格改进,mail为传统上基于大型复杂变压器的策略提供了一种有竞争力的替代方案。在编码器-解码器结构中引入mamba增强了其通用性,使其既适合独立使用,也适合集成到扩散过程等高级架构中。对libero-il基准测试和真实机器人实验的广泛评估表明,mail不仅匹配而且超越了现有基线的性能,使其成为一种有前景的il任务方法。

B Transformer Architecture

我们描述了扩散策略中的两种基于变换器的架构:仅解码器模型(图9)和编码器-解码器模型(见图10)。这两种架构都利用了变压器模型的优势来有效地处理顺序数据并捕获长期依赖关系。

图9:仅解码器学习块。该架构集成了用于状态编码的resnet-18和用于地平线j动作的动作编码器,这两个组件都馈入了一个自我注意机制。位置编码(pe)和时间编码(te)增强了输入。最终,自我关注的输出被馈送到线性输出层,以预测未来的行为

图10:编码器-解码器学习块。此图说明了为策略学习设计的编码器-解码器转换器块的架构。在编码器中,状态使用resnet-18进行编码,通过时间编码(te)和位置编码(pe)进行增强,并通过自我注意进行处理。解码器然后利用对编码动作的自注意,并采用交叉注意来整合来自编码器的编码状态。最终,交叉注意力的输出被馈送到线性输出层,以预测未来的行动。

D Model Details

D.1 Parameter Comparison

我们还评估了配备rtx2060gpu的本地pc上的推理时间,使用32的批处理大小,以确保在表5中的相同条件下评估所有模型。

D.2 Training Details

我们在表6中列出了基于transformer和基于mamba的策略的训练超参数。为了确保公平比较,我们将两种策略的超参数调整到同一水平。

这些策略是使用libero提供的人类专家演示进行训练的,在主实验中,我们只对每个任务进行10次演示。

所有模型都在配备4个a100 gpu的集群上训练,批处理大小为256,使用3个不同的种子在50个迭代周期内进行训练。最后,我们计算了这3个种子的平均成功率。

E.3 Data Collection

遥操作用于收集所有真实机器人任务的数据,其中领导者机器人由人类控制,追随者机器人跟随领导者机器人,如图12所示。物体被放置在跟随机器人的前方,摄像头看不到领导机器人或人类。将引导机器人的当前关节状态作为期望的关节状态发送给跟随机器人。夹持器的状态被认为是二进制的,要么关闭,要么打开。为引导机器人的夹持器设置阈值;如果当前宽度低于阈值,跟随机器人的夹持器将关闭,否则将打开。

E.4 Evaluation

对于评估,使用模型的输出有时会激活机器人的安全机制,因为它违反了一定的约束。为了解决这个问题,在当前关节位置和模型输出之间生成轨迹。然后在每个时间步长将该轨迹的点提供给机器人,而不是模型的原始输出。该轨迹的长度取决于模型的输出与当前机器人状态的距离。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1924490.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

思维+构造,CF 1059C - Sequence Transformation

一、题目 1、题目描述 2、输入输出 2.1输入 2.2输出 3、原题链接 1059C - Sequence Transformation 二、解题报告 1、思路分析 n 1,2,3的情况从样例已知 考虑n > 4的情况 我们考虑要字典序最大,自然要最早出现非1的数,…

老物件线上3D回忆展拓宽了艺术作品的展示空间和时间-深圳华锐视点

在数字技术的浪潮下,3D线上画展为艺术家们开启了一个全新的展示与销售平台。这一创新形式不仅拓宽了艺术作品的展示空间,还为广大观众带来了前所未有的观赏体验。 3D线上画展制作以其独特的互动性,让艺术不再是单一的视觉享受。在这里&#x…

【香菇带你学Linux】Linux环境下gcc编译安装【建议收藏】

文章目录 0. 前言1. 安装前准备工作1.1 创建weihu用户1.2 安装依赖包1.2.1 安装 GMP1.2.2 安装MPFR1.2.3 安装MPC 2. gcc10.0.1版本安装3. 报错解决3. 1. wget下载报错 4. 参考文档 0. 前言 gcc(GNU Compiler Collection)是GNU项目的一部分,…

Leetcode-203-移除链表元素-临时变量作用域-c++

题目详见https://leetcode.cn/problems/remove-linked-list-elements/ 题解代码 /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode() : val(0), next(nullptr) {}* ListNode(int x) : val(x), next(nullpt…

动手学深度学习(Pytorch版)代码实践 -注意力机制-Transformer

68Transformer 1. PositionWiseFFN 基于位置的前馈网络 原理:这是一个应用于每个位置的前馈神经网络。它使用相同的多层感知机(MLP)对序列中的每个位置独立进行变换。作用:对输入序列的每个位置独立地进行非线性变换&#xff0c…

【Python】数据分析-Matplotlib绘图

数据分析 Jupyter Notebook Jupyter Notebook: 一款用于编程、文档、笔记和展示的软件。 启动命令: jupyter notebookMatplotlib 设置中文格式:plt.rcParams[font.sans-serif] [KaiTi] # 查看本地所有字体 import matplotlib.font_manager a sorted…

《昇思25天学习打卡营第17天|K近邻算法实现红酒聚类》

K近邻算法原理介绍 K近邻算法(K-Nearest-Neighbor, KNN)是一种用于分类和回归的非参数统计方法,最初由 Cover和Hart于1968年提出是机器学习最基础的算法之一。它正是基于以上思想:要确定一个样本的类别,可以计算它与所…

Linux-指令

希望你开心,希望你健康,希望你幸福,希望你点赞! 最后的最后,关注喵,关注喵,关注喵,大大会看到更多有趣的博客哦!!! 喵喵喵,你对我真的…

韦东山嵌入式linux系列-具体单板的 LED 驱动程序

笔者使用的是STM32MP157的板子 1 怎么写 LED 驱动程序? 详细步骤如下: ① 看原理图确定引脚,确定引脚输出什么电平才能点亮/熄灭 LED ② 看主芯片手册,确定寄存器操作方法:哪些寄存器?哪些位?…

链接追踪系列-00.es设置日志保存7天-番外篇

索引生命周期策略 ELK日志我们一般都是按天存储,例如索引名为"zipkin-span-2023-03-24",因为日志量所占的存储是非常大的,我们不能一直保存,而是要定期清理旧的,这里就以保留7天日志为例。 自动清理7天以前…

Pytorch中nn.Sequential()函数创建网络的几种方法

1. 创作灵感 在创建大型网络的时候,如果使用nn.Sequential()将几个有紧密联系的运算组成一个序列,可以使网络的结构更加清晰。 2.应用举例 为了记录nn.Sequential()的用法,搭建以下测试网络&…

node js 快速构建部署 Wiki 风格的文档网站

easy-wiki 快速构建 项目地址 :https://github.com/enncy/easy-wiki 教程文档 :https://enncy.github.io/easy-wiki/index.html 本文将介绍如何通过内置插件快速构建 WIKI 文档,并自带侧边栏,顶部栏,丰富样式等功能 #…

WEB前端03-CSS3基础

CSS3基础 1.CSS基本概念 CSS是Cascading Style Sheets(层叠样式表)的缩写,它是一种对Web文档添加样式的简单机制,是一种表现HTML或XML等文件外观样式的计算机语言,是一种网页排版和布局设计的技术。 CSS的特点 纯C…

maven的settings.xml无法正确配置本地仓库路径

因为以前使用过新版的maven,现在要换个版本使用。 在配置新的本地仓库路径的时候突然发现居然idea居然识别不了我settings.xml里面配置的路径。 我很是震惊,明明之前一直都是这样子配置的。怎么突然间不行了。当我冥思苦想,在网上搜寻资料无果…

02:项目二:感应开关盖垃圾桶

感应开关盖垃圾桶 1、PWM开发SG901.1、怎样通过C51单片机输出PWM波?1.2、通过定时器输出PWM波来控制SG90 2、超声波测距模块的使用3、感应开关盖垃圾桶 需要材料: 1、SG90舵机模块 2、HC-SR04超声波模块 3、震动传感器 4、蜂鸣器 5、若干杜邦线 1、PWM开…

【深度学习 pytorch】迁移学习 (迁移ResNet18)

李宏毅深度学习笔记 《深度学习原理Pytorch实战》 https://blog.csdn.net/peter6768/article/details/135712687 迁移学习 实际应用中很多任务的数据的标注成本很高,无法获得充足的训练数据,这种情况可以使用迁移学习(transfer learning)。假设A、B是两…

第三期闯关基础岛

1、 Linux 基础知识 任务描述完成所需时间闯关任务完成SSH连接与端口映射并运行hello_world.py10min可选任务 1将Linux基础命令在开发机上完成一遍10min可选任务 2使用 VSCODE 远程连接开发机并创建一个conda环境10min可选任务 3创建并运行test.sh文件10min 1.1、SSH连接 使用…

Android Spinner

1. Spinner Spinner是下拉列表,如图3-14所示,通常用于为用户提供选择输入。Spinner有一个重要的属性:spinnerMode,它有2种情况: 属性值为dropdown时,表示Spinner的数据下拉展示,如图1&#xf…

自己动手写一个滑动验证码组件(后端为Spring Boot项目)

近期参加的项目,主管丢给我一个任务,说要支持滑动验证码。我身为50岁的软件攻城狮,当时正背着双手,好像一个受训的保安似的,中规中矩地参加每日站会,心想滑动验证码在今时今日已经是标配了,司空…

jenkins系列-06.harbor

https://github.com/goharbor/harbor/releases?page2 https://github.com/goharbor/harbor/releases/download/v2.3.4/harbor-offline-installer-v2.3.4.tgz harbor官网:https://goharbor.io/ 点击 Download now 链接,会自动跳转到上述github页面&am…