深度学习注意力机制(MHA)的训练(Eigen)

news/2024/7/19 11:27:29 标签: 深度学习, 机器学习, transformer, c++

Multi-head Attention 在 Transformer模型中的位置

简介

本文使用Eigen3在Transformer模型中实现多头注意力的前向传播和反向传播。具体来说,这个eigenMHA (eigenDNN)【源码:https://github.com/jundaf2/eigenMHA】所对应了大致如下的cuDNN的api的功能:

  • cudnnCreateAttnDescriptor()
  • cudnnSetAttnDescriptor()
  • cudnnGetAttnDescriptor()
  • cudnnDestroyAttnDescriptor()
  • cudnnGetMultiHeadAttnBuffers()
  • cudnnGetMultiHeadAttnWeights()
  • cudnnMultiHeadAttnForward()
  • cudnnMultiHeadAttnBackwardData()
  • cudnnMultiHeadAttnBackwardWeights()

Multi-head Attention 的构成
简单来说,MHA作为Tranformer模型中的一个模块,在训练中既要在需要将embedding,通过Q K V的线性层、S=Q*K^T(GEMM)、P=Softmax(Mask(S))、P=Dropout(P)、O=P*V(GEMM)、O的线性层 前向传播到下一层(可能是Layernorm),然后再在反向传播中,将输出O的梯度,通过O的线性层O=P*K(GEMM)P=Dropout(P)P=Softmax(S)S=Q*K^T(GEMM)Q K V的线性层反向传播回输入端(embedding的梯度)。

MHA训练过程涉及到的变量

MHA训练前向

  1. 输入QKV线性层的embeddings (前向起始点)

Q i n K i n V i n \mathbf{Q}_{in} \quad \mathbf{K}_{in} \quad \mathbf{V}_{in} QinKinVin

  1. 线性层权重和偏置

W Q b Q \mathbf{W}_{Q} \quad \mathbf{b}_{Q} WQbQ

W K b K \mathbf{W}_{K} \quad \mathbf{b}_{K} WKbK

W V b V \mathbf{W}_{V} \quad \mathbf{b}_{V} WVbV

W O b O \mathbf{W}_{O} \quad \mathbf{b}_{O} WObO

  1. 计算中间变量
  2. O的线性层输出值 和 目标值

O o u t O t a r g e t \mathbf{O}_{out}\quad\mathbf{O}_{target} OoutOtarget

MHA前向传播公式如下:

Q = Q i n ∗ W Q + b Q \mathbf{Q} = \mathbf{Q}_{in}*\mathbf{W}_{Q}+\mathbf{b}_{Q} Q=QinWQ+bQ

K = K i n ∗ W K + b K \mathbf{K} = \mathbf{K}_{in}*\mathbf{W}_{K}+\mathbf{b}_{K} K=KinWK+bK

V = V i n ∗ W V + b V \mathbf{V} = \mathbf{V}_{in}*\mathbf{W}_{V}+\mathbf{b}_{V} V=VinWV+bV

S = Q ∗ K T \mathbf{S} = \mathbf{Q}*\mathbf{K}^T S=QKT

KaTeX parse error: Undefined control sequence: \bfrac at position 42: …ask(\mathbf{S}*\̲b̲f̲r̲a̲c̲{1}{\sqrt{d}}))…

P = D r o p o u t F W D ( P ) \mathbf{P} = DropoutFWD(\mathbf{P}) P=DropoutFWD(P)

O = P ∗ V \mathbf{O}=\mathbf{P}*\mathbf{V} O=PV

O o u t = O ∗ W O + b O \mathbf{O}_{out} = \mathbf{O}*\mathbf{W}_{O}+\mathbf{b}_{O} Oout=OWO+bO

MSE Loss

在这个训练的计算结构中,反向传播的起始点是损失函数,因为我们仅仅关注于MHA本身,因此将MHA的输出 O o u t \mathbf{O}_{out} Oout和预设的目标 O t a r g e t \mathbf{O}_{target} Otarget输入MSE函数取得误差 l o s s loss loss和反向传播的梯度 g r a d _ O o u t \mathbf{grad\_O}_{out} grad_Oout
l o s s = M S E L o s s ( O o u t , O t a r g e t ) loss = MSELoss(\mathbf{O}_{out},\mathbf{O}_{target}) loss=MSELoss(Oout,Otarget)

