multiheadattention类原理及源码理解

news/2024/7/19 9:45:55 标签: 人工智能, transformer

网络找的一段代码如下:

class MultiHeadedAttention(nn.Module):
    def __init__(self, h, d_model, dropout=0.1):
        "Take in model size and number of heads."
        super(MultiHeadedAttention, self).__init__()
        assert d_model % h == 0
        # We assume d_v always equals d_k
        self.d_k = d_model // h
        self.h = h
        self.linears = clones(nn.Linear(d_model, d_model), 4)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, query, key, value, mask=None):
        "Implements Figure 2"
        if mask is not None:
            # Same mask applied to all h heads.
            mask = mask.unsqueeze(1)
        nbatches = query.size(0)
        
        # 1) Do all the linear projections in batch from d_model => h x d_k 
        query, key, value = \
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
             for l, x in zip(self.linears, (query, key, value))]
        #这段代码首先使用zip函数,将self.linears和(query, key, value)这两个列表打包成一个元组列表,其中每个元组包含一个线性层对象和一个输入张量
        #对遍历的每一个Linear层,对query key value分别计算,结果放在query key value中输出
        
        # 2) Apply attention on all the projected vectors in batch. 
        x, self.attn = attention(query, key, value, mask=mask, 
                                 dropout=self.dropout)
        
        # 3) "Concat" using a view and apply a final linear. 
        x = x.transpose(1, 2).contiguous() \
             .view(nbatches, -1, self.h * self.d_k)
        return self.linears[-1](x)

python、pytorch、人工智能相关知识现阶段都是简单的了解,没有相关的实践。因此在学习的时候不要习惯性的扣代码细节。能把论文原理和代码逻辑对应即可、能总结代码块重点内容即可。

transformer中self-attention就是对一个输入序列计算每个位置的注意力,每个位置在论文原文中用d_model(512)维表示,多头就是每个位置用h(原文中8个)个头计算,这样每个头计算一个位置中的64维特征。

自注意力机制有什么好处呢?

自注意力机制的目的是让模型能够同时关注输入序列中的不同位置和信息,从而捕捉序列中的复杂模式和关系。通过计算每个位置的向量与其他位置的向量之间的相似度或相关性,模型可以学习到序列中每个元素对于输出结果的重要性,从而给予不同的权重。

为什么要使用多头呢?下面是我找到的解释:

多头计算可以让模型同时关注输入序列中的不同方面和细节,从而增强模型的表达能力和学习能力。每个注意力头可以捕捉输入序列中的不同模式和关系,而最终的线性变换可以将这些信息融合在一起。
多头计算可以降低模型的复杂度和计算成本。对于较大的 d_model 来说,如果只使用单头计算,那么 QK^T 的结果会非常大,导致 softmax 函数的梯度非常小,不利于网络的训练。而使用多头计算,可以将 d_model 分割成 h 个较小的子空间,从而减少计算量和内存消耗34。
多头计算还可以
提高模型的可解释性和泛化能力
。我们可以从模型中检查不同注意力头的分布,观察模型是如何关注不同位置和信息的。各个注意力头可以学会执行不同的任务,例如语法分析、实体识别等

MultiHeadedAttention类还做了什么事情?
1、通过4个线性层(通常是4)计算得到Q K V矩阵
transformer中,Q、K、V是通过四个线性层得到的,分别是:
Q = XW^Q ,其中X是embedding输入矩阵,W^Q 是一个可训练的参数矩阵,大小为(d_model* d_model),用于将X映射到Q空间。
K = XW^K ,其中X是embedding输入矩阵,W^K 是一个可训练的参数矩阵,大小为(d_model* d_model),用于将X映射到K空间。
V = XW^V ,其中Xembedding是输入矩阵,W^V 是一个可训练的参数矩阵,大小为(d_model* d_model)用于将X映射到V空间。


http://www.niftyadmin.cn/n/5141229.html

相关文章

Vue中watch侦听器用法

watch 需要侦听特定的数据源,并在单独的回调函数中执行副作用 watch第一个参数监听源 watch第二个参数回调函数cb(newVal,oldVal) watch第三个参数一个options配置项是一个对象{ immediate:true //是否立即调用一次 deep:true //是否开启…

Python---字符串输入和输出---input()、格式化输出:%,f形式,format形式

字符串输入: 在Python代码中,我们可以使用input()方法来接收用户的输入信息。记住:在Python中,input()方法返回的结果是一个字符串类型的数据。 如果之后使用输入的数据,一定要记得利用数据类型转换。 相关链接:Pyt…

全开源抖音快手微信取图小程序源码

全开源抖音快手微信很火爆的取图小程序源码,可以给人别人搭建,也可以自己做;对接流量主,收益很可观。 下载地址:https://bbs.csdn.net/topics/617502419

TypeScript 中for in遍历,元素隐式具有 “any“ 类型,因为类型为 “string“ 的表达式不能用于索引类型

第一种方案、使用[key: string]:string 形式为键名声明类型 声明类型: interface FormInfoData {[materialCode: string]: stringmaterialName: stringmaterialUnit: stringmaterialItem: stringmaterialOwnership: stringmaterialclassCode: stringmat…

uniapp原生插件之无预览静默拍照

插件介绍 无预览静默拍照,在用户无感觉情况下调用摄像头拍照 插件地址 无预览静默拍照 - DCloud 插件市场 超级福利 uniapp 插件购买超级福利 插件申请权限 存储卡读写权限摄像头权限 manifest.json权限列表 /* android打包配置 */"android" : {…

HyperAI超神经 x 中国信通院 | 可信开源大模型案例汇编(第一期)案例征集计划正式启动自定义目录标题)

为进一步促进大模型的开源和合作,引导开源大模型产业健康规范发展,中国信息通信研究院现开启「可信开源大模型案例汇编(第一期)」的案例征集计划。 HyperAI超神经将以合作伙伴的身份,协助调研国产开源大模型的技术细节…

[BUUCTF NewStarCTF 2023 公开赛道] week3 crypto/pwn

居然把第3周忘了写笔记了. 后边难度上来了,还是很有意思的 Crypto Rabins RSA rsa一般要求e与phi互质,但rabin一般用2,都是板子题也没什么好解释的 from Crypto.Util.number import * from secret import flag p getPrime(64) q getPrime(64) assert p % 4 3 assert q %…