Multi-Head Attention 代码实现

news2025/1/10 0:08:40

Multi-Head Attention 代码实现

flyfish

MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)

公式的另一种写法
h e a d i = Attention ( Q W i Q , K W i K , V W i V ) MultiHead ( Q , K , V ) = Concat ( h e a d 1 , . . . , h e a d h ) \begin{gather}head_i = \text{Attention}(\boldsymbol{Q}\boldsymbol{W}_i^Q,\boldsymbol{K}\boldsymbol{W}_i^K,\boldsymbol{V}\boldsymbol{W}_i^V)\\\text{MultiHead}(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V}) = \text{Concat}(head_1,...,head_h)\end{gather} headi=Attention(QWiQ,KWiK,VWiV)MultiHead(Q,K,V)=Concat(head1,...,headh)
在这里插入图片描述
在这里插入图片描述
多头(Multi-head),其实就是多做几次 Scaled Dot-product Attention,然后把结果拼接

h1 =    [1, 2, 3]    
        [4, 5, 6]
        [7, 8, 9]

h2 =    [10, 11, 12]
        [13, 14, 15]
        [16, 17, 18]

h3 =    [19, 20, 21]
        [22, 23, 24]
        [25, 26, 27]

h_concatenate = [1, 2, 3, 10, 11, 12, 19, 20, 21]    
                [4, 5, 6, 13, 14, 15, 22, 23, 24]
                [7, 8, 9, 16, 17, 18, 25, 26, 27]

第一种 写法
全部



# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import math



def attention(query, key, value, mask=None, dropout=None):
 

     # query的最后⼀维的⼤⼩, ⼀般情况下就等同于词嵌⼊维度, 命名为d_k
     d_k = query.size(-1)


     scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
     print("scores.shape:",scores.shape)#scores.shape: torch.Size([1, 12, 12])

     if mask is not None:
         scores = scores.masked_fill(mask == 0, -1e9)


     p_attn = F.softmax(scores, dim = -1)

     if dropout is not None:
         p_attn = dropout(p_attn)

     return torch.matmul(p_attn, value), p_attn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

       
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x +  self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)
#在测试attention的时候需要位置编码PositionalEncoding




import copy


def clones(module, N):
 return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
# 多头注意⼒机制的处理MultiHeadedAttention  """head代表头数,embedding_dim代表词嵌入的维度"""


class MultiHeadedAttention(nn.Module):
    def __init__(self, head, embedding_dim, dropout=0.1):
        super(MultiHeadedAttention, self).__init__()
        # 常⽤的assert语句,判断h是否能被d_model整除,
        # 之后要给每个头分配等量的词特征.也就是embedding_dim/head个.
        assert embedding_dim % head == 0
        # 得到每个头获得的分割词向量维度d_k
        self.d_k = embedding_dim // head

        self.head = head

        # 在多头注意⼒中,Q,K,V各需要⼀个,最后拼接的矩阵还需要⼀个,因此⼀共是四个.
        self.linears = clones(nn.Linear(embedding_dim, embedding_dim), 4)
        # self.attn为None,得到的注意⼒张量
        self.attn = None

        self.dropout = nn.Dropout(p=dropout)
    def forward(self, query, key, value, mask=None):

        # 如果存在掩码张量mask
        if mask is not None:
            mask = mask.unsqueeze(0)
        batch_size = query.size(0)

        query, key, value = \
        [model(x).view(batch_size, -1, self.head, self.d_k).transpose(1,2) for model, x in zip(self.linears, (query, key, value))]

        x, self.attn = attention(query, key, value, mask=mask,dropout=self.dropout)

        x = x.transpose(1, 2).contiguous().view(batch_size, -1, self.head *self.d_k)

        return self.linears[-1](x)



# 词嵌⼊维度是8维
d_model = 512
# 置0⽐率为0.1
dropout = 0.1
# 句⼦最⼤⻓度
max_len=12

x = torch.zeros(1, max_len, d_model)
pe = PositionalEncoding(d_model, dropout, max_len)
                           
pe_result = pe(x)

print("pe_result:", pe_result)
query = key = value = pe_result
print("pe_result.shape:",pe_result.shape)

#没有mask的输出情况
#pe_result.shape: torch.Size([1, 12, 8])
attn, p_attn = attention(query, key, value)
print("no mask\n")
print("attn:", attn)
print("p_attn:", p_attn)

#scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 除以math.sqrt(d_k) 表示这个注意力就是 缩放点积注意力,如果没有,那么就是 点积注意力
#当Q=K=V时,又叫⾃注意⼒机制

#有mask的输出情况
print("mask\n")
mask = torch.zeros(1, max_len, max_len)
attn, p_attn = attention(query, key, value, mask=mask)
print("attn:", attn)
print("p_attn:", p_attn)

#在测试attention的时候需要位置编码PositionalEncoding



# 头数head
head = 8
# 词嵌⼊维度embedding_dim
embedding_dim = d_model


mha = MultiHeadedAttention(head, embedding_dim, dropout)
print(mha)
mha_result = mha(query, key, value, mask)
print(mha_result)

第二种写法
全部

class _MultiheadAttention(nn.Module):
    def __init__(self, hidden_size, n_heads, d_k=None, d_v=None,
                 res_attention=False, attn_dropout=0., proj_dropout=0., qkv_bias=True, lsa=False):
        """
        Multi Head Attention Layer
        Input shape:
            Q:       [batch_size (bs) x max_q_len x hidden_size]
            K, V:    [batch_size (bs) x q_len x hidden_size]
            mask:    [q_len x q_len]
        """
        super().__init__()
        d_k = hidden_size // n_heads if d_k is None else d_k
        d_v = hidden_size // n_heads if d_v is None else d_v

        self.n_heads, self.d_k, self.d_v = n_heads, d_k, d_v

        self.W_Q = nn.Linear(hidden_size, d_k * n_heads, bias=qkv_bias)
        self.W_K = nn.Linear(hidden_size, d_k * n_heads, bias=qkv_bias)
        self.W_V = nn.Linear(hidden_size, d_v * n_heads, bias=qkv_bias)

        # Scaled Dot-Product Attention (multiple heads)
        self.res_attention = res_attention
        self.sdp_attn = _ScaledDotProductAttention(hidden_size, n_heads, attn_dropout=attn_dropout,
                                                   res_attention=self.res_attention, lsa=lsa)

        # Poject output
        self.to_out = nn.Sequential(nn.Linear(n_heads * d_v, hidden_size), nn.Dropout(proj_dropout))

    def forward(self, Q:torch.Tensor, K:Optional[torch.Tensor]=None, V:Optional[torch.Tensor]=None, prev:Optional[torch.Tensor]=None,
                key_padding_mask:Optional[torch.Tensor]=None, attn_mask:Optional[torch.Tensor]=None):

        bs = Q.size(0)
        if K is None: K = Q
        if V is None: V = Q

        # Linear (+ split in multiple heads)
        q_s = self.W_Q(Q).view(bs, -1, self.n_heads, self.d_k).transpose(1,2)       # q_s    : [bs x n_heads x max_q_len x d_k]
        k_s = self.W_K(K).view(bs, -1, self.n_heads, self.d_k).permute(0,2,3,1)     # k_s    : [bs x n_heads x d_k x q_len] - transpose(1,2) + transpose(2,3)
        v_s = self.W_V(V).view(bs, -1, self.n_heads, self.d_v).transpose(1,2)       # v_s    : [bs x n_heads x q_len x d_v]

        # Apply Scaled Dot-Product Attention (multiple heads)
        if self.res_attention:
            output, attn_weights, attn_scores = self.sdp_attn(q_s, k_s, v_s,
                                                    prev=prev, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        else:
            output, attn_weights = self.sdp_attn(q_s, k_s, v_s, key_padding_mask=key_padding_mask, attn_mask=attn_mask)
        # output: [bs x n_heads x q_len x d_v], attn: [bs x n_heads x q_len x q_len], scores: [bs x n_heads x max_q_len x q_len]

        # back to the original inputs dimensions
        output = output.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.d_v) # output: [bs x q_len x n_heads * d_v]
        output = self.to_out(output)

        if self.res_attention: return output, attn_weights, attn_scores
        else: return output, attn_weights


