Transformer——多头注意力机制(Pytorch)

news2025/1/10 23:43:39

1. 原理图

2. 代码

import torch
import torch.nn as nn


class Multi_Head_Self_Attention(nn.Module):
    def __init__(self, embed_size, heads):
        super(Multi_Head_Self_Attention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False)

    def forward(self,queries, keys, values, mask):
        N = queries.shape[0]  # batch_size
        query_len = queries.shape[1]  # sequence_length
        key_len = keys.shape[1]  # sequence_length 
        value_len = values.shape[1]  # sequence_length

        queries = self.queries(queries)
        keys = self.keys(keys)
        values = self.values(values)

        # Split the embedding into self.heads pieces
        # batch_size, sequence_length, embed_size(512) --> 
        # batch_size, sequence_length, heads(8), head_dim(64)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        values = values.reshape(N, value_len, self.heads, self.head_dim)

        # batch_size, sequence_length, heads(8), head_dim(64) --> 
        # batch_size, heads(8), sequence_length, head_dim(64)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Scaled dot-product attention
        score = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** (1/2))

        if mask is not None:
            score = score.masked_fill(mask == 0, float("-inf"))
        # batch_size, heads(8), sequence_length, sequence_length
        attention = torch.softmax(score, dim=-1)

        out = torch.matmul(attention, values)
        # batch_size, heads(8), sequence_length, head_dim(64) -->
        # batch_size, sequence_length, heads(8), head_dim(64) -->
        # batch_size, sequence_length, embed_size(512)
        # 为了方便送入后面的网络
        out = out.transpose(1, 2).contiguous().reshape(N, query_len, self.embed_size)
        out = self.fc_out(out)

        return out
    

batch_size = 64
sequence_length = 10
embed_size = 512
heads = 8
mask = None

Q = torch.randn(batch_size, sequence_length, embed_size)  
K = torch.randn(batch_size, sequence_length, embed_size)  
V = torch.randn(batch_size, sequence_length, embed_size)  

model = Multi_Head_Self_Attention(embed_size, heads)
output = model(Q, K, V, mask)
print(output.shape)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1921737.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Shiro550反序列化漏洞分析

shiro搭建教程可以在网上自行搜索 漏洞发现 进入shiro界面后,burp抓包,选择remember me并进行登录。观察burp抓到的包 登录之后服务器返回一个Cookie Remember me 之后用户的访问都带着这个Cookie 这个Cookie很长,可能会在里面存在一定的信…

springboot增加过滤器后中文乱码

