单头注意力机制(ScaledDotProductAttention) python实现

news2024/11/17 7:55:47

输入是query和 key-value,注意力机制首先计算query与每个key的关联性(compatibility),每个关联性作为每个value的权重(weight),各个权重与value的乘积相加得到输出。
在这里插入图片描述

import torch
import torch.nn as nn

class ScaledDotProductAttention(nn.Module):
    """ Scaled Dot-Product Attention """

    def __init__(self, scale):
        super().__init__()

        self.scale = scale
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):
        u = torch.bmm(q, k.transpose(1, 2)) # 1.Matmul
        u = u / self.scale # 2.Scale

        if mask is not None:
            u = u.masked_fill(mask, -np.inf) # 3.Mask

        attn = self.softmax(u) # 4.Softmax
        output = torch.bmm(attn, v) # 5.Output

        return attn, output


if __name__ == "__main__":
    n_q, n_k, n_v = 2, 4, 4
    d_q, d_k, d_v = 128, 128, 64
    batch = 2
    q = torch.randn(batch, n_q, d_q)
    k = torch.randn(batch, n_k, d_k)
    v = torch.randn(batch, n_v, d_v)
    mask = torch.zeros(batch, n_q, n_k).bool()

    attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))
    attn, output = attention(q, k, v, mask=mask)

    print(attn)
    print(output)

运行结果:


tensor([[[0.4165, 0.3548, 0.1667, 0.0620],
         [0.0381, 0.3595, 0.4584, 0.1439]],

        [[0.3611, 0.1587, 0.2078, 0.2723],
         [0.1603, 0.0530, 0.0670, 0.7198]]])