class _ScaledDotProductAttention(nn.Module):
    """
    Scaled Dot-Product Attention module (Attention is all you need by Vaswani et al., 2017) with optional residual attention from previous layer
    (Realformer: Transformer likes residual attention by He et al, 2020) and locality self sttention (Vision Transformer for Small-Size Datasets
    by Lee et al, 2021)
    """

    def __init__(self, hidden_size, n_heads, attn_dropout=0., res_attention=False, lsa=False):
        super().__init__()
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.res_attention = res_attention
        head_dim = hidden_size // n_heads
        self.scale = nn.Parameter(torch.tensor(head_dim ** -0.5), requires_grad=lsa)
        self.lsa = lsa

    def forward(self, q:torch.Tensor, k:torch.Tensor, v:torch.Tensor,
                prev:Optional[torch.Tensor]=None, key_padding_mask:Optional[torch.Tensor]=None,
                attn_mask:Optional[torch.Tensor]=None):
        '''
        Input shape:
            q               : [bs x n_heads x max_q_len x d_k]
            k               : [bs x n_heads x d_k x seq_len]
            v               : [bs x n_heads x seq_len x d_v]
            prev            : [bs x n_heads x q_len x seq_len]
            key_padding_mask: [bs x seq_len]
            attn_mask       : [1 x seq_len x seq_len]
        Output shape:
            output:  [bs x n_heads x q_len x d_v]
            attn   : [bs x n_heads x q_len x seq_len]
            scores : [bs x n_heads x q_len x seq_len]
        '''

        # Scaled MatMul (q, k) - similarity scores for all pairs of positions in an input sequence
        attn_scores = torch.matmul(q, k) * self.scale      # attn_scores : [bs x n_heads x max_q_len x q_len]

        # Add pre-softmax attention scores from the previous layer (optional)
        if prev is not None: attn_scores = attn_scores + prev

        # Attention mask (optional)
        if attn_mask is not None:                                     # attn_mask with shape [q_len x seq_len] - only used when q_len == seq_len
            if attn_mask.dtype == torch.bool:
                attn_scores.masked_fill_(attn_mask, -np.inf)
            else:
                attn_scores += attn_mask

        # Key padding mask (optional)
        if key_padding_mask is not None:                              # mask with shape [bs x q_len] (only when max_w_len == q_len)
            attn_scores.masked_fill_(key_padding_mask.unsqueeze(1).unsqueeze(2), -np.inf)

        # normalize the attention weights
        attn_weights = F.softmax(attn_scores, dim=-1)                 # attn_weights   : [bs x n_heads x max_q_len x q_len]
        attn_weights = self.attn_dropout(attn_weights)

        # compute the new values given the attention weights
        output = torch.matmul(attn_weights, v)                        # output: [bs x n_heads x max_q_len x d_v]

        if self.res_attention: return output, attn_weights, attn_scores
        else: return output, attn_weights

第三种写法
全部

import math
import torch
from torch import nn
from d2l import torch as d2l

class MultiHeadAttention(nn.Module):
    """多头注意力"""
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 num_heads, dropout, bias=False, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = d2l.DotProductAttention(dropout)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)
        self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)

    def forward(self, queries, keys, values, valid_lens):
        # queries,keys,values的形状:
        # (batch_size,查询或者“键-值”对的个数,num_hiddens)
        # valid_lens 的形状:
        # (batch_size,)或(batch_size,查询的个数)
        # 经过变换后,输出的queries,keys,values 的形状:
        # (batch_size*num_heads,查询或者“键-值”对的个数,
        # num_hiddens/num_heads)
        queries = transpose_qkv(self.W_q(queries), self.num_heads)
        keys = transpose_qkv(self.W_k(keys), self.num_heads)
        values = transpose_qkv(self.W_v(values), self.num_heads)

        if valid_lens is not None:
            # 在轴0,将第一项(标量或者矢量)复制num_heads次,
            # 然后如此复制第二项,然后诸如此类。
            valid_lens = torch.repeat_interleave(
                valid_lens, repeats=self.num_heads, dim=0)

        # output的形状:(batch_size*num_heads,查询的个数,
        # num_hiddens/num_heads)
        output = self.attention(queries, keys, values, valid_lens)

        # output_concat的形状:(batch_size,查询的个数,num_hiddens)
        output_concat = transpose_output(output, self.num_heads)
        return self.W_o(output_concat)


def transpose_qkv(X, num_heads):
    """为了多注意力头的并行计算而变换形状"""
    # 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
    # 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,
    # num_hiddens/num_heads)
    X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)

    # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    X = X.permute(0, 2, 1, 3)

    # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,
    # num_hiddens/num_heads)
    return X.reshape(-1, X.shape[2], X.shape[3])


