注意力FM模型AFM

news2024/11/27 12:39:37

1. 概述

在CTR预估任务中,对模型特征的探索是一个重要的分支方向,尤其是特征的交叉,从早起的线性模型Logistic Regression开始,研究者在其中加入了人工的交叉特征,对最终的预估效果起到了正向的效果,但是人工的方式毕竟需要大量的人力,能否自动挖掘出特征的交叉成了研究的重要方向,随着Factorization Machines[1]的提出,模型能够自动处理二阶的特征交叉,极大减轻了人工交叉的工作量。

但是在FM中,每一个交叉特征的权重是一致的,但是在实际的工作中,不同的交叉特征应该具备不同的权重,尤其是较少使用到的权重,对于统一的权重会影响到模型的最终效果。AFM(Attentional Factorization Machines)[2]模型在FM模型的基础上,引入了Attention机制,通过Attention的网络对FM模型中的交叉特征赋予不同的权重。

2. 算法原理

2.1. FM模型中的交叉特征

FM模型中包含了两个部分,一部分是线性部分,另一部分是二阶的交叉部分,其表达式如下所示:

y ^ F M ( x ) = w 0 + ∑ i = 1 n w i x i ⏟ + ∑ i = 1 n ∑ j = i + 1 n w ^ i j x i x j ⏟ linear    regression pair-wise    feature omteractions \begin{matrix} \hat{y}_{FM}\left ( \mathbf{x} \right )= & \underbrace{w_0+\sum_{i=1}^{n}w_ix_i} & + & \underbrace{\sum_{i=1}^{n}\sum_{j=i+1}^{n}\hat{w}_{ij}x_ix_j} \\ & \textrm{linear\;regression} & & \textrm{pair-wise\;feature omteractions} \\ \end{matrix} y^FM(x)= w0+i=1nwixilinearregression+ i=1nj=i+1nw^ijxixjpair-wisefeature omteractions
其中, w ^ i j \hat{w}_{ij} w^ij表示的是交叉特征 x i x j x_ix_j xixj的权重,在FM算法中,为了方便计算,为每一个特征赋予了一个 k k k维的向量: v i ∈ R k \mathbf{v}_i\in \mathbb{R}^k viRk,则 w ^ i j \hat{w}_{ij} w^ij可以表示为:

w ^ i j = v i T v j \hat{w}_{ij}=\mathbf{v}_i^T\mathbf{v}_j w^ij=viTvj

对于具体为甚么上述的这样的计算方式可以方便计算,可以参见参考[3]。既然上面说 w ^ i j \hat{w}_{ij} w^ij表示的是交叉特征 x i x j x_ix_j xixj的权重,那么为什么还说在FM模型中的每个交叉特征的权重是一致的,这个怎么理解?如果将FM模型放入到神经网络的框架下,FM模型的结构可以由下图表示:

在这里插入图片描述

对于每一个特征都赋予一个 k k k维的向量,如上图中的第二个特征 x 2 x_2 x2 k k k维向量为 v 2 \mathbf{v}_2 v2,同理,第四个特征 x 4 x_4 x4 k k k维向量为 v 4 \mathbf{v}_4 v4,这里类似于对原始特征的Embedding,最终 x 2 x_2 x2 x 4 x_4 x4的交叉特征可以表示为: ( v 2 ⊙ v 4 ) x 2 x 4 \left ( \mathbf{v}_2\odot \mathbf{v}_4 \right )x_2x_4 (v2v4)x2x4,其中, ⊙ \odot 表示的是元素的乘积。最终,将所有的交叉特征相加便得到了交叉部分 y 2 y_2 y2

y 2 = p T ∑ ( i , j ) ∈ R x ( v i ⊙ v i ) x i x j + b y_2= \mathbf{p}^T\sum_{\left ( i,j \right )\in \mathfrak{R}_x}\left ( \mathbf{v}_i\odot \mathbf{v}_i \right )x_ix_j+b y2=pT(i,j)Rx(vivi)xixj+b

