TabR:检索增强能否让深度学习在表格数据上超过梯度增强模型?

news2024/9/23 19:25:27

这是一篇7月新发布的论文,他提出了使用自然语言处理的检索增强Retrieval Augmented技术,目的是让深度学习在表格数据上超过梯度增强模型。

检索增强一直是NLP中研究的一个方向,但是引入了检索增强的表格深度学习模型在当前实现与非基于检索的模型相比几乎没有改进。所以论文作者提出了一个新的TabR模型,模型通过增加一个类似注意力的检索组件来改进现有模型。据说,这种注意力机制的细节可以显著提高表格数据任务的性能。TabR模型在表格数据上的平均性能优于其他DL模型,在几个数据集上设置了新的标准,在某些情况下甚至超过了GBDT模型,特别是在通常被视为GBDT友好的数据集上。

TabR

表格数据集通常被表示为特征和标签对{(xi, yi)},其中xi和yi分别是第i个对象的特征和标签。一般有三种类型的主要任务:二元分类、多类分类和回归。

对于表格数据我们会将数据集分为训练部分、验证部分和测试部分,模型对“输入”或“目标”对象进行预测。当使用检索技术时,检索是在一组“上下文候选”或“候选”中完成的,被检索的对象称为“上下文对象”或简称为“上下文”。同一组候选对象用于所有输入对象。

论文的实验设置涉及调优和评估协议,其中需要超参数调优和基于验证集性能的早期停止。然后在15个随机种子的平均测试集上测试最佳超参数,并在算法比较中考虑标准偏差。

论文作者的目标是将检索功能集成到传统的前馈网络中。该过程包括通过编码器传递目标对象及其上下文候选者,然后检索组件会对目标对象进行的表示,最后预测器进行预测。

编码器和预测器模块很简单简单,因为它们不是工作的重点。检索模块对目标对象的表示以及候选对象的表示和标签进行操作。这个模块可以看作是注意力机制的一般化版本。

这个过程包括几个步骤:

  • 如果编码器包含至少一个块,则将表示进行规范化;
  • 根据与目标对象的相似性定义上下文对象;
  • 基于softmax函数对上下文对象的相似性分配权重;
  • 定义上下文对象的值;
  • 使用值和权重输出加权聚合。

上下文大小设置为一个较大的值96,softmax函数会自动选择有效的上下文大小。

检索模块是最重要的部分

作者探讨了检索模块的不同实现,特别是相似度模块和值模块。并且说明了是通过一下几个步骤得到最终的模型。

1、作者评估了传统注意力的相似性和值模块,发现该配置与多层感知器(MLP)相似,因此不能证明使用检索组件是合理的。

2、然后他们将上下文标签添加到值模块中,但发现这并没有改进,这表明传统注意力的相似性模块可能是瓶颈。

3、为了改进相似度模块,作者删除了查询的概念,并用L2距离替换点积。这种调整使得几个数据集上性能的显著跃升。

4、值模块也进行改进,灵感来自最近提出的DNNR(用于回归问题的kNN算法的广义版本)。新的值模块带来了进一步的性能改进。

5、最后,作者创建模型TabR。在相似性模块中省略缩放项,不包括目标对象在其自身的上下文中(使用交叉注意),平均而言会得到更好的结果。

生成的TabR模型为基于检索的表格深度学习问题提供了一种健壮的方法。

作者也强调了TabR模型的两个主要局限性:

与所有检索增强模型一样,从应用程序的角度来看,使用真实的训练对象进行预测可能会带来一些问题,例如隐私和道德问题。

TabR的检索组件虽然比以前的工作更有效,但会产生明显的开销。所以它可能无法有效地扩展以处理真正的大型数据集。

实验结果

作者将TabR与现有的检索增强解决方案和最先进的参数模型进行比较。除了完全配置的TabR,他们还使用了一个简化版本,TabR- s,它不使用特征嵌入,只有一个线性编码器和一个块预测器。

与全参数深度学习模型的比较表明,TabR在几个数据集上优于大多数模型,除了MI数据集,在其他数据集也很有竞争力。在许多数据集上,它比多层感知器(MLP)提供了显著的提升。

与GBDT模型相比,调整后的TabR在几个数据集上也有明显的改进,并且在其他数据集上保持竞争力(除了MI数据集),并且TabR的平均表现也优于GBDT模型。

总之,TabR将自己确立为表格数据问题的强大深度学习解决方案,展示了强大的平均性能,并在几个数据集上设置了新的基准。它的基于检索的方法具有良好的潜力,并且在某些数据集上可以明显优于梯度增强的决策树。

一些研究

1、冻结上下文以更快地训练TabR

在TabR的原始实现中,由于需要对所有候选对象进行编码并计算每个训练批次的相似度,因此在大型数据集上的训练可能很慢。作者提到在完整的“Weather prediction”数据集上训练一个TabR需要18个多小时,该数据集有300多万个对象。

