【模型架构】学习RNN、LSTM、TextCNN和Transformer以及PyTorch代码实现

news2025/1/16 17:05:06

一、前言

在自然语言处理(NLP)领域,模型架构的不断发展极大地推动了技术的进步。从早期的循环神经网络(RNN)到长短期记忆网络(LSTM)、Transformer再到当下火热的Mamba(放在下一节),每一种架构都带来了不同的突破和应用。本文将详细介绍这些经典的模型架构及其在PyTorch中的实现,由于我只是门外汉(想扩展一下知识面),如果有理解不到位的地方欢迎评论指正~。

个人感觉NLP的任务本质上是一个序列到序列的过程,给定输入序列 X=\left\{ x_1,x_2,x_3,...,x_n \right\},要通过一个函数实现映射,得到输出序列Y=\left\{y_1,y_2,y_3,...,y_n \right\},这里的x1、x2、x3可以理解为一个个单词,NLP的具体应用有:

  • 机器翻译:将源语言的句子(序列)翻译成目标语言的句子(序列)。

  • 文本生成:根据输入序列生成相关的输出文本,如文章生成、对话生成等。

  • 语音识别:将语音信号(序列)转换为文本(序列)。

  • 文本分类:尽管最终输出是一个类别标签,但在一些高级应用中,也可以将其看作是将文本序列映射到某个特定的输出序列(如标签序列)。

二、RNN和LSTM

2.1 RNN

循环神经网络(RNN)是一种适合处理序列数据的神经网络架构。与传统的前馈神经网络(线性层)不同,RNN具有循环连接,能够在序列数据的处理过程中保留和利用之前的状态信息。网络结构如下所示:

RNN的网络结构

x和隐藏状态h的计算过程

RNN通过在网络中引入循环连接,将前一个时间步的输出作为当前时间步的输入之一,使得网络能够记住以前的状态。具体来说,RNN的每个时间步都会接收当前输入和前一个时间步的隐藏状态,并输出新的隐藏状态。其核心公式为:

\begin{aligned}&h_{t}=\sigma(W_{hx}x_t+W_{hh}h_{t-1}+b_h)\\&y_{t}=\phi(W_{hy}h_t+b_y)\end{aligned}

  • 𝑥𝑡 是当前时间步的输入。

  • ℎ𝑡 是当前时间步的隐藏状态。

  • ℎ𝑡−1 是前一个时间步的隐藏状态(如果是第一个输入,这项是0)。

  • 𝑦𝑡 是当前时间步的输出。

  • 𝑊ℎ𝑥𝑊ℎℎ𝑊ℎ𝑦 都是权重矩阵,是可以共享参数的。

  • 𝑏ℎ 𝑏𝑦 是偏置。

  • 𝜎𝜙 是激活函数。

2.1.1 RNN的优点

  • 处理序列数据:RNN可以处理任意长度的序列数据,并能够记住序列中的上下文信息。

  • 参数共享:RNN在不同时间步之间共享参数,使得模型在处理不同长度的序列时更加高效。

2.1.2 RNN的缺点

  • 梯度消失和爆炸:在训练过程中,RNN会遇到梯度消失和梯度爆炸的问题,导致模型难以训练或收敛缓慢。

  • 长距离依赖问题:RNN在处理长序列数据时,容易遗忘较早的上下文信息,难以捕捉长距离依赖关系。

  • 不能并行训练:每个时间步的计算需要依赖于前一个时间步的结果,这导致RNN的计算不能完全并行化,必须按顺序进行。这种顺序性限制了RNN的训练速度,但是推理不受影响(尽管推理过程同样受到顺序性依赖的限制,但相比训练过程,推理的计算量相对较小,因为推理时不需要进行反向传播和梯度计算。推理过程主要集中在前向传播,即根据输入数据通过模型得到输出。因此,推理过程中的计算相对较快,尽管不能并行化,但在许多实际应用中仍然可以达到实时或接近实时的性能)。

关于长距离依赖问题的理解:

在RNN中,每个时间步的信息会被传递到下一个时间步。然而,随着序列长度的增加,早期时间步的信息需要通过许多次的传递才能影响到后续时间步。每次传递过程中,信息可能会逐渐衰减。这种逐步衰减导致RNN在处理长序列数据时,早期时间步的信息可能被遗忘或淹没在新信息中。

