大模型基础——从零实现一个Transformer(2)

news2024/12/27 17:24:17

大模型基础——从零实现一个Transformer(1)

一、引言

上一章主要实现了一下Transformer里面的BPE算法和 Embedding模块定义
本章主要讲一下 Transformer里面的位置编码以及多头注意力

二、位置编码

2.1正弦位置编码(Sinusoidal Position Encoding)

其中:

pos:表示token在文本中的位置
: i代表词向量具体的某一维度,即位置编码的每个维度对应一个波长不同的正弦或余弦波
d : d表示位置编码的最大维度,和词嵌入的维度相同,假设是512

对于位置0的编码为:

对于位置1的编码为:

2.2 正弦位置编码特性

  • 相对位置关系:pos + k的位置编码可以被位置pos的位置编码线性表示
    三角函数公式如下:

对于pos + k的位置编码:

根据式( 3 )和( 4 )整理上式有:

  • 位置之间的相对距离

𝑃𝐸𝑝𝑜𝑠+𝑘∙𝑃𝐸𝑝𝑜𝑠 的内积:

位置之间内积的关系大小如下:

可以看到内积会随着相对位置的递增而减少,从而可以表示位置的相对距离。内积的结果是对称的,所以没有方向信息。

2.3 代码实现

import torch
from torch import nn,Tensor
import math


class PositionalEmbedding(nn.Module):
    def __init__(self,d_model:int=512,dropout:float=0.1,max_positions:int=1024) -> None:
        '''

        :param d_model: embedding向量的维度
        :param dropout:
        :param max_positions: 最大长度
        '''
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        # Position Embedding  (max_positions,d_model)
        pe = torch.zeros(max_positions,d_model)

        # 创建position index列表 ,形状为:(max_positions, 1)
        position = torch.arange(0,max_positions).unsqueeze(1)

        # d_model 维度 偶数位是sin ,奇数位是cos
        # 计算除数,这里的除数将用于计算正弦和余弦的频率
        div_term = torch.exp(
            torch.arange(0,d_model,2) * -(math.log(10000.0) /d_model)
        )

        # 对矩阵的偶数列(0,2,4...)进行正弦函数编码
        pe[:, 0::2] = torch.sin(position * div_term)

        # 对矩阵的奇数列(1,3,5...)进行余弦函数编码
        pe[:, 1::2] = torch.cos(position * div_term)

        # 扩展维度,增加batch_size: pe (1, max_positions, d_model)
        pe = pe.unsqueeze(0)

        # buffers will not be trained
        self.register_buffer("pe", pe)

    def forward(self,x:Tensor) ->Tensor:
        """

                Args:
                    x (Tensor): (batch_size, seq_len, d_model) embeddings

                Returns:
                    Tensor: (batch_size, seq_len, d_model)
        """

        # x.size(1)是指当前x的最大长度
        x = x + self.pe[:,:x.size(1)]
        return self.dropout(x)

if __name__ == '__main__':
    seq_len = 128
    d_model = 512

    pe = PositionalEmbedding(d_model)

    x = torch.rand((1,100,d_model))
    print(pe(x).shape)

三、多头注意力

3.1 自注意力

公式如下:

  • 假设一个矩阵X,分别乘上权重矩阵,,就得到了Q , K , V向量矩阵

  • 然后除以 𝑑𝑘 进行缩放,再经过Softmax,得到注意力权重矩阵,接着乘以value向量矩阵V,就一次得到了所有单词的输出矩阵Z

3.2 多头注意力

将原来n_head分割乘Nx n_sub_head.对于每个头i,都有它自己不同的key,query和value矩阵: 𝑊𝑖𝐾,𝑊𝑖𝑄,𝑊𝑖𝑉 。在多头注意力中,key和query的维度是 𝑑𝑘 ,value嵌入的维度是 𝑑𝑣 (其中key,query和value的维度可以不同,Transformer里面一般设置的是相同的),这样每个头i,权重 𝑊𝑖𝑄∈𝑅𝑑×𝑑𝑘,𝑊𝑖𝐾∈𝑅𝑑×𝑑𝑘,𝑊𝑖𝑉∈𝑅𝑑×𝑑𝑣 ,然后与压缩到X中的输入相乘,得到 𝑄∈𝑅𝑁×𝑑𝑘,𝐾∈𝑅𝑁×𝑑𝑘,𝑉∈𝑅𝑁×𝑑𝑣 .

3.3 代码实现

import math

import torch
from torch import nn,Tensor
from typing import *