作者注意到在训练过程中,平均训练对象的上下文(即,根据相似度模块S,前m个候选对象及其分布)趋于稳定,这为优化提供了机会。在一定数量的epoch之后,他们提出了一个“上下文冻结”,即最后一次计算所有训练对象的最新上下文,然后在其余的训练中重用。

这种简单的技术可以加速TabR的训练,并且不会在指标上造成重大损失。在上面提到的完整的“Weather prediction”数据集上,它使速度提高了近7倍(将训练时间从18小时9分钟减少到3小时15分钟),同时仍然保持有竞争力的均方根误差(RMSE)值。

2、用新的训练数据更新TabR不需要再训练(初步探索)

在现实世界的场景中,在机器学习模型已经训练完之后,通常会收到新的、看不见的训练数据。作者测试了TabR在不需要再训练的情况下合并新数据的能力,方法是将新数据添加到候选检索集中。

他们使用完整的“Weather prediction”数据集进行了这个测试。结果表明在线更新可以有效地将新数据整合到训练好的TabR模型中。这种方法可以通过在数据子集上训练模型并从完整数据集中检索模型来将TabR扩展到更大的数据集。

3、使用检索组件增强XGBoost

作者试图通过结合类似于TabR中的检索组件来提高XGBoost的性能。这种方法涉及在原始特征空间中找到与给定输入对象最接近的96个训练对象(匹配TabR的上下文大小)。然后对这些最近邻的特征和标签进行平均,将标签按原样用于回归任务,并将其转换为用于分类任务的单一编码。

将这些平均数据与目标对象的特征和标签连接起来,形成XGBoost的新输入向量。但是该策略并没有显著提高XGBoost的性能。试图改变邻居的数量也没有产生任何显著的改善。

总结

深度学习模型在表格类数据上一直没有超越梯度增强模型,TabR还在这个方向继续努力。

如果你对他感兴趣,一下是论文和源代码:

https://avoid.overfit.cn/post/9e8cc5f506af4b368516876e108a62c7

作者:Andrew Lukyanenko

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/834077.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构——单链表OJ题(第二弹)

单链表OJ题 前言一、返回链表开始入环的第一个节点思路一思路二 二、返回链表的深度拷贝总结 前言 此次练习题有两道! 有点小难度,但相信难不住大家的! 我也会给出两道OJ题的链接,大家也赶快去试一试吧 一、返回链表开始入环的第…

【unity】Pico VR 开发笔记(视角移动)

【unity】Pico VR 开发笔记(视角移动) 视角移动是简单的基础功能,这里区别于头显定位获得的小范围位移,是长距离不影响安全边界的位移方式。的常见的位移方式有两种,其一是触发后瞬间传送到指定位置,其次是…

【雕爷学编程】MicroPython动手做(32)——物联网之MQTT

MQTT (Message Queuing Telemetry Transport)消息队列遥测传输协议,是一种基于发布/订阅(publish/subscribe)模式的"轻量级"通讯协议,该协议构建于TCP/IP协议上,由IBM在1999年发布。M…

il汇编整数相加

在这里尝试了IL汇编字符串连接; IL汇编字符串连接_bcbobo21cn的博客-CSDN博客 下面来看一下IL汇编整数相加; 大概的看一下一些资料,下面语句, ldc.i4 20 ldc.i4 30 add 看上去像是,装载整数20到一个类似于…

《Web安全基础》03. SQL 注入

web 1:简要 SQL 注入2:MySQL 注入2.1:信息获取2.2:跨库攻击2.3:文件读写2.4:常见防护 3:注入方法3.1:类型方法明确3.2:盲注3.3:编码3.4:二次注入3…

【前瞻】视频技术的发展趋势讨论以及应用场景

视频技术的发展可以追溯到19世纪初期的早期实验。到20世纪初期,电视技术的发明和普及促进了视频技术的进一步发展。 1)数字化:数字化技术的发明和发展使得视频技术更加先进。数字电视信号具有更高的清晰度和更大的带宽,可以更快地…

Knife4j系列--解决下载文件乱码的问题

原文网址:Knife4j系列--解决下载文件乱码的问题_IT利刃出鞘的博客-CSDN博客 简介 说明 本文介绍Knife4j如何解决下载文件不显示下载按钮,返回的是乱码的问题。 相关网址 文件下载始终是Knife4j.txt Issue #I374SP 萧明/knife4j - Gitee.com 注意…

Webpack怎么使用?

Webpack 使用 前几篇文章中已经介绍了如何初始化包管理器 package.json 这里不再重复介绍,如有需要请查看 搭建工程化项目。 安装 :::warning 注意 请确保你已经安装了 yarn,如有需要请查看 搭建工程化开发环境。 ::: 通过命令 yarn add webpack web…

