推荐系统实战5——EasyRec 在DSSM召回模型中添加负采样构建CTR点击平台

news2024/12/25 22:31:17

推荐系统实战5——EasyRec 在DSSM召回模型中添加负采样构建CTR点击平台

  • 学习前言
  • EasyRec仓库地址
  • DSSM实现思路
    • 一、DSSM整体结构解析
    • 二、网络结构解析
      • 1、Embedding层的构建
      • 2、网络层的构建
      • 3、相似度计算
    • 三、训练部分解析
  • 训练自己的DSSM模型
    • 一、数据集的准备
    • 二、Config配置文件的设置
    • 三、开始网络训练
    • 四、训练结果的评估
    • 五、训练结果的预测
      • 1、训练模型的导出
      • 2、导出模型的预测

学习前言

当物品池很大上百万甚至是上亿的时候,不能仅考虑少量的正样本与负样本,因为物品太多,大多数物品都是负样本,此时双塔召回模型常常需要针对每个正样本采样一千甚至一万的负样本才能达到比较好的召回效果,
在这里插入图片描述

EasyRec仓库地址

官方库地址:
https://github.com/alibaba/EasyRec
带注释的Config地址:
https://github.com/bubbliiiing/EasyRec-Config

DSSM实现思路

一、DSSM整体结构解析

在这里插入图片描述
DSSM的论文地址为:
Learning Deep Structured Semantic Models for Web Search using Clickthrough Data

DSSM全称为Deep Structured Semantic Model,是微软发表的一篇论文,其核心思想是将user和item映射到到共同维度的语义空间中,通过最大化user和item语义向量之间的余弦相似度,达到检索的目的。

在推荐场景基于DSSM的双塔结构,根据user的特点与item的特点计算余弦相似度,进行个性化召回。

结构如上图所示,这是论文中的图,最左侧的输入为user特征,右侧所有的输入均为item特征,二者均经过了全连接与激活函数进行了非线性映射到共同维度,在获得高层次语义信息后,利用二者的语义信息计算余弦相似度,选择较高相似度的作为召回结果。

DSSM模型的思路非常简单,获得user和item的特征后,获得二者的相似度,然后进行推荐。但普通的DSSM存在一些小问题,面对现在市面上大多数的业务场景,很多业务的备选item都是上百万甚至上亿的,比如购物,选中的衣服只有十来款,没有选中的衣服有数十万款,双塔召回模型需要在items中针对每个正样本采样一千甚至一万的负样本才能达到比较好的召回效果,此时正负样本的比例巨大。

二、网络结构解析

1、Embedding层的构建

在这里插入图片描述
对于推荐系统而言,输入常常是字符串形式,因为不是矩阵,字符串本身无法被网络直接处理,EasyRec是基于tensorflow构建的,在tensorflow中,可以使用tf.string_to_hash_bucket_fast将输入进来的字符串转化成一个固定的数字。具体转换方式如下所示:

import tensorflow as tf
if tf.__version__ >= '2.0':
    tf = tf.compat.v1
sparse_id_values = tf.string_to_hash_bucket_fast("hello", 10)
# 此时的输出为:
sparse_id_values = 6

对任意一个字符串,我们都可以将其转化成固定的数字,这个数字处于0到hash_bucket_size之间,之后之后在代码中会建立一个可查询的embedding表,他的shape为:
(hash_bucket_size, embedding_dim)
这是一个hash_bucket_size行,embedding_dim列的矩阵,当我们通过一个字符串获得一个固定的数字后,我们会通过这个固定的数字sparse_id_values,获得其中第sparse_id_values行。

比如上述的例子中,我们假设hash_bucket_size等于10,embedding_dim等于32。如果输入的字符串为hello,我们获得的sparse_id_values=6。我们此时就会获取embedding表的第6行,作为这个数据的embedding。

在EasyRec的Config中,我们只需要在feature_config指定对应的标签名、embedding_dim、hash_bucket_size就读取数据,将数据转化成特定长度的Embedding了。

如下所示:

#------------------------------------------------------#
#   用于作为特征的数据,不包括label
#------------------------------------------------------#
feature_config: {
  features: {
    input_names: "hour"
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 50
  }
}

2、网络层的构建

