deepfm内容理解

news2024/9/23 17:17:15

对于CTR问题,被证明的最有效的提升任务表现的策略是特征组合(Feature Interaction);

两个问题:

如何更好地学习特征组合,进而更加精确地描述数据的特点;

如何更高效的学习特征组合。

DNN局限 :当我们使用DNN网络解决推荐问题的时候存在网络参数过于庞大的问题,这是因为在进行特征处理的时候我们需要使用one-hot编码来处理离散特征,这会导致输入的维度猛增。

为了解决DNN参数量过大的局限性,可以采用非常经典的Field思想,将OneHot特征转换为Dense Vector,通过增加全连接层就可以实现高阶的特征组合。

黑色的线 和 红色的线 进行concat

self定义 

deep_features = deep_features
fm_features = fm_features  #稀疏的特征
deep_dims = sum([fea.embed_dim for fea in deep_features])  #8
fm_dims = sum([fea.embed_dim for fea in fm_features])  #368   = 23*16           #稀疏的特征embedding化
linear = LR(fm_dims)  # 1-odrder interaction   低阶信息   (fc): Linear(in_features=368, out_features=1, bias=True)
fm = FM(reduce_sum=True)  # 2-odrder interaction    #FM将一阶特征和二阶特征cancat
embedding = EmbeddingLayer(deep_features + fm_features)
mlp = MLP(deep_dims, **mlp_params)

 forward


input_deep = embedding(x, deep_features, squeeze_dim=True)  #[batch_size, deep_dims]    torch.Size([10, 8])
input_fm = embedding(x, fm_features, squeeze_dim=False)  #[batch_size, num_fields, embed_dim]   torch.Size([10, 23, 16])
y_linear = linear(input_fm.flatten(start_dim=1))  #torch.Size([10, 1])  对应的稀疏特征 经过线性层变为1
y_fm = fm(input_fm)  #torch.Size([10, 1])    #对稀疏特征做一阶 二阶处理 
y_deep = mlp(input_deep)  #[batch_size, 1]  #torch.Size([10, 1])
y = y_linear + y_fm + y_deep          
# return torch.sigmoid(y.squeeze(1))

定义的一些函数: 

import torch.nn as nn
class LR(nn.Module):
    """Logistic Regression Module. It is the one Non-linear 
    transformation for input feature.

    Args:
        input_dim (int): input size of Linear module.
        sigmoid (bool): whether to add sigmoid function before output.

    Shape:
        - Input: `(batch_size, input_dim)`
        - Output: `(batch_size, 1)`
    """

    def __init__(self, input_dim, sigmoid=False):
        super().__init__()
        self.sigmoid = sigmoid
        self.fc = nn.Linear(input_dim, 1, bias=True)

    def forward(self, x):
        if self.sigmoid:
            return torch.sigmoid(self.fc(x))
        else:
            return self.fc(x)
        

class FM(nn.Module):
    """The Factorization Machine module, mentioned in the `DeepFM paper
    <https://arxiv.org/pdf/1703.04247.pdf>`. It is used to learn 2nd-order 
    feature interactions.

    Args:
        reduce_sum (bool): whether to sum in embed_dim (default = `True`).

    Shape:
        - Input: `(batch_size, num_features, embed_dim)`
        - Output: `(batch_size, 1)`` or ``(batch_size, embed_dim)`
    """

    def __init__(self, reduce_sum=True):
        super().__init__()
        self.reduce_sum = reduce_sum

    def forward(self, x):
        square_of_sum = torch.sum(x, dim=1)**2
        sum_of_square = torch.sum(x**2, dim=1)
        ix = square_of_sum - sum_of_square
        if self.reduce_sum:
            ix = torch.sum(ix, dim=1, keepdim=True)
        return 0.5 * ix

参考资料:

推荐系统遇上深度学习(三)--DeepFM模型理论和实践 - 简书 (jianshu.com)

DeepFM (datawhalechina.github.io)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/986350.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

