基于KV Cache构建流式帧级别Transformer实现自回归解码

news2024/9/24 1:20:49

在自然语言处理和序列建模中,Transformer模型因其在处理长距离依赖关系上的卓越性能而被广泛使用。传统的Transformer模型在处理长序列时,计算和存储的开销较大,而流式帧级别Transformer通过引入KV Cache(键值缓存)来有效地缓解这一问题。

本文将介绍如何基于KV Cache构建流式帧级别Transformer,并实现自回归解码。通过实际代码示例,详细解释其工作原理和实现细节。
在这里插入图片描述

流式帧级别Transformer简介

流式帧级别Transformer是一种特殊的Transformer变体,设计用于流式输入处理。这种模型可以在序列的每个时间步处理输入,并且利用KV Cache存储历史的键和值,避免重复计算,从而提高效率。自回归解码则意味着模型在生成下一个输出时依赖于之前的输出。

代码实现

我们将实现一个包含编码器和解码器的流式帧级别Transformer模型。编码器和解码器分别利用KV Cache存储和更新历史信息,以实现高效的序列建模和生成。

编码器

首先,定义编码器类StreamSelfAttentionEncoder

import torch
import torch.nn as nn
import math

class StreamSelfAttentionEncoder(nn.Module):
    def __init__(self, model_dim, self_attention_size):
        super(StreamSelfAttentionEncoder, self).__init__()
        self.model_dim = model_dim
        self.self_attention_size = self_attention_size
        self.Q = nn.Linear(model_dim, model_dim)
        self.K = nn.Linear(model_dim, model_dim)
        self.V = nn.Linear(model_dim, model_dim)
        self.softmax = nn.Softmax(dim=-1)
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(model_dim, model_dim * 4),
            nn.ReLU(),
            nn.Linear(model_dim * 4, model_dim)
        )

    def forward(self, x, k_cache=None, v_cache=None, pos=None):
        # Ensure positional encoding is on the same device as x
        if pos is not None:
            pos_enc = self.get_positional_encoding(pos, self.model_dim, x.device)
            x = x + pos_enc.unsqueeze(0).unsqueeze(1)  # (N, 1, model_dim)
        
        # Project inputs to Q, K, V
        q = self.Q(x)  # (N, 1, model_dim)
        k = self.K(x)  # (N, 1, model_dim)
        v = self.V(x)  # (N, 1, model_dim)
        
        batch_size = x.size(0)
        
        # Initialize k_cache and v_cache if not provided
        if k_cache is None:
            k_cache = torch.zeros((batch_size, 0, self.model_dim), device=x.device)
            v_cache = torch.zeros((batch_size, 0, self.model_dim), device=x.device)
        
        # Concatenate past K, V with current K, V
        k_cache = torch.cat([k_cache, k], dim=1)  # (N, seq_len + 1, model_dim)
        v_cache = torch.cat([v_cache, v], dim=1)  # (N, seq_len + 1, model_dim)
        
        # Compute attention scores
        attn_scores = torch.matmul(q, k_cache[:, -self.self_attention_size:].transpose(-2, -1)) / math.sqrt(self.model_dim)
        attn_weights = self.softmax(attn_scores)
        
        # Compute attention output
        attn_output = torch.matmul(attn_weights, v_cache[:, -self.self_attention_size:])
        
        # Apply skip connection and FFN
        attn_output = attn_output + x
        ffn_output = self.ffn(attn_output)
        output = ffn_output + attn_output
        
        return output, k_cache, v_cache

    def get_positional_encoding(self, pos, model_dim, device):
        pe = torch.zeros(model_dim, device=device)
        div_term = torch.exp(torch.arange(0, model_dim, 2, device=device).float() * (-math.log(10000.0) / model_dim))
        pe[0::2] = torch.sin(pos * div_term)
        pe[1::2] = torch.cos(pos * div_term)
        return pe

