Transformer详解(3)-多头自注意力机制

news2024/10/5 18:22:07

attention

在这里插入图片描述

在这里插入图片描述

multi-head attention

在这里插入图片描述
在这里插入图片描述

pytorch代码实现

import math
import torch
from torch import nn
import torch.nn.functional as F


class MultiHeadAttention(nn.Module):
    def __init__(self, heads=8, d_model=128, droput=0.1):
        super().__init__()

        self.d_model = d_model  # 128
        self.d_k = d_model // heads  # 128//8=16
        self.h = heads  # 8

        self.q_linear = nn.Linear(d_model, d_model)  # (50,128)*(128,128)=(50,128),其中(128*128)属于权重,在网络训练中学习。
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(droput)
        self.out = nn.Linear(d_model, d_model)

    def attention(self, q, k, v, d_k, mask=None, dropout=None):
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # 矩阵乘法 (32,8,50,16)*(32,8,16,50)->(32,8,50,50)

        if mask is not None:
            mask = mask.unsqueeze(1)
            scores = scores.masked_fill(mask == 0, -1e9)

        scores = F.softmax(scores, dim=-1)

        if dropout is not None:
            scores = dropout(scores)

        output = torch.matmul(scores, v)  # (32,8,50,50)*(32,8,50,16)->(32,8,50,16)
        return output

    def forward(self, q, k, v, mask=None):
        bs = q.size(0)  # batch_size 大小  这里的例子是32
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.k_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.k_linear(v).view(bs, -1, self.h, self.d_k)
        # (32,50,128)->(32,50,128)->(32,50,8,16)  8*16=128 每个embedding拆成的8份,也就是8个头

        k = k.transpose(1, 2)  # (32,50,8,16)->(32,8,50,16)
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)

        scores = self.attention(q, k, v, self.d_k, mask, self.dropout)  # (32,8,50,16)
        concat = scores.transpose(1, 2).contiguous().view(bs, -1, self.d_model)  # (32,50,128)
        output = self.out(concat)  # (32,50,128)

        return output


if __name__ == '__main__':
    multi_head_attention = MultiHeadAttention(8, 128)
    normal_tensor = torch.randn(32, 50, 128)  # 随机生成均值为0,方差为1的正态分布。batch_size=32,序列长度=50,embedding维度=128。
    x = torch.sigmoid(normal_tensor)  # 把每个数缩放到(0,1)
    output = multi_head_attention(x, x, x)
    print('done')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1689300.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

通过el-tree自定义渲染网页版工作目录,实现鼠标悬浮显示完整名称、用icon区分文件和文件夹等需求

