Transformer-XL模型简单介绍

news/2024/7/19 12:08:42 标签: transformer, 深度学习, 自然语言处理

目录

一、前言

二、整体概要

三、细节描述

 3.1 状态复用的块级别循环

3.2 相对位置编码

四、论文链接


一、前言

以自注意力机制为核心的 Transformer 模型是各种预训练语言模型中的主要组成部分。自注意力机制能够构建序列中各个元素之间的上下文关联程度,挖掘深层次的语义信息。然而,自注意力机制的时空复杂度为O(n^{2}),即时间和空间消耗会随着输入序列的长度呈平方级增长。这种问题的存在使得预训练语言模型处理长文本的效率较低。

传统处理长文本的方法一般是切分输入文本,其中每份的大小设置为预训练语言模型能够单次处理的最大长度(如512)。 最终将多片文本的决策结果进行综合(如对分类结果进行投票)或者拼接(如序列标注或生成任务)得到最终结果 。然而,这种方法不能很好地构建文本块之间的联系,挖掘长距离文本依赖的能力较弱。因此,更好的方法还是需要从根本上提高预训练语言模型单次能够处理的最大文本长度,从而能够更加充分地利用自注意力机制。针对这一挑战,Transformer-XL模型给出了解决方法。

二、整体概要

前面介绍到,Transformer 中处理长文本的传统策略是将文本切分成固定长度的块,并单独编码每个块,块与块之间没有信息交互。下图给出了块长度为4的一个示例。可以看到在训练阶段,Transformer分别对第一块中的序列x1 、x 2 、x 3 、x 4 与第二块中的序列x 5 、x 6 、x 7 、x 8 进行建模。而在测试阶段,由于每次处理的最大长度为4,当模型在处理序列x2 、x 3 、x 4 、x 5 时,无法构建与历史x 1 的关系。另外,由于需要以滑动窗口的方式处理整个序列,所以这种方法的效率也非常低
 

 

为了优化对长文本的建模,Transformer-XL提出了两种改进策略——状态复用的块级别循环(Segment-level Recurrence with State Reuse)和相对位置编码(Relative Positional Encodings)。接下来针对这两种改进策略进行介绍。

三、细节描述

 3.1 状态复用的块级别循环

假设两个连续的长度为n的块分别为s_{\tau} =x_{\tau,1} \cdot \cdot \cdot x_{\tau,n}s_{\tau+1} =x_{\tau+1,1} \cdot \cdot \cdot x_{\tau+1,n},第τ 块在第l 层Transformer的隐含层输出为 h_{\tau }^{[l]}\in \mathbb{R}^{n\times d}(d为隐含层维度大小)。计算第τ+1块在第l层 Transformer的隐含层输出 h_{\tau+1 }^{[l]}
 

式中,函数SG(·)表示停止梯度传输;操作符  ◦  表示沿长度维度进行拼接; W 表示全连接权重。与传统Transformer的主要不同点在 于,键k_{\tau+1 }^{[l]} 和值v_{\tau+1 }^{[l]} 依赖于扩展的上下文信息h_{\tau+1 }^{[l-1]} 以及上一个块h_{\tau }^{[l-1]} 的缓存信息
 
这种状态复用的块级别循环机制应用于语料库中每两个连续的片段,本质上是在隐含状态下产生一个片段级的循环。因此,在这种机制下,Transformer利用的有效上下文可以远远超出两个块。需要注意的是, h_{\tau+1 }^{[l]}h_{\tau}^{[l-1]}之间的循环依赖性使得存在向下一层的计算依赖,这与传统的循环神经网络(RNN)中的同层循环机制(即只存在相同层之间的循环)是不同的。因此,最大可能的依赖长度随块的长度n和层数L呈线性增长(与开头提到的平方级增长形成对比),即O(nL) ,如下图(b)中的阴影部分所示。这种机制和RNN中常用的随时间反向传播机制(Back Propagation Through Time,BPTT)类似。然而,在这里是将整个序列的隐含层状态全部缓存,而不是像BPTT机制中只会保留最后一个状态。
 

另外,这种设计除了能够处理更长的文本序列,还能加快测试速度。作者通过一系列的实验表明,Transformer-XL相比传统Transformer,能够在测试阶段达到1800倍以上的加速。

3.2 相对位置编码

虽然状态复用的块级别循环技术能够将不同块之间的信息联系起来,但在实际应用中还存在一个非常重要的问题:如何区分不同块中的相同位置(如第\tau块和第\tau +1块中的第二个位置)?采用传统Transformer中的绝对位置编码方法是不可行的,其原因可通过下式说明:

式中, v_{\tau }\in \mathbb{R}^{n\times d}表示块s_{\tau } 对应的词向量;v^{p}表示位置向量;f 表示变换函数。 可以看到对于不同的块,使用的位置向量是一样的。例如,对于第  τ 块中的x_{\tau,i} 和第 τ+1 块中的x_{\tau+1,i} 的位置信息是完全相同的,而这 显然是不合理的。
 
为了解决这个问题,Transformer-XL引入了 相对位置编码 策略。位置信息的重要性主要体现在注意力矩阵的计算上,用于构建不同词之间的关联关系。应用相对位置编码后,第i个词与第j个词的注意力值a_{i,j}为:

 

 

式中,W 和u\in \mathbb{R}^{d}表示可训练的权重; v_{x_{i}}表示词xi对应的词向量;R\in \mathbb{R}^{N\times d}表示相对位置矩阵(N表示最大编码长度),是一个不可训练的正弦编码矩阵,其第i行表示相对位置间隔为i的位置向量。接下来针对上式中的各个部分进行介绍。

基于内容的相关度(a):计算查询xi与键xj的内容之间关联信息;

内容相关的位置偏置(b):计算查询xi的内容与键xj的位置编码之间的关联信息,R_{i-j}表示两者的相对位置信息,表示取R中的第i−j行;

全局内容偏置(c):计算查询xi的位置编码与键xj的内容之间的关联信息;

全局位置偏置(d):计算查询xi与键xj的位置编码之间关联信息。

想深入学习的读者可以参考下方论文链接了解更多细节部分,同时模型代码也一并附加到文章顶部。

四、论文链接

https://arxiv.org/abs/1901.02860


http://www.niftyadmin.cn/n/458029.html

相关文章

SpringBoot 如何使用 @RestControllerAdvice 注解进行 RESTful 异常处理

SpringBoot 如何使用 RestControllerAdvice 注解进行 RESTful 异常处理 在 SpringBoot 应用程序中,RESTful 异常处理是一个非常重要的话题。当 RESTful API 出现异常时,我们需要对异常进行处理,以保证 API 的稳定性和可靠性。SpringBoot 提供…

C++之判断文件是否存在的几种方法

文章目录 1. 方法一:C语言之access2. 方法二:C方法之ifstream3. 方法三:fopen方法4. 方法四:sys中的stat函数方法 1. 方法一:C语言之access 可以使用C语言中unistd.h里的函数access()来判断文件是否存在,…

避免由于数据导入导致 Load average 过载

通过htop发现过载 htop输出的头部理解: - 左上是CPU的使用情况:有多少核就有多少行,每行中0表示负载低、100表示高负载;颜色编码CPU和Memory由什么进程占用(红色——kernel进程,绿色——普通用户进程&…

Windows Command Prompt 中编写的批处理脚本

在 Windows Command Prompt 中编写的批处理脚本。del /s 1.docxs Command Prompt 中编写的批处理脚本。del /s 1.docx 这个批处理脚本的功能是删除名为 "1.docx" 的文件,使用了 "/s" 标志来删除所有子目录中的文件。 如果我有多个文件要删除&a…

Prompt Engineering 面面观

作者:紫气东来 项目地址:https://zhuanlan.zhihu.com/p/632369186 一、概述 提示工程(Prompt Engineering),也称为 In-Context Prompting,是指在不更新模型权重的情况下如何与 LLM 交互以引导其行为以获得…

mitmproxy抓包原理

文章目录 mitmproxy原理详解1 mitmproxy 基本原理2 作为中间代理获取HTTP请求信息2.1 应对显式HTTP请求2.2 应对隐式HTTP请求 3 作为中间代理获取HTTPS请求信息3.1 显式HTTPS请求1) 获取远程主机名2) 处理主题备用名称SAN3) 处理服务器名称指示SNI4) 显式HTTPS请求信息获取整个…

腾讯云服务器ping不通怎么解决?什么原因?

腾讯云服务器ping不通什么原因?ping不通公网IP地址还是域名?新手站长从云服务器公网IP、安全组、Linux系统和Windows操作系统多方面来详细说明腾讯云服务器ping不通的解决方法: 目录 腾讯云服务器ping不通原因分析及解决方法 安全组ICMP协…

<C++> C++11 Lambda表达式

C11 Lambda表达式 1.C98中的一个例子 在C98中&#xff0c;如果想要对一个数据集合中的元素进行排序&#xff0c;可以使用std::sort方法。 #include <algorithm> #include <functional> int main() {int array[] {4, 1, 8, 5, 3, 7, 0, 9, 2, 6};// 默认按照小于…