同时,在训练RNN时,通过时间反向传播算法(Backpropagation Through Time, BPTT)来更新参数。这一过程中,梯度会从输出层反向传播到输入层。然而,长序列中的梯度在多次反向传播时,容易出现梯度消失(梯度逐渐变小,趋近于零)或梯度爆炸(梯度过大,导致数值不稳定)的现象。梯度消失会导致模型难以学习和记住长距离依赖的信息,梯度爆炸则会导致模型参数更新不稳定。

2.1.3 代码实现

以下的实现都是基于文本分类任务进行的:

import torch
import torch.nn as nn

class TextRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout, num_classes):
        super(TextRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        self.fc = nn.Linear(hidden_dim, num_classes)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.embedding(x)

        rnn_out, hidden = self.rnn(x)

        x = self.dropout(rnn_out[:, -1, :])
        x = self.fc(x)
        return x

如果不用torch自带RNN的api的话,下面是自搭版本:

import torch
import torch.nn as nn

class CustomRNNLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CustomRNNLayer, self).__init__()
        self.hidden_dim = hidden_dim
        self.i2h = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.h2o = nn.Linear(hidden_dim, hidden_dim)
        self.tanh = nn.Tanh()

    def forward(self, input, hidden):
        combined = torch.cat((input, hidden), 1)
        hidden = self.tanh(self.i2h(combined))
        output = self.h2o(hidden)
        return output, hidden

class TextRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout, num_classes):
        super(TextRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.rnn1 = CustomRNNLayer(embedding_dim, hidden_dim)
        self.rnn2 = CustomRNNLayer(hidden_dim, hidden_dim)

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)

        batch_size, seq_len, _ = x.shape
        hidden1 = torch.zeros(batch_size, self.hidden_dim).to(x.device)
        hidden2 = torch.zeros(batch_size, self.hidden_dim).to(x.device)

        for t in range(seq_len):
            input_t = x[:, t, :]
            hidden1, _ = self.rnn1(input_t, hidden1)
            hidden2, _ = self.rnn2(hidden1, hidden2)

        x = self.dropout(hidden2)
        x = self.fc(x)
        return x

初始化 hidden1 和 hidden2 为零张量,表示第一个和第二个RNN层的初始隐藏状态,遍历序列长度 seq_len 的每个时间步,将当前时间步的输入向量 input_t 输入到第一个RNN层,更新 hidden1;再将 hidden1 输入到第二个RNN层,更新 hidden2。

特别解释一下,input_t = x[:, t, :] 是提取当前时间步 t 的输入向量,原本的x是(batch_size, seq_len, embedding_dim),处理后是(batch_size, embedding_dim)。

2.2 LSTM

长短期记忆网络(LSTM)是一种特殊的循环神经网络(RNN)架构,旨在解决传统RNN在处理长序列数据时的梯度消失和梯度爆炸问题。LSTM通过引入记忆单元(cell state)和门控机制(gate mechanism),能够更好地捕捉和保留长距离依赖关系。

LSTM的基本结构包括一个记忆单元和三个门:输入门、遗忘门和输出门。这些门用于控制信息在LSTM单元中的流动。LSTM的工作原理可以用以下步骤描述:

  • 遗忘门(Forget Gate):决定记忆单元中的哪些信息需要被遗忘。

  • 输入门(Input Gate):决定哪些新信息需要被存储到记忆单元中。

  • 输出门(Output Gate):决定记忆单元中的哪些信息需要输出。

LSTM的网络结构,可以看到和RNN相似,但是用到门控和记忆机制

LSTM在每个时间步的计算可以分为以下4个阶段,也对应了上图出现的顺序:

遗忘门的计算:

f_t=\sigma(W_f\cdot[h_{t-1},x_t]+b_f)

遗忘门 ft决定了前一个时间步的记忆单元状态C t-1 中哪些信息需要被遗忘。 σ是 sigmoid 激活函数(输出限制在 [0, 1] 之间,0就代表了遗忘,不许任何量通过,1就指允许任意量通过,从而使得网络就能了解哪些数据是需要遗忘,哪些数据是需要保存), wf是遗忘门的权重矩阵,bf是偏置。 [h_{t-1},x_t] 这是一个concat连接操作。

输入门的计算:

\begin{aligned}&i_{t}=\sigma(W_{i}\cdot[h_{t-1},x_{t}]+b_{i})\\&\tilde{C}_{t}=\tanh(W_{C}\cdot[h_{t-1},x_{t}]+b_{C})\end{aligned}

输入门 it决定了当前输入xt中哪些信息需要被添加到记忆单元中, Ct是新的候选记忆, Wi和Wc 分别是输入门和候选记忆的权重矩阵,bi和bc 是偏置。