class MultiHeadAttention(nn.Module):
    def __init__(self,d_model: int = 512,n_heads: int=8,dropout: float = 0.1):
        '''

        :param d_model: embedding大小
        :param n_heads: 多头个数
        :param dropout:
        '''
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_key = d_model // n_heads

        self.q = nn.Linear(d_model,d_model)
        self.k = nn.Linear(d_model,d_model)
        self.k = nn.Linear(d_model,d_model)

        self.concat = nn.Linear(d_model,d_model)

        self.dropout = nn.Dropout(dropout)

    def split_heads(self,x:Tensor,is_key : bool = False) -> Tensor:
        '''
        分割向量为N个头,如果是key的话,softmax时候,key需要转置一下
        :param x:
        :param is_key:
        :return:
        '''
        batch_size = x.size(0)

        # x (batch_size,seq_len,n_heads,d_key)
        x = x.view(batch_size,-1,self.n_heads,self.d_key)
        if is_key:
            # (batch_size,n_heads,d_key,seq_len)
            return x.permute(0,2,3,1)

        # (batch_size,n_heads,seq_len,d_key
        return x.transpose(1,2)

    def merge_heads(self,x: Tensor) -> Tensor:
        x = x.transpose(1,2).contigouse().view(x.size(0),-1,self.d_model)
        return x

    def attention(self,
                  query:Tensor,
                  key:Tensor,
                  value:Tensor,
                  mask:Tensor = None,
                  keep_attentions:bool = False):

        scores = torch.matmul(query,key) / math.sqrt(self.d_key)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        # weights (batch_size,n_heads,q_length,k_length)
        weights = self.dropout(torch.softmax(scores,dim=-1))

        # (batch_size,n_heads,q_length,k_length) x (batch_size,n_heads,v_length,d_key)
        # -> (batch_size,n_heads,q_length,d_key)
        # assert k_length == v_length

        # attn_output (batch_size, n_heads, q_length, d_key)
        atten_output = torch.matmul(weights,value)

        if keep_attentions:
            self.weights = weights
        else:
            del weights

        return atten_output

    def forward(self,
                query: Tensor,
                key: Tensor,
                value: Tensor,
                mask: Tensor = None,
                keep_attentions: bool = False)-> Tuple[Tensor,Tensor]:
        '''

        :param query:(batch_size, q_length, d_model)
        :param key:(batch_size, k_length, d_model)
        :param value:(batch_size, v_length, d_model)
        :param mask: mask for padding or decoder. Defaults to None.
        :param keep_attentions: whether keep attention weigths or not. Defaults to False.
        :return: (batch_size, q_length, d_model) attention output
        '''
        query = self.q(query)
        key = self.k(key)
        value = self.v(value)

        query,key,value = (
            self.split_heads(query),
            self.split_heads(key,is_key=True),
            self.split_heads(value)
        )

        atten_output = self.attention(query,key,value,mask,keep_attentions)

        del query
        del key
        del value

        # concat
        concat_output = self.merge_heads(atten_output)

        # the final liear
        # output (batch_size, q_length, d_model)
        output = self.concat(concat_output)
        
        return output

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1809495.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

linux中xterm窗口怎么调整字体大小

需求:打开的xterm窗口字体比较小,怎么才能调整字体大小,打开的大写: 解决方法: 在home目录下搞一个设置文件 .Xresource,里面内容如下 然后把设置文件添加到 .tcshrc 文件中生效 这样重新打开的xterm字…

MySQL数据库(二)和java复习

一.MySQL数据库学习(二) (一).DQL查询数据 DQL(Data Query Language)是用于从数据库中检索数据的语言。常见的 DQL 语句包括 SELECT、FROM、WHERE、GROUP BY、HAVING 和 ORDER BY 等关键字,用于指定要检索的数据、数据源、过滤条件、分组方…

《编程小白变大神:DjangoBlog带你飞越代码海洋》

还在为你的博客加载速度慢而烦恼?DjangoBlog性能优化大揭秘,让你的网站速度飞跃提升!本文将带你深入了解缓存策略、数据库优化、静态文件处理等关键技术,更有Gunicorn和Nginx的黄金搭档,让你的博客部署如虎添翼。无论你…

助力高考,一组彩色的文字

1、获取文本内容 首先&#xff0c;获取每个<div>元素的文本内容&#xff0c;并清空其内部HTML&#xff08;innerHTML ""&#xff09;。 2、创建<span>元素 然后&#xff0c;它遍历文本的每个字符&#xff0c;为每个字符创建一个新的<span>元素…

《python程序语言设计》2018版第5章第36题改造4.17 石头 剪刀 布某一方超过2次就结束。

代码编写记录 2024.05.04 05.36.01version 换一个什么数代替剪子 我先建立一个函数judgement condition 石头3 剪子2 布1 如何构建一个循环进行的架构&#xff0c;是我们最需要的想法 循环以什么条件开始呢 是小于2个还是大于2个。 guess_num random.randint(1, 3) computer…

nginx优化与防盗链【☆☆☆】

目录 一、用户层面的优化 1、隐藏版本号 方法一&#xff1a;修改配置文件 方法二&#xff1a;修改源码文件&#xff0c;重新编译安装 2、修改nginx用户与组 3、配置nginx网页缓存时间 4、nginx的日志切割 5、配置nginx实现连接超时 6、更改nginx运行进程数 7、开启网…

IPv4 子网掩码计算器—python代码实现

