深度学习:transpose_qkv()与transpose_output()

news2024/11/20 22:02:33

transpose_qkv 函数的主要作用是将输入的张量重新排列,使其适合多头注意力的计算。具体来说,它将输入张量的形状从 (batch_size, seq_len, num_hiddens) 转换为 (batch_size * num_heads, seq_len, num_hiddens // num_heads)

详细步骤

  • 输入形状
    假设输入的张量形状为 (batch_size, seq_len, num_hiddens),其中:
    batch_size 是批次大小。
    seq_len 是序列长度。
    num_hiddens 是隐藏层的维度。

  • 拆分多头
    多头注意力机制将 num_hiddens 维度拆分成 num_heads 个头,每个头的维度为 num_hiddens // num_heads。

  • 重新排列
    通过重新排列张量的维度,将 (batch_size, seq_len, num_hiddens) 转换为 (batch_size * num_heads, seq_len, num_hiddens // num_heads)。

具体实现

假设 transpose_qkv 函数的实现如下:

def transpose_qkv(X, num_heads):
    # X: (batch_size, seq_len, num_hiddens)
    batch_size, seq_len, num_hiddens = X.shape
    num_hiddens_per_head = num_hiddens // num_heads
    
    # 将 num_hiddens 维度拆分成 num_heads 个头
    X = X.reshape(batch_size, seq_len, num_heads, num_hiddens_per_head)
    
    # 交换维度,使得每个头的数据连续排列
    X = X.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, num_hiddens_per_head)
    
    # 将 batch_size 和 num_heads 合并
    X = X.reshape(batch_size * num_heads, seq_len, num_hiddens_per_head)
    
    return X
  • 解释
    1. 拆分维度:
      X.reshape(batch_size, seq_len, num_heads, num_hiddens_per_head):
      将 num_hiddens 维度拆分成 num_heads 个头,每个头的维度为 num_hiddens_per_head。
      此时,X 的形状为 (batch_size, seq_len, num_heads, num_hiddens_per_head)。
    2. 交换维度:
      X.permute(0, 2, 1, 3):
      将 num_heads 维度移到第二个位置,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, num_heads, seq_len, num_hiddens_per_head)。
    3. 合并维度:
      X.reshape(batch_size * num_heads, seq_len, num_hiddens_per_head):
      将 batch_size 和 num_heads 合并,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size * num_heads, seq_len, num_hiddens_per_head)。

总结

transpose_qkv 函数通过以下步骤将输入张量重新排列,使其适合多头注意力的计算:

  • 将 num_hiddens 维度拆分成 num_heads 个头。

  • 交换维度,使得每个头的数据连续排列。

  • 合并 batch_size 和 num_heads 维度,使得每个头的数据连续排列。

最终,transpose_qkv 函数返回形状为 (batch_size * num_heads, seq_len, num_hiddens // num_heads) 的张量,以便进行多头注意力计算。

transpose_output 函数的主要作用是将多头注意力的输出重新排列,使其适合后续的处理。具体来说,它将输入张量的形状从 (batch_size * num_heads, seq_len, num_hiddens // num_heads) 转换为 (batch_size, seq_len, num_hiddens)

具体实现

假设 transpose_output 函数的实现如下:

def transpose_output(X, num_heads):
    # X: (batch_size * num_heads, seq_len, num_hiddens_per_head)
    batch_size_times_num_heads, seq_len, num_hiddens_per_head = X.shape
    batch_size = batch_size_times_num_heads // num_heads
    
    # 将 batch_size 和 num_heads 拆分
    X = X.reshape(batch_size, num_heads, seq_len, num_hiddens_per_head)
    
    # 交换维度,使得每个头的数据连续排列
    X = X.permute(0, 2, 1, 3)  # (batch_size, seq_len, num_heads, num_hiddens_per_head)
    
    # 将 num_heads 和 num_hiddens_per_head 合并
    X = X.reshape(batch_size, seq_len, num_heads * num_hiddens_per_head)
    
    return X
  • 解释
    1. 拆分维度:
      X.reshape(batch_size, num_heads, seq_len, num_hiddens_per_head):
      将 batch_size * num_heads 维度拆分成 batch_size 和 num_heads。
      此时,X 的形状为 (batch_size, num_heads, seq_len, num_hiddens_per_head)。
    2. 交换维度:
      X.permute(0, 2, 1, 3):
      将 seq_len 维度移到第二个位置,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, seq_len, num_heads, num_hiddens_per_head)。
    3. 合并维度:
      X.reshape(batch_size, seq_len, num_heads * num_hiddens_per_head):
      将 num_heads 和 num_hiddens_per_head 合并,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, seq_len, num_hiddens)。