在这个编码器中,我们通过以下步骤来处理输入数据:

  1. 位置编码(Positional Encoding)

    if pos is not None:
        pos_enc = self.get_positional_encoding(pos, self.model_dim, x.device)
        x = x + pos_enc.unsqueeze(0).unsqueeze(1)  # (N, 1, model_dim)
    

    这里我们为输入x添加位置编码,以保留序列信息。

  2. 投影(Projection)

    q = self.Q(x)  # (N, 1, model_dim)
    k = self.K(x)  # (N, 1, model_dim)
    v = self.V(x)  # (N, 1, model_dim)
    

    将输入x投影到查询(Query)、键(Key)和值(Value)空间。

  3. KV缓存初始化和更新(KV Cache Initialization and Update)

    if k_cache is None:
        k_cache = torch.zeros((batch_size, 0, self.model_dim), device=x.device)
        v_cache = torch.zeros((batch_size, 0, self.model_dim), device=x.device)
    
    k_cache = torch.cat([k_cache, k], dim=1)  # (N, seq_len + 1, model_dim)
    v_cache = torch.cat([v_cache, v], dim=1)  # (N, seq_len + 1, model_dim)
    

    初始化并更新KV缓存,将当前的kv值拼接到缓存中。

  4. 注意力计算(Attention Calculation)

    attn_scores = torch.matmul(q, k_cache[:, -self.self_attention_size:].transpose(-2, -1)) / math.sqrt(self.model_dim)
    attn_weights = self.softmax(attn_scores)
    attn_output = torch.matmul(attn_weights, v_cache[:, -self.self_attention_size:])
    

    计算查询与缓存中键的点积,然后通过softmax获得注意力权重,再将权重应用到缓存中的值上,得到注意力输出。

  5. 前馈网络(Feed-Forward Network)和跳跃连接(Skip Connection)

    attn_output = attn_output + x
    ffn_output = self.ffn(attn_output)
    output = ffn_output + attn_output
    

    最后,将注意力输出与输入相加,再经过前馈网络和跳跃连接得到最终输出。

解码器

接下来,定义解码器类StreamSelfAttentionDecoder

class StreamSelfAttentionDecoder(nn.Module):
    def __init__(self, model_dim, self_attention_size, cross_attention_size):
        super(StreamSelfAttentionDecoder, self).__init__()
        self.model_dim = model_dim
        self.self_attention_size = self_attention_size
        self.cross_attention_size = cross_attention_size
        self.Qe = nn.Linear(model_dim, model_dim)
        self.Qd = nn.Linear(model_dim, model_dim)
        self.Kd = nn.Linear(model_dim, model_dim)
        self.Vd = nn.Linear(model_dim, model_dim)
        self.softmax = nn.Softmax(dim=-1)
        
        # FFN
        self.ffn = nn.Sequential(
            nn.Linear(model_dim, model_dim * 4),
            nn.ReLU(),
            nn.Linear(model_dim * 4, model_dim)
        )

    def forward(self, x,
                encoder_k_cache,
                encoder_v_cache,
                decoder_k_cache=None,
                decoder_v_cache=None, 
                pos=None):
        
        batch_size = x.size(0)

        # Ensure positional encoding is on the same device as x
        if pos is not None:
            pos_enc = self.get_positional_encoding(pos, self.model_dim, x.device)
            x = x + pos_enc.unsqueeze(0).unsqueeze(1)  # (N, 1, model_dim)
        
        # Initialize caches if not provided
        if decoder_k_cache is None:
            decoder_k_cache = torch.zeros((batch_size, 0, self.model_dim), device=x.device)
            decoder_v_cache = torch.zeros((batch_size, 0, self.model_dim), device=x.device)
        
        # Decoder self-attention
        qd = self.Qd(x)  # (N, 1, model_dim)
        kd = self.Kd(x)  # (N, 1, model_dim)
        vd = self.Vd(x)  # (N, 1, model_dim)

        # Concatenate past K, V with current K, V
        decoder_k_cache = torch.cat([decoder_k_cache, kd], dim=1)  # (N, seq_len + 1, model_dim)
        decoder_v_cache = torch.cat([decoder_v_cache, vd], dim=1)  # (N, seq_len + 1