def transpose_output(X, num_heads):
    """逆转transpose_qkv函数的操作"""
    X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
    X = X.permute(0, 2, 1, 3)
    return X.reshape(X.shape[0], X.shape[1], -1)
num_hiddens, num_heads = 100, 5
attention = MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                               num_hiddens, num_heads, 0.5)
attention.eval()

batch_size, num_queries = 2, 4
num_kvpairs, valid_lens =  6, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
Y = torch.ones((batch_size, num_kvpairs, num_hiddens))
attention(X, Y, Y, valid_lens).shape

多头(Multi-head),其实就是多做几次 Scaled Dot-product Attention,然后把结果拼接.
拼接结果然后再传给各个Decoder

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1566582.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

kettle使用MD5加密增量获取接口数据

kettle使用MD5加密增量获取接口数据 场景介绍: 使用JavaScript组件进行MD5加密得到Http header,调用API接口增量获取接口数据,使用json input组件解析数据入库 案例适用范围: MD5加密可参考、增量过程可参考、调用API接口获取…

C++(set和map详解,包含常用函数的分析)

set set是关联性容器 set的底层是在极端情况下都不会退化成单只的红黑树,也就是平衡树,本质是二叉搜索树. set的性质:set的key是不允许被修改的 使用set需要包含头文件 set<int> s;s.insert(1);s.insert(1);s.insert(1);s.insert(1);s.insert(2);s.insert(56);s.inser…

面试题:RabbitMQ 消息队列中间件

1. 确保消息不丢失 生产者确认机制 确保生产者的消息能到达队列&#xff0c;如果报错可以先记录到日志中&#xff0c;再去修复数据持久化功能 确保消息未消费前在队列中不会丢失&#xff0c;其中的交换机、队列、和消息都要做持久化消费者确认机制 由spring确认消息处理成功后…

Java基础之流程控制语句(循环)

文章目录 Java基础之流程控制语句(循环)1.顺序结构2.分支结构if语句的第一种格式if语句的第二种格式if语句的第三种格式Switch语句格式Switch的其他知识点default的位置和省略case穿透Switch的新特性 3.循环结构循环的分类for 循环while 循环for循环 与 while循环 的对比 4.do.…

RAG原理、综述与论文应用全解析

1. 背景 1.1 定义 检索增强生成 (Retrieval-Augmented Generation, RAG) 是指在利用大语言模型回答问题之前&#xff0c;先从外部知识库检索相关信息。 早在2020年就已经有人提及RAG的概念&#xff08;paper&#xff1a;Retrieval-augmented generation for knowledge-inten…

IDEA 解决 java: 找不到符号 符号: 类 __ (使用了lombok的注解)

原因IDEA版本太高&#xff0c;在 ProcessingEnvironement 预编译的时候是以代理的方式来执行的&#xff0c;不再是直接 javac方式, lombok依赖的 javac方式的 annotation processors 不再生效了 解决办法&#xff1a;下面这一句&#xff0c;加在下图中 -Djps.track.ap.depen…

八口快速以太网交换机芯片方案分享/JL5110

以太网交换机(switch)是一种网络设备&#xff0c;用于在局域网中连接多个计算机和其他网络设备。它可以实现多个端口之间的同时传输&#xff0c;并根据MAC地址进行帧过滤和转发。交换机通过自学习的方式&#xff0c;将MAC地址与相应的接口关联起来&#xff0c;以便将数据帧准确…

C语言中的数组与函数指针:深入解析与应用

文章目录 一、引言二、数组的定义1、数组的定义与初始化2、char*与char[]的区别1. 存储与表示2. 修改内容3. 作为函数参数 三、字符串指针数组1. 定义与概念2. 使用示例3. 内存管理 四、从字符串指针数组到函数指针的过渡1、字符串指针数组的应用场景2、函数指针的基本概念3、如…

【RedHat9.0】Timer定时器——创建单调定时器实例

一个timer&#xff08;定时器&#xff09;的单元类型&#xff0c;用来定时触发用户定义的操作。要使用timer的定时器&#xff0c;关键是要创建一个定时器单元文件和一个配套的服务单元文件&#xff0c;然后启动这些单元文件。 定时器类型&#xff1a; 单调定时器&#xff1a;即…

