【AAAI 2024】解锁深度表格学习(Deep Tabular Learning)的关键:算术特征交互

news2025/1/11 16:57:15

近日,阿里云人工智能平台PAI与浙江大学吴健、应豪超老师团队合作论文《Arithmetic Feature Interaction is Necessary for Deep Tabular Learning》正式在国际人工智能顶会AAAI-2024上发表。本项工作聚焦于深度表格学习中的一个核心问题:在处理结构化表格数据(tabular data)时,深度模型是否拥有有效的归纳偏差(inductive bias)。我们提出算术特征交互(arithmetic feature interaction)对深度表格学习是至关重要的假设,并通过创建合成数据集以及设计实现一种支持上述交互的AMFormer架构(一种修改的Transformer架构)来验证这一假设。实验结果表明,AMFormer在合成数据集表现出显著更优的细粒度表格数据建模、训练样本效率和泛化能力,并在真实数据的对比上超过一众基准方法,成为深度表格学习新的SOTA(state-of-the-art)模型。

背景

图1:结构化表格数据示例,引用自[Borisov et al.]
图1:结构化表格数据示例,引用自[Borisov et al.]

结构化表格数据——这些数据往往以表(Table)的形式存储于数据库或数仓中——作为一种在金融、市场营销、医学科学和推荐系统等多个领域广泛使用的重要数据格式,其分析一直是机器学习研究的热点。表格数据(图1)通常同时包含数值型(numerical)特征和类目型(categorical)特征,并往往伴随有特征缺失、噪声、类别不平衡(class imblanance)等数据质量问题,且缺少时序性、局部性等有效的先验归纳偏差,极大地带来了分析上的挑战。传统的树集成模型(如,XGBoost、LightGBM、CatBoost)因在处理数据质量问题上的鲁棒性,依然是工业界实际建模的主流选择,但其效果很大程度依赖于特征工程产出的原始特征质量。

随着深度学习的流行,研究者试图引入深度学习端到端建模,从而减少在处理表格数据时对特征工程的依赖。相关的研究工作至少可以可以分成四大类:(1)在传统建模方法中叠加深度学习模块(通常是多层感知机MLP),如Wide&Deep、DeepFMs;(2)形状函数(shape function)采用深度学习建模的广义加性模型(generalized additive model),如 NAM、NBM、SIAN;(3)树结构启发的深度模型,如NODE、Net-DNF;(4)基于Transformer架构的模型,如AutoInt、DCAP、FT-Transformer。尽管如此,深度学习在表格数据上相比树模型的提升并不显著且持续,其有效性仍然存在疑问,表格数据因此被视为深度学习尚未征服的最后堡垒。

算术特征交互在深度表格学习的“必要性”

我们认为现有的深度表格学习方法效果不尽如人意的关键症结在于没有找到有效的建模归纳偏差,并进一步提出算术特征交互对深度表格学习是至关重要的假设。本节介绍我们通过创建一个合成数据集,并对比引入算数特征交互前后的模型效果,来验证该假设。

合成数据集的构造方法如下:我们设计了一个包含八个特征(

图片

)的合成数据集。

图片

图2:合成数据集上的结果对比。图中+x%表示AMFormer相比Transformer的相对提升。
标图2:合成数据集上的结果对比。图中+x%表示AMFormer相比Transformer的相对提升。

在上述数据中,我们将引入了算数特征交互的AMFormer架构与经典的XGBoost和Transformer架构对比。实验结果显示:

以上结果共同证实了算术特征交互在深度表格学习中的显著意义。

算法架构

图片
图3:AMFormer架构,其中L表示模型层数。