在这里插入图片描述
在完成Embedding层的构建后,我们会将获取到的所有特征的Embedding结果通过concat堆叠到一起。因此,无论是user还是item最终都会变成一个指定长度的向量。

在本博文中,我们称它们为embedding后的特征,它们的shape为[batch_size, embedding_size]。作为一个二维矩阵,第一维度为batch_size即批次大小,第二维度为embedding_size即该个体的特征长度。

在DSSM中,无论是user还是item,它们的网络层构建并不复杂,只是普通的DNN(全连接神经网络)。本博文使用EasyRec自带的电商演示数据集为例子进行解析,每个个体存在若干的特征,每个特征embedding后的长度为16,user存在10个特征,item存在5个特征。由于我们在DSSM召回模型中添加负采样,所以user和item的batch_size个数不同,此处我们假设基础的batch_size为4096,负采样的个数为1024。
user在embedding后的特征堆叠为[4096, 160];
item在embedding后的特征堆叠为[5120, 80]。

即存在4096个基础user,每个user需要和4096个基础item和1024个负采样的item进行匹配。

对于user而言,输入特征在embedding后的shape为[4096, 160]。在经过四次Dense-Bn-ReLU的特征映射后,网络最终输出为[4096, 32]。
对于item而言,输入的embedding的batch_size为4096(与user配对)+ 1024(额外的负采样),假设embedding的特征后的特征长度为80。在经过四次Dense-Bn-ReLU的特征映射后,网络最终输出为[5120, 32]。

在EasyRec的Config中,我们只需要在model_config部分指定对应的模型名称、每个模型所需的特征以及每个模型的构建方式,就可以对user和item各自进行构建了。如下所示,model_class表示的是模型类别,feature_groups表示的是每个模型所需的特征,dssm是模型下的模型参数。

model_config:{
  model_class: "DSSM"
  feature_groups: {
    group_name: 'user'
    feature_names: 'user_id'
    feature_names: 'cms_segid'
    feature_names: 'cms_group_id'
    feature_names: 'age_level'
    feature_names: 'pvalue_level'
    feature_names: 'shopping_level'
    feature_names: 'occupation'
    feature_names: 'new_user_class_level'
    feature_names: 'tag_category_list'
    feature_names: 'tag_brand_list'
    wide_deep:DEEP
  }
  feature_groups: {
    group_name: "item"
    feature_names: 'adgroup_id'
    feature_names: 'cate_id'
    feature_names: 'campaign_id'
    feature_names: 'customer'
    feature_names: 'brand'
    feature_names: 'price'
    feature_names: 'pid'
    wide_deep:DEEP
  }
  dssm {
    user_tower {
      id: "user_id"
      dnn {
        hidden_units: [256, 128, 64, 32]
        # dropout_ratio : [0.1, 0.1, 0.1, 0.1]
      }
    }
    item_tower {
      id: "adgroup_id"
      dnn {
        hidden_units: [256, 128, 64, 32]
      }
    }
    l2_regularization: 1e-6
  }
  embedding_regularization: 5e-5
}

在EasyRec的源码中,DNN部分的构建代码为,在代码中,只是对hidden_units进行循环,循环时构建Dense、BN、Relu层:

# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import tensorflow as tf
from easy_rec.python.utils.load_class import load_by_path

if tf.__version__ >= '2.0':
  tf = tf.compat.v1
