推荐算法实战项目:FNN 原理以及案例实战(附完整 Python 代码)

news2024/9/22 21:33:25

本文要介绍的是FNN模型,出自于张伟楠老师于2016年发表的论文《Deep Learning over Multi-field Categorical Data》。

论文提出了两种深度学习模型,分别叫做FNN(Factorisation Machine supported Neural Network)和SNN(Sampling-based Neural Network),本文只介绍FNN模型。

其实学习FNN模型之前,强烈建议先学习FM模型,因为FNN模型其实可以看做是由一个FM模型和一个MLP组成的。

FM的引入是为了高效地将输入稀疏特征映射到稠密特征,从而加快FNN模型的训练。

介绍

用户行为预测在许多网页级应用上发挥着重要的作用,比如网页搜索、推荐系统、赞助搜索、以及广告展示等。

在在线广告中,举个例子,对目标用户群体的定位能力是区别于传统线下广告的关键优势。所有的定位技术,都依赖于预测是否特定的用户认为这个广告是相关的,给出用户在特定的场景中点击的概率。

目前大部分的CTR预测都是线性模型,如逻辑回归,朴素贝叶斯,FTRL逻辑回归和贝叶斯逻辑回归等。所有的这些都是基于使用one-hot编码的大量稀疏特征。线性模型简单,有效,但是性能偏差,因为无法学习到特征之间的相互关系。

非线性模型可以通过特征间的组合提高模型的能力。如FMs,将二值化的特征映射成连续的低维空间,通过内积获取特征间的相互关系。GBDT梯度提升树,通过树的构建过程,自动的学习特征的组合。

然而,这些方法并不能利用所有可能的组合。此外,许多模型仍然需要依靠手工进行特征工程,来决定如何进行特征的组合。另一个问题是,已有的CTR模型在对复杂数据间的潜在的模式上的表达能力是非常有限的。所以,它们的泛化能力是非常受限的。

​深度学习在CV和NLP上取得了很大的成功,在非监督的训练中,神经网络可以从原始的特征中学习到高维的特征表示,这个能力也可以用在CTR上。在CTR中,大部分的输入特征是来自各个领域的,而且是离散的类别特征。比如用户所在的城市信息(London,Paris,Beijing),设备类型(PC,Mobile),广告类别(Sports, Electronics)等等,并且特征之间的相互依赖是未知的。

因此,我们抱着极大地兴趣想了解一下,深度学习方法是如何在大规模的多特征域的离散类别数据上通过学习特征表示来提高CTR任务的估计准确度的。然而,大规模的输入特征空间需要调整大量的参数,这毫无疑问在计算上是非常昂贵的。与物理世界的图像或者音频数据不同,在推荐系统或者在线搜索等系统中,输入数据都是及其稀疏的。

举个例子,假如我们有100万个二进制输入特征,以及100个隐层,那么这大概需要1亿个连接才能构建第一层神经网络。

模型

先直接给出论文中的FNN模型图,如下:

FNN模型图

上图的右侧已经标明了每一层的含义,下面从模型自顶向下的角度详细解释一下。

  • 输出层
    模型的输出是一个实数作为预测的CTR,即特定用户在指定上下文的条件下点击给定广告的概率。

这里给出FM的方程:

为了进一步展示FM与FNN的关系,这里再补上一张图加以说明:

FM和FNN Embedding层各参数的对应关系

使用预训练的FM来初始化FNN模型的第一层参数可以有效地学习特征表示,并且绕开了高维二值输入带来的计算复杂度高问题。​更进一步,隐含层的权重(除了FM层)可以通过预训练的RBM来进行初始化。FM的权重可以通过SGD来进行更新,我们只需要更新那些不为0的单元,这样可以减少大量的计算。通过预训练FM层和其他的层进行初始化之后,再通过监督学习的方法进行微调,使用交叉熵的损失函数:

使用反向传播的链式法则,FNN模型(包含FM)的权重可以被高效地更新,举个例子,FM层的权重可以通过下述公式来进行更新:


FNN模型分析

FNN模型的特点:

  1. 采用FM预训练得到的隐含层及其权重作为神经网络的第一层的初始值,之后再不断堆叠全连接层,最终输出预测的点击率。
  2. 可以将FNN理解成一种特殊的embedding+MLP,其要求第一层嵌入后的各特征域特征维度一致,并且嵌入权重的初始化是FM预训练好的。
  3. 这不是一个端到端的训练过程,有贪心训练的思路。而且如果不考虑预训练过程,模型网络结构也没有考虑低阶特征组合。

FNN模型的优缺点:

  • 优点
    1.引入DNN对特征进行更高阶组合,减少特征工程,能在一定程度上增强FM的学习能力,这种尝试为后续深度推荐模型的发展提供了新的思路。

  • 缺点
    1.两阶段训练模式,在应用过程中不方便,且模型能力受限于FM表征能力的上限。
    2.FNN专注于高阶组合特征,但是却没有对低阶特征进行建模。

完整源码&技术交流

技术要学会分享、交流,不建议闭门造车。一个人走的很快、一堆人可以走的更远。

文章中的完整源码、资料、数据、技术交流提升, 均可加知识星球交流群获取,群友已超过2000人,添加时切记的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、添加微信号:mlc2060,备注:来自 获取推荐资料
方式②、微信搜索公众号:机器学习社区,后台回复:推荐资料

部分代码

模型部分代码:

import torch
import torch.nn as nn
from BaseModel.basemodel import BaseModel

class FNN(BaseModel):
    def __init__(self, config, dense_features_cols, sparse_features_cols):
        super(FNN, self).__init__(config)
        # 稠密和稀疏特征的数量
        self.num_dense_feature = dense_features_cols.__len__()
        self.num_sparse_feature = sparse_features_cols.__len__()

        # FNN的线性部分,对应 ∑WiXi
        self.embedding_layers_1 = nn.ModuleList([
            nn.Embedding(num_embeddings=feat_dim, embedding_dim=1)
                for feat_dim in sparse_features_cols
        ])

        # FNN的Interaction部分,对应∑∑<Vi,Vj>XiXj
        self.embedding_layers_2 = nn.ModuleList([
            nn.Embedding(num_embeddings=feat_dim, embedding_dim=config['embed_dim'])
                for feat_dim in sparse_features_cols
        ])

        # FNN的DNN部分
        self.hidden_layers = [self.num_dense_feature + self.num_sparse_feature*(config['embed_dim']+1)] + config['dnn_hidden_units']
        self.dnn_layers = nn.ModuleList([
            nn.Linear(in_features=layer[0], out_features=layer[1])\
                for layer in list(zip(self.hidden_layers[:-1], self.hidden_layers[1:]))
        ])
        self.dnn_linear = nn.Linear(self.hidden_layers[-1], 1, bias=False)

    def forward(self, x):
        # 先区分出稀疏特征和稠密特征,这里是按照列来划分的,即所有的行都要进行筛选
        dense_input, sparse_inputs = x[:, :self.num_dense_feature], x[:, self.num_dense_feature:]
        sparse_inputs = sparse_inputs.long()

        # 求出线性部分
        linear_logit = [self.embedding_layers_1[i](sparse_inputs[:, i]) for i in range(sparse_inputs.shape[1])]
        linear_logit = torch.cat(linear_logit, axis=-1)

        # 求出稀疏特征的embedding向量
        sparse_embeds = [self.embedding_layers_2[i](sparse_inputs[:, i]) for i in range(sparse_inputs.shape[1])]
        sparse_embeds = torch.cat(sparse_embeds, axis=-1)

        dnn_input = torch.cat((dense_input, linear_logit, sparse_embeds), dim=-1)

        # DNN 层
        dnn_output = dnn_input
        for dnn in self.dnn_layers:
            dnn_output = dnn(dnn_output)
            dnn_output = torch.tanh(dnn_output)
        dnn_logit = self.dnn_linear(dnn_output)

        # Final
        y_pred = torch.sigmoid(dnn_logit)

        return y_pred

注意这里实现的FNN模型跟论文中的并不完全一样。论文中描述的FNN模型的第一层的参数是通过预训练好的FM来初始化的,因此模型需要分为两个阶段来训练。这里为了简化,直接使用了两个Embedding层来代替FM中应该学习得到的参数,使得网络可以以端到端的方式训练,简化代码实现。

测试部分代码:

import torch
from FNN.network import FNN
import torch.utils.data as Data
from DeepCrossing.trainer import Trainer
from Utils.criteo_loader import getTestData, getTrainData

fnn_config = \
{
    'embed_dim': 8, # 用于控制稀疏特征经过Embedding层后的稠密特征大小
    'dnn_hidden_units': [128, 128],
    'num_epoch': 150,
    'batch_size': 64,
    'lr': 1e-3,
    'l2_regularization': 1e-4,
    'device_id': 0,
    'use_cuda': False,
    'train_file': '../Data/criteo/processed_data/train_set.csv',
    'fea_file': '../Data/criteo/processed_data/fea_col.npy',
    'validate_file': '../Data/criteo/processed_data/val_set.csv',
    'test_file': '../Data/criteo/processed_data/test_set.csv',
    'model_name': '../TrainedModels/FNN.model'
}

if __name__ == "__main__":
    ####################################################################################
    # FNN 模型
    ####################################################################################
    training_data, training_label, dense_features_col, sparse_features_col = getTrainData(fnn_config['train_file'], fnn_config['fea_file'])
    train_dataset = Data.TensorDataset(torch.tensor(training_data).float(), torch.tensor(training_label).float())

    test_data = getTestData(fnn_config['test_file'])
    test_dataset = Data.TensorDataset(torch.tensor(test_data).float())

    fnn = FNN(fnn_config, dense_features_cols=dense_features_col, sparse_features_cols=sparse_features_col)
    print(fnn)
    ####################################################################################
    # 模型训练阶段
    ####################################################################################
    # # 实例化模型训练器
    trainer = Trainer(model=fnn, config=fnn_config)
    # 训练
    trainer.train(train_dataset)
    # 保存模型
    trainer.save()

    ####################################################################################
    # 模型测试阶段
    ####################################################################################
    fnn.eval()
    if fnn_config['use_cuda']:
        fnn.loadModel(map_location=lambda storage, loc: storage.cuda(fnn_config['device_id']))
        fnn = fnn.cuda()
    else:
        fnn.loadModel(map_location=torch.device('cpu'))

    y_pred_probs = fnn(torch.tensor(test_data).float())
    y_pred = torch.where(y_pred_probs>0.5, torch.ones_like(y_pred_probs), torch.zeros_like(y_pred_probs))
    print("Test Data CTR Predict...\n ", y_pred.view(-1))

在criteo数据集上的部分测试结果,输出的是每一个测试数据的点击率预估:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/488950.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

如何利用 Kotlin 特性封装 DataStore

Jetpack DataStore是一种数据存储解决方案&#xff0c;由于使用了 Kotlin 协程或者 RxJava 以异步、一致的事务方式存储数据&#xff0c;用法相较于其它存储方案 (SharedPreferences、MMKV) 会更加特别&#xff0c;所以目前网上都没有什么比较好的 DataStore 封装。 个人了解了…

(十)Shapefile文件创建——创建Shapefile和dBASE

&#xff08;十&#xff09; Shapefile文件创建——创建Shapefile和dBASE ArcCatalog 可以创建新的 Shapefile 和 dBASE表&#xff0c;并可进行属性项及索引的操作定义 Shapefile 的坐标系统。当在目录中改变 Shapefile 的结构和特性 (Properties)时必须使用 ArcMap 来更新或重…

