Transformer的PyTorch实现之若干问题探讨(二)

news/2024/7/19 9:01:14 标签: transformer, pytorch, 深度学习

在《Transformer的PyTorch实现之若干问题探讨(一)》中探讨了Transformer的训练整体流程,本文进一步探讨Transformer训练过程中teacher forcing的实现原理。

1.Transformer中decoder的流程

在论文《Attention is all you need》中,关于encoder及self attention有较为详细的论述,这也是网上很多教程在谈及transformer时候会重点讨论的部分。但是关于transformer的decoder部分,他的结构上与encoder实际非常像,但其中有一些巧妙的设计。本文会详细谈谈。首先给出一个完整transformer的结构图:
在这里插入图片描述

上图左侧为encoder部分,右侧为decoder部分。对于decoder部分,将enc_input经过multi head attention后得到的张量,以K,V送入decoder中。而decoder阶段的masked multi head attention需要解决如何将dec_input编码成Q。最终输出的logits实际是与Q的维度一致。对于Scaled Dot-Product Attention,其公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
在《Transformer的PyTorch实现之若干问题探讨(一)》中,decoder阶段,Q的维度为[2,8,6,64](2为batch size,8为head数,6为句子长度,64为向量长度),K的维度为[2,8,5,64],V的维度为[2,8,5,64]。其中, Q K T QK^T QKT的维度为[2,8,6,5] 的,可以理解每个查询张量Q对每个键值张K的注意力权重。之后乘以V,维度为[2,8,6,64]。可以看到最终的维度是根据查询张量Q来加权值向量V。Q就是dec_input经过masked multi head attention得来。那么,dec_input中实际是包含了所有的标签的。那么dec_input是如何mask掉不需要的token的呢?

2.Decoder中的self attention mask

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model)
        self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])


    def forward(self, dec_inputs, enc_inputs, enc_outputs):
        '''
        这三个参数对应的不是Q、K、V,dec_inputs是Q,enc_outputs是K和V,enc_inputs是用来计算padding mask的
        dec_inputs: [batch_size, tgt_len]
        enc_inpus: [batch_size, src_len]
        enc_outputs: [batch_size, src_len, d_model]
        '''
        dec_outputs = self.tgt_emb(dec_inputs)#词序号编码成向量
        dec_outputs = self.pos_emb(dec_outputs).cuda()#位置编码
        dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() #[2, 6, 6]
        dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() #[2, 6, 6],上三角矩阵
        # 将两个mask叠加,布尔值可以视为0和1,和大于0的位置是需要被mask掉的,赋为True,和为0的位置是有意义的为False
        dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask +
                                       dec_self_attn_subsequence_mask), 0).cuda()
        # 这是co-attention部分,为啥传入的是enc_inputs而不是enc_outputs:enc_outputs是向量,这儿是需要通过词编码来判断是否需要mask掉
        dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) #[2, 6, 5]

        for layer in self.layers:
            dec_outputs = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)

        return dec_outputs # dec_outputs: [batch_size, tgt_len, d_model]

上述代码为Decoder部分。可以看到有两个mask:dec_self_attn_pad_mask(用于将dec_inputs中的P mask掉)与dec_self_attn_subsequence_mask(用于实现decoder的self attention)。这两个mask在后面会相加合并。这儿可以分别展示二者的值,其中:

dec_self_attn_pad_mask:
tensor([[[False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False]],
        [[False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False],
         [False, False, False, False, False, False]]], device='cuda:0')#[2, 6, 6]
dec_self_attn_subsequence_mask:
tensor([[[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]],
        [[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]]], device='cuda:0', dtype=torch.uint8)#[2, 6, 6]

