Transformer - 注意⼒机制

news2024/12/23 3:55:08

Transformer - 注意⼒机制

flyfish

计算过程

flyfish

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import math

def attention(query, key, value, mask=None, dropout=None):

     # query的最后⼀维的⼤⼩, ⼀般情况下就等同于词嵌⼊维度, 命名为d_k
     d_k = query.size(-1)

     scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
     print("scores.shape:",scores.shape)#scores.shape: torch.Size([1, 12, 12])

     if mask is not None:
         scores = scores.masked_fill(mask == 0, -1e9)


     p_attn = F.softmax(scores, dim = -1)

     if dropout is not None:
         p_attn = dropout(p_attn)

     return torch.matmul(p_attn, value), p_attn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

       
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x +  self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)
#在测试attention的时候需要位置编码PositionalEncoding


# 词嵌⼊维度是8维
d_model = 8
# 置0⽐率为0.1
dropout = 0.1
# 句⼦最⼤⻓度
max_len=12

x = torch.zeros(1, max_len, d_model)
pe = PositionalEncoding(d_model, dropout, max_len)
                           
pe_result = pe(x)

print("pe_result:", pe_result)
query = key = value = pe_result
print("pe_result.shape:",pe_result.shape)

#没有mask的输出情况
#pe_result.shape: torch.Size([1, 12, 8])
attn, p_attn = attention(query, key, value)
print("no mask\n")
print("attn:", attn)
print("p_attn:", p_attn)

#scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 除以math.sqrt(d_k) 表示这个注意力就是 缩放点积注意力,如果没有,那么就是 点积注意力
#当Q=K=V时,又叫⾃注意⼒机制

#有mask的输出情况
print("mask\n")
mask = torch.zeros(1, max_len, max_len)
attn, p_attn = attention(query, key, value, mask=mask)
print("attn:", attn)
print("p_attn:", p_attn)
pe_result: tensor([[[ 0.0000e+00,  1.1111e+00,  0.0000e+00,  1.1111e+00,  0.0000e+00,
           1.1111e+00,  0.0000e+00,  1.1111e+00],
         [ 9.3497e-01,  6.0034e-01,  1.1093e-01,  1.1056e+00,  1.1111e-02,
           1.1111e+00,  1.1111e-03,  1.1111e+00],
         [ 1.0103e+00, -4.6239e-01,  2.2074e-01,  1.0890e+00,  2.2221e-02,
           0.0000e+00,  2.2222e-03,  1.1111e+00],
         [ 1.5680e-01, -1.1000e+00,  0.0000e+00,  1.0615e+00,  3.3328e-02,
           0.0000e+00,  3.3333e-03,  1.1111e+00],
         [-8.4089e-01, -7.2627e-01,  4.3269e-01,  1.0234e+00,  4.4433e-02,
           1.1102e+00,  4.4444e-03,  1.1111e+00],
         [-1.0655e+00,  3.1518e-01,  5.3270e-01,  0.0000e+00,  5.5532e-02,
           1.1097e+00,  5.5555e-03,  1.1111e+00],
         [-3.1046e-01,  1.0669e+00,  6.2738e-01,  9.1704e-01,  0.0000e+00,
           1.1091e+00,  6.6666e-03,  0.0000e+00],
         [ 7.2999e-01,  8.3767e-01,  7.1580e-01,  8.4982e-01,  7.7714e-02,
           1.1084e+00,  7.7777e-03,  1.1111e+00],
         [ 1.0993e+00, -1.6167e-01,  7.9706e-01,  7.7412e-01,  8.8794e-02,
           1.1076e+00,  8.8888e-03,  1.1111e+00],
         [ 4.5791e-01, -0.0000e+00,  8.7036e-01,  6.9068e-01,  9.9865e-02,
           1.1066e+00,  9.9999e-03,  1.1111e+00],
         [-6.0447e-01, -9.3230e-01,  9.3497e-01,  6.0034e-01,  1.1093e-01,
           1.1056e+00,  1.1111e-02,  1.1111e+00],
         [-1.1111e+00,  4.9174e-03,  9.9023e-01,  5.0400e-01,  1.2198e-01,
           1.1044e+00,  1.2222e-02,  1.1110e+00]]])
pe_result.shape: torch.Size([1, 12, 8])
scores.shape: torch.Size([1, 12, 12])
no mask