MHA训练反向

  1. MHA输出(O的线性层输出)的梯度 (来自于 LayerNorm,反向起始点)

g r a d _ O o u t \mathbf{grad\_O}_{out} grad_Oout

  1. 中间变量的梯度
  2. 输入的梯度

g r a d _ Q i n g r a d _ K i n g r a d _ V i n \mathbf{grad\_Q}_{in} \quad \mathbf{grad\_K}_{in} \quad \mathbf{grad\_V}_{in} grad_Qingrad_Kingrad_Vin

  1. 权重和偏置的梯度

g r a d _ W Q g r a d _ b Q \mathbf{grad\_W}_{Q} \quad \mathbf{grad\_b}_{Q} grad_WQgrad_bQ

g r a d _ W K g r a d _ b K \mathbf{grad\_W}_{K} \quad \mathbf{grad\_b}_{K} grad_WKgrad_bK

g r a d _ W V g r a d _ b V \mathbf{grad\_W}_{V} \quad \mathbf{grad\_b}_{V} grad_WVgrad_bV

g r a d _ W O g r a d _ b O \mathbf{grad\_W}_{O} \quad \mathbf{grad\_b}_{O} grad_WOgrad_bO

MHA反向传播公式如下:

g r a d _ O = g r a d _ O o u t ∗ W O \mathbf{grad\_O} = \mathbf{grad\_O}_{out}*\mathbf{W}_{O} grad_O=grad_OoutWO

g r a d _ W O = g r a d _ O o u t T ∗ O \mathbf{grad\_W}_{O} = \mathbf{grad\_O}_{out}^T*\mathbf{O} grad_WO=grad_OoutTO

g r a d _ b O = c o l s u m ( g r a d _ O o u t ) \mathbf{grad\_b}_{O} = colsum(\mathbf{grad\_O}_{out}) grad_bO=colsum(grad_Oout)

g r a d _ P = g r a d _ O ∗ V T \mathbf{grad\_P} = \mathbf{grad\_O}*\mathbf{V}^T grad_P=grad_OVT

g r a d _ V = P T ∗ g r a d _ O \mathbf{grad\_V} = \mathbf{P}^T*\mathbf{grad\_O} grad_V=PTgrad_O

g r a d _ P = D r o p o u t B W D ( g r a d _ P ) \mathbf{grad\_P} = DropoutBWD(\mathbf{grad\_P}) grad_P=DropoutBWD(grad_P)

g r a d _ S = S o f t m a x B W D ( P , g r a d _ P ) ∗ 1 d \mathbf{grad\_S} = SoftmaxBWD(\mathbf{P},\mathbf{grad\_P})*\frac{1}{\sqrt{d}} grad_S=SoftmaxBWD(P,grad_P)d 1

g r a d _ Q = g r a d _ S ∗ K \mathbf{grad\_Q} = \mathbf{grad\_S}*\mathbf{K} grad_Q=grad_SK

g r a d _ K = g r a d _ S T ∗ Q \mathbf{grad\_K} = \mathbf{grad\_S}^T*\mathbf{Q} grad_K=grad_STQ

g r a d _ Q i n = g r a d _ Q ∗ W Q \mathbf{grad\_Q}_{in} = \mathbf{grad\_Q}*\mathbf{W}_{Q} grad_Qin=grad_QWQ

g r a d _ W Q = g r a d _ Q T ∗ Q i n \mathbf{grad\_W}_{Q} = \mathbf{grad\_Q}^T*\mathbf{Q}_{in} grad_WQ=grad_QTQin

g r a d _ b Q = c o l s u m ( g r a d _ Q ) \mathbf{grad\_b}_{Q} = colsum(\mathbf{grad\_Q}) grad_bQ=colsum(grad_Q)

g r a d _ K i n = g r a d _ K ∗ W K \mathbf{grad\_K}_{in} = \mathbf{grad\_K}*\mathbf{W}_{K} grad_Kin=grad_KWK

g r a d _ W K = g r a d _ K T ∗ K i n \mathbf{grad\_W}_{K} = \mathbf{grad\_K}^T*\mathbf{K}_{in} grad_WK=grad_KTKin

g r a d _ b K = c o l s u m ( g r a d _ K ) \mathbf{grad\_b}_{K} = colsum(\mathbf{grad\_K}) grad_bK=colsum(grad_K)

