Transformer-XL:打破序列长度限制的Transformer模型

news2025/1/11 3:28:36

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

Transformer-XL

(封面图由ERNIE-ViLG AI 作画大模型生成)

Transformer-XL:打破序列长度限制的Transformer模型

在自然语言处理领域中,序列模型是至关重要的一类模型,但是它们受到了序列长度的限制。在传统的循环神经网络(RNN)模型中,由于梯度消失或梯度爆炸的问题,只能处理较短的序列。为了克服这个问题,Attention机制被引入到了序列模型中,其中Transformer是最著名的例子。

但是,即使是Transformer,它也有一个长度限制。由于输入和输出是一次性给出的,因此Transformer不能处理超过固定长度的序列。为了克服这个问题,Dai等人提出了Transformer-XL,它能够处理超过固定长度的序列,并且能够捕捉更长期的依赖关系。

1. Transformer-XL介绍

Transformer-XL是由Dai等人在2019年提出的,是Transformer模型的一种扩展。Transformer-XL通过使用可重复的缓存机制来解决Transformer模型的长度限制问题。它使用了两种类型的缓存:前向缓存和后向缓存。

前向缓存是指在当前时间步之前的所有时间步的表示形式。它可以被看作是一个保存在模型中的记忆,包含了之前所有时间步的信息。后向缓存是指在当前时间步之后的所有时间步的表示形式。这个缓存是由前向缓存生成的,因此它只能在前向传递后被使用。使用这两种缓存,Transformer-XL能够从之前的计算中获取上下文,并将其用于当前的计算。

具体来说,Transformer-XL使用了一种称为相对位置编码的技术来表示缓存。相对位置编码是指根据缓存中的位置来编码每个时间步的表示形式。相对位置编码不仅考虑了时间步之间的绝对位置,还考虑了它们之间的相对位置。这种编码方式可以帮助Transformer-XL捕捉更长期的依赖关系。

在前向传递期间,Transformer-XL使用前向缓存来计算当前时间步的表示形式。在后向传递期间,它使用后向缓存来计算当前时间步的表示形式。通过这种方式,Transformer-XL能够利用之前计算的信息,并将其用于当前的计算。这使得Transformer-XL能够处理超过固定长度的序列,并捕捉更长期的依赖关系。

2. Transformer-XL原理

Transformer-XL的核心是通过增强Transformer中的循环机制,来增强长序列上下文的记忆。为了理解这个改进,我们首先回顾一下Transformer的结构。Transformer由多个Encoder和Decoder堆叠而成,其中每个Encoder和Decoder均由多头自注意力机制(Multi-Head Self-Attention)和前向神经网络(Feed-Forward Neural Networks)两部分组成。

在Transformer中,每个Encoder或Decoder的自注意力机制的输入是序列中的某个位置,然后它会计算出该位置对所有位置的注意力分数,并对所有位置的值进行加权平均,得到该位置的输出。这个过程是在所有位置上并行计算的,因此Transformer的计算复杂度是线性的。然而,由于需要同时处理整个序列,每个位置的计算都是独立的,因此Transformer无法直接处理超过固定长度的序列。

Transformer-XL的改进之一是增加了一种记忆机制,可以将之前的状态保存下来,并在下一步计算时使用这些状态。具体来说,每个Encoder和Decoder都有一个内存(Memory),用于存储之前的状态。每当计算到一个新的位置时,会从内存中读取之前的状态,并与当前的输入一起计算,得到新的输出,并将输出存储到内存中。这个过程相当于是对前面的序列进行了循环,从而扩展了序列的长度。Transformer-XL中的每个Encoder包含一个内存,可以在计算当前位置时使用之前的内存状态。

另一个Transformer-XL的改进是针对长距离依赖的处理。长距离依赖是指序列中两个相距较远的位置之间存在的依赖关系。传统的循环神经网络(RNN)可以处理长距离依赖,但由于其顺序计算的特性,其计算速度较慢,并且无法进行并行计算。而Transformer可以进行并行计算,但在处理长序列时也存在长距离依赖的问题,因为在自注意力机制中,每个位置只能通过加权平均得到相对位置的信息,无法获取到绝对位置的信息。