class DNN:
  def __init__(self, dnn_config, l2_reg, name='dnn', is_training=False):
    """Initializes a `DNN` Layer.

    Args:
      dnn_config: instance of easy_rec.python.protos.dnn_pb2.DNN
      l2_reg: l2 regularizer
      name: scope of the DNN, so that the parameters could be separated from other dnns
      is_training: train phase or not, impact batchnorm and dropout
    """
    self._config = dnn_config
    self._l2_reg = l2_reg
    self._name = name
    self._is_training = is_training
    logging.info('dnn activation function = %s' % self._config.activation)
    self.activation = load_by_path(self._config.activation)

  @property
  def hidden_units(self):
    return self._config.hidden_units

  @property
  def dropout_ratio(self):
    return self._config.dropout_ratio

  def __call__(self, deep_fea, hidden_layer_feature_output=False):
    hidden_units_len = len(self.hidden_units)
    if hidden_units_len == 1 and self.hidden_units[0] == 0:
      return deep_fea

    hidden_feature_dict = {}
    for i, unit in enumerate(self.hidden_units):
      deep_fea = tf.layers.dense(
          inputs=deep_fea,
          units=unit,
          kernel_regularizer=self._l2_reg,
          activation=None,
          name='%s/dnn_%d' % (self._name, i))
      if self._config.use_bn:
        deep_fea = tf.layers.batch_normalization(
            deep_fea,
            training=self._is_training,
            trainable=True,
            name='%s/dnn_%d/bn' % (self._name, i))
      deep_fea = self.activation(
          deep_fea, name='%s/dnn_%d/act' % (self._name, i))
      if len(self.dropout_ratio) > 0 and self._is_training:
        assert self.dropout_ratio[
            i] < 1, 'invalid dropout_ratio: %.3f' % self.dropout_ratio[i]
        deep_fea = tf.nn.dropout(
            deep_fea,
            keep_prob=1 - self.dropout_ratio[i],
            name='%s/%d/dropout' % (self._name, i))

      if hidden_layer_feature_output:
        hidden_feature_dict['hidden_layer' + str(i)] = deep_fea
        if (i + 1 == hidden_units_len):
          hidden_feature_dict['hidden_layer_end'] = deep_fea
          return hidden_feature_dict
    else:
      return deep_fea

因此无论是user还是item,每个个体最终都会变成一个指定长度的向量,长度在本博文中为32。在本博文中,我们将映射后的结果称为user和item的语义向量

在l2标准化后,就可以计算user和item各自的余弦相似度了。

if self._loss_type == LossType.CLASSIFICATION:
  if self._model_config.simi_func == Similarity.COSINE:
    user_tower_emb = self.norm(user_tower_emb)
    item_tower_emb = self.norm(item_tower_emb)

3、相似度计算

在这里插入图片描述
在网络层的构建中,我们获得了user和item的语义向量,它们的shape分别为[4096, 32]和[5120, 32]。我们此时需要计算语义向量之间的相似度。此时我们求取list_wise_sim。

首先对item的语义向量进行转置,此时可以获得一个[32, 5120]的矩阵。然后利用user的语义向量叉乘item转置后的语义向量,即 [ 4096 , 32 ] × [ 32 , 5120 ] [4096, 32]\times[32, 5120] [4096,32]×[32,5120],完成后可以获得一个[4096, 5120]的矩阵,代表了4096个user和5120个item(包括负样本)的相似度。

具体运算代码为:

simple_user_item_sim = tf.matmul(user_emb, tf.transpose(simple_item_emb))

在EasyRec中,还会对获取到的相似度进行下一轮缩放,其中sim_w和sim_b是可以训练的参数。

y_pred = user_item_sim * tf.abs(sim_w) + sim_b
self._prediction_dict['probs'] = tf.nn.softmax(y_pred)

最终在对y_pred取softmax后,将其再次映射到0-1之间。

此时,如果softmax越接近于1,代表该user和该item相似性很高,应当进行推荐;如果softmax越接近于0,代表该user和该item相似性很低,不应当进行推荐。

三、训练部分解析

在这里插入图片描述
网络的训练部分并不复杂,在网络计算好余弦相似度后,我们对最终取softmax的结果通过事先设定好的标签进行交叉熵的计算。

在输入进来数据时,假设batch_size为4096,4096对user和item是匹配的。在上述获得的[4096, 5120]矩阵中,对角线部分的user和item是配对的。
在这里插入图片描述
因此我们可以创建一个在对角线上为1的矩阵作为标签。然后利用多分类交叉熵求取损失。

在代码中的实现方式为:

batch_size = tf.shape(self._prediction_dict['probs'])[0]
indices = tf.range(batch_size)
indices = tf.concat([indices[:, None], indices[:, None]], axis=1)
hit_prob = tf.gather_nd(self._prediction_dict['probs'][:batch_size, :batch_size], indices)
self._loss_dict['cross_entropy_loss'] = -tf.reduce_mean(tf.log(hit_prob + 1e-12)) * self._sample_weight

训练自己的DSSM模型

在训练自己的DSSM模型之前,需要首先配置好EasyRec环境。

本博文以EasyRec自带的电商示例数据集进行解析。数据集位于EasyRec根目录下,分别位于下面两个位置。
“data/test/tb_data/taobao_train_data”
“data/test/tb_data/taobao_test_data”