attn: tensor([[[ 1.0590e-01,  2.7361e-01,  4.9333e-01,  8.3999e-01,  5.0599e-02,
           1.0079e+00,  5.6491e-03,  1.0138e+00],
         [ 2.7554e-01,  2.0916e-01,  4.9203e-01,  8.6593e-01,  5.2177e-02,
           9.7066e-01,  5.6513e-03,  1.0398e+00],
         [ 2.8765e-01, -3.8825e-02,  4.7812e-01,  8.7535e-01,  5.4246e-02,
           8.4157e-01,  5.7015e-03,  1.0659e+00],
         [ 9.3666e-02, -1.8286e-01,  4.8727e-01,  8.5124e-01,  5.7070e-02,
           8.2547e-01,  5.9523e-03,  1.0712e+00],
         [-1.6747e-01, -1.0274e-01,  5.6960e-01,  7.7584e-01,  6.3699e-02,
           9.6958e-01,  6.7169e-03,  1.0546e+00],
         [-2.2646e-01,  6.8462e-02,  5.8668e-01,  7.2227e-01,  6.3119e-02,
           1.0233e+00,  6.8004e-03,  1.0310e+00],
         [ 8.8945e-04,  2.7654e-01,  5.3750e-01,  8.0958e-01,  5.2289e-02,
           1.0259e+00,  6.1360e-03,  9.6094e-01],
         [ 2.2231e-01,  2.2832e-01,  5.2263e-01,  8.4111e-01,  5.4828e-02,
           9.9655e-01,  5.9765e-03,  1.0298e+00],
         [ 2.6388e-01,  7.2239e-02,  5.3800e-01,  8.4070e-01,  5.8958e-02,
           9.5033e-01,  6.2306e-03,  1.0564e+00],
         [ 1.2822e-01,  7.4518e-02,  5.5305e-01,  8.1381e-01,  6.0125e-02,
           9.7442e-01,  6.4089e-03,  1.0462e+00],
         [-1.5757e-01, -1.3194e-01,  5.9562e-01,  7.6069e-01,  6.7079e-02,
           9.7264e-01,  7.0187e-03,  1.0607e+00],
         [-2.3505e-01,  5.6245e-03,  6.0160e-01,  7.3040e-01,  6.5491e-02,
           1.0176e+00,  7.0038e-03,  1.0367e+00]]])
p_attn: tensor([[[0.1488, 0.1215, 0.0514, 0.0396, 0.0698, 0.0703, 0.0875, 0.1205,
          0.0790, 0.0814, 0.0544, 0.0757],
         [0.1170, 0.1434, 0.0757, 0.0489, 0.0590, 0.0460, 0.0642, 0.1304,
          0.1161, 0.0943, 0.0527, 0.0524],
         [0.0716, 0.1094, 0.1341, 0.1067, 0.0716, 0.0379, 0.0407, 0.0930,
          0.1221, 0.0921, 0.0713, 0.0494],
         [0.0597, 0.0765, 0.1155, 0.1397, 0.1127, 0.0506, 0.0359, 0.0627,
          0.0918, 0.0806, 0.1056, 0.0688],
         [0.0692, 0.0607, 0.0509, 0.0740, 0.1475, 0.0846, 0.0509, 0.0607,
          0.0692, 0.0788, 0.1342, 0.1194],
         [0.0887, 0.0601, 0.0343, 0.0423, 0.1076, 0.1341, 0.0721, 0.0748,
          0.0591, 0.0777, 0.1057, 0.1435],
         [0.1232, 0.0938, 0.0411, 0.0335, 0.0722, 0.0804, 0.1351, 0.1103,
          0.0722, 0.0814, 0.0633, 0.0935],
         [0.1124, 0.1263, 0.0623, 0.0388, 0.0571, 0.0553, 0.0731, 0.1388,
          0.1134, 0.1001, 0.0571, 0.0652],
         [0.0758, 0.1157, 0.0841, 0.0584, 0.0670, 0.0450, 0.0492, 0.1166,
          0.1429, 0.1101, 0.0763, 0.0588],
         [0.0822, 0.0989, 0.0668, 0.0540, 0.0803, 0.0622, 0.0584, 0.1084,
          0.1158, 0.1046, 0.0879, 0.0804],
         [0.0548, 0.0551, 0.0515, 0.0705, 0.1364, 0.0845, 0.0454, 0.0617,
          0.0801, 0.0877, 0.1499, 0.1224],
         [0.0763, 0.0548, 0.0357, 0.0459, 0.1213, 0.1146, 0.0669, 0.0703,
          0.0616, 0.0802, 0.1224, 0.1499]]])
