推荐算法实战项目:NFM 原理以及案例实战(附完整 Python 代码)

news2024/9/28 9:27:44

本文要介绍的是由新加坡国立大学的研究人员在论文《Neural Factorization Machines for Sparse Predictive Analytics∗》中提出的NFM模型。

NFM模型全称是Neural Factorization Machines,通过名字也可以看出,这又是一个基于FM模型改进得到的网络。无论是FM模型还是其改进模型FFM,归根结底是一个二阶特征交叉的模型。受到组合爆炸问题的困扰,FM几乎不可能扩展到三阶及其以上,这就不可避免地限制了FM模型的表达能力。

而深度学习网络理论上有拟合任何复杂函数的能力,因此有没有可能使用DNN的更强的表达能力来对FM模型进行扩展呢?这也正是NFM模型出现的缘由。

现存问题

FM的问题

论文作者分析了在大规模特征组合场景下,传统方法比如LR、GBDT等方法在特征组合时候的缺陷,同时引入了FM算法和DNN方法。忽略应用领域,作者将特征组合的方式分为了两种:

  • 基于FM的线性模型
  • 基于神经网络的非线性模型

作者指出基于FM的方法具有很强的通用性,可以推广到很多领域。它是一种通用的预测器,可以和任何实值特征向量一起进行监督学习。同时作者也分析了FM模型的表达能力有限问题。先忽略效率,作者指出FM实际上仍然属于多变量线性模型家族。

不幸的是,真实世界的数据往往都是高度非线性并且不能够被线性模型准确地模拟出来。正因如此,FM可能缺乏足够强的表达能来对具有复杂固有结构和规则的真实数据进行建模。

DNN的问题

DNN的天然优势是可以以一种隐式的方式来学习任意顺序的组合特征。越深层次的网络结构可以学到更加高阶的特征表示,然而越深的网络结构就越难优化,因为随着网络深度的加深,臭名昭著的梯度消失、爆炸,过度拟合,网络退化等问题就会凸显出来。

为了实际展示DNNs的优化困难问题,作者画出了Wide&Deep和Deep&Cross模型在Frappe数据集上的训练和测试误差随着训练轮次的变化关系图,如下:

可以看到随机初始化网络的参数会导致很糟糕的性能表现。而使用FM预训练好的参数来初始化网络会提高训练网络的效率。

NFM模型

接下来给出作者提出的模型框架图,如下:

NFM模型

NFM模型主要是想结合FM模型以及DNN来对稀疏数据进行建模。与FM类似,NFM也是一个可以使用任意实值特征向量的通用的预测器。对于一个稀疏输入向量,NFM通过以下公式来估计目标值:

上式中的前两部分是线性回归部分。第三项是NFM的核心,它是用一个多层前向神经网络用来对特征交互进行建模,即使用一个表达能力更强的函数来代替FM中二阶隐向量内积的部分。NFM模型和FM模型的关系如下:

接下来自底向上分层介绍NFM模型。

1. Embedding 层

Embedding层是一个全连接层,它将每个稀疏特征映射成一个稠密向量表示。请注意,我们根据输入特征值重新调整了embedding向量,而不是简单地查找embedding表,以便覆盖所有实值特征。

2.Bi-Interaction 层

我们将embedding向量集合传入Bi-Interaction层,它通过执行池化操作将embedding向量集合转换成单个向量。具体操作如下:

维的向量,此向量编码了embedding空间中任意两个特征之间的交互。
值得指出的是Bi-Interaction层的池化操作并没有引入额外的模型参数,更重要的是,它可以在线性时间内高效计算。这个特性与平均、最大池化操作以及聚合操作类似,这些操作比较简单,被广泛地运用在神经网络中。为了展示Bi-Interaction层的线性时间复杂度,我们可以将上式改写成下面的形式:

3.Hidden Layers

在Bi-Interaction层的上面便是一系列堆叠而成的全连接层,它们能够学习到特征之间的更高阶交互。正式地,全连接层的定义如下:

通过指定非线性的激活函数,比如sigmoid、tanh、ReLU,模型可以以一种非线性的方式学习到高阶的特征交互。

4. 预测层

