Self-Attention流程的代码实现【python】

news2024/9/19 19:48:56

文章目录

  • 1、知识回顾
  • 2、Self-attetion实现步骤
  • 3、准备输入
  • 4、初始化参数
  • 5、获取Q,K,V
  • 6、计算attention scores
  • 7、计算softmax
  • 8、给values乘上scores
  • 9、完整代码
  • 10、总结

🍃作者介绍:双非本科大四网络工程专业在读,阿里云专家博主,专注于Java领域学习,擅长web应用开发、数据结构和算法,初步涉猎人工智能和前端开发。
🦅个人主页:@逐梦苍穹
📕所属专栏:人工智能
🌻gitee地址:xzl的人工智能代码仓库
✈ 您的一键三连,是我创作的最大动力🌹

1、知识回顾

关于Self-Attention的一系列理论知识,请看我的另外一篇文章

深入剖析Self-Attention自注意力机制【图解】:https://xzl-tech.blog.csdn.net/article/details/141308634

这篇文章讲到了self-attention的计算过程,如果不想看那么细致的话,我们还是在这里简单复习一下:
image.png
image.png
image.png
那么,需要告诉大家的是,既然是要用代码实现,那肯定是需要以一个矩阵的角度去看待整个self-attention的计算过程,请看下文!

2、Self-attetion实现步骤

这里我们实现的注意力机制是现在比较流行的点积相乘的注意力机制
self-attention机制的实现步骤:

  1. 准备输入
  2. 初始化参数
  3. 获取key,query和value
  4. 给input1计算attention score
  5. 计算softmax
  6. 给value乘上score,获得output

整个过程都是以矩阵的视角在操作的:

image.png

3、准备输入

本文的关注点在于实现过程,所以数据方面我们采用自定义的方式获取:
image.png
这样就会得到如下张量:
image.png

4、初始化参数

在我上一篇剖析self-attention机制的文章中提到,整个self-attention的计算过程,需要学习的只有三个参数,那就是q,k,v对应的权重矩阵:
微信截图_20240819164437.png
这里同样不细讲如何学习,这里的重点在于带大家跑通整个self-attention计算的代码流程,
所以初始化参数如下:
image.png
来看一下输出:
image.png

5、获取Q,K,V

前面初始化了q,k,v对应的权重矩阵,下面获取Q,K,V:
image.png
如图所示,我们可以得到如下表达式:
Q = W q ( I n p u t ) Q=W^q(Input) Q=Wq(Input)
K = W k ( I n p u t ) K=W^k (Input) K=Wk(Input)
V = W v ( I n p u t ) V=W^v (Input) V=Wv(Input)


代码实现:
image.png
得到结果:
image.png

6、计算attention scores

我在上一篇讲解self-attention机制的文章中,关于计算attention scores的过程其实是分步计算的:
image.png
即分步计算 α i , j \alpha_{i,j} αi,j
但是在代码实现上,我们上面已经全部矩阵化了,我们得到的不是单独的 K 1 K^1 K1或者是 K 2 K^2 K2,而是关于 K a l l K^{all} Kall的矩阵( Q a l l Q^{all} Qall V a l l V^{all} Vall同理):
image.png
画成图解就是:
微信截图_20240819170437.png
所以这里计算的attention scores 用代码表示就是:
image.png
输出效果:
image.png

7、计算softmax

同样,这里一口气将所有的 α i , j \alpha_{i,j} αi,j经过 S o f t m a x Softmax Softmax处理:
微信截图_20240819170812.png
代码:
image.png
输出:
image.png

代码里面的dim=-1,指定在最后一个维度上应用 softmax 操作;
在二维张量的情况下,dim=-1 指的是在每一行(行向量)上计算 softmax

8、给values乘上scores

使用经过softmax后的attention score乘以它对应的value值:
image.png
代码:
image.png
输出:
image.png

9、完整代码

完整代码,代码即注释:

# -*- coding: utf-8 -*-
# @Author: CSDN@逐梦苍穹
# @Time: 2024/8/19 17:24
import torch
from torch.nn.functional import softmax

# 输入数据 x,包含3个输入向量,每个向量有4个维度
x = [
    [1, 0, 1, 0],  # 输入向量1
    [0, 2, 0, 2],  # 输入向量2
    [1, 1, 1, 1]   # 输入向量3
]
# 将输入数据转换为 PyTorch 张量,并设置数据类型为 float32
x = torch.tensor(x, dtype=torch.float32)

# 定义键(Key)的权重矩阵,形状为 (4, 3)
w_key = [
    [0, 0, 1],
    [1, 1, 0],
    [0, 1, 0],
    [1, 1, 0]
]
# 定义查询(Query)的权重矩阵,形状为 (4, 3)
w_query = [
    [1, 0, 1],
    [1, 0, 0],
    [0, 0, 1],
    [0, 1, 1]
]
# 定义值(Value)的权重矩阵,形状为 (4, 3)
w_value = [
    [0, 2, 0],
    [0, 3, 0],
    [1, 0, 3],
    [1, 1, 0]
]

