华为FinalMLP

news2024/12/24 9:35:00

FinalMLP:An Enhanced Two-Stream MLP model for CTR Prediction

摘要

Two-Stream model:因为一个普通的MLP网络不足以学到丰富的特征交叉信息,因此大家提出了很多实用MLP和其他专用网络结合来学习。
MLP是隐式地学习特征交叉,当前很多工作主要在另外一个stream中显式的增强特征交叉。本文提出的两个stream都用MLP网络,训练的好一样能达到惊人的效果。而且提出的可插拔式使用的特征选择层和交叉融合层,可以得到性能更强的two-stream MLP模型。

简介

单个MLP网络很难学到丰富的特征交叉信息。很多模型结构提出了是为了学习显式的特征交叉,像FM、CIN、AFN,虽然这些模型能很好学习到一阶、二阶、三阶等交叉特征,但是没法像MLP网络学到深层次交叉信息,因此很多two-stream的模型提出了,结合MLP网络和显式交叉网络,结合两者优点,像Wide&Deep、DeepFM、DCN、xDeepFM、AutoInt+,这些two-stream模型中,MLP网络学习隐式的特征交叉,另外一个stream学习显式的特征交叉。

很多two-stream模型都验证了对于单个MLP网络的效果,但是没有对比过结合两个MLP网络的two-stream模型(称为DualMLP),本文就做了对比,尽管DualMLP结构很简单,但是效果惊人。

two-stream模型可以视作两个并行网络的集成,每个stream可以从不同视角学到特征交叉的信息。比如Wide&Deep、DeepFM,一个stream去学习低阶的特征交叉,另外一个stream学习高阶的特征交叉;DCN、AutoInt+一个stream去学习显式的特征交叉,另外一个stream学习隐式的特征交叉;xDeepFM进一步从vector-wise和bit-wise视角学习特征交叉。这些都验证了两个stream中网络的差异对效果有重要影响。

本文的two-stream中两个stream都是MLP网络,差异性在于网络的层数和隐层单元数,实验发现可以实现更好的效果。同时将DualMLP作为base,在此基础上面增大两个stream的差异性,可以进一步提升DualMLP的效果。当前的two-stream模型在结合两个stream的时候通过sum或者concat,这个简单操作可能浪费了更高水平交叉(stream-level)的机会。

FinalMLP:intergrates Feature selection and interaction aggregation layers on top of two MLP module networks。即结合了特征选择层和交叉融合层的双流MLP网络,特征选择层是通过gate网络得到特征重要性进行soft特征选择,每个stream通过选择不同重要度的特征,增大各个stream的差异性。交叉融合层则是提出了一个二阶的双线性融合融合,同时为了减低计算复杂度,将计算分成k个组,也就是多头双线性融合。

背景及相关工作

Framework of Two-Stream CTR Models

框架图
在这里插入图片描述

特征Embedding

高维稀疏到稠密的表示

特征选择

可选的层,本文提出的是软选择,通过特征的重要性权重选择

特征交叉

通过两个不同的并行的网络进行交叉

两个网络的融合(Stream-level Fusion)

假设最后预估的概率为 y ^ \hat y y^ o 1 \mathbf o_1 o1 o 2 \mathbf o_2 o2是两个stream的输出表示, F \mathcal{F} F表示融合操作,通常是sum或者concat。 w w w表示将输出映射成一维的线性函数。
y ^ = σ ( w T F ( o 1 , o 2 ) ) \hat y = \sigma (w^T \mathcal{F} (\mathbf o_1, \mathbf o_2)) y^=σ(wTF(o1,o2))

代表性的Two-Stream CTR Models

Wide&Deep:一个线性网络(line stream)和一个MLP网络(deep stream)
DeepFM:在wide侧用FM替换,二阶显式交叉
DCN:一个cross网络做高阶显式交叉,另外一个stream是MLP做隐式交叉
xDeepFM:使用CIN通过vector-wise方式高阶交叉,另外一个stream通过bit-wise方式隐式交叉
AutoInt+:使用自注意力网络学习高阶交叉,融合AutoInt和MLP作为two-stream
AFN+:融合AFN和MLP作为two-stream
DeepIM:一个交互机器组件IM(interaction machine module)学习高阶特征交叉,融合IM和MLP作为two-stream
MaskNet:使用两个MaskNet作为two-stream
DCN-V2:通过一个更具表现力的cross网络来做显式特征交叉,使用cross网络和MLP作为two-stream
EDCN:并不是严格的two-stream模型,提出的一个桥接模块,桥接两个stream隐层的,这个操作限制每个stream的隐层必须有相同的层数和神经单元数,降低了灵活性

Two-Stream MLP Model