g r a d _ V i n = g r a d _ V ∗ W V \mathbf{grad\_V}_{in} = \mathbf{grad\_V}*\mathbf{W}_{V} grad_Vin=grad_VWV

g r a d _ W V = g r a d _ V T ∗ V i n \mathbf{grad\_W}_{V} = \mathbf{grad\_V}^T*\mathbf{V}_{in} grad_WV=grad_VTVin

g r a d _ b V = c o l s u m ( g r a d _ V ) \mathbf{grad\_b}_{V} = colsum(\mathbf{grad\_V}) grad_bV=colsum(grad_V)

MHA训练库的组成部分

MSE损失函数

损失函数作为深度学习系统的起源,产生了损失量和回传梯度,是深度学习系统的基本组成部分。
请添加图片描述

eidnnStatus_t eidnnMSELoss(
    eidnnHandle_t handle,
    const Tensor<float, 3> &output, 
    const Tensor<float, 3> &target,
    Tensor<float, 0> &loss,
    Tensor<float, 3> &d_loss);

线性层

cuDNN 没有给线性层操作提供了专门的API

在eigenDNN, 我们有

eidnnStatus_t eidnnLinearForward(eidnnHandle_t handle,
                    const Tensor<float, 3>& x, // data
                    const Tensor<float, 2>& w, // weight
                    const Tensor<float, 1>& bias, // bias
                    Tensor<float, 3>& y);

eidnnStatus_t eidnnLinearBackward(eidnnHandle_t handle,
                     const Tensor<float, 3>& dy,
                     const Tensor<float, 3>& x,
                     const Tensor<float, 2>& w,
                     Tensor<float, 3>& dx, // gradient of input data
                     Tensor<float, 2>& dw, // accumulated gradient of weight
                     Tensor<float, 1>& dbias // accumulated gradient of bias
                     );

批量矩阵乘法

C = β ∗ C + α ∗ O p c ( M a t M u l ( O p a ( A ) , O p b ( B ) ) ) C = \beta * C + \alpha*Op_c(MatMul(Op_a(A),Op_b(B))) C=βC+αOpc(MatMul(Opa(A),Opb(B)))

, 其中 O p m ( M ) Op_m(M) Opm(M) 是对 M M M 是否采取转置操作.

cuDNN 没有给批量矩阵乘法操作提供了专门的API

在eigenDNN, 我们有

eidnnStatus_t eidnnStridedBatchedGemmForward(
    eidnnHandle_t handle,
    float alpha,
    float beta,
    bool trans_A, // Op_a
    bool trans_B, // Op_b
    bool trans_C, // Op_c
    const Tensor<float, 4> &A, 
    const Tensor<float, 4> &B, 
    Tensor<float, 4> &C);
eidnnStatus_t eidnnStridedBatchedGemmBackward(
    eidnnHandle_t handle,
    float alpha,
    float beta,
    bool trans_A, // Op_a
    bool trans_B, // Op_b
    bool trans_C, // Op_c
    const Tensor<float, 4> &A, // A
    const Tensor<float, 4> &B, // B
    const Tensor<float, 4> &d_C, // gradient of C
    Tensor<float, 4> &d_A, // gradient of A
    Tensor<float, 4> &d_B // gradient of B
    );

Softmax

cuDNN 给softmax 操作提供了如下 API.

  • cudnnSoftmaxForward()
  • cudnnSoftmaxBackward()

在eigenDNN, 我们有

eidnnStatus_t eidnnSoftmaxForward(eidnnHandle_t handle,
                    eidnnSoftmaxAlgorithm_t algo,
                    eidnnSoftmaxMode_t mode,
                    const Tensor<float, 4>& x,
                    Tensor<float, 4>& y);
eidnnStatus_t eidnnSoftmaxBackward(eidnnHandle_t handle,
                     eidnnSoftmaxAlgorithm_t algo,
                     eidnnSoftmaxMode_t mode,
                     const Tensor<float, 4>& y,
                     const Tensor<float, 4>& dy,
                     Tensor<float, 4>& dx);

Dropout

cuDNN 给dropout 操作提供了如下 API.

  • cudnnCreateDropoutDescriptor()
  • cudnnDestroyDropoutDescriptor()
  • cudnnDropoutGetStatesSize()
  • cudnnDropoutGetReserveSpaceSize()
  • cudnnDropoutForward()
  • cudnnGetDropoutDescriptor()
  • cudnnRestoreDropoutDescriptor()
  • cudnnSetDropoutDescriptor()
  • cudnnDropoutBackward()

