自注意力机制、多头自注意力机制、填充掩码 Python实现

news2025/4/23 2:51:10

原理讲解

【Transformer系列(2)】注意力机制、自注意力机制、多头注意力机制、通道注意力机制、空间注意力机制超详细讲解

自注意力机制

import torch
import torch.nn as nn


# 自注意力机制
class SelfAttention(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)        

    def forward(self, x, mask=None):
        batch_size, seq_len, input_dim = x.shape
        
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        atten_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(input_dim, dtype=torch.float))
        if mask is not None:
            mask = mask.unsqueeze(1)
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))        

        atten_scores = torch.softmax(atten_weights, dim=-1)

        attented_values = torch.matmul(atten_scores, v)

        return attented_values

# 自动填充函数
def pad_sequences(sequences, max_len=None):
    batch_size = len(sequences)
    input_dim = sequences[0].shape[-1]
    lengths = torch.tensor([seq.shape[0] for seq in sequences])
    max_len = max_len or lengths.max().item()
    
    padded = torch.zeros(batch_size, max_len, input_dim)
    for i, seq in enumerate(sequences):
        seq_len = seq.shape[0]
        padded[i, :seq_len, :] = seq

    mask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)

    return padded, mask.long()


if __name__ == '__main__':
    batch_size = 2
    seq_len = 3
    input_dim = 128
    seq_len_1 = 3
    seq_len_2 = 5
    
    x1 = torch.randn(seq_len_1, input_dim)            
    x2 = torch.randn(seq_len_2, input_dim)
    
    target_seq_len = 10    
    padded_x, mask = pad_sequences([x1, x2], target_seq_len)

    selfattention = SelfAttention(input_dim)    
    attention = selfattention(padded_x)
    print(attention)

多头自注意力机制

import torch
import torch.nn as nn

# 定义多头自注意力模块
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads

        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)        

    def forward(self, x, mask=None):
        batch_size, seq_len, input_dim = x.shape

        # 将输入向量拆分为多个头
        ## transpose(1,2)后变成 (batch_size, self.num_heads, seq_len, self.head_dim)形式
        q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力权重
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

        # 应用 padding mask
        if mask is not None:
            # mask: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len) 用于广播
            mask = mask.unsqueeze(1).unsqueeze(2)  # 扩展维度以便于广播
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))        
        
        attn_scores = torch.softmax(attn_weights, dim=-1)

        # 注意力加权求和
        attended_values = torch.matmul(attn_scores, v).transpose(1, 2).contiguous().view(batch_size, seq_len, input_dim)

        return attended_values

# 自动填充函数
def pad_sequences(sequences, max_len=None):
    batch_size = len(sequences)
    input_dim = sequences[0].shape[-1]
    lengths = torch.tensor([seq.shape[0] for seq in sequences])
    max_len = max_len or lengths.max().item()
    
    padded = torch.zeros(batch_size, max_len, input_dim)
    for i, seq in enumerate(sequences):
        seq_len = seq.shape[0]
        padded[i, :seq_len, :] = seq

    mask = torch.arange(max_len).expand(batch_size, max_len) < lengths.unsqueeze(1)

    return padded, mask.long()

if __name__ == '__main__':
    heads = 2
    batch_size = 2
    seq_len_1 = 3
    seq_len_2 = 5
    input_dim = 128
    
    x1 = torch.randn(seq_len_1, input_dim)            
    x2 = torch.randn(seq_len_2, input_dim)
    
    target_seq_len = 10    
    padded_x, mask = pad_sequences([x1, x2], target_seq_len)
    
    multiheadattention = MultiHeadSelfAttention(input_dim, heads)
    
    attention = multiheadattention(padded_x, mask)    
    print(attention)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2340462.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

目标检测篇---R-CNN梳理

目标检测系列文章 第一章 R-CNN 目录 目标检测系列文章&#x1f4c4; 论文标题&#x1f9e0; 论文逻辑梳理1. 引言部分梳理 (动机与思想) &#x1f4dd; 三句话总结&#x1f50d; 方法逻辑梳理&#x1f680; 关键创新点&#x1f517; 方法流程图补充边界框回归 (BBR)1. BBR 的…

