Transformer位置编码详解

news/2024/7/19 10:39:13 标签: transformer, 深度学习, 人工智能

在处理自然语言时候,因Transformer是基于注意力机制,不像RNN有词位置顺序信息,故需要加入词的位置信息来显示的表明词的上下文关系。具体是将词经过位置编码(positional encoding),然后与emb词向量求和,作为编码块(Encoder block)的输入信息。在《Attention Is All You Need》论文中,位置编码信息如下:

其中PE的维度为:[序列长度,编码维度](即[seq_len,emb_dim])

pos表示词语在句子中的位置

d_{model}  表示编码(emb)的维度

i表示词向量的位置,偶数位置用sin,奇数位置用cos

据此,即可根据不同的pos信息和i信息得到不同的位置嵌入信息。具体计算时候,由于sin和cos后半部分相同,采用log将次方拿下,方便计算。

具体Pytorch代码实现如下

# coding=utf8

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import math

def get_position_encoding(seq_len, emb_dim):
    pe = torch.zeros(seq_len, emb_dim)
    pos = torch.arange(0, seq_len, dtype=torch.float)
    pos = pos.unsqueeze(1)
    locpos = torch.arange(0, emb_dim, 2).float()
    div_term = torch.exp(locpos * (-math.log(10000.0) / emb_dim))    #对应上面公式最后一行
    # 第一维度序列长度,第二维度编码
    pe[:, 0::2] = torch.sin(pos * div_term)
    pe[:, 1::2] = torch.cos(pos * div_term)
    return pe
pe = get_position_encoding(100, 32)
sns.heatmap(pe)
plt.xlabel('emb_dim')
plt.ylabel('seq_len')
plt.show()

生成图如下:

补充知识点:

切片,位置编码赋值:

def clip_pos(x):
    xdata = torch.arange(1, x, 1)
    print("###xdata:###", xdata)
    """
    切片的语法使用冒号(:)来表示,形式为`[start:end:step]`,其中start表示起始索引(包含),end表示结束索引(不包含),step表示步长(默认为1)。
    如果省略start,则默认从序列的第一个元素开始
    如果省略end,则默认截取到序列的最后一个元素 
    如果省略step,则默认以步长为1进行截取
    """
    # print(xdata[0::2])
    print(xdata[0:4:2])
    print(xdata[1::2])


http://www.niftyadmin.cn/n/5472323.html

相关文章

【Linux】Ubuntu 压缩与解压缩

首先在Windows下安装7Zip压缩软件,以便于可以生成 .tar 和 .bz2 的压缩格式的文件。例如新建一个test文件夹,操作后如下。 gzip 压缩工具:负责 .gz 格式的文件的压缩和解压缩,gzip --help 查看使用帮助; 压缩文件&…

HTML1:html基础

HTML 冯诺依曼体系结构 运算器 控制器 存储器 输入设备 输出设备 c/s(client客户端) 客户端架构软件 需要安装,更新麻烦,不跨平台 b/s(browser浏览器) 网页架构软件 无需安装,无需更新,可跨平台 浏览器 浏览器内核: 处理浏览器得到的各种资源 网页: 结构 HTML(超…

【LangChain学习之旅】—(19)BabyAGI:根据气候变化自动制定鲜花存储策略

【LangChain学习之旅】—(19)BabyAGI:根据气候变化自动制定鲜花存储策略 AutoGPTBaby AGIHuggingGPTLangChain 目前是将基于 CAMEL 框架的代理定义为 Simulation Agents(模拟代理)。这种代理在模拟环境中进行角色扮演,试图模拟特定场景或行为,而不是在真实世界中完成具体…

【React】React18+Typescript+craco配置最小化批量引入Svg并应用的组件

React18Typescriptcraco配置最小化批量引入Svg并应用的组件 前言创建React Typescript项目通过require.context实现批量引入Svg安装[types/webpack-env](https://github.com/DefinitelyTyped/DefinitelyTyped/blob/master/README.zh-Hans.md)解决类型报错安装[craco](https://…

Android 关于apk反编译d2j-dex2jar classes.dex失败的几种方法

目录 确认路径正确直接定位到指定目录确定目录正确,按如下路径修改下面是未找到相关文件正确操作 确认路径正确 ,即d2j-dex2jar和classes.dex是否都在一个文件夹里(大部分的情况都是路径不正确) 直接定位到指定目录 路径正确的…

数据结构入门系列-栈的结构及栈的实现

🌈个人主页:羽晨同学 💫个人格言:“成为自己未来的主人~” 栈 栈的概念及结构 栈:一种特殊的线性表,其只允许在固定的一段进行插入和删除元素操作,进行数据输入和删除操作的一端称为栈顶,另…

海外仓的出入库流程有什么痛点?位像素海外仓系统怎么提高出入库效率?

随着跨境电商的蓬勃发展,海外仓是其中不可或缺的一个关键环节。而货物的出库与入库则是海外仓管理中的一个核心业务流程,它的运作效率直接影响到整个跨境物流的效率和客户体验。今天,让我们具体来看一看关于海外仓出入库的流程,其…

【随笔】Git 高级篇 -- 撤销变更(十四)

💌 所属专栏:【Git】 😀 作  者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! 💖 欢迎大…