# 将权重矩阵转换为 PyTorch 张量,并设置数据类型为 float32
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)

# 打印权重矩阵以供检查
print("w_key: ", w_key)
print("w_query: ", w_query)
print("w_value: ", w_value)

# 计算 Keys: 将输入 x 与键的权重矩阵相乘,生成键向量
keys = w_key @ x
# 计算 Queries: 将输入 x 与查询的权重矩阵相乘,生成查询向量
querys = w_query @ x
# 计算 Values: 将输入 x 与值的权重矩阵相乘,生成值向量
values = w_value @ x

# 打印键、查询和值向量以供检查
print("Keys: ", keys)
print("Querys: ", querys)
print("Values: ", values)

# 计算注意力分数(Attention Scores):通过键和查询向量的点积计算
# 结果是一个 (4, 4) 的矩阵,其中每个元素表示查询和键之间的相似度
attn_scores = keys @ querys
print("Attention Scores: ", attn_scores)

# 对注意力分数应用 Softmax 函数,将其转换为概率分布
# Softmax 处理后的矩阵形状仍为 (4, 4),表示每个查询对所有键的关注度
attn_scores_softmax = softmax(attn_scores, dim=-1)
print("Attention Scores Softmax: ", attn_scores_softmax)

# 计算加权后的输出值:将值向量与注意力分数进行加权求和
# 结果是一个形状为 (4, 4) 的矩阵,表示经过注意力加权后的最终输出
output = values @ attn_scores_softmax
print("output: ", output)

10、总结

对全文的代码过程做一个总结:
这份代码实现了自注意力机制的核心部分,包括(Key)、查询(Query)和(Value)的计算,以及通过注意力分数进行加权求和的过程

  1. 输入与权重定义
    1. 输入数据 x 包含 3 个向量,每个向量有 4 个维度
    2. 定义了三个权重矩阵 w_keyw_queryw_value,分别用于生成键、查询和值向量。
  2. 计算键、查询和值向量
    1. 将输入 x 分别与 w_keyw_queryw_value 相乘,生成对应的键、查询和值向量
    2. 这个步骤是将输入映射到不同的特征空间,以便进行注意力计算
  3. 计算注意力分数
    1. 通过键向量和查询向量的点积计算注意力分数
    2. 这些分数表示查询向量与键向量之间的相似度,用于决定每个查询向量对不同键向量的关注程度
  4. 应用 Softmax 函数
    1. 对注意力分数进行 softmax 操作,将这些分数转换为概率分布,确保每个查询对所有键的注意力之和为 1
    2. 这一步将注意力分数变为实际的注意力权重
  5. 计算加权后的输出值
    1. 将值向量与注意力权重相乘并求和,得到最终的加权输出
    2. 这一步模拟了注意力机制如何根据注意力权重聚合输入信息,从而生成最终的上下文表示

这些代码完整地展示了自注意力机制的基本工作流程;
通过计算注意力分数并对值向量进行加权求和,自注意力机制能够在输入序列中捕捉到相关信息,从而在各种深度学习任务中生成更具上下文感知的输出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2055080.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OOP篇(Java - 抽象类、类、对象、构造器、接口、内部类、 代码块、枚举)(doing)

目录 一、抽象类 1. 简介 2. 什么时候定义抽象类? 3. 什么是抽象方法? 4. 抽象类的作用是什么? 5. 继承抽象类需要做什么? 6. 抽象类为什么不能创建对象?自己干什么, 创建对象毫无意义 7. final和abstract是什…

【备战蓝桥杯青少组】第三天 放苹果

题 OpenJudge - 666:放苹果 描述 把M个同样的苹果放在N个同样的盘子里,允许有的盘子空着不放,问共有多少种不同的分法?(用K表示)5,1,1和1,5,1 是同一种分法。 输入 第一行…

Linux驱动入门实验班——DHT11、DS18B20模块驱动(附百问网视频链接)

目录 前言 一、DHT11模块 1.通信协议 2.数据格式 3.编程思路 ①入口函数 ②实现read函数 ③编写中断处理函数 ④***编写数据解析函数 ⑤应用程序 二、DS18B20模块 1. 通信时序 ① 初始化时序 ② 写时序 ③ 读时序 2. 常用命令 3. 编程思路 1.启动温度转换 2…

Dragonfly S 5MP工业相机量产 机器视觉应用的新选择

近日,51camera的合作厂商Teledyne FLIR IIS宣布Dragonfly️ S USB 5MP模块化、紧凑型相机现已全面投产,Dragonfly S 5MP是新Dragonfly S系列中首款迈入量产阶段的相机。 作为机器视觉应用领域的入门级产品,Dragonfly S不仅简化了成像系统的快…

实战Kubernetes之快速部署 K8s 集群 v1.28.0

文章目录 一、前言二、主机准备三、系统配置3.1. 关闭防火墙及相关配置3.2. 修改主机名3.3. 主机名DNS解析3.4. 时间同步3.5. 配置网络3.6. 重启服务器 四、安装软件4.1. 安装 Docker4.2. 安装 cri-dockerd4.3. 添加国内YUM源4.4. 安装 kubeadm、kubelet 和 kubectl 五、Master…