总结

transpose_output 函数通过以下步骤将多头注意力的输出重新排列,使其适合后续的处理:

  • 将 batch_size * num_heads 维度拆分成 batch_size 和 num_heads。

  • 交换维度,使得每个头的数据连续排列。

  • 合并 num_heads 和 num_hiddens_per_head 维度,使得每个头的数据连续排列。

最终,transpose_output 函数返回形状为 (batch_size, seq_len, num_hiddens) 的张量,以便进行后续的处理。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2244280.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux网络——套接字编程

目录 1. 网络通信基本脉络 2. 端口号 ① 什么是套接字编程? ② 端口号 port && 进程 PID 3. 网络字节序 4. 套接字编程 ① UDP版 ② TCP版 5. 改进方案与拓展 ①多进程版 ②多线程版 ③线程池版 ④守护进程化 1. 简单的重联 2. session &…

Excel如何把两列数据合并成一列,4种方法

Excel如何把两列数据合并成一列,4种方法 参考链接:https://baijiahao.baidu.com/s?id=1786337572531105925&wfr=spider&for=pc 在Excel中,有时候需要把两列或者多列数据合并到一列中,下面介绍4种常见方法,并且提示一些使用注意事项,总有一种方法符合你的要求:…

LabVIEW三针自动校准系统

基于LabVIEW的智能三针自动校准系统采用非接触式激光测径仪对标准三针进行精确测量。系统通过LabVIEW软件平台与硬件设备的协同工作,实现了数据自动采集、处理及报告生成,大幅提高了校准精度与效率,并有效降低了人为操作误差。 一、项目背景…

【Java】JDK集合类源码设计相关笔记

文章目录 前言1. Iterable2. RandomAccess2.1 RandomAccess 使用索引进行二分查找 3. Map3.1 HashMap3.2 IdentityHashMap 4. Collections 工具类4.1 Collections.shuffle() 洗牌 前言 目的: 收集JDK集合类的类图。记录一些有意思的设计。将之前写过的文章建立联系。 1. Ite…

macbook外接2k/1080p显示器调试经验

准备工具 电脑 满足电脑和显示器要求的hdmi线或者转接头或者扩展坞 betterdisplay软件 Dell P2419H的最佳显示信息如下 飞利浦 245Es 2K的最佳显示比例如下 首选1152

Stable Diffusion的解读(二)

Stable Diffusion的解读(二) 文章目录 Stable Diffusion的解读(二)摘要Abstract一、机器学习部分1. 算法梳理1.1 LDM采样算法1.2 U-Net结构组成 2. Stable Diffusion 官方 GitHub 仓库2.1 安装2.2 主函数2.3 DDIM采样器2.4 Unet 3…

Github 2024-11-16Rust开源项目日报 Top10

根据Github Trendings的统计,今日(2024-11-16统计)共有10个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Rust项目10Go项目1Python项目1Lapce:用 Rust 编写的极快且强大的代码编辑器 创建周期:2181 天开发语言:Rust协议类型:Apache License 2.0St…

Redis作为分布式锁,得会避坑

日常开发中,经常会碰到秒杀抢购等业务场景。为了避免并发请求造成的库存超卖等问题,我们一般会用到Redis分布式锁。但是使用Redis分布式锁之前要知道有哪些坑是需要我们避过去的。 1. 非原子操作(setnx expire) 一说到实现Redis…

