推荐算法实战项目:AFM 原理以及案例实战(附完整 Python 代码)

news2025/1/17 22:04:46

本文要介绍的是由浙江大学联合新加坡国立大学提出的AFM模型。通过名字也可以看出,此模型又是基于FM模型的改进,其中A代表”Attention“,即AFM模型实际上是在FM模型中引入了注意力机制改进得来的。

之所以要在FM模型中引入注意力机制,是因为传统的FM模型对所有的交叉特征都平等对待,即每个交叉特征的权重都是相同的(都为1)。而在实际应用中,不同交叉特征的重要程度往往是不一样的。

如果”一视同仁“地对待所有的交叉特征,不考虑不同特征对结果的影响程度,事实上消解了大量有价值的信息。

AFM 论文地址:这里

推荐系统中的注意力机制

这里再举个例子,说明一下注意力机制是如何在推荐系统中派上用场的。注意力机制基于假设——不同的交叉特征对结果的影响程度不同,以更直观的业务场景为例,用户对不同交叉特征的关注程度应该是不同的。

举例来说,如果应用场景是预测一位男性用户是否会购买一款键盘的可能性,那么**”性别=男”“购买历史包含鼠标“这一交叉特征,很可能比”性别=男”“年龄=30“**这一交叉特征重要,模型应该投入更多的”注意力“在前面的特征上。

正因如此,将注意力机制引入推荐系统中也显得理所当然了。

模型

在介绍AFM模型之前,先给出FM模型的方程:

FM模型方程

Pair-wise 交互层

Pair-wise 每个交叉向量都是通过对两个不同的向量进行内积来计算的。可以通过以下公式来描述:

Attention-based Pooling层

下面看一下作者是如何将注意力机制加入到FM模型中去的,具体如下:

作者提出了通过MLP来参数化注意力分数,作者称之为”注意力网络“,其定义如下:

AFM模型

下面给出完整的AFM框架图:

AFM框架

AFM模型的整体方程为:

完整源码&技术交流

技术要学会分享、交流,不建议闭门造车。一个人走的很快、一堆人可以走的更远。

文章中的完整源码、资料、数据、技术交流提升, 均可加知识星球交流群获取,群友已超过2000人,添加时切记的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、添加微信号:mlc2060,备注:来自 获取推荐资料
方式②、微信搜索公众号:机器学习社区,后台回复:推荐资料

代码实践

模型部分:

import torch
import torch.nn as nn
from BaseModel.basemodel import BaseModel

class AFM(BaseModel):
    def __init__(self, config, dense_features_cols, sparse_features_cols):
        super(AFM, self).__init__(config)
        self.num_fields = config['num_fields']
        self.embed_dim = config['embed_dim']
        self.l2_reg_w = config['l2_reg_w']

        # 稠密和稀疏特征的数量
        self.num_dense_feature = dense_features_cols.__len__()
        self.num_sparse_feature = sparse_features_cols.__len__()

        # AFM的线性部分,对应 ∑W_i*X_i, 这里包含了稠密和稀疏特征
        self.linear_model = nn.Linear(self.num_dense_feature + self.num_sparse_feature, 1)

        # AFM的Embedding层,只是针对稀疏特征,有待改进。
        self.embedding_layers = nn.ModuleList([
            nn.Embedding(num_embeddings=feat_dim, embedding_dim=config['embed_dim'])
                for feat_dim in sparse_features_cols
        ])

        # Attention Network
        self.attention = torch.nn.Linear(self.embed_dim, self.embed_dim, bias=True)
        self.projection = torch.nn.Linear(self.embed_dim, 1, bias=False)
        self.attention_dropout = nn.Dropout(config['dropout_rate'])

        # prediction layer
        self.predict_layer = torch.nn.Linear(self.embed_dim, 1)

    def forward(self, x):
        # 先区分出稀疏特征和稠密特征,这里是按照列来划分的,即所有的行都要进行筛选
        dense_input, sparse_inputs = x[:, :self.num_dense_feature], x[:, self.num_dense_feature:]
        sparse_inputs = sparse_inputs.long()

        # 求出线性部分
        linear_logit = self.linear_model(x)

        # 求出稀疏特征的embedding向量
        sparse_embeds = [self.embedding_layers[i](sparse_inputs[:, i]) for i in range(sparse_inputs.shape[1])]
        sparse_embeds = torch.cat(sparse_embeds, axis=-1)
        sparse_embeds = sparse_embeds.view(-1, self.num_sparse_feature, self.embed_dim)

        # calculate inner product
        row, col = list(), list()
        for i in range(self.num_fields - 1):
            for j in range(i + 1, self.num_fields):
                row.append(i), col.append(j)
        p, q = sparse_embeds[:, row], sparse_embeds[:, col]
        inner_product = p * q

        # 通过Attention network得到注意力分数
        attention_scores = torch.relu(self.attention(inner_product))
        attention_scores = torch.softmax(self.projection(attention_scores), dim=1)

        # dim=1 按行求和
        attention_output = torch.sum(attention_scores * inner_product, dim=1)
        attention_output = self.attention_dropout(attention_output)

        # Prodict Layer
        # for regression problem with MSELoss
        y_pred = self.predict_layer(attention_output) + linear_logit
        # for classifier problem with LogLoss
        # y_pred = torch.sigmoid(y_pred)
        return y_pred

