排序模型进阶-WideDeepWDL模型导出

news2025/1/18 3:53:19

8.5 排序模型进阶-Wide&Deep

学习目标

  • 目标
  • 应用

8.5.1 wide&deep

 

 

  • Wide部分的输入特征:
    • raw input features and transformed features
    • notice: W&D这里的cross-product transformation:
    • 只在离散特征之间做组合,不管是文本策略型的,还是离散值的;没有连续值特征的啥事,至少在W&D的paper里面是这样使用的。
  • Deep部分的输入特征:
    • raw input+embeding处理
    • 非连续值之外的特征做embedding处理,这里都是策略特征,就是乘以个embedding-matrix。在TensorFlow里面的接口是:tf.feature_column.embedding_column,默认trainable=True.
    • 连续值特征的处理是:将其按照累积分布函数P(X≤x),压缩至[0,1]内。
    • notice:** Wide部分用FTRL+L1来训练;Deep部分用AdaGrad来训练。
  • Wide&Deep在TensorFlow里面的API接口为:tf.estimator.DNNLinearCombinedClassifier

代码:

import tensorflow as tf
from tensorflow.python import keras


class WDL(object):
    """wide&deep模型
    """
    def __init__(self):
        pass

    @staticmethod
    def read_ctr_records():
        # 定义转换函数,输入时序列化的
        def parse_tfrecords_function(example_proto):
            features = {
                "label": tf.FixedLenFeature([], tf.int64),
                "feature": tf.FixedLenFeature([], tf.string)
            }
            parsed_features = tf.parse_single_example(example_proto, features)

            feature = tf.decode_raw(parsed_features['feature'], tf.float64)
            feature = tf.reshape(tf.cast(feature, tf.float32), [1, 121])
            # 特征顺序 1 channel_id,  100 article_vector, 10 user_weights, 10 article_weights
            # 1 channel_id类别型特征, 100维文章向量求平均值当连续特征,10维用户权重求平均值当连续特征
            channel_id = tf.cast(tf.slice(feature, [0, 0], [1, 1]), tf.int32)
            vector = tf.reduce_sum(tf.slice(feature, [0, 1], [1, 100]), axis=1, keep_dims=True)
            user_weights = tf.reduce_sum(tf.slice(feature, [0, 101], [1, 10]), axis=1, keep_dims=True)
            article_weights = tf.reduce_sum(tf.slice(feature, [0, 111], [1, 10]), axis=1, keep_dims=True)

            label = tf.reshape(tf.cast(parsed_features['label'], tf.float32), [1, 1])

            # 构造字典 名称-tensor
            FEATURE_COLUMNS = ['channel_id', 'vector', 'user_weigths', 'article_weights']
            tensor_list = [channel_id, vector, user_weights, article_weights]

            feature_dict = dict(zip(FEATURE_COLUMNS, tensor_list))

            return feature_dict, label

        dataset = tf.data.TFRecordDataset(["./train_ctr_201905.tfrecords"])
        dataset = dataset.map(parse_tfrecords_function)
        dataset = dataset.shuffle(buffer_size=10000)
        dataset = dataset.repeat(10000)
        return dataset.make_one_shot_iterator().get_next()

    def build_estimator(self):
        """建立模型
        :param dataset:
        :return:
        """
        # 离散分类
        article_id = tf.feature_column.categorical_column_with_identity('channel_id', num_buckets=25)
        # 连续类型
        vector = tf.feature_column.numeric_column('vector')
        user_weigths = tf.feature_column.numeric_column('user_weigths')
        article_weights = tf.feature_column.numeric_column('article_weights')

        wide_columns = [article_id]

        # embedding_column用来表示类别型的变量
        deep_columns = [tf.feature_column.embedding_column(article_id, dimension=25),
                        vector, user_weigths, article_weights]

        estimator = tf.estimator.DNNLinearCombinedClassifier(model_dir="./ckpt/wide_and_deep",
                                                             linear_feature_columns=wide_columns,
                                                             dnn_feature_columns=deep_columns,
                                                             dnn_hidden_units=[1024, 512, 256])

        return estimator


if __name__ == '__main__':
    wdl = WDL()
    # dataset = wdl.read_ctr_records()
    # lwf.train(dataset)
    # lwf.train_v2(dataset)
    estimator = wdl.build_estimator()
    estimator.train(input_fn=wdl.read_ctr_records, steps=10000)
    # eval_result = estimator.evaluate(input_fn=wdl.read_ctr_records, steps=10000)
    # print(eval_result)

WDL模型导出

3.2 线上预估

线上流量是模型效果的试金石。离线训练好的模型只有参与到线上真实流量预估,才能发挥其价值。在演化的过程中,我们开发了一套稳定可靠的线上预估体系,提高了模型迭代的效率。

模型同步