tanh激活函数的范围是-1~1,它对新信息进行变换,使得新信息能够取正值和负值。这样可以更灵活地调整记忆单元状态,从而保留和更新信息

更新记忆单元状态:

C_t=f_t*C_{t-1}+i_t*\tilde{C}_t

记忆单元状态Ct通过遗忘门和输入门的输出进行更新,融合了前一个时间步的记忆和当前时间步的新信息。

输出门的计算:

\begin{aligned}&o_{t}=\sigma(W_o\cdot[h_{t-1},x_t]+b_o)\\&h_{t}=o_t*\tanh(C_t)\end{aligned}

输出门 ot 决定了记忆单元中哪些信息需要输出,最终的隐藏状态 ht 通过记忆单元状态 Ct​ 以及输出门的控制生成。

单个计算模块的细节

2.2.1 LSTM的优点

  • 解决长距离依赖问题:LSTM通过引入记忆单元(cell state)和门控机制(遗忘门、输入门和输出门),有效地解决了传统RNN的长距离依赖问题。它能够记住长时间跨度内的重要信息,避免了信息在多次传递逐渐衰减。

  • 缓解梯度消失和梯度爆炸问题:在传统RNN中,梯度消失和梯度爆炸是常见的问题,特别是在处理长序列时。LSTM通过其门控机制,能够更稳定地传递梯度,减少了梯度消失和爆炸的发生,从而提高了训练效果。

  • 灵活的记忆更新:LSTM的记忆单元和门控机制使得网络能够有选择性地记住和遗忘信息。这种灵活性使得LSTM在处理复杂的时间序列数据时表现出色,能够捕捉到数据中的重要模式和特征。

2.2.2 LSTM的缺点

  • 计算复杂度高:相较于简单的RNN,LSTM的结构更复杂,包含更多的参数(如多个门和记忆单元)。这种复杂性增加了计算成本,导致训练和推理速度较慢。

  • 难以并行化:LSTM的顺序计算特性限制了其并行化的能力。在处理长序列时,每个时间步的计算依赖于前一个时间步的结果,这使得计算不能完全并行化,从而影响训练和推理的效率。

  • 对长序列仍有局限:尽管LSTM在处理长距离依赖问题上比传统RNN有显著改善,但在极长序列的情况下,仍可能遇到信息遗忘和梯度衰减的问题。

2.2.3 代码实现

import torch
import torch.nn as nn

class TextLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout, num_classes):
        super(TextLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)

        batch_size, seq_len, _ = x.shape
        h_0 = torch.zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).to(x.device)
        c_0 = torch.zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size).to(x.device)

        x, (h_n, c_n) = self.lstm(x, (h_0, c_0))
        x = self.dropout(h_n[-1])

        x = self.fc(x)
        return x

自搭版本:

import torch
import torch.nn as nn

class CustomLSTMLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CustomLSTMLayer, self).__init__()
        self.hidden_dim = hidden_dim
        self.i2f = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.i2i = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.i2c = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.i2o = nn.Linear(input_dim + hidden_dim, hidden_dim)
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, hidden, cell):
        combined = torch.cat((input, hidden), 1)
        f_t = self.sigmoid(self.i2f(combined))
        i_t = self.sigmoid(self.i2i(combined))
        c_tilde_t = self.tanh(self.i2c(combined))
        c_t = f_t * cell + i_t * c_tilde_t
        o_t = self.sigmoid(self.i2o(combined))
        h_t = o_t * self.tanh(c_t)
        return h_t, c_t


class TextLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout, num_classes):
        super(TextLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        self.lstm1 = CustomLSTMLayer(embedding_dim, hidden_dim)
        self.lstm2 = CustomLSTMLayer(hidden_dim, hidden_dim)

        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = self.embedding(x)

        batch_size, seq_len, _ = x.shape
        hidden1 = torch.zeros(batch_size, self.hidden_dim).to(x.device)
        cell1 = torch.zeros(batch_size, self.hidden_dim).to(x.device)
        hidden2 = torch.zeros(batch_size, self.hidden_dim).to(x.device)
        cell2 = torch.zeros(batch_size, self.hidden_dim).to(x.device)

        for t in range(seq_len):
            input_t = x[:, t, :]
            hidden1, cell1 = self.lstm1(input_t, hidden1, cell1)
            hidden2, cell2 = self.lstm2(hidden1, hidden2, cell2)

        x = self.dropout(hidden2)
        x = self.fc(x)
        return x

三、TextCNN

TextCNN(文本卷积神经网络)是一种应用于自然语言处理(NLP)任务的卷积神经网络(CNN)模型。

TextCNN的基本结构包括以下几个部分:

  • 嵌入层(Embedding Layer):将输入的文本序列转换为稠密的词向量表示。这些词向量可以是预训练的词向量(如Word2Vec、GloVe)或在训练过程中学习到的嵌入。

  • 卷积层(Convolutional Layer):对嵌入后的词向量序列应用卷积操作,提取不同大小的n-gram特征。卷积核的大小可以设定为不同的窗口大小(如2, 3, 4等),以捕捉不同范围的局部特征。

  • 池化层(Pooling Layer):对卷积后的特征图应用最大池化操作,将每个卷积核的输出缩减为一个固定大小的特征向量。这一步有助于提取最重要的特征,并减少特征的维度。

  • 全连接层(Fully Connected Layer):将池化后的特征向量连接成一个长的特征向量,输入到全连接层中进行分类。最后一层通常是一个Softmax层,用于输出分类结果。

具体流程如下:

  • 输入文本:输入一个文本序列,假设长度为n,每个词通过词汇表索引转换为词向量表示,形成一个形状为(n,d)的矩阵,其中 d 是词向量的维度。

  • 卷积操作:使用不同大小的卷积核(如2, 3, 4)对嵌入矩阵进行卷积操作,提取不同n-gram的局部特征。卷积后的特征图形状为(n-k+1, m),其中 k 是卷积核的大小,m 是卷积核的数量。

  • 最大池化:对每个卷积核的输出特征图进行最大池化操作,提取重要的特征。池化后的特征向量形状为 (m, )。

  • 特征拼接:将不同卷积核和池化操作得到的特征向量拼接成一个长的特征向量,输入到全连接层。

  • 分类输出:最后通过全连接层和Softmax层进行分类,输出各类别的概率。

TextCNN的网络结构

3.1 TextCNN的优点

  • 高效提取局部特征:卷积操作能够有效提取不同n-gram范围内的局部特征,对于捕捉文本的局部模式非常有效。

  • 并行计算:卷积操作和池化操作可以并行计算,相对于RNN等顺序模型,训练和推理速度更快。

3.2 TextCNN的缺点

  • 缺乏长距离依赖:由于卷积操作的感受野有限,TextCNN在捕捉长距离依赖方面不如LSTM等序列模型表现好。

  • 固定大小的卷积核:虽然可以通过多种卷积核来捕捉不同的n-gram特征,但仍然受限于卷积核的大小,对于变长依赖的建模能力有限。

3.3 权值共享

权值共享是指在卷积神经网络的卷积操作中,同一卷积核(filter)的参数在整个输入图像或特征图上被重复使用。这意味着,对于一个卷积层中的每一个卷积核,其参数在整个输入图像的不同位置上是相同的。

  • 降低参数:在传统的全连接层中,每个神经元都有自己的权重参数,输入维度较大时会导致参数数量庞大。而在卷积层中,由于使用了权值共享,一个卷积核的参数数量固定,独立于输入图像的大小。这大大减少了模型的参数数量。

  • 提升训练效率:由于参数数量减少,权值共享使得模型训练变得更加高效。需要学习的参数变少了,相应的训练时间也减少了。

  • 空间不变性(Translation Invariance):权值共享意味着卷积核在输入图像的不同位置都使用相同的参数,这使得卷积神经网络可以识别图像中的特征,不管这些特征出现在图像的哪个位置。这样,模型可以更好地处理位移和变形,提高对输入图像的泛化能力。

3.4 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_filters, kernel_sizes, dropout, num_classes):
        super(TextCNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.conv1 = nn.Conv2d(1, num_filters, (kernel_sizes[0], embedding_dim))
        self.conv2 = nn.Conv2d(1, num_filters, (kernel_sizes[1], embedding_dim))
        self.conv3 = nn.Conv2d(1, num_filters, (kernel_sizes[2], embedding_dim))
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(num_filters * len(kernel_sizes), num_classes)

    def forward(self, x):
        x = self.embedding(x)
        x = x.unsqueeze(1)  # 增加通道维度方便卷积处理

        conv1_out = F.relu(self.conv1(x)).squeeze(3)
        pooled1 = F.max_pool1d(conv1_out, conv1_out.size(2)).squeeze(2)

        conv2_out = F.relu(self.conv2(x)).squeeze(3)
        pooled2 = F.max_pool1d(conv2_out, conv2_out.size(2)).squeeze(2)

        conv3_out = F.relu(self.conv3(x)).squeeze(3)
        pooled3 = F.max_pool1d(conv3_out, conv3_out.size(2)).squeeze(2)

        x = torch.cat((pooled1, pooled2, pooled3), 1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

四、Transformer

Transformer是老熟人了,是目前主流的网络架构,当然它最早还是起源于NLP领域。

Transformer模型主要由两个部分组成:编码器(Encoder)和解码器(Decoder)。编码器和解码器各自由多个相同的层(layer)堆叠而成,每一层包含两个主要子层(sublayer):

  • 编码器(Encoder):由多个相同的编码器层堆叠组成,每个编码器层包含一个自注意力子层和一个前馈神经网络子层。

  • 解码器(Decoder):由多个相同的解码器层堆叠组成,每个解码器层包含一个自注意力子层、一个编码器-解码器注意力子层和一个前馈神经网络子层。

4.1 自注意力机制(Self-Attention)

自注意力机制是Transformer的核心组件,用于计算输入序列中每个位置的表示。具体来说,自注意力机制通过计算输入序列中每个位置与其他所有位置的相似度来捕捉全局依赖关系。计算公式如下:

\text{Attention}(Q,K,V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

其中:

  • Q(Query)是查询矩阵。

  • K(Key)是键矩阵。

  • V(Value)是值矩阵。

  • dk​ 是键向量的维度。

  • 其实,QKV都是来自一个x经过不同的权重矩阵计算得到的。

4.2 多头注意力机制(Multi-Head Attention)

为了进一步提升模型的表达能力,Transformer采用了多头注意力机制。多头注意力通过对输入进行多个独立的自注意力计算(称为头),并将结果拼接在一起,通过线性变换生成最终的输出。公式如下:

\mathrm{MultiHead}(Q,K,V)=\mathrm{Concat}(\mathrm{head}_1,\mathrm{head}_2,\ldots,\mathrm{head}_h)W^O

其中每个头的计算为:

\text{head}_i=\text{Attention}(QW_i^Q,KW_i^K,VW_i^V)

4.3 前馈神经网络(Feed-Forward Neural Network)

每个编码器和解码器层还包含一个前馈神经网络子层。这个子层包含两个线性变换和一个激活函数(通常是ReLU):

\mathrm{FFN}(x)=\max(0,xW_1+b_1)W_2+b_2

4.4 整体流程

Transformer网络框架

Transformer模型通过嵌入层和位置编码将输入序列转化为稠密向量表示,然后经过编码器和解码器的多层处理,捕捉序列中的全局依赖关系。

编码器通过多头自注意力机制和前馈神经网络提取输入序列的特征,解码器通过掩码多头自注意力机制(遮住了遮盖掉未来的时间步,防止解码器在生成当前时间步的输出时看到未来的单词,确保自回归性质。)、编码器-解码器注意力机制和前馈神经网络生成输出序列。最后通过线性层和Softmax层生成输出单词的概率分布。加法和归一化(Add & Norm,其实就是残差和LayerNorm)操作在每个子层后确保梯度稳定,帮助训练更深的网络。

在Transformer模型的解码器部分,"outputs (shifted right)" 指的是在解码过程中,模型使用已经生成的输出单词作为当前时间步的输入,同时将这些输出单词整体向右偏移一个位置,以确保模型生成下一个单词时只能依赖之前生成的单词,而不是未来的单词。

假设要生成一个法语句子 "Je suis étudiant"。具体步骤如下:

编码器处理

  1. 编码器接收英语句子 "I am a student"。

  2. 编码器生成全局上下文表示,提供给解码器。

解码器生成

  1. 解码器首先接收起始标记 <sos> 作为输入(这里就体现了右移,因为第一个单词变成了一个特定的符号),生成第一个单词 "Je"。

  2. 在生成 "Je" 后,将 "Je" 作为下一个时间步的输入。解码器现在的输入是 <sos> Je,它只能看到 "Je" 之前的内容。

  3. 解码器生成第二个单词 "suis"。接下来,解码器的输入是 <sos> Je suis。

  4. 这一过程不断重复,解码器在每个时间步只能看到之前生成的单词,而不能看到未来的单词。

多头注意力机制

将查询(Q)、键(K)和值(V)通过多个线性变换,拆分成多个组(头),每个头独立计算注意力分数和加权求和值。最后,所有头的输出拼接在一起,通过一个线性变换恢复到原来的维度。这种设计允许模型在不同的子空间中关注不同部分的信息,从而提高模型的表达能力和捕捉复杂模式的能力。

多头注意力机制示意图

4.5 Transformer的优点

  • 并行化计算:由于不依赖于前一个时间步的计算结果,Transformer可以并行处理整个序列,这显著提高了训练和推理的速度。

  • 捕捉全局依赖:自注意力机制能够捕捉序列中任意两个位置之间的依赖关系(具体体现在是矩阵运算),特别适合长序列的处理。

  • 扩展性强:Transformer架构具有很强的扩展性,可以通过增加层数和头数来提高模型的表达能力。

4.6 Transformer的缺点

  • 计算资源消耗大:自注意力机制的计算复杂度为 𝑂(𝑛2⋅𝑑),其中n是序列长度,d是模型的维度。这使得Transformer在处理非常长的序列时计算资源消耗较大。

  • 需要大量数据:Transformer模型通常需要大量的数据来进行训练,以充分发挥其性能优势。这在数据稀缺的任务中可能成为一个限制因素。主要是在ViT那篇论文中提到了,Transformer结构缺少一些CNN本身设计的归纳偏置(其实就是卷积结构带来的先验经验),比如平移不变性和包含局部关系,因此在规模不足的数据集上表现没有那么好。所以,卷积结构其实是一种trick,而transformer结构是没有这种trick的,就需要更多的数据来让它学习这种结构。

4.7 Pytorch代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers,
                 dim_feedforward, dropout=0.1):
        super(Transformer, self).__init__()
        self.d_model = d_model
        # 定义源语言和目标语言的嵌入层
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        # 位置编码层
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        # Transformer模型
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward,
                                          dropout)
        # 输出层
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask):
        # 对源输入进行嵌入和位置编码
        src = self.src_embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        src = self.pos_encoder(src)
        # 对目标输入进行嵌入和位置编码
        tgt = self.tgt_embedding(tgt) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        tgt = self.pos_encoder(tgt)
        # 编码器
        memory = self.transformer.encoder(src, mask=src_mask, src_key_padding_mask=src_padding_mask)
        # 解码器
        output = self.transformer.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=None,
                                          tgt_key_padding_mask=tgt_padding_mask,
                                          memory_key_padding_mask=memory_key_padding_mask)
        # 输出层
        output = self.fc_out(output)
        return output


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        # 初始化位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

