Scaled Dot-Product Attention

news2024/11/20 10:31:33

Scaled Dot-Product Attention

flyfish

Attention ( Q , K , V ) = softmax ( Q K T d k ) V {\text{Attention}}(Q, K, V) = \text{softmax}\left(\frac{QK^{T}}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

import torch
import torch.nn as nn
import torch.nn.functional as F

class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)

    def forward(self, q, k, v, mask=None):

        attn = torch.matmul(q / self.temperature, k.transpose(2, 3))

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e9)

        attn = self.dropout(F.softmax(attn, dim=-1))
        output = torch.matmul(attn, v)

        return output, attn

PyTorch官网的实现
torch.nn.functional.scaled_dot_product_attention

# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

在这里插入图片描述
点乘
在这里插入图片描述
计算过程
在这里插入图片描述

(1, 2, 3) • (7, 9, 11) = 1×7 + 2×9 + 3×11= 58
(1, 2, 3) • (8, 10, 12) = 1×8 + 2×10 + 3×12= 64
(4, 5, 6) • (7, 9, 11) = 4×7 + 5×9 + 6×11= 139
(4, 5, 6) • (8, 10, 12) = 4×8 + 5×10 + 6×12= 154

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1478378.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

源码框架-​1.Spring底层核心原理解析

目录 Spring中核心知识点: Bean的创建过程 推断构造方法 AOP大致流程 Spring事务 Spring中核心知识点: Bean的生命周期底层原理依赖注入底层原理初始化底层原理推断构造方法底层原理AOP底层原理Spring事务底层原理 ps:这篇文章中都只是大致流程,后续会针对每…

【vuex之五大核心概念】

vuex:五大核心概念 一、state状态1.state的含义2.如何访问以及使用仓库的数据(1)通过store直接访问获取store对象 (2)通过辅助函数MapState 二、mutations1.作用2.严格模式3.操作流程定义 mutations 对象,对象中存放修…

IEEE Transactions on Industrial Electronics工业电子TIE修改稿注意事项及提交须知

一、背景 兔年末投了一篇TIE,手稿初次提交的注意事项也整理成了博客IEEE Transactions on Industrial Electronics工业电子TIE论文投稿须知,获得了许多点赞和收藏。最近也收到了审稿结果,给的意见是大修major revision,总之只要不…

FinalShell连接Linux

远程连接linux 我们使用VMware可以得到Linux虚拟机,但是在/Mware中操作Linux的命令行页面不太方便,主要是: 内容的复制、粘贴跨越VMware不方便 文件的上传、下载跨越VMware不方便 不方便也就是和Linux系统的各类交互,跨越VMwar 到Linux操作系…

vue a-table 实现指定字段相同数据合并行