我们开发了一个高可用的同步组件:用户只需要提供线下训练好的模型的HDFS路径,该组件会自动同步到线上服务机器上。该组件基于HTTPFS实现,它是美团离线计算组提供的HDFS的HTTP方式访问接口。同步过程如下:

  1. 同步前,检查模型md5文件,只有该文件更新了,才需要同步。
  2. 同步时,随机链接HTTPFS机器并限制下载速度。
  3. 同步后,校验模型文件md5值并备份旧模型。

同步过程中,如果发生错误或者超时,都会触发报警并重试。依赖这一组件,我们实现了在2min内可靠的将模型文件同步到线上。

模型计算

当前我们线上有两套并行的预估计算服务。

基于TF Serving的模型服务

TF Serving是TensorFlow官方提供的一套用于在线实时预估的框架。它的突出优点是:和TensorFlow无缝链接,具有很好的扩展性。使用TF serving可以快速支持RNN、LSTM、GAN等多种网络结构,而不需要额外开发代码。这非常有利于我们模型快速实验和迭代。

使用这种方式,线上服务需要将特征发送给TF Serving,这不可避免引入了网络IO,给带宽和预估时延带来压力。我们尝试了以下优化,效果显著。

  1. 并发请求。一个请求会召回很多符合条件的广告。在客户端多个广告并发请求TF Serving,可以有效降低整体预估时延。
  2. 特征ID化。通过将字符串类型的特征名哈希到64位整型空间,可以有效减少传输的数据量,降低使用的带宽。

导出代码:

# 导出serving_model
wide_columns = [tf.feature_column.categorical_column_with_identity('channel_id', num_buckets=25)]
deep_columns = [tf.feature_column.embedding_column(tf.feature_column.categorical_column_with_identity('channel_id', num_buckets=25), dimension=25),
                tf.feature_column.numeric_column('vector'),
                tf.feature_column.numeric_column('user_weigths'),
                tf.feature_column.numeric_column('article_weights')
                ]

columns = wide_columns + deep_columns
feature_spec = tf.feature_column.make_parse_example_spec(columns)
serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
estimator.export_savedmodel("./serving_model/wdl/", serving_input_receiver_fn)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/197863.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《从0开始学大数据》之构建一个大数据平台

在分布式系统中分发执行代码并启动执行,这样的计算方式必然不会很快,即使在一个规模不太大的数据集上进行一次简单计算,MapReduce 也可能需要几分钟,Spark 快一点,也至少需要数秒的时间。而互联网产品处理用户请求&…

微信短视频怎么提取gif?三步教你在线提取gif动画

