Attention Free Transformer(AFT)

news/2024/7/19 9:22:48 标签: transformer, 深度学习, 人工智能

Attention Free Transformer(AFT)

paper: An Attention Free Transformer

date: 2021-05

org: Apple

1 Motivation

原本基于dot product self attention Transformer的时间复杂度和空间复杂度都很高。提出了一个新的AFT层来降低transformer的计算量。

在这里插入图片描述

2 Method

2.1 Multi-Head Attention回顾

首先回顾一下经典的Multi-Head Attention(MHA),每一个head的计算如下

f i ( X ) = σ ( Q i ( K i ) T d k ) V i ,   s . t .   Q i = X W i Q , K i = X W i K , V i = X W i V , (1) f _ { i } ( X ) = \sigma ( \frac { Q _ { i } ( K _ { i } ) ^ { T } } { \sqrt { d _ { k } } } ) V _ { i } , \ \mathrm { s . t . } \ Q _ { i } = X W _ { i } ^ { Q } , K _ { i } = X W _ { i } ^ { K } , V _ { i } = X W _ { i } ^ { V } , \tag{1} fi(X)=σ(dk Qi(Ki)T)Vi, s.t. Qi=XWiQ,Ki=XWiK,Vi=XWiV,(1)

其中: W i Q    ∈    R d × d k , W i K    ∈    R d × d k , W i V    ∈    R d × d υ W _ { i } ^ { Q } \; \in \; R ^ { d \times d _ { k } } , W _ { i } ^ { K } \; \in \; R ^ { d \times d _ { k } } , W _ { i } ^ { V } \; \in \; R ^ { d \times d _ { \upsilon } } WiQRd×dk,WiKRd×dk,WiVRd×dυ σ \sigma σ是非线性函数,默认为 s o f t m a x softmax softmax。通常情况下 d v = d k , h = d d k d_v = d_k, h = \frac{d}{d_k} dv=dk,h=dkd。假定输入 X ∈ R T × d X \in \mathbb {R}^ {T \times d} XRT×d, 经过 f i f_i fi转化后的输出 f i ( X ) ∈ R T × d v f_i{(X)} \in \mathbb{R} ^{T \times d_v} fi(X)RT×dv。将所有head的结果拼接起来得到最后的输出 R T × d \mathbb{R} ^{T \times d} RT×d

单头Attention的时间复杂度计算:

  • Q K V QKV QKV 的计算,此处有3个矩阵乘法,计算量为 d × d k × T × 3 d \times d_k \times T \times 3 d×dk×T×3, 时间复杂度为: O ( 1 h T d 2 ) \mathcal{O}(\frac{1}{h}Td^2) O(h1Td2)
  • Q K T QK^T QKT的计算,计算量为: d k × T × T d_k \times T \times T dk×T×T, 时间复杂度为: O ( 1 h T 2 d ) \mathcal{O}(\frac{1}{h}T^2d) O(h1T2d)
  • scale 的计算量为: T × T T \times T T×T, 时间复杂度为: O ( T 2 ) \mathcal{O}(T^2) O(T2)
  • softmax的计算量为: T × T T \times T T×T, 时间复杂度为: O ( T 2 ) \mathcal{O}(T^2) O(T2)
  • 最后加权乘法计算量为 d k × T × T d_k \times T \times T dk×T×T,时间复杂度为: O ( 1 h T 2 d ) \mathcal{O}(\frac{1}{h}T^2d) O(h1T2d)

对于MHA,时间复杂度为 O ( T d 2 ) \mathcal{O}(Td^2) O(Td2)

2.2 Attention Free Transofrmer(AFT)

2.2.1 AFT full

第一步和MHA一样,输入 X X X经过三个linear transfer得到 Q K V QKV QKV,3个矩阵, 维度为 R T × d \mathbb{R}^{T \times d} RT×d。AFT引入了一个新的可训练参数矩阵 w ∈ R T × T w \in \mathbb{R}^{T \times T} wRT×T,论文将其称之为可学习的一对一位置偏置(learned pair-wise position biases)。

在这里插入图片描述

我们以 y t y_t yt 为视角看每一步的具体流程。