Transformer-XL通过增加一种新的方法来解决长距离依赖问题,称为相对位置编码(Relative Positional Encoding)。相对位置编码的思路是通过引入相对位置的概念,来获取序列中不同位置之间的关系。具体来说,对于一个位置i和另一个位置j,相对位置的定义是它们之间的距离d=i-j。通过引入相对位置编码,Transformer可以获取到位置i和位置j之间的相对位置信息,从而处理长距离依赖。相对位置编码的具体实现方式是,在原有的位置编码的基础上,增加一部分相对位置编码,表示当前位置与其他位置之间的相对位置信息。Transformer-XL中的相对位置编码包括了相对位置的信息,从而能够处理长距离依赖。

Transformer-XL的训练过程与传统的语言模型相同,即通过最大化下一个单词的条件概率来训练模型。在Transformer-XL中,每个位置的输入是之前的一段序列,输出是下一个单词的概率分布。因此,Transformer-XL的目标函数可以表示为:

L = − 1 N ∑ i = 1 N log ⁡ P ( w i ∣ w < i ) \mathcal{L} = -\frac{1}{N}\sum_{i=1}^N\log P(w_i|w_{<i}) L=N1i=1NlogP(wiw<i)

其中, w < i w_{<i} w<i表示序列中前i-1个单词, w i w_i wi表示第i个单词,N表示训练样本的总数。在实际训练中,可以使用随机梯度下降(Stochastic Gradient Descent)或其变种方法来优化目标函数,以更新模型参数。

3. Transformer-XL优劣势

(1)优势

Transformer-XL具有以下几个优势:

  1. 能够处理超过固定长度的序列。在传统的Transformer中,输入和输出是一次性给出的,因此Transformer不能处理超过固定长度的序列。Transformer-XL通过使用可重复的缓存机制来解决这个问题。缓存机制使Transformer-XL可以从之前的计算中获取上下文,并将其用于当前的计算。这使得Transformer-XL能够处理比传统Transformer更长的序列。

  2. 能够捕捉更长期的依赖关系。在传统的Transformer中,由于输入和输出是一次性给出的,因此传统Transformer只能利用上下文中较近的信息。Transformer-XL采用的相对位置编码方法使得模型可以处理更长的序列,使得模型更加适合用于语言建模任务,而在较长的序列上,传统的位置编码方法往往存在限制。这使得Transformer-XL在某些任务上比传统Transformer更具优势。

  3. 能够提高模型的可训练性。在传统的Transformer中,模型的训练过程需要在固定长度的序列上进行。这使得模型的训练非常困难。Transformer-XL通过使用可重复的缓存机制,可以将长序列分成多个较短的序列,并在这些较短的序列上进行训练。这提高了模型的可训练性,并使得模型的训练更加稳定。

  4. 在不同的NLP任务中,Transformer-XL也有出色的表现,证明了其在不同场景下的通用性。

(2)劣势

  1. Transformer-XL的模型结构较为复杂,需要更多的计算资源和时间来训练模型,并且模型的训练过程需要使用分布式训练,进一步增加了训练难度和计算成本。
  2. 在处理较短的序列时,Transformer-XL可能会存在过拟合的风险,因为它的模型结构和参数数量比较大,需要更多的数据来避免过拟合。

4. 案例

下面是使用Transformer-XL进行语言建模的一个示例。在这个示例中,我们将使用Penn Treebank数据集,该数据集是一个常用的语言建模数据集,包含了经过标记的英语句子。我们将使用Transformer-XL来训练一个语言建模器,该模型将预测下一个单词的概率,给定之前的单词序列。

我们使用PyTorch实现Transformer-XL,并使用Penn Treebank数据集进行训练。代码如下:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

from pytorch_transformers import (XLNetConfig, XLNetTokenizer,
                                  XLNetForSequenceClassification, XLNetModel,
                                  AdamW, WarmupLinearSchedule)