本节介绍AMFormer架构(图3),并重点介绍算数特征交互的引入。AMFormer架构借鉴了经典的Transformer框架,并引入了Arithmetic Block来增强模型的算术特征交互能力。在AMFormer中,我们首先将原始特征转换为具有代表性的嵌入向量,对于数值特征,我们使用一个1输入d输出的线性层;对于类别特征,则使用一个d维的嵌入查询表。之后,这些初始嵌入通过L个顺序层进行处理,这些层增强了嵌入向量中的上下文和交互元素。每一层中的算术模块采用了并行的加法和乘法注意力机制,以刻意促进算术特征之间的交互。为了促进梯度流动和增强特征表示,我们保留了残差连接和前馈网络。最终,依据这些丰富的嵌入向量,AMFormer使用分类或回归头部生成最终输出。

算术模块的关键组件包括并行注意力机制和提示标记。为了补偿需要算术特征交互的特征,我们在AMFormer中配置了并行注意力机制,这些机制负责提取有意义的加法和乘法交互候选者。这些交互候选随着会沿着候选维度被串联(concatenate)起来,并通过一个下采样的线性层进行融合,使得AMFormer的每一层都能有效捕捉算术特征交互,即特征上的四则算法运算。为了防止由特征冗余引起的过拟合并提升模型在超大规模特征数据集上的伸缩,我们放弃了原始Transformer架构中平方复杂度的自注意力机制,而是使用两组提示向量(prompt token vectors)作为加法和乘法查询。这种方法为AMFormer提供了有限的特征交互自由度,并且作为一个附带效果,优化了内存占用和训练效率。

以上是AMFormer在架构层引入的主要创新,关于模型更详细的实现细节可以参考原文以及我们的开源实现。

进一步实验结果

图片
表1:真实数据集统计以及评估指标。标题

为了进一步展示AMFormer的效果,我们挑选了四个真实数据集进行实验。被挑选数据集覆盖了二分类、多分类以及回归任务,数据集统计如表1所示。

图片
标表2:AMFormer以及基准方法的性能对比,其中括号内的数字表示该方法在当前数据集上表现的排名,最优以及次优的结果分别以加粗以及下划线突出。题

我们一共测试了包含传统树模型(XGBoost)、树架构深度学习方法(NODE)、高阶特征交互(DCN-V2、DCAP)以及Transformer派生架构(AutoInt、FT-Trans)在内的六个基准算法以及两个AMFormer实现(分别选择AutoInt、FT-Trans做基础架构,即AMF-A和AMF-F),结果汇总在表2中。

在一系列对比实验中,AMFormer表现更突出。结果显示,基于MLP的深度学习方法如DCN-V2在表格数据上的性能不尽如人意,而基于Transformer架构的模型显示出更大的潜力,但未能始终超过树模型XGBoost。我们的AMFormer在四个不同的数据集上,与所有六个基准模型相比,表现一致更优:在分类任务中,它将AutoInt和FT-transformer的准确率或AUC提升至少0.5%,最高达到1.23%(EP)和4.96%(CO);在回归任务中,它也显著减少了平均平方误差。相比其它深度表格学习方法,AMFormer具有更好的鲁棒和稳定性,这使得在性能排序中AMFormer断层式优于其它基准算法,这些实验结果充分证明了AMFormer在深度表格学习中的必要性和优越性。

结论

本工作研究了深度模型在表格数据上的有效归纳偏置。我们提出,算术特征交互对于表格深度学习是必要的,并将这一理念融入Transformer架构中,创建了AMFormer。我们在合成数据和真实世界数据上验证了AMFormer的有效性。合成数据的结果展示了其在精细表格数据建模、训练数据效率以及泛化方面的优越能力。此外,对真实世界数据的广泛实验进一步确认了其一致的有效性。因此,我们相信AMFormer为深度表格学习设定了强有力的归纳偏置。

进一步阅读:

● 论文标题:

Arithmetic Feature Interaction is Necessary for Deep Tabular Learning

● 论文作者:

程奕、胡仁君、应豪超、施兴、吴健、林伟

● 论文PDF链接:

https://arxiv.org/abs/2402.02334

● 代码链接:

https://github.com/aigc-apps/AMFormer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1519981.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【网络安全】 MSF提权

本文章仅用于信息安全学习,请遵守相关法律法规,严禁用于非法途径。若读者因此作出任何危害网络安全的行为,后果自负,与作者无关。 环境准备: 名称系统位数IP攻击机Kali Linux6410.3.0.231客户端Windows 76410.3.0.234…

基于GA优化的CNN-GRU-Attention的时间序列回归预测matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 4.1卷积神经网络(CNN)在时间序列中的应用 4.2 长短时记忆网络(LSTM)处理序列依赖关系 4.3 注意力机制(Attention) 4…

CSS之字体镂空

方法一(有缺陷)&#xff1a; <!DOCTYPE html> <html> <head> <meta charset"utf-8"> <title>Examples</title> <style> .num1{-webkit-text-stroke: 0.4px red; }</style> </head> <body><div clas…

神策分析 Copilot 成功通过网信办算法备案,数据分析 AI 化全面落地

近日&#xff0c;神策数据严格遵循《互联网信息服务深度合成管理规定》&#xff0c;已完成智能数据问答算法备案。该算法基于大模型技术&#xff0c;专注于为客户提供数据指标查询和数据洞察方面的专业回答。 神策分析 Copilot 运用神策数据智能数据问答算法&#xff0c;聚焦分…

Spring Web MVC 入门使用

1. 什么是Spring Web MVC Spring Web MVC是基于Servlet API 构建的原始Web框架&#xff0c;从一开始就包含在Spring框架中。 Servlet 是一套Java Web 开发的规范&#xff0c;或者说是一套Java Web 开发的技术标准。只有规范并不能做任何事情&#xff0c;必须要有人去实现它&a…

ZooKeeper命令和监控详解

ZooKeeper监控命令详解 在分布式系统中&#xff0c;ZooKeeper作为一个非常重要的协调服务&#xff0c;它的健康状态直接影响到整个系统的可靠性和稳定性。因此&#xff0c;对ZooKeeper进行有效监控是非常必要的。本文将详细介绍ZooKeeper提供的命令行工具zkCli.sh&#xff0c;…

Prometheus 轻量化部署和使用

文章目录 说明Prometheus简介Grafana简介prometheus和Grafana的关系环境准备&#xff08;docker&#xff09;docker安装时间时区问题&#xff08;我的代码中&#xff09;dockers镜像加速和服务器时区设置 数据库准备(mysql、redis)mysql配置redis配置 Prometheus、grafana下载和…

Linux远程连接本地数据库(docker)

1. 安装docker 参考上一篇文章 CentOS安装Docker 2. Linux中安装Mysql 2.1 docker拉取mysql镜像 拉取镜像 docker pull mysql查看镜像列表 docker images2.2 运行mysql容器 运行一个名字为mysql的mysql容器&#xff0c;其连接端口号为3306&#xff0c;密码为123456 docker r…

H266开源视频编码器VVENC现状

VVenC 是由 Fraunhofer HHI 研究团队开发的&#xff0c;主要是视频编码系统组。HHI 是欧洲最大的研究组织 Fraunhofer 协会的成员&#xff0c;该协会是德国的一个大型非营利性组织。源代码在&#xff1a; https://github.com/fraunhoferhhi/vvenc VVenC几乎与H.266视频标准同时…

React18 后台管理模板项目:现代、高效与灵活

&#x1f389; 给大家推荐一款React18TypescriptVitezustandAntdunocss且超级好用的中后台管理框架 项目地址 码云&#xff1a;https://gitee.com/nideweixiaonuannuande/xt-admin-react18github&#xff1a;https://github.com/1245488569/xt-admin-react18 演示地址 http…

AI人工智能培训讲师ChatGPT讲师叶梓培训简历及提纲ChatGPT等AI技术在医疗领域的应用

