推荐系统(十五):基于双塔模型的多目标商品召回/推荐系统

news2025/3/31 11:23:53

在电商推荐场景中,用户行为通常呈现漏斗形态:曝光→点击→转化。本文基于TensorFlow构建了一个支持多任务学习的双塔推荐模型,可同时预测点击率(CTR)和转化率(CVR)。通过用户塔和商品塔的分离式设计,模型既能捕捉用户兴趣偏好,又能理解商品特征,最终通过向量相似度计算生成推荐结果。

1.系统架构

参考文章《推荐系统(六)双塔模型》

2.核心实现步骤

2.1 模拟数据构造

生成包含100用户、200商品和1000次交互的模拟数据集:

"""
Part-1:模拟数据构造

本部分模拟真实场景,人工构造用户数据、商品数据、用户-商品交互数据(点击、购买),并进行必要的预处
"""
# 计算机生成的随机数本质是伪随机数,由算法基于初始种子值(seed)生成固定序列。设置相同的种子会得到相同的随机数序列
# 在机器学习中,随机性会影响模型训练、数据划分(如训练集/测试集分割)、参数初始化等环节。设置种子后,多次运行代码会得到相同结果,便于调试和验证
np.random.seed(42)
tf.random.set_seed(42)
# 用户特征:用户ID、年龄、性别、职业
# 商品特征:商品ID、类别、品牌、价格
num_users = 100
num_items = 200
num_interactions = 1000

# 生成用户特征
user_data = {
   
    'user_id': np.arange(1, num_users + 1),
    'user_age': np.random.randint(18, 65, size=num_users),
    'user_gender': np.random.choice(['male', 'female'], size=num_users),
    'user_occupation': np.random.choice(['student', 'worker', 'teacher'], size=num_users),
    'city_code': np.random.randint(1, 2856, size=num_users),  # 城市编码,中国有 2856 个城市
    'device_type': np.random.randint(0, 5, size=num_users)  # 设备类型(0=Android,1=iOS等)
}

# 生成商品特征
item_data = {
   
    'item_id': np.arange(1, num_items + 1),
    'item_category': np.random.choice(['electronics', 'books', 'clothing'], size=num_items),
    'item_brand': np.random.choice(['brandA', 'brandB', 'brandC'], size=num_items),
    'item_price': np.random.randint(1, 199, size=num_items)
}

# 生成用户-商品交互数据
# 包括:点击和转化(购买)数据
interactions = []
for _ in range(num_interactions):
    user_id = np.random.randint(1, num_users + 1)
    item_id = np.random.randint(1, num_items + 1)
    # 点击标签。0: 未点击, 1: 点击。在真实场景中可通过客户端埋点上报获得用户的点击行为数据
    click_label = np.random.randint(0, 2)
    # 转化标签。由于转化的前提是点击,因此点击和转化之间是一个漏斗关系——转化显著低于点击
    conversion_label = 0
    if click_label == 1:
        conversion_label = np.random.binomial(1, 0.3)  # 假设点击后30%转化率
    interactions.append([user_id, item_id, click_label, conversion_label])

interaction_df = pd.DataFrame(
    interactions,
    columns=['user_id', 'item_id', 'click_label', 'conversion_label'])

# 合并用户特征、商品特征和交互数据
user_df = pd.DataFrame(user_data)
item_df = pd.DataFrame(item_data)
df = interaction_df.merge(user_df, on='user_id').merge(item_df, on='item_id')

# 划分训练集和测试集
labels = df[['click_label', 'conversion_label']]
features = df.drop(['click_label', 'conversion_label'], axis=1)
train_features, test_features, train_labels, test_labels = train_test_split(
    features, labels,
    test_size=0.2,
    random_state=42
)

2.2 特征工程

对不同类型的特征进行差异化处理:
在这里插入图片描述

"""
Part-2:特征工程

本部分对原始用户数据、商品数据、用户-商品交互数据进行分类处理,加工为模型训练需要的特征
    1.数值型特征:如用户年龄、价格,少数场景下可直接使用,但最好进行标准化,从而消除量纲差异
    2.类别型特征:需要进行 Embedding 处理
    3.ID类特征:如用户ID、商品ID,属于高维稀疏特征,需要embedding处理为低维稠密特征

关于 Embedding 处理:
    1.无论是通过tf.keras.layers.Embedding还是feature_column.embedding_column,Embedding层的初始值通常是随机生成的(例如均匀分布或截断正态分布)
    2.在模型训练过程中,Embedding向量会通过反向传播不断更新,使得模型的预测结果(如用户-物品相似度)与目标(如点击标签)更接近
    3.训练后的Embedding向量会收敛到某种有意义的表示,与初始化的随机值完全不同
    
关于标准化处理:
    1.如果使用 feature_column 的 normalizer_fn:模型自动处理,无需手动干预
    2.如果手动标准化:必须保存训练阶段的参数(均值和标准差),并在预测时加载这些参数进行标准化
"""
""" 
用户特征预处理 
"""
# 高维稀疏特征处理
# 过程:先将用户ID定义为类别型特征,num_buckets=num_users 表示用户ID的取值范围是 [0, num_users-1] 的整数;然后,embedding处理
# 注意:在模拟数据中用户和商品数量较少(100用户/200商品),直接使用 ID embedding 容易导致尾部 ID 无法充分训练
# 双塔模型通常需要权衡记忆(ID特征)与泛化(属性特征)能力
user_id = feature_column.categorical_column_with_identity('user_id', num_buckets=num_users)
user_id_emb = feature_column.embedding_column(user_id, dimension=8)
# 数值特征处理
# StandardScaler 是 Scikit-Learn 提供的标准化工具,它会将数据转换为均值为 0、标准差为 1 的分布。
# 标准化(或采用归一化)可以消除不同特征间的量纲差异(例如年龄范围是 0-100,价格范围是 0-10000),使模型训练更稳定
scaler_age = StandardScaler()
df['user_age'] = scaler_age.fit_transform(df[['user_age']])
user_age = feature_column.numeric_column('user_age')

# 类别特征处理
# 先映射,后嵌入,生成低维稠密向量
# 将性别字符串(如“male”“female”)映射为整数ID,输入数据中的性别字符串会被转换为 0(male)或 1(female),然后进行嵌入转换,生成低维稠密向量
user_gender = feature_column.categorical_column_with_vocabulary_list(
    'user_gender', ['male', 'female'])
user_gender_emb = feature_column.embedding_column(user_gender, dimension=2)
# 将职业字符串映射为整数ID(如“student”→0,“worker”→1,依此类推),然后进行嵌入转换,生成低维稠密向量
user_occupation = feature_column.categorical_column_with_vocabulary_list(
    'user_occupation', ['student', 'worker', 'teacher'])
user_occupation_emb = feature_column.embedding_column(user_occupation, dimension=2)
# 用户所在城市编码embedding
# 城市ID的可能取值范围(1到2855,共2855个值,需设置为max_id + 1)
city_code_column = feature_column.categorical_column_with_identity(key='city_code', num_buckets=2856)
city_code_emb = feature_column.embedding_column(city_code_column, dimension=8)
# 用户设备编码embedding
device_types_column = feature_column.categorical_column_with_identity(key='device_type', num_buckets=5)
device_types_emb = feature_column.embedding_column(device_types_column, dimension=8)

""" 
商品特征预处理 
"""
# 高维稀疏特征处理
# 与 user_id 类似,商品ID被定义为 [0, num_items-1] 的整数类别
item_id = feature_column.categorical_column_with_identity('item_id', num_buckets=num_items)
item_id_emb = feature_column.embedding_column(item_id, dimension=8)

# 数值特征处理
# StandardScaler 是 Scikit-Learn 提供的标准化工具,它会将数据转换为均值为 0、标准差为 1 的分布。
# 标准化(或采用归一化)可以消除不同特征间的量纲差异(例如年龄范围是 0-100,价格范围是 0-10000),使模型训练更稳定
scaler_price = StandardScaler()
df['item_price'] = scaler_price.fit_transform(df[['item_price']])
item_price = feature_column.numeric_column('item_price')

# 类别特征处理
# 先映射,后嵌入,生成低维稠密向量
# 分别将商品类别和品牌字符串映射为整数ID,(如“electronics”→0,“books”→1,依此类推),然后进行嵌入转换,生成低维稠密向量
item_category = feature_column.categorical_column_with_vocabulary_list(
    'item_category', ['electronics', 'books', 'clothing'])
item_category_emb = feature_column.embedding_column(item_category, dimension=2)
item_brand = feature_column.categorical_column_with_vocabulary_list(
    'item_brand', ['brandA', 'brandB', 'brandC'])
item_brand_emb = feature_column.embedding_column(item_brand, dimension=2)

# 打印前几行数据以查看结构(设置display.max_columns为None,强制显示全部列)
pd.set_option('display.max_columns', None)
print(df.head())

2.3 双塔模型架构

模型采用分离式双塔结构,最后通过点积计算相似度:

"""
Part-3:模型架构设计

    1.User_Tower + Item_Tower 构成"双塔"结构
    2.多任务学习:基于TensorFlow Estimator构建多任务学习模型,主要用于同时预测点击率(CTR)和转化率(CVR)
"""
# 用户塔特征列
user_tower_columns = [
    user_id_emb,
    user_age,
    user_gender_emb,
    user_occupation_emb,
    city_code_emb,
    device_types_emb
]

# 商品塔特征列
item_tower_columns = [
    item_id_emb,
    item_category_emb,
    item_brand_emb,
    item_price
]


# 自定义多任务模型
def model_fn(features, labels, mode, params):
    """
    自定义多任务模型:基于TensorFlow Estimator的多任务学习模型,主要用于同时预测点击率(CTR)和转化率(CVR)
    """
    # 通过 DenseFeatures 层,将不同的特征列(用户塔和商品塔)转换为模型可用的输入
    # 用户塔
    user_tower = tf.keras.layers.DenseFeatures(params['user_tower_columns'])(features)
    user_tower = tf.keras.layers.Dense(64, activation='relu')(user_tower)
    # 用户塔的输出层是一个32维的向量,本质就是用户Embedding
    user_tower = tf.keras.layers.Dense(32, activation='relu')(user_tower)

    # 商品塔
    item_tower = tf.keras.layers.DenseFeatures(params['item_tower_columns'])(features)
    item_tower = tf.keras.layers.Dense(64, activation='relu')(item_tower)
    # 商品塔的输出层是一个32维的向量,本质就是商品Embedding
    item_tower = tf.keras.layers.Dense(32, activation='relu')(item_tower)

    # 点积交互(即用户Embedding和商品Embedding求取余弦相似度)
    interaction = tf.keras.layers.Dot(axes=1)([user_tower, item_tower])

    # CTR预测头
    ctr_logits = tf.keras.layers.Dense(1)(interaction)
    ctr_pred = tf.sigmoid(ctr_logits)

    # CVR预测头(共享交互层)
    cvr_logits = tf.keras.layers.Dense(1)(interaction)
    cvr_pred = tf.sigmoid(cvr_logits)

    # 损失计算
    if mode != tf.estimator.ModeKeys.PREDICT:
        # CTR损失计算
        # 所有样本的点击标签(click_label)均参与计算,使用Sigmoid交叉熵。
        ctr_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(
                labels=labels['click_label'],
                logits=tf.squeeze(ctr_logits, axis=-1)  # 压缩最后一个维度
            )
        )

        # 仅点击样本的转化标签(conversion_label)参与计算CVR损失(通过click_mask过滤)
        click_mask = tf.cast(labels['click_label'], tf.bool)
        cvr_labels_masked = tf.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2323341.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【MLP-BEV(10)】BEVPooling V1和BEVPooling V2的view_transformer,进行鱼眼图片实践