可以看到,dec_self_attn_pad_mask全为false,这是因为dec_input中不包含P,而dec_self_attn_subsequence_mask为上三角矩阵,对于每个token,需要mask掉它之后的token(本代码中,为1或True的位置会被mask掉)。接下来进一步追问,为什么上三角矩阵就可以mask掉该token之后的token?具体是如何实现的呢?
对于前文的Scaled Dot-Product Attention公式,代码中的表述实际为:

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v] 全文两处用到注意力,一处是self attention,另一处是co attention,前者不必说,后者的k和v都是encoder的输出,所以k和v的形状总是相同的
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        '''
        # 1) 计算注意力分数QK^T/sqrt(d_k)
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k)  # scores: [batch_size, n_heads, len_q, len_k]
        # 2)  进行 mask 和 softmax
        # mask为True的位置会被设为-1e9
        scores.masked_fill_(attn_mask, -1e9) # 把True设为-1e9
        attn = nn.Softmax(dim=-1)(scores)  # attn: [batch_size, n_heads, len_q, len_k]
        # 3) 乘V得到最终的加权和
        context = torch.matmul(attn, V)  # context: [batch_size, n_heads, len_q, d_v], [2, 8, 5, 64]
        '''
        得出的context是每个维度(d_1-d_v)都考虑了在当前维度(这一列)当前token对所有token的注意力后更新的新的值,
        换言之每个维度d是相互独立的,每个维度考虑自己的所有token的注意力,所以可以理解成1列扩展到多列

        返回的context: [batch_size, n_heads, len_q, d_v]本质上还是batch_size个句子,
        只不过每个句子中词向量维度512被分成了8个部分,分别由8个头各自看一部分,每个头算的是整个句子(一列)的512/8=64个维度,最后按列拼接起来
        '''
        return context # context: [batch_size, n_heads, len_q, d_v]

其中,Q,K,V的维度都是[2, 8, 6, 64], score的维度为[2, 8, 6, 6],即每个token之间的注意力分数。这儿取出一个batch中的一个head下的注意力分数a为例,a的维度为[6, 6],如图所示:
在这里插入图片描述

如上图所示,在得分score中,标黄的0.71和0.24分别是S与S,以及S与I的词向量相乘得到。由于I在S后面,所以需要通过mask将其置为负无穷大,而0.71需要保留,因为是S与S在同一个位置上。因此这个mask矩阵为上三角矩阵。


http://www.niftyadmin.cn/n/5372119.html

相关文章

v-if 和v-show 的区别

第074个 查看专栏目录: VUE ------ element UI 专栏目标 在vue和element UI联合技术栈的操控下,本专栏提供行之有效的源代码示例和信息点介绍,做到灵活运用。 提供vue2的一些基本操作:安装、引用,模板使用,computed&a…

如何评价CS2 ?

看过好多评论,以及结合自身感受,发现这么一个有趣的现象:玩过CS1.5和CS1.6的老玩家对CS2的评价是正向的。上手就玩CSGO的玩家对于CS2评价是负向的。 这就说明玩过更多版本cs的玩家对于此次更新有更多的参考对象,能清楚的知道CS2的…

代码随想录 Leetcode55. 跳跃游戏

题目&#xff1a; 代码(首刷自解 2024年2月9日&#xff09;&#xff1a; class Solution { public:bool canJump(vector<int>& nums) {int noz 0;for (int i nums.size() - 2; i > 0; --i) {if (nums[i] 0) {noz;continue;} else {if (nums[i] > noz) noz …

B站弹幕分析系统

视频展示&#xff0c;请点击。 尚硅谷案例 utllib的基本使用 # 使用urllib来获取百度首页的源码 import urllib.request# (1)定义一个url 就是你要访问的地址 url http://www.baidu.com# (2)模拟浏览器先服务器发送请求 response响应 response urllib.request.urlopen(url)…

C语言:分支与循环

创造不易&#xff0c;友友们给个三连吧&#xff01;&#xff01; C语⾔是结构化的程序设计语⾔&#xff0c;这⾥的结构指的是顺序结构、选择结构、循环结构&#xff0c;C语⾔是能够实 现这三种结构的&#xff0c;其实我们如果仔细分析&#xff0c;我们⽇常所⻅的事情都可以拆分…

Vue 3 + Koa2 + MySQL 开发和上线部署个人网站

Vue 3 Koa2 MySQL 开发和上线部署个人网站 记录个人的一个操作步骤, 顺序不分先后, 嫌啰嗦请出门右转! 环境说明: 服务器: 阿里云轻量应用服务器 服务器系统: CentOS8.2 本地环境: macOS 12.7.2 Node: 20.10.0 MySQL: 8.0.26 Vue: 3.3.11 Koa: 2.7.0 pm2: 5.3.1 Nginx: 1.1…

Yii特性

Yii几乎拥有了当今Web 2.0应用发展的全部特性。下面是这些特性的一个简短的清单。 模型-视图-控制器&#xff08;MVC&#xff09;设计模式&#xff1a;Yii在WEB编程中采用这一成熟的技术从而可以更好的将逻辑层和表现层分开。数据库访问对象&#xff08;DAO&#xff09;和Acti…

Spark安装(Yarn模式)

一、解压 链接&#xff1a;https://pan.baidu.com/s/1O8u1SEuLOQv2Yietea_Uxg 提取码&#xff1a;mb4h tar -zxvf /opt/software/spark-3.0.3-bin-hadoop3.2.tgz -C /opt/module/spark-yarn mv spark-3.0.3-bin-hadoop3.2/ spark-yarn 二、配置环境变量 vim /etc/profile…