SETP1: w e i g h t e d ( K ( t ) ) \mathrm{weighted}(K^{(t)}) weighted(K(t))。从 w w w t = t t=t t=t的向量, 和 K K K做点乘后以列方向进行 s o f t m a x \mathrm{softmax} softmax。该步骤的计算复杂度为 O ( T × d ) \mathcal{O}(T \times d) O(T×d)

W e i g h t e d ( K ( t ) ) = exp ⁡ ( K + w t ) ∑ i = 1 T exp ⁡ ( k i + w t i ) (2) \mathrm{Weighted}(K^{(t)}) = \frac{\exp (K + w_t ) }{\sum_{i=1}^{T} \exp (k_i + w_{ti}) } \tag{2} Weighted(K(t))=i=1Texp(ki+wti)exp(K+wt)(2)

在这里插入图片描述

STEP2: 求 A t t e n t i o n ( t ) \mathrm{Attention}^{(t)} Attention(t)矩阵。将q_t用sigmoid变换后,点乘wighted(K)。该步骤的计算复杂度为 O ( T × d ) \mathcal{O}(T \times d) O(T×d)

A t t e n t i o n ( t ) = σ ( q t ) ⊙ W e i g h t e d ( K ( t ) ) = σ ( q t ) ⊙ exp ⁡ ( K + w t ) ∑ i = 1 T exp ⁡ ( k i + w t i ) (3) \mathrm{Attention^{(t)}} = \sigma(q_t) \odot \mathrm{Weighted}(K^{(t)})= \frac{\sigma(q_t) \odot \exp (K + w_t ) }{\sum_{i=1}^{T} \exp (k_i + w_{ti}) } \tag{3} Attention(t)=σ(qt)Weighted(K(t))=i=1Texp(ki+wti)σ(qt)exp(K+wt)(3)
在这里插入图片描述

STEP3: 计算 y t y_t yt。该步骤的计算复杂度为 O ( T × d ) \mathcal{O}(T \times d) O(T×d)

y t = ∑ i = 1 T ( A t t e n t i o n ( t ) i ⊙ v i ) = ∑ i = 1 T σ ( q t ) ⊙ exp ⁡ ( k i + w t ) ∑ i = 1 T exp ⁡ ( k i + w t i ) ⊙ v i (4) y_t = \sum_{i=1}^{T}(\mathrm{Attention^{(t)}}_i \odot v_i) = \sum_{i=1}^{T} \frac{\sigma(q_t) \odot \exp (k_i + w_t ) }{\sum_{i=1}^{T} \exp (k_i + w_{ti}) } \odot v_i \tag{4} yt=i=1T(Attention(t)ivi)=i=1Ti=1Texp(ki+wti)σ(qt)exp(ki+wt)vi(4)

在这里插入图片描述

对式(4)稍做变形,可得论文中的计算公式

y t = σ ( q t ) ⊙ ∑ i = 1 T exp ⁡ ( k i + w t ) ⊙ v i ∑ i = 1 T exp ⁡ ( k i + w t i ) (5) y_t = \sigma(q_t)\odot \frac{ \sum_{i=1}^{T}\exp (k_i + w_t ) \odot v_i}{\sum_{i=1}^{T} \exp (k_i + w_{ti}) } \tag{5} yt=σ(qt)i=1Texp(ki+wti)i=1Texp(ki+wt)vi(5)

将所有的步骤串起来的流程如下。可以看到AFT其实也用到了attention的思想。但AFT中的Attention Score的计算并没有用到矩阵乘法,只用到了向量点乘。虽整体的计算复杂度仍然是 O ( T 2 d ) \mathcal{O}(T^2d) O(T2d),但计算量已有所下降。

式(4)计算pipeline

在这里插入图片描述

式(5)计算pipeline

在这里插入图片描述

2.2.1 AFT local

在许多情况下,局部性是一个很重要的归纳偏置(inductive bias),而标准的Transformer的计算中没有引入局部信息。因此,作者提出AFT-local。其形式与AFT-Full一致。区别在于,引入了下式限制