文章目录 先说说 BEVPoolv1步骤1:3D点生成步骤2 2D特征采样和BEV特征生成特点再谈谈BEVPoolv2步骤1:3D点生成步骤2: 计算索引关系步骤3: `voxel_pooling`计算鱼眼图片进行实践步骤1、3D点生成(基于Kannala-Brandt 进行调整)步骤2、2D特征采样和BEV特征生成(1) 体素化 (Voxe…

Elasticsearch:使用 Azure AI 文档智能解析 PDF 文本和表格数据

作者:来自 Elastic James Williams 了解如何使用 Azure AI 文档智能解析包含文本和表格数据的 PDF 文档。 Azure AI 文档智能是一个强大的工具,用于从 PDF 中提取结构化数据。它可以有效地提取文本和表格数据。提取的数据可以索引到 Elastic Cloud Serve…

【 <二> 丹方改良:Spring 时代的 JavaWeb】之 Spring Boot 中的 AOP:实现日志记录与性能监控

<前文回顾> 点击此处查看 合集 https://blog.csdn.net/foyodesigner/category_12907601.html?fromshareblogcolumn&sharetypeblogcolumn&sharerId12907601&sharereferPC&sharesourceFoyoDesigner&sharefromfrom_link <今日更新> 一、开篇整…

多模态大模型训练范式演进与前瞻

本文从多模态大模型相关概念出发&#xff0c;并以Flamingo 模型为例&#xff0c;探讨了基于多模态大模型训练的演进与前瞻。新一代训练范式包括统一架构、数据工程革新和动态适应机制&#xff0c;以提升跨模态推理能力和长视频理解。 多模态大模型 定义 什么是多模态大模型&…

游戏引擎学习第187天

看起来观众解决了上次的bug 昨天遇到了一个相对困难的bug&#xff0c;可以说它相当棘手。刚开始的时候&#xff0c;没有立刻想到什么合适的解决办法&#xff0c;所以今天得从头开始&#xff0c;逐步验证之前的假设&#xff0c;收集足够的信息&#xff0c;逐一排查可能的原因&a…

HarmonyOS NEXT 关于鸿蒙的一多开发(一次开发,多端部署) 1+8+N

官方定义 定义&#xff1a;一套代码工程&#xff0c;一次开发上架&#xff0c;多端按需部署。 目标&#xff1a;支撑开发者快速高效的开发支持多种终端设备形态的应用&#xff0c;实现对不同设备兼容的同时&#xff0c;提供跨设备的流转、迁移和协同的分布式体验。 什么是18…

当Kafka化身抽水马桶:论组件并发提升与系统可用性的量子纠缠关系

《当Kafka化身抽水马桶&#xff1a;论组件并发提升与系统可用性的量子纠缠关系》 引言&#xff1a;一场OOM引发的血案 某个月黑风高的夜晚&#xff0c;监控系统突然发出刺耳的警报——我们的数据发现流水线集体扑街。事后复盘发现&#xff1a;Kafka集群、Gateway、Discovery服…

Dify+ollama+vanna 实现text2sql 智能数据分析 -01

新鲜出炉-今天安装vanna踩过的坑 今天的任务是安装vanna这个工具&#xff0c;因为dify中自己写的查询向量数据库和执行sql这两步太慢了大概要20S&#xff0c;所以想用下这个工具&#xff0c;看是否会快一点。后面会把这个vanna封装成一个工具让dify调用。 环境说明 我是在本…

uniapp uni-swipe-action滑动内容排版改造

在uniapp开发中 默认的uni-swipe-action滑动组件 按钮里的文字都是横排的 不能换行的 如果是在一些小屏设备 比如PDA这种&#xff0c;同时按钮文字又都是4个字 多按钮的情况 就会发现滑动一下都直接满屏了 观看体验都不好 但默认的官方组件又没有样式的设置&#xff0c;下面就告…

电脑卡怎么办?提升电脑流畅的方法

电脑已经成为我们工作、学习和娱乐不可或缺的伙伴。然而&#xff0c;随着使用时间的增长&#xff0c;许多用户会遇到电脑运行变慢、卡顿的情况&#xff0c;这不仅影响了工作效率&#xff0c;也大大降低了使用体验。本文将为大家分析电脑卡顿的常见原因&#xff0c;并提供一套实…

SpringBoot报错解决方案

org.apache.tomcat.util.http.fileupload.impl.SizeLimitExceededException: the request was rejected because its size (31297934) exceeds the configured maximum (10485760) 文件上传大小超过限制

软件需求未明确非功能性指标(如并发量)的后果

软件需求未明确非功能性指标&#xff08;如并发量&#xff09;可能带来的严重后果包括&#xff1a;系统性能下降、用户体验恶化、稳定性降低、安全风险增加、后期维护成本高企。其中&#xff0c;系统性能下降尤为显著。当软件系统在设计和开发阶段未明确并发量需求时&#xff0…

VScode-i18n-ally-Vue

参考这篇文章&#xff0c;做Vue项目的国际化配置&#xff0c;本篇文章主要解释&#xff0c;下载了i18n之后&#xff0c;该如何对Vscode进行配置 https://juejin.cn/post/7271964525998309428 i18n Ally全局配置项 Vscode中安装i18n Ally插件&#xff0c;并设置其配置项&#…

Spring Boot项目快速创建-开发流程(笔记)

主要流程&#xff1a; 前端发送网络请求->controller->调用service->操纵mapper->操作数据库->对entity数据对象赋值->返回前端 前期准备&#xff1a; maven、mysql下载好 跟学视频&#xff0c;感谢老师&#xff1a; https://www.bilibili.com/video/BV1gm4…

车架号查询车牌号接口如何用Java对接

一、什么是车架号查询车牌号接口&#xff1f; 车架号查询车牌号接口&#xff0c;即传入车架号&#xff0c;返回车牌号、车型编码、初次登记日期信息。车架号又称车辆VIN码&#xff0c;车辆识别码。 二、如何用Java对接该接口&#xff1f; 下面我们以阿里云接口为例&#xff0…

npm : 无法加载文件 C:\Program Files\nodejs\npm.ps1,因为在此系统上禁止运行脚本的处理方法

1、安装了node.js后&#xff0c;windows powershell中直接输入npm&#xff0c;然后就报错 2、出现原因&#xff1a;权限不够 系统禁用了脚本的执行&#xff0c;所以我们在windows powershell输入npm -v的时候&#xff0c;就会报上面的错误。 3、解决 Set-ExecutionPolicy Un…

数仓开发那些事(11)

某神州优秀员工&#xff1a;一闪&#xff0c;领导说要给我涨米。 一闪&#xff1a;。。。。&#xff08;着急的团团转&#xff09; 老运维&#xff1a;Oi&#xff0c;两个吊毛&#xff0c;看看你们的hadoop集群&#xff0c;健康度30分&#xff0c;怎么还在抽思谋克&#xff1f…

从零开始完成冒泡排序(0基础)——C语言版

文章目录 前言一、冒泡排序的基本思想二、冒泡排序的执行过程&#xff08;一&#xff09;第一轮排序&#xff08;二&#xff09;第二轮排序&#xff08;三&#xff09;第三轮排序&#xff08;四&#xff09;第四轮排序 三、冒泡排序的代码实现&#xff08;C语言&#xff09;&am…

工业级POE交换机:助力智能化与自动化发展

随着工业互联网、物联网&#xff08;IoT&#xff09;和自动化技术的快速发展&#xff0c;网络设备在工业领域的应用日益广泛。然而&#xff0c;在严苛环境下&#xff0c;传统网络设备往往难以应对复杂的温湿度变化、电磁干扰和供电不稳定等挑战。为同时满足数据传输与供电一体化…

使用ZYNQ芯片和LVGL框架实现用户高刷新UI设计系列教程(第五讲)

在上一讲我们讲解了按键回调函数的自定义函数的用法&#xff0c;这一讲继续讲解回调函数的另一种用法。 首先我们将上一讲做好的按键名称以及自定义回调事件中的按键名称修改&#xff0c;改为默认模式为“open”当点击按键时进入回调函数将按键名称改为“close”&#xff0c;具…