记录一下小问题 public class RepeatableHttpServletWrapper extends HttpServletRequestWrapper {private byte[] body;public RepeatableHttpServletWrapper(HttpServletRequest request) throws IOException {super(request);request.setCharacterEncoding("UTF-8&q…

数据建设实践之大数据平台(一)准备环境

大数据组件版本信息 zookeeper-3.5.7hadoop-3.3.5mysql-5.7.28apache-hive-3.1.3spark-3.3.1dataxapache-dolphinscheduler-3.1.9大数据技术架构 大数据组件部署规划 node101node102node103node104node105datax datax datax ZK ZK ZK RM RM NM

Git的命令使用与IDEA内置git图形化的使用

Git 简介 Git 是分布式版本控制系统,它可以帮助开发人员跟踪和管理代码的更改。Git 可以记录代码的历史记录,并允许您在不同版本之间切换。 通过历史记录可以查看: 进行了哪些更改?谁进行了更改?何时进行了更改&#…

nodejs模板引擎(二)

虽然Jade现在已经被更名为Pug,但它的使用方式并没有太大的改变。下面是如何在Node.js中使用Pug(原Jade)模板引擎的基本步骤: 1. 安装 Pug 首先,你需要安装Pug模块。在你的项目目录中,使用npm来安装&#…

gradle 和 java 版本对应关系

文章目录 gradle 和 java 版本对应关系原地址 gradle 和 java 版本对应关系 原地址 https://docs.gradle.org/current/userguide/compatibility.html#compatibility

超市暑期(7-8月)生鲜之蔬果商品及营销操作建议!

生鲜经营的思路现在越来越被重视,越来越做的更精细化,营销方法和手段越来越多,如何正确地运用好营销策略,如何做到这个季节的生鲜经营既能保持新鲜,又能保持盈利呢? 7-8月份蔬菜重点商品及季节性商品 叶菜…

无人驾驶大热,新能源汽车智能化中的算网支持

来源新华社:百度“萝卜快跑”全无人驾驶汽车行驶在路上 当前,新能源汽车产业数智化已成为全球汽车产业数字化转型的焦点。一方面,随着人工智能、大数据、云计算等技术的深度融合,新能源汽车在自动驾驶、智能互联、能源管理等方面…

从零设计一个神经网络:实现手写数字识别

前言 为了能够更好的理解神经网络,从手写数字识别这个小任务来逐层弄清楚神经网络的工作原理以及一般流程是非常合适的。 这篇文章就来手写完成一个数字识别的任务,来说明如何设计、实现并训练一个标准的前馈神经网络,以期对神经网络有一个…

AI编程助手-Tabnine的使用体验

文章目录 一,安装使用1,VSCode安装Tabnine插件2,使用 三,Tabnine的工作原理1,深度学习的力量2,注意事项:最大化Tabnine的效能 在编程的世界里,每一行代码都承载着创造者的智慧与汗水…

ubuntu安装YOLOV8环境

文章目录 前言 前言 ubuntu20.04 使用vmware虚拟机 1、安装python sudo apt-get install python3 python3-pip2,安装虚拟环境 sudo apt install python3.8-venv3,创建虚拟环境 python3 -m venv yolov8-env4,进入虚拟环境 source yolov8…

测试人必会 K8S 操作之 Dashboard

在云计算和微服务架构的时代,Kubernetes (K8S) 已成为管理容器化应用的标准。然而,对于许多新手来说,K8S 的操作和管理常常显得复杂而神秘。特别是,当你第一次接触 K8S Dashboard 时,你是否也感到有些无所适从&#xf…

十大CRM系统对比:选出最适合你的工具

本文将分享10款优质CRM系统:纷享销客、Zoho CRM、HubSpot、Salesforce、悟空CRM、销售易、Pipedrive、Oracle CRM、Insightly、SugarCRM。 在选择CRM系统时,很多企业主和管理者都面临着一个难题:市面上的品牌众多,到底哪个才是最…

《昇思25天学习打卡营第14天|SSD目标检测》

SSD(Single Shot MultiBox Detector)是一种用于目标检测的深度学习算法。它的设计旨在同时检测多个对象,并确定它们在图像中的位置和类别。与其他目标检测算法相比,SSD具有速度快和精度高的特点,在实时检测应用中非常受…

python 代码设计贪吃蛇

代码: # -*- codeing utf-8 -*- import tkinter as tk import random from tkinter import messageboxclass Snake:def __init__(self, master):self.master masterself.master.title("Snake")# 创建画布self.canvas tk.Canvas(self.master, width400,…

Centos忘记密码,重置root密码

Centos忘记密码,重置root密码 操作环境:Centos7.6 1、选择包含rescue的选项,按e进入编辑模式 首先,我们需要重启系统,进入开机引导菜单界面。在这里,我们可以看到系统的内核版本和启动参数等信息。我们需…

期权专题12:期权保证金和期权盈亏

目录 1. 期权保证金 1.1 计算逻辑 1.2 代码复现 1.3 实际案例 2. 期权盈亏 2.1 价格走势 2.2 计算公式 2.2.1 卖出期权 2.2.2 买入期权 免责声明:本文由作者参考相关资料,并结合自身实践和思考独立完成,对全文内容的准确性、完整性或…

龙迅#LT8644EX适用于HDMI2.0 4进4出矩阵应用,分辨率最高支持4K60HZ!

1. 概述 LT8644EX是一款 1616 数字交叉点开关,具有 16 个差分 CML 兼容输入和 16 个差分 CML 输出。该LT8644EX针对每个端口的数据速率高达 6 Gbps 的不归零 (NRZ) 信令进行了优化。每个端口都提供可编程的输入均衡电平和可编程输出摆幅。…

10个Python函数参数进阶用法及代码优化

目录 1. 默认参数值:让函数更加灵活 2. 关键字参数:清晰的调用方式 3. *args:拥抱不确定数量的位置参数 4. **kwargs:处理不确定数量的关键字参数 5. 参数解包:简化多参数的传递 6. 命名关键字参数:限…

【第31章】MyBatis-Plus之注解配置

文章目录 前言一、注解介绍二、注解列表总结 前言 本文详细介绍了 MyBatisPlus 注解的用法及属性,提供了源码链接以便深入理解。欢迎通过下方链接查看注解类的源码。 Mybatis-Plus Annotation 源码 一、注解介绍 Mybatis-Plus注解统一存放在com.baomidou.mybatis…