tensor([[[ 2.2813e-01, -6.3289e-01,  1.3624e+00,  8.4069e-01,  8.1762e-02,
          -6.3727e-01, -6.3929e-01, -1.0091e+00,  3.7668e-01, -2.9384e-01,
          -6.2543e-02, -4.4706e-01,  3.8331e-01,  2.2979e-02, -1.1968e+00,
          -3.7061e-01, -1.9007e-01, -1.7616e-01,  3.6516e-01,  1.1321e-01,
          -9.5077e-01, -1.3449e+00, -1.2594e+00,  4.2644e-01, -6.3195e-01,
          -5.2016e-01, -2.5782e-01, -2.4116e-01,  1.7582e-01, -1.5177e+00,
          -9.3120e-01, -4.9671e-01, -4.5024e-01, -1.0746e+00,  5.4357e-01,
          -6.2079e-01,  5.1379e-01,  5.6308e-02, -6.3830e-01, -3.6174e-01,
          -3.0044e-01, -3.0946e-01, -5.0303e-01, -1.8382e-01,  1.1064e+00,
          -7.5142e-01, -1.5372e-01, -3.3204e-01, -7.9568e-01,  1.3108e-01,
          -8.6041e-01,  2.5165e-01,  8.8248e-02,  3.7294e-01, -5.2247e-02,
           4.8462e-01, -7.4389e-01, -5.4351e-01, -9.7697e-01, -9.3327e-01,
          -4.4550e-02,  6.1108e-01, -5.4613e-01,  2.3962e-01],
         [ 6.9032e-02,  9.0591e-01,  8.3206e-01,  1.3668e+00,  1.8095e-02,
          -7.3172e-02, -3.0873e-01, -9.2571e-01,  4.3452e-01, -4.7707e-02,
          -3.0431e-01, -1.7578e-01,  4.0575e-01, -4.4958e-01, -4.9809e-01,
          -1.7263e-02, -3.8684e-01,  2.8536e-01,  4.1150e-02, -3.7069e-01,
          -7.2903e-01, -2.5185e-01, -1.0011e-01,  9.0434e-01, -7.8387e-02,
           6.9680e-01,  5.3684e-01,  2.8456e-01,  2.2887e-01, -1.7423e+00,
          -4.4135e-01, -2.9209e-01,  1.7053e-01, -6.4208e-01,  1.7977e-01,
           1.3822e-01, -1.7873e-01, -4.7619e-01, -6.7788e-01, -5.3340e-01,
           3.1518e-01, -5.6127e-02,  2.2175e-01, -3.9524e-01,  5.4478e-01,
          -5.7730e-01,  5.8043e-01, -3.0143e-01, -5.7146e-01,  1.5063e-05,
          -6.8221e-01, -1.3456e-02, -6.5192e-01,  7.4233e-02,  3.1776e-01,
           3.1504e-01, -9.5457e-01, -8.9894e-01, -7.8422e-01, -4.1440e-01,
          -9.4272e-02,  2.7226e-01, -7.0286e-01,  8.9388e-01]],

        [[-7.6068e-02,  1.6911e-01,  5.1532e-02, -5.3612e-02,  2.4258e-02,
           1.6490e-01,  7.4469e-01, -1.1471e+00, -4.5234e-01,  1.0684e-01,
           1.0929e+00, -5.8079e-01,  1.7665e-01, -2.0187e-02, -3.3850e-01,
           4.4517e-01, -4.5871e-01,  6.7840e-01, -4.3617e-01,  7.6141e-01,
           3.8135e-02, -2.3898e-01,  3.2086e-01,  4.1481e-01, -1.8267e-01,
           8.4337e-01,  7.8504e-02, -1.0101e+00,  5.0766e-02,  2.3338e-01,
          -3.5572e-01,  1.3751e-01, -4.9570e-02,  4.8627e-01, -3.3225e-01,
           6.5361e-01,  2.8979e-01,  9.9991e-02,  8.6995e-01, -7.2569e-02,
           2.5490e-01, -2.6418e-01,  6.1185e-01, -7.7243e-01, -4.6956e-01,
          -3.1459e-01, -2.1278e-01,  9.1588e-01, -2.1349e-02, -5.0036e-01,
           3.6214e-01,  1.3723e-02,  1.2322e-01, -5.3018e-01,  2.4809e-01,
          -3.2042e-01,  2.4807e-01, -1.5764e-01, -2.6655e-01,  1.8610e-01,
          -1.6585e-01,  2.3454e-01,  3.1852e-01,  6.1627e-01],
         [-1.7126e-01,  8.6634e-01,  4.7069e-01, -8.1842e-01, -6.2145e-01,
          -3.8596e-02,  1.2991e+00, -8.4528e-01, -1.5742e+00,  1.2813e+00,
           1.1197e+00, -1.2562e+00,  7.3848e-01,  2.2198e-02, -4.1664e-01,
           1.1044e+00, -1.2744e+00, -1.6599e-01, -6.4863e-01,  1.1497e+00,
          -1.4236e-01, -1.2829e-01, -2.7600e-01,  4.7095e-01, -5.1933e-02,
           8.7453e-01, -6.4251e-01, -4.2953e-01,  3.5337e-01, -2.2782e-01,
           2.5079e-01,  1.7728e-01,  6.4826e-01,  2.4980e-01,  8.3032e-02,
           2.1247e+00, -3.0265e-01, -1.9821e-01,  9.7439e-01, -3.6237e-01,
          -2.6392e-01, -5.1498e-01,  1.3055e+00, -9.1860e-01, -6.9769e-01,
           6.5717e-01,  5.8009e-01,  3.6944e-01,  2.0414e-01, -9.0271e-01,
           4.5972e-01,  9.4667e-01,  1.3700e-02, -2.7962e-01,  3.7535e-01,
          -4.1842e-01, -6.2615e-01,  6.8238e-03, -3.4866e-01,  5.7681e-01,
          -5.5240e-01,  1.8245e-01,  6.2508e-01,  6.0020e-01]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1517090.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AS-V1000 视频监控平台产品介绍:web客户端功能介绍(上)

目 录 一、引言 1.1 AS-V1000视频监控平台介绍 1.2 平台服务器配置说明 二、软件概述 2.1 软件用途 2.2 登陆界面 2.3 主界面 2.4 视频浏览 三、web端主要功能介绍 3.1 多画面风格 3.1.1风格切换 3.1.2 切换效果 3.2截屏 3.2.1 单画面截屏 3.2.2 …

【当前全网最详细】WebUI中使用Instant_ID来控制生成对象面部的用法

🎈为什么有这篇文章 中文网络上或者B站很多UP,在讲述WebUI中使用这个controlnet来换脸的时候,要么讲的过于复杂,要么就是没有讲清楚,所以这里整理下详细的使用方法,并记录下生成的内容。 如果懒得看文字可…

微信小程序-day01

文章目录 前言微信小程序介绍 一、为什么要学习微信小程序?二、微信小程序的历史创建开发环境1.注册账号2.获取APPID 三、下载微信开发者工具1.创建微信小程序项目2.填写相关信息3.项目创建成功 四、小程序目录结构项目的主体组成结构 总结 前言 微信小程序介绍 微信小程序&…

visa卡支持美区苹果Apple id绑定

苹果手机我相信大家都很熟悉,所以很多小伙伴都需要绑定卡来进行一系列的体验,这里我使用的是559666 在绑定之前我们需要先开一张visa卡,点击获取 开卡步骤如下,按图片步骤即可开卡 卡片信息在卡中心cvc安全码里面

STM32F407_外部中断

这里写目录标题 前言1、EXTI概述2、外部中断配置流程完整代码 前言 注释很详细,放心食用。 1、EXTI概述 STM32F4的每个IO都可以作为外部中断的中断输入口,这点也是STM32F4的强大之处。STM32F407的中断控制器支持22个外部中断/事件请求。每个中断线上都设…

hex文件格式解析

本文框架 1. hex文件格式1.1 数据长度1.2 地址域1.3 数据类型1.4 数据域1.5 CRC校验域 本文对hex文件格式进行解析,介绍各部分组成及其含义,在此mark下,方便后续开发脚本对hex文件进行操作。 1. hex文件格式 Intel HEX文件是由一行行符合Int…

Explain详解与索引优化最佳实践

Explain工具介绍 使用EXPLAIN关键字可以模拟优化器执行SQL语句,分析你的查询语句或是结构的性能瓶颈 在select语句之前增加explain关键字,MySQL会在查询前设置一个标记,执行查询会返回执行计划的信息,而不是执行这条SQL 注意: 如果from中包含子查询,仍会执行该子查询,将结果…

Gemma: Open Models Based on Gemini Research and Technology

Gemma: Open Models Based on Gemini Research and Technology 相关链接:arxiv 关键字:Gemma、Google DeepMind、open models、language understanding、reasoning 摘要 这项工作介绍了Gemma,一系列轻量级、最先进的开放模型,基于…

笔记本的显示器都是核显输出,还要独显干啥呢?

前言 今天小白还在睡梦中,就接到一个朋友发来的消息:笔记本的显示器都是直接在核显上的,没有改独显的选项。 如果是这样,那笔记本还有独立显卡什么事情?笔记本加了独显难道就只是为了圈钱? 其实并不是这样…

06双体系Java学习之算术运算符,赋值运算符,关系运算符

// 二元运算符//CtrlD : 复制当前行到下一行int a 10;int b 20;int c 25;int d 25;System.out.println(ab);System.out.println(a-b);System.out.println(a*b);System.out.println(a/(double)b);赋值运算符 关系运算符 package operator;public class Demo03 {public stati…

ModuleNotFoundError: No module named ‘sklearn.cross_validation‘

一、问题分析 ModuleNotFoundError: No module named sklearn.cross_validation 英文先翻译一遍,模块未找到问题,这里涉及到sklearn这个模块,Sklearn (全称 SciKit-Learn),是基于 Python 语言的机器学习工…

力扣每日一题 合并后数组中的最大元素 贪心

Problem: 2789. 合并后数组中的最大元素 思路 贪心:从右向左合并,尽可能的多合并,直到不能合并,更新答案,找前一阶段的最大合并值 复杂度 时间复杂度: O ( n ) O(n) O(n) 空间复杂度: O ( 1 ) O(1) O(1) Code …

1456.定长子串中元音的最大数目

题目:给你字符串 s 和整数 k 。 请返回字符串 s 中长度为 k 的单个子字符串中可能包含的最大元音字母数。 英文中的 元音字母 为(a, e, i, o, u)。 解题思路: 1.右侧新进入窗口的字母为元音字母,左侧移出窗口的字母…

C语言【典型算法编程题】总结

以下最全总结! 一,分支结构 1,if 编写程序,从键盘上输入三角形的三个边长(实数),判断这三个边能否构成三角形(构成三角形的条件为:任意两边之和大于第三边),如果能构成三角形,则计算三角形的面积并输出(保留2位小数);如果不能构成三角形,则输出“Flase”字符…

AJAX 03 XMLHttpRequest、Promise、封装简易版 axios

AJAX 学习 AJAX 3 原理01 XMLHttpRequest① XHR 定义② XHR & axios 关系③ 使用 XHR④ XHR查询参数案例:地区查询(URLSearchParams)⑤ XHR数据提交 POST 02 PromisePromise 使用Promise - 三种状态案例:使用Promise XHR 获取…

解析找不到msvcr120.dll无法继续执行此代码的多种修复方法

在计算机使用过程中,我们经常会遇到一些错误提示,其中之一就是“msvcr120.dll丢失”。这个错误通常会导致某些程序无法正常运行。为了解决这个问题,本文将介绍5种修复msvcr120.dll丢失的方法。 一,msvcr120.dll丢失会出现哪些问题…

sql注入重学

sql基本操作 基本查询语句 union (必须得是前面的列与后面的列相同才可以查询) 看第二局uses表中的列有3列,而emails中的列只有两列,所有无法成功查询 这就相当于我们再加了一列 group by (分组) 相当于将其分为10列…

Python 闭包和nonlocal声明

闭包是针对嵌套函数环境的概念,它的作用是延伸函数的作用域。简单来说,闭包就是一个函数,但它可以保存着上层函数作用域中的变量,使得这些变量可以在函数中使用。而nonlocal声明的作用就是允许函数重新绑定局部作用域以外且非全局…

3、鸿蒙学习-在AGC创建HarmonyOS 项目或应用

项目和应用介绍 关于项目 项目是资源、应用的组织实体。资源包括服务器、数据库、存储,以及您的应用、终端用户的数据等。在您使用部分服务时,您是数据的控制者,数据将按照您设置的数据处理位置来存储在指定区域。 通常,您不需…

paraview处理openfoam对称模型

paraview处理openfoam对称模型 步骤如下: 导入对称模型,以openfoam中xx\tutorials\incompressible\SRFSimpleFoam\mixer中的搅拌器为例;使用ctrl+space,查找transform,在Filters中也能找到;经过三次transform,可以移动旋转出对称的其他3部分;经过此三次移动旋转,并不能…