, model_dim)
        
        # Compute self-attention scores
        attn_self_scores = torch.matmul(qd, decoder_k_cache[:, -self.self_attention_size:].transpose(-2, -1)) / math.sqrt(self.model_dim)
        attn_self_weights = self.softmax(attn_self_scores)
        attn_self_output = torch.matmul(attn_self_weights, decoder_v_cache[:, -self.self_attention_size:])
        attn_self_output = attn_self_output + x

        # Encoder-decoder cross-attention
        qe = self.Qe(attn_self_output)
        attn_cross_scores = torch.matmul(qe, encoder_k_cache[:, -self.cross_attention_size:].transpose(-2, -1)) / math.sqrt(self.model_dim)
        attn_cross_weights = self.softmax(attn_cross_scores)
        attn_cross_output = torch.matmul(attn_cross_weights, encoder_v_cache[:, -self.cross_attention_size:])
        attn_cross_output = attn_cross_output + attn_self_output

        # Apply skip connection and FFN
        ffn_output = self.ffn(attn_cross_output)
        output = ffn_output + attn_cross_output
        
        return output, decoder_k_cache, decoder_v_cache

    def get_positional_encoding(self, pos, model_dim, device):
        pe = torch.zeros(model_dim, device=device)
        div_term = torch.exp(torch.arange(0, model_dim, 2, device=device).float() * (-math.log(10000.0) / model_dim))
        pe[0::2] = torch.sin(pos * div_term)
        pe[1::2] = torch.cos(pos * div_term)
        return pe

在这个解码器中,我们通过以下步骤来处理输入数据:

  1. 位置编码(Positional Encoding)

    if pos is not None:
        pos_enc = self.get_positional_encoding(pos, self.model_dim, x.device)
        x = x + pos_enc.unsqueeze(0).unsqueeze(1)  # (N, 1, model_dim)
    

    这里我们为输入x添加位置编码,以保留序列信息。

  2. 投影(Projection)

    qd = self.Qd(x)  # (N, 1, model_dim)
    kd = self.Kd(x)  # (N, 1, model_dim)
    vd = self.Vd(x)  # (N, 1, model_dim)
    

    将输入x投影到查询(Query)、键(Key)和值(Value)空间。

  3. KV缓存初始化和更新(KV Cache Initialization and Update)

    if decoder_k_cache is None:
        decoder_k_cache = torch.zeros((batch_size, 0, self.model_dim), device=x.device)
        decoder_v_cache = torch.zeros((batch_size, 0, self.model_dim), device=x.device)
    
    decoder_k_cache = torch.cat([decoder_k_cache, kd], dim=1)  # (N, seq_len + 1, model_dim)
    decoder_v_cache = torch.cat([decoder_v_cache, vd], dim=1)  # (N, seq_len + 1, model_dim)
    

    初始化并更新解码器的KV缓存,将当前的kdvd值拼接到缓存中。

  4. 自注意力计算(Self-Attention Calculation)

    attn_self_scores = torch.matmul(qd, decoder_k_cache[:, -self.self_attention_size:].transpose(-2, -1)) / math.sqrt(self.model_dim)
    attn_self_weights = self.softmax(attn_self_scores)
    attn_self_output = torch.matmul(attn_self_weights, decoder_v_cache[:, -self.self_attention_size:])
    attn_self_output = attn_self_output + x
    

    计算查询与解码器缓存中键的点积,然后通过softmax获得注意力权重,再将权重应用到缓存中的值上,得到自注意力输出。

  5. 交叉注意力计算(Cross-Attention Calculation)

    qe = self.Qe(attn_self_output)
    attn_cross_scores = torch.matmul(qe, encoder_k_cache[:, -self.cross_attention_size:].transpose(-2, -1)) / math.sqrt(self.model_dim)
    attn_cross_weights = self.softmax(attn_cross_scores)
    attn_cross_output = torch.matmul(attn_cross_weights, encoder_v_cache[:, -self.cross_attention_size:])
    attn_cross_output = attn_cross_output + attn_self_output
    

    计算自注意力输出与编码器缓存中键的点积,然后通过softmax获得注意力权重,再将权重应用到编码器缓存中的值上,得到交叉注意力输出。

  6. 前馈网络(Feed-Forward Network)和跳跃连接(Skip Connection)

    ffn_output = self.ffn(attn_cross_output)
    output = ffn_output + attn_cross_output
    

    最后,将交叉注意力输出与输入相加,再经过前馈网络和跳跃连接得到最终输出。

示例代码

以下代码展示了如何实例化编码器和解码器,并进行前向传播:

if __name__ == "__main__":
    batch_size = 2
    model_dim = 64
    attention_size = 10
    self_attention_size = 8
    cross_attention_size = 6
    seq_len = 1
    decoder_step = 4
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Instantiate the self-attention encoder and decoder
    encoder = StreamSelfAttentionEncoder(model_dim, attention_size).to(device)
    decoder = StreamSelfAttentionDecoder(model_dim, self_attention_size, cross_attention_size).to(device)
    
    encoder_k_cache = encoder_v_cache = None
    decoder_k_cache = decoder_v_cache = None
    
    for t in range(100):
        x = torch.rand(batch_size, seq_len, model_dim).to(device)  # (N, 1, model_dim)
        pos = t  # Current position
        
        # Encoder forward pass
        encoder_output, encoder_k_cache, encoder_v_cache = encoder(x, encoder_k_cache, encoder_v_cache, pos)
        print(f"Encoder Output shape at time step {t}: {encoder_output.shape}")  # (N, 1, model_dim)
        print(f"Encoder k_cache shape: {encoder_k_cache.shape}")  # (N, seq_len + 1, model_dim)
        print(f"Encoder v_cache shape: {encoder_v_cache.shape}")  # (N, seq_len + 1, model_dim)
        print()

        if t % decoder_step == 0:
            # Decoder forward pass
            decoder_output, decoder_k_cache, decoder_v_cache = decoder(encoder_output, encoder_k_cache, encoder_v_cache, decoder_k_cache, decoder_v_cache, pos)
            print(f"Decoder Output shape at time step {t}: {decoder_output.shape}")  # (N, 1, model_dim)
            print(f"Decoder k_cache shape: {decoder_k_cache.shape}")  # (N, seq_len + 1, model_dim)
            print(f"Decoder v_cache shape: {decoder_v_cache.shape}")  # (N, seq_len + 1, model_dim)
            print()

运行结果如下(对解码器进行跳帧处理)
在这里插入图片描述

结论

通过本文的介绍和示例代码,我们详细阐述了如何基于KV Cache构建流式帧级别Transformer并实现自回归解码。这种方法不仅能有效处理长序列数据,还能显著提升计算效率。希望这篇文章能帮助读者更好地理解和应用流式帧级别Transformer模型。

通过实践和调整参数,读者可以进一步优化模型性能,以满足不同任务的需求。流式帧级别Transformer的应用前景广泛,无论是在自然语言处理、语音识别还是其他序列数据处理领域,都有很大的潜力。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1915382.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

AMD X3D CPU 史诗级进化,锐龙7 9800X3D默秒全

6 月份刚刚结束,这有关下半年新一代 PC 硬件消息便愈发蠢蠢欲动起来。 上个月初台北国际电脑展上,AMD 正式公布了下一代 Zen 5 架构 Ryzen 9000 系列桌面处理器。 AMD 前脚刚大吹特吹性能吊锤 Intel i9 14900K 云云,没想到反手又来了一波被自…

【两大3D转换SDK对比】HOOPS Exchange VS. CAD Exchanger

在现代工业和工程设计领域,CAD数据转换工具是确保不同软件系统间数据互通的关键环节。HOOPS Exchange和CAD Exchanger是两款备受关注的工具,它们在功能、支持格式、性能和应用场景等方面有着显著差异。 本文将从背景、支持格式、功能和性能、应用场景等…

小程序内容管理系统设计

设计一个小程序内容管理系统(CMS)时,需要考虑以下几个关键方面来确保其功能完善、用户友好且高效: 1. 需求分析 目标用户:明确你的目标用户群体,比如企业、媒体、个人博主等,这将决定系统的功…

本地部署,图片细节处理大模型Tile Controlnet

目录 什么是 Tile ControlNet? 工作原理 应用场景 优势与挑战 优势 挑战 本地部署 运行结果 未来展望 结论 Tip: 在近年来的深度学习和计算机视觉领域,生成对抗网络(GAN)和扩散模型等技术取得了显著的进展。…

NI 5G大规模MIMO测试台:将理论变为现实

目录 概览引言MIMO原型验证系统MIMO原型验证系统硬件LabVIEW通信系统设计套件(简称LabVIEW Communications)CPU开发代码FPGA代码开发硬件和软件紧密集成 LabVIEW Communications MIMO应用框架MIMO应用框架特性单用户MIMO和多用户MIMO基站和移动站天线数量…

LINUX命令行curl指令与python内置urllib模块