电商示例数据集包含若干特征,保存在文本文件中,尽管后缀不是csv,但实际上是csv格式,具体如下所示:

clk:点击记录
buy:是否购买

以下为商品特征:
pid:商品pid码
adgroup_id:商品广告单元id
cate_id:商品种类id
campaign_id:商品公司id
customer:商品顾客
brand:商品品牌
price:商品价格

以下为用户特征:
user_id:用户id
cms_segid:微群ID
cms_group_id:一个特征
final gender code:性别
age level:年龄层次
pvalue level:消费档次
shopping level:购物深度
occupation:是否工作
new_user_class_level:城市等级
tag_category_list:点击的种类列表
tag_brand_list:点击的品牌列表

一、数据集的准备

本文使用文本格式进行训练,训练前需要自己制作好数据集,如果没有自己的数据集,可以通过示例的电商数据集进行尝试。

准备好的数据集一般存放在data/test文件夹中。在示例数据集里:

data/test/tb_data/taobao_train_data代表的是训练集,模型基于该文件进行梯度下降。
data/test/tb_data/taobao_test_data代表的是验证集(测试集),这里不单独划分一个测试集,验证集和测试集共用。

csv中直接存放特征的值即可,特征之间以’,'隔开,不需要存放特征名,如图所示,每一列的数据代表什么特征我们自己需要清楚。
在这里插入图片描述

二、Config配置文件的设置

Config配置文件中需要设置多方面的内容,采用prototxt格式,配置顺序为:
数据集的地址、模型保存的地址、训练相关参数设置、评估情况、数据集内容情况、数据集特征情况、模型情况。

具体的构建方式如下:

#------------------------------------------------------#
#   训练用的数据文件地址
#------------------------------------------------------#
train_input_path: "data/test/tb_data/taobao_train_data"
#------------------------------------------------------#
#   评估用的数据文件地址
#------------------------------------------------------#
eval_input_path: "data/test/tb_data/taobao_test_data"
#------------------------------------------------------#
#   训练好的权值保存的路径
#------------------------------------------------------#
model_dir: "experiments/dssm_neg_sampler_taobao_ckpt"

#------------------------------------------------------#
#   训练相关的参数
#------------------------------------------------------#
train_config {
  log_step_count_steps: 100
  #------------------------------------------------------#
  #   optimizer_config      优化器参数
  #------------------------------------------------------#
  optimizer_config: {
    #------------------------------------------------------#
    #   adam_optimizer                    Adam优化器
    #   learning_rate                     学习率下降方式
    #   exponential_decay_learning_rate   指数下降
    #   initial_learning_rate             初始学习率
    #   decay_steps                       学习率衰减步长
    #   decay_factor                      衰减倍数
    #   min_learning_rate                 最低学习率
    #------------------------------------------------------#
    adam_optimizer: {
      learning_rate: {
        exponential_decay_learning_rate {
          initial_learning_rate : 0.001
          decay_steps           : 1000
          decay_factor          : 0.5
          min_learning_rate     : 0.00001
        }
      }
    }
    use_moving_average: false
  }
  #------------------------------------------------------#
  #   sync_replicas             
  #   save_checkpoints_steps    保存周期
  #   log_step_count_steps      log记录周期
  #   num_steps                 总训练步长
  #------------------------------------------------------#
  sync_replicas           : true
  save_checkpoints_steps  : 100
  log_step_count_steps    : 100
  num_steps               : 2500
}

#------------------------------------------------------#
#   评估参数
#   推荐系统一般使用AUC进行评估
#------------------------------------------------------#
eval_config {
  metrics_set: {
    auc {}
  }
}