w t , t ′ = { w t , t ′ , i f ∣ t − t ′ ∣ < s 0 , o t h e r w i s e . (6) w_{t, t'} = \begin{cases} w_{t, t'}, \quad \mathrm{if} |t - t'| < s \\ 0, \quad \mathrm{otherwise.}\end{cases} \tag{6} wt,t={wt,t,iftt<s0,otherwise.(6)

式中的 s s s就是定义的局部窗口大小(local window size)。它进一步降低了计算量。变换后的 w w w如下图所示(此时 s = 2 s=2 s=2, 黑色方块为0)。

在这里插入图片描述

2.2.2 AFT simple

AFT simple是AFT local当 s = 0 s = 0 s=0时的特殊形式。此时没有位置偏置。可将式5化简为,因为对不同的 t t t ∑ i = 1 T ( s o f t m a x ( K ) ⊙ V ) i \sum_{i=1}^{T} (\mathrm{softmax}(K) \odot V)_{i} i=1T(softmax(K)V)i都是相同的。AFT simple的时间复杂度为 O ( T d ) \mathcal{O}(Td) O(Td)

y t = σ ( q t ) ⊙ ∑ i = 1 T exp ⁡ ( k i ) ⊙ v i ∑ i = 1 T exp ⁡ ( k i ) = σ ( q t ) ⊙ ∑ i = 1 T ( s o f t m a x ( K ) ⊙ V ) i (6) y_t = \sigma(q_t)\odot \frac{ \sum_{i=1}^{T}\exp (k_i) \odot v_i}{\sum_{i=1}^{T} \exp (k_i) } = \sigma(q_t)\odot \sum_{i=1}^{T} (\mathrm{softmax}(K) \odot V)_{i}\tag{6} yt=σ(qt)i=1Texp(ki)i=1Texp(ki)vi=σ(qt)i=1T(softmax(K)V)i(6)

2.2.3 AFT conv

作者进一步将局部性的思想扩展到空间权重共享(如卷积),提出AFT-conv。具体来说,让 w t , t ′ w_{t,t'} wt,t的值仅依赖 t t t t ′ t' t的相对位置。为了考虑参数数量随着 h e a d head head数增加而增长的情况,作者采用了一个设计选择,将 K K K的维度与head数绑定在一起(MHA的思路)。这使得AFT-conv可以采用深度可分离卷积、全局池化和element-wise操作的实现方式。

可以看到与AFT simple相比,AFT conv引入了head思想,并通过1维卷积的计算结果引入局部信息。其形式与式(6)相比分子分母中新增了 c o n v 1 d ( exp ⁡ ( K j ) ⊙ V j ,    exp ⁡ ( w j )   − 1 ) \mathrm { c o n v 1 d } ( \exp ( K ^ { j } ) \odot V ^ { j } , \; \exp ( w ^ { j } ) \, - 1 ) conv1d(exp(Kj)Vj,exp(wj)1) c o n v 1 d ( exp ⁡ ( K j ) ,    exp ⁡ ( w j )    − 1 ) \mathrm { c o n v 1 d } ( \exp ( K ^ { j } ) , \; \exp ( w ^ { j } ) \; - 1 ) conv1d(exp(Kj),exp(wj)1)。(上标 j j j表示第 j j j个head)。此时的 w w w为conv1d的filter。

y t j = σ q ( q t j ) ⊙ c o n v 1 d ( exp ⁡ ( K j ) ⊙ V j ,    exp ⁡ ( w j )   − 1 ) + ∑ i = 1 T exp ⁡ ( k i j ) ⊙ v i j c o n v 1 d ( exp ⁡ ( K j ) ,    exp ⁡ ( w j )    − 1 ) + ∑ i = 1 T exp ⁡ ( k i j ) (7) y _ { t } ^ { j } = \sigma _ { q } ( q _ { t } ^ { j } ) \odot \frac { \mathrm { c o n v 1 d } ( \exp ( K ^ { j } ) \odot V ^ { j } , \; \exp ( w ^ { j } ) \, - 1 ) + \sum _ { i = 1 } ^ { T } \exp ( k _ { i } ^ { j } ) \odot v _ { i } ^ { j } } { \mathrm { c o n v 1 d } ( \exp ( K ^ { j } ) , \; \exp ( w ^ { j } ) \; - 1 ) + \sum _ { i = 1 } ^ { T } \exp ( k _ {i } ^ { j } ) } \tag{7} ytj=σq(qtj)conv1d(exp(Kj),exp(wj)1)+i=1Texp(kij)conv1d(exp(Kj)Vj,exp(wj)1)+i=1Texp(kij)vij(7)

从ViT可视化attention map中可以看出(横轴为head, 纵轴为layer)。原本的ViT(左边)的不同层,head的attention map的响应最大区域基本都是中心区域。而用了AFT-conv后,不同层、head的attention都有所不同,有助于模型捕获不同尺度的特征。

在这里插入图片描述

3 小结

本文提出了一种Dot Product Attention Free的Transformer,最多能将transofmer的时间复杂度从 O ( T 2 d ) \mathcal{O}(T^2d) O(T2d)降低到 O ( T d ) \mathcal{O}(Td) O(Td)(AFT-simple)。


http://www.niftyadmin.cn/n/5035959.html

相关文章

时间序列预测系列之循环神经网络

文章目录 1.前言2.RNN基础组件1.RNN2.LSTM3.GRU4.FC-LSTM5.ConvLSTM6.CNN-LSTM 1.前言 循环神经网络&#xff08;Recurrent Neural Network&#xff0c;简称RNN&#xff09;是一类在处理序列数据和时间序列数据时非常有用的神经网络架构。RNN的主要特点是它们具有循环连接&…

TimeUnit 时间颗粒度和延时的使用 (demo)

1. TimeUnit介绍 TimeUnit是JDK封装好的java.util.concurrent包下面的一个类&#xff0c;表示给定单元粒度的时间段 import java.util.concurrent.TimeUnit; 2. TimeUnit作用时间颗粒度转换 线程延时 3. 常用的颗粒度 TimeUnit.DAYS //天TimeUnit.HOURS //小…

【MySQL】ibd2sdi工具介绍和使用

文章目录 【MySQL】ibd2sdi工具介绍ibd2sdi的解析对象ibd2sdi的选项参数ibd2sdi的输出例参考 【免责声明】文章仅供学习交流&#xff0c;观点代表个人&#xff0c;与任何公司无关。 编辑|SQL和数据库技术(ID:SQLplusDB) 【MySQL】ibd2sdi工具介绍 MySQL提供了叫做ibd2sdi的实用…

程序员必掌握的核心算法:提升编程技能的关键路径

一&#xff1a;引言 作为程序员&#xff0c;算法是我们编程生涯中的灵魂。算法是解决问题的方法和步骤&#xff0c;它们在计算机科学中扮演着至关重要的角色。无论你是初学者还是经验丰富的专业人士&#xff0c;都需要掌握一些核心算法&#xff0c;因为它们在各种应用场景中频…

smtp邮箱有啥用?smtp协议是什么?

SMTP&#xff08;Simple Mail Transfer Protocol&#xff09;是邮件传输的标准协议&#xff0c;是互联网邮件的基石。SMTP协议主要用于从一个邮件客户端发送邮件到一个邮件服务器。蜂邮今天就带大家了解一下。 SMTP邮箱是基于SMTP协议的邮件服务&#xff0c;它使得用户可以使用…

Java - RSA 不限制长度加解密算法,你就只知道个分段法?

问题描述 java javax.crypto.IllegalBlockSizeException: Data must not be longer than XXX bytes 今天发现用网上千篇一律的 RSA 加解密算法&#xff0c;待加密数据一旦比较大&#xff0c;就会报以上错误。查了网上一些解决方案——分段法&#xff0c;看起来实在有点繁琐&…

SQL server中merge语句添加where条件

SQL server中merge语句添加where条件 1.merge语句添加where条件 在SQL Server中&#xff0c;可以使用MERGE语句将INSERT、UPDATE和DELETE操作组合在一起&#xff0c;根据指定的条件将数据合并到目标表中。如果想在MERGE语句中添加WHERE条件&#xff0c;可以按照以下格式编写语…

记录每日LeetCode 198.打家劫舍 Java实现

题目描述: 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金&#xff0c;影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统&#xff0c;如果两间相邻的房屋在同一晚上被小偷闯入&#xff0c;系统会自动报警。 给定一个代表每个房屋…