nlp中的transformer中的mask

news/2024/7/19 9:56:24 标签: 自然语言处理, transformer, 人工智能

由于在实现多头注意力时需要考虑到各种情况下的掩码,因此在这里需要先对这部分内容进行介绍。在Transformer中,主要有两个地方会用到掩码这一机制。第1个地方就是在上一篇文章用介绍到的Attention Mask,用于在训练过程中解码的时候掩盖掉当前时刻之后的信息;第2个地方便是对一个batch中不同长度的序列在Padding到相同长度后,对Padding部分的信息进行掩盖。下面分别就这两种情况进行介绍。

1.Attention Mask

实现:generate_square_subsequent_mask

 def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

2.Padding Mask

实现:

用法:

https://blog.csdn.net/vivi_cin/article/details/135390462

参考:

nn.TransformerEncoderLayer中的src_mask,src_key_padding_mask解析_src_mask和src_key_padding_mask-CSDN博客

(32 封私信 / 4 条消息) transformer中: self-attention部分是否需要进行mask? - 知乎 (zhihu.com) 几个很好的回答:

Q:transformer中attention_mask一定需要嘛?

A:Transformer结构包括编码器和解码器,在编码过程中目的就是为了让模型看到当前位置前后的信息,所以不需要attention mask。但是在解码过程中为了模拟在真实的inference场景中,当前位置看不到下一位置,且同时需要上一位置的信息,所以在训练的时候加了attention mask。

所以,如果你的任务在实际的inference中也符合这样的特点,那么你在训练的时候也是需要attention,相反则不需要。

参考:(32 封私信 / 4 条消息) transformer中attention_mask一定需要嘛? - 知乎 (zhihu.com)

还有一个写的很好的博主:

 nn.TransformerEncoderLayer中的src_mask,src_key_padding_mask解析_src_mask和src_key_padding_mask-CSDN博客

参考的github上关于pad mask 实现 :

https://github.com/HIT-SCIR/plm-nlp-code/blob/64564b643a09cb85163ccca1f8c41fc94f5fc9ec/chp4/utils.py

关键代码:

def length_to_mask(lengths):
    max_len = torch.max(lengths)
    mask = torch.arange(max_len, device=lengths.device).expand(lengths.shape[0], max_len) < lengths.unsqueeze(1)
    return mask

 model:

class Transformer(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,
                 dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):
        super(Transformer, self).__init__()
        # 词嵌入层
        self.embedding_dim = embedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)
        # 编码层:使用Transformer
        encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        # 输出层
        self.output = nn.Linear(hidden_dim, num_class)

    def forward(self, inputs, lengths):
        inputs = torch.transpose(inputs, 0, 1)
        hidden_states = self.embeddings(inputs)
        hidden_states = self.position_embedding(hidden_states)
        attention_mask = length_to_mask(lengths) == False
        hidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)
        logits = self.output(hidden_states)
        log_probs = F.log_softmax(logits, dim=-1)
        return log_probs

模型完整代码:

class Transformer(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class,
                 dim_feedforward=512, num_head=2, num_layers=2, dropout=0.1, max_len=512, activation: str = "relu"):
        super(Transformer, self).__init__()
        # 词嵌入层
        self.embedding_dim = embedding_dim
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.position_embedding = PositionalEncoding(embedding_dim, dropout, max_len)
        # 编码层:使用Transformer
        encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        # 输出层
        self.output = nn.Linear(hidden_dim, num_class)

    def forward(self, inputs, lengths):
        inputs = torch.transpose(inputs, 0, 1)
        hidden_states = self.embeddings(inputs)
        hidden_states = self.position_embedding(hidden_states)
        attention_mask = length_to_mask(lengths) == False
        hidden_states = self.transformer(hidden_states, src_key_padding_mask=attention_mask).transpose(0, 1)
        logits = self.output(hidden_states)
        log_probs = F.log_softmax(logits, dim=-1)
        return log_probs


http://www.niftyadmin.cn/n/5313229.html

相关文章

【Verilog】行为级建模

系列文章 数值&#xff08;整数&#xff0c;实数&#xff0c;字符串&#xff09;与数据类型&#xff08;wire、reg、mem、parameter&#xff09; 运算符 数据流建模 系列文章定义过程语句initial过程语句always过程语句过程语句使用中的注意事项过程赋值语句连续赋值语句 条件…

【LLM】vLLM部署与int8量化

Acceleration & Quantization vLLM vLLM是一个开源的大型语言模型&#xff08;LLM&#xff09;推理和服务库&#xff0c;它通过一个名为PagedAttention的新型注意力算法来解决传统LLM在生产环境中部署时所遇到的高内存消耗和计算成本的挑战。PagedAttention算法能有效管理…

【leetcode 447. 回旋镖的数量】审慎思考与推倒重来

447. 回旋镖的数量 题目描述 给定平面上 **n **对 互不相同 的点 points &#xff0c;其中 points[i] [xi, yi] 。回旋镖 是由点 (i, j, k) 表示的元组 &#xff0c;其中 i 和 j 之间的距离和 i 和 k 之间的欧式距离相等&#xff08;需要考虑元组的顺序&#xff09;。 返回平…

使用numpy处理图片——基础操作

numpy是一款非常优秀的处理多维数组的Python基础包。在现实中&#xff0c;我们最经常接触的多维数组相关的场景就是图像处理。本系列将通过若干篇对图像处理相关的探讨&#xff0c;来介绍numpy的使用方法&#xff0c;以获得直观的体验。 本系列使用的照片使用的是RGBA色彩空间模…

什么是API网关代理?

带有API网关的代理服务显着增强了用户体验和性能。特别是对于那些使用需要频繁创建和轮换代理的工具的人来说&#xff0c;使用 API 可以节省大量时间并提高效率。 了解API API&#xff08;即应用程序编程接口&#xff09;充当服务提供商和用户之间的连接网关。通过 API 连接&a…

Android性能优化系列——内存优化

内存&#xff0c;是Android应用的生命线&#xff0c;一旦在内存上出现问题&#xff0c;轻者内存泄漏造成App卡顿&#xff0c;重者直接crash&#xff0c;因此一个应用保持健壮&#xff0c;要做好内存的使用和优化。网上有很多讲JAVA内存虚拟机的好文章&#xff0c;我就不赘述了。…

为什么制作文件二维码?文件做成二维码的优势

现在经常会遇到查看或者下载文件的情况&#xff0c;通过这种方式来完成文件的传递&#xff0c;那么为什么将文件做成二维码的方式来展示呢&#xff1f;二维码的优势有很多&#xff0c;比如能够让更多人同时获取内容才&#xff0c;方便更快的传播&#xff0c;而且没有有效期的限…

2024年跨境电商上半年营销日历最全整理

2024年伊始&#xff0c;跨境电商开启新一轮的营销竞技&#xff0c;那么首先需要客户需求&#xff0c;节假日与用户需求息息相关&#xff0c;那么接下来小编为大家整理2024上半年海外都有哪些节日和假期&#xff1f;跨境卖家如何见针对营销日历选品&#xff0c;助力卖家把握2024…