C#处理网络传输中不完整的数据流

1、背景 在读取byte数组的场景&#xff08;例如&#xff1a;读取文件、网络传输数据&#xff09;中&#xff0c;特别是网络传输的场景中&#xff0c;非常有可能接收了不完整的byte数组&#xff0c;在将byte数组转换时&#xff0c;因字符的缺失/增多&#xff0c;转为乱码。如下…

HTML 初识

段落标签 <p><!-- 段落标签 -->Lorem ipsum dolor sit amet consectetur adipisicing elit. Fugiat, voluptate iure. Obcaecati explicabo sint ipsum impedit! Dolorum omnis voluptas sint unde sed, ipsa molestiae quo sapiente quos et ad reprehenderit.&l…

MATLAB 训练CNN模型 yolo v4

学生对小车控制提出了更好的要求&#xff0c;能否加入深度学习模型。 考虑到小车用matlab来做&#xff0c;yolo v5及以上版本都需要在pytorch下训练&#xff0c;还是用早期版本来演示。 1 yolov4 调用 参考 trainYOLOv4ObjectDetector (mathworks.com) name "tiny-yo…

【前端】跟着maxkb学习logicflow流程图画法

文章目录 背景1. 选定学习对象-maxkb应用逻辑编排2. 确定实现框架3. 关键逻辑&#xff1a;查看app-node.js4. 学习开始节点绘制流程数据形式 5. 给节点增加表单输入框遇到过的问题 背景 看看前端如何绘制流程图&#xff0c;界面好看点。 "logicflow/core": "1.…

【漏洞复现】CVE-2024-38856(ApacheOfbiz RCE)

【漏洞复现】CVE-2024-38856&#xff08;ApacheOfbiz RCE&#xff09; 1. 漏洞描述 Apache OFBiz 是一个开源的企业资源规划&#xff08;ERP&#xff09;系统。它提供了一套企业应用程序&#xff0c;用于集成和自动化企业的许多业务流程。 这个漏洞是由于对 CVE-2023-51467 的…

超详细VMware虚拟机扩容磁盘容量-无坑版

1.环境&#xff1a; 虚拟机&#xff1a;VMware Workstation 17 Pro-17.5.2 Linux系统&#xff1a;Ubuntu 22.04 LTS 2.硬盘容量 虚拟机当前硬盘容量180G -> 扩展至 300G 3.操作步骤 &#xff08;1&#xff09;在虚拟机关机的状态下&#xff0c;虚拟机硬盘扩容之前必…

全面理解Linux 系统日志:核心文件与查看方法

全文目录 1 Linux 系统日志分类及功能1.1 通用日志1.1.1 ‌/var/log/messages1.1.2 ‌/var/log/syslog 1.2 安全相关日志1.2.1 ‌/var/log/auth.log‌&#xff08;Debian/Ubuntu&#xff09;或 ‌/var/log/secure‌&#xff08;RHEL/CentOS&#xff09;1.2.2 /var/log/audit/au…

机器学习-08-关联规则更新

总结 本系列是机器学习课程的系列课程&#xff0c;主要介绍机器学习中关联规则和协同过滤。 参考 机器学习&#xff08;三&#xff09;&#xff1a;Apriori算法&#xff08;算法精讲&#xff09; Apriori 算法 理论 重点 【手撕算法】【Apriori】关联规则Apriori原理、代码…

Flutter与FastAPI的OSS系统实现

作者&#xff1a;孙嘉成 目录 一、对象存储 二、FastAPI与对象存储 2.1 缤纷云S4服务API对接与鉴权实现 2.2 RESTful接口设计与异步路由优化 三、Flutter界面与数据交互开发 3.1 应用的创建 3.2页面的搭建 3.3 文件的上传 关键词&#xff1a;对象存储、FastAPI、Flutte…

Kubernetes控制平面组件:API Server详解(二)