def generate_square_subsequent_mask(sz):
    # 生成一个上三角矩阵,防止解码器看到未来的token
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

def create_padding_mask(seq):
    # 创建填充mask,忽略填充部分
    seq = seq == 0
    return seq


# 使用示例
src_vocab_size = 10000
tgt_vocab_size = 10000
d_model = 512
nhead = 8
num_encoder_layers = 6
num_decoder_layers = 6
dim_feedforward = 2048
dropout = 0.1

model = Transformer(src_vocab_size, tgt_vocab_size, d_model, nhead, num_encoder_layers, num_decoder_layers,
                    dim_feedforward, dropout)

src = torch.randint(0, src_vocab_size, (10, 32))  # (源序列长度, 批次大小)
tgt = torch.randint(0, tgt_vocab_size, (20, 32))  # (目标序列长度, 批次大小)
src_mask = generate_square_subsequent_mask(src.size(0))
tgt_mask = generate_square_subsequent_mask(tgt.size(0))  # 生成shifted mask,防止解码器看到未来的token
src_padding_mask = create_padding_mask(src).transpose(0, 1)  # 调整mask形状为 (批次大小, 源序列长度)
tgt_padding_mask = create_padding_mask(tgt).transpose(0, 1)  # 调整mask形状为 (批次大小, 目标序列长度)
memory_key_padding_mask = src_padding_mask