在eigenDNN, 我们有

// dropout rate, 
// pointer to memory space of states (allocated by forward pass), 
// size of memory space in bytes (calculated by forward pass), 
// random seed
using eidnnDropoutDescriptor_t = std::tuple<float, void*, size_t, unsigned long long>; 
eidnnStatus_t eidnnDropoutForward(
    eidnnHandle_t                       handle,
    eidnnDropoutDescriptor_t      &dropoutDesc,
    const Tensor<float, 4>         &x, // input data
    Tensor<float, 4>               &y // input data after dropout
    );
eidnnStatus_t eidnnDropoutBackward(
    eidnnHandle_t                   handle,
    const eidnnDropoutDescriptor_t  dropoutDesc,
    const Tensor<float, 4>       &dy, // gradient of dropout output data
    Tensor<float, 4>             &dx // gradient of dropout input data
    );

Please star this project [https://github.com/jundaf2/eigenMHA] if you find it useful~


http://www.niftyadmin.cn/n/178391.html

相关文章

Windows Server 2016 中文版、英文版下载 (updated Mar 2023)

Windows Server 2016 Version 1607&#xff0c;2023 年 3 月更新 请访问原文链接&#xff1a;https://sysin.org/blog/windows-server-2016/&#xff0c;查看最新版。原创作品&#xff0c;转载请保留出处。 作者主页&#xff1a;sysin.org 本站将不定期发布官方原版风格月度更…

记录--你还在傻傻的npm run serve吗?快来尝尝这个!

这里给大家分享我在网上总结出来的一些知识&#xff0c;希望对大家有所帮助 背景 大家在日常开发中应该经常会有需要切换不同环境地址的情况。当一个项目代码切换环境地址时&#xff0c;vue-cli没有能够感知文件的变化&#xff0c;所以代理的还是旧的地址&#xff0c;所以通常我…

springboot就业信息管理系统

041-springboot就业信息管理系统演示录像2022开发语言&#xff1a;Java 框架&#xff1a;springboot JDK版本&#xff1a;JDK1.8 服务器&#xff1a;tomcat7 数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09; 数据库工具&#xff1a;Navicat11 开发软件&…

GTC 2023的算力协奏曲,NVIDIA与宁畅“共舞”AI时代

开年以来&#xff0c;伴随着大模型的关键性突破&#xff0c;大众对人工智能的关注达到了全新高点。以深度学习为核心的AI&#xff0c;正在邀世界各行各业一同起舞。AI计算硬件的进化&#xff0c;也必须与整个社会的算力需求&#xff0c;同频共振。大模型让算力需求指数级增长&a…

字节面试代码题二解析

一.题目描述 给你一个高度数组Hi&#xff0c;每提升Hi一个单位的代价是Wi&#xff0c;求让相邻两个高度不同的最小代价。比如Hi [2&#xff0c;3&#xff0c;4&#xff0c;4]&#xff1b;Wi [1&#xff0c; 2&#xff0c;3&#xff0c;4]&#xff0c;就是让第三个提升1。 二.…

【数据库原理与应用 - 第二章】关系数据库基础(更新中)

目录 一、关系的概念 1、域 2、笛卡尔积 3、关系 4、相关术语 &#xff08;1&#xff09;候选码与主码 &#xff08;2&#xff09;主属性与非主属性 二、关系数据模型 1、关系模型的数据结构 2、关系操作与完整性约束 &#xff08;1&#xff09;关系操作语言 &…

YOLOv7-tiny网络结构图及yaml文件 详细备注

YOLOv7-tiny 整体网络结构图yolov7-tiny.yaml组件模块MXCBLSPPCSP结构图yaml构建代码MCB结构图yaml文件表示common.py代码参考整体网络结构图 yolov7-tiny.yaml # parameters nc: 80 # number of classes depth_multiple: 1.0 # model depth multiple width_multiple: 1.0 …

GPS定位

b’KaTeX parse error: Undefined control sequence: \n at position 68: …230323,,,A*7D\r\̲n̲ bGPVTG,T,M,0.039,N,0.073,K,A2D\r\n’ b’KaTeX parse error: Undefined control sequence: \n at position 75: …M,-2.4,M,,*4C\r\̲n̲ bGPGSA,A,3,26,32,31,28,16,10,2.87…