论文精读Transformer: Attention is all you need

news/2024/7/19 11:26:00 标签: transformer, 深度学习, 人工智能

  • 1 基础背景
  • 2 Motivation
  • 3 解决思路
    • 3.1 Encoder
    • 3.2 Decoder
  • 4 复杂度分析
  • 5 结果
  • 6 知识补充
  • 7 评价

1 基础背景

由Google机器翻译Google Brain团队发表。
论文链接:https://arxiv.org/abs/1706.03762
源码链接:https://github.com/tensorflow/tensor2tensor
一句话概述:Attention机制摒弃了循环神经网络和卷积神经网络,一经提出就在机器翻译领域达到了SOTA。

2 Motivation

LSTM和CNN存在如下问题。
LSTM是循环执行的,依赖前一项隐藏层信息h(t-1),天然不适合并行,而且受限于隐藏层大小,更容易丢失以前的信息。
CNN需要很多层才能把距离很远的像素或者输入联系起来,输入长度和计算复杂度成线性关系,导致(1)难以学习远距离的输入;(2)网络结构更加复杂;Transformer能够一次将所有的像素或者输入联系起来。
训练时间都很长。Transformer因为不依赖之前的信息,所以可以进行并行可算,减少计算时间。

3 解决思路

其网络结构如下所示。

3.1 Encoder

  1. Input Embedding
    这一步将一个单词src变成一个向量。
  2. Positional Encoding
    因为Transformer本身并不关注词源和词源之间的位置关系,每个词源会和其他词源做注意力机制,所以缺少对词源在句子中的位置的信息的捕获,在这一步通过加入位置信息,来引导Transformer理解词源的位置。
  3. A stack of 6 identical layers(MultiHead + Add&Norm + FeedForward + Add&Norm)
    (1)Attention
    Attention一般分为加和additive注意力和点积dot product注意力。
    本文使用的是Scaled dot product attention,是对【QKT】除以【根号下向量长度】
    假设存在n个词源,Q的大小为 n x dk,K的大小为 n x dk,得到 n x n 的矩阵,再进行缩放。
    缩放的原因是:在向量长度过长后,向量之间差异较大(部分维度的值大,部分小),算出来的值差异大,经过softmax之后会分别靠向0和1,除以dk之后差异变小,就一定程度上避免向两端靠拢。
    注意:softmax是针对每个词源的,而不是所有放一起做。
    softmax之后的矩阵维度是 n x n,可以看做一个权重矩阵,再乘以V。
    本文采用了Multi-head attention
    在获取Q/K/V矩阵之后,将其分为h个head,每个head i有自己的Wqi/Wki/Wvi,分别乘以QKV,得到每个head对应的Qi/Ki/Vi,其维度是原本的1/h,这个head相当于创建了subspace,进行更丰富的依赖关系学习,生成的向量di再拼起来,再经过一次线性变换后得到完整的输出。
    (2)Add
    借鉴了Resnet网络,将输入直接加到输出中,可以避免梯度消失或梯度爆炸。
    (3)Norm
    这一层Norm采用的是LayerNorm,不是Batchnorm。
    (4)Feedforward
    feedforward层 = linear(512-2048) + relu + linear(2048-512),因为是针对每个词源的,所以叫position-wise。因为之前一步的注意力已经获取了词源和词源之间的联系,所以这一步只需要针对单个词源就可以。

3.2 Decoder

  1. A stack of 6 identical layers(MultiHead + Add&Norm + FeedForward + Add&Norm)
    Encoder的输出同时作为Key和Value,即这两个矩阵是相同的。
    (1)Masked multi-head attention
    因为在实际使用时,是使用的自回归auto-regression,所以前一个词预测出来之后,才会预测后一个词。
    因此,在输入时,也应该到第t个词的时候,掩盖掉t+1之后的词,具体方式是计算完QKT之后将其设置为负无穷大,那么softmax之后就变成了0,相当于0权重。
    masking out (setting to −∞) all values in the input of the softmax

4 复杂度分析