docker部署MySQL5.7.43并使用python脚本插入数据——实施案例

目录 一、配置docker环境 1. 阿里云镜像站配置docker环境 1. 安装必要的一些系统工具 ​编辑 2. 添加软件源信息 ​编辑 3. 修改 Docker 的 YUM 仓库配置文件,将 Docker 官方仓库的地址替换为阿里云的镜像源,以提高下载速度。 4. 更新并安装Dock…

【Qt】Qt窗口 | QDockWidget 浮动窗口

文章目录 一. 浮动窗口二. 代码创建&使用浮动窗口1. 创建浮动窗口2. 设置可停靠位置3. 添加控件 一. 浮动窗口 浮动窗口(也称为“停靠窗口”或“工具窗口”),是一个可以在主窗口内或主窗口外部悬浮的窗口。它通常用于显示工具栏、面板或其他附加信息。浮动窗口…

AScript 的UI asui模板的导入

两种方案: 第一种直接在web端,右击UI文件夹 第二种在pycharm,也是右击UI文件夹 调用UI,在init类中直接调用即可

Jupyter安装指南:最简便最详细的步骤

一.介绍 JupyterNotebook 是一个款以网页为基础的交互计算环境,可以创建Jupyter的文档,支持多种语言,包括Python, Julia, R等等。一般来说,如果是使用R语言的话,使用Rstudio居多,使用Python的话&#xff0…

高防服务器租用多少钱

高防服务器租用的具体价格受多种因素影响。通常情况下,高防服务器的租用费用可能从数百元到数万元不等,具体取决于服务提供商、服务器配置、防护级别等因素。下面将详细探讨决定高防服务器租用价格的几个主要因素,rak小编为您整理发布高防服务…

【LeetCode热题100】滑动窗口

这篇博客总结了滑动窗口的8道常见题目,分别是:长度最小的子数组、无重复字符的最长子串、 最大连续1的个数III、将x减到0的最小操作数、水果成篮、找到字符串中所有字母异位词、串联所有单词的子串、最小覆盖子串。 class Solution { public:int minSubA…

解决Vue3+Ts打包项目时会生成很多的map文件

正常打包会生成.js和.map文件 怎么去解决它呢? 正常来说我们会在vite.config.ts配置我们的项目打包方式,如下:(我这里的target:es2022是为了支持模块中顶层await的使用) // Vite 配置文件 export default…

海思NVR源码方案:集成ONVIF、GUI、存储与告警的全功能解决方案

海思平台作为中国领先的半导体厂商之一,其3520D芯片凭借高性能、低功耗和广泛的应用性成为了NVR(网络硬盘录像机)解决方案的核心选择。海思平台的NVR方案不仅支持多种编码格式,且兼容多种视频监控协议,特别是在ONVIF&a…

NC 二叉搜索树的第k个节点

系列文章目录 文章目录 系列文章目录前言 前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 描述 给定一棵结点…

【python】调用openAI api接口批量处理excel中的文本

调用openAI api接口批量处理文本 主页:github; BLOG:BLOG; 教程:视频 1. project简介 (1)概况 用于在python中调用open AI的API,处理xlsx表格中的自然语言文本。一个专门做dirty work的好帮手 &am…

Linux系统-系统信息网络目录文件的相关命令

1.系统信息和性能查看 查看磁盘的占用情况: df -Th 这是参数连着写。相当于df -T -h df -Th 此命令主要用于监控服务器的磁盘空间,如果空间不够用了,会导致服务器和应用的性能严重下降。这时候要手动清理一些不用的垃圾文件,比…

el-image-pro点击文本也能预览图片,支持下载

背景 element-ui:2.15.14 el-image的预览是没有下载功能的,默认是这样的 且默认是通过点击图片才能预览的,有时候我们显示的是图片名称,那么能不能直接点击图片名称来预览呢? 现在想在预览的时候,给它加一…

探秘陆生生态秘境:eDNA视角下的多营养级物种世界

现今的生物多样性和气候危机迫使我们开发更有效的陆地生态系统监测工具,eDNA宏条形码技术(eDNA metabarcoding),能够非侵入性地调查许多生态系统的物种丰富度,不会对生态环境造成干扰。通过分析这些信息,我…

树莓派开发笔记06-树莓派的SPI控制(点亮0.96OLED)

实验说明 我们这里会使用SPI去驱动一个0.96的OLED,首先需要打开SPI sudo raspi-config Interfacing Options------>SPI------>Yes------->OK------->finsh然后将屏幕接到树莓派上,接mosi和sclk的脚,DC接28,RST接29&…

C语言 ——— 学习并使用malloc和free函数

目录 malloc函数的功能 学习malloc函数​编辑 使用malloc函数 free函数的功能 学习并使用free函数​编辑 malloc动态开辟10个整型空间后赋值为0-9,再打印,打印后free malloc函数的功能 malloc函数能向内存申请一块连续可用的空间,并返…