[腾讯云Cloud Studio实战训练营]无门槛使用GPT+Cloud Studio辅助编程完成Excel自动工资结算

目录 前言一、Cloud Studio产品介绍1.1 注册Cloud Studio 二、项目实验2.1 选择合适的开发环境2.2 实验项目介绍2.3 实验步骤三、总结 前言 chatgpt简单介绍: ChatGPT是一种基于GPT的自然语言处理模型,专门用于生成对话式文本。它是OpenAI于2021年发布的&#xff0…

2023华数杯数学建模C题

母亲身心健康对婴儿成长的影响 母亲是婴儿生命中最重要的人之一,她不仅为婴儿提供营养物质和身体保护, 还为婴儿提供情感支持和安全感。母亲心理健康状态的不良状况,如抑郁、焦虑、 压力等,可能会对婴儿的认知、情感、社会行为等方…

【万字长文】SpringBoot整合Atomikos实现多数据源分布式事务(提供Gitee源码)

前言:在最近的实际开发的过程中,遇到了在多数据源的情况下要保证原子性的问题,这个问题当时遇到了也是思考了一段时间,后来通过搜集大量资料与学习,最后是采用了分布式事务来解决这个问题,在讲解之前&#…

pytorch的CrossEntropyLoss交叉熵损失函数默认是平均值

pytorch中使用nn.CrossEntropyLoss()创建出来的交叉熵损失函数计算损失默认是求平均值的,即多个样本输入后获取的是一个均值标量,而不是样本大小的向量。 net nn.Linear(4, 2) loss nn.CrossEntropyLoss() X torch.rand(10, 4) y torch.ones(10, dt…

最新ChatGPT商用网站源码+支持ai绘画+GPT4.0+Prompt角色+MJ以图生图+思维导图生成!

使用Nestjs和Vue3框架技术,持续集成AI能力到系统! 同步mj图片重新生成指令 同步 Vary 指令 单张图片对比加强 Vary(Strong) | Vary(Subtle) 同步 Zoom 指令 单张图片无限缩放 Zoom out 2x | Zoom out 1.5x 新增GPT联网提问功能、签到功能 支持微信环境静…

Jenkins Gerrit Trigger实践

1.创建Gerrit Trigger 2.jenkins master节点生成gerrit用户的密钥 这里的用户名得写登录gerrit后个人信息中的 Username 3.gerrit 配置刚刚jenkins生成密钥的公钥 4.gerrit 用户加入群组 不加这个群组,下一步测试就会报错“User aeshare has no capability conn…

阿里云ssl免费数字证书快过期 如何更换

1.登陆阿里云 找到ssl 查看快过期的证书 数字证书管理服务-ssl证书 2.创建免费的证书,对应过期证书的域名 3.下载新证书 pem key放在本地 此处记录本地的下载路径 /Users/dorsey/Downloads/10791167_lzzabc.cn_nginx/lzzabc.cn.pem /Users/dorsey/Downloads/1…

sheetJs / xlsx-js-style 纯前端实现导出 excel 表格及自定义单元格样式

文章目录 一、安装二、创建基础工作表三、设置单元格宽度/高度/隐藏单元格四、分配数字格式五、超链接六、单元格注释七、公式八、合并单元格九、自定义单元格样式十、项目地址 一、安装 xlsx 地址:https://www.npmjs.com/package/xlsxSheetJs 地址:htt…

中断管理

其实,关于中断的概念和定义在之前已经反复学习和实践了,这节主要讲和FreeRTOS相关的中断知识。 中断优先级 任何中断的优先级都大于任务! 在我们的操作系统,中断同样是具有优先级的,并且我们也可以设置它的优先级&a…

Quic协议 0-RTT

目录 1、Quic协议 2、Quic直接通过TLS握手进行建立链接,TLS是1.3版本 3.1、通过缓存服务器公钥实现0-RTT,服务器 通过kdf密钥派生机制,来产生会话加密key,保证数据向前安全性 3.2、通过PKN来实现数据重传保证数据完整性&…

Yalmip入门教程(4)-约束条件定义的相关函数

博客中所有内容均来源于自己学习过程中积累的经验以及对yalmip官方文档的翻译:https://yalmip.github.io/tutorials/ 这篇博客将详细介绍yalmip工具箱中约束条件定义等相关函数的用法。 1.约束条件定义的相关函数 1.1 alldifferent函数 alldifferent函数用于表示…

计算机毕设 深度学习疫情社交安全距离检测算法 - python opencv cnn

文章目录 0 前言1 课题背景2 实现效果3 相关技术3.1 YOLOV43.2 基于 DeepSort 算法的行人跟踪 4 最后 0 前言 🔥 这两年开始毕业设计和毕业答辩的要求和难度不断提升,传统的毕设题目缺少创新和亮点,往往达不到毕业答辩的要求,这两…