Pytorch手撸Attention

news2025/1/12 4:00:00

Pytorch手撸Attention

注释写的很详细了,对照着公式比较下更好理解,可以参考一下知乎的文章

注意力机制

在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as F


class SelfAttention(nn.Module):
    def __init__(self, embed_size):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size

        # 定义三个全连接层,用于生成查询(Q)、键(K)和值(V)
        # 用Linear线性层让q、k、y能更好的拟合实际需求
        self.value = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.query = nn.Linear(embed_size, embed_size)

    def forward(self, x):
        # x 的形状应为 (batch_size批次数量, seq_len序列长度, embed_size嵌入维度)
        batch_size, seq_len, embed_size = x.shape

        Q = self.query(x)
        K = self.key(x)
        V = self.value(x)
        
        # 计算注意力分数矩阵
        # 使用 Q 矩阵乘以 K 矩阵的转置来得到原始注意力分数
        # 注意力分数的形状为 [batch_size, seq_len, seq_len]
        # K.transpose(1,2)转置后[batch_size, embed_size, seq_len]
        # 为什么不直接使用 .T 直接转置?直接转置就成了[embed_size, seq_len,batch_size],不方便后续进行矩阵乘法
        attention_scores = torch.matmul(Q, K.transpose(1, 2)) / torch.sqrt(
            torch.tensor(self.embed_size, dtype=torch.float32))

        # 应用 softmax 获取归一化的注意力权重,dim=-1表示基于最后一个维度做softmax
        attention_weight = F.softmax(attention_scores, dim=-1)

        # 应用注意力权重到 V 矩阵,得到加权和
        # 输出的形状为 [batch_size, seq_len, embed_size]
        output = torch.matmul(attention_weight, V)

        return output

多头注意力机制

在这里插入图片描述

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, num_heads):
        super().__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        # 整除来确定每个头的维度
        self.head_dim = embed_size // num_heads
		
        # 加入断言,防止head_dim是小数,必须保证可以整除
        assert self.head_dim * num_heads == embed_size

        self.q = nn.Linear(embed_size, embed_size)
        self.k = nn.Linear(embed_size, embed_size)
        self.v = nn.Linear(embed_size, embed_size)
        self.out = nn.Linear(embed_size, embed_size)

    def forward(self, query, key, value):
        # N就是batch_size的数量
        N = query.shape[0]
        
        # *_len是序列长度
        q_len = query.shape[1]
        k_len = key.shape[1]
        v_len = value.shape[1]
		
        # 通过线性变换让矩阵更好的拟合
        queries = self.q(query)
        keys = self.k(key)
        values = self.v(value)
		
        # 重新构建多头的queries,permute调整tensor的维度顺序
        # 结合下文demo进行理解
        queries = queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        keys = keys.reshape(N, k_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        values = values.reshape(N, v_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
		
        # 计算多头注意力分数
        attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt(
            torch.tensor(self.head_dim, dtype=torch.float32))
        attention = F.softmax(attention_scores, dim=-1)
		
        # 整合多头注意力机制的计算结果
        out = torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size)
        # 过一遍线性函数
        out = self.out(out)

        return out

demo测试

self-attention测试
# 测试自注意力机制
batch_size = 2
seq_len = 3
embed_size = 4

# 生成一个随机数据 tensor
input_tensor = torch.rand(batch_size, seq_len, embed_size)

# 创建自注意力模型实例
model = SelfAttention(embed_size)

# print输入数据
print("输入数据 [batch_size, seq_len, embed_size]:")
print(input_tensor)

# 运行自注意力模型
output_tensor = model(input_tensor)

# print输出数据
print("输出数据 [batch_size, seq_len, embed_size]:")
print(output_tensor)

=======print=========

输入数据 [batch_size, seq_len, embed_size]:
tensor([[[0.7579, 0.7342, 0.1031, 0.8610],
         [0.8250, 0.0362, 0.8953, 0.1687],
         [0.8254, 0.8506, 0.9826, 0.0440]],

        [[0.0700, 0.4503, 0.1597, 0.6681],
         [0.8587, 0.4884, 0.4604, 0.2724],
         [0.5490, 0.7795, 0.7391, 0.9113]]])

输出数据 [batch_size, seq_len, embed_size]:
tensor([[[-0.3714,  0.6405, -0.0865, -0.0659],
         [-0.3748,  0.6389, -0.0861, -0.0706],
         [-0.3694,  0.6388, -0.0855, -0.0660]],

        [[-0.2365,  0.4541, -0.1811, -0.0354],
         [-0.2338,  0.4455, -0.1871, -0.0370],
         [-0.2332,  0.4458, -0.1867, -0.0363]]], grad_fn=<UnsafeViewBackward0>)
MultiHeadAttention

多头注意力机制务必自己debug一下,主要聚焦在理解如何拆分成多头的,不结合代码你很难理解多头的操作过程

1、queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 处理之后的 size = torch.Size([64, 8, 10, 16])

  • 通过上述操作,queries 张量的最终形状变为 [N, self.num_heads, q_len, self.head_dim]。这样的排列方式使得每个注意力头可以单独处理对应的序列部分,而每个头的处理仅关注其分配到的特定维度 self.head_dim
  • 这个形状是为了后续的矩阵乘法操作准备的,其中每个头的查询将与对应的键进行点乘,以计算注意力分数

2、attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt( torch.tensor(self.head_dim, dtype=torch.float32)) 将reshape后的quries的后两个维度进行转置后点乘,对应了 Q ⋅ K T Q \cdot K^T QKT ;根据demo这里的头数为8,所以公式中对应的下标 i i i 为8

3、在进行完多头注意力机制的计算后通过 torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size) 整合,变回原来的 [batch_size,seq_length,embed_size]形状