#------------------------------------------------------#
#   数据集的各类数据情况
#------------------------------------------------------#
data_config {
  #------------------------------------------------------#
  #   separator   代表分隔符,默认为","
  #------------------------------------------------------#
  separator: ","
  #------------------------------------------------------#
  #   需要注意的时,此处的数据顺序需要和csv中一样。
  #   input_name  代表该列数据的名称
  #   input_type  代表该列数据的数据类别,默认是STRING。
  #   default_val 代表默认值,可以不设置
  #------------------------------------------------------#
  input_fields {
    input_name:'clk'
    input_type: INT32
  }
  input_fields {
    input_name:'buy'
    input_type: INT32
  }
  input_fields {
    input_name: 'pid'
    input_type: STRING
  }
  input_fields {
    input_name: 'adgroup_id'
    input_type: STRING
  }
  input_fields {
    input_name: 'cate_id'
    input_type: STRING
  }
  input_fields {
    input_name: 'campaign_id'
    input_type: STRING
  }
  input_fields {
    input_name: 'customer'
    input_type: STRING
  }
  input_fields {
    input_name: 'brand'
    input_type: STRING
  }
  input_fields {
    input_name: 'user_id'
    input_type: STRING
  }
  input_fields {
    input_name: 'cms_segid'
    input_type: STRING
  }
  input_fields {
    input_name: 'cms_group_id'
    input_type: STRING
  }
  input_fields {
    input_name: 'final_gender_code'
    input_type: STRING
  }
  input_fields {
    input_name: 'age_level'
    input_type: STRING
  }
  input_fields {
    input_name: 'pvalue_level'
    input_type: STRING
  }
  input_fields {
    input_name: 'shopping_level'
    input_type: STRING
  }
  input_fields {
    input_name: 'occupation'
    input_type: STRING
  }
  input_fields {
    input_name: 'new_user_class_level'
    input_type: STRING
  }
  input_fields {
    input_name: 'tag_category_list'
    input_type: STRING
  }
  input_fields {
    input_name: 'tag_brand_list'
    input_type: STRING
  }
  input_fields {
    input_name: 'price'
    input_type: INT32
  }

  #------------------------------------------------------#
  #   列名必须在data_config中出现过,代表为标签
  #------------------------------------------------------#
  label_fields: 'clk'
  
  #------------------------------------------------------#
  #   batch_size    批次大小
  #   prefetch_size 提高数据加载的速度,防止数据瓶颈
  #	  num_epochs    训练时取num_steps和num_epochs中的小值
  #					看哪个先达到就结束
  #------------------------------------------------------#
  batch_size: 4096
  num_epochs: 10000
  prefetch_size: 32
  #---------------------------------------------------------------------------#
  #   CSVInput      表示数据格式是CSV,注意要配合separator使用
  #   OdpsInputV2   如果在MaxCompute上运行EasyRec, 则使用OdpsInputV2
  #   OdpsInputV3   如果在本地或者EMR上访问MaxCompute Table, 则使用OdpsInputV3
  #---------------------------------------------------------------------------#
  input_type: CSVInput
  
  #---------------------------------------------------------------------------#
  #   负采样的数据集地址
  #---------------------------------------------------------------------------#
  negative_sampler {
    input_path: 'data/test/tb_data/taobao_ad_feature_gl'
    num_sample: 1024
    num_eval_sample: 2048
    attr_fields: 'adgroup_id'
    attr_fields: 'cate_id'
    attr_fields: 'campaign_id'
    attr_fields: 'customer'
    attr_fields: 'brand'
    item_id_field: 'adgroup_id'
  }
} 

#------------------------------------------------------#
#   用于作为特征的数据,不包括label
#------------------------------------------------------#
feature_config: {
  #---------------------------------------------------------------------------#
  #   具体设置可参考https://easyrec.readthedocs.io/en/latest/feature/feature.html
  #   input_names       代表该列数据的名称
  #   feature_type      特征类别
  #   embedding_dim     该列数据在经过Embedding处理后的特征长度
  #   hash_bucket_size  将变量hash之后去模
  #---------------------------------------------------------------------------#
  features: {
    input_names: 'pid'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 10
  }
  features: {
    input_names: 'adgroup_id'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 100000
  }
  features: {
    input_names: 'cate_id'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 10000
  }
  features: {
    input_names: 'campaign_id'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 100000
  }
  features: {
    input_names: 'customer'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 100000
  }
  features: {
    input_names: 'brand'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 100000
  }
  features: {
    input_names: 'user_id'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 100000
  }
  features: {
    input_names: 'cms_segid'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 100
  }
  features: {
    input_names: 'cms_group_id'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 100
  }
  features: {
    input_names: 'final_gender_code'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 10
  }
  features: {
    input_names: 'age_level'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 10
  }
  features: {
    input_names: 'pvalue_level'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 10
  }
  features: {
    input_names: 'shopping_level'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 10
  }
  features: {
    input_names: 'occupation'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 10
  }
  features: {
    input_names: 'new_user_class_level'
    feature_type: IdFeature
    embedding_dim: 16
    hash_bucket_size: 10
  }
  features: {
     input_names: 'tag_category_list'
     feature_type: TagFeature
     separator: '|'
     hash_bucket_size: 100000
     embedding_dim: 16
  }
  features: {
     input_names: 'tag_brand_list'
     feature_type: TagFeature
     separator: '|'
     hash_bucket_size: 100000
     embedding_dim: 16
  }
  features: {
    input_names: 'price'
    feature_type: IdFeature
    embedding_dim: 16
    num_buckets: 50
  }
}

