推荐算法—widedeep原理知识总结代码实现

news2025/1/17 23:13:50

wide&deep原理知识总结代码实现

  • 1. Wide&Deep 模型的结构
    • 1.1 模型的记忆能力
    • 1.2 模型的泛化能力
  • 2. Wide&Deep 模型的应用场景
  • 3. Wide&Deep 模型的代码实现
    • 3.1 tensorflow实现
    • 3.2 pytorch实现

今天,总结一个在业界有着巨大影响力的推荐模型,Google 的 Wide&Deep。可以说,只要掌握了 Wide&Deep,就抓住了深度推荐模型这几年发展的一个主要方向。

1. Wide&Deep 模型的结构

Wide&Deep 模型的结构
上图就是 Wide&Deep 模型的结构图,它是由左侧的 Wide 部分和右侧的 Deep 部分组成的。Wide 部分的结构很简单,就是把输入层直接连接到输出层,中间没有做任何处理。Deep 层的结构稍复杂,就是常见的Embedding+MLP 的模型结构。

Wide 部分的主要作用是让模型具有较强的“记忆能力”(Memorization),而 Deep 部分的主要作用是让模型具有“泛化能力”(Generalization),因为只有这样的结构特点,才能让模型兼具逻辑回归和深度神经网络的优点,也就是既能快速处理和记忆大量历史行为特征,又具有强大的表达能力,这就是 Google 提出这个模型的动机。

1.1 模型的记忆能力

所谓的 “记忆能力”,可以被宽泛地理解为模型直接学习历史数据中物品或者特征的“共现频率”,并且把它们直接作为推荐依据的能力。
就像我们在电影推荐中可以发现一系列的规则,比如,看了 A 电影的用户经常喜欢看电影 B,这种“因为 A 所以 B”式的规则,非常直接也非常有价值。

1.2 模型的泛化能力

“泛化能力”指的是模型对于新鲜样本、以及从未出现过的特征组合的预测能力。
看一个例子:假设,我们知道 25 岁的男性用户喜欢看电影 A,35 岁的女性用户也喜欢看电影 A。如果我们想让一个只有记忆能力的模型回答,“35 岁的男性喜不喜欢看电影 A”这样的问题,这个模型就会“说”,我从来没学过这样的知识啊,没法回答你。这就体现出泛化能力的重要性了。模型有了很强的泛化能力之后,才能够对一些非常稀疏的,甚至从未出现过的情况作出尽量“靠谱”的预测。

事实上,矩阵分解就是为了解决协同过滤“泛化能力”不强而诞生的。因为协同过滤只会“死板”地使用用户的原始行为特征,而矩阵分解因为生成了用户和物品的隐向量,所以就可以计算任意两个用户和物品之间的相似度了。这就是泛化能力强的另一个例子。

2. Wide&Deep 模型的应用场景

Wide&Deep 模型是由 Google 的应用商店团队 Google Play 提出的,在 Google Play 为用户推荐 APP 这样的应用场景下,Wide&Deep 模型的推荐目标就显而易见了,就是应该尽量推荐那些用户可能喜欢,愿意安装的应用。那具体到 Wide&Deep 模型中,Google Play 团队是如何为 Wide 部分和 Deep 部分挑选特征的呢?

请添加图片描述
先从右边 Wide 部分的特征看起,只利用了两个特征的交叉,这两个特征是“已安装应用”和“当前曝光应用”。这样一来,Wide 部分想学到的知识就非常直观,就是希望记忆好“如果 A 所以 B”这样的简单规则。在 Google Play 的场景下,就是希望记住“如果用户已经安装了应用 A,是否会安装 B”这样的规则。

再来看看左边的 Deep 部分,就是一个非常典型的 Embedding+MLP 结构。其中的输入特征很多,有用户年龄、属性特征、设备类型,还有已安装应用的 Embedding 等。把这些特征一股脑地放进多层神经网络里面去学习之后,它们互相之间会发生多重的交叉组合,这最终会让模型具备很强的泛化能力。

总的来说,Wide&Deep 通过组合 Wide 部分的线性模型和 Deep 部分的深度网络,取各自所长,就能得到一个综合能力更强的组合模型。

3. Wide&Deep 模型的代码实现

3.1 tensorflow实现

