Deep Crossing:深度交叉网络在推荐系统中的应用

news2025/2/5 18:47:32

实验和完整代码

完整代码实现和jupyter运行:https://github.com/Myolive-Lin/RecSys--deep-learning-recommendation-system/tree/main

引言

在机器学习和深度学习领域,特征工程一直是一个关键步骤,尤其是对于大规模的推荐系统和广告点击率预测任务。传统的特征工程通常依赖于手动设计的组合特征,这些特征虽然有效,但在大规模数据场景下,其开发和维护成本极高。Deep Crossing 是一种新型的深度学习模型,能够自动学习特征组合,无需手动设计组合特征,从而在大规模数据上实现高效建模。

背景知识

Deep Crossing 是由微软研究院提出的一种深度神经网络模型,专门用于处理大规模稀疏特征数据。该模型的核心思想是通过嵌入层(Embedding Layer)、残差单元(Residual Units)和评分层(Scoring Layer)自动学习特征之间的复杂交互关系。Deep Crossing 的主要贡献在于它能够自动发现重要的特征组合,而无需依赖于手动设计的组合特征。

1. 模型结构

Deep Crossing 的网络结构主要包括以下几个部分:

  1. Embedding 层
    • 将稀疏的类别特征嵌入到低维的稠密向量中。每个类别特征都有一个对应的嵌入矩阵,嵌入矩阵的大小为 (类别数, 嵌入维度)
    • 例如,对于用户 ID 和项目 ID 等类别特征,可以将其嵌入到一个低维的稠密向量中,以便神经网络能够更好地处理。
  2. 残差单元(Residual Units)
    • 残差单元是 Deep Crossing 的核心部分,用于学习特征之间的复杂交互关系。每个残差单元包含两个全连接层(nn.Linear),中间通过非线性激活函数(ReLU)和批量归一化(BatchNorm)进行处理。
    • 残差单元的输出通过残差连接(Residual Connection)与输入相加,从而保留了输入的特征信息,避免了梯度消失问题。
  3. 评分层(Scoring Layer)
    • 评分层是一个全连接层,用于将经过残差单元处理后的特征向量映射到最终的预测值。输出层通常使用 Sigmoid 函数将输出值映射到 [0, 1] 范围内,表示预测的概率。

模型结构如下:

其中Feature #1 和 Features #n都是分类型数据,Feature #2是数值型数据

残差模块结构如下:

随着网络的加深,梯度在反向传播过程中可能会逐渐衰减(梯度消失)或指数级增长(梯度爆炸)。残差连接(Residual Connection) 通过 恒等映射(Identity Mapping),使梯度可以直接沿着跳跃连接传播,从而减轻梯度消失或爆炸的问题。这对于深度神经网络(DNN)而言尤为重要。

数学上,假设残差模块的输入为 x \mathbf{x} x,非线性变换为 F ( x ) F(\mathbf{x}) F(x),则输出为:

y = F ( x ) + x y=F(x)+x y=F(x)+x

这样,在反向传播时,梯度可以通过 F ( x ) F(\mathbf{x}) F(x) 传播,也可以通过恒等映射直接传播:

∂ y ∂ x = ∂ F ( x ) ∂ x + 1 \frac{\partial \mathbf{y}}{\partial \mathbf{x}} = \frac{\partial F(\mathbf{x})}{\partial \mathbf{x}}+ 1 xy=xF(x)+1

这保证了梯度不会因层数加深而过度衰减。


此外,从模型的表达能力来看,由于残差模块能够直接建模

F ( x ) = H ( x ) − x F(x) = H(x) - x F(x)=H(x)x

模型学习的是输入和输出之间的残差,而不是直接拟合输出 H ( X ) H(X) H(X),使得模型更容易优化,也能学习到更复杂的特征交互关系。

2. 模型理论框架

2.1 整体架构

Deep Crossing采用经典的Embedding+MLP范式,其数学表达为:

y ^ = σ ( W ( L ) ⋅ h ( L − 1 ) + b ( L ) ) \hat{y} = \sigma(W^{(L)} \cdot h^{(L-1)} + b^{(L)}) y^=σ(W(L)h(L1)+b(L))

其中 h ( l ) h^{(l)} h(l)表示第 l l l层隐藏状态,包含以下核心组件:

1. 特征嵌入层

​ 对类别型特征 c i ∈ R d i c_i \in \mathbb{R}^{d_i} ciRdi进行降维:

e i = E i T c i , E i ∈ R d i × k e_i = E_i^T c_i, \quad E_i \in \mathbb{R}^{d_i \times k} ei=EiTci,EiRdi×k

​ 数值型特征直接标准化处理:

v j = x j − μ j σ j v_j = \frac{x_j - \mu_j}{\sigma_j} vj=σjxjμj

2. 特征堆叠层

​ 将各特征向量拼接:

h ( 0 ) = [ e 1 ; e 2 ; . . . ; e m ; v 1 ; v 2 ; . . . ; v n ] h^{(0)} = [e_1; e_2; ...; e_m; v_1; v_2; ...; v_n] h(0)=[e1;e2;...;em;v1;v2;...;vn]

3. 残差层

采用改进的残差单元(受ResNet启发):

h ( l ) = f ( W 2 ( l ) ⋅ ReLU ( W 1 ( l ) h ( l − 1 ) + b 1 ( l ) ) + b 2 ( l ) ) + h ( l − 1 ) h^{(l)} = f(W_2^{(l)} \cdot \text{ReLU}(W_1^{(l)} h^{(l-1)} + b_1^{(l)}) + b_2^{(l)}) + h^{(l-1)}\\ h(l)=f(W2(l)ReLU(W1(l)h(l1)+b1(l))+b2(l))+h(l1)
其中f为激活函数,实验表明ReLU效果最优。

4. 评分层

最终预测层实现为:

p = sigmoid ( W ( L ) h ( L − 1 ) + b ( L ) ) p = \text{sigmoid}(W^{(L)} h^{(L-1)} + b^{(L)}) p=sigmoid(W(L)h(L1)+b(L))

3. 代码实现

残差模块

#残差网络块
class ResidualUnit(nn.Module):
    def __init__(self, input_dim, hidden_dim, dropout_rate):
        super(ResidualUnit, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, input_dim),
            nn.BatchNorm1d(input_dim),
            nn.Dropout(dropout_rate)
        )
        self.relu = nn.ReLU()
    
    def forward(self, x):
        residual = self.layers(x)
        return self.relu(x + residual)
    

Deep Crossing模块

class DeepCrossing(nn.Module):
    def __init__(self, cat_sizes, num_sizes, config):
        super(DeepCrossing, self).__init__()

        #Embedding层
        self.embeddings = nn.ModuleList([
            nn.Embedding(size, config.embedding_dim ) for size in cat_sizes #生成对应 Embedding层    
        ])

        #计算总特征维度
        total_dim = len(cat_sizes) * config.embedding_dim + num_sizes

        #多层Residual units
        self.res_uint = nn.Sequential()
        for _ in range(config.num_residual_units):
            self.res_uint.append(
                ResidualUnit(total_dim, config.hidden_dim, config.dropout_rate)
            )

        #scoring层
        self.fc = nn.Linear(total_dim,1)

    def forward(self, x_cat, x_num):
        #处理类别特征,注意x_cat 每一列都是一个类别特征,采用类似Ordinal Encoder
        embeddings = []
        for i in range(len(self.embeddings)):
            embeddings.append(self.embeddings[i](x_cat[:,i]))
        
        x = torch.cat(embeddings, dim = 1) #拼接起来

        #拼接数值特征
        x = torch.cat([x,x_num], dim = 1)

        #残差单元
        x = self.res_uint(x)

        #输出层
        return torch.sigmoid(self.fc(x)).squeeze()
        
        

4. 实验

由于没有合适的数据,使用sklearn中make_classification方法生成的数据进行实验如下:
在这里插入图片描述

实验结果表明,Deep Crossing 模型在训练和测试集上都表现良好,损失逐渐减小,AUC 分数逐渐提高,且训练和测试结果接近,说明模型能够有效地学习特征之间的交互关系,并具有良好的泛化能力。这些结果验证了 Deep Crossing 模型在处理大规模稀疏数据和自动特征学习方面的优势。

总结

Deep Crossing 通过 Residual Network 深度建模特征交互,避免了手工特征工程的复杂性,并在 CTR 预估等任务中表现优异。相比于传统神经网络,残差结构的加入有效缓解了梯度消失问题,使得深度学习在推荐系统领域取得更大突破。

Reference

[1]. Y. Shan, T. R. Hoens, J. Jiao, H. Wang, D. Yu, and J. C. Mao, “Deep Crossing: Web-Scale Modeling without Manually Crafted Combinatorial Features,” in Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, 2016, pp. 255-262.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2293434.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

想品客老师的第十天:类

类是一个优化js面向对象的工具 类的声明 //1、class User{}console.log(typeof User)//function//2、let Hdclass{}//其实跟1差不多class Stu{show(){}//注意这里不用加逗号,对象才加逗号get(){console.log(后盾人)}}let hdnew Stu()hd.get()//后盾人 类的原理 类…

MyBatis-Plus速成指南:条件构造器和常用接口

Wrapper 介绍 Wrapper:条件构造抽象类,最顶端父类 AbstractWrapper:用于查询条件封装,生成 SQL 的 where 条件QueryWrapper:查询条件封装UpdateWrapper:Update 条件封装AbstractLambdaWrapper:使…

(脚本学习)BUU18 [CISCN2019 华北赛区 Day2 Web1]Hack World1

自用 题目 考虑是不是布尔盲注,如何测试:用"1^1^11 1^0^10,就像是真真真等于真,真假真等于假"这个测试 SQL布尔盲注脚本1 import requestsurl "http://8e4a9bf2-c055-4680-91fd-5b969ebc209e.node5.buuoj.cn…

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.25 多线程并行:GIL绕过与真正并发

2.25 多线程并行:GIL绕过与真正并发 目录 #mermaid-svg-JO4lsTIyjOweVkos {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-JO4lsTIyjOweVkos .error-icon{fill:#552222;}#mermaid-svg-JO4lsTIyjOweVkos …

Java 大视界 -- Java 大数据在智能医疗影像诊断中的应用(72)

💖亲爱的朋友们,热烈欢迎来到 青云交的博客!能与诸位在此相逢,我倍感荣幸。在这飞速更迭的时代,我们都渴望一方心灵净土,而 我的博客 正是这样温暖的所在。这里为你呈上趣味与实用兼具的知识,也期待你毫无保留地分享独特见解,愿我们于此携手成长,共赴新程!💖 一、…

【Leetcode刷题记录】1456. 定长子串中元音的最大数目---定长滑动窗口即解题思路总结

1456. 定长子串中元音的最大数目 给你字符串 s 和整数 k 。请返回字符串 s 中长度为 k 的单个子字符串中可能包含的最大元音字母数。 英文中的 元音字母 为(a, e, i, o, u)。 这道题的暴力求解的思路是通过遍历字符串 s 的每一个长度为 k 的子串&#xf…

upload-labs安装与配置

前言 作者进行upload-labs靶场练习时,在环境上出了很多问题,吃了很多苦头,甚至改了很多配置也没有成功。 upload-labs很多操作都是旧时代的产物了,配置普遍都比较老,比如PHP版本用5.2.17(还有中间件等&am…

从Transformer到世界模型:AGI核心架构演进

文章目录 引言:架构革命推动AGI进化一、Transformer:重新定义序列建模1.1 注意力机制的革命性突破1.2 从NLP到跨模态演进1.3 规模扩展的黄金定律二、通向世界模型的关键跃迁2.1 从语言模型到认知架构2.2 世界模型的核心特征2.3 混合架构的突破三、构建世界模型的技术路径3.1 …

每日一博 - 三高系统架构设计:高性能、高并发、高可用性解析

文章目录 引言一、高性能篇1.1 高性能的核心意义1.2 影响系统性能的因素1.3 高性能优化方法论1.3.1 读优化:缓存与数据库的结合1.3.2 写优化:异步化处理 1.4 高性能优化实践1.4.1 本地缓存 vs 分布式缓存1.4.2 数据库优化 二、高并发篇2.1 高并发的核心意…

【工欲善其事】利用 DeepSeek 实现复杂 Git 操作:从原项目剥离出子版本树并同步到新的代码库中

文章目录 利用 DeepSeek 实现复杂 Git 操作1 背景介绍2 需求描述3 思路分析4 实现过程4.1 第一次需求确认4.2 第二次需求确认4.3 第三次需求确认4.4 V3 模型:中间结果的处理4.5 方案验证,首战告捷 5 总结复盘 利用 DeepSeek 实现复杂 Git 操作 1 背景介绍…

【C++】线程池实现

目录 一、线程池简介线程池的核心组件实现步骤 二、C11实现线程池源码 三、线程池源码解析1. 成员变量2. 构造函数2.1 线程初始化2.2 工作线程逻辑 3. 任务提交(enqueue方法)3.1 方法签名3.2 任务封装3.3 任务入队 4. 析构函数4.1 停机控制 5. 关键技术点解析5.1 完美转发实现5…

数据结构实战之线性表(三)

目录 1.顺序表释放 2.顺序表增加空间 3.合并顺序表 4.线性表之链表实现 1.项目结构以及初始代码 2.初始化链表(不带头结点) 3.链表尾部插入数据并显示 4.链表头部插入数据 5.初始化链表(带头结点) 6.带头结点的链表头部插入数据并显示 7.带头结…

【python】python基于机器学习与数据分析的手机特性关联与分类预测(源码+数据集)【独一无二】

👉博__主👈:米码收割机 👉技__能👈:C/Python语言 👉专__注👈:专注主流机器人、人工智能等相关领域的开发、测试技术。 python基于机器学习与数据分析的手机特性关联与分类…

ZOJ 1007 Numerical Summation of a Series

原题目链接 生成该系列值的表格 对于x 的 2001 个值,x 0.000、0.001、0.002、…、2.000。表中的所有条目的绝对误差必须小于 0.5e-12(精度为 12 位)。此问题基于 Hamming (1962) 的一个问题,当时的大型机按今天的微型计算机标准来…

全面解析文件上传下载删除漏洞:风险与应对

在数字化转型的时代,文件上传、下载与删除功能已经成为各类应用程序的标准配置,从日常办公使用的协同平台,到云端存储服务,再到社交网络应用,这些功能在给用户带来便捷体验、显著提升工作效率的同时,也隐藏…

【C语言深入探索】结构体详解(二):使用场景

目录 一、复杂数据的表示 二、数据的封装 三、多态的模拟 四、回调函数的实现 五、多线程编程 六、通信协议的实现和文件操作 6.1. 使用结构体实现简单通信协议 6.2. 使用结构体进行文件操作 七、图形界面编程 结构体在C语言中具有广泛的应用场景,以下是一…

【大模型】AI 辅助编程操作实战使用详解

目录 一、前言 二、AI 编程介绍 2.1 AI 编程是什么 2.1.1 为什么需要AI辅助编程 2.2 AI 编程主要特点 2.3 AI编程底层核心技术 2.4 AI 编程核心应用场景 三、AI 代码辅助编程解决方案 3.1 AI 大模型平台 3.1.1 AI大模型平台代码生成优缺点 3.2 AI 编码插件 3.3 AI 编…

RK3566-移植5.10内核Ubuntu22.04

说明 记录了本人使用泰山派(RK3566)作为平台并且成功移植5.10.160版本kernel和ubuntu22.04,并且成功配置&连接网络的完整过程。 本文章所用ubuntu下载地址:ubuntu-cdimage-ubuntu-base-releases-22.04-release安装包下载_开源…

从零开始实现一个双向循环链表:C语言实战

文章目录 1链表的再次介绍2为什么选择双向循环链表?3代码实现:从初始化到销毁1. 定义链表节点2. 初始化链表3. 插入和删除节点4. 链表的其他操作5. 打印链表和判断链表是否为空6. 销毁链表 4测试代码5链表种类介绍6链表与顺序表的区别7存储金字塔L0: 寄存…

51单片机 06 定时器

51 单片机的定时器属于单片机的内部资源,其电路的连接和运转均在单片机内部完成。 作用:1、用于计时;2、替代长时间的Delay,提高CPU 运行效率和处理速度。 定时器个数:3个(T0、T1、T2)&#xf…