#------------------------------------------------------#
#   模型参数设置
#------------------------------------------------------#
model_config:{
  #------------------------------------------------------#
  #   模型种类
  #------------------------------------------------------#
  model_class: "DSSM"
  #------------------------------------------------------#
  #   group_name      指定组名
  #   feature_names   该组的特征
  #   wide_deep       模型的记忆能力和泛化能力
  #------------------------------------------------------#
  feature_groups: {
    group_name: 'user'
    feature_names: 'user_id'
    feature_names: 'cms_segid'
    feature_names: 'cms_group_id'
    feature_names: 'age_level'
    feature_names: 'pvalue_level'
    feature_names: 'shopping_level'
    feature_names: 'occupation'
    feature_names: 'new_user_class_level'
    feature_names: 'tag_category_list'
    feature_names: 'tag_brand_list'
    wide_deep:DEEP
  }
  feature_groups: {
    group_name: "item"
    feature_names: 'adgroup_id'
    feature_names: 'cate_id'
    feature_names: 'campaign_id'
    feature_names: 'customer'
    feature_names: 'brand'
    wide_deep:DEEP
  }
  #------------------------------------------------------#
  #   user_tower          user塔
  #   item_tower          item塔
  #   dnn                 代表全连接网络的神经元个数
  #   l2_regularization   l2正则化情况
  #------------------------------------------------------#
  dssm {
    user_tower {
      id: "user_id"
      dnn {
        hidden_units: [256, 128, 64, 32]
        # dropout_ratio : [0.1, 0.1, 0.1, 0.1]
      }
    }
    item_tower {
      id: "adgroup_id"
      dnn {
        hidden_units: [256, 128, 64, 32]
      }
    }
    simi_func: INNER_PRODUCT
    scale_simi: true
    l2_regularization: 1e-6
  }
  loss_type: SOFTMAX_CROSS_ENTROPY
  embedding_regularization: 5e-5
}

export_config {
}

三、开始网络训练

设置好训练所需的config后,就可以开始模型的训练了,单卡用户可以使用如下指令进行训练:

CUDA_VISIBLE_DEVICES=0 python -m easy_rec.python.train_eval --pipeline_config_path samples/model_config/dssm_neg_sampler_on_taobao.config

CPU用户可以使用如下指令进行训练:

python -m easy_rec.python.train_eval --pipeline_config_path samples/model_config/dssm_neg_sampler_on_taobao.config

四、训练结果的评估

在完成模型的训练后,我们可以使用如下指令进行评估:

CUDA_VISIBLE_DEVICES=0 python -m easy_rec.python.eval --pipeline_config_path samples/model_config/dssm_neg_sampler_on_taobao.config

CPU用户可以使用如下指令进行评估:

python -m easy_rec.python.eval --pipeline_config_path samples/model_config/dssm_neg_sampler_on_taobao.config

五、训练结果的预测

1、训练模型的导出

无论是否使用GPU,都可以使用以下代码将模型导出为PB模式。下列指令的导出路径为dssm_on_taobao_export。

CUDA_VISIBLE_DEVICES='' python -m easy_rec.python.export --pipeline_config_path samples/model_config/dssm_neg_sampler_on_taobao.config --export_dir experiments/dssm_neg_sampler_taobao_ckpt/export/final

2、导出模型的预测

在完成模型的导出后,就可以利用导出的模型进行预测了,离线预测方式如下:

CUDA_VISIBLE_DEVICES=0 python -m easy_rec.python.predict --input_path 'data/test/tb_data/taobao_test_data' --output_path 'data/test/taobao_test_data_pred_result' --saved_model_dir experiments/dssm_neg_sampler_taobao_ckpt/export/final --reserved_cols 'ALL_COLUMNS' --output_cols 'ALL_COLUMNS'