ESXI主机扩容(VCSA)

原因分析SCSI扩容 VMware为ESXI虚拟机硬盘扩容(需要先关闭ESXI) ESXI扩容前ESXI扩容 https://blog.csdn.net/tongxin_tongmeng/article/details/132652423 ESXI扩容后

2023年高教社杯数学建模国赛 赛题浅析

2023年国赛如期而至&#xff0c;为了方便大家尽快确定选题&#xff0c;这里将对赛题进行浅析&#xff0c;以分析赛题的主要难点、出题思路以及选择之后可能遇到的难点进行说明&#xff0c;方便大家尽快确定选题。 难度排序 B>A>C 选题人数 C>A>B (预估结果&…

软件工程课件

软件工程 考点概述软件工程概述能力成度模型能力成熟度模型集成软件过程模型逆向工程![ ](https://img-blog.csdnimg.cn/425cea8190fb4c5ab2bf7be5e2ad990e.png) 考点概述 重点章节 软件工程概述 之前老版教程的&#xff0c;之前考过 能力成度模型 记忆 能力等级 和 特点 能力…

【Java】基础练习 --- Stream练习

1.拼接 给定一个字符串数组,使用 Stream 把所有字符串拼接成一个字符串。 String[] arr {"a", "b", "c"}; 输出: abc&#xff08;1&#xff09;源码&#xff1a; package swp.kaifamiao.codes.Java.d0907;import java.util.stream.Stream;/*…

2023高教社杯全国大学生数学建模竞赛E题代码解析

2023高教社杯全国大学生数学建模竞赛E题 黄河水沙监测数据分析 代码解析 因为一些不可抗力&#xff0c;下面仅展示部分python代码&#xff08;第一问的部分&#xff09;&#xff0c;其余代码看文末 首先导入包&#xff1a; import numpy as np import pandas as pd import m…

【Python】matplotlib分格显示

参考&#xff1a;matplotlib图形整合之多个子图一起绘制_matplotlib多子图_王小王-123的博客-CSDN博客 方式一&#xff1a; import matplotlib.pyplot as plt import matplotlib.gridspec as gridspecplt.figure() # 方式一: gridspec # rowspan:行的跨度&#xff0c;colspan…

【代码随想录】Day 49 动态规划10 (买卖股票Ⅰ、Ⅱ)

买卖股票的最佳时机 https://leetcode.cn/problems/best-time-to-buy-and-sell-stock/ dp[i]表示在第i天时&#xff0c;卖/不卖股票能获得的最大利润&#xff1a; 1、卖股票&#xff1a;dp[i] prices[i] -minPrice&#xff08;i天以前的最低价格&#xff09; 2、不卖股票&am…

Python web自动化测试 —— 文件上传

文件上传三种方式&#xff1a; &#xff08;一&#xff09;查看元素标签&#xff0c;如果是input&#xff0c;则可以参照文本框输入的形式进行文件上传 方法&#xff1a;和用户输入是一样的&#xff0c;使用send_keys 1 2 3 4 5 步骤&#xff1a;1、找到定位元素&#xff0c…

ICCV2021 Exploring Cross-Image Pixel Contrast for Semantic Segmentation (Oral)

Exploring Cross-Image Pixel Contrast for Semantic Segmentation 探索语义分割的跨图像像素对比度 Paper&#xff1a;https://openaccess.thecvf.com/content/ICCV2021/html/Wang_Exploring_Cross-Image_Pixel_Contrast_for_Semantic_Segmentation_ICCV_2021_paper.html Co…

k8s 入门到实战--部署应用到 k8s

k8s 入门到实战 01.png 本文提供视频版&#xff1a; 背景 最近这这段时间更新了一些 k8s 相关的博客和视频&#xff0c;也收到了一些反馈&#xff1b;大概分为这几类&#xff1a; 公司已经经历过服务化改造了&#xff0c;但还未接触过云原生。公司部分应用进行了云原生改造&…

Spring MVC拦截器

拦截器&#xff08;Interceptor&#xff09;是 Spring MVC 提供的一种强大的功能组件。它可以对用户请求进行拦截&#xff0c;并在请求进入控制器&#xff08;Controller&#xff09;之前、控制器处理完请求后、甚至是渲染视图后&#xff0c;执行一些指定的操作。 在 Spring MV…

springboot设置热部署

一、常见的三种方式&#xff1a; Springboot中常见的热部署方式有3种: 1.使用springloaded配置pom.xml文件&#xff0c;使用mvn spring-boot:run启动 2.使用springloaded本地加载启动&#xff0c;配置jvm参数 3.使用devtools工具包&#xff0c;操作简单&#xff0c;但是每次…

python开发基础篇1——后端操作K8s API方式

文章目录 一、基本了解1.1 操作k8s API1.2 基本使用 二、数据表格展示K8s常见资源2.1 Namespace2.2 Node2.3 PV2.4 Deployment2.5 DaemonSet2.6 StatefulSet2.7 Pod2.8 Service2.9 Ingress2.10 PVC2.11 ConfigMap2.12 Secret2.13 优化 一、基本了解 操作K8s资源api方式&#xf…

第11篇:ESP32vscode_platformio_idf框架helloworld点亮LED

第1篇:Arduino与ESP32开发板的安装方法 第2篇:ESP32 helloword第一个程序示范点亮板载LED 第3篇:vscode搭建esp32 arduino开发环境 第4篇:vscodeplatformio搭建esp32 arduino开发环境 ​​​​​​第5篇:doit_esp32_devkit_v1使用pmw呼吸灯实验 第6篇:ESP32连接无源喇叭播…

酷克数据推出AI开发工具箱HashML 加速企业级AI应用落地投产

近日&#xff0c;业界领先的国产企业级云数仓厂商酷克数据发布了下一代In-Database高级分析和数据科学工具箱HashML&#xff0c;在业内率先实现为企业提供随数仓部署一步到位、开箱即用的AI能力。 在数字经济时代&#xff0c;描述性分析已经非常成熟并被企业广泛采纳。然而&am…

vue3:3、项目目录和关键文件

关于vsvode的更改 <!-- 加上setup允许在script中直接编写组合式api --> <script setup> // 组件引入后直接用 import HelloWorld from ./components/HelloWorld.vue import TheWelcome from ./components/TheWelcome.vue</script><!-- 1、js放在最上面&am…

JDK源码剖析之PriorityQueue优先级队列

写在前面 版本信息&#xff1a; JDK1.8 PriorityQueue介绍 在数据结构中&#xff0c;队列分为FIFO、LIFO 两种模型&#xff0c;分别为先进先出&#xff0c;后进后出、先进后出&#xff0c;后进先出&#xff08;栈&#xff09; 而一切数据结构都是基于数组或者是链表实现。 在…

线上问诊:可视化展示

系列文章目录 线上问诊&#xff1a;业务数据采集 线上问诊&#xff1a;数仓数据同步 线上问诊&#xff1a;数仓开发(一) 线上问诊&#xff1a;数仓开发(二) 线上问诊&#xff1a;数仓开发(三) 线上问诊&#xff1a;可视化展示 文章目录 系列文章目录前言一、全流程调度1.生产新…

两两交换链表中节点

给你一个链表&#xff0c;两两交换其中相邻的节点&#xff0c;并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题&#xff08;即&#xff0c;只能进行节点交换&#xff09;。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4] 输出&#xff1a;[2,1,4…

【python 多线程】初体验+单线程下载器+并行下载器

1.多线程初体验 主线程的id和进程的id是一个 查看进程pid下有多少个线程 ps -T -p pid(base) D:\code\python_project\python_coroutine>C:/ProgramData/Anaconda3/python.exe d:/code/python_project/python_coroutine/01demo.py threading.active_count1 i am producer…