Qt、C++实现五子棋人机对战与本地双人对战(高难度AI,极少代码)

介绍 本项目基于 Qt C 实现了一个完整的五子棋游戏,支持 人机对战 和 人人对战 模式,并提供了三种难度选择(简单、中等、困难)。界面美观,逻辑清晰,是一个综合性很强的 Qt 小项目 标题项目核心功能 棋盘…

Vulnhub靶场案例渗透[12]-Grotesque: 1.0.1

文章目录 一、靶场搭建1. 靶场描述2. 下载靶机环境3. 靶场搭建 二、渗透靶场1. 确定靶机IP2. 探测靶场开放端口及对应服务3. 目录扫描4. 敏感信息获取5. 反弹shell6. 权限提升 一、靶场搭建 1. 靶场描述 get flags difficulty: medium about vm: tested and exported from vi…

git日志查询和导出

背景 查看git的提交记录并下载 操作 1、找到你idea代码的路径,然后 git bash here打开窗口 2、下载所有的日志记录 git log > commit.log3、下载特定日期范围内记录 git log --since"2024-09-01" --until"2024-11-18" 你的分支 > c…

LeetCode Hot100 15.三数之和

题干: 思路: 首先想到的是哈希表,类似于两数之和的想法,共两层循环,将遍历到的第一个元素和第二个元素存入哈希表中,然后按条件找第三个元素,但是这道题有去重的要求,哈希表实现较为…

Vue3、Vite5、Primevue、Oxlint、Husky9 简单快速搭建最新的Web项目模板

Vue3、Vite5、Oxlint、Husky9 简单搭建最新的Web项目模板 特色进入正题创建基础模板配置API自动化导入配置组件自动化导入配置UnoCss接入Primevue接入VueRouter4配置项目全局环境变量 封装Axios接入Pinia状态管理接入Prerttier OXLint ESLint接入 husky lint-staged&#xf…

基于深度学习的文本信息提取方法研究(pytorch python textcnn框架)

💗博主介绍💗:✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示:文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…

Linux(命令格式详细+字符集 图片+大白话)

后面也会持续更新,学到新东西会在其中补充。 建议按顺序食用,欢迎批评或者交流! 缺什么东西欢迎评论!我都会及时修改的! 在这里真的很感谢这位老师的教学视频让迷茫的我找到了很好的学习视频 王晓春老师的个人空间…

机器学习中的概率超能力:如何用朴素贝叶斯算法结合标注数据做出精准预测

💗💗💗欢迎来到我的博客,你将找到有关如何使用技术解决问题的文章,也会找到某个技术的学习路线。无论你是何种职业,我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章,也欢…

01 —— Webpack打包流程及一个例子

静态模块打包工具 静态模块:html、css、js、图片等固定内容的文件 打包:把静态模块内容,压缩、转译等 Webpack打包流程 src中新建一个index.js模块文件;然后将check.js模块内的两个函数导入过来,进行使用下载webpack…

时间类的实现

在现实生活中,我们常常需要计算某一天的前/后xx天是哪一天,算起来十分麻烦,为此我们不妨写一个程序,来减少我们的思考时间。 1.基本实现过程 为了实现时间类,我们需要将代码写在3个文件中,以增强可读性&a…

学习笔记024——Ubuntu 安装 Redis遇到相关问题

目录 1、更新APT存储库缓存: 2、apt安装Redis: 3、如何查看检查 Redis版本: 4、配置文件相关设置: 5、重启服务,配置生效: 6、查看服务状态: 1、更新APT存储库缓存: sudo apt…

【时间之外】IT人求职和创业应知【35】-RTE三进宫

目录 新闻一:京东工业发布11.11战报,多项倍增数据体现工业经济信心提升 新闻二:阿里云100万核算力支撑天猫双11,弹性计算规模刷新纪录 新闻三:声网CEO赵斌:RTE将成为生成式AI时代AI Infra的关键部分 认知…