其中, R x = { ( i , j ) } i ∈ χ , j ∈ χ , j > i \mathfrak{R}_x=\left\{\left ( i,j \right ) \right\}_{i\in \chi ,j\in \chi,j>i} Rx={(i,j)}iχ,jχ,j>i p ∈ R k \mathbf{p}\in \mathbb{R}^k pRk b ∈ R b\in \mathbb{R} bR,在上述的FM中, p = 1 \mathbf{p}=\mathbf{1} p=1 b = 0 b=0 b=0。在相加的过程中,对于每一部分的交叉特征的权重都是一致的,这就会导致上面说的统一的权重会影响到模型的最终效果。我们希望对于每一部分的交叉特征能够有不同的权重,即:

y 2 = p T ∑ ( i , j ) ∈ R x a i , j ( v i ⊙ v i ) x i x j + b y_2=\mathbf{p}^T\sum_{\left ( i,j \right )\in \mathfrak{R}_x}a_{i,j}\left ( \mathbf{v}_i\odot \mathbf{v}_i \right )x_ix_j+b y2=pT(i,j)Rxai,j(vivi)xixj+b

其中, a i , j a_{i,j} ai,j表示的是第 i i i j j j交叉特征部分的权重。

2.2. AFM的网络结构

在注意力FM模型AFM(Attentional Factorization Machines)中,是在FM的基础上引入了Attention机制,通过Attention网络学习到每个交叉特征的权重 a i , j a_{i,j} ai,j,AFM的网络结构如下图所示:

在这里插入图片描述

上述在Pair-wise Interaction Layer和Prediction Score之间的SUM Pooling上增加了Attention的网络,具体的数学表达式如下所示:

y ^ A F M ( x ) = w 0 + ∑ i = 1 n w i x i + p T ∑ i = 1 n ∑ j = i + 1 n a i j ( v i ⊙ v j ) x i x j \hat{y}_{AFM}\left ( \mathbf{x} \right )=w_0+\sum_{i=1}^{n}w_ix_i+\mathbf{p}^T\sum_{i=1}^{n}\sum_{j=i+1}^{n}a_{ij}\left ( \mathbf{v}_i\odot \mathbf{v}_j \right )x_ix_j y^AFM(x)=w0+i=1nwixi+pTi=1nj=i+1naij(vivj)xixj

2.3. Attention网络

对于Attention网络部分,需要计算出对于不同的交叉特征部分的权重 a i j a_{ij} aij,其中,网络的输入为 ( v i ⊙ v j ) x i x j \left ( \mathbf{v}_i\odot \mathbf{v}_j \right )x_ix_j (vivj)xixj a i j a_{ij} aij的计算过程如下:

a i j ′ = h T R e L U ( W ( v i ⊙ v j ) x i x j + b ) a i j = e x p ( a i j ′ ) ∑ ( i , j ) ∈ R x e x p ( a i j ′ ) \begin{matrix} a^{'}_{ij}=\mathbf{h}^TReLU\left ( \mathbf{W}\left ( \mathbf{v}_i\odot \mathbf{v}_j \right )x_ix_j+\mathbf{b} \right ) \\ a_{ij}=\frac{exp\left ( a^{'}_{ij} \right )}{\sum_{\left ( i,j \right )\in \mathfrak{R}_x}exp\left ( a^{'}_{ij} \right )} \end{matrix} aij=hTReLU(W(vivj)xixj+b)aij=(i,j)Rxexp(aij)exp(aij)

参考[4]中给出了具体的AFM的实现,下面是Attention网络的具体实现方法:

def call(self, inputs, training=None, **kwargs):
	if K.ndim(inputs[0]) != 3:
		raise ValueError(
			"Unexpected inputs dimensions %d, expect to be 3 dimensions" % (K.ndim(inputs)))

	embeds_vec_list = inputs # 交叉特征部分
	row = []
	col = []

	for r, c in itertools.combinations(embeds_vec_list, 2):
		row.append(r)
		col.append(c)

	p = tf.concat(row, axis=1)
	q = tf.concat(col, axis=1)
	inner_product = p * q

	bi_interaction = inner_product
	attention_temp = tf.nn.relu(tf.nn.bias_add(tf.tensordot(
		bi_interaction, self.attention_W, axes=(-1, 0)), self.attention_b)) # 计算网络输出,上述公式的第一部分
	#  Dense(self.attention_factor,'relu',kernel_regularizer=l2(self.l2_reg_w))(bi_interaction)
	self.normalized_att_score = softmax(tf.tensordot(
		attention_temp, self.projection_h, axes=(-1, 0)), dim=1) # 归一化,上述公式的第二部分
	attention_output = reduce_sum(
		self.normalized_att_score * bi_interaction, axis=1) # 加权求和

	attention_output = self.dropout(attention_output, training=training)  # training,防止过拟合

	afm_out = self.tensordot([attention_output, self.projection_p]) # 乘以向量,做最终的输出
	return afm_out

3. 总结

AFM模型在FM模型的基础上,引入了Attention机制,通过Attention的网络对FM模型中的交叉特征赋予不同的权重。

参考文献

[1] Rendle S. Factorization machines[C]//2010 IEEE International conference on data mining. IEEE, 2010: 995-1000.

[2] Xiao J, Ye H, He X, et al. Attentional factorization machines: Learning the weight of feature interactions via attention networks[J]. arXiv preprint arXiv:1708.04617, 2017.

[3] 简单易学的机器学习算法——因子分解机(Factorization Machine)

[4] DeepCTR

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/166897.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华为MPLS跨域C1方案实验配置

目录 配置接域内IGP路由协议与LDP协议 配置IPv4的BGP邻居 配置PE之间的Vpnv4邻居 配置PE与CE设备对接命令 ASBR上手工为PE地址分配标签 MPLS隧道——跨域解决方案C1、C2讲解_静下心来敲木鱼的博客-CSDN博客_route-policy rr permit node 10 if-match mpls-labelhttps://bl…

IB地理课选课指南,SL还是HL适合呢?

IB地理科的标准级别(Standard Level, SL)课程跟高级级别(Higher Level,HL)课程的最大不同处在于,考卷的数量跟题目的数量是不同的。可是,两者之间的教学内容和科目指引(S…

二十八、Kubernetes中job详解

1、概述 在kubernetes中,有很多类型的pod控制器,每种都有自己的适合的场景,常见的有下面这些: ReplicationController:比较原始的pod控制器,已经被废弃,由ReplicaSet替代 ReplicaSet&#xff…

CentOS 7 升级安装 Python 3.9 版本

由于 yum install python3 默认安装的 Python 版本较低,现如今有更高版本的 Python 需求,就想用编译安装的方法安装一个较高版本的 Python,顺道记录一下安装过程。 注意:不要卸载自带的 python2,由于 yum 指令需要 pyt…

idea中代码git的版本穿梭Git Rest三种模式详解(soft,mixed,hard)

使用Git进行版本控制开发时难免会遇到回顾的情况,这里来解释下该如何正确的回滚 文章目录1.本地仓库回滚2.远程仓库回滚2.1错误案例2.2正确操作3.代码提交到错误的分支解决4.Git Rest三种模式详解(soft,mixed,hard)4.1操作演示reset --hard&a…

【论文简述】FlowFormer:A Transformer Architecture for Optical Flow(ECCV 2022)

一、论文简述 1. 第一作者:Zhaoyang Huang、Xiaoyu Shi 2. 发表年份:2022 3. 发表期刊:ECCV 4. 关键词:光流、代价体、Transformer、GRU 5. 探索动机:现有的方法对代价体的信息利用有限。 6. 工作目标&#xff1…

RabbitMQ 部署及配置详解(集群部署)

RabbitMQ 集群是一个或 多个节点,每个节点共享用户、虚拟主机、 队列、交换、绑定、运行时参数和其他分布式状态。一、RabbitMQ 集群可以通过多种方式形成:通过在配置文件中列出群集节点以声明方式以声明方式使用基于 DNS 的发现以声明方式使用 AWS &…

Java中的LinkedList

文章目录前言一、LinkedList的使用1.1 什么是LinkedList1.2 LinkedList的使用1.2.1 LinkedList的构造1.2.2 LinkedList的其他常用方法介绍1.2.3 LinkedList的遍历二、LinkedList的模拟实现三、ArrayList和LinkedList的区别总结前言 上一节中我们讲解了Java中的链表&#xff0c…

vue3.0中echarts实现中图地图的省份切换,并解决多次切换后地图卡死的情况

一、echarts安装及地图的准备 1、安装echarts npm install echarts2、下载china.js等json文件到项目中的文件夹 map的下载地址&#xff1a; 等审核 二、代码说明 <template><div class"center-body"><div class"map" id"map"…

fork函数详解

文章目录fork函数例子详解工作原理GDB 多进程调试fork函数 fork系统调用用于创建一个新进程&#xff0c;称为子进程&#xff0c;它与进程&#xff08;称为系统调用fork的进程&#xff09;同时运行&#xff0c;此进程称为父进程。创建新的子进程后&#xff0c;两个进程将执行fo…

jvm系列(2)--类加载子系统

目录第2章-类加载子系统内存结构概述简图详细图类加载器子系统类加载器ClassLoader角色类加载过程概述加载阶段链接阶段验证(Verify)准备(Prepare)解析(Resolve)初始化阶段类的初始化时机clinit()1&#xff0c;2&#xff0c;3说明4说明5说明6说明类加载器的分类概述虚拟机自带的…

【web安全】——文件上传的绕过方式

作者名&#xff1a;白昼安全主页面链接&#xff1a; 主页传送门创作初心&#xff1a; 舞台再大&#xff0c;你不上台&#xff0c;永远是观众&#xff0c;没人会关心你努不努力&#xff0c;摔的痛不痛&#xff0c;他们只会看你最后站在什么位置&#xff0c;然后羡慕或鄙夷座右铭…

价值创造链路及经营计划

“价值创造过程最主要的环节是建立链接&#xff0c;北京万柳书院在网上热议&#xff0c;其背后是人与人的大量链接&#xff0c;近期热议的湖南卫视春晚亦如是&#xff0c;这种链接为价值的设计、沟通、传递创造条件&#xff1b;企业以客户为中心设计产品&#xff0c;往大了说是…

C++ string类的初步了解

目录 一. 为什么学习string类&#xff1f; 1.C语言中的字符串 2.string类 二. string类的常用接口说明 1.构造 2.容量 size和length capacity clear empty reserve resize 3.元素访问 operator[] at front、back 4.迭代器 ​编辑begin、end rbegin、rend …

数据结构初阶:排序

本期博客我们来到了初阶数据结构最后一个知识点&#xff1a;排序 排序&#xff0c;我们从小到大就一直在接触&#xff0c;按身高、成绩、学号等等不同的排序我们已经历许多&#xff0c;那么各位是按怎样的方法进行排序的呢&#xff1f; 废话不多说这期博客我们对各种排序方法…

测试开发 | 测试平台开发-前端开发之数据展示与分析

本文节选自霍格沃兹测试学院内部教材测试平台的数据展示与分析&#xff0c;我们主要使用开源工具ECharts来进行数据的展示与分析。ECharts简介与安装ECharts是一款基于JavaScript的数据可视化图表库&#xff0c;提供直观&#xff0c;生动&#xff0c;可交互&#xff0c;可个性化…

Unity 使用OpenXR和XR Interaction Toolkit 开发 HTCVive(Vive Cosmos)

Unity 使用OpenXR和XR Interaction Toolkit 开发 HTCVive&#xff08;Vive Cosmos&#xff09; 提示&#xff1a;作者是 Unity 2020.3 以上版本做的开发。开发VR程序需要安装 Steam&#xff0c;SteamVR, (Vive Cosmos,需要再安装VIVEPORT,VIVEConsole) OpenXR 控制设备 &#x…

OpenCV(12)-OpenCV的机器学习

OpenCV的机器学习 基本概念 计算机视觉是机器学习的一种应用&#xff0c;而且是最有价的应用 人脸识别 哈尔(Haar)级联方法深度学习方法(DNN) Haar人脸识别方法 哈尔(Haar)级联方法是专门为解决人脸识别而推出的&#xff0c;在深度学习还不流行时&#xff0c;哈尔已可以商…

Android 深入系统完全讲解(21)

关键性 EGLSurface 代码位置 继续再看看&#xff0c;代码跑到 C 里面去了。 然后关键点&#xff1a; 获取本地窗口&#xff0c;创建 Surface&#xff0c;然后 toEGLHandle 进行包裹&#xff0c;变成 EGL 上下文。 EGLSurface 。 绘制的设计本质逻辑 在这里就回归一点&#xff…

Unity学习笔记--File.ReadAllLines和File.ReadAllText的使用以及注意事项(一定要看到最后!!!)

目录前言一、File.ReadAllLines参数返回例子二、File.ReadAllText参数返回例子注意事项可能出现的问题总结前言 最近在做文件存储以及读取的时候&#xff0c;需要用到C#给我们提供的类&#xff1a;File 具体使用方法可以看官方文档&#xff1a;C# File 类 这篇文章只会说File.…