解析html内容的h标签成目录树(markdown解析出来的html)

一.本人用的markdown插件是cherry-markdown&#xff0c;个人觉得比较好用&#xff0c;画图和数学公式都整合的很好 https://github.com/Tencent/cherry-markdown 二.背景 经过markdown解析的html&#xff0c;要取里面的h标签转换成目录树&#xff0c;发现这里面都要人工计算&…

EXCEL-VB编程实现自动抓取多工作簿多工作表中的单元格数据

一、VB编程基础 1、 EXCEL文件启动宏设置 文件-选项-信任中心-信任中心设置-宏设置-启用所有宏 汇总文件保存必须以宏启动工作簿格式类型进行保存 2、 VB编程界面与入门 参考收藏 https://blog.csdn.net/O_MMMM_O/article/details/107260402?spm1001.2014.3001.5506 二、…

用Vue仿了一个类似抖音的App

大家好&#xff0c;我是 Java陈序员。 今天&#xff0c;给大家介绍一个基于 Vue3 实现的高仿抖音开源项目。 关注微信公众号&#xff1a;【Java陈序员】&#xff0c;获取开源项目分享、AI副业分享、超200本经典计算机电子书籍等。 项目介绍 douyin —— 一个基于 Vue、Vite 实…

解决uCharts图表在微信小程序层级过高问题

uniapp微信小程图表层级过高问题&#xff1f; 项目中涉及 uCharts图表&#xff0c;在 App/H5端均正常使用&#xff0c;微信小程序 存在层级问题&#xff01; 文章目录 uniapp微信小程图表层级过高问题&#xff1f;效果图遇到问题解决方案 啰嗦一下~&#xff0c;自己的粗心 在实…

Firefox 关键词高亮插件的简单实现

目录 1、配置 manifest.json 文件 2、编写侧边栏结构 3、查找关键词并高亮的方法 3-1&#xff09; 如果直接使用 innerHTML 进行替换 4、清除关键词高亮 5、页面脚本代码 6、参考 1、配置 manifest.json 文件 {"manifest_version": 2,"name": &quo…

HDLbits 刷题 --Always nolatches

学习: Your circuit has one 16-bit input, and four outputs. Build this circuit that recognizes these four scancodes and asserts the correct output. To avoid creating latches, all outputs must be assigned a value in all possible conditions (See also always…

内存管理是如何影响系统的性能的

大家好&#xff0c;今天给大家介绍内存管理是如何影响系统的性能的&#xff0c;文章末尾附有分享大家一个资料包&#xff0c;差不多150多G。里面学习内容、面经、项目都比较新也比较全&#xff01;可进群免费领取。 内存管理对系统性能的影响至关重要&#xff0c;主要体现在以下…

课程设计项目3.2:基于振动信号分析的电机轴承故障检测

01.课程设计内容 02.代码效果图 获取代码请关注MATLAB科研小白的个人公众号&#xff08;即文章下方二维码&#xff09;&#xff0c;并回复&#xff1a;MATLAB课程设计本公众号致力于解决找代码难&#xff0c;写代码怵。各位有什么急需的代码&#xff0c;欢迎后台留言~不定时更新…

Js之运算符与表达式——②

运算符&#xff1a;也叫操作符&#xff0c;是一种符号。通过运算符可以对一个或多个值进行运算&#xff0c;并获取运算结果。 表达式&#xff1a;由数字、运算符、变量的组合&#xff08;组成的式子&#xff09;。 表达式最终都会有一个运算结果&#xff0c;我们将这个结果称…

【cache】卡常

来源于《国家集训队2024论文集》中的《论现代硬件上的常数优化》 个人总结&#xff1a; 不要开二的次幂作为维度的数组&#xff0c;否则常数会变大【存疑】。总是保证内存访问连续性更高。比如 st表&#xff0c;把log维放在第二维&#xff0c;会导致内存访问距离最大为N*LOG。…

权限提升技术:攻防实战与技巧

本次活动赠书1本&#xff0c;包邮到家。参与方式&#xff1a;点赞收藏文章即可。获奖者将以私信方式告知。 网络安全已经成为当今社会非常重要的话题&#xff0c;尤其是近几年来&#xff0c;我们目睹了越来越多的网络攻击事件&#xff0c;例如公民个人信息泄露&#xff0c;企业…