云原生学习路线导航页&#xff08;持续更新中&#xff09; kubernetes学习系列快捷链接 Kubernetes架构原则和对象设计&#xff08;一&#xff09;Kubernetes架构原则和对象设计&#xff08;二&#xff09;Kubernetes架构原则和对象设计&#xff08;三&#xff09;Kubernetes控…

MySQL-锁机制3-意向共享锁与意向排它锁、死锁

文章目录 一、意向锁二、死锁应该如何避免死锁问题&#xff1f; 总结 一、意向锁 在表获取共享锁或者排它锁时&#xff0c;需要先检查该表有没有被其它事务获取过X锁&#xff0c;通过意向锁可以避免大量的行锁扫描&#xff0c;提升表获取锁的效率。意向锁是一种表级锁&#xf…

报告系统状态的连续日期 mysql + pandas(连续值判断)

本题用到知识点&#xff1a;row_number(), union, date_sub(), to_timedelta()…… 目录 思路 pandas Mysql 思路 链接&#xff1a;报告系统状态的连续日期 思路&#xff1a; 判断连续性常用的一个方法&#xff0c;增量相同的两个列的差值是固定的。 让日期与行号 * 天数…

Tailwind 武林奇谈:bg-blue-400 失效,如何重拾蓝衣神功?

前言 江湖有云,Tailwind CSS,乃前端武林中的轻功秘籍。习得此技,排版如行云流水,配色似御风随形,收放自如,随心所欲。 某日,小侠你奋笔敲码,正欲施展“蓝衣神功”(bg-blue-400),让按钮怒气冲冠、蓝光满面,怎料一招使出,画面竟一片白茫茫大地真干净,毫无半点杀气…

开始放飞之先搞个VSCode

文章目录 开始放飞之先搞个VSCode重要提醒安装VSCode下载MinGW-w64回到VSCode中去VSCode原生调试键盘问题遗留问题参考文献 开始放飞之先搞个VSCode 突然发现自己的新台式机上面连个像样的编程环境都没有&#xff0c;全是游戏了&#xff01;&#xff01;&#xff01;&#xff…

基于SA模拟退火算法的车间调度优化matlab仿真,输出甘特图和优化收敛曲线

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于SA模拟退火算法的车间调度优化matlab仿真,输出甘特图和优化收敛曲线。输出指标包括最小平均流动时间&#xff0c;最大完工时间&#xff0c;最小间隙时间。 2…

【仿Mudou库one thread per loop式并发服务器实现】SERVER服务器模块实现

SERVER服务器模块实现 1. Buffer模块2. Socket模块3. Channel模块4. Poller模块5. EventLoop模块5.1 TimerQueue模块5.2 TimeWheel整合到EventLoop5.1 EventLoop与线程结合5.2 EventLoop线程池 6. Connection模块7. Acceptor模块8. TcpServer模块 1. Buffer模块 Buffer模块&…

uniapp h5接入地图选点组件

uniapp h5接入地图选点组件 1、申请腾讯地图key2、代码接入2.1入口页面 &#xff08;pages/map/map&#xff09;templatescript 2.2选点页面&#xff08;pages/map/mapselect/mapselect&#xff09;templatescript 该内容只针对uniapp 打包h5接入地图选点组件做详细说明&#x…

【随缘更新,免积分下载】Selenium chromedriver驱动下载(最新版135.0.7049.42)

目录 一、chromedriver概述 二、chromedriver使用方式 三、chromedriver新版本下载&#x1f525;&#x1f525;&#x1f525; 四、Selenium与Chrome参数设置&#x1f525;&#x1f525; 五、Selenium直接操控已打开的Chrome浏览器&#x1f525;&#x1f525;&#x1f525;…

jenkins批量复制Job项目的shell脚本实现

背景 现在需要将“测试” 目录中的所有job全部复制到 一个新目录中 test2。可以结合jenkins提供的apilinux shell 进行实现。 测试目录的实际文件夹名称是 test。 脚本运行效果如下&#xff1a; [qdevsom5f-dev-hhyl shekk]$ ./copy_jenkins_job.sh 创建文件夹 test2 获取源…