# 测试多头注意力
embed_size = 128  # 嵌入维度
num_heads = 8    # 头数
attention = MultiHeadAttention(embed_size, num_heads)

# 创建随机数据模拟 [batch_size, seq_length, embedding_dim]
batch_size = 64
seq_length = 10
dummy_values = torch.rand(batch_size, seq_length, embed_size)
dummy_keys = torch.rand(batch_size, seq_length, embed_size)
dummy_queries = torch.rand(batch_size, seq_length, embed_size)

# 计算多头注意力输出
output = attention(dummy_values, dummy_keys, dummy_queries)
print(output.shape)  # [batch_size, seq_length, embed_size]

=======print=========

torch.Size([64, 10, 128])

如果你难以理解权重矩阵的拼接和拆分,推荐李宏毅的attention课程(YouTobe)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1609580.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Kafka安装Windows版

系列文章目录 文章目录 系列文章目录前言前言 前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站,这篇文章男女通用,看懂了就去分享给你的码吧。 Kafka 是一个由 LinkedIn 开发的分布式消息系统,它于2011年年初开源,现…

java byte数组转String

hi&#xff0c;我是程序员王也&#xff0c;一个资深Java开发工程师&#xff0c;平时十分热衷于技术副业变现和各种搞钱项目的程序员~&#xff0c;如果你也是&#xff0c;可以一起交流交流。 今天我们聊聊Java中Byte数组转String的用法~ 转换方法概览 在Java中&#xff0c;将b…

【人工智能】机器学习算法综述及常见算法详解

目录 推荐 1、机器学习算法简介 1.1 机器学习算法包含的两个步骤 1.2 机器学习算法的分类 2、线性回归算法 2.1 线性回归的假设是什么&#xff1f; 2.2 如何确定线性回归模型的拟合优度&#xff1f; 2.3 如何处理线性回归中的异常值&#xff1f; 3、逻辑回归算法 3.1 …

蓝桥杯2024年第十五届省赛

E:宝石组合 根据给的公式化简后变为gcd(a,b,c)根据算数基本定理&#xff0c;推一下就可以了 然后我们对1到mx的树求约数&#xff0c;并记录约数的次数&#xff0c;我们选择一个最大的且次数大于等3的就是gcd int mx; vector<int> g[N]; vector<int> cnt[N]; int…

深入刨析 mysql 底层索引结构B+树

文章目录 前言一、什么是索引&#xff1f;二、不同索引结构对比2.1 二叉树2.2 平衡二叉树2.3 B-树2.4 B树 三、mysql 的索引3.1 聚簇索引3.2 非聚簇索引 前言 很多人看过mysql索引的介绍&#xff1a;hash表、B-树、B树、聚簇索引、主键索引、唯一索引、辅助索引、二级索引、联…

【面经】2024春招-云计算后台研发工程师2(三大行 TW等)

【面经】2024春招-云计算后台研发工程师2&#xff08;三大行 & TW等&#xff09; 文章目录 岗位与面经基础1&#xff1a;数据库 & 网络基础2&#xff1a;系统 & 网络编程模板3&#xff1a;算法 & 行测 岗位与面经 1、银行面经&#xff08;重点&#xff09; …

MinIO自定义权限控制浅研

转载说明&#xff1a;如果您喜欢这篇文章并打算转载它&#xff0c;请私信作者取得授权。感谢您喜爱本文&#xff0c;请文明转载&#xff0c;谢谢。 MinIO搭建好之后&#xff0c;出于不同场景的需要&#xff0c;有时候需要对不同的用户和Bucket做一些针对性的权限控制。 MinIO的…

2024年短剧素材怎么下载

看到好看的短剧&#xff0c;想要把它剪辑出来&#xff0c;该怎么下载呢&#xff0c;本文就教大家用工具进行下载 需要工具的点击下面的链接 下载高手链接&#xff1a;https://pan.baidu.com/s/1qJ81sNBzzzU0w6DWf-9Nxw?pwdl09r 提取码&#xff1a;l09r --来自百度网盘超级…

