深度学习 - 51.推荐场景下的 Attention And Multi-Head Attention 简单实现 By Keras

news2024/11/23 20:33:46

目录

一.引言

二.计算流程

1.Attention 结构

2.Multi-Head Attention 结构

三.计算实现

1.Item、序列样本生成

2.OwnAttention Layer 实现

2.1 init 初始化

2.2 build 参数构建

2.3 call 逻辑调用

3.OwnAttention Layer 测试

四.总结


一.引言

Attention And Multi-Head Attention 一文中我们简单介绍了 Attention 与 Multi-Head Attention 在推荐场景下的计算,本文通过 Keras 自定义 Layer 的方式,实现 OwnAttention Layer 实现两种 Attention 的功能。

二.计算流程

1.Attention 结构

• 输入

Query 为候选 Item

Keys 为用户行为序列 Goods id,key_i 代表第 i 个行为 good

Values 与 Keys 相同

• 计算

lookup 获取 query、keys、values 向量

query 向量 + keys 向量通过 ActivationUnit 获取每个 key_i 对应的权重 weight_i

weight_i softmax 归一化,此步骤可选

将 weight_i 与 value_i  加权平均得到 Attention Output

2.Multi-Head Attention 结构

• 输入

Query 为候选 Item

Keys 为用户行为序列 Goods id,key_i 代表第 i 个行为 good

Values 与 Keys 相同

 

• 计算

lookup 获取 query、keys、values 向量

原始向量先经过一次 Linear 层

根据 head 的数量,将向量 Split 分为多个子向量,代表不同子空间

每一个 Head 下的子向量执行 Scaled Dot-Product Attention 得到权重 Weight

与子空间 Value 加权平均得到输出

输出再通过一次 Linear 层并 Concat 得到  Attention Output

三.计算实现

1.Item、序列样本生成

def genSamples(_batch_size=5, _T_k=10, _N=1000, seed=0):
    np.random.seed(seed)
    # 用户历史序列
    user_history = np.random.randint(0, N, size=(batch_size, _T_k))
    # 候选 Item
    user_candidate = np.random.randint(0, N, size=(batch_size, 1))
    return user_history, user_candidate

batch_size 为样本数,T_k 为行为数,N 为 Goods 总数,模拟数据,主要为了跑通逻辑:

    # 用户历史行为序列 && 候选商品 ID
    batch_size, T_k, N = 5, 10, 1000
    history, candidate = genSamples(batch_size, T_k, N)
    print(history[0:5])
    print(candidate[0:5])

 

2.OwnAttention Layer 实现

2.1 init 初始化

import numpy as np
import tensorflow as tf
from tensorflow.python.keras.layers import *
from tensorflow.keras.layers import Layer

class OwnAttention(Layer):

    def __init__(self, _mode='Attention', _is_weight_normalization=True, **kwargs):
        self.activation_unit = None
        self.DNN = None
        self.LastDNN = None
        self.kernel = None
        self.N = 10000
        self.T_k = 10
        self.emd_dim = 8
        self.num_heads = 2
        self.mode = _mode
        self.is_weight_normalization = _is_weight_normalization
        super().__init__(**kwargs)

N、T_k、emd_dim 分别代表商品库大小、序列长度与向量维度

mode 供分两种 'Attention' 与 'Multi-Head Attention' 分别代表两种 Attention 模式

is_weight_normalization 权重是否归一化,这个根据自己场景与内积的量纲决定

2.2 build 参数构建

    def build(self, input_shape):
        # 获取 Item 向量
        self.kernel = self.add_weight(name='seq_emb',
                                      shape=(self.N, self.emd_dim),
                                      initializer='he_normal',
                                      trainable=True)
        # Multi-Head Linear
        self.DNN = Dense(self.emd_dim, activation='relu')
        self.LastDNN = Dense(self.emd_dim, activation='relu')

        # Activation Unit
        self.activation_unit = Dense(1, activation='relu')

        super(OwnAttention, self).build(input_shape)

kernel 为商品 id 对应的 Embedding 层,维度为 N x emd_dim

DNN 为 Multi-Head 的首层 Linear

