文献阅读:LONGNET: Scaling Transformers to 1,000,000,000 Tokens

news2024/11/28 2:38:42
  • 文献阅读:LONGNET: Scaling Transformers to 1,000,000,000 Tokens
    • 1. 文章简介
    • 2. 方法原理
      • 1. 方法思路
      • 2. Dilated Attention
        • 1. 具体原理
        • 2. 多头实现
        • 3. 复杂度分析
      • 3. 训练方法
    • 3. 实验结果
    • 4. 结论 & 思考
    • 5. 参考链接
  • 文献链接:https://arxiv.org/abs/2307.02486

1. 文章简介

这篇文章算是我司最近的一篇力作吧,即DeepNet, Foundation Transformer之后,大佬们终于还是盯上了attention layer,毕竟attention层 O ( N 2 ) O(N^2) O(N2)的计算复杂度一直是制约Transformer往长文本发展的主要原因。

想当年,像是线性化Attention的Linformer,或者以更直观的稀疏化attention的Reformer,亦或者结合局部与全局attention的Longformer,或者类似金字塔型的将长文本拆分为短文本然后各自做attention然后逐层往上的方式(不过这篇具体文章给忘了),总之当年零零碎碎有不少关于优化attention层计算量,使之可以拓展到长文本上的工作。

不过可惜的是,虽然当时大家都觉得这个方向很重要,结果以GPT3还有PALM等为代表的大模型反而从工程上发力,直接强行扩展文本长度,从头上干掉了这个问题……

这两年,感觉这方面的工作已经比较少听到了,不过我司的大佬们似乎还是重新抓出了这个方向,然后像是DeepNet那样直接干出了一个量级上碾压的工作,也是真的厉害……

在这里插入图片描述

2. 方法原理

1. 方法思路

LongNet的整体的一个思路其实和之前的Reformer,Linformer等一致,还是在attention层方面做文章,希望将attention layer的计算复杂度从原始的 O ( N 2 d ) O(N^2d) O(N2d)进行优化,使得其与句长 N N N呈线性关系而非平方关系,从而使得模型整体的计算复杂度得到缩减。

对于,文中提出了dilated attention的结构,成功地将attention layer的计算复杂度从 O ( N 2 d ) O(N^2d) O(N2d)降维至 O ( N d ) O(Nd) O(Nd)复杂度。

在这里插入图片描述

需要注意的是,这里的比较没有包含linear transformer,它虽然很早之前已经实现了 O ( N d ) O(Nd) O(Nd)复杂度的attention实现,不过貌似效果不佳,不算是主流的attention方法,因此文中弃用了linear transformer作为对照。

下面,我们就需要具体看一下Dilated Attention层的具体实现方法。

2. Dilated Attention

1. 具体原理

首先,我们给出Dilated Attention层的整体原理图如下:

在这里插入图片描述

具体来说,就是首先给出一个局部窗口长度 w w w和间隔距离 r r r,那么,就可以将总长为 N N N的序列拆分为 N / w N/w N/w个子序列,然后在每一个子序列当中按照间隔 r r r取出token,一共就能够取出 w / r w/r w/r个token,然后用着 w / r w/r w/r个token作为新的序列计算attention,然后把这 N / w N/w N/w个attention矩阵concat起来,就能得到一个 N × N N \times N N×N的稀疏attention矩阵。

考察对于固定的 w , r w,r w,r下的第 i i i个attention矩阵,有:

{ Q i = [ Q i w Q i w + r ⋯ Q ( i + 1 ) w − r ] K i = [ K i w K i w + r ⋯ K ( i + 1 ) w − r ] V i = [ V i w V i w + r ⋯ V ( i + 1 ) w − r ] \left\{ \begin{aligned} Q_i &= [Q_{iw} & Q_{iw+r} & \cdots & Q_{(i+1)w-r}] \\ K_i &= [K_{iw} & K_{iw+r} & \cdots & K_{(i+1)w-r}] \\ V_i &= [V_{iw} & V_{iw+r} & \cdots & V_{(i+1)w-r}] \end{aligned} \right. QiKiVi=[Qiw=[Kiw=[ViwQiw+rKiw+rViw+rQ(i+1)wr]K(i+1)wr]V(i+1)wr]

此时有:

O i = s o f t m a x ( Q i ⋅ K i T d ) V i O_i = \mathop{softmax}(\frac{Q_i \cdot K_i^T}{\sqrt{d}})V_i Oi=softmax(d QiKiT)Vi

当然,这样的一个attention矩阵事实上只包含了局部的attention信息,因此无法兼顾长距离和短距离的attention信息。因此,如果要令总的attention兼顾长距离和短距离的attention信息,就需要取出多组 w , r w,r w,r,分别计算attention然后进行矩阵加和。也就是上图中的合并部分,从而才能获得包含全局attention信息的矩阵。

具体实现上来说,文中采用的是等比数列的方式进行实现,比如如下的方式:

{ w = w , α w , α 2 w , ⋯   , α n w r = r , α r , α 2 r , ⋯   , α n r \left\{ \begin{aligned} w &= {w, \alpha w, \alpha^2 w, \cdots, \alpha^n w} \\ r &= {r, \alpha r, \alpha^2 r, \cdots, \alpha^n r} \end{aligned} \right. {wr=w,αw,α2w,,αnw=r,αr,α2r,,αnr

在上图的demo中,取用的 w , r w,r w,r就是 4 4 4 1 1 1 α \alpha α的取值为 2 2 2

当然,考虑到由于 w , r w,r w,r取值不同导致的attention的密度不同,因此加和的时候需要对权重进行调整,具体而言:

O = ∑ i = 1 k s i ∑ j s j O r i , w i O = \sum\limits_{i=1}^{k}\frac{s_i}{\sum_j s_j}O_{r_i, w_i} O=i=1kjsjsiOri,wi

其中, s i s_i si ( w i , r i ) (w_i, r_i) (wi,ri)这组参数下计算得到的attention矩阵( Q i ⋅ K i T d \frac{Q_i \cdot K_i^T}{\sqrt{d}} d QiKiT)在计算softmax时的分母部分,也就是:

∑ j e Q i ⋅ K i T d \sum\limits_{j} e^{\frac{Q_i \cdot K_i^T}{\sqrt{d}}} jed QiKiT

这样也就得到了一组 n n n维的系数向量,作为我们这里的 s s s

2. 多头实现

关于Dilated Attention的多头实现,整体来说和vanilla transformer的实现方式是一致的,还是在input的向量当中进行split,然后分别过一个上述介绍的Dilated Attention层,最后将output的结果concat起来即可。

不过,感谢作者Shuming大佬的解释,这里和vanilla transformer存在一定的区别,具体就在于对于每一个context window,我们事实上都是等间隔的sample了其中的几个token进行attention的计算,某种意义上来说总是会丢失掉一些信息的。

因此,在设计多头attention的时候,文中进行了一定的优化,即对于input的token位置在不同的head上面给了不同的位置偏移量,从而使得尽可能地覆盖更多的token之间的attention。

具体来说就是,对于第 j j j个head,选取的token为:

{ Q i = [ Q i w + j ( ≡ r ) Q i w + r + j ( ≡ r ) ⋯ Q ( i + 1 ) w − r + j ( ≡ r ) ] K i = [ K i w + j ( ≡ r ) K i w + r + j ( ≡ r ) ⋯ K ( i + 1 ) w − r + j ( ≡ r ) ] V i = [ V i w + j ( ≡ r ) V i w + r + j ( ≡ r ) ⋯ V ( i + 1 ) w − r + j ( ≡ r ) ] \left\{ \begin{aligned} Q_i &= [Q_{iw + j(\equiv r)} & Q_{iw+r + j(\equiv r)} & \cdots & Q_{(i+1)w-r + j(\equiv r)}] \\ K_i &= [K_{iw + j(\equiv r)} & K_{iw+r + j(\equiv r)} & \cdots & K_{(i+1)w-r + j(\equiv r)}] \\ V_i &= [V_{iw + j(\equiv r)} & V_{iw+r + j(\equiv r)} & \cdots & V_{(i+1)w-r + j(\equiv r)}] \end{aligned} \right. QiKiVi=[Qiw+j(r)=[Kiw+j(r)=[Viw+j(r)Qiw+r+j(r)Kiw+r+j(r)Viw+r+j(r)Q(i+1)wr+j(r)]K(i+1)wr+j(r)]V(i+1)wr+j(r)]

可以用文中的图3来对上述不同头的attention进行更为形象化的展示如下:

在这里插入图片描述

3. 复杂度分析

下面,我们来考察一下Dilated Attention层的算法复杂度。

我们首先来考察对于一组确定的 w , r w,r w,r对应的Dilated Attention层的算法复杂度,其对应的结果如下:

F L O P s = 2 N w ⋅ ( w r ) 2 d = 2 N w d r 2 FLOPs = \frac{2N}{w} \cdot (\frac{w}{r})^2d = \frac{2Nwd}{r^2} FLOPs=w2N(rw)2d=r22Nwd

因此,遍历 w , r w,r w,r,我们即可得到完整的Dilated Attention层的算法复杂度如下:

F L O P s = ∑ i = 0 k − 1 2 N w i d r i 2 = 2 N w 0 d r 0 2 ∑ i = 0 k − 1 1 α i < 2 N w 0 d r 0 2 ⋅ α α − 1 ∼ O ( N d ) FLOPs = \sum\limits_{i=0}^{k-1}\frac{2Nw_id}{r_i^2} = \frac{2Nw_0d}{r_0^2} \sum\limits_{i=0}^{k-1} \frac{1}{\alpha^i} < \frac{2Nw_0d}{r_0^2} \cdot \frac{\alpha}{\alpha-1} \sim O(Nd) FLOPs=i=0k1ri22Nwid=r022Nw0di=0k1αi1<r022Nw0dα1αO(Nd)

3. 训练方法

最后,我们看一下文中实际的训练过程。

注意到,这里由于极限的扩展了输入的context的序列长度,因此事实上如何将文本塞入GPU也就成了一个大问题,因此,这方面也需要有一些工程上的实现细节考察。

具体来说,文中给出的方法还是说先对sequence进行一下split,然后由不同的GPU分别计算,最后进行加总实现。

其原理图可以参考文中的图4:

在这里插入图片描述

不过需要注意的是,这里在不同的gpu当中计算完了不同的部分的input seq之后,在计算dilated attention的时候会有一个slice的过程,然后slice之后的得到的dilated attention会在不同的GPU之间进行聚合,从而确保不同的gpu上的token之间的attention能够相互计算和聚合。

由于这里只是slice之后的attention,因此可以避免掉由于过长的文本长度(比如文中给出的1B)导致的内存爆炸的问题。

3. 实验结果

文中使用torchscale作为基准库,然后替换attention layer之后train了一个768维,12层的模型进行实验考察。

得到结果如下:

在这里插入图片描述

而除了最终的ppl之外,文中还比较了transformer与LongNet在处理不同文本长度的文本时所需的计算量。

在这里插入图片描述

可以看到:

  • LongNet可以在更少的计算量下获得相较于原始的transformer更好的ppl。

此外,文中还对LongNet在不同的参数量以及不同的context window进行了一下考察,得到结果如下:

在这里插入图片描述

可以看到:

  • 随着参数量的增长,模型的ppl是在不断减小的,说明LongNet具有很好的扩展能力;
  • context window越大,模型的效果也能够不断地提升,说明LongNet对于长文本有较好的理解能力。

最后,文中还非常直观的给出了将输入文本长度扩展到1B之后vanilla transformer与LongNet的infer时间变化的比较:

在这里插入图片描述

其结果直观地证明了LongNet对于长文本处理能力的能力,较之Vanilla Transformer耗时的快速增长,Dilated Attention基本没有发生什么太大的变化。

4. 结论 & 思考

综上,整体而言这篇文章还是很惊艳的,至少从context length的角度来说这种突破性的震撼确实厉害,结合他之前的foundation transformer等工作,我觉得他们在transformer的基础架构上面确实花了不少的功夫来做优化,这一点确实是厉害。

不过考虑到工程上,这篇文章的主要贡献可能还是在于长文本的关联attention上面,也就意味着其优势必然还是需要长上下文+大语料的前提下才能充分发挥出它的效果,就目前我的工作而言,可能还是有点用不太到……

所以,就只能膜拜一下大佬了,后面有机会的话可以考虑一下在业余时间复现一下看看了,在工作上倒是觉得ROI应该是不会很大了……

5. 参考链接

  1. Longformer: 局部Attention和全局attention的混搭

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1175161.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

onlyoffice 去除:连接器(connector) 表单填(Filling out the form)限制 可使用jsApi级别操作文档

阅读须知&#xff1a;本文针对有对word/excel进行js操作的需求 本次改造基于V7.3.3进行&#xff0c;已经更新进入docker。 小伙伴们须知&#xff1a;改造后的office docker需要付费&#xff08;875元&#xff09;&#xff0c;等于wps一个月费用 欢迎大家一起交流&#xff1a;V&…

pycharm插件推荐:一款能够根据上下文自动提示帮写代码的AI插件

直接上插件&#xff1a; 这个插件有多牛&#xff01;他能够根据注释帮你直接补全代码&#xff08;只需要你按一下tab键&#xff09;&#xff0c;甚至还有学习的能力。 如下&#xff1a; 我注释写完后&#xff0c;一回车就模糊的写出了预计的代码&#xff0c;只要我按下tab键…

【Head First 设计模式】-- 策略模式

一、背景 Head First 设计模式第一章设计模式入门–策略模式 二、工具箱的工具&#xff08;本章&#xff09; 1、OO基础 封装 继承 多态 抽象 2、OO原则 封装变化 面向接口编程&#xff0c;而非面向实现编程 组合优于继承 3、OO模式 策略模式&#xff0c;所谓策略模式就是定义…

基于STM32HAL库(窗口看门狗)-简述

目录 概述 一、开发环境 二、STM32CubeMx配置 三、编码 四、运行结果 五、总结 概述 一个成熟靠谱的项目&#xff0c;离不开“看门狗”的必选项&#xff0c;凡是人写的程序多少都会有出现bug的情况&#xff08;或芯片外设受外界干扰导致故障程序卡死、跑飞的情况&#xf…

中国多主数据库:压强投入,期待破茧

拿破仑曾说&#xff1a;“战争的艺术就是在某一点上集中最大优势兵力”&#xff0c;强调了力量集中的重要性。 如今&#xff0c;国际形势风云变幻&#xff0c;西方世界对中国的围剿不再仅仅体现在军事和地缘政治上&#xff0c;而更多表现在经济与科技上。在科技领域&#xff0…

小程序制作(超详解!!!)第十二节 循环求和计算器

1.index.wxml <view class"box"><view class"title">利用循环语句求和</view><view><input placeholder"请输入起点数值" type"number" bindblur"starNum"></input><!--一旦失去交…

JavaEE-部署项目到服务器

本部分内容为&#xff1a;安装依赖&#xff1a;JDK&#xff0c;Tomcat&#xff0c;Mysql&#xff1b;部署项目到服务器 什么是Tomcat Tomcat简单的说就是一个运行JAVA的网络服务器&#xff0c;底层是Socket的一个程序&#xff0c;它也是JSP和Serlvet的一个容器。 为什么我们需要…

【Docker】Docker中 的AUFS、BTRFS、ZFS、存储池概念的详细讲解

前言 作者简介&#xff1a; 辭七七&#xff0c;目前大二&#xff0c;正在学习C/C&#xff0c;Java&#xff0c;Python等 作者主页&#xff1a; 七七的个人主页 文章收录专栏&#xff1a; 七七的闲谈 欢迎大家点赞 &#x1f44d; 收藏 ⭐ 加关注哦&#xff01;&#x1f496;&…

Jetpack:030-Jetpack中的状态

文章目录 1. 概念介绍2. 使用方法2.1 可监听对象2.2 获取状态值2.3 修改状态值2.4 重组函数 3. 示例代码4. 内容总结 我们在上一章回中介绍了Jetpack中网格布局相关的内容&#xff0c;本章回中主要 介绍状态。闲话休提&#xff0c;让我们一起Talk Android Jetpack吧&#xff0…

好题分析(2023.10.29——2023.11.04)

目录 ​编辑 前情回顾&#xff1a; 前言&#xff1a; 题目一&#xff1a;《合并两个有序数组》 1.运用qsort 2.利用三指针 题目二&#xff1a;《移除链表元素》 题目三&#xff1a;《链表的中间节点》 总结&#xff1a; 前情回顾&#xff1a; 我们在上一篇好题分析…

【图】:常用图搜索(图遍历)算法

目录 概念图遍历深度优先搜索 (DFS)DFS 适用场景DFS 优缺点 广度优先搜索 (BFS)BFS 适用场景BFS 优缺点 DFS & BFS 异同点 图搜索Dijkstra算法A*算法Floyd算法Bellman-Ford算法SPFA算法 概念 图遍历和图搜索是解决图论问题时常用的两种基本操作。 图遍历是指从图中的某一个…

Spring Cloud分布式缓存

目录 单点Redis Redis数据持久化 RDB持久化 bgsave细节 RDB的缺点 AOF持久化 AOF的问题 RDB与AOF对比 搭建Redis主从架构 数据同步原理 全量同步 增量同步 主从同步优化 Redis哨兵 集群检测 选举主节点 故障转移 搭建哨兵集群 RedisTemplate的哨兵模式 单点…

【Leetcode】【每日一题】【中等】187. 重复的DNA序列 官方题解待更新

力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能&#xff0c;轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/repeated-dna-sequences/descrip…

C++ AVL树 c语言版本

引入平衡树 假设我们有两个节点&#xff1a;当我们插入第三个节点&#xff0c;就失衡了&#xff1a;此刻我们就要把它平衡一下。 为什么要变平衡 为什么说它失衡了呢&#xff0c;又为什么要把它变平衡&#xff1f; 如图a&#xff0c;假设我们要查找30这个节点就要查3次才能…

耳机,耳麦,傻傻分不清,难怪麦克风没有声音

有时候会发现为什么同一根耳机线&#xff0c;插到笔记本上可以同时说话和收音&#xff0c;但是插到台式机就不行呢&#xff1f; 因为在以前&#xff0c;耳机和麦克风的接口都是独立的&#xff08;如上图&#xff09;。现在笔记本为了方便&#xff0c;就普遍使用了二合一接口&a…

正点原子嵌入式linux驱动开发——Linux 网络设备驱动

网络驱动是linux里面驱动三巨头之一&#xff0c;linux下的网络功能非常强大&#xff0c;嵌入式linux中也常常用到网络功能。前面已经讲过了字符设备驱动和块设备驱动&#xff0c;本章就来学习一下linux里面的网络设备驱动。 嵌入式网络简介 嵌入式下的网络硬件接口 本次笔记…

是时候放弃 Java 序列化了

基本概念 Java 序列化和反序列化三连问&#xff1a; 什么是 Java 序列化和反序列化&#xff1f;为什么需要 Java 序列化和反序列化&#xff1f;如何实现 Java 序列化和反序列化&#xff1f; 是什么 一句话就能够说明白什么是 Java 序列化和反序列化&#xff1f;Java 序列化…

【探索Linux】—— 强大的命令行工具 P.13(文件系统 | 软硬链接 | 动态库和静态库)

阅读导航 引言一、文件系统1. 磁盘文件系统2. 磁盘结构&#xff08;1&#xff09;物理结构&#xff08;2&#xff09;存储结构 3. stat 命令4. Linux ext2文件系统 二、软硬链接1. 软连接2. 硬链接 三、动态库和静态库1. 动态库&#xff08;1&#xff09;动态库文件扩展名&…

计算虚拟化1——CPU虚拟化

目录 vCPU的概念 vCPU和CPU的关系 CPU的Ring级别 CPU虚拟化技术 软件辅助全虚拟化 半虚拟化 硬件辅助虚拟化 计算资源的虚拟化可以分为CPU虚拟化、内存虚拟化、I/O虚拟化三个方面 CPU虚拟化&#xff1a;多个虚拟机共享CPU资源&#xff0c;对虚拟机中的敏感指令进行截获…

【JavaSE】基础笔记 - 类和对象(上)

目录 1、面向对象的初步认知 1.1、什么是面向对象 1.2、面向对象与面向过程 2. 类定义和使用 2.1、简单认识类 2.2、类的定义格式 2.3、自定义类举例说明 2.3.1、定义一个狗类 2.3.2、定义一个学生类 3、类的实例化 3.1、什么是实例化 3.2、类和对象的说明 1、面向…