# wide and deep model architecture
# deep part for all input features
deep = tf.keras.layers.DenseFeatures(numerical_columns + categorical_columns)(inputs)
deep = tf.keras.layers.Dense(128, activation='relu')(deep)
deep = tf.keras.layers.Dense(128, activation='relu')(deep)
# wide part for cross feature
wide = tf.keras.layers.DenseFeatures(crossed_feature)(inputs)
both = tf.keras.layers.concatenate([deep, wide])
output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(both)
model = tf.keras.Model(inputs, output_layer)

Deep 部分,它是输入层加两层 128 维隐层的结构,它的输入是类别型 Embedding 向量和数值型特征。
Wide 部分直接把输入特征连接到了输出层就可以了。但是,这里要重点关注一下 Wide 部分所用的特征 crossed_feature。


movie_feature = tf.feature_column.categorical_column_with_identity(key='movieId', num_buckets=1001)
rated_movie_feature = tf.feature_column.categorical_column_with_identity(key='userRatedMovie1', num_buckets=1001)
crossed_feature = tf.feature_column.crossed_column([movie_feature, rated_movie_feature], 10000)

在 Deep 部分和 Wide 部分都构建完后,要使用 concatenate layer 把两部分连接起来,形成一个完整的特征向量,输入到最终的 sigmoid 神经元中,产生推荐分数。

3.2 pytorch实现

#Wide部分
class LR_Layer(nn.Module):
    def __init__(self,enc_dict):
        super(LR_Layer, self).__init__()
        self.enc_dict = enc_dict
        self.emb_layer = EmbeddingLayer(enc_dict=self.enc_dict,embedding_dim=1)
        self.dnn_input_dim = get_dnn_input_dim(self.enc_dict, 1)
        self.fc = nn.Linear(self.dnn_input_dim,1)
        
    def forward(self,data):
        sparse_emb = self.emb_layer(data)
        sparse_emb = torch.stack(sparse_emb,dim=1).flatten(1) #[batch,num_sparse*emb]
        dense_input = get_linear_input(self.enc_dict, data)  #[batch,num_dense]
        dnn_input = torch.cat((sparse_emb, dense_input), dim=1) # [batch,num_sparse*emb + num_dense]
        out = self.fc(dnn_input)
        return out
#DNN部分
class MLP_Layer(nn.Module):
    def __init__(self,
                 input_dim,
                 output_dim=None,
                 hidden_units=[],
                 hidden_activations="ReLU",
                 final_activation=None,
                 dropout_rates=0,
                 batch_norm=False,
                 use_bias=True):
        super(MLP_Layer, self).__init__()
        dense_layers = []
        if not isinstance(dropout_rates, list):
            dropout_rates = [dropout_rates] * len(hidden_units)
        if not isinstance(hidden_activations, list):
            hidden_activations = [hidden_activations] * len(hidden_units)
        hidden_activations = [set_activation(x) for x in hidden_activations]
        hidden_units = [input_dim] + hidden_units
        for idx in range(len(hidden_units) - 1):
            dense_layers.append(nn.Linear(hidden_units[idx], hidden_units[idx + 1], bias=use_bias))
            if batch_norm:
                dense_layers.append(nn.BatchNorm1d(hidden_units[idx + 1]))
            if hidden_activations[idx]:
                dense_layers.append(hidden_activations[idx])
            if dropout_rates[idx] > 0:
                dense_layers.append(nn.Dropout(p=dropout_rates[idx]))
        if output_dim is not None:
            dense_layers.append(nn.Linear(hidden_units[-1], output_dim, bias=use_bias))
        if final_activation is not None:
            dense_layers.append(set_activation(final_activation))
        self.dnn = nn.Sequential(*dense_layers)  # * used to unpack list

    def forward(self, inputs):
        return self.dnn(inputs)