# Define the model
class TransformerXLModel(nn.Module):
    def __init__(self, ntoken, n_layer, n_head, d_model, d_head, d_inner, dropout, **kwargs):
        super().__init__()
        self.transformer = XLNetModel(XLNetConfig(n_layer=n_layer,
                                                   n_head=n_head,
                                                   d_model=d_model,
                                                   d_head=d_head,
                                                   d_inner=d_inner,
                                                   dropout=dropout,
                                                   **kwargs))
        self.drop = nn.Dropout(dropout)
        self.decoder = nn.Linear(d_model, ntoken)

    def forward(self, input_ids, mems=None):
        if mems is None:
            mems = self.init_mems(input_ids.size(0))
        output, new_mems = self.transformer(input_ids, mems)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0) * output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), new_mems

    def init_mems(self, bsz):
        mems = []
        for i in range(self.transformer.config.n_layer):
            empty = torch.zeros(self.transformer.config.mem_len, bsz, self.transformer.config.d_model).to(next(self.parameters()))
            mems.append(empty)
        return mems

# Define the dataset
class PTBDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.tokenizer = tokenizer
        self.data = []
        with open(data, 'r') as f:
            for line in f:
                self.data.extend(tokenizer.encode(line.rstrip()))

    def __len__(self):
        return len(self.data) - 1

    def __getitem__(self, idx):
        return (self.data[idx], self.data[idx+1])

# Define the training function
def train(model, train_loader, optimizer, scheduler, criterion, device):
    model.train()
    total_loss = 0.
    for step, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        x = x.to(device)
        y = y.to(device)
        logits, _ = model(x)
        logits = logits.view(-1, logits.size(2))
        y = y.view(-1)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)

# Define the evaluation function
def evaluate(model, eval_loader, criterion, device):
    model.eval()
    total_loss = 0.
    with torch.no_grad():
        for step, (x, y) in enumerate(eval_loader):
            x = x.to(device)
            y = y.to(device)
            logits, _ = model(x)
            logits = logits.view(-1, logits.size(2))
            y = y.view(-1)
        	loss = criterion(logits, y)
        	total_loss += loss.item()
    return total_loss / len(eval_loader)
# Set hyperparameters
batch_size = 32
lr = 5e-5
n_layer = 6
n_head = 8
d_model = 512
d_head = 64
d_inner = 2048
dropout = 0.1
n_epochs = 10

# Load the dataset
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
train_dataset = PTBDataset('ptb.train.txt', tokenizer)
eval_dataset = PTBDataset('ptb.valid.txt', tokenizer)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=batch_size)
# Initialize the model and optimizer
model = TransformerXLModel(tokenizer.vocab_size, n_layer, n_head, d_model, d_head, d_inner, dropout)
model.to('cuda')
optimizer = AdamW(model.parameters(), lr=lr, correct_bias=False)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=0.1, t_total=len(train_loader) * n_epochs)
criterion = nn.CrossEntropyLoss()

# Train the model
for epoch in range(n_epochs):
	train_loss = train(model, train_loader, optimizer, scheduler, criterion, 'cuda')
	eval_loss = evaluate(model, eval_loader, criterion, 'cuda')
	print(f"Epoch {epoch+1} - train_loss: {train_loss:.4f} - eval_loss: {eval_loss:.4f}")
# Test the model
test_dataset = PTBDataset('ptb.test.txt', tokenizer)
test_loader = DataLoader(test_dataset, batch_size=batch_size)
test_loss = evaluate(model, test_loader, criterion, 'cuda')
print(f"Test loss: {test_loss:.4f}")

❤️觉得内容不错的话,欢迎点赞收藏加关注😊😊😊,后续会继续输入更多优质内容❤️

👉有问题欢迎大家加关注私戳或者评论(包括但不限于NLP算法相关,linux学习相关,读研读博相关......)👈

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/403187.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis经典五种数据类型底层实现原理解析