LastDNN 为 Multi-Head 的末层 Linear

activation_unit 用于计算加权权重

Tips:

关于 activation_unit,除了上面的简单实现外,还可以加入 goods 对应的 Position Embedding 或者加入其它 SideInfo 侧信息辅助决策。

2.3 call 逻辑调用

    def call(self, inputs, **kwargs):
        _history, _candidate = inputs

        Q = tf.nn.embedding_lookup(self.kernel, _candidate)
        K = tf.nn.embedding_lookup(self.kernel, _history)
        V = tf.nn.embedding_lookup(self.kernel, _history)

        print("Q Shape: %s \nK Shape: %s \nV Shape: %s" % (Q.shape, K.shape, V.shape))

第一步 lookup 获取 id 对应的 Embedding,BS=5、T_k=1、emd_dim=8:

Q Shape: (5, 1, 8) 
K Shape: (5, 10, 8) 
V Shape: (5, 10, 8)

• mode = 'Attention'

        if self.mode == 'Attention':
            # 获取 Attention 权重
            # [None, T_k, emd_dim] -> [None, T_k, 1] -> [None, 1, T_k]
            din_out = self.activation_unit(K)
            din_out = tf.transpose(din_out, (0, 2, 1))

            # 构建 Mask [None, 1, T_k]
            seq_mask = tf.equal(_history, tf.zeros_like(_history))
            seq_mask = tf.expand_dims(seq_mask, axis=1)

            # 权重归一化, 权重不使用 softmax 归一化则默认为 0 填充 [None, 1, T_k]
            if self.is_weight_normalization:
                paddings = tf.ones_like(din_out) * (-2 ** 32 + 1)
            else:
                paddings = tf.zeros_like(din_out)

            # 归一化 + Padding 的 Attention 权重 [None, 1, T_k]
            din_out = tf.where(seq_mask, paddings, din_out)

            if self.is_weight_normalization:
                din_out = tf.nn.softmax(din_out, axis=2)

            # Attention 输出
            output = tf.matmul(din_out, V)
            output = tf.squeeze(output)
            return output

计算逻辑与维度可参考上面的文字注释,这里增加了 padding 与 weight_normalization,din_out 为最终的加权权重,V 为 values 即 lookup 得到的序列 Embedding。

• mode = 'Multi-Head Attention'

        elif self.mode == 'Multi-Head Attention':

            # Linear
            Q = self.DNN(Q)  # [None, T_q, emd_dim]
            K = self.DNN(K)  # [None, T_k, emd_dim]
            V = self.DNN(V)  # [None, T_k, emd_dim]

            # Split And Concat
            Q_ = tf.concat(tf.split(Q, self.num_heads, axis=2), axis=0)  # [h*None, T_q, emd_dim/h]
            K_ = tf.concat(tf.split(K, self.num_heads, axis=2), axis=0)  # [h*None, T_k, emd_dim/h]
            V_ = tf.concat(tf.split(V, self.num_heads, axis=2), axis=0)  # [h*None, T_k, emd_dim/h]

            # Scaled Dot-Product
            # [h*None, T_q, emd_dim/h] x [h*None, emd_dim/h, T_k] -> [h*None, T_q, T_k]
            d_k = Q_.shape[-1]
            weight = tf.matmul(Q_, tf.transpose(K_, [0, 2, 1]))
            weight = weight / (d_k ** 0.5)
            weight = tf.nn.softmax(weight)

            # Weighted-Sum
            # [h*None, T_q, T_k] * [h*None, T_k, emd_dim/h] -> [h*None, T_q, emd_dim/h]
            weighted = tf.matmul(weight, V_)
            print("Weight Shape: %s Value Shape: %s Weighted-Sum Shape: %s" % (weight.shape, V_.shape, weighted.shape))

            # Concat && Linear
            # [None, T_q, emd_dim]
            concat = tf.squeeze(tf.concat(tf.split(weighted, self.num_heads, axis=0), axis=2))
            multiHeadOutput = self.LastDNN(concat)

            return multiHeadOutput