今天聊一下&#xff0c;我用python和vscode工具实现一个IPv4计算器的一些思路&#xff0c;以及使用Python编写IPv4计算器一些好处&#xff1f; 首先&#xff0c;一、Python语法简洁易读&#xff0c;便于理解和维护&#xff0c;即使对编程不熟悉的用户也能快速了解代码逻辑。其…

阿里通义千问 Qwen2 大模型开源发布

阿里通义千问 Qwen2 大模型开源发布 Qwen2 系列模型是 Qwen1.5 系列模型的重大升级。该系列包括了五个不同尺寸的预训练和指令微调模型&#xff1a;Qwen2-0.5B、Qwen2-1.5B、Qwen2-7B、Qwen2-57B-A14B 以及 Qwen2-72B。 在中文和英文的基础上&#xff0c;Qwen2 系列的训练数…

已解决Error || RuntimeError: size mismatch, m1: [32 x 100], m2: [500 x 10]

已解决Error || RuntimeError: size mismatch, m1: [32 x 100], m2: [500 x 10] 原创作者&#xff1a; 猫头虎 作者微信号&#xff1a; Libin9iOak 作者公众号&#xff1a; 猫头虎技术团队 更新日期&#xff1a; 2024年6月6日 博主猫头虎的技术世界 &#x1f31f; 欢迎来…

情景题之小明的Linux实习之旅:linux实战练习1(下)【基础命令,权限修改,日志查询,进程管理...】

小明的Linux实习之旅&#xff1a;基础指令练习情景练习题下 前景提要小明是怎么做的场景1&#xff1a;初识Linux&#xff0c;创建目录和文件场景2&#xff1a;权限管理&#xff0c;小明的权限困惑场景3&#xff1a;打包与解压&#xff0c;小明的备份操作场景4&#xff1a;使用G…

分享一个 .NET Core Console 项目中应用 NLog 写日志的详细例子

前言 日志在软件开发中扮演着非常重要的角色&#xff0c;通常我们用它来记录应用程序运行时发生的事件、错误信息、警告以及其他相关信息&#xff0c;帮助在调试和排查问题时更快速地定位和解决 Bug。 通过日志&#xff0c;我们可以做到&#xff1a; 故障排除和调试&#xff…

让GNSSRTK不再难【第一天】

第1讲 GNSS系统组成以及应用 北斗导航科普动画_哔哩哔哩_bilibili 1.1 GNSS系统 1.1.1 基本概念 全球卫星导航系统&#xff08;Global Navigation Satellite System, GNSS&#xff09;&#xff0c;是能在地球表面或近地空间的任何地点为用户提供全天候的三维坐标、速度以及…

ISO 19115-2:2019 第6章 获取和处理元数据

6 获取和处理元数据 6.1 获取和处理要求的元数据 ISO 19115-1 确定了描述数字地理资源所需的元数据。本文件扩展了 ISO 19115-1 中确定的元数据,并确定了描述地理资源获取和处理所需的附加元数据。 6.2 获取和处理元数据包及其依赖关系 ISO 地理信息系列标准使用一个或多个…

【C语言】C语言—通讯录管理系统(源码)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;公众号&#x1f448;&#xff1a;测试开发自动化【获取源码商业合作】 &#x1f449;荣__誉&#x1f448;&#xff1a;阿里云博客专家博主、5…

算法笔记1-高精度模板(加减乘除)个人模板

目录 加法 减法 乘法 ​编辑 除法 加法 #include <iostream> #include <cstring> #include <algorithm> #include <cmath> #include <queue>using namespace std;typedef pair<int,int> PII;const int N 1e5 10;int n; int a[N],…

nw.js 如何调用activeX控件 (控件是C++编写的dll文件)

&#x1f3c6;本文收录于「Bug调优」专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&&…

《永生之后》读后

文章以2120年背景创作&#xff0c;人类进入永生之年&#xff0c;发现了延长寿命的药物。停滞的死亡&#xff0c;新生的继续造生了人口大爆炸&#xff0c;于是分成两个阵营-长生区&#xff08;不再繁衍后代&#xff09;与生死区&#xff08;不服用药物&#xff0c;仍然生老病死&…

665. 非递减数列(中等)

665. 非递减数列 1. 题目描述2.详细题解3.代码实现3.1 Python3.2 Java 1. 题目描述 题目中转&#xff1a;665. 非递减数列 2.详细题解 判断在最多改变 1 个元素的情况下&#xff0c;该数组能否变成一个非递减数列&#xff0c;一看到题目&#xff0c;不就是遍历判断有几处不…

【全网最简单的解决办法】vscode中点击运行出现仅当从 VS 开发人员命令提示符处运行 VS Code 时,cl.exe 生成和调试才可用

首先确保你是否下载好了gcc编译器&#xff01;&#xff01;&#xff01; 检测方法&#xff1a; winR 打开cmd命令窗 输入where gcc(如果出现路径则说明gcc配置好啦&#xff01;) where gcc 然后打开我们的vscode 把这个文件删除掉 再次点击运行代码&#xff0c;第一个出现…