output = model(src, tgt, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
print(output.shape)  # 应该是 (目标序列长度, 批次大小, 目标词汇表大小)

generate_square_subsequent_mask 函数

  • torch.ones(sz, sz):生成一个全是1的矩阵,形状为 (sz, sz)。

  • torch.triu():将矩阵的下三角部分置为0,上三角部分保持为1。torch.triu(torch.ones(sz, sz)) 生成一个上三角矩阵。

  • transpose(0, 1):对矩阵进行转置,使其符合注意力机制的输入格式。

  • mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)):将上三角矩阵中0的位置填充为负无穷,1的位置填充为0。

create_padding_mask 函数

用于生成一个填充mask,标记序列中的填充部分。具体来说,这个mask会告诉模型哪些位置是填充值(通常是0),模型在计算注意力时会忽略这些填充值,从而只关注有效的输入。

在自然语言处理任务中,输入序列通常具有不同的长度。为了使所有输入序列具有相同的长度,通常会在较短的序列末尾添加填充值(通常为0)。但是,这些填充值在计算注意力时是不应该被考虑的,因为它们不包含实际信息。因此,需要一个mask来标记这些填充值的位置,使模型在计算注意力时忽略它们。

Input: 
序列1: [5, 7, 2, 0, 0] 序列2: [1, 3, 0, 0, 0] 
Output: 
tensor([[False, False, False, True, True], [False, False, True, True, True]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1720302.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Linux|虚拟机|Windows 11 家庭版的Hyper虚拟机服务开启

前言&#xff1a; Windows11的版本是比较多的&#xff0c;但有的时候笔记本预装的可能是家庭版&#xff0c;而家庭版的Windows通常是不支持虚拟机的&#xff0c;也就是说Hyper服务根本就看不到 Windows的程序和功能大体如下&#xff1a; &#x1f197;&#xff0c;那么如何开…

ChaosBlade混沌测试实践

ChaosBlade: 一个简单易用且功能强大的混沌实验实施工具 官方仓库&#xff1a;https://github.com/chaosblade-io/chaosblade 1. 项目介绍 ChaosBlade 是阿里巴巴开源的一款遵循混沌工程原理和混沌实验模型的实验注入工具&#xff0c;帮助企业提升分布式系统的容错能力&…

Nuxt3项目实现 OG:Image

目录 前言 1、安装 2、设置网站 URL 3、启用 Nuxt DevTools 4、创建您的第一个Og:Image a. 定义OG镜像 b. 查看您的Og:Image 5、自定义NuxtSeo模板 a. 定义 NuxtSeo模板 b. 使用其他可用的社区模板 6、创建自己的模板 a. 定义组件 BlogPost.vue b. 使用新模板 c.…

【爱空间_登录安全分析报告】

前言 由于网站注册入口容易被黑客攻击&#xff0c;存在如下安全问题&#xff1a; 暴力破解密码&#xff0c;造成用户信息泄露短信盗刷的安全问题&#xff0c;影响业务及导致用户投诉带来经济损失&#xff0c;尤其是后付费客户&#xff0c;风险巨大&#xff0c;造成亏损无底洞 …

模型 FABE(特性 优势 好处 证据)法则

说明&#xff1a;系列文章 分享 模型&#xff0c;了解更多&#x1f449; 模型_思维模型目录。特性、优势、好处、证据&#xff0c;一气呵成。 1 FABE法则的应用 1.1 FABE法则营销商用跑步机 一家高端健身器材公司的销售代表正在向一家新开的健身房推销他们的商用跑步机。以下…

孩子用的灯什么样的好?安利几款适合孩子用的护眼台灯

随着学生们重返校园&#xff0c;家长和孩子们忙于新学期的准备工作&#xff0c;眼睛健康的考量自然也在其中。这也是为何近年来护眼台灯越来越受到欢迎的原因之一。作为一个长期近视并且日常用眼时间较长的人&#xff0c;我本人对护眼台灯有着长期的使用经历&#xff0c;并对它…

如何创建一个Angular项目(超简单)

1、安装Node.js&#xff08;官网Node.js下载&#xff09; 2、运行node -v和npm -v两条命令&#xff08;检验是否下载成功Node.js&#xff09; 3、npm i -g cnpm --registryhttps://registry.npmmirror.com&#xff08;用npm安装cnpm&#xff0c;将镜像源设置为国内镜像源&…

接入knife4j-openapi3访问/doc.html页面空白问题

大概率拦截器拦截下来了&#xff0c;我们F12看网络请求进行排查 都是 /webjars/ 路径下的资源被拦截了&#xff0c;只需在拦截器中添加该白名单即可"/webjars/**" 具体配置如下&#xff1a; Configuration public class WebConfig implements WebMvcConfigurer {priv…

云端数据提取:安全、高效地利用无限资源

在当今的大数据时代&#xff0c;企业和组织越来越依赖于云平台存储和处理海量数据。然而&#xff0c;随着数据的指数级增长&#xff0c;数据的安全性和高效的数据处理成为了企业最为关心的议题之一。本文将探讨云端数据安全的重要性&#xff0c;并提出一套既高效又安全的数据提…

图像加雾算法的研究与应用

目录 前言 一、图像加雾 1、基于传统方法的雾图合成 2、基于深度学习的雾图合成 3、基于Lightroom Classic实现软件加雾 4、基于深度图的方法实现加雾 二、开源的数据集 三、参考文章 前言 在去雾任务当中&#xff0c;训练和评估去雾算法需要大量的带有雾霾和无雾霾的…

超维小课堂 | 1、Pixhawk硬件平台和PX4固件以及QGC地面站之间的关联和区别

Pixhawk硬件平台和PX4固件以及QGC地面站之间的关联和区别 一、Pixhawk是无人机飞控的硬件平台。 主要集成了相关的单片机芯片和各种传感器的外围电路&#xff0c;因此&#xff0c;可以直观的认为Pixhawk一个由各种硬件模块整合成的硬件平台&#xff0c;类似于集成好的单片机开…

【Qt知识】Qt中的对象树是什么?

在深入Qt编程的世界时&#xff0c;你会频繁遇到一个核心概念——对象树&#xff08;Object Tree&#xff09;。这个概念是Qt框架管理内存、处理事件和组织用户界面元素的基础。 什么是Qt对象树&#xff1f; 如果你的Qt应用程序就像一片茂盛的森林&#xff0c;而这片森林中的每…

CameraProvider启动流程

从Android 8.0之后&#xff0c;Android 引入Treble机制&#xff0c;主要是为了解决目前Android 版本之间升级麻烦的问题&#xff0c;将OEM适配的部分vendor与google 对android 大框架升级的部分system部分做了分离&#xff0c;一旦适配了一个版本的vendor信息之后&#xff0c;之…

炫云亮相第二十届中国国际动漫节国际动漫游戏商务大会!

5月28日-29日&#xff0c;备受瞩目的第二十届中国国际动漫节国际动漫游戏商务大会(iABC2024)在杭州滨江开元名都大酒店隆重召开&#xff01;本届大会以动漫IP为核心&#xff0c;从源头到全系列数字内容&#xff0c;探索创新协同、融合发展、价值转化&#xff0c;并对重点作品和…

校园安保巡逻机器人

2023年8月5日&#xff0c;陕西西安一高校实验室起火冒烟&#xff0c;导致学校化学实验室发生火灾。2022年8月3日&#xff0c;一名歹徒持械闯入江西吉安安福县城的一家私立幼儿园&#xff0c;对着无辜的幼儿行凶伤人&#xff0c;造成3死6伤。 像这样的事故有不断地发生&#xf…

计算机基础学习路线

计算机基础学习路线 整理自学计算机基础的过程&#xff0c;虽学习内容众多&#xff0c;然始终相信世上无难事&#xff0c;只怕有心人&#xff0c;期间也遇到许多志同道合的同学&#xff0c;现在也分享自己的学习过程来帮助有需要的。 一、数据结构与算法 视频方面我看的是青…

DNF手游攻略:勇士进阶指南!

在即将到来的6月5日&#xff0c;《DNF手游》将迎来一场盛大的更新&#xff0c;此次更新带来了大量新内容和玩法&#xff0c;极大丰富了游戏的体验。本文将为广大玩家详细解析此次更新的亮点&#xff0c;包括新增的组队挑战玩法“罗特斯入门团本”、新星使宠物的推出、宠物进化功…

【Android】

hint在text显示提示内容 设置主键&#xff0c;在mainactivity // 获取SharedPreferences对象存放的用户名和密码&#xff0c;并设为相应组件的值 //指定key的值&#xff0c;及获取不到值时使用的默认值 String sName sp.getString("name", "unknown")…

pom文件中,Maven导入依赖出现 Dependency not found

解决方案&#xff1a; 1、首先看一下自己的Maven是否配置好了 2、再检查一下镜像是否正确 3、如果上面都没有问题&#xff0c;看 dependencyManagement 标签 我这个出错&#xff0c;爆一大片红就是因为 这个标签 dependencyManagement 解决方法&#xff1a;在父工程中进行依…

实现样式一键切换

实现效果实现方法操作步骤 1.实现效果类似于&#xff1a; 点击切换后更改样式为&#xff1a; 2.实现方法 1.通过css变量切换主题 2.通过link引入css文件切换主题 3.实现方法 定义下拉框双向绑定数据&#xff08;写到根组件app.vue中&#xff09;&#xff0c; 设置theme默认值…