Split 负责根据 head 数量将原始向量拆分为多个向量子空间,d_k 为缩放系数,这个可以根据自己场景决定,与上面 Attention 不同的是前后增加了两个 Linear 层,除此之外,实际应用时这里可能还需要 Paddiing 与 Dropout。

3.OwnAttention Layer 测试

• mode = 'Attention'

    mode = 'Attention'
    attention = OwnAttention(mode)
    attention_output = attention([history, candidate])
    print("%s Output Shape: %s" % (mode, attention_output.shape))
Attention Output Shape: (5, 8)

• mode = 'Multi-Head Attention'

    mode = 'Multi-Head Attention'
    attention = OwnAttention(mode)
    attention_output = attention([history, candidate])
    print("%s Output Shape: %s" % (mode, attention_output.shape))
Weight Shape: (10, 1, 10) Value Shape: (10, 10, 4) Weighted-Sum Shape: (10, 1, 4)
Multi-Head Attention Output Shape: (5, 8)

四.总结

实现的比较简单,主要是粗略了解 Attention 与 Multi-Head Attention 的实现流程,实际应用场景下,如果 Goods 商品库的 N 太大,也可以采用 Hash 的方式,在牺牲一定性能的情况下弥补工程上的不足。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/565745.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

minio在window和linux下部署权限控制添加,JAVA代码实现

minio部署启用 参考官方,根据不同的操作系统,按照步骤部署 minio官网地址https://www.minio.org.cn/docs/minio/windows/index.html minio权限添加 minio权限添加https://blog.csdn.net/xnian_/article/details/130841657 windows环境部署 安装最小I…

chatgpt赋能python:Python与USB的结合——打造更高效的设备连接