在这里中涉及到大量的参数,常用的参数如下。

  • input_path: 输入文件路径;
  • output_path: 输出文件路径,不需要提前创建,会自动创建;
  • save_modeld_dir: 导出的模型目录;
  • reserved_cols: 输入文件需要拷贝到输出文件的列,默认为’ALL_COLUMNS’,则所有的列都被copy到输出文件中。如果不想输入文件拷贝任何信息到输出文件,可以设置其为’';
  • output_cols: 输出文件自身需要保留预测结果中的列,默认’ALL_COLUMNS’,则所有的列都被copy到输出文件中,可以使用下方的设置方式output_cols=“probs double”,代表输出probs,数据类型为double;
  • input_sep: 输入文件的分隔符,默认",";
  • output_sep: 输出文件的分隔符,默认"|"。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/148434.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一键生成分享链接的贺卡制作工具

不用自己动手设计&#xff0c;在线模板帮你轻松搞定新春贺卡设计&#xff0c;免下载的设计工具。跟着小编的设计教程&#xff0c;教你如何使用乔拓云工具&#xff0c;在线搞定你的新春祝福贺卡设计&#xff0c;不用任何设计经验&#xff0c;只需要跟着教程就能搞定的专属贺卡设…

论文笔记:RCLane: Relay Chain Prediction for Lane Detection

RCLane: Relay Chain Prediction for Lane Detection笔记摘要动机模型结构方法其他模型试验结果笔记摘要 该篇论文的核心创新点在于head。论文根据车道线既需要局部信息&#xff0c;也需要全局信息才能很好拟合的特性&#xff0c;设计了相应的算法head。并且论文实验证明该方法…

机器视觉(十一):条码识别

目录&#xff1a; 机器视觉&#xff08;一&#xff09;&#xff1a;概述 机器视觉&#xff08;二&#xff09;&#xff1a;机器视觉硬件技术 机器视觉&#xff08;三&#xff09;&#xff1a;摄像机标定技术 机器视觉&#xff08;四&#xff09;&#xff1a;空域图像增强 …

记一次虚拟机编译c程序错误

file included from /usr/include/stdio.h:74:0, from opendir.c:2: /usr/include/libio.h:302:3: error: unknown type name ‘size_t’ size_t __pad5; ^ /usr/include/libio.h:305:67: error: ‘size_t’ undeclared here (not in a function) ch…

黑马程序员 Maven 教程

Maven 简介 传统项目管理的缺点&#xff1a; (1) jar 包不统一&#xff0c;jar 包不兼容; (2) 工程升级维护过程操作繁琐; Maven 是什么 Maven 的本质是一个项目管理工具&#xff0c;将项目开发和管理过程抽象成一个项目对象模型 (POM) POM (Project Object Model) : 项目对…

二分搜索算法

目录1.概述2.代码实现2.1.最基本的二分搜索2.2.搜索最左侧边界2.3.搜索最右侧边界3.应用本文参考&#xff1a; LABULADONG 的算法网站 《大话数据结构》 1.概述 &#xff08;1&#xff09;二分搜索 (Binary Search)&#xff0c;又称为折半搜索 (Half-interval Search)。它的前…

云收藏系统|基于Springboot实现云收藏系统

作者主页&#xff1a;编程指南针 作者简介&#xff1a;Java领域优质创作者、CSDN博客专家 、掘金特邀作者、多年架构师设计经验、腾讯课堂常驻讲师 主要内容&#xff1a;Java项目、毕业设计、简历模板、学习资料、面试题库、技术互助 收藏点赞不迷路 关注作者有好处 文末获取源…

Java实现队列

目录 一、队列概述 二、队列的模拟实现 1、入队 2、出队 3、取队头元素 4、获取队列长度 三、循环队列 1、入队 2、出队 3、取队头元素 4、取队尾元素 四、面试题 1、用队列实现栈 2、用栈实现队列 一、队列概述 队列也是常见的数据结构&#xff0c;是一…

Mybatis源码解析二:DataSource数据源负责创建连接以及Transaction的事物管理

