人工智能-注意力机制之Transformer

news/2024/7/19 10:13:39 标签: transformer, 深度学习

Transformer

比较了卷积神经网络(CNN)、循环神经网络(RNN)和自注意力(self-attention)。值得注意的是,自注意力同时具有并行计算和最短的最大路径长度这两个优势。因此,使用自注意力来设计深度架构是很有吸引力的。对比之前仍然依赖循环神经网络实现输入表示的自注意力模Transformer模型完全基于注意力机制,没有任何卷积层或循环神经网络层。尽管Transformer最初是应用于在文本数据上的序列到序列学习,但现在已经推广到各种现代的深度学习中,例如语言、视觉、语音和强化学习领域。

模型

Transformer作为编码器-解码器架构的一个实例,其整体架构图如下图展示。正如所见到的,Transformer是由编码器和解码器组成的。与基于Bahdanau注意力实现的序列到序列的学习相比,Transformer的编码器和解码器是基于自注意力的模块叠加而成的,源(输入)序列和目标(输出)序列的嵌入(embedding)表示将加上位置编码(positional encoding),再分别输入到编码器和解码器中。

transformer架构 

Transformer解码器也是由多个相同的层叠加而成的,并且层中使用了残差连接和层规范化。除了编码器中描述的两个子层之外,解码器还在这两个子层之间插入了第三个子层,称为编码器-解码器注意力(encoder-decoder attention)层。在编码器-解码器注意力中,查询来自前一个解码器层的输出,而键和值来自整个编码器的输出。在解码器自注意力中,查询、键和值都来自上一个解码器层的输出。但是,解码器中的每个位置只能考虑该位置之前的所有位置。这种掩蔽(masked)注意力保留了自回归(auto-regressive)属性,确保预测仅依赖于已生成的输出词元。

接下来将实现Transformer模型的剩余部分。

import math
import warnings
import pandas as pd
from d2l import paddle as d2l

warnings.filterwarnings("ignore")
import paddle
from paddle import nn

 基于位置的前馈网络

基于位置的前馈网络对序列中的所有位置的表示进行变换时使用的是同一个多层感知机(MLP),这就是称前馈网络是基于位置的(positionwise)的原因。在下面的实现中,输入X的形状(批量大小,时间步数或序列长度,隐单元数或特征维度)将被一个两层的感知机转换成形状为(批量大小,时间步数,ffn_num_outputs)的输出张量。

#@save
class PositionWiseFFN(nn.Module):
    """基于位置的前馈网络"""
    def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs,
                 **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

下面的例子显示,改变张量的最里层维度的尺寸,会改变成基于位置的前馈网络的输出尺寸。因为用同一个多层感知机对所有位置上的输入进行变换,所以当所有这些位置的输入相同时,它们的输出也是相同的。

ffn = PositionWiseFFN(4, 4, 8)
ffn.eval()
ffn(torch.ones((2, 3, 4)))[0]
tensor([[-0.8290,  1.0067,  0.3619,  0.3594, -0.5328,  0.2712,  0.7394,  0.0747],
        [-0.8290,  1.0067,  0.3619,  0.3594, -0.5328,  0.2712,  0.7394,  0.0747],
        [-0.8290,  1.0067,  0.3619,  0.3594, -0.5328,  0.2712,  0.7394,  0.0747]],
       grad_fn=<SelectBackward0>)

 


http://www.niftyadmin.cn/n/5215836.html

相关文章

Java中的泛型是什么?如何使用泛型类和泛型方法?

Java 中的泛型是一种编程机制&#xff0c;允许你编写可以与多种数据类型一起工作的代码&#xff0c;同时提供编译时类型检查以确保类型的安全性。泛型的主要目的是提高代码的可重用性、类型安全性和程序的整体性能。 泛型类&#xff08;Generic Class&#xff09;: 在泛型类中…

【STM32】新建工程

学习来源&#xff1a;[2-2] 新建工程_哔哩哔哩_bilibili 目前STM32的开发主要有基于寄存器的开发方式、基于标准库也就是库函数的方式和基于HAL库的方式。本学习是基于库函数的方式。&#xff08;各种资料去百度云下载&#xff09; 1 建立工程文件夹 Keil中新建工程&#xf…

***利用SecureCRT上传、下载文件(使用sz与rz命令)

使用SecureCrt连接到服务器。 1、上传文件&#xff1a;rz命令 输入“rz”&#xff0c;回车&#xff0c;在弹窗的文件选择框中选择本地磁盘中需要上传的文件&#xff0c;点击【Add】按钮&#xff0c;再点击传输指令即可。 注意&#xff08;如果没有权限不可能成功&#xff0c;…

[递归]排队游戏

例题(15.2)排队游戏 题目描述 在幼儿园中&#xff0c;老师安排小朋友做一个排队的游戏。首先老师精心的把数目相同的小男孩和小女孩编排在一个队列中&#xff0c;每个小孩按其在队列中的位置发给一个编号&#xff08;编号从0开始&#xff09;。然后老师告诉小朋友们&#xff…

GoLang Filepath.Walk遍历优化

原生标准库在文件量过大时效率和内存均表现不好 1400万文件遍历Filepath.Walk 1400万文件重写直接调用windows api并处理细节 结论 1400万文件遍历时对比 对比条目filepath.walkwindows api并触发黑科技运行时间710秒22秒内存占用480M38M 关键代码 //超级快的文件遍历 fun…

SpringBoot使用ObjectMapper之Long和BigDemical类型的属性字符串处理,防止前端丢失数值精度

SpringBoot使用ObjectMapper之Long和BigDemical类型的属性字符串处理&#xff0c;防止前端丢失数值精度! 方式一&#xff1a;注解 使用注解 JsonFormat(shape JsonFormat.Shape.STRING)&#xff0c;如下&#xff1a; import com.fasterxml.jackson.annotation.JsonFormat; …

JavaFX开发调用AWT创建系统托盘MenuItem菜单中文乱码

打开系统托盘MenuItem只能显示英文字符和中文显示方框 解决办法&#xff1a; 打开Edit Configurations… 选择Mofidy options 勾选Add VM options 在VM optios中填入以下代码 -Dfile.encodingGBK

XML Schema中的simpleContent 元素

XML Schema中的simpleContent 元素出现在complexType 内部&#xff0c;是对complexType 的一种扩展、或者限制。 simpleContent 元素在complexType元素内部最多只能出现1次。 simpleContent元素下面必须包含1个restriction或者extension元素。 例如&#xff0c;下面的Schema片…