目录总纲redis的k,v键值对新的三大类型五种经典数据类型redisObject结构图示结构讲解数据类型与数据结构关系图示string数据类型三大编码格式SDS详解代码结构为什么要重新设计源码解析三大编码格式hash数据类型ziplist和hashtable编码格式ziplist详解结构剖析ziplist的优势(为什…

TypeScript 基础学习之泛型和 extends 关键字

越来越多的团队开始使用 TS 写工程项目&#xff0c; TS 的优缺点也不在此赘述&#xff0c;相信大家都听的很多了。平时对 TS 说了解&#xff0c;仔细思考了解的也不深&#xff0c;借机重新看了 TS 文档&#xff0c;边学习边分享&#xff0c;提升对 TS 的认知的同时&#xff0c;…

Qt静态扫描(命令行操作)

Qt静态扫描&#xff08;命令行操作&#xff09; 前沿&#xff1a; 静态代码分析是指无需运行被测代码&#xff0c;通过词法分析、语法分析、控制流、数据流分析等技术对程序代码进行扫描&#xff0c;找出代码隐藏的错误和缺陷&#xff0c;如参数不匹配&#xff0c;有歧义的嵌…

Linux查看UTC时间

先了解一下几个时间概念。 GMT时间&#xff1a;Greenwich Mean Time&#xff0c;格林尼治平时&#xff0c;又称格林尼治平均时间或格林尼治标准时间。是指位于英国伦敦郊区的皇家格林尼治天文台的标准时间。 GMT时间存在较大误差&#xff0c;因此不再被作为标准时间使用。现在…

数据传输服务DTS的应用场景(阿里巴巴)

数据传输服务DTS的应用场景(阿里巴巴) 数据传输服务DTS&#xff08;Data Transmission Service&#xff09;支持数据迁移、数据订阅和数据实时同步功能&#xff0c;帮助您实现多种典型应用场景。 不停机迁移数据库 传输方式&#xff1a;数据迁移 为了保证数据的一致性&#…

【17】组合逻辑 - VL17/VL19/VL20 用3-8译码器 或 4选1多路选择器 实现逻辑函数

VL17 用3-8译码器实现全减器 【本题我的也是绝境】 因为把握到了题目的本质要求【用3-8译码器】来实现全减器。 其实我对全减器也是不大清楚,但是仿照对全加器的理解,全减器就是低位不够减来自低位的借位 和 本单元位不够减向后面一位索要的借位。如此而已,也没有很难理解…

Python3简单实现图像风格迁移

导语T_T之前似乎发过类似的文章&#xff0c;那时候是用Keras实现的&#xff0c;现在用的PyTorch&#xff0c;而且那时候发的内容感觉有些水&#xff0c;于是我决定。。。好吧我确实只是为了写点PyTorch练手然后顺便过来水一篇美文~~~利用Python实现图像风格的迁移&#xff01;&…

Python实现性能测试(locust)

一、安装locustpip install locust -- 安装&#xff08;在pycharm里面安装或cmd命令行安装都可&#xff09;locust -V -- 查看版本&#xff0c;显示了就证明安装成功了或者直接在Pycharm中安装locust:搜索locust并点击安装&#xff0c;其他的第三方包也可以通过这种方式二、loc…

JavaScript Math(算数)对象

Math&#xff08;算数&#xff09;对象的作用是&#xff1a;执行常见的算数任务。在线实例round()如何使用 round()。random()如何使用 random() 来返回 0 到 1 之间的随机数。max()如何使用 max() 来返回两个给定的数中的较大的数。&#xff08;在 ECMASCript v3 之前&#xf…

站外seo优化有用吗?值得投入时间和精力吗?

随着互联网的普及和竞争的激烈化&#xff0c;SEO&#xff08;Search Engine Optimization&#xff0c;搜索引擎优化&#xff09;已经成为各种网站推广的必备技能。 而站外SEO优化就是指通过在其他网站上增加链接和引用等方式&#xff0c;来提高自己网站的搜索引擎排名和曝光度…

【6G 新技术】6G数据面介绍

博主未授权任何人或组织机构转载博主任何原创文章&#xff0c;感谢各位对原创的支持&#xff01; 博主链接 本人就职于国际知名终端厂商&#xff0c;负责modem芯片研发。 在5G早期负责终端数据业务层、核心网相关的开发工作&#xff0c;目前牵头6G算力网络技术标准研究。 博客…

window.onresize的详细使用

最近做的项目老是涉及到大小屏切换&#xff0c;但是因为屏幕宽高不一样的原因&#xff0c;老是要计算表格高度 window.onresize&#xff1a;监听window窗口变化&#xff0c;当窗口大小发生变化时&#xff0c;会触发此事件 含义 MDN中的定义是这样子的&#xff1a; 文档视图调…

GitHub与PicGo搭建免费稳定图床并实现Typora内复制自动上传

本文介绍基于Github平台与PicGo工具&#xff0c;构建免费、稳定的图床&#xff0c;并实现在Typora内撰写Markdown文档时&#xff0c;粘贴图片就可以将这一图片自动上传到搭建好的图床中的方法。 1 配置GitHub 首先&#xff0c;我们需要配置Github&#xff0c;创建一个仓库从而…

mysql 查询一个表的数据,并修改部分数据,再插回原来的表中,复制某个用户的数据给另一个用户

mysql 查询一个表的数据&#xff0c;并修改部分数据&#xff0c;再插回原来的表中&#xff0c;复制某个用户的数据给另一个用户 一、需求 我有一表日记的表&#xff0c;表中盛放着所有用户的日记数据。 在做演示项目的时候&#xff0c;我需要将一个用户的数据复制给另一个用户…

PlotNeuralNet + ChatGPT创建专业的神经网络的可视化图形

PlotNeuralNet&#xff1a;可以创建任何神经网络的可视化图表&#xff0c;并且这个LaTeX包有Python接口&#xff0c;我们可以方便的调用。 但是他的最大问题是需要我们手动的编写网络的结构&#xff0c;这是一个很麻烦的事情&#xff0c;这时 ChatGPT 就出来了&#xff0c;它可…

JavaScript学习笔记(3.0)

数组是一种特殊类型的对象。在JavaScript中对数组使用typeof运算符会返回“object”。 但是&#xff0c;JavaScript数组最好以数组来描述。 数组使用数字来访问其“元素”。比如person[0]访问person数组中的第一个元素。 <!DOCTYPE html> <html> <body>&l…

【JavaEE进阶】——第一节.Maven国内源配置

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言 操作步骤 1.打开项目配置界面&#xff08;当前项目配置&#xff09; 2.检查并配置国内源 3.再次打开项目配置界面&#xff08;新项目配置&#xff09; 4…

Android RecyclerView的notify方法和动画的刷新详解

前些天发现了一个蛮有意思的人工智能学习网站,8个字形容一下"通俗易懂&#xff0c;风趣幽默"&#xff0c;感觉非常有意思,忍不住分享一下给大家。 &#x1f449;点击跳转到教程 前言&#xff1a; 本篇讲解了RecyclerView关于通知列表刷新的常用的notify方法。和Recy…

综合练习7 摄氏度转华氏温度(“\t“的使用,循环语句)

综合练习7 摄氏度转华氏温度 使用do…while循环&#xff0c;在控制台输入摄氏温度与华氏温度的对照表。 对照表从摄氏温度-30℃到50℃&#xff0c;每行间隔10℃&#xff0c;运行如下&#xff1a; 摄氏温度&#xff1a;-30℃ 华氏温度&#xff1a;-22.0℉ 摄氏温度&#xff1a;…

【专项训练】动态规划-3

动态规划:状态转移方程、找重复性和最优子结构 分治 + 记忆化搜索,可以过度到动态规划(动态递推) function DP():# DP状态定义# 需要经验,需把现实问题定义为一个数组,一维、二维、三维……dp =[][] # 二维情况for i = 0...M: