特征交叉-CAN学习笔记代码解读

news2025/2/25 5:29:41

一 核心模块coaction

  1. 对于每个特征对(feature_pairs)
  2. weight, bias 来自于P_induction
  3. P_fead是MLP的input

举个例子:如果是用户ID和产品ID的co-action,且产品ID是做induction,用户ID是做feed。

  • step1 用户ID/产品ID都先形成一个向量:对于产品ID,用parameter lookup获取一个可学习的P_induction(这个维度是(wi+bi) * L depth of mlp); 用户ID则直接形成一个向量P_fead
  • step2 P_induction 这个向量逐层(MLP层),reshape成MLP网络的weight 和bias;
  • step3 weight和bias作为MLP的参数,利用P_feed 作为input,进行MLP前向运算,得到特征交互结果
  1. 代码解读
#### CAN config #####
weight_emb_w = [[16, 8], [8,4]] # micro-mlp的参数dimension
weight_emb_b = [0, 0]           # bias参数
orders = 3  # 特征的阶数,文章提到了,要做高阶特征交叉,直接是P_feed^c, c就是阶数
order_indep = False # True
WEIGHT_EMB_DIM = (sum([w[0]*w[1] for w in weight_emb_w]) + sum(weight_emb_b)) # * orders 这个是供每一个micro-mlp拆解w&b需要的dimension总和
INDEP_NUM = 1
if order_indep:
    INDEP_NUM *= orders
###### 这一部分对应图中绿色和橙色部分,主要是把P_feed&P_induction的嵌入表示得到 ##########
if self.use_coaction:
   # batch_ph batch输入的数据;his_batch_ph历史批次数据; his_batch_embedded 历史嵌入表示
   ph_dict = {
       "item": [self.mid_batch_ph, self.mid_his_batch_ph, self.mid_his_batch_embedded],
       "cate": [self.cate_batch_ph, self.cate_his_batch_ph, self.cate_his_batch_embedded]
   }
   ### p_induction ####
   self.mlp_batch_embedded = [] # induction embedding
   with tf.device(device):
       # 定义可训练的嵌入矩阵,在这里n_mid是item id的数量
       self.item_mlp_embeddings_var = tf.get_variable("item_mlp_embedding_var", [n_mid, INDEP_NUM * WEIGHT_EMB_DIM], trainable=True)
       self.cate_mlp_embeddings_var = tf.get_variable("cate_mlp_embedding_var", [n_cate, INDEP_NUM * WEIGHT_EMB_DIM], trainable=True)
       # 通过embedding_lookup在上一步初始化好的矩阵中找到对应的embedding表示
       self.mlp_batch_embedded.append(tf.nn.embedding_lookup(self.item_mlp_embeddings_var, ph_dict['item'][0]))
       self.mlp_batch_embedded.append(tf.nn.embedding_lookup(self.cate_mlp_embeddings_var, ph_dict['cate'][0]))
       #########P_feed input ########
       self.input_batch_embedded = []
       self.item_input_embeddings_var = tf.get_variable("item_input_embedding_var", [n_mid, weight_emb_w[0][0] * INDEP_NUM], trainable=True)
       self.cate_input_embeddings_var = tf.get_variable("cate_input_embedding_var", [n_cate, weight_emb_w[0][0] * INDEP_NUM], trainable=True)  
         self.input_batch_embedded.append(tf.nn.embedding_lookup(self.item_input_embeddings_var, ph_dict['item'][1]))
       self.input_batch_embedded.append(tf.nn.embedding_lookup(self.cate_input_embeddings_var, ph_dict['cate'][1]))