复杂度如下表
(1)当句子长度n小于token维度d时,self-attention的复杂度会低于RNN;
(2)卷积神经网络核大小k一般小于n,但是需要很多层卷积才能联系起来;
(3)如果句子长度太大了,可以限制每个词源的搜索范围,只限定和附近r个做注意力机制。

5 结果

在英语转德语的比赛和英语转法语的比赛中,均达到了SOTA。

6 知识补充

batchnormlayernorm
相同点都是针对某部分输入进行标准化
不同点针对某一个feature,对所有样本进行标准化针对某一个样本,对所有feature进行标准化

7 评价

1.相比于RNN,Transformer更少关注之前的序列信息,导致序列信息capture sequential information获取更难。
2.位置信息使用偏少lack of explicit modeling of positional information。这导致难以获取长范围的信息may not capture long-range dependencies。
3.难以获取局部的模式difficulty in capturing local patterns。因为CNN通过卷积和池化能够把局部信息联系起来,而Transformer更容易忽略这些信息。
4.Transformer的时间复杂度与序列长度成平方关系,而RNN和CNN与序列长度成正比,所以在数据集太大的时候,计算复杂度会更高computational complexity。


http://www.niftyadmin.cn/n/5045037.html

相关文章

WPF 的一些坑

关于 本文记录开发程序以来 WPF 的一些坑 Background 当你使用下面的代码时&#xff0c;会发现 DragOver 不起作用 <GridGrid.Column"2"AllowDrop"True"DragLeave"OnDragFileLeave"DragOver"OnDragFileOver"Drop"OnDragFi…

PostgreSQL 查询某个属性相同内容出现的次数

查询某个数据库表属性 name 相同内容出现出现的次数&#xff0c;并按次数从大到小排序 SELECT name, COUNT(*) AS count FROM your_table GROUP BY name ORDER BY count DESC;示例 select project_id, COUNT(*) AS count from app_ads_positions group by project_id order b…

【送面试题】深入解析Cookie和Session的请求区别及使用场景

AI绘画关于SD,MJ,GPT,SDXL百科全书 面试题分享点我直达 2023Python面试题 2023最新面试合集链接 2023大厂面试题PDF 面试题PDF版本 java、python面试题 项目实战:AI文本 OCR识别最佳实践 AI Gamma一键生成PPT工具直达链接 玩转cloud Studio 在线编码神器 玩转 GPU AI…

学历不高,为何我还要转行编程?这个行业的秘密你知道吗?

Python 技能变现之路&#xff1a;掌握这些赚钱思路&#xff0c;开启财富之门 互联网行业&#xff1a;是走下坡路还是瘦死的骆驼比马大&#xff1f;看看这个你就知道了&#xff01; 招聘寒冬中&#xff0c;Python 程序员如何突出重围&#xff1f; 【文末福利】这样写简历&…

关于ES5内置函数Object的新方法--Object.create()

在今天的学习中&#xff0c;更深层次的了解了一下ES5中内置函数Object的新方法Object.create(),觉得这个api功能真的十分强大&#xff0c;并且使用的场景也有很多&#xff0c;现和各位同行们分享学习成果&#xff0c;欢迎各位大佬们指正&#xff0c;废话不多说&#xff0c;开整…

协议的定义

协议是网络通信实体之间在数据交换过程中需要遵循的规则或约定&#xff0c;是计算机网络有序运行的重要保证。 任何一个协议都会显式或隐式地定义3个基本要素&#xff1a;语法、语义和时序&#xff0c;称为协议三要素。 语法&#xff1a;语法定义实体之间交换信息的格式与结…

【draw】draw.io怎么设置默认字体大小

默认情况下draw里面的字体大小是12pt&#xff0c;我们可以将字体改成我们想要的大小&#xff0c;例如18pt。 然后点击样式&#xff0c;设置为默认样式。 下一次当我们使用文字大小时就是18pt了。

Postman全局配置变量token

Postman全局配置变量token 这里主要是介绍在 Postman 中全局配置token&#xff0c;以及方便以后查阅&#xff01;&#xff01;&#xff01; 一、简介 用户在开发或调试网络程序和网页B/S模式的程序时需要一些方法来跟踪网页请求&#xff0c;可使用一些网络的监视工具如Firebu…