文献阅读:LONGNET: Scaling Transformers to 1,000,000,000 Tokens

news/2024/7/19 10:47:45 标签: LongNet, LLM, scaling, Dilated注意力, Transformer
  • 文献阅读:LONGNET: Scaling Transformers to 1,000,000,000 Tokens
    • 1. 文章简介
    • 2. 方法原理
      • 1. 方法思路
      • 2. Dilated Attention
        • 1. 具体原理
        • 2. 多头实现
        • 3. 复杂度分析
      • 3. 训练方法
    • 3. 实验结果
    • 4. 结论 & 思考
    • 5. 参考链接
  • 文献链接:https://arxiv.org/abs/2307.02486

1. 文章简介

这篇文章算是我司最近的一篇力作吧,即DeepNet, Foundation Transformer之后,大佬们终于还是盯上了attention layer,毕竟attention层 O ( N 2 ) O(N^2) O(N2)的计算复杂度一直是制约Transformer往长文本发展的主要原因。

想当年,像是线性化Attention的Linformer,或者以更直观的稀疏化attention的Reformer,亦或者结合局部与全局attention的Longformer,或者类似金字塔型的将长文本拆分为短文本然后各自做attention然后逐层往上的方式(不过这篇具体文章给忘了),总之当年零零碎碎有不少关于优化attention层计算量,使之可以拓展到长文本上的工作。

不过可惜的是,虽然当时大家都觉得这个方向很重要,结果以GPT3还有PALM等为代表的大模型反而从工程上发力,直接强行扩展文本长度,从头上干掉了这个问题……

这两年,感觉这方面的工作已经比较少听到了,不过我司的大佬们似乎还是重新抓出了这个方向,然后像是DeepNet那样直接干出了一个量级上碾压的工作,也是真的厉害……

在这里插入图片描述

2. 方法原理

1. 方法思路

LongNet的整体的一个思路其实和之前的Reformer,Linformer等一致,还是在attention层方面做文章,希望将attention layer的计算复杂度从原始的 O ( N 2 d ) O(N^2d) O(N2d)进行优化,使得其与句长 N N N呈线性关系而非平方关系,从而使得模型整体的计算复杂度得到缩减。

对于,文中提出了dilated attention的结构,成功地将attention layer的计算复杂度从 O ( N 2 d ) O(N^2d) O(N2d)降维至 O ( N d ) O(Nd) O(Nd)复杂度。

在这里插入图片描述

需要注意的是,这里的比较没有包含linear transformer,它虽然很早之前已经实现了 O ( N d ) O(Nd) O(Nd)复杂度的attention实现,不过貌似效果不佳,不算是主流的attention方法,因此文中弃用了linear transformer作为对照。

下面,我们就需要具体看一下Dilated Attention层的具体实现方法。

2. Dilated Attention

1. 具体原理

首先,我们给出Dilated Attention层的整体原理图如下:

在这里插入图片描述

具体来说,就是首先给出一个局部窗口长度 w w w和间隔距离 r r r,那么,就可以将总长为 N N N的序列拆分为 N / w N/w N/w个子序列,然后在每一个子序列当中按照间隔 r r r取出token,一共就能够取出 w / r w/r w/r个token,然后用着 w / r w/r w/r个token作为新的序列计算attention,然后把这 N / w N/w N/w个attention矩阵concat起来,就能得到一个 N × N N \times N N×N的稀疏attention矩阵。

考察对于固定的 w , r w,r w,r下的第 i i i个attention矩阵,有:

{ Q i = [ Q i w Q i w + r ⋯ Q ( i + 1 ) w − r ] K i = [ K i w K i w + r ⋯ K ( i + 1 ) w − r ] V i = [ V i w V i w + r ⋯ V ( i + 1 ) w − r ] \left\{ \begin{aligned} Q_i &= [Q_{iw} & Q_{iw+r} & \cdots & Q_{(i+1)w-r}] \\ K_i &= [K_{iw} & K_{iw+r} & \cdots & K_{(i+1)w-r}] \\ V_i &= [V_{iw} & V_{iw+r} & \cdots & V_{(i+1)w-r}] \end{aligned} \right. QiKiVi=[Qiw=[Kiw=[ViwQiw+rKiw+rViw+rQ(i+1)wr]K(i+1)wr]V(i+1)wr]

此时有:

O i = s o f t m a x ( Q i ⋅ K i T d ) V i O_i = \mathop{softmax}(\frac{Q_i \cdot K_i^T}{\sqrt{d}})V_i Oi=softmax(d QiKiT)Vi

当然,这样的一个attention矩阵事实上只包含了局部的attention信息,因此无法兼顾长距离和短距离的attention信息。因此,如果要令总的attention兼顾长距离和短距离的attention信息,就需要取出多组 w , r w,r w,r,分别计算attention然后进行矩阵加和。也就是上图中的合并部分,从而才能获得包含全局attention信息的矩阵。

具体实现上来说,文中采用的是等比数列的方式进行实现,比如如下的方式:

{ w = w , α w , α 2 w , ⋯   , α n w r = r , α r , α 2 r , ⋯   , α n r \left\{ \begin{aligned} w &= {w, \alpha w, \alpha^2 w, \cdots, \alpha^n w} \\ r &= {r, \alpha r, \alpha^2 r, \cdots, \alpha^n r} \end{aligned} \right. {wr=w,αw,α2w,,αnw=r,αr,α2r,,αnr

在上图的demo中,取用的 w , r w,r w,r就是 4 4 4 1 1 1 α \alpha α的取值为 2 2 2

当然,考虑到由于 w , r w,r w,r取值不同导致的attention的密度不同,因此加和的时候需要对权重进行调整,具体而言:

O = ∑ i = 1 k s i ∑ j s j O r i , w i O = \sum\limits_{i=1}^{k}\frac{s_i}{\sum_j s_j}O_{r_i, w_i} O=i=1kjsjsiOri,wi

其中, s i s_i si ( w i , r i ) (w_i, r_i) (wi,ri)这组参数下计算得到的attention矩阵( Q i ⋅ K i T d \frac{Q_i \cdot K_i^T}{\sqrt{d}} d QiKiT)在计算softmax时的分母部分,也就是:

∑ j e Q i ⋅ K i T d \sum\limits_{j} e^{\frac{Q_i \cdot K_i^T}{\sqrt{d}}} jed QiKiT

这样也就得到了一组 n n n维的系数向量,作为我们这里的 s s s

2. 多头实现

关于Dilated Attention的多头实现,整体来说和vanilla transformer的实现方式是一致的,还是在input的向量当中进行split,然后分别过一个上述介绍的Dilated Attention层,最后将output的结果concat起来即可。

不过,感谢作者Shuming大佬的解释,这里和vanilla transformer存在一定的区别,具体就在于对于每一个context window,我们事实上都是等间隔的sample了其中的几个token进行attention的计算,某种意义上来说总是会丢失掉一些信息的。

因此,在设计多头attention的时候,文中进行了一定的优化,即对于input的token位置在不同的head上面给了不同的位置偏移量,从而使得尽可能地覆盖更多的token之间的attention。

具体来说就是,对于第 j j j个head,选取的token为:

{ Q i = [ Q i w + j ( ≡ r ) Q i w + r + j ( ≡ r ) ⋯ Q ( i + 1 ) w − r + j ( ≡ r ) ] K i = [ K i w + j ( ≡ r ) K i w + r + j ( ≡ r ) ⋯ K ( i + 1 ) w − r + j ( ≡ r ) ] V i = [ V i w + j ( ≡ r ) V i w + r + j ( ≡ r ) ⋯ V ( i + 1 ) w − r + j ( ≡ r ) ] \left\{ \begin{aligned} Q_i &= [Q_{iw + j(\equiv r)} & Q_{iw+r + j(\equiv r)} & \cdots & Q_{(i+1)w-r + j(\equiv r)}] \\ K_i &= [K_{iw + j(\equiv r)} & K_{iw+r + j(\equiv r)} & \cdots & K_{(i+1)w-r + j(\equiv r)}] \\ V_i &= [V_{iw + j(\equiv r)} & V_{iw+r + j(\equiv r)} & \cdots & V_{(i+1)w-r + j(\equiv r)}] \end{aligned} \right. QiKiVi=[Qiw+j(r)=[Kiw+j(r)=[Viw+j(r)Qiw+r+j(r)Kiw+r+j(r)Viw+r+j(r)Q(i+1)wr+j(r)]K(i+1)wr+j(r)]V(i+1)wr+j(r)]

可以用文中的图3来对上述不同头的attention进行更为形象化的展示如下:

在这里插入图片描述

3. 复杂度分析

下面,我们来考察一下Dilated Attention层的算法复杂度。

我们首先来考察对于一组确定的 w , r w,r w,r对应的Dilated Attention层的算法复杂度,其对应的结果如下:

F L O P s = 2 N w ⋅ ( w r ) 2 d = 2 N w d r 2 FLOPs = \frac{2N}{w} \cdot (\frac{w}{r})^2d = \frac{2Nwd}{r^2} FLOPs=w2N(rw)2d=r22Nwd

因此,遍历 w , r w,r w,r,我们即可得到完整的Dilated Attention层的算法复杂度如下:

F L O P s = ∑ i = 0 k − 1 2 N w i d r i 2 = 2 N w 0 d r 0 2 ∑ i = 0 k − 1 1 α i < 2 N w 0 d r 0 2 ⋅ α α − 1 ∼ O ( N d ) FLOPs = \sum\limits_{i=0}^{k-1}\frac{2Nw_id}{r_i^2} = \frac{2Nw_0d}{r_0^2} \sum\limits_{i=0}^{k-1} \frac{1}{\alpha^i} < \frac{2Nw_0d}{r_0^2} \cdot \frac{\alpha}{\alpha-1} \sim O(Nd) FLOPs=i=0k1ri22Nwid=r022Nw0di=0k1αi1<r022Nw0dα1αO(Nd)

3. 训练方法

最后,我们看一下文中实际的训练过程。

注意到,这里由于极限的扩展了输入的context的序列长度,因此事实上如何将文本塞入GPU也就成了一个大问题,因此,这方面也需要有一些工程上的实现细节考察。

具体来说,文中给出的方法还是说先对sequence进行一下split,然后由不同的GPU分别计算,最后进行加总实现。

其原理图可以参考文中的图4:

在这里插入图片描述

不过需要注意的是,这里在不同的gpu当中计算完了不同的部分的input seq之后,在计算dilated attention的时候会有一个slice的过程,然后slice之后的得到的dilated attention会在不同的GPU之间进行聚合,从而确保不同的gpu上的token之间的attention能够相互计算和聚合。

由于这里只是slice之后的attention,因此可以避免掉由于过长的文本长度(比如文中给出的1B)导致的内存爆炸的问题。

3. 实验结果

文中使用torchscale作为基准库,然后替换attention layer之后train了一个768维,12层的模型进行实验考察。

得到结果如下:

在这里插入图片描述

而除了最终的ppl之外,文中还比较了transformer与LongNet在处理不同文本长度的文本时所需的计算量。

在这里插入图片描述

可以看到:

  • LongNet可以在更少的计算量下获得相较于原始的transformer更好的ppl。

此外,文中还对LongNet在不同的参数量以及不同的context window进行了一下考察,得到结果如下:

在这里插入图片描述

可以看到:

  • 随着参数量的增长,模型的ppl是在不断减小的,说明LongNet具有很好的扩展能力;
  • context window越大,模型的效果也能够不断地提升,说明LongNet对于长文本有较好的理解能力。

最后,文中还非常直观的给出了将输入文本长度扩展到1B之后vanilla transformer与LongNet的infer时间变化的比较:

在这里插入图片描述

其结果直观地证明了LongNet对于长文本处理能力的能力,较之Vanilla Transformer耗时的快速增长,Dilated Attention基本没有发生什么太大的变化。

4. 结论 & 思考

综上,整体而言这篇文章还是很惊艳的,至少从context length的角度来说这种突破性的震撼确实厉害,结合他之前的foundation transformer等工作,我觉得他们在transformer的基础架构上面确实花了不少的功夫来做优化,这一点确实是厉害。

不过考虑到工程上,这篇文章的主要贡献可能还是在于长文本的关联attention上面,也就意味着其优势必然还是需要长上下文+大语料的前提下才能充分发挥出它的效果,就目前我的工作而言,可能还是有点用不太到……

所以,就只能膜拜一下大佬了,后面有机会的话可以考虑一下在业余时间复现一下看看了,在工作上倒是觉得ROI应该是不会很大了……

5. 参考链接

  1. Longformer: 局部Attention和全局attention的混搭

http://www.niftyadmin.cn/n/5153281.html

相关文章

【Linux系统编程】系统用户和权限的操作

目录 一&#xff0c;Linux的用户 1&#xff0c;用户之间的切换 2&#xff0c;超级用户权限的使用 二&#xff0c;Linux的文件权限 1&#xff0c;文件信息的介绍 2&#xff0c;文件权限的修改 3&#xff0c;用户的修改 3-1&#xff0c;拥有者的更改 3-2&#xff0c;所属…

两种MySQL OCP认证应该如何选?

很多同学都找姚远老师说要参加MySQL OCP认证培训&#xff0c;但绝大部分同学并不知道MySQL OCP认证有两种&#xff0c;以MySQL 8.0为例。 一种是管理方向&#xff0c;叫&#xff1a;Oracle Certified Professional, MySQL 8.0 Database Administrator&#xff08;我考试的比较…

【SpringSecurity】简介

SpringSecurity简介 Spring Security 的前身是Acegi Security&#xff0c;在被收纳为Spring 子项目后正式更名为Spring Security。Spring Security目前已经到了6.x&#xff0c;并且加入了原生OAuth2.0框架&#xff0c;支持更加现代化的密码加密方式。可以预见&#xff0c;在Ja…

网络服务退出一个问题的解析

一、问题 在实际开发中遇到一个问题&#xff0c;解决的过程虽然不长&#xff0c;但确实是想得比较多&#xff0c;总结一下&#xff0c;以供参考。这是一个网络通信的服务端而且使用的是别人封装好的库&#xff0c;通信等都没有问题&#xff0c;但在退出时会报一个错误&#xf…

一个JS版寻路的实现

js版的寻路的测试 20231104_161146 path get_v8: function (x_inc, y_inc) {if (x_inc 0) {if (y_inc < 0) {return [[0, -1], [-1, -1], [1, -1], [-1, 0], [1, 0], [-1, 1], [1, 1], [0, 1]];} else if (y_inc > 0) {return [[0, 1], [-1, 1], [1, 1], [-1, 0], [1, 0…

OpenFeign 的超时重试机制以及底层实现原理

目录 1. 什么是 OpenFeign&#xff1f; 2. OpenFeign 的功能升级 3. OpenFeign 内置的超时重试机制 3.1 配置超时重试 3.2 覆盖 Retryer 对象 4. 自定义超时重试机制 4.1 为什么需要自定义超时重试机制 4.2 如何自定义超时重试机制 5. OpenFeign 超时重试的底层原理 5…

“一键批量拆分HTML文本,高效整理文件,提升工作效率“

您是否曾经被大量的HTML文本文件困扰&#xff0c;难以找到所需的特定信息&#xff1f;现在&#xff0c;我们向您推荐一款强大的工具&#xff0c;它能够一键拆分HTML文本&#xff0c;让您轻松实现文件整理&#xff0c;提高工作效率&#xff01; 首先&#xff0c;在首助编辑高手…

C现代方法(第18章)笔记——声明

文章目录 第18章 声明18.1 声明的语法18.2 存储类型18.2.1 变量的性质18.2.2 auto存储类型18.2.3 static存储类型18.2.4 extern存储类型18.2.5 register存储类型18.2.6 函数的存储类型18.2.7 小结 18.3 类型限定符18.4 声明符18.4.1 解释复杂声明18.4.2 使用类型定义来简化声明…