################这一部分是P_induction&P_feed在MLP的使用#######################
if self.use_coaction:
    # p_feed/input
    input_batch = self.input_batch_embedded
    tmp_sum, tmp_seq = [], []
    if INDEP_NUM == 2:
        # 文章说明了是feature pairs,mlp_batch&input_batch都包含了两个部分,要分别组合
        for i, mlp_batch in enumerate(self.mlp_batch_embedded):
            for j, input_batch in enumerate(self.input_batch_embedded):
                coaction_sum, coaction_seq = gen_coaction(
                    mlp_batch[:, WEIGHT_EMB_DIM * j:  WEIGHT_EMB_DIM * (j+1)], 
                    input_batch[:, :, weight_emb_w[0][0] * i: weight_emb_w[0][0] * (i+1)],  
                    EMBEDDING_DIM, 
                    mode=CALC_MODE,
                    mask=self.mask) 
                
                tmp_sum.append(coaction_sum)
                tmp_seq.append(coaction_seq)
    else:
        for i, (mlp_batch, input_batch) in enumerate(zip(self.mlp_batch_embedded, self.input_batch_embedded)):
            coaction_sum, coaction_seq = gen_coaction(
                  mlp_batch[:, :INDEP_NUM * WEIGHT_EMB_DIM], 
                  input_batch[:, :, :weight_emb_w[0][0]],  
                  EMBEDDING_DIM, 
                  mode=CALC_MODE, 
                  mask=self.mask) 
            
            tmp_sum.append(coaction_sum)
            tmp_seq.append(coaction_seq)
            
    self.coaction_sum = tf.concat(tmp_sum, axis=1) # sum pooling
    self.cross.append(self.coaction_sum)   # concat              
###### core interaction 核心运算 #########
def gen_coaction(ad, his_items, dim, mode="can", mask=None):
    """
    ad: induct
    his_items 待交互seq
    """
    weight, bias = [], []
    idx = 0
    weight_orders = []
    bias_orders = []
    # 拆解得到weight&bias参数
    for i in range(orders):
        for w, b in zip(weight_emb_w, weight_emb_b):
            weight.append(tf.reshape(ad[:, idx:idx+w[0]*w[1]], [-1, w[0], w[1]]))
            idx += w[0] * w[1]
            if b == 0:
                bias.append(None)
            else:
                bias.append(tf.reshape(ad[:, idx:idx+b], [-1, 1, b]))
                idx += b
        weight_orders.append(weight)
        bias_orders.append(bias)
        if not order_indep:
            break
 
    if mode == "can":
        out_seq = []
        hh = []
        # 高阶特征处理,explicit deal with
        for i in range(orders):
            hh.append(his_items**(i+1))
        #hh = [sum(hh)]
        for i, h in enumerate(hh):
            if order_indep:
                weight, bias = weight_orders[i], bias_orders[i]
            else:
                weight, bias = weight_orders[0], bias_orders[0]
            # 模拟MLP forward calculation
            for j, (w, b) in enumerate(zip(weight, bias)):
                h  = tf.matmul(h, w)
                if b is not None:
                    h = h + b
                if j != len(weight)-1:
                    h = tf.nn.tanh(h)
                out_seq.append(h)
        out_seq = tf.concat(out_seq, 2)
        if mask is not None:
            mask = tf.expand_dims(mask, axis=-1) 
            out_seq = out_seq * mask
            
    # 序列交互结果做sum_pooling
    out = tf.reduce_sum(out_seq, 1)
    if keep_fake_carte_seq and mode=="emb":
        return out, out_seq
    return out, None

二 文章中的应用
整体的模型结构两部分构成:

  • co-action作为核心形成的一部分,对于用户的序列特征,一一作用后做sum-pooling,对于非序列特征,作用后直接输出
  • DIEN作为核心形成的一部分

两部分concat以后加一个DNN常规操作,看起来就像是用co-action做显式的特征交叉,然后DIEN做之前的序列建模。
在这里插入图片描述
三 一些其他细节补充

  1. can 部分高阶特征处理: 直接把待交叉特征p_fead 做c阶运算后,再与p_induction进行作用
  2. 在文章场景,p_induction是target_item,也就是产品

四 用tf2/torch重构

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class CAN_Model(nn.Module):
    def __init__(self, n_uid, n_mid, n_cate, n_carte, EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE, use_negsampling=False, use_softmax=True, use_coaction=False, use_cartes=False):
        super(CAN_Model, self).__init__()
        
        self.n_uid = n_uid
        self.n_mid = n_mid
        self.n_cate = n_cate
        self.n_carte = n_carte
        self.EMBEDDING_DIM = EMBEDDING_DIM
        self.HIDDEN_SIZE = HIDDEN_SIZE
        self.ATTENTION_SIZE = ATTENTION_SIZE
        self.use_negsampling = use_negsampling
        self.use_softmax = use_softmax
        self.use_coaction = use_coaction
        self.use_cartes = use_cartes

        self.uid_embeddings = nn.Embedding(n_uid, EMBEDDING_DIM)
        self.mid_embeddings = nn.Embedding(n_mid, EMBEDDING_DIM)
        self.cate_embeddings = nn.Embedding(n_cate, EMBEDDING_DIM)

        if use_cartes:
            self.carte_embeddings = nn.ModuleList([nn.Embedding(num, EMBEDDING_DIM) for num in n_carte])

        if self.use_coaction:
            self.item_mlp_embeddings = nn.Parameter(torch.randn(n_mid, INDEP_NUM * WEIGHT_EMB_DIM))
            self.cate_mlp_embeddings = nn.Parameter(torch.randn(n_cate, INDEP_NUM * WEIGHT_EMB_DIM))
            self.input_batch_embeddings = nn.ModuleList([nn.Embedding(n_mid, weight_emb_w[0][0] * INDEP_NUM), nn.Embedding(n_cate, weight_emb_w[0][0] * INDEP_NUM)])

        self.fc1 = nn.Linear(200, 80)
        self.fc2 = nn.Linear(80, 2 if use_softmax else 1)

    def forward(self, uid, mid, cate, mid_his, cate_his, mask, target, seq_len, lr, carte=None):
        # Embedding lookups
        uid_emb = self.uid_embeddings(uid)
        mid_emb = self.mid_embeddings(mid)
        cate_emb = self.cate_embeddings(cate)
        mid_his_emb = self.mid_embeddings(mid_his)
        cate_his_emb = self.cate_embeddings(cate_his)

        if self.use_cartes:
            carte_emb = [emb(carte[:, i, :]) for i, emb in enumerate(self.carte_embeddings)]

        # Co-action logic (if enabled)
        if self.use_coaction:
            # This is a simplified version of the co-action implementation from the original TensorFlow code
            mlp_embedded_item = self.item_mlp_embeddings[mid]
            mlp_embedded_cate = self.cate_mlp_embeddings[cate]
            input_embedded_item = self.input_batch_embeddings[0](mid_his)
            input_embedded_cate = self.input_batch_embeddings[1](cate_his)
            # Further coaction operations can be added based on your logic

        # Concatenate item and category embeddings
        item_eb = torch.cat([mid_emb, cate_emb], dim=1)
        item_his_eb = torch.cat([mid_his_emb, cate_his_emb], dim=2)
        item_his_eb_sum = item_his_eb.sum(dim=1)

        if self.use_negsampling:
            # Assuming the negative sampling implementation would need its own logic.
            pass

        # FC layers
        x = self.fc1(item_eb)
        x = F.relu(x)
        x = self.fc2(x)

        # Loss computation
        if self.use_softmax:
            y_hat = F.softmax(x, dim=-1)
            loss = F.cross_entropy(y_hat, target)
        else:
            y_hat = torch.sigmoid(x)
            loss = F.binary_cross_entropy_with_logits(x, target)

        return loss, y_hat

    def auxiliary_loss(self, h_states, click_seq, noclick_seq, mask):
        mask = mask.float()
        click_input = torch.cat([h_states, click_seq], dim=-1)
        noclick_input = torch.cat([h_states, noclick_seq], dim=-1)
        click_prop = self.auxiliary_net(click_input)[:, :, 0]
        noclick_prop = self.auxiliary_net(noclick_input)[:, :, 0]
        click_loss = -torch.log(click_prop) * mask
        noclick_loss = -torch.log(1.0 - noclick_prop) * mask
        loss = (click_loss + noclick_loss).mean()
        return loss

    def auxiliary_net(self, in_):
        x = F.relu(self.fc1(in_))
        x = F.relu(self.fc2(x))
        return x

    def train_step(self, data, optimizer):
        optimizer.zero_grad()
        loss, y_hat = self(data)
        loss.backward()
        optimizer.step()
        return loss.item()

    def evaluate(self, data):
        with torch.no_grad():
            loss, y_hat = self(data)
        return loss.item(), y_hat