现在各大社交软件上短视频是越来越多,为了方便传播、保存可以将短视频制作成GIF。那么,如何从视频中提取动图呢?很简单,两招就能完成在线视频转换成gif动图的操作,只需要使用【GIF中文网】的视频转gif(http…

OBS使用WebRTC进行腾讯云推流播流

推流: 首先:OBS想要推送WebRTC格式的推流需要满足以下两点: 1:OBS版本在26及以上 2:需要给OBS安装腾讯云插件,而且只支持Windows版。 OBS下载地址:Download | OBSDownload OBS Studio for W…

电路方案分析(十六)带有C2000微控制器且精度为 ±0.1° 的分立式旋转变压器前端参考设计

带有C2000微控制器且精度为 0.1 的分立式旋转变压器前端参考设计 tips:参考Ti设计资源:TIDA-01527 旋转变压器详细介绍: https://blog.csdn.net/qq_41600018/article/details/127597875?spm1001.2014.3001.5501 该参考方案设计…

用投资思维做好招商工作:湘商回归,长沙急企业之所急

在中国经济发展40年后,当下经济发展的底层逻辑已发生了根本性变化。企业发展所面临的问题,投资所思考的方向也已不同以往。一味再强调本地资源优势,介绍当地优惠政策的招商工作方式不再适应当下形式,往往反而会导致忽略企业的真实…

58.Isaac教程--OTG5 直线运动规划器

OTG5 直线运动规划器 ISAAC教程合集地址文章目录OTG5 直线运动规划器最大值和期望值的配置OTG5 的 Flatsim 演示用于直线运动的在线轨迹生成 - V 型 (OTG5) 规划器允许线性运动,同时明确防止曲线。 这在即使与一般直线运动方向有轻微偏差也会导致意外结果的情况下很…

Redis核心技术-高可靠-集群方案(客户端分片、代理分片、Redis Cluster)

Redis在3.0版本前只支持单实例模式,虽然Redis的开发者Antirez早在博客上就提出在Redis 3.0版本中加入集群的功能,但3.0版本等到2015年才发布正式版。 各大企业等不急了,在3.0版本还没发布前为了解决Redis的存储瓶颈,纷纷推出了各…

【每日一题】【LeetCode】【第二十四天】【Python】两个数组的交集 II

解决之路 题目描述 测试案例(部分) 第一次 顺着“两个数组的交集”的思路想,先用集合处理nums1和nums2,然后通过“交集”运算得出列表res,然后循环检查列表res,得出各个元素在两个数组中出现的最小次数&…

2022生化原理I复习资料汇总

文章目录1.2022复习重点及参考题2022年考试复习题:附录:参考答案及复习重点2.2021复习重点及参考题3.往年复习重点及参考题汇总4.复习重点整理及考试题型生化原理I复习资料及往年考题1.2022复习重点及参考题 2022年考试复习题: 1.2021-2022…

Python Stock安装与使用

这个是使用python 开发股票系统。 使用 tushare 获取股票数据,然后使用tornado 进行web 展示。 使用pandas numpy 数据处理。 项目代码 项目代码放到github上面 GitHub - pythonstock/stock: stock,股票系统。使用python进行开发。 因为为了简单&#x…

【Netty学习】七、详解ByteBuf缓冲区

七、详解ByteBuf缓冲区 为了确保引用计数不会混乱,在Netty的业务处理器开发过程中,应该坚持一个原则:retain和release方法应该结对使用。简单地说,在一个方法中,调用了retain,就应该调用一次release。 pub…

视图存储过程存储函数

文章目录视图常见数据库对象视图概述为什么使用视图?视图的理解创建视图创建单表视图创建多表联合视图基于视图创建视图查看视图更新视图的数据一般情况不可更新的视图修改、删除视图修改视图删除视图总结视图优点视图不足存储过程&存储函数存储过程概述理解分类…

NFT Insider #84:The Sandbox与华纳音乐集团合作举办全世界最大的DemoDrop,英超联赛签署NFT协议

引言:NFT Insider由NFT收藏组织WHALE Members、BeepCrypto联合出品,浓缩每周NFT新闻,为大家带来关于NFT最全面、最新鲜、最有价值的讯息。每期周报将从NFT市场数据,艺术新闻类,游戏新闻类,虚拟世界类&#…

[论文分享] How could Neural Networks understand Programs?

前言 读一篇 ICML 2021 的论文How could Neural Networks understand Programs? 程序语义理解是程序设计语言处理(PLP)的一个基本问题。最近基于NLP预训练技术学习代码表示的工作,推动了该方向的前沿。然而,PL和NL的语义有着本质的区别。忽略这些&…

CPP----精选常识100例

1 静态全局变量的作用域 本文件 2 判断一个程序是C还是C编译的 #ifdef __cpluspluscout << "c"; #else cout << "c"; #endif3 C函数传递方式 值传递&#xff0c;引用传递&#xff0c;指针传递 4 虚函数定义及用法 虚函数是C中用于实现多态(p…

vue2 a-tree-select树形结构-懒加载(无限子级)---笔记

实现效果 思维导图 HTML代码&#xff1a;treeData是绑定的数组&#xff0c;onLoadData是懒加载函数 <a-tree-select style"width: 100%; margin-left: 20px" tree-data-simple-mode multiplelabelInValueplaceholder"请选择…" v-decorator"[lea…

史上最详细的KMP算法教程,看这一篇就够了

&#x1f9d1;‍&#x1f4bb; 文章作者&#xff1a;Iareges &#x1f517; 博客主页&#xff1a;https://blog.csdn.net/raelum ⚠️ 转载请注明出处 目录一、BF算法二、KMP算法2.1 字符串基础2.2 next数组2.3 KMP的实现2.4 next数组的生成三、改进的KMP算法3.1 nextval数组3.…

turf.js实现行政区(多边形)图形合并边界提取,掩膜等效果

在做前端行政区展示的时候,可能经常会遇到这样的需求,就是给定一个行政区比如杭州市各个区,县的行政区边界图形,但是我们现在需要一个杭州市的行政区边界,我们是否可以通过前端合并这些行政区,答案当然是可以的,我们可以使用turf.js来实现这个需求。 turf官网:Turf.js…

纯滞后系统的数字Smith预估控制-2

在纯滞后系统的数字Smith预估控制-1的基础上进行Simulink仿真。采用 Simulink 进行数字化仿真&#xff0c;按Smith算法设计Simulink模块。在PI控制中&#xff0c;kp0.5&#xff0c;ki0.01。其响应结果如图1和图2所示。图1 Smith阶跃响应结果图2 只采用PI控制时的阶跃响应结果初…

CDA Level Ⅱ 模拟题(一)

单选1 练习题 【单选题】1/20 一项针对某城市小微企业税收扶持和税收种类的调查&#xff0c;本打算调查500个企业&#xff0c;但忽然发现税务中心数据库中已存有这项调查数据&#xff0c;并且可以有权限获取这份数据&#xff0c;请问这是什么类型的调查方式&#xff1f; A.分层…