动态规划 --- 01背包

动态规划 — 01背包 一直到现在都非常害怕动态规划&#xff0c;因为基本上自己都无法想出dp递推式&#xff0c;太难受了 T.T 今天再一次遇到了需要写01背包的情况&#xff0c;根据自己学习的一点点经历&#xff0c;再稍微总结一下01背包吧&#xff0c;虽然是个被认为dp入门的…

自学Python必须知道的优秀社区

国内学习Python网站&#xff1a; 知乎学习平台&#xff1a;Python - 基础入门 - 知学堂黑马程序员视频库&#xff1a;大数据学习路线2023版-黑马程序员大数据学习路线图菜鸟教程&#xff1a;菜鸟教程 - 学的不仅是技术&#xff0c;更是梦想&#xff01;极客学院&#xff1a;极…

香港服务器租用攻略:如何优化用户体验?

服务器是网站、应用程序和其他在线内容的核心&#xff0c;对于在线业务来说是至关重要的。如今&#xff0c;随着互联网的普及和数字化转型&#xff0c;越来越多的企业选择在香港租用服务器&#xff0c;以满足其业务需求。但是&#xff0c;租用服务器并不仅仅是选择一个服务商并…

让chatGPT给我写一个CSS,我太蠢了

前言 CSS这东西&#xff0c;让AI写的确有点难度&#xff0c;毕竟它写出来的东西&#xff0c;没办法直接预览&#xff0c;这是其次。重要的是CSS这东西怎么描述&#xff0c;不好描述啊&#xff0c;比如我让他给我制作一个这样的效果出来&#xff0c;没办法描述&#xff0c;所以…

AcWIng1085. 不要62(数位DP)

文章目录 一、问题二、分析三、代码 一、问题 二、分析 这道题涉及的算法是数位DP。如果大家不懂数位DP的话&#xff0c;可以先去看作者之前的文章&#xff1a;第五十章 动态规划——数位DP模型 假设一个数 n n n&#xff0c;我们先求出从 1 1 1到 n n n当中&#xff0c;所有…

《花雕学AI》如何用ChatGPT提升工作效率:适合不同场合的实用技巧大全

实用技巧分类目录 一、最佳ChatGPT 4提示 二、最佳写作和内容创作ChatGPT提示 三、最佳趣味性ChatGPT提示 四、最佳网络开发的ChatGPT提示 五、最佳音乐主题ChatGPT提示 六、最佳职业主题ChatGPT提示 七、最佳用于教育的ChatGPT提示 八、最佳用于市场营销的ChatGPT提示 九、最…

MEET开发者 | 选择和努力一样重要,专访杭州三汇测试工程师齐雪莲

「MEET开发者」栏目的第二期嘉宾是来自杭州三汇的测试工程师——齐雪莲。她是从小在新疆长大的甘肃人&#xff0c;10岁的时候回到了甘肃&#xff0c;大学又考回了新疆&#xff0c;在塔里木大学就读计算机科学与技术专业。 毕业后齐雪莲入职了三汇新疆办事处任测试一职&#xff…

电脑没有网络连接怎么办 电脑无法连接网络怎么解决

这个问题至少困扰我一周 目录 电脑没有网络连接怎么办? 方法一 方法二 方法三 方法四 方法五 方法六 电脑没有网络连接怎么办? 其中也包括了改IP。。电脑就是不好使 #include <iostream> using namespace std; int main(){system("netsh interface ip s…

日志收集系统:将应用产生的数据通过flume收集后发送到Kafka,整理后保存至hbase

目录 前言&#xff1a;功能描述 第一步&#xff1a;flume拉取日志数据&#xff0c;并将源数据保存至Kafka flume配置文件&#xff1a; users&#xff1a; user_friends_raw&#xff1a; events&#xff1a; train&#xff1a; 第二步&#xff1a;Kafka源数据处理 方式一…