urllib是python御用的易用的轻便模块,curl是Linux功能强大的命令行工具,都是参与Web的利器。 (笔记模板由python脚本于2024年07月10日 18:41:12创建,本篇笔记适合喜欢Python和Linux的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网&…

【最强八股文 -- 计算机网络】【快速版】WWW 构建技术 (3 项)

1.HTML(HyperText Markup Language):作为页面的文本标记语言 2.HTTP(HyperTextTransfer Protocol):文档传递协议 3.URL(Uniform Resource Locator):指定文档所在地址 HTTPS 和 HTTP 的区别: HTTP: 以明文的方式在网络中传输数据,HTTPS 解决了HTTP 不安全的缺陷&…

芋道源码 yudao-cloud 文档,视频,开发指南如何看全部

进入官网后可以看到相关内容 但是后端手册开始就看不了了 必须加入知识知识星球才行,很烦 闲**鱼搜索用户 水城打坐的藤壶 找到这个链接 这下大家都懂了吧 现在就可以看到看不到的内容了 在线文档的弹窗可技术去除,很简单 直接起飞哈 包括更新sq…

DELTA: DEGRADATION-FREE FULLY TEST-TIME ADAPTATION--论文笔记

论文笔记 资料 1.代码地址 2.论文地址 https://arxiv.org/abs/2301.13018 3.数据集地址 https://github.com/bwbwzhao/DELTA 论文摘要的翻译 完全测试时间自适应旨在使预训练模型在实时推理过程中适应测试数据流,当测试数据分布与训练数据分布不同时&#x…

前端面试题40(浅谈MVVM双向数据绑定)

MVVM(Model-View-ViewModel)架构模式是一种用于简化用户界面(UI)开发的软件架构设计模式,尤其在现代前端开发中非常流行,例如在使用Angular、React、Vue.js等框架时。MVVM模式源于经典的MVC(Mod…

【C++修行之道】string类练习题

目录 387. 字符串中的第一个唯一字符 125. 验证回文串 917. 仅仅反转字母 415. 字符串相加(重点) 541. 反转字符串 II 387. 字符串中的第一个唯一字符 字符串中的第一个唯一字符 - 力扣(LeetCode) 给定一个字符串 s &#…

【UE5.3】笔记10-时间轴的使用

时间轴 右键--Add Timeline(在最下面) --> 双击进入时间轴的编辑界面: 左上角可以添加不同类型的轨道,可以自定义轨道的长度,单位秒,一次可以添加多个 可以通过右键添加关键帧,快捷键:shift鼠标左键按…

ssrf结合redis未授权getshell

目录 漏洞介绍 SSRF Redis未授权 利用原理 环境搭建 利用过程 rockylinux cron计划任务反弹shell 写公钥免密登录 ubuntu 写公钥免密登录 漏洞介绍 SSRF SSRF(server side request forgrey)服务端请求伪造,因后端未过滤用户输入&…

LeetCode(2)合并链表、环形链表的约瑟夫问题、链表分割

一、合并链表 . - 力扣(LeetCode) 题目描述: /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ typedef struct ListNode ListNode; struct ListNode* mergeTwoLists(struct …

skywalking-1-服务端安装

skywalking很优秀。 安装服务端 skywalking的服务端主要是aop服务,为了方便查看使用还需要安装ui。另外采集的数据我们肯定要存起来,这个数据库就直接用官方的banyandb。也就是aop、ui、banyandb都使用官方包。 我们的目的是快速使用和体验&#xff0c…

stm32按键设置闹钟数进退位不正常?如何解决

🏆本文收录于《CSDN问答解惑-专业版》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收…

MMII 的多模态医学图像交互框架:更直观地理解人体解剖结构和疾病

医生在诊断和治疗过程中依赖于人体解剖图像,如磁共振成像(MRI),难以全面捕捉人体组织的复杂性,例如组织之间的空间关系、质地、大小等。然而,实时感知有关患者解剖结构和疾病的多模态信息对于医疗程序的成功…

✅小程序申请+备案教程

##red## 🔴 大家好,我是雄雄,欢迎关注微信公众号,雄雄的小课堂。 零、注意事项 需要特别注意的是,如果公司主体的微信公众号已经交过300块钱的认证费了的话,注册小程序通过公众号来注册,可以免…

手搓前端day1

断断续续的学了些前端,今天开始写写代码,就当是记录一下自己前端的成长过程 效果: 写了点css,实现了简单的前端页面的跳转 文件目录 代码如下: styles.css body{margin: 0;padding: 0;}header{background-color: bl…

3102.力扣每日一题7/9 Java(TreeMap)

博客主页:音符犹如代码系列专栏:算法练习关注博主,后期持续更新系列文章如果有错误感谢请大家批评指出,及时修改感谢大家点赞👍收藏⭐评论✍ 目录 TreeMap详解 解题思路 解题方法 时间复杂度 空间复杂度 Code T…