# Example of using the model
n_uid = 1000
n_mid = 1000
n_cate = 500
n_carte = [10, 20]  # Example carte sizes
EMBEDDING_DIM = 128
HIDDEN_SIZE = 256
ATTENTION_SIZE = 128

model = CAN_Model(n_uid, n_mid, n_cate, n_carte, EMBEDDING_DIM, HIDDEN_SIZE, ATTENTION_SIZE)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Example data
uid = torch.randint(0, n_uid, (32,))
mid = torch.randint(0, n_mid, (32,))
cate = torch.randint(0, n_cate, (32,))
mid_his = torch.randint(0, n_mid, (32, 5))
cate_his = torch.randint(0, n_cate, (32, 5))
mask = torch.ones(32, 5)
target = torch.randint(0, 2, (32,))
seq_len = torch.randint(1, 5, (32,))
lr = 0.001

# Training step
loss = model.train_step((uid, mid, cate, mid_his, cate_his, mask, target, seq_len, lr), optimizer)
print(f"Loss: {loss}")

Reference:

  1. 文章形成思路历程
  2. CAN: Feature Co-Action for Click-Through Rate Prediction-21年,阿里
  3. Implementation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2258569.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

EfficientNet与复合缩放理论(Compound Scaling Theory) 详解(MATLAB)

1.EfficientNet网络与模型复合缩放 1.1 EfficientNet网络简介 1.1.1 提出背景、动机与过程 EfficientNet是一种高效的卷积神经网络(CNN),由Google的研究团队Tan等人在2019年提出。EfficientNet的设计目标是提高网络的性能,同时减…

SQL语句在MySQL中如何执行

MySQL的基础架构 首先就是客户端,其次Server服务层,大多数MySQL的核心服务都在这一层,包括连接、分析、优化、缓存以及所有的内置函数(时间、日期、加密函数),所有跨存储引擎功能都在这一层实现&#xff1…

开源低代码平台-Microi吾码-表单控件数据源绑定配置