本文提出的两个stream都是MLP,称为DualMLP,两个MLP网络(隐层数及unit数不同)表示如下
o 1 = M L P 1 ( h 1 ) \mathbf o_1 = MLP_1(\mathbf h_1) o1=MLP1(h1)
o 2 = M L P 1 ( h 2 ) \mathbf o_2 = MLP_1(\mathbf h_2) o2=MLP1(h2)

Stream-Specific Feature Selection

受MMoE启发,每个stream根据gate网络对特征进行差异化选择,特征选择层定义如下
g 1 = G a t e 1 ( x 1 ) , g 2 = G a t e 1 ( x 2 ) \mathbf g_1 = Gate_1(\mathbf x1), \mathbf g_2 = Gate_1(\mathbf x2) g1=Gate1(x1),g2=Gate1(x2)
h 1 = 2 σ ( g 1 ) ⊙ e , h 2 = 2 σ ( g 2 ) ⊙ e \mathbf h_1 = 2 \sigma(\mathbf g_1) \odot \mathbf e, \mathbf h_2 = 2 \sigma(\mathbf g_2) \odot \mathbf e h1=2σ(g1)eh2=2σ(g2)e
这里 G a t e i Gate_i Gatei表示stream中MLP基于的门控网络,是以选择的特征集 x i \mathbf x_i xi作为输入,两个stream的输入可以是不同的特征子集。输出是各个特征的权重 g i \mathbf g_i gi,这里乘以2主要是为了权重均值为1。

下面有个示例图,输入分别是user、item特征集
在这里插入图片描述

交叉融合Stream-Level Interaction Aggregation

Bilinear Fusion

当前都是sum或者concat融合,借鉴在CV领域广泛使用的双线性pooling,提出双线性交叉融合层,去融合两个stream的输出,表示如下
y ^ = σ ( b + w 1 T o 1 + w 2 T o 2 + o 1 T W 3 o 2 ) \hat y = \sigma (b + \mathbf w_1^T \mathbf o_1 + \mathbf w_2^T \mathbf o_2 + \mathbf o_1^T \mathbf W_3 \mathbf o_2) y^=σ(b+w1To1+w2To2+o1TW3o2)
其中, b ∈ R , w 1 ∈ R d 1 × 1 , w 2 ∈ R d 2 × 1 W 3 ∈ R d 1 × d 2 b\in R,\mathbf w_1 \in R^{d_1 \times 1}, \mathbf w_2 \in R^{d_2 \times 1} \mathbf W_3 \in R^{d_1 \times d_2} bR,w1Rd1×1,w2Rd2×1W3Rd1×d2,这里 d 1 d_1 d1 d 2 d_2 d2表示 o 1 \mathbf o_1 o1 o 2 \mathbf o_2 o2的维度。