滚珠螺杆在设备上的应用

滚珠螺杆跟直线导轨一样&#xff0c;是很多机械设备上不可或缺的重要部件&#xff0c;它是确保机器能够具备高加工精度的前提条件&#xff0c;因此本身对于精度的要求也相当地高。今天&#xff0c;我们就来了解一下滚珠螺杆在不同设备上的应用吧&#xff01; 1、大型的加工中心…

磁盘U盘变本地磁盘寻回教程

磁盘损坏怎么恢复&#xff1f;磁盘是我们工作、学习和生活中常用的信息存储工具&#xff0c;因为容量大、价格便宜而深受人们的喜爱&#xff0c;因此磁盘也成为了我们一些重要信息的信息载具。磁盘U盘变本地磁盘寻回教程这时我们该如何恢复我们丢失的数据呢&#xff1f;这个时候…

ubuntu 安装 notepad++,显示中文菜单,并解决中文乱码问题

1.安装notepad sudo snap install notepad-plus-plus sudo snap install wine-platform-runtime2. notepad中文乱码问题 安装完成之后&#xff0c;输入中文会显示“口口…”&#xff0c;实际上并不是缺少什么windows字库&#xff0c;而是刚安装好的notepad默认字体是Courier …

4月VR大数据:PICO平台应用近400款,领跑国内VR生态

Hello大家好&#xff0c;每月一期的VR内容/硬件大数据统计又和大家见面了。 想了解VR软硬件行情么&#xff1f;关注这里就对了。我们会统计Steam平台的用户及内容等数据&#xff0c;每月初准时为你推送&#xff0c;不要错过喔&#xff01; 本数据报告包含&#xff1a;Steam VR硬…

软件测试面试题最牛汇总,不会有人没有这份文档吧

常见的面试题汇总 1、你做了几年的测试、自动化测试&#xff0c;说一下 selenium 的原理是什么&#xff1f; 我做了五年的测试&#xff0c;1年的自动化测试&#xff1b; selenium 它是用 http 协议来连接 webdriver &#xff0c;客户端可以使用 Java 或者 Python 各种编程语言…

一个.Net版本的ChatGPT SDK

ChatGPT大火&#xff0c;用它来写代码、写表白书、写文章、写对联、写报告、写周边… 啥都会&#xff01; 个人、小公司没有能力开发大模型&#xff0c;但基于开放平台&#xff0c;根据特定的场景开发应用&#xff0c;却是非常火热的。 为了避免重复造轮子&#xff0c;今天给…

你真的会跟 ChatGPT 聊天吗?(上)

前言&#xff1a;即使你对文中提及的技术不大了解&#xff0c;你也可以毫无压力地看完这篇描述如何更好地获得 ChatGPT 生成内容的文章。因为我也是利用 Azure OpenAI 等认知服务来学习&#xff0c;然后就这样写出来的。所以&#xff0c;舒服地坐下来&#xff0c;慢慢看吧&…

网络计算模式复习(三)

云计算和网格技术的差别 相对于网格计算&#xff0c;在表现形式上&#xff0c;云计算拥有明显的特点&#xff1a; 低成本&#xff0c;这是最突出的特点虚拟机的支持&#xff0c;得在网络环境下的一些原来比较难做的事情现在比较容易处理镜像部署的执行&#xff0c;这样就能够…

【微服务 | 学成在线】项目易错重难点分析(媒资管理模块篇·下)

文章目录 视频处理视频编码和文件格式文件格式和视频编码方式区别ProcessBuilder分布式任务调度XXL-JOBXXL-JOB配置XXL-JOB使用分片广播技术方案视频处理方案及实现思路分布式锁 视频处理 视频编码和文件格式 什么是视频编码&#xff1f; 同时我们还要知道我们为什么要对视频…