简介 对于一个成熟的ORM框架来说&#xff0c;数据源的管理以及事务的管理一定是不可或缺的组成&#xff0c;对于Mybatis来说&#xff0c;为了使用方便以及扩展简单也是做了一系列的封装&#xff0c;这一篇主要介绍mybatis是如何管理数据源以及事务的。 数据源DataSource Dat…

【深度学习】李宏毅2021/2022春深度学习课程笔记 - Adversarial Attack(恶意攻击)

文章目录一、基本概念1.1 动机1.2 恶意攻击的例子1.3 如何攻击&#xff1f;二、White Box vs Black Box三、One Pixel Attack四、Universal Adversarial Attack五、Beyond Image六、Attack in the Physical World七、Adversarial Reprogramming八、Backdoor in Model九、防御9.…

TLS回调函数实现反调试

title: TLS回调函数实现反调试.md date: 2022-06-16 23:40:49.231 updated: 2022-06-16 23:41:11.924 url: /archives/tls回调函数实现反调试 categories: tags: 逆向 TLS回调函数实现反调试 TLS-线程局部存储 先于我们OEP执行 #include<stdlib.h> #include<time.…

使用红黑树封装map、set

map、set如何用红黑树封装 map、set应用&#xff1a;map是一个使用参数K、参数V的类模板&#xff0c;set是只使用参数K的类模板。因为map应用时&#xff0c;需要使用到KV&#xff0c;而set只是存单个值&#xff0c;K。红黑树类的存储 &#xff1a;map和set类中使用红黑树数据成…

Logback配置详解

简介&#xff1a; logback是java的日志开源组件&#xff0c;是log4j创始人写的&#xff0c;性能比log4j要好&#xff0c;目前主要分为3个模块&#xff1a; logback-core:核心代码模块logback-classic:log4j的一个改良版本&#xff0c;同时实现了slf4j的接口&#xff0c;这样你…

树莓派mjpg-streamer实现监控功能

树莓派实现监控功能&#xff0c;调用mjpg-streamer库来实现。mjpg-streamer是一个开源的摄像头媒体流&#xff0c;通过本地获取摄像头的数据&#xff0c;通过http通讯发送&#xff0c;可以通过浏览器访问树莓派的IP地址和端口号就能看到视频流。 实现步骤 1.git clone https:…

关于内核的概念理解

狭义的操作系统可以认为就是内核&#xff0c;比如Linux内核。广义的操作系统则包括内核和一系列应用软件&#xff0c;比如Linux内核编辑器vim编译器gcc命令行解释器&#xff08;shell&#xff09;等&#xff0c;通常称为GNU/Linux。 源代码https://github.com/torvalds/Linux …

Jenkins自动化部署SpringBoot项目(windows环境)

文章目录1、Jenkins介绍1.1、概念1.2、优势1.3、Jenkins目的2、环境准备3、Jenkins下载3.1、下载3.2、运行3.3、问题解决4、Jenkins配置4.1、用户配置4.2、系统配置4.3、全局工具配置-最重要5、新建项目7、测试8、错误解决1、Jenkins介绍 1.1、概念 Jenkins是一个开源软件项目…

自动化测试Seleniums~1

一.什么是自动化测试 1.自动化测试介绍 自动化测试指软件测试的自动化&#xff0c;在预设状态下运行应用程序或者系统&#xff0c;预设条件包括正常和异常&#xff0c;最后评估运行结果。将人为驱动的测试行为转化为机器执行的过程。 将测试人员双手解放&#xff0c;将部分测…

黑马javaWeb Brand综合案例

01-综合案例-环境搭建 02-查询所有-后台&前台

leetcode83周赛

前言&#xff1a; 周赛两题选手,有点意思 830.较大分组的位置 思路&#xff1a;wa了三发&#xff0c;对边界了解的不够清楚 可以有一个小小的优化,时间复杂度O(n) // arr.add(start); //arr.add(i-1); //res.add(arr); res.add(Arrays.asList(start,i - 1));class Solution {pu…

MATLAB-mesh/ezmesh函数三维图形绘制

l ) mesh 函数生成由X、Y和Z指定的网线面&#xff0c;由C指定颜色的三维网格图。具体调用方法如下。mesh(Z):分别以矩阵Z的行、列下标作为x轴和y轴的自变量绘图。mesh(X , Y,Z):最常用的一般调用格式。mesh(X,Y ,Z,C):完整的调用格式&#xff0c;C用于指定图形的颜色&#xff0…