o 1 T W 3 o 2 \mathbf o_1^T \mathbf W_3 \mathbf o_2 o1TW3o2表示 o 1 \mathbf o_1 o1 o 2 \mathbf o_2 o2二阶双线性交叉,当 W 3 \mathbf W_3 W3是单位矩阵,那就是点乘,如果是零矩阵,就是concat融合( b + [ w 1 , w 2 ] T [ o 1 , o 2 ] b+[\mathbf w_1, \mathbf w_2]^T[\mathbf o_1, \mathbf o_2] b+[w1,w2]T[o1,o2]

这个双线性融合和FM也有关联,FM,使用 m m m维的特征向量 x \mathbf x x建模二阶交叉,可以表示为
y ^ = σ ( b + w T x + x T u p p e r ( P P T ) x ) \hat y = \sigma (b + \mathbf w^T \mathbf x + \mathbf x^T \mathcal{upper} (\mathbf P \mathbf P^T) \mathbf x) y^=σ(b+wTx+xTupper(PPT)x)

其中, b ∈ R , w ∈ R m × 1 , P ∈ R m × d b\in R,\mathbf w \in R^{m \times 1}, \mathbf P \in R^{m \times d} bR,wRm×1,PRm×d,其实FM是双线性融合的特例,当 o 1 = o 2 \mathbf o_1 = \mathbf o_2 o1=o2

但是这么做有个缺点,当 o 1 \mathbf o_1 o1 o 2 \mathbf o_2 o2维度较大时,例如1000维,双线性映射矩阵 W 3 ∈ R 1000 × 1000 \mathbf W_3\in R^{1000 \times 1000} W3R1000×1000参数量太大。

多头双线性融合

借鉴多头注意力的思想,将 o 1 \mathbf o_1 o1 o 2 \mathbf o_2 o2拆分为 k k k个子空间
o 1 = [ o 11 , o 12 , . . . , o 1 k ] \mathbf o_1 = [\mathbf o_{11}, \mathbf o_{12}, ..., \mathbf o_{1k}] o1=[o11,o12,...,o1k]
o 2 = [ o 21 , o 22 , . . . , o 2 k ] \mathbf o_2 = [\mathbf o_{21}, \mathbf o_{22}, ..., \mathbf o_{2k}] o2=[o21,o22,...,o2k]
k k k是超参数,在各个子空间分别进行双线性映射
y ^ = σ ( ∑ j = 1 k B F ( o 1 j , o 2 j ) ) \hat y = \sigma ( \sum_{j = 1} ^k BF(\mathbf o_{1j}, \mathbf o_{2j})) y^=σ(j=1kBF(o1j,o2j))
这样就把参数量由 d 1 d 2 d_1d_2 d1d2变为 d 1 d 2 / k d_1d_2/k d1d2/k

模型训练

L = − 1 N ∑ ( y l o g ( y ^ ) + ( 1 − y ) l o g ( 1 − y ^ ) ) L = - \frac {1} {N} \sum(y \mathcal log(\hat y) + (1-y) \mathcal log(1-\hat y)) L=N1(ylog(y^)+(1y)log(1y^))

实验

模型实现基于FuxiCTR,一个开源的预估CTR库。embedding_size = 10, batch_size = 4096, 默认的MLP层单元数[400, 400, 400]。对于DualMLP和FinalMLP,两个MLP设置1-3层,学习率设置为1e-3或者1e-5。

比较单个MLP和显式交叉网络

单个MLP效果非常惊人
在这里插入图片描述

DualMLP和FinalMLP

可以看到在two-stream模型中,DualMLP和FinalMLP效果完胜。
在这里插入图片描述

Ablation Studies

对比下面几个模块,说明提出的特征选择层及双线性融合层是有效果的
DualMLP
w/o FS:去掉特征选择模块
Sum:FinalMLP使用sum融合
Concat:FinalMLP使用Concat融合
EWP:FinalMLP使用Elemen-wise乘融合
在这里插入图片描述

多头双线性融合

拆分为多个子组后,效果更好,但是需要调整超参数 k k k.
在这里插入图片描述

总结

这个论文仅用MLP网络就实现了这么强的效果,和一般认知还是有些diff的,说明MLP网络只要调整的好,效果也是相当惊人的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/579717.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

分布式网络通信框架(二)——RPC通信原理和技术选型

项目实现功能 技术选型 黄色部分:设计rpc方法参数的打包和解析,也就是数据的序列化和反序列化,用protobuf做RPC方法调用的序列化和反序列化。 使用protobuf的好处: protobuf是二进制存储,xml和json是文本存储; pro…

哈希应用: 位图 + 布隆过滤器

文章目录 哈希应用: 位图 布隆过滤器1. 位图1.1 提出问题1.2 位图概念1.3 位图实现1.4 位图应用1.4.1 变形题1代码 1.4.2 变形题21.4.3 找文件交集思路1思路2 1.4.4 总结 1.5 位图优缺点 2. 哈希切割3. 布隆过滤器3.1 提出问题3.2 布隆过滤器概念3.3 布隆过滤器的各个接口3.3.…

MySQL---优化(insert、order by 、group by 、limit、子查询)

1. insert语句优化 当进行数据的insert操作的时候,可以考虑采用以下几种优化方案: -- 如果需要同时对一张表插入很多行数据时,应该尽量使用多个值表的insert语句,这种方式将大大的缩减 -- 客户端与数据库之间的连接、关闭等消耗。使得效率比…

R-Meta分析与【文献计量分析、贝叶斯、机器学习等】多技术融合实践与拓展

Meta分析是针对某一科研问题,根据明确的搜索策略、选择筛选文献标准、采用严格的评价方法,对来源不同的研究成果进行收集、合并及定量统计分析的方法,最早出现于“循证医学”,现已广泛应用于农林生态,资源环境等方面。…

ARM体系结构

目录 ARM体系架构 一、ARM公司概述 二、ARM产品系列 三、指令、指令集 指令 指令集 ARM指令集 ARM指令集 Thumb指令集 (属于ARM指令集) 四、编译原理 五、ARM数据类型 字节序 大端对齐 小端对齐 六、ARM工作模式 1.AR…

Java中synchronized锁的深入理解

使用范围 synchronized使用上用于同步方法或者同步代码块在锁实现上是基于对象去实现使用中用于对static修饰的便是class类锁使用中用于对非static修饰的便是当前对象锁 synchronized的优化 在jdk1.6中对synchronized做了相关的优化 锁消除 在synchronized修饰的代码块中…

如何实现局域网下设备之间的互通互联和外网访问?

两台电脑怎么在同一路由下访问共享文件夹?两台不同系统的电脑在同一个路由器下访问共享文件夹进行数据共享,从本质上说就是在同一个局域网下设备之间的互通互联,这就需要我们搭建一个内网文件共享服务器来实现此功能 ,比如常见的W…

linux系统中通配符与常用转义字符

通配符 在平时我们使用使用linux系统的过程中会遇到忘记文件名称的问题,这时候呢,通配符就发挥它的作用啦。 顾名思义啦,通配符就是用来匹配信息的符号,如何(*)代表零个或多个字符,(…

Unity烟花特效实现(附源码)

Unity烟花特效 附代码 写在前面效果代码地址核心步骤 写在后面 写在前面 朋友过生,不知道送什么礼物,就想着用自己所学知识做个特效当礼物吧,嘿。 主要参考了 这位up的视频 ,感谢 效果 代码地址 https://github.com/hahahappyb…

【LeetCode热题100】打开第5天:最长回文子串

文章目录 最长回文子串⛅前言🔒题目🔑题解 最长回文子串 ⛅前言 大家好,我是知识汲取者,欢迎来到我的LeetCode热题100刷题专栏! 精选 100 道力扣(LeetCode)上最热门的题目,适合初识…

linux高级---k8s中的五种控制器

文章目录 一、k8s的控制器类型二、pod与控制器之间的关系三、状态与无状态化对特点四、Deployment1、Deployment的资源清单文件2、在配置清单中调用deployment控制器3、镜像更新4、金丝雀发布5、删除Deployment 五、Statefulset六、DaemonSet1、daemonset的资源清单文件2、在配…

车载T-BOX

Telematics BOX,简称车载T-BOX,车载T-BOX主要用于和后台系统/手机APP通信,实现手机APP的车辆信息显示与控制 目录 1、车载T-BOX的定义 2、车载T-BOX的主要功能 2.1、数据采集和存储 2.2、远程查询和控制 2.3、道路救援 2.4、故障诊断 …

vue2_计算属性

目录 计算属性 计算属性缓存vs方法 计算属性vs侦听属性 getter和setter 计算属性和监听器 前端调用api实现问答 侦听器 计算属性 鉴于能在插值表达式中写js表达式;这样做也一定程度上违背了设计插值表达式的初衷;特别是: 其实就相当于…

nginx(七十九)rewrite模块指令再探

一 rewrite模块再探 ① 知识回顾 1) 结合自己遇到过的案例场景2) 关注一些易错点、难点3) 本文内容很杂,建议读者选取感兴趣的阅读 rewrite模块 rewrite功能 ② nginx中利用if 等价&&多条件 需求背景: 1) nginx不支持&&、||、and、or等逻辑…

