【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model)

news2025/1/15 11:37:29

 

目录

一、引言

1.1 往期回顾

1.2 本期概要

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

2.2 技术优缺点

2.3 业务代码实践

三、总结


一、引言

在朴素的深度学习ctr预估模型中(如DNN),通常以一个行为为预估目标,比如通过ctr预估点击率。但实际推荐系统业务场景中,更多是多种目标融合的结果,比如视频推荐,会存在视频点击率、视频完整播放率、视频播放时长等多个目标,而多种目标如何更好的融合,在工业界与学术界均有较多内容产出,由于该环节对实际业务影响最为直接,特开此专栏对推荐系统深度学习多目标问题进行讲述。

1.1 往期回顾

上一篇文章主要介绍了推荐系统多目标算法中的“样本Loss加权”,该方法在训练时Loss乘以样本权重实现对多种目标的加权,通过引导Loss梯度的学习方向,让模型参数朝着你设定的权重方向去学习。

1.2 本期概要

今天进一步深化,主要介绍Shared-Bottom Multi-task Model算法,该算法中文可译为“底部共享多任务模型”,该算法设定多个任务,每个任务设定多个目标,通过“Loss计算时调整每个任务的权重”,亦或是“每个塔单元内,多目标Loss计算时调整每个目标的权重”进行多任务多目标的调整。

二、Shared-Bottom Multi-task Model(SBMM)

2.1 技术原理

Shared-Bottom Multi-task Model(SBMM)全称为底层共享多任务模型,主要由底层共享网络、多任务塔、多目标输出构成。核心原理:通过构造多任务多目标样本数据,在Loss计算环节,将各任务Loss求和(或加权求和),对Loss求导(求梯度)后,逐步后向传播迭代。

  1. 底部网络:Shared-Bottom 网络通常位于底部,可以为一个DNN网络,或者emb+pooling+mlp的方式对input输入的稀疏(sparse)特征进行稠密(dense)化。
  2. ​​​​​​​多个任务塔:底部网络上层接N个任务塔(Tower),每个塔根据需要可以定义为简单或复杂的多层感知器(mlp)网络。每个塔可以对应特定的场景,比如一二级页面场景。
  3. 多个目标:每个任务塔(Tower)可以输出多个学习目标,每个学习目标还可以像上一篇文章一样进行样本Loss加权。每个目标可以对应一种特定的指标行为,比如点击、时长、下单等。

2.2 技术优缺点

相比于上一篇文章提到的样本Loss加权融合法,以及后续文章将会介绍的MoE、MMoE方法,有如下优缺点:

优点:

  • 可以对多级场景任务进行建模,使得ctcvr等点击后转化问题可以被深度学习
  • 浅层参数共享,互相补充学习,任务相关性越高,模型的loss可以降低到更低

缺点: 

  • 跷跷板问题:任务没有好的相关性时,这种Hard parameter sharing会损害效果

2.3 业务代码实践

我们以小红书推荐场景为例,用户在一级发现页场景中停留并点击了“误杀3”中的一个视频笔记,在二级场景视频播放页中观看并点赞了视频。

跨场景多目标建模:我们定义一个SBMM算法结构,底层是一个3层的MLP(64,32,16),MLP出来后接一级场景Tower和二级场景Tower,一级场景任务中分别定义视频一级页“是否停留”、“停留时长”、“是否点击”,二级场景任务中分别定义“点击后播放时长”,“播放后是否点赞”

伪代码:

导入 pytorch 库
定义 SharedBottomMultiTaskModel 类 继承自 nn.Module:
    定义 __init__ 方法 参数 (self, 输入维度, 隐藏层1大小, 隐藏层2大小, 隐藏层3大小, 输出任务1维度, 输出任务2维度):
        初始化共享底部的三层全连接层
        初始化任务1的三层全连接层
        初始化任务2的三层全连接层

    定义 forward 方法 参数 (self, 输入数据):
        计算输入数据通过共享底部后的输出
        从共享底部输出分别计算任务1和任务2的结果
        返回任务1和任务2的结果

生成虚拟样本数据:
    创建训练集和测试集