Python与USB的结合——打造更高效的设备连接 Python作为一门广泛应用于各个领域的编程语言,在计算机硬件设备方面也有着广泛的运用。在设备连接这一领域中,Python的成功案例就是其与USB的结合。 什么是USB? USB即通用串行总线(…

win11 重装 NVIDIA 驱动

文章目录 win11 重装 NVIDIA 驱动1. 安装并使用驱动卸载工具 DDU2. 下载并安装 NVIDIA Toolkit3. 查看 CUDA 版本 win11 重装 NVIDIA 驱动 1. 安装并使用驱动卸载工具 DDU 浏览器搜索并下载安装 DDU win R 输入 msconfig 进入安全模式 重启后在安全模式下打开 DDU 完成卸…

讯飞星火_VS_文心一言

获得讯飞星火认知大模型体验授权,第一时间来测试一下效果,使用申请手机号登录后,需要同意讯飞SparkDesk体验规则,如下图所示: 同意之后就可以进行体验了,界面如下: 讯飞星火效果体验 以下Promp…

JavaScript实现循环读入整数进行累加,直到累加的和大于1000为止的代码

以下为实现循环读入整数进行累加,直到累加的和大于1000为止的程序代码和运行截图 目录 前言 一、循环读入整数进行累加,直到累加的和大于1000为止 1.1 运行流程及思想 1.2 代码段 1.3 JavaScript语句代码 1.4 运行截图 前言 1.若有选择&#xff0…

day3 - 图像在不同色彩空间间的转换

本期将了解三种不同的颜色空间,RGB,HSV,GRAY。会使用OpenCV来读取三种颜色空间,并且操作不同空间的转换,观察不同颜色空间的特点。 完成本期内容,你可以: 了解RGB,HSV,G…

雷达中的无源和有源的区别

常规雷达探测目标时,需要源源不断地发射无线电波,所以叫有源雷达( active radar)。有源雷达的优点是能自主搜索目标,因为它接收的是自己发射的电磁波,所以灵敏度高,分辨率好。但这种雷达易受目标的电磁干扰&#xff0c…

chatgpt赋能python:Python:一门强大的编程语言

Python:一门强大的编程语言 Python是一款高级编程语言,以其简单易用和多功能而闻名于世。Python首次发布于1989年,如今已成为许多开发者的首选编程语言。Python特别适合于数据处理、机器学习、人工智能等领域。 为什么选择Python&#xff1…

chatgpt赋能python:PythonWMS:优化仓库管理的新选择

Python WMS: 优化仓库管理的新选择 在现代商业环境中,仓库管理对于公司的供应链管理至关重要。然而,传统的仓库管理系统(WMS)经常过于复杂或桎梏化,不能适应快速变化的市场需求。现在,随着Python WMS的出现…

jQuery-基本过滤器

<!DOCTYPE HTML> <html> <head> <meta http-equiv"Content-Type" content"text/html; charsetUTF-8"> <title>基本过滤器</title> <style type"text/css"> …

Ubuntu安装RabbitMQ server - 在外远程访问

文章目录 前言1.安装erlang 语言2.安装rabbitMQ3. 内网穿透3.1 安装cpolar内网穿透(支持一键自动安装脚本)3.2 创建HTTP隧道 4. 公网远程连接5.固定公网TCP地址5.1 保留一个固定的公网TCP端口地址5.2 配置固定公网TCP端口地址 转载自cpolar内网穿透的文章&#xff1a;无公网IP&…

MyBatis-Plus_04 代码生成器、多数据源(主从)、MyBatisX插件

目录 ①. 代码生成器 ②. 多数据源&#xff08;主从&#xff09; ③. MyBatisX ①. 代码生成器 添加代码生成器依赖 <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-generator</artifactId><version>3.5.1&…

chatgpt赋能python:PythonTika:解析各种格式的文档

Python Tika: 解析各种格式的文档 简介 Python Tika是一个基于Apache Tika的Python库&#xff0c;可以解析各种格式的文档&#xff0c;如PDF、Microsoft Office、OpenOffice、XML、HTML、TXT等等。它提供了一种非常方便的方法来获取文档内容&#xff0c;包括元数据、正文、各…

Vue(Vuex插件)

一、Vuex的介绍 1. 概念 专门在Vue中实现集中式状态数据管理的一个Vue插件&#xff0c;对Vue的应用中多个组件的共享状态进行集中式的管理(读/写)&#xff0c;也是一种组件间通信的方式&#xff0c;且适用于任意组件间通信。 2. 了解vuex地址 https://github.com/vuejs/vuexh…

BLIP-2:salesforce提出基于冻结视觉编码器和LLM模型参数的高效训练多模态大模型

论文链接&#xff1a;https://arxiv.org/abs/2301.12597 项目代码&#xff1a;https://github.com/salesforce/LAVIS/tree/main/projects/blip2 体验地址&#xff1a;https://huggingface.co/spaces/Salesforce/BLIP2 文档介绍&#xff1a;https://huggingface.co/docs/tran…

浅谈数据中心供配电设计应用以及监控产品选型

摘 要&#xff1a;近年来&#xff0c;随着数据中心的迅猛发展&#xff0c;数据中心的能耗问题也越来越突出&#xff0c;有关数据中心的能源管理和供配电设计已经成为热门问题&#xff0c;合理可靠的数据中心配电系统方案&#xff0c;是提高数据中心电能使用效率&#xff0c;降低…

图片翻译怎么弄?如何把图片翻译成中文?

在使用社交媒体时&#xff0c;可能会遇到来自世界各地的异文化信息&#xff0c;这时我们可以借助图片翻译的方法帮助我们更好地了解这些信息&#xff0c;促进跨文化交流。那么图片翻译怎么弄呢&#xff1f;图片翻译的方法有哪些呢&#xff1f;这篇文章给你推荐三个非常好用的图…

深入理解Java虚拟机:JVM高级特性与最佳实践-总结-11

深入理解Java虚拟机&#xff1a;JVM高级特性与最佳实践-总结-11 Java内存模型与线程概述硬件的效率与一致性Java内存模型主内存与工作内存内存间交互操作 Java内存模型与线程 概述 多任务处理在现代计算机操作系统中几乎已是一项必备的功能了。在许多场景下&#xff0c;让计算…

22WPF----Prism框架

1.关于Prism框架 官网&#xff1a;Prism Library 文档可以参考 源码地址&#xff1a;https://github.com/PrismLibrary/Prism 版本8.1 Prism框架10历史、微软&#xff0c;最新版本使用 2、功能说明 Prism提供了一组设计模式的实现&#xff0c;有助于编写结构良好的且可维…