#Wide&Deep
class WDL(nn.Module):
    def __init__(self,
                 embedding_dim=40,
                 hidden_units=[64, 64, 64],
                 loss_fun = 'torch.nn.BCELoss()',
                 enc_dict=None):
        super(WDL, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.hidden_units = hidden_units
        self.loss_fun = eval(loss_fun)
#         self.loss_fun = torch.nn.BCELoss()
        self.enc_dict = enc_dict
        
        self.embedding_layer = EmbeddingLayer(enc_dict=self.enc_dict, embedding_dim=self.embedding_dim)
        #Wide部分
        self.lr = LR_Layer(enc_dict=self.enc_dict)
        # Deep部分
        self.dnn_input_dim = get_dnn_input_dim(self.enc_dict, self.embedding_dim) # num_sprase*emb + num_dense
        self.dnn = MLP_Layer(input_dim=self.dnn_input_dim, output_dim=1, hidden_units=self.hidden_units,
                                 hidden_activations='relu', dropout_rates=0)
        
    def forward(self,data):
        #Wide
        wide_logit = self.lr(data) #Batch,1

        #Deep
        sparse_emb = self.embedding_layer(data)
        sparse_emb = torch.stack(sparse_emb,dim=1).flatten(1) #[Batch,num_sparse_fea*embedding_dim]

        dense_input = get_linear_input(self.enc_dict, data)
        dnn_input = torch.cat((sparse_emb, dense_input), dim=1)#[Batch,num_sparse_fea*embedding_dim+num_dense]
        deep_logit = self.dnn(dnn_input)

        #Wide+Deep
        y_pred = (wide_logit+deep_logit).sigmoid()
#         return y_pred

        #输出
        loss = self.loss_fun(y_pred.squeeze(-1),data['label'])
        output_dict = {'pred':y_pred,'loss':loss}
        return output_dict

最后,有个问题,什么样的特征应该放进wide,什么样的特征应该放进deep部分呢?

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/364951.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

设计模式.工厂模式.黑马跟学笔记

设计模式.工厂模式4.创建型模式4.2 工厂模式4.2.1 概述4.2.2 简单工厂模式4.2.2.1 结构4.2.2.2 实现4.2.2.4 优缺点4.2.2.3 扩展4.2.3 工厂方法模式4.2.3.1 概念4.2.3.2 结构4.2.3.3 实现4.2.3.4 优缺点4.2.4 抽象工厂模式4.2.4.1 概念4.2.4.2 结构4.2.4.2 实现4.2.4.3 优缺点4…

C语言进阶(六)—— 结构体

1. 结构体基础知识1.1 结构体类型的定义struct Person{char name[64];int age; };typedef struct _PERSON{char name[64];int age; }Person;注意:定义结构体类型时不要直接给成员赋值,结构体只是一个类型,编译器还没有为其分配空间&#xff0…

【Kubernetes 入门实战课】Day02——初识容器

系列文章目录 【Kubernetes 入门实战课】Day01——搭建kubernetes实验环境(一) 文章目录系列文章目录前言一、Docker的诞生二、Docker的形态1、Docker Desktop2、Docker Engine二、Docker的安装1、服务器连接外网安装2、服务器不通外网三、Docker的使用三、Docker的架构总结前…

JavaWeb11-死锁

目录 1.死锁定义 1.1.代码演示 1.2.使用jconsole/jvisualvm/jmc查看死锁 ①使用jconsole:最简单。 ②使用jvisualvm:(Java虚拟机)更方便,更直观,更智能,更高级,是合适的选择。 …

Melis4.0[D1s]:2.启动流程(GUI桌面加载部分)跟踪笔记

文章目录0. 控制台输出信息等级设置0.1 设置log level 4 无法正常启动1.宏观启动流程1.1 控制台入口函数finsh_thread_entry()执行《startup.sh》1.2 《startup.sh》启动桌面GUI模块1.2.1 《startup.sh》加载 desktop.mod1.2.2 desktop.mod加载 init.axf1.2.3 init.axf 介绍1.…

C#与三菱PLC MC协议通信,Java与三菱PLC MC协议通信

三菱PLC的MC协议是一种常用的通信协议,用于实现三菱PLC与其他设备之间的通信。以下是一些关于MC协议的基本信息:协议格式MC协议的通信数据格式如下:数据头网络编号PC编号目标模块IO编号目标模块站号本机模块IO编号本机模块站号请求数据长度请…

Linud SSH与SCP的配置

目录 配置SSH协议 配置服务器通过密钥进行认证 配置SCP完成文件传输 ssh协议讲解 SSH协议理论讲解_静下心来敲木鱼的博客-CSDN博客https://blog.csdn.net/m0_49864110/article/details/128500490?ops_request_misc%257B%2522request%255Fid%2522%253A%2522167704203816800…

不加大资金投入,仅凭智能名片如何解决企业营销难题的?

中国90%以上的中小企业想要竞争和发展,就必须推广自己的品牌,提高自己的知名度。在小程序之前,APP是主流,但大多数中小企业负担不起APP的开发和昂贵的营销成本。 进入微信互联网时代后,为了帮助企业以更低的成本获得…

浙大MEM现场小组复试经验分享

作为2019年上岸浙大MEM项目的学姐一枚,很高兴收到杭州达立易考教育老师的邀请,给大家分享下现场面试的经历。先来看下复试流程是怎么样的。1、体检所有考生须参加。体检需在复试前完成(未体检考生不得参加复试)。 2、资格审查&…

历时半年!从外包到现在阿里网易25K,分享一下自己的涨薪经验

前言 首先自我介绍一下,本人普通一本毕业,年初被老东家裁员干掉了,之后一直住在朋友那混吃等死,转折是今年年后,二月初的时候和大佬吃了个饭,觉得自己不能这样下去了,拿着某大佬给我的面试资料…

你知道IT运维的本质是什么吗?

大家好,我是技福的小咖老师。 之前看到个文章,说运维的本质是“可视化”,甚至还有人说是DevOps。不可否认,“可视化”是运维过程中非常重要的一个环节;DevOps则是开发运维一体化非常重要的工具。 究其根本&#xff0…

【09-JVM面试专题-实例化过程详细讲讲?对象的基本结构你知道吗?TLAB堆上内存分配是怎么样的?你了解这个TLAB吗?】

实例化过程详细讲讲?对象的基本结构你知道吗?TLAB堆上内存分配是怎么样的?你了解这个TLAB吗? 实例化过程详细讲讲?对象的基本结构你知道吗?TLAB堆上内存分配是怎么样的?你了解这个TLAB吗&#x…

EMR Studio Workspace 访问 Github ( 公网Git仓库 )

EMR Studio Workspace访问公网Git仓库 会遇到很多问题,由于EMR Studio不能给出任何有用的错误信息,导致排查起来非常麻烦。下面总结了若干项注意事项,可以避免踩坑。如果你遇到了同样的问题,请根据以下部分或全部建议去修正你的环境,问题即可解决。本文地址:https://laur…

pc端集成企业微信的扫码登录及遇到的问题

集成步骤: 1、在企业微信后台中添加应用 2、记录下应用的相关信息,在后文要用到 3、引入企业微信js 旧版:http://rescdn.qqmail.com/node/ww/wwopenmng/js/sso/wwLogin-1.0.0.js 新版(20210803更新):http…

kotlin学习教程

kotlin的方法 可以直接调用 不用 new? 2.kotlin关于字符串 用 $拼接变量 3.kotlin 类 方法 变量 可以同级的,同级的 方法 和 变量(常量) 是 生成了 一个新的 xxxKt.class ,并且都是 static的, 4.kotlin的类,方法,默认…

JVM调优方式

对JVM内存的系统级的调优主要的目的是减少GC的频率和Full GC的次数。 1.Full GC 会对整个堆进行整理,包括Young、Tenured和Perm。Full GC因为需要对整个堆进行回收,所以比较慢,因此应该尽可能减少Full GC的次数。 2.导致Full GC的原因 1)年老…

消息中间件

为什么要使用消息中间件同步通信:耗时长,受网络波动影响,不能保证高成功率,耦合性高。1.同步方式(耗时长):同步方式的问题:当一个用户提交订单到成功需要300ms300ms300ms20ms 920ms…

民锋国际期货:2023,既艰难又充满希望,既纷乱又有无数机会。

不管是官方还是民间,各种信号都表明,2023年是一个拼经济的年份。 通货膨胀带来的需求量的增加,与中国经济高速发展带来的供给量增加,二者共同构成了我们的物价。 做一个长期主义者,做一个坚定看好中国未来的人&#…

SpringBoot(powernode)(内含教学视频+源代码)

SpringBoot(powernode)(内含教学视频源代码) 教学视频源代码下载链接地址:https://download.csdn.net/download/weixin_46411355/87484637 目录SpringBoot(powernode)(内含教学视频…

AcWing3490.小平方——学习笔记

目录 题目 代码 AC结果 思路 题目 3490. 小平方 - AcWing题库https://www.acwing.com/problem/content/3493/ 代码 import java.util.Scanner;public class Main {public static void main(String[] args){Scanner input new Scanner(System.in);int target input.nextI…