在criteo数据集上测试,测试代码如下:

import torch
from AFM.network import AFM
from DeepCrossing.trainer import Trainer
import torch.utils.data as Data
from Utils.criteo_loader import getTestData, getTrainData

afm_config = \
{
    'num_fields': 26, # 这里配置的只是稀疏特征的个数
    'embed_dim': 8, # 用于控制稀疏特征经过Embedding层后的稠密特征大小
    'seed': 1024,
    'l2_reg_w': 0.001,
    'dropout_rate': 0.1,
    'num_epoch': 200,
    'batch_size': 64,
    'lr': 1e-3,
    'l2_regularization': 1e-4,
    'device_id': 0,
    'use_cuda': False,
    'train_file': '../Data/criteo/processed_data/train_set.csv',
    'fea_file': '../Data/criteo/processed_data/fea_col.npy',
    'validate_file': '../Data/criteo/processed_data/val_set.csv',
    'test_file': '../Data/criteo/processed_data/test_set.csv',
    'model_name': '../TrainedModels/AFM.model'
}

if __name__ == "__main__":
    ####################################################################################
    # AFM 模型
    ####################################################################################
    training_data, training_label, dense_features_col, sparse_features_col = getTrainData(afm_config['train_file'], afm_config['fea_file'])
    train_dataset = Data.TensorDataset(torch.tensor(training_data).float(), torch.tensor(training_label).float())

    test_data = getTestData(afm_config['test_file'])
    test_dataset = Data.TensorDataset(torch.tensor(test_data).float())

    afm = AFM(afm_config, dense_features_cols=dense_features_col, sparse_features_cols=sparse_features_col)
    ####################################################################################
    # 模型训练阶段
    ####################################################################################
    # # 实例化模型训练器
    trainer = Trainer(model=afm, config=afm_config)
    # 训练
    trainer.train(train_dataset)
    # 保存模型
    trainer.save()

    ####################################################################################
    # 模型测试阶段
    ####################################################################################
    afm.eval()
    if afm_config['use_cuda']:
        afm.loadModel(map_location=lambda storage, loc: storage.cuda(afm_config['device_id']))
        afm = afm.cuda()
    else:
        afm.loadModel(map_location=torch.device('cpu'))

    y_pred_probs = afm(torch.tensor(test_data).float())
    y_pred = torch.where(y_pred_probs>0.5, torch.ones_like(y_pred_probs), torch.zeros_like(y_pred_probs))
    print("Test Data CTR Predict...\n ", y_pred.view(-1))

点击率预估结果如下(预测用户会点击输出为1,反之为0):

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/487607.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PC3-管理员操作

token无效可能,就是token过期了需要配置::: history 安装可以跳路由在ts文件中:因为在ts文件中还需要store,清空token // 安装可以跳路由在ts文件中npm i history 防止接口出现 token 无效,登…

【C++】AVL平衡二叉树源码剖析

目录 概述 算法 左单旋 右单旋 左右双旋 右左双旋 源码 AVLTree.h test.cpp 概述 AVL树也叫平衡二叉搜索树,是二叉搜索树的进化版,设计是原理是弥补二叉搜索树的缺陷:当插入的数据接近于有序数列时,二叉搜索树的性能严重…

20天能拿下PMP吗?

新版大纲,专注于人员、过程、业务环境三个领域,内容贯穿价值交付范围(包括预测、敏捷和混合的方法)。除了考试时间由240分钟变更为230分钟、200道单选题变为180道(包含单选和多选)之外,新考纲还…

【Ubuntu18配置Anaconda深度学习环境】

参考:Ubuntu18配置与ROS 兼容的深度学习环境(Anaconda3PyTorch1.10python3.8cuda10.2) 1. 前言 之前在Window下安装了Anaconda,熟悉了一下安装过程,Ubuntu18.04下最难的应该就是和ROS Melodic的兼容问题。ROS1是基于P…

Linux常用命令——inotifywait命令

在线Linux命令查询工具 inotifywait 异步文件系统监控机制 补充说明 Inotify一种强大的、细粒度的、异步文件系统监控机制,它满足各种各样的文件监控需要,可以监控文件系统的访问属性、读写属性、权限属性、删除创建、移动等操作,也就是可…

数据结构之第十一章、排序算法

一、排序的概念及引用 1.1排序的概念 排序:所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起来的操作。 1.1.1排序的稳定性 稳定性:假定在待排序的记录序列中,存在多个具…

数据结构之第九章、优先级队列(堆)

目录 一、优先级队列 1.1概念 二、优先级队列的模拟实现 2.1堆的概念 2.2堆的存储方式 2.3堆的创建 2.3.1堆向下调整 2.3.2堆的创建 2.3.3建堆的时间复杂度 2.4堆的插入与删除 2.4.1堆的插入 2.4.2堆的删除 2.5用堆模拟实现优先级队列 三、常用接口介绍 3.1Priori…

计算机组成原理与体系结构

目录 第一章、计算机组成原理与体系结构1、数据的表示1.1.进制转换1.2.原码、反码、补码、移码1.3.数据的表述 2、计算机结构3、Flynn分类法4、CISC与RISC5、流水线技术5.1、流水线的基本概念5.2、流水线的计算5.3、流水线吞吐率计算5.4、流水线加速比计算5.5、流水线的效率 6、…

Python小姿势 - Python的多线程编程

Python的多线程编程 Python的多线程编程提供了一个非常简单的方法来让一个Python程序同时运行多个任务。这个方法通过创建新的线程来实现,线程可以被视为一个单独的执行流程。 为了创建一个新线程,我们需要使用Python的_thread模块中的start_new_thread(…

【IDEA】简单入门:请求数据库表数据

目录 修改编辑与控制台字体大小 二、sprintboot项目入门 【1】直接开始配置Controller 【2】直接请求数据库中的数据,返回json格式 (0)整合PostgreSQL框架 (2)实体entity类 (3)控制类Mai…

快速了解车联网V2X通信

自动驾驶拥有极其巨大的潜力,有可能改变我们的出行方式。它不仅有望永远改变车辆的设计和制造,还会永远改变汽车的所有权乃至整个交通运输业务。要实现全自动驾驶的目标,开发人员需要开发极为复杂的软件,软件中融入的人工智能(AI)…

机械硬盘和固态硬盘有什么区别?如何使用?

案例:怎么区分机械硬盘和固态硬盘? 【我知道硬盘可以用来储存数据,但我不知道机械硬盘和固态硬盘的区别,有没有小伙伴可以详细解释一下。】 硬盘可以用来储存数据,常见的硬盘有两种,分别是机器硬盘和固态…

C++11多线程编程——线程池的实现

学一门新技术,还是要问那个问题,为什么我们需要这个技术,这个技术能解决什么痛点。 一、为何需要线程池 那么为什么我们需要线程池技术呢?多线程编程用的好好的,干嘛还要引入线程池这个东西呢?引入一个新的…

发展文旅夜游项目有哪些好处

夜晚的城市,总是充满着无限的魅力和活力,而文旅夜游更是让这份魅力和活力得到了更好的展现和发挥。新起典文旅科技认为文旅夜游不仅仅是一种旅游方式,更是可以增加城市夜间经济、丰富文化娱乐生活、缓解白天拥堵、提高旅游体验、促进文化交流…

HTTP的特点

灵活可扩展 HTTP 协议最初诞生的时候就比较简单,本着开放的精神只规定了报文的基本格式,比如用空格分隔单词,用换行分隔字段,“headerbody”等,报文里的各个组成部分都没有做严格的语法语义限制,可以由开发…

大厂面试NLP工程师,会考察你哪些方面的能力?

你好,我是周磊。 相信你已经知道,一名AI算法工程师,不但需要基础能力扎实,更要具备良好的工程落地能力。那在NLP工程师面试的时候,你知道面试官会从哪些维度去考察你这两方面的能力吗? 今天我就结合我的一…

一种用于大坝水库边坡内部振弦式应变计组

1用途 多向应变计组适用于长期埋设在水工结构物或其它混凝土结构物内,测量结构物内部各个方向上的应变量,并可同步测量埋设点的温度。 应变计按方向和支数安装在应变计安装支座上,组成多向应变计组,用于测量大体积混凝土中各方向…

SpringCloud------热部署(三)

SpringCloud------热部署(三) Devtools是热部署插件,引入热部署实现高效自测。 步骤: 1.Adding devtools to your project 2.Adding plugin to your project 3.Enabling automatic build 4.Update the value of 点击 ctrlshiftal…

大型互联网企业大流量高并发电商领域核心项目已上线(完整流程+项目白皮书)

说在前面的话 面对近年来网络的飞速发展,大家已经都习惯了网络购物,从而出现了一些衍生品例如:某宝/某东/拼夕夕等大型网站以及购物APP~ 并且从而导致很多大型互联网企业以及中小厂都需要有完整的项目经验,以及优秀处理超大流量…

Mysql数据库迁移|如何把一台服务器的mysql数据库迁移到另一台服务器上的myql中

前言 那么这里博主先安利一下一些干货满满的专栏啦! Linux专栏https://blog.csdn.net/yu_cblog/category_11786077.html?spm1001.2014.3001.5482操作系统专栏https://blog.csdn.net/yu_cblog/category_12165502.html?spm1001.2014.3001.5482手撕数据结构https:/…