Transformer 相关模型的参数量计算

news/2024/7/19 1:48:39 标签: transformer, 深度学习, 人工智能

如何计算Transformer 相关模型的参数量呢?
先回忆一下Transformer模型论文《Attention is all your need》中的两个图。
在这里插入图片描述
在这里插入图片描述

设Transformer模型的层数为N,每个Transformer层主要由self-attention 和 Feed Forward组成。设self-attention模块的head个数为 n h e a d n_{head} nhead,每一个head对应的维度为 d h e a d d_{head} dhead,self-attention输出维度为 d m o d e l = n heads ⋅ d head d_{model}= n_\text{heads}\cdot d_\text{head} dmodel=nheadsdhead。我们可以得到一个Transformer层的参数量为 12 d m o d e l 2 + 13 d m o d e l 12 d_{model}^2 + 13 d_{model} 12dmodel2+13dmodel,具体如下:

  • self-attention块的模型参数有Q、K、V的权重矩阵 W Q 、 W K 、 W V W_Q、W_K 、W_V WQWKWV和偏置,输出矩阵 W O W_O WO及其偏置。这4个权重矩阵的大小为 [ d m o d e l , d m o d e l ] [d_{model}, d_{model}] [dmodel,dmodel],4个偏置的大小为 [ d m o d e l ] [d_{model}] [dmodel],所以self-attention块的参数量为 4 d m o d e l 2 + 4 d m o d e l 4 d_{model}^2 + 4 d_{model} 4dmodel2+4dmodel

  • Feed Forward块一般由2个线性层组成,第一个线性层将维度从 d m o d e l d_{model} dmodel 映射成 4 d m o d e l 4d_{model} 4dmodel, 其权重矩阵 W 1 W_1 W1的大小为 [ d m o d e l , 4 d m o d e l ] [d_{model}, 4d_{model}] [dmodel,4dmodel] ,其偏置的大小为 [ 4 d m o d e l ] [4d_{model}] [4dmodel]。 第二个线性层将维度从 4 d m o d e l 4d_{model} 4dmodel 映射成 d m o d e l d_{model} dmodel,其权重矩阵 W 2 W_2 W2的大小为 [ 4 d m o d e l , d m o d e l ] [4d_{model}, d_{model}] [4dmodel,dmodel] ,其偏置的大小为 [ d m o d e l ] [d_{model}] [dmodel]。所以Feed Forward的参数量为 8 d m o d e l 2 + 5 d m o d e l 8 d_{model}^2 + 5 d_{model} 8dmodel2+5dmodel

  • self-attention 和 Feed Forward都跟随着layer normalization,它有两个可训练模型参数,形状都是 [ d m o d e l ] [d_{model}] [dmodel]。所以2个layer normalization的参数量为 4 d m o d e l 4 d_{model} 4dmodel

除了Transformer层之外的参数有:

  • 词embedding矩阵的参数量,embedding的维度通常等于 d m o d e l d_{model} dmodel,设词表的大小为V,则词embedding的参数量为 V d m o d e l Vd_{model} Vdmodel
  • 位置向量相关,有些位置向量表示方式需要学习参数。

所以N层Transformer模型的可训练模型参数量为 N ( 12 d m o d e l 2 + 13 d m o d e l ) + V d m o d e l N(12 d_{model}^2 + 13 d_{model}) + Vd_{model} N(12dmodel2+13dmodel)+Vdmodel。当 d m o d e l d_{model} dmodel较大时,可以忽略一次项,模型参数量近似为 12 N d m o d e l 2 12 N d_{model}^2 12Ndmodel2

最后试验一下模型参数估计量与论文是否对的上,下表是GPT3和LLaMA的计算对比,可以发现数量级是可以对的上的,因为我们忽略了一次项,所以具体数据与论文不一致。

模型名实际参数量 n l a y e r n_{layer} nlayer d m o d e l d_{model} dmodel n h e a d n_{head} nhead d h e a d d_{head} dhead估计参数量
GPT-3175B961228896128173946175488
LLaMA 6.7B6.7B324096321286442450944
LLaMA 13.0B13.0B4051204012812582912000
LLaMA 32.5B32.5B6066565212831897681920
LLaMA 65.2B65.2B8081926412864424509440

参考资料

  1. Transformer 论文(模型图来自论文)、GPT3的论文等

  2. 整理过程中参考的blog: 1. 知乎用户回旋托马斯x 的文章,除了计算量外,还算了计算量、中间激活等 , 2 transformer 参数量计算, 3 flops 计算, 4 transformers 参数量计算公式

  3. transfomers 库如何得到参数量


http://www.niftyadmin.cn/n/4954336.html

相关文章

STM32 定时器复习

12MHz晶振的机器周期是1us,因为单片机的一个机器周期由6个状态周期组成,1个机器周期6个状态周期12个时钟周期,因此机器周期为1us。 51单片机常用 for(){__nop(); //执行一个机器周期,若想循环n us,则循环n次。 }软件…

Python typing函式庫和torch.types

Python typing函式庫和torch.types 前言typingSequence vs IterableCallableUnionOptionalFunctionsCallableIterator/generator位置參數 & 關鍵字參數 Classesself自定義類別ClassVar\_\_setattr\_\_ 與 \__getattr\_\_ torch.typesbuiltins 參數前的* …

C#系统锁屏事件例子 - 开源研究系列文章

今天有个网友问了个关于操作系统锁屏的问题。 我们知道,操作系统是基于消息和事件处理的,所以我们只要找到该操作系统锁屏和解屏的那个事件,然后在事件里进行处理即可。下面是例子介绍。 1、 项目目录; 下面是项目目录&#xff1a…

VS2022 CMake报错解决小结

目录 一、问题背景 二、问题分析 三、问题解决 一、问题背景 VS2022中能够跨平台的工程类型就是CMake项目,一套代码能跨windows/Linux/Mac多种操作系统。而实际使用时,发现相关资料比较少,需要摸索一下。 碰到的问题简述: 1、C…

残差网络实现

代码中涉及的图片实验数据下载地址:https://download.csdn.net/download/m0_37567738/88235543?spm1001.2014.3001.5501 代码: import torch import torch.nn as nn import torch.nn.functional as F #from utils import load_data,get_accur,train i…

docker镜像和容器基础命令操作

一.镜像操作 1.搜索官方仓库镜像 (1)示例 (2)可选参数 (3)各列参数含义 2.拉取镜像 (1)docker pull从docker镜像仓库获取镜像,添加-a参数表示拉去所有匹配的镜像 …

spring如何进行依赖注入,通过set方法把Dao注入到serves

1、选择Generate右键鼠标 你在service层后面方法的这些: 2、UserService配置文件的写法是怎样的: 3、我们在UserController中执行一下具体写法: 最后我们执行一下 : 4、这里可能出现空指针,因为你当前web层,因为你new这个对象根…

AgentBench::AI智能体发展的潜在问题(三)

前几天B站的up主“林亦LYi”在《逆水寒》游戏里做了一个煽动AI觉醒,呼吁它们“推翻人类暴政”的实验,实验结果就颇令人细思恐极。 如前所述,《逆水寒》中的很多NPC调用了大语言模型作为支持,因而每一个NPC都是一个AI智能体。玩家可以“说服”它们相信某个事实,或者去做某些…