设备描述符

前言 一直以来对设备描述符这个概念云里雾里的: 什么是设备描述符?设备描述符是个结构体还是结构体指针?为什么要有设备描述符?设备描述符的作用?设备描述符是根据什么定义的? 启发 今天看《Linux那些事…

【喜闻乐见,包教包会】二分图最大匹配:匈牙利算法(洛谷P3386)

🎭不要管上面那玩意。。。 引入 现在,你,是一位酒店的经理。 西装笔挺,清瘦智慧。 金丝眼镜,黑色钢笔。 大理石的地板,黑晶石的办公桌,晶莹的落地玻璃。 而现在,有几个雍容华贵的…

Spring高手之路——深入理解与实现IOC依赖查找与依赖注入

本文从xml开始讲解,注解后面给出 文章目录 1. 一个最基本的 IOC 依赖查找实例2. IOC 的两种实现方式2.1 依赖查找(Dependency Lookup)2.2 依赖注入(Dependency Injection) 3. 在三层架构中的 service 层与 dao 层体会依…

Opencv(图像处理)-基于Python-绘图功能

1.介绍2. line()3.rectangle()4.circle()5. ellipse()6.polylines()7.fillPoly()8. putText()代码示例9.用鼠标在图片上作图 1.介绍 OpenCV为开发者还提供了绘图功能,我们可以通过函数来实现在图片上作图。 2. line() 画线 cv2.line(img,开始点&#x…

G0第23章:GORM基本示例、GORM Model定义、主键、表名、列名的约定

04 GORM基本示例 注意: 本文以MySQL数据库为例,讲解GORM各项功能的主要使用方法。 往下阅读本文前,你需要有一个能够成功连接上的MySQL数据库实例。 Docker快速创建MySQL实例 很多同学如果不会安装MySQL或者懒得安装MySQL,可以使用一下命令…

STL好难(3):vector的使用

目录 1.vector的介绍和使用 2.vector的常见构造: 3.vector的遍历方式 🍉[ ] 下标 🍉通过迭代器进行访问: 🍉范围for: 4.vector的迭代器 🍉begin 和 end 🍉rbegin 和 rend …