vue a-table 实现相同数据合并行 实现效果代码实现cloums数据格式数据源格式合并代码 实现效果 代码实现 cloums数据格式 const getColumns function () {return [{title: "分类",dataIndex: "checked",width: "150px",customRender: (text, …

外贸贸易术语FCA是什么?

在国际贸易实践中,FOB是最早应用于国际贸易的术语之一,一直以来都是使用比例最高的贸易术语。但近年来又一匹“黑马”脱颖而出——“FCA”术语。 01 概念和应用 《Incoterms2020》将11个贸易术语分为适用于任何运输方式或多种运输方式和适用于海运和内…

k8s部署 多master节点负载均衡以及集群高可用

一、k8s 添加多master节点实验 1、master02节点初始化操作 2、在master01节点基础上,完成master02节点部署 ①从master01节点复制所需要的文件 需要从master01节点复制etcd数据库所需要的ssl证书、kubernetes安装目录(二进制文件、组件与apiserver通信…

面试官:谈一谈Cookie和Session的区别?

我先解释一下Cookie,它是客户端浏览器用来保存服务端数据的一种机制,当我们通过浏览器去进行网页访问的时候,服务器可以把一些状态数据以key-value的形式写入到Cookie里面,存储到客户端浏览器。下一次这个浏览器再访问服务器的时候…

DPU是什么?

问题描述: DPU是什么? 解答: DPU(Data Processing Unit)是以数据为中心构造的专用处理器,采用软件 定义技术路线支撑基础设施层资源虚拟化,支持存储、安全、服务质量管理等 基础设施层服务。…

【vmware安装群晖】

vmware安装群晖 vmware安装群辉: vmware版本:17pro 下载链接, https://customerconnect.vmware.com/cn/downloads/details?downloadGroupWKST-1751-WIN&productId1376&rPId116859 激活码可自行搜索 教程: https://b…

云时代【4】—— 资源隔离与控制技术

云时代【4】—— 资源隔离与控制技术 二、资源隔离与控制技术(一)NameSpace1. 基本介绍2. 相关 Linux 指令实战1:隔离进程实战2:隔离文件系统 (二)CGroups1. 基本介绍2. 相关 Linux 指令实战1:c…

sora技术报告阅读

sora是一个在可变持续时间、分辨率和宽高比的视频和图像上联合训练文本条件扩散模型。 需要将所有类型的视觉数据转化为统一表示的方法,使得能够对生成模型进行大规模训练。 Sora是一个通用的视觉数据模型,它可以生成不同持续时间、宽高比和分辨率的视…

ctfshow——反序列化

文章目录 web 254——啥也没web 255——反序列化对变量进行赋值(1)web 256——反序列化对变量进行赋值(2)web 257——对象注入web 258——对象注入(绕过preg_match)web 259 web 254——啥也没 这里就是使用GET传输,use…

小程序中使用echarts地图

一、下载并安装echarts 1、下载echarts-for-weixin组件 echarts-for-weixin项目提供了一个小程序组件,用这种方式可以在小程序中方便地使用 ECharts。 下载ec-canvas项目(下载地址) ​​ 注意:下载的 ec-canvas 中的echarts的版本…

k8s.gcr.io/pause:3.2镜像丢失解决

文章目录 前言错误信息临时解决推荐解决onetwo 前言 使用Kubernetes(k8s)时遇到了镜像拉取的问题,导致Pod沙盒创建失败。错误显示在尝试从k8s.gcr.io拉取pause:3.2镜像时遇到了超时问题,这通常是因为网络问题或者镜像仓库服务器的…

【Go-Zero】测试API查询信息无法返回数据库信息与api、rpc文件编写规范

【Go-Zero】测试API查询信息无法返回数据库信息与api、rpc文件编写规范 大家好 我是寸铁👊 总结了一篇测试API查询信息无法返回数据库信息与api、rpc文件编写规范的文章✨ 喜欢的小伙伴可以点点关注 💝 问题背景 大家好,我是寸铁&#xff01…

TypeScript 中类的理解及应用场景

👩 个人主页:不爱吃糖的程序媛 🙋‍♂️ 作者简介:前端领域新星创作者、CSDN内容合伙人,专注于前端各领域技术,成长的路上共同学习共同进步,一起加油呀! ✨系列专栏:前端…

应用多元统计分析--多元数据的直观表示(R语言)

例1.2 为了研究全国31个省、市、自治区2018年城镇居民生活消费的分布规律,根据调查资料做区域消费类型划分。 指标: 食品x1:人均食品支出(元/人) 衣着x2:人均衣着商品支出(元/人) 居住x3:人均居住支出(元/人) 生活x4…

map和set的简单介绍

由于博主的能力有限,所以为了方便大家对于map和set的学习,我放一个官方的map和set的链接供大家参考: https://cplusplus.com/ 在初阶阶段,我们已经接触过STL中的部分容器,比如:vector、list、deque&#x…

【AI+应用】怎么快速制作一个类chatGPT套壳网站

最近有人问我, 看了我之前写的一篇文章 [人工智能] AI浪潮下Sora对于普通人的机会 , 怎么做一个类chatGPT的套壳网站,是从0开始做么。 对于普通人来说,万事不懂先AI, AI找不到答案搜索google或百度。对于程序员来说…