表单控件数据源绑定配置 平台简介普通数据源数据源引擎Sql数据源通过其它字段来动态绑定数据源关于绑定数据源后的显示字段和存储字段 平台简介 技术框架:.NET8 Redis MySql/SqlServer/Oracle Vue2/3 Element-UI/Element-Plus平台始于2014年(基于Av…

Y3编辑器文档4:触发器1(对话、装备、特效、行为树、排行榜、不同步问题)

文章目录 一、触发器简介1.1 触发器界面1.2 ECA语句编辑及快捷键1.3 参数设置1.4 变量设置1.5 实体触发器1.6 函数库与触发器复用 二、触发器的多层结构2.1 子触发器(在游戏内对新的事件进行注册)2.2 触发器变量作用域2.3 复合条件2.4 循环2.5 计时器2.6…

Redis原理—4.核心原理摘要

大纲(9870字) 1.Redis服务器的Socket网络连接建立 2.Redis多路复用监听与文件事件模型 3.基于队列串行化的文件事件处理机制 4.完整的Redis Server网络通信流程 5.Redis串行化单线程模型为什么能高并发 6.Redis内核级请求处理流程与原理 7.Redis通信协议与内核级请求数据…

轻量级日志管理平台:Grafana Loki搭建及应用(详细篇)

前言 Grafana Loki是Grafana Lab团队提供的一个水平可扩展、高可用性、多租户的日志聚合系统,与其他日志系统不同的是,Loki最初设计的理念是为了为日志建立标签索引,而非将原日志内容进行索引。 现在目前成熟的方案基本上都是:L…

【规范一】JAVA静态代码规范

1.规范的划分 将Java代码规范分为 风格规范 和 质量规范 ,主要是因为这两种规范关注的方面不同,各自解决的问题也不同。下面详细解释为什么需要将代码规范分为这两种 1.1 风格规范(Coding Style Guidelines) 风格规范主要关注代码…

Angular由一个bug说起之十二:网页页面持续占用CPU过高

随着网络日益发达,网页的内容也更加丰富,形式也更加多样化。而随之而来的性能问题也不容小觑。这篇文章我会根据我在实践中遇到的一个问题来总结,我在面对性能问题的一些解决步骤,希望能对大家有所启发。 查找问题原因 我接触的…

WordPress全能CDN插件_自动刷新预热_缓存优化|国内国外集成CDN配置

WordPress全网独家原创CDN插件 自动刷新预热 缓存优化 国内国外集成CDN配置 支持白山云 cdnfly Cloudflare PS:目前国内集成了CDNfly,白山云国外集成了Cloudflare,更新手动刷新,全站刷新,优化提交线程,根据网友建议适配阿里云,le…

唇形同步视频生成工具:Wav2Lip

一、模型介绍 今天介绍一个唇形同步的工具-Wav2Lip;Wav2Lip是一种用于生成唇形同步(lip-sync)视频的深度学习算法,它能够根据输入的音频流自动为给定的人脸视频添加准确的口型动作。 (Paper) Wav2Lip模型…

【汽车】-- 燃油发动机3缸和4缸

3缸和4缸燃油发动机是小轿车常见的发动机配置。以下从结构特点、性能、经济性等方面对两者进行对比,并分析优缺点及使用注意事项: 1. 结构与运行原理 3缸发动机 特点:少一个气缸,内部零部件更少,整体结构更紧凑。优点…

[NeurlPS 2022] STaR 开源代码实现解读

STaR 方法代码开源,这里给出一个中文代码解读地址:repo入口点:iteration_train.py;关键代码:device_train.py, device_inference.py, and create_finetune_tfrecords.py;基于 JAX、RAY,在 Googl…

欢迪迈手机商城设计与实现

文末获取源码和万字论文,制作不易,感谢点赞支持。 题目:欢迪迈手机商城设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管…

【鸿蒙实战开发】数据的下拉刷新与上拉加载

本章介绍 本章主要介绍 ArkUI 开发中最常用的场景下拉刷新, 上拉加载,在本章中介绍的内容在实际开发过程当中会高频的使用,所以同学们要牢记本章的内容。下面就让我们开始今天的讲解吧! List 组件 在 ArkUI 中List容器组件也可以实现数据滚动的效果&a…

UnityShaderLab 实现程序化形状(一)

1.实现一个长宽可变的矩形: 代码: fixed4 frag (v2f i) : SV_Target{return saturate(length(saturate(abs(i.uv - 0.5)-0.13)))/0.03;} 2.实现一个半径可变的圆形: 代码: fixed4 frag (v2f i) : SV_Target{return (distance(a…

高阶数据结构--B树B+树实现原理B树模拟实现--Java

目录 一、B-树概念 二、B-树插入分析 1.用序列{53, 139, 75, 49, 145, 36, 101}构建B树的过程如下: 2.插入过程总结 三、B树插入实现 四、B树 1.B树概念 2.B树的特性 五、B树应用 1.索引 2.Mysql索引 3.InnoDB 一、B-树概念 1970 年, R.Bayer 和…

网络安全——防火墙

基本概念 防火墙是一个系统,通过过滤传输数据达到防止未经授权的网络传输侵入私有网络,阻止不必要流量的同时允许必要流量进入。防火墙旨在私有和共有网络间建立一道安全屏障,因为网上总有黑客和恶意攻击入侵私有网络来破坏,防火…

基于Qwen2-VL模型针对LaTeX OCR任务进行微调训练 - 多图推理

基于Qwen2-VL模型针对LaTeX OCR任务进行微调训练 - 多图推理 flyfish 基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_LoRA配置如何写 基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_单图推理 基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_原模型_单图推理 基于Q…

Ant Design Pro实战--day01

下载nvm https://nvm.uihtm.com/nvm-1.1.12-setup.zip 下载node.js 16.16.0 //非此版本会报错 nvm install 16.16.0 安装Ant Design pro //安装脚手架 npm i ant-design/pro-cli -g //下载项目 pro create myapp //选择版本 simple 安装依赖 npm install 启动umi yarn add u…

一、为什么要学习麒麟?

麒麟认证:开启职业晋升之门 当前,就业难已经成为一个普遍的社会问题。许多大学生毕业后面临着找工作的困境,他们往往发现自己很难找到满意的职位。即使有幸找到了工作,也经常需要应对工作压力大、薪资低等问题。除此之外&#xff…