mask

scores.shape: torch.Size([1, 12, 12])
attn: tensor([[[0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185],
         [0.0381, 0.0461, 0.5194, 0.8105, 0.0555, 0.9236, 0.0061, 1.0185]]])
p_attn: tensor([[[0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833],
         [0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,
          0.0833, 0.0833, 0.0833, 0.0833]]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1565061.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

动态规划详解(Dynamic Programming)

目录 引入什么是动态规划?动态规划的特点解题办法解题套路框架举例说明斐波那契数列题目描述解题思路方式一:暴力求解思考 方式二:带备忘录的递归解法方式三:动态规划 推荐练手题目 引入 动态规划问题(Dynamic Progra…

QT背景介绍

🐌博主主页:🐌​倔强的大蜗牛🐌​ 📚专栏分类:QT❤️感谢大家点赞👍收藏⭐评论✍️ 目录 一、QT背景 1.1什么是QT 1.2QT的发展历史 1.3什么是框架、库 1.4QT支持的平台 1.5QT的优点 1.6QT的…

分布式锁 — Redisson 全面解析!

前言 分布式锁主要是解决集群,分布式下数据一致性的问题。在单机的环境下,应用是在同一进程下的,只需要保证单进程多线程环境中的线程安全性,通过 JAVA 提供的 volatile、ReentrantLock、synchronized 以及 concurrent 并发包下一…

JVM_垃圾收集器

GC垃圾收集器 文章目录 GC垃圾收集器GC垃圾回收算法和垃圾收集器关系GC算法主要有以下几种四种主要的垃圾收集器SerialParallelCMSG1垃圾收集器总结查看默认垃圾收集器 默认垃圾收集器有哪些各垃圾收集器的使用范围部分参数说明 新生代下的垃圾收集器并行GC(ParNew)并行回收GC&…

[Python GUI PyQt] PyQt5快速入门

PyQt5快速入门 PyQt5的快速入门0. 写在前面1. 思维导图2. 第一个PyQt5的应用程序3. PyQt5的常用基本控件和布局3.1 PyQt5的常用基本控件3.1.1 按钮控件 QPushButton3.1.2 文本标签控件 QLabel3.1.3 单行输入框控件 QLineEdit3.1.4 A Quick Widgets Demo 3.2 PyQt5的常用基本控件…

morkdown语法转微信公众号排版(免费)

morkdown语法转微信公众号排版(免费) 源码来自githab,有些简单的问题我都修复了。大家可以直接去找原作者的源码,如果githab打不开就从我下载的网盘里下载吧。 效果

在制定OKR的过程中,应该怎么确定目标O的来源或方向?

在制定OKR(Objectives and Key Results,目标与关键成果)的过程中,确定目标O的来源或方向是至关重要的一步。一个明确、合理的目标能够为团队指明方向,激发团队成员的积极性和创造力,进而推动公司的整体发展…

【嵌入式智能产品开发实战】(十五)—— 政安晨:通过ARM-Linux掌握基本技能【GNU C标准与编译器】

目录 GNU C 什么是C语言标准 C语言标准的内容 C语言标准的发展过程 1.K&R C 2.ANSI C 3.C99标准 4.C11标准 编译器对C语言标准的支持 编译器对C语言标准的扩展 政安晨的个人主页:政安晨 欢迎 👍点赞✍评论⭐收藏 收录专栏: 嵌入式智能产品…

信息技术学院大数据技术专业开展专业实训周

四川城市职业学院讯(信息技术学院 陈天伟)日前,为提升学生的工匠精神和职业认知,信息技术学院邀请企业专家入驻眉山校区大数据实训基地,开展数据标识专业实训周。 数据标识是大数据专业的核心技术,数据标识…

在CentOS 7上安装Python 3.7.7

文章目录 一、实战步骤1. 安装编译工具2. 下载Python 3.7.7安装包3. 上传Python 3.7.7安装包4. 解压缩安装包5. 切换目录并编译安装6. 配置Python环境变量7. 使配置生效8. 验证安装是否成功 二、实战总结 一、实战步骤 1. 安装编译工具 在终端中执行以下命令 yum -y groupin…

24年大一训练一(东北林业大学)

前言&#xff1a; 周五晚上的训练赛&#xff0c;以后应该每两周都会有一次。 正文&#xff1a; Problem:A矩阵翻转&#xff1a; #include<bits/stdc.h> using namespace std; int a[55][55]; int main(){int n,m;while(cin>>n>>m){for(int i1;i<n;i){for…

1.Git是用来干嘛的

本文章学习于【GeekHour】一小时Git教程&#xff0c;来自bilibili Git就是一个文件管理系统&#xff0c;这样说吧&#xff0c;当多个人同时在操作一个文件的同时&#xff0c;很容易造成紊乱&#xff0c;git就是保证文件不紊乱产生的 包括集中式管理系统和分布式管理系统 听懂…

每日一题:用c语言写(输入n个数(n小于等于100),输出数字2的出现次数)

目录 一、要求 二、代码 三、结果 ​四、注意 一、要求 二、代码 #define _CRT_SECURE_NO_WARNINGS #include <stdio.h> int main() {//输入n个数&#xff08;n小于等于100&#xff09;&#xff0c;输出数字2的出现次数;int n[100] ;int num 0;int count 0;/…

加域报错:找不到网络路径

在尝试将计算机加入Windows域时&#xff0c;如果收到“找不到网络路径”的错误提示&#xff0c;可能的原因及解决方法如下&#xff1a; 网络连接问题&#xff1a;确保计算机与域控制器之间的物理网络连接是正常的&#xff0c;可以通过ping命令测试与域控制器的连通性。例如&…

【黑马头条】-day05延迟队列文章发布审核-Redis-zSet实现延迟队列-Feign远程调用

文章目录 昨日回顾今日内容1 延迟任务1.1 概述1.2 技术对比1.2.1 DelayQueue1.2.2 RabbitMQ1.2.3 Redis实现1.2.4 总结 2 redis实现延迟任务2.0 实现思路2.1 思考2.2 初步配置实现2.2.1 导入heima-leadnews-schedule模块2.2.2 在Nacos注册配置管理leadnews-schedule2.2.3 导入表…

【单片机家电产品学习记录--红外线】

单片机家电产品学习记录–红外线 红外手势驱动电路&#xff0c;&#xff08;手势控制的LED灯&#xff09; 原理 通过红外线对管&#xff0c;IC搭建的电路&#xff0c;实现灯模式转换。 手势控制灯模式转换&#xff0c;详细说明 转载 1《三色调光LED台灯电路》&#xff0c…

大数据学习第十一天(复习linux指令3)

1、su和exit su命令就是用于账户切换的系统命令 基本语法&#xff1a;su[-] [用户名] 1&#xff09;-表示是否在切换用户后加载变量&#xff0c;建议带上 2&#xff09;参数&#xff1a;用户名&#xff0c;表示切换用户 3&#xff09;切换用户后&#xff0c;可以通过exit命令退…

Redhat 7.9 安装dm8配置文档

Redhat 7.9 安装dm8配置文档 一 创建用户 groupadd -g 12349 dinstall useradd -u 12345 -g dinstall -m -d /home/dmdba -s /bin/bash dmdba passwd dmdba二 创建目录 mkdir /dm8 chown -R dmdba:dinstall /dm8三 配置/etc/security/limits.conf dmdba soft nproc 163…

二叉树结点关键字输出的递归算法实现

在计算机科学中&#xff0c;二叉树是一种重要的数据结构&#xff0c;广泛应用于各种算法和程序设计中。二叉树的遍历是二叉树操作中的基础问题之一&#xff0c;其目的是以某种规则访问二叉树的每个结点&#xff0c;使得每个结点被且仅被访问一次。给定一个具有n个结点的二叉树&…

idea端口占用

报错&#xff1a;Verify the connector‘s configuration, identify and stop any process that‘s listening on port XXXX 翻译&#xff1a; 原因&#xff1a; 解决&#xff1a; 一、重启大法 二、手动关闭 启动spring项目是控制台报错&#xff0c;详细信息如下&#xff…