最后一个隐层的输出[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-n0PNzlJx-1683131080992)(https://math.jianshu.com/math?formula=z_l)]通过以下方式转换成最终的预测分数:

总结一下,NFM预测模型的计算公式可以被概括为:

相比于FM模型而言,NFM模型仅仅是多了参数,即DNN部分的参数,这部分参数是用来学习特征之间的更高阶交互。

NFM和FM的关系

FM可以认为是一个浅层的线性模型,它可以被看做是NFM模型的一个特例,即不包含隐层。那么其方程如下:

NFM和Wide&Deep和DeepCross的关系

NFM与现存的几种深度学习解决方法都有着相同的多层神经网络结构。关键的不同点在于BI-Interaction池化组件,这个在NFM中是唯一的。如果我们将Bi-Interaction池化层换成一个concatenation层,并应用一个塔型MLP的隐层,那么我们就可以还原Wide&Deep模型。concatenation操作的一个明显缺陷是它并没有考虑不同特征之间的交互。因此,这些深度学习方法只能依靠接下来的全连接层来学习有意义的特征交互,然而不幸的是,这在实践中通常很难进行训练。

完整源码&技术交流

技术要学会分享、交流,不建议闭门造车。一个人走的很快、一堆人可以走的更远。

文章中的完整源码、资料、数据、技术交流提升, 均可加知识星球交流群获取,群友已超过2000人,添加时切记的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、添加微信号:mlc2060,备注:来自 获取推荐资料
方式②、微信搜索公众号:机器学习社区,后台回复:推荐资料

部分代码

模型部分代码(主要包含了B-Interaction模型和NFM模型):

import torch
import torch.nn as nn
from BaseModel.basemodel import BaseModel

class BiInteractionPooling(nn.Module):
    """Bi-Interaction Layer used in Neural FM,compress the
      pairwise element-wise product of features into one single vector.
      Input shape
        - A 3D tensor with shape:``(batch_size,field_size,embedding_size)``.
      Output shape
        - 3D tensor with shape: ``(batch_size,1,embedding_size)``.
    """
    def __init__(self):
        super(BiInteractionPooling, self).__init__()

    def forward(self, inputs):
        concated_embeds_value = inputs
        square_of_sum = torch.pow(
            torch.sum(concated_embeds_value, dim=1, keepdim=True), 2)
        sum_of_square = torch.sum(
            concated_embeds_value * concated_embeds_value, dim=1, keepdim=True)
        cross_term = 0.5 * (square_of_sum - sum_of_square)
        return cross_term

class NFM(BaseModel):
    def __init__(self, config, dense_features_cols, sparse_features_cols):
        super(NFM, self).__init__(config)
        # 稠密和稀疏特征的数量
        self.num_dense_feature = dense_features_cols.__len__()
        self.num_sparse_feature = sparse_features_cols.__len__()

        # NFM的线性部分,对应 ∑WiXi
        self.linear_model = nn.Linear(self.num_dense_feature + self.num_sparse_feature, 1)

        # NFM的Embedding层
        self.embedding_layers = nn.ModuleList([
            nn.Embedding(num_embeddings=feat_dim, embedding_dim=config['embed_dim'])
                for feat_dim in sparse_features_cols
        ])

        # B-Interaction 层
        self.bi_pooling = BiInteractionPooling()
        self.bi_dropout = config['bi_dropout']
        if self.bi_dropout > 0:
            self.dropout = nn.Dropout(self.bi_dropout)

        # NFM的DNN部分
        self.hidden_layers = [self.num_dense_feature + config['embed_dim']] + config['dnn_hidden_units']
        self.dnn_layers = nn.ModuleList([
            nn.Linear(in_features=layer[0], out_features=layer[1])\
                for layer in list(zip(self.hidden_layers[:-1], self.hidden_layers[1:]))
        ])
        self.dnn_linear = nn.Linear(self.hidden_layers[-1], 1, bias=False)

    def forward(self, x):
        # 先区分出稀疏特征和稠密特征,这里是按照列来划分的,即所有的行都要进行筛选
        dense_input, sparse_inputs = x[:, :self.num_dense_feature], x[:, self.num_dense_feature:]
        sparse_inputs = sparse_inputs.long()

        # 求出线性部分
        linear_logit = self.linear_model(x)

        # 求出稀疏特征的embedding向量
        sparse_embeds = [self.embedding_layers[i](sparse_inputs[:, i]) for i in range(sparse_inputs.shape[1])]
        sparse_embeds = torch.cat(sparse_embeds, axis=-1)

        # 送入B-Interaction层
        fm_input = sparse_embeds.view(-1, self.num_sparse_feature, self._config['embed_dim'])
        # print(fm_input)
        # print(fm_input.shape)

        bi_out = self.bi_pooling(fm_input)
        if self.bi_dropout:
            bi_out = self.dropout(bi_out)

        bi_out = bi_out.view(-1, self._config['embed_dim'])
        # 将结果聚合起来
        dnn_input = torch.cat((dense_input, bi_out), dim=-1)

        # DNN 层
        dnn_output = dnn_input
        for dnn in self.dnn_layers:
            dnn_output = dnn(dnn_output)
            # dnn_output = nn.BatchNormalize(dnn_output)
            dnn_output = torch.relu(dnn_output)
        dnn_logit = self.dnn_linear(dnn_output)

        # Final
        logit = linear_logit + dnn_logit
        y_pred = torch.sigmoid(logit)

        return y_pred

依旧使用criteo数据集的小样本来做demo,测试部分代码如下:

import torch
from DeepCrossing.trainer import Trainer
from NFM.network import NFM
import torch.utils.data as Data
from Utils.criteo_loader import getTestData, getTrainData

nfm_config = \
{
    'embed_dim': 8, # 用于控制稀疏特征经过Embedding层后的稠密特征大小
    'dnn_hidden_units': [128, 128],
    'num_dense_features': 13,
    'bi_dropout': 0.5,
    'num_epoch': 500,
    'batch_size': 128,
    'lr': 1e-3,
    'l2_regularization': 1e-4,
    'device_id': 0,
    'use_cuda': False,
    'train_file': '../Data/criteo/processed_data/train_set.csv',
    'fea_file': '../Data/criteo/processed_data/fea_col.npy',
    'validate_file': '../Data/criteo/processed_data/val_set.csv',
    'test_file': '../Data/criteo/processed_data/test_set.csv',
    'model_name': '../TrainedModels/NFM.model'
}

def toOneHot(x, MaxList):
    res = []
    for i in range(len(x)):
        t = torch.zeros(MaxList[i])
        t[int(x[i])] = 1
        res.append(t)
    return torch.cat(res, -1)

if __name__ == "__main__":
    ####################################################################################
    # NFM 模型
    ####################################################################################
    training_data, training_label, dense_features_col, sparse_features_col = getTrainData(nfm_config['train_file'], nfm_config['fea_file'])
    train_dataset = Data.TensorDataset(torch.tensor(training_data).float(), torch.tensor(training_label).float())

    test_data = getTestData(nfm_config['test_file'])
    test_dataset = Data.TensorDataset(torch.tensor(test_data).float())

    nfm = NFM(nfm_config, dense_features_cols=dense_features_col, sparse_features_cols=sparse_features_col)
    print(nfm)
    ####################################################################################
    # 模型训练阶段
    ####################################################################################
    # # 实例化模型训练器
    trainer = Trainer(model=nfm, config=nfm_config)
    # 训练
    trainer.train(train_dataset)
    # 保存模型
    trainer.save()

    ####################################################################################
    # 模型测试阶段
    ####################################################################################
    nfm.eval()
    if nfm_config['use_cuda']:
        nfm.loadModel(map_location=lambda storage, loc: storage.cuda(nfm_config['device_id']))
        nfm = nfm.cuda()
    else:
        nfm.loadModel(map_location=torch.device('cpu'))

    y_pred_probs = nfm(torch.tensor(test_data).float())
    y_pred = torch.where(y_pred_probs>0.5, torch.ones_like(y_pred_probs), torch.zeros_like(y_pred_probs))
    print("Test Data CTR Predict...\n ", y_pred.view(-1))

输出的点击率预估部分结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/487195.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringBoot+vue文件上传下载预览大文件分片上传文件上传进度

文章目录 学习链接上传文件前端后端代码 下载文件a标签下载前端代码后台代码 动态a标签下载前端代码 axios 动态a标签前端代码 浏览器直接输入 预览文件前端代码后端代码 分片上传前后端分别md5加密spark-md5commons-codec 分片上传实现前端代码后端代码 学习链接 Blob &…

zynq基于XDMA实现PCIE X8视频采集卡 提供工程源码和QT上位机程序和技术支持

目录 1、前言2、我已有的PCIE方案3、基于zynq架构的PCIE4、总体设计思路和方案视频输入通路视频输出通路PCIE输出上位机通路 5、vivado工程详解6、SDK 工程详解7、驱动安装8、QT上位机软件9、上板调试验证10、福利:工程代码的获取 1、前言 PCIE(PCI Exp…

智能无人蜂群作战系统适应性进化模型仿真研究

源自:系统仿真学报 作者:李志强, 李元龙, 殷来祥, 马向平 摘 要 智能无人蜂群作战系统主要由有限行为能力的大规模作战个体组成,一般不具备应对复杂战场环境和作战对手变化的适应能力。采用遗传算法与增强学习相结合的方法探索构建基于个体…

Apache Flink (最新版本) 远程代码执行

路虽远&#xff0c;行则将至&#xff1b;事虽难&#xff0c;做则必成 Apache Flink < 1.9.1(最新版本) 远程代码执行 CVE-2020-17518 漏洞描述 近日,有安全研究员公开了一个Apache Flink的任意Jar包上传导致远程代码执行的漏洞. 漏洞影响 Apache Flink < 1.9.1(最新…

通过频谱规划软件摆脱频谱监测硬件限制

背景 随着无线通信技术的发展,电磁频谱被逐渐扩充&#xff0c;从几kHz到1THz的频段慢慢被各种技术填充与覆盖。在任意时刻任意地点&#xff0c;5G、WiFi、GNSS、广播电台、航空通信等&#xff0c;都离不开无线通信。 电磁频谱&#xff08;EMS&#xff09;被广泛用于生活中&am…

基于matlab使用均匀矩形阵列进行电子扫描

一、前言 本示例模拟定期扫描预定义监视区域的相控阵雷达。该单基地雷达使用900元件矩形阵列。介绍了根据规范推导雷达参数的步骤。合成接收到的脉冲后&#xff0c;进行检测和距离估计。最后&#xff0c;利用多普勒估计得到每个目标的速度。 二、雷达定义 首先&#xff0c;我们…

M301H-BYT代工-支持Hi3798 MV300H/MV300/MV310芯片-当贝纯净桌面-强刷卡刷固件包

M301H-BYT代工-支持Hi3798 MV300H&#xff0f;MV300&#xff0f;MV310芯片-当贝纯净桌面-强刷卡刷固件包 特点&#xff1a; 1、适用于对应型号的电视盒子刷机&#xff1b; 2、开放原厂固件屏蔽的市场安装和u盘安装apk&#xff1b; 3、修改dns&#xff0c;三网通用&#xff…

算法学习-图像的数据格式BGR

OpenCV学习——图像的BGR格式解读 1. opencv读取的图片数据格式2. BGR含义 1. opencv读取的图片数据格式 opencv读取的图片数据格式为numpy的nparray格式。 一张二维图片是由像素点构成&#xff0c;如下图所示&#xff1a; 其中行与列确定了像素点的位置&#xff0c;值确定了…

美团企业版:地利尚可,天时不足

配图来自Canva可画 近年来入局B端逐渐成为各家互联网大厂的必然选项&#xff0c;美团自然不甘心落于人后。 4月13日&#xff0c;美团企业版正式上线&#xff0c;面向企业客户推出一站式企业消费管理服务&#xff0c;覆盖团餐、差旅等场景&#xff0c;同时推出“企航计划”&am…

电脑视频删除了怎么恢复回来?很着急

案例分享&#xff1a;“电脑视频删除了怎么恢复回来&#xff1f;我是一名影楼的摄像师&#xff0c;我的主要工作就是拍摄婚礼视频&#xff0c;最近拍了一场婚礼视频&#xff0c;当时由于相机的内存不足&#xff0c;于是将宣传片等视频都导入进了电脑里面&#xff0c;清空摄像机…

自定义控件 (?/N) - 颜料 Paint

参考来源 一、颜色 1.1 直接设置颜色 1.1.1 setColor( ) public void setColor(ColorInt int color) paint.setColor(Color.RED) paint.setColor(Color.parseColor("#009688")) 1.1.2 setARGB( ) public void setARGB(int a, int r, int g, int b) paint.se…

多商户商城系统-v2.2.3版本发布

likeshop多商户商城系统-v2.2.3版本发布了&#xff01;主要更新内容如下 新增 1.用户端退出账号功能 优化 1.平台添加营业执照保存异常问题 2.平台端分销商品优化-只显示参与分销的商品 3.优化订单详情显示营销价格标签 4.平台交易设置增加默认值 5.种草社区评论调整&a…

如何下载安装驱动

1 打开浏览器 这里以Edge浏览器举例 第一步打开桌面上的Edge浏览器 如果您的桌面上没有 那么找到搜索栏 搜索Edge 然后打开 打开之后一般是这样 然后把我发送您的地址 驱动下载地址 https://t.lenovo.com.cn/yfeyfYyD &#xff08;这个网址只是一个例子&#xff09; 删除掉前…

MQ主流中间件

MQ主流中间件 目前&#xff0c;在消息中间件的领域中&#xff0c;主流的组件包括以下几种&#xff1a; Apache Kafka&#xff1a;一个分布式流处理平台&#xff0c;可以用于构建实时数据管道和流式应用程序。 RabbitMQ&#xff1a;一个实现了 AMQP&#xff08;高级消息队列协…

【Spring Security第一篇】初始Spring Security、表单认证、认证与授权

文章目录 一、初识Spring Security1. Spring Security简介2. 第一个Spring Security项目&#xff08;XML文件配置&#xff09;3. 第一个Spring Security项目&#xff08;自动配置&#xff09;4. 配置Security账户 二、表单认证1. Web 应用中基于密码的身份认证机制2. 默认表单认…

基于 Rainbond 的混合云管理解决方案

内容概要&#xff1a;文章探讨了混合云场景中的难点、要点&#xff0c;以及Rainbond平台在跨云平台的混合云管理方面的解决方案。包括通过通过统一控制台对多集群中的容器进行编排和管理&#xff0c;实现了对混合云中应用的一致性管理。文章还介绍了Rainbond平台在混合云环境下…

程序员的新电脑应该安装那些环境呢?

换新电脑了&#xff0c;那么作为一名程序员需要安装那些软件呢&#xff1f; 电脑系统版本&#xff1a;Windows11 注意&#xff1a;用户名一定要设置成英文的&#xff0c;否则后面会出现一定的问题&#xff01;&#xff01; 1、配置环境 &#xff08;1&#xff09;JDK环境 h…

HIEE300024R4 UAA326A04电流、电压、功率测量机电指示仪表的选用

​ HIEE300024R4 UAA326A04电流、电压、功率测量机电指示仪表的选用 电流、电压、功率测量机电指示仪表的选用 用于测量电流和电压的仪器类型如下 不 乐器 适用于 1个 PMMC&#xff08;永磁动圈&#xff09; 直流电流 2个 动铁式 交直流 3个 电测力计式 …

vcruntime140_1.dll丢失的解决方法

vcruntime140_1.dll是Microsoft Visual C Redistributable中的一个动态链接库&#xff08;DLL&#xff09;文件&#xff0c;是电脑Windows系统中重要的文件&#xff0c;丢失会造成很多软件报错无法运行。有不少小伙伴在打开ps&#xff0c;pr或者游戏的过程中都遇到过这个问题&a…

辅助驾驶功能开发-功能规范篇(16)-2-领航辅助系统NAP-巡航跟车基础功能

接上篇博文 2.3.2.巡航跟车基础功能 巡航跟车基础功能介绍辅助驾驶系统的车速设定,车间时距设定,纵向定速巡航、跟车加减速、起停,横向居中控制,弯道控制等逻辑。 前置条件: (1)NOA功能激活; 2.3.2.1.车速调节 1)激活时初始显示 中控屏设置界面有“融合限速设置”的开…