目录 一、通过el-tree自定义渲染网页版工作目录 1.1、需求介绍 1.2、使用el-tree生成文档目录 1.2.1、官方基础用法 ①效果 ②代码: 1.2.2、自定义文档目录(实现鼠标悬浮显示完整名称、用icon区分文件和文件夹) ①效果(直接效…

JavaScript表达式和运算符

表达式 表达式一般由常量、变量、运算符、子表达式构成。最简单的表达式可以是一个简单的值。常量或变量。例:var a10 运算符 运算符一般用符号来表示,也有些使用关键字表示。运算符由3中类型 1.一元运算符:一个运算符能够结合一个操作数&…

YOLOv8_seg的训练、验证、预测及导出[实例分割实践篇]

实例分割数据集链接,还是和目标检测篇一样,从coco2017val数据集中挑出来person和surfboard两类:链接:百度网盘 请输入提取码 提取码:3xmm 1.实例分割数据划分及配置 1.1实例分割数据划分 从上面得到的数据还不能够直接训练,需要按照一定的比例划分训练集和验证集,并按…

如何远程实时查看监控管理员工电脑屏幕?远程查看多台员工电脑屏幕的方法

现代化企业管理中,远程监控员工电脑屏幕已经成为一种有效的手段,用于提升工作效率、确保信息安全以及维护工作纪律。 而这种事情,也仿佛已经很常见了。 那么一般都如何监控的呢? 一、选择合适的远程监控软件 1.域智盾 域智盾提…

2024 电工杯高校数学建模竞赛(B题)数学建模完整思路+完整代码全解全析

你是否在寻找数学建模比赛的突破点?数学建模进阶思路! 作为经验丰富的数学建模团队,我们将为你带来2024电工杯数学建模竞赛(B题)的全面解析。这个解决方案包不仅包括完整的代码实现,还有详尽的建模过程和解…

markdown 文件渲染工具推荐 obsidian publish

背景 Markdown 是一种轻量级的标记语言,最开始使用它是觉得码字非常方便,从一开始的 word 排版到 markdown ,还不太不习惯,用了 obsidian把一些文字发在网上后,才逐渐发现他的厉害之处。 让人更加专注于内容本身&…

机器学习云环境测试

等待创建完成后,点击 PyTorch 打开,创建一个全新的 notebook 在 Cell 中输入如下代码,并点击 Run 完成后点击 New Cell ,在 New Cell 中输入如下代码 输入完成后点击 Run ,运行 New Cell 。(每个 Cell 代…

Java面试八股之Thread类中的yeild方法有什么作用

Thread类中的yeild方法有什么作用 谦让机制:Thread.yield()方法主要用于实现线程间的礼让或谦让机制。当某个线程执行到yield()方法时,它会主动放弃当前已获得的CPU执行权,从运行状态(Running)转变为可运行状态&#…

朴素贝叶斯+SMSSpamCollections

1. 打开 Jupyter 后,在工作目录中,新建一个文件夹命名为 Test01 ,并且在文件夹中导入数据 集。在网页端界面点击 “upload” 按钮,在弹出的界面中选择要导入的数据集。然后数据集出现 在 jupyter 文件目录中,此时…

基于SSM的大学生兼职管理系统

基于SSM的大学生兼职管理系统的设计与实现~ 开发语言:Java数据库:MySQL技术:SpringSpringMVCMyBatis工具:IDEA/Ecilpse、Navicat、Maven 系统展示 登录界面 企业界面 前台学生界面 管理员界面 摘要 随着大学生兼职市场的日益繁…

VTK9.2.0+QT5.14.0绘制三维显示背景

背景 上一篇绘制点云的博文中,使用的vtkCameraOrientationWidget来绘制的坐标轴,最近又学习到两种新的坐标轴绘制形式。 vtkOrientationMarkerWidget vtkAxesActor 单独使用vtkAxesActor能够绘制出坐标轴,但是会随着鼠标操作旋转和平移时…

Java设计模式 _行为型模式_迭代器模式

一、迭代器模式 1、迭代器模式 迭代器模式(Iterator Pattern)是一种行为型设计模式,用于顺序访问集合对象的元素,不需要关心集合对象的底层表示。如:java中的Iterator接口就是这个工作原理。 2、实现思路 &#xff0…

linux 中 fd 申请和释放管理(两级 bitmap)

linux 中 fd 的几点理解_linux fd-CSDN博客 通过上边的文章,我们可以知道,在 linux 中,fd 有以下几点需要了解: (1)fd 表示进程打开的文件,是进程级别的资源,不是系统级别的资源 …

Rhinoceros v7.5 解锁版安装教程 (3D三维造型软件)

前言 Rhinoceros 中文名称犀牛是一款超强的三维建模工具,全称Rhinoceros,Rhino是美国Robert McNeel & Assoc开发的PC上强大的专业3D造型软件,它可以广泛地应用于三维动画制作、工业制造、科学研究以及机械设计等领域。它能轻易整合3DS M…

Nodejs及stfshow相关例题

Nodejs及stfshow相关例题 Node.js 是一个基于 Chrome V8 引擎的 Javascript 运行环境。可以说nodejs是一个运行环境,或者说是一个 JS 语言解释器而不是某种库。 Node.js可以生成动态页面内容Node.js 可以在服务器上创建、打开、读取、写入、删除和关闭文件Node.js…

设计模式9——适配器模式

写文章的初心主要是用来帮助自己快速的回忆这个模式该怎么用,主要是下面的UML图可以起到大作用,在你学习过一遍以后可能会遗忘,忘记了不要紧,只要看一眼UML图就能想起来了。同时也请大家多多指教。 适配器模式(Adapte…

【Linux】Centos7安装JDK

【Linux】Centos7安装JDK 下载 Oracle 官网下载 JDK17 https://www.oracle.com/cn/java/technologies/downloads/#java17 安装 使用rz命令上传 jdk tar 包,上传失败直接用 xftp 上传 在安装图形界面时,有勾选开发工具,会自动安装 JDK 需要先…

【openlayers系统学习】3.4波段数学计算(计算NDVI)

四、波段数学计算(计算NDVI) 我们已经看到了如何使用 ol/source/GeoTIFF​ 源代码来渲染真彩色和假彩色合成。我们通过将缩放的反射率值直接渲染到红色、绿色或蓝色显示通道中的一个来实现这一点。还可以对来自GeoTIFF(或其他数据瓦片源&…

深度学习之基于Pytorch框架手写数字识别

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 一、项目背景与意义 手写数字识别是数字图像处理领域的一个经典问题,也是深度学习技术的一个常用应用场…

vue期末复习选择题1

1. 下面哪一项描述是错误的?(B) A.$("ul li:gt(5):not(:last)")选取ul标记里面索引值大于5且不是最后一个的li元素B.$("div").find("span")选取div元素的子元素spanC.$("div.showmore > a")选取…