叶梓&#xff0c;上海交通大学计算机专业博士毕业&#xff0c;高级工程师。主研方向&#xff1a;数据挖掘、机器学习、人工智能。历任国内知名上市IT企业的AI技术总监、资深技术专家&#xff0c;市级行业大数据平台技术负责人。 长期负责城市信息化智能平台的建设工作&#xff…

在react中使用tailwindcss

安装tailwind css npm i -D tailwindcssnpm:tailwindcss/postcss7-compat postcss^7 autoprefixer^9安装 CRACO 由于 Create React App 不能让您覆盖原生的 PostCSS 配置&#xff0c;所以我们还需要安装 CRACO 才能配置 Tailwind。 npm install craco/craco配置CRACO 在项目根…

uni app 打肉肉(打飞机)小游戏

都给老婆和孩子写了 合十 钓鱼了&#xff0c;给自己写个打飞机吧。没找飞机怪兽的图片。就用馒头和肉肉代替了。有问题不要私信我。自己改哈 <template><view class"page_main"><view class"contentone"><canvas class"canvas…

吴恩达机器学习笔记 二十一 迁移学习 预训练

迁移学习&#xff08;transfer learning&#xff09;&#xff1a;直接把神经网络拿来&#xff0c;前面的参数可以直接用&#xff0c;把最后一层改了。 两种训练参数的方式&#xff1a; 1.只训练输出层的参数 2.训练所有参数 当只有一个小数据集的时候&#xff0c;第一种方法…

uniapp小程序:使用uni.getLocation通过腾讯地图获取相关地址信息详情(超详细)

先看运行结果&#xff1a; 流程&#xff1a; 1、在edge网页搜索腾讯位置服务 搜索后点击这里 已经有账号的就进行登录&#xff0c;没有账号的进行注册即可 点击控制台&#xff1a; 进去后点击成员管理---->我的应用---->创建应用 输入相应的参数应用名称&#xff08;随便…

Docker:常用命令

文章目录 docker作用常用指令 docker 作用 Docker 是一种容器化平台&#xff0c;可以让开发者打包应用程序及其依赖项&#xff0c;并以容器的形式进行发布、交付和运行。 Docker 的一些主要作用&#xff1a; 应用程序隔离&#xff1a;Docker 使用容器技术&#xff0c;将应用程…

NCP1271D65R2G中文资料规格书PDF数据手册引脚图参数图片价格功能特性描述

产品描述&#xff1a; NCP1271 是成功的 7 引脚电流模式 NCP12XX 系列的新一代引脚-引脚兼容新产品。该控制器通过使用可调节 Soft Skip 模式和集成的高电压启动 FET&#xff0c;实现了卓越的待机功耗。此专属 Soft Skip 还大大降低了噪音的风险。 因此可以在箝位网络中使用不…

音频提取:分享几个常用方法,简单好用

有时候我们会在视频中发现一首非常好听的歌曲&#xff0c;但是我们并不需要视频本身。 这时&#xff0c;我们可以提取视频中的音频&#xff0c;将其转化为音频文件&#xff0c;然后在任何时间、任何地点进行欣赏。 下面给大家分享几个提取视频中音频的几个方法&#xff0c;供…

vue/uniapp路由history模式下宝塔空间链接打开新窗口显示404解决方法

vue/uniapp路由history模式下宝塔空间链接打开新窗口显示404&#xff0c;或者域名后带路径参数刷新就报404 解决方法&#xff1a; 宝塔中站点配置修改&#xff1a;【配置文件】中添加下面代码&#xff0c;具体如图&#xff1a; location / {try_files $uri $uri/ /index.html…

【Go语言】Go语言中的函数

Go语言中的函数 Go语言中&#xff0c;函数主要有三种类型&#xff1a; 普通函数 匿名函数&#xff08;闭包&#xff09; 类方法 1 函数定义 Go语言函数的基本组成包括&#xff1a;关键字func、函数名、参数列表、返回值、函数体和返回语句。Go语言是强类型语言&#xff0…