CommunityToolkit.Mvvm笔记---ObservableValidator

ObservableValidator 是实现 INotifyDataErrorInfo 接口的基类&#xff0c;它支持验证向其他应用程序模块公开的属性。 它也继承自 ObservableObject&#xff0c;因此它还可实现 INotifyPropertyChanged 和 INotifyPropertyChanging。 它可用作需要支持属性更改通知和属性验证的…

Pytest精通指南(19)断言和异常处理

文章目录 Pytest assert 简介assert 应用场景assert 测试结果assert 基本用法Pytest raises 简介raises 用途和作用raises 与 try 的区别python代码中使用try在测试用例中使用try使用pytest.raises() Pytest assert 简介 断言&#xff08;Assertion&#xff09;是编程中的一个基…

代码随想录算法训练营Day4 | 24.两两交换链表中的节点、19删除链表中的第N个节点、链表相交、142.环形链表

24.两两交换链表中的节点 题目&#xff1a;给你一个链表&#xff0c;两两交换其中相邻的节点&#xff0c;并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题&#xff08;即&#xff0c;只能进行节点交换&#xff09;。 题目链接&#xff1a;24. 两两交换…

蛋白质亚细胞定位预测(生物信息学工具-017)

直奔主题&#xff0c;下面这张表图怎么制作&#xff0c;一般都是毕业论文hh&#xff0c;蛋白质的亚细胞定位如何预测&#xff1f; 01 方法 https://wolfpsort.hgc.jp/ #官网小程序&#xff0c;简单好用&#xff0c;不用R包&#xff0c;python包&#xff0c;linux程序&#x…

Linux的学习之路:14、文件(1)

摘要 有一说一文件一天学不完&#xff0c;细节太多了&#xff0c;所以这里也没更新完&#xff0c;这里部分文件知识&#xff0c;然后C语言和os两种的文件操作 目录 摘要 一、文件预备 二、c文件操作 三、OS文件操作 1、系统文件I/O 2、接口介绍 四、思维导图 一、文件…

C++笔记:类和对象(一)

类和对象 认识类和对象 先来回忆一下C语言中的类型和变量&#xff0c;类型就像是定义了数据的规则&#xff0c;而变量则是根据这些规则来实际存储数据的容器。类是我们自己定义的一种数据类型&#xff0c;而对象则是这种数据类型的一个具体实例。类就可以理解为类型&#xff0c…

在Ubuntu 22.04上安装配置VNC实现可视化

前面安装的部分可以看我这篇文章 在Ubuntu 18.04上安装配置VNC实现Spinach测试可视化_ubuntu18开vnc-CSDN博客 命令差不多一样&#xff1a; sudo apt update sudo apt install xfce4 xfce4-goodies sudo apt install tightvncserver这个时候就可以启动server了 启动server&…

音频---数字mic

一、常见的数字mic pdm麦通过codec芯片将数字麦转换为i2s信号输入到SOC 纯pdm麦就是直接进入SOC的pdm接口&#xff0c;走的是PDM信号&#xff0c;PDM信号就是两个线&#xff0c;一根数据线一根时钟线&#xff08;如顺芯ES7201/7202把MIC信号转换成PDM&#xff09;。 二、DMIC…

等保合规:保护企业网络安全的必要性与优势

前言 无论是多部网络安全法律法规的出台&#xff0c;还是最近的“滴滴被安全审查”事件&#xff0c;我们听得最多的一个词&#xff0c;就是“等保。” 只要你接触安全类工作&#xff0c;听得最多的一个词&#xff0c;一定是“等保。” 那么&#xff0c;到底什么是“等保”呢…

Docker部署SpringBoot服务(Jar包映射部署)

介绍 项目在docker部署运行以后&#xff0c;每次需更新jar包时&#xff0c;都得重新制作镜像&#xff0c;再重新制作容器。流程及其繁琐&#xff0c;效率极低。 以下步骤是在不更新镜像和容器的前提下&#xff0c;直接更新jar完成项目更新的操作。 不重新制作镜像部署 1. 创…

Python简化远程部署和系统管理的工具库之fabric使用详解

概要 Python Fabric库是一个用于简化远程部署和系统管理的工具库。它提供了一组简洁而强大的函数和工具,可以帮助开发者轻松地在多台远程主机上执行命令、上传文件、下载文件等操作,从而实现自动化部署和管理。 安装 要使用Python Fabric库,首先需要安装Fabric模块。可以通…

OpenHarmony实战开发-如何使用text组件的enableDataDetector属性实现文本特殊文字识别。

介绍 本示例介绍使用text组件的enableDataDetector属性实现文本特殊文字识别。 效果图预览 使用说明 1.进入页面&#xff0c;输入带有特殊文字的信息并发送&#xff0c;对话列表中文本会自动识别并标识特殊文字。目前支持识别的类型包括电话号码、链接、邮箱和地址&#xff…