实例化模型对象
定义损失函数和优化器
训练循环:
    前向传播: 获取预测值
    计算每个任务的损失
    反向传播和优化

PyTorch版本:

算法逻辑

  1. 导入必要的库。
  2. 定义一个类来表示共享底部和特定任务头部的模型结构。
  3. 在初始化方法中定义共享底部和两个独立的任务头部网络层。
  4. 实现前向传播函数,处理输入数据通过共享底部后分发到不同的任务头部。
  5. 生成虚拟样本数据。
  6. 定义损失函数和优化器。
  7. 编写训练循环。
  8. 进行模型预测。
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

class SharedBottomMultiTaskModel(nn.Module):
    def __init__(self, input_dim, hidden1_dim, hidden2_dim, hidden3_dim, output_task1_dim, output_task2_dim):
        super(SharedBottomMultiTaskModel, self).__init__()
        # 定义共享底部的三层全连接层
        self.shared_bottom = nn.Sequential(
            nn.Linear(input_dim, hidden1_dim),
            nn.ReLU(),
            nn.Linear(hidden1_dim, hidden2_dim),
            nn.ReLU(),
            nn.Linear(hidden2_dim, hidden3_dim),
            nn.ReLU()
        )

        # 定义任务1的三层全连接层
        self.task1_head = nn.Sequential(
            nn.Linear(hidden3_dim, hidden2_dim),
            nn.ReLU(),
            nn.Linear(hidden2_dim, output_task1_dim)
        )

        # 定义任务2的三层全连接层
        self.task2_head = nn.Sequential(
            nn.Linear(hidden3_dim, hidden2_dim),
            nn.ReLU(),
            nn.Linear(hidden2_dim, output_task2_dim)
        )

    def forward(self, x):
        # 计算输入数据通过共享底部后的输出
        shared_output = self.shared_bottom(x)

        # 从共享底部输出分别计算任务1和任务2的结果
        task1_output = self.task1_head(shared_output)
        task2_output = self.task2_head(shared_output)

        return task1_output, task2_output

# 构造虚拟样本数据
torch.manual_seed(42)  # 设置随机种子以保证结果可重复
input_dim = 10
task1_dim = 3
task2_dim = 2
num_samples = 1000
X_train = torch.randn(num_samples, input_dim)
y_train_task1 = torch.randn(num_samples, task1_dim)  # 假设任务1的输出维度为task1_dim
y_train_task2 = torch.randn(num_samples, task2_dim)  # 假设任务2的输出维度为task2_dim

# 创建数据加载器
train_dataset = TensorDataset(X_train, y_train_task1, y_train_task2)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 实例化模型对象
model = SharedBottomMultiTaskModel(input_dim, 64, 32, 16, task1_dim, task2_dim)

# 定义损失函数和优化器
criterion_task1 = nn.MSELoss()
criterion_task2 = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (X_batch, y_task1_batch, y_task2_batch) in enumerate(train_loader):
        # 前向传播: 获取预测值
        outputs_task1, outputs_task2 = model(X_batch)

        # 计算每个任务的损失
        loss_task1 = criterion_task1(outputs_task1, y_task1_batch)
        loss_task2 = criterion_task2(outputs_task2, y_task2_batch)
        #print(f'loss_task1:{loss_task1},loss_task2:{loss_task2}')
        total_loss = loss_task1 + loss_task2

        # 反向传播和优化
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')

# 模型预测
model.eval()
with torch.no_grad():
    test_input = torch.randn(1, input_dim)  # 构造一个测试样本
    pred_task1, pred_task2 = model(test_input)

    print(f'任务1预测结果: {pred_task1}')
    print(f'任务2预测结果: {pred_task2}')

三、总结

本文从技术原理、技术优缺点方面对推荐系统深度学习多任务多目标“Shared-Bottom Multi-task Model”算法进行讲解,该模型使用深度学习模型对多个任务场景多个目标的业务问题进行建模,使得用户在多个场景连续性行为可以被学习,在现实推荐系统业务中是比较基础的方法,后面本专栏还会陆续介绍MoE、MMoE等多任务多目标算法,期待您的关注和支持。

如果您还有时间,欢迎阅读本专栏的其他文章:

【深度学习】多目标融合算法(一):样本Loss加权(Sample Loss Reweight)

【深度学习】多目标融合算法(二):底部共享多任务模型(Shared-Bottom Multi-task Model) ​​​​​​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2276980.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

天机学堂3-ES+Caffeine

文章目录 day05-问答系统表 用户端分页查询问题目标效果代码实现 3.6.管理端分页查询问题ES相关 管理端互动问题分页实现三级分类3.6.5.2.多级缓存3.6.5.3.Caffeine 4.评论相关接口目标效果新增回答或评论 day05-问答系统 效果: 表 互动提问的问题表&#xff1a…

【Docker】Docker部署多种容器

关于docker,Windows上使用Powershell/CMD执行指令,Linux系统直接使用终端执行指令。 docker安装MySQL 拉取MySQL 也可以跳过拉取步骤,直接run,这样本地容器不存在的话,会自动拉取最新/指定的版本。 # 默认拉取最新…

【Flink】Flink内存管理

Flink内存整体结构图: JobManager内存管理 JVM 进程总内存(Total Process Memory)Flink总内存(Total Flink Memory):JVM进程总内存减去JVM Metaspace(元空间)和JVM Overhead(运行时开销)上图解释: JVM进程总内存为2G;JVM运行时开销(JVM Overh…

如何规模化实现完全自动驾驶?Mobileye提出解题“新”思路

在CES 2025上,Mobileye展示了端到端自动驾驶系统Mobileye Drive™,通过高度集成的传感器、算法和计算平台,可以实现自动驾驶功能的全覆盖。 Mobileye创始人兼首席执行官Amnon Shashua教授 期间,Mobileye创始人兼首席执行官Amnon …

Qt 5.14.2 学习记录 —— 십일 QLCDNumber、QProgressBar、QCalendarWidget

文章目录 1、QLCDNumber2、ProgressBar3、QCalendarWidget 1、QLCDNumber 写一个倒计时程序。拖一个LCD Number到界面: 定时器用Qt的QTimer类,这个类的对象会产生一个timeout信号,通过start方法来开启定时器,并且参数中设定触发ti…

简要认识JAVAWeb技术三剑客:HTMLCSSJavaScript

目录 一、web标准二、什么是HTML三、什么是CSS四、什么是JavaScript 黑马JAVAWeb飞书在线讲义地址: https://heuqqdmbyk.feishu.cn/wiki/LYVswfK4eigRIhkW0pvcqgH9nWd 一、web标准 Web标准也称网页标准,由一系列的标准组成,大部分由W3C&…

sosadmin相关命令

sosadmin命令 以下是本人翻译的官方文档,如有不对,还请指出,引用请标明出处。 原本有个对应表可以跳转的,但是CSDN的这个[](#)跳转好像不太一样,必须得用html标签,就懒得改了。 sosadmin help 用法 sosadm…

【C语言】字符串函数详解

文章目录 Ⅰ. strcpy -- 字符串拷贝1、函数介绍2、模拟实现 Ⅱ. strcat -- 字符串追加1、函数介绍2、模拟实现 Ⅲ. strcmp -- 字符串比较1、函数介绍2、模拟实现 Ⅳ. strncpy、strncat、strncmp -- 可限制操作长度Ⅴ. strlen -- 求字符串长度1、函数介绍2、模拟实现&#xff08…

IO进程day6

一、思维导图 二、练习题1 有一个隧道&#xff0c;长1000m&#xff0c;有一辆高铁&#xff0c;每秒100米&#xff0c;有一辆快车&#xff0c;每秒50m 要求模拟这两列火车通过隧道的场景。 #include <stdio.h> #include <unistd.h> #include <pthread.h>pthre…

手撕代码: C++实现按位序列化和反序列化

目录 1.需求 2.流程分析 3.实现过程 4.总结 1.需求 在我们正在开发的项目&#xff0c;有这样一种需求&#xff0c;实现固定格式和自由格式的比特流无线传输。解释一下&#xff0c;固定格式形如下面表格&#xff1a; 每个字段都有位宽、类型等属性&#xff0c;这种固定格式一…

期望最大化算法:机器学习中的隐变量与参数估计的艺术

引言 在机器学习和统计学领域&#xff0c;许多实际问题涉及到含有隐变量的概率模型。例如&#xff0c;在图像识别中&#xff0c;图像的语义信息往往是隐变量&#xff0c;而我们能观测到的只是图像的像素值&#xff1b;在语音识别中&#xff0c;语音对应的文本内容是隐变量&…

2025封禁指定国家ip-安装xtables-addons记录

如何安装和使用 安装lux仓库(该仓库包含xtables-addons所需的依赖环境) # wget http://repo.iotti.biz/CentOS/7/noarch/lux-release-7-1.noarch.rpm # rpm -ivh lux-release-7-1.noarch.rpm 安装xtables-addons。注意&#xff1a;必须先安装kmod-xtables-addons&#xff0c;再…

使用C语言实现栈的插入、删除和排序操作

栈是一种后进先出(LIFO, Last In First Out)的数据结构,这意味着最后插入的元素最先被删除。在C语言中,我们可以通过数组或链表来实现栈。本文将使用数组来实现一个简单的栈,并提供插入(push)、删除(pop)以及排序(这里采用一种简单的排序方法,例如冒泡排序)的操作示…

【Go】Go Gin框架初识(一)

1. 什么是Gin框架 Gin框架&#xff1a;是一个由 Golang 语言开发的 web 框架&#xff0c;能够极大提高开发 web 应用的效率&#xff01; 1.1 什么是web框架 web框架体系图&#xff08;前后端不分离&#xff09;如下图所示&#xff1a; 从上图中我们可以发现一个Web框架最重要…

【深度学习】Windows系统Anaconda + CUDA + cuDNN + Pytorch环境配置

在做深度学习内容之前&#xff0c;为GPU配置anaconda CUDA cuDNN pytorch环境&#xff0c;在网络上参考了很多帖子&#xff0c;但pytorch的安装部分都有些问题或者比较复杂繁琐&#xff0c;这里总结了相对简单快速的配置方式 文章目录 AnacondaCUDAcuDNNpytorchtorchtorchau…

提供的 IP 地址 10.0.0.5 和子网掩码位 /26 来计算相关的网络信息

网络和IP地址计算器 https://www.sojson.com/convert/subnetmask.html提供的 IP 地址 10.0.0.5 和子网掩码位 /26 来计算相关的网络信息。 子网掩码转换 子网掩码 /26 的含义二进制表示:/26 表示前 26 位是网络部分&#xff0c;剩下的 6 位是主机部分。对应的子网掩码为 255…

js中的Object.defineProperty()详解

文章目录 一、Object.defineProperty()二、descriptor属性描述符2.1、数据描述符2.2、访问器描述符2.3、descriptor属性2.3.1、value2.3.2、writable2.3.3、enumerable &#xff08;可遍历性&#xff09;2.3.4、configurable &#xff08;可配置性&#xff09; 三、注意事项 一…

单细胞组学大模型(8)--- scGenePT,scGPT和GenePT的结合,实验数据和文本数据的交融模型

–https://doi.org/10.1101/2024.10.23.619972 研究团队和单位 Theofanis Karaletsos–Head Of AI - Science at Chan Zuckerberg Initiative &#xff08;Chan Zuckerberg Initiative是扎克伯格和他妻子Chan成立的科研&教育机构&#xff09; 研究简介 研究背景&…

微信原生小程序自定义封装组件(以导航navbar为例)

封装 topnav.js const App getApp(); Component({// 组件的属性列表properties: {pageName: String, //中间的titleshowNav: { //判断是否显示左上角的按钮 type: Boolean,value: true},showHome: { //判断是否显示左上角的home按钮type: Boolean,value: true},showLocat…

day06_Spark SQL

文章目录 day06_Spark SQL课程笔记一、今日课程内容二、DataFrame详解&#xff08;掌握&#xff09;5.清洗相关的API6.Spark SQL的Shuffle分区设置7.数据写出操作写出到文件写出到数据库 三、Spark SQL的综合案例&#xff08;掌握&#xff09;1、常见DSL代码整理2、电影分析案例…