Transformer中的 Add Norm

news/2024/7/19 12:31:29 标签: transformer, 深度学习, 人工智能

Transformer中的 Add & Norm

flyfish

Add

同一个意思 Residual connections,Skip Connections
在这里插入图片描述

Norm

在这里插入图片描述
包括Post layer normalization和Pre layer normalization
Post layer normalization:Transformer 论文中使用的方式,将 Layer normalization 放在 Skip Connections 之间
在这里插入图片描述

class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, d_ff=2048):
        super(PoswiseFeedForwardNet, self).__init__()
        # 定义一维卷积层 1,用于将输入映射到更高维度
        self.conv1 = nn.Conv1d(in_channels=d_embedding, out_channels=d_ff, kernel_size=1)
        # 定义一维卷积层 2,用于将输入映射回原始维度
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_embedding, kernel_size=1)
        # 定义层归一化
        self.layer_norm = nn.LayerNorm(d_embedding)
    def forward(self, inputs): 
        #------------------------- 维度信息 -------------------------------- 
        # inputs [batch_size, len_q, embedding_dim]
        #----------------------------------------------------------------                       
        residual = inputs  # 保留残差连接 
        # 在卷积层 1 后使用 ReLU 激活函数 
        output = nn.ReLU()(self.conv1(inputs.transpose(1, 2))) 
        #------------------------- 维度信息 -------------------------------- 
        # output [batch_size, d_ff, len_q]
        #----------------------------------------------------------------
        # 使用卷积层 2 进行降维 
        output = self.conv2(output).transpose(1, 2) 
        #------------------------- 维度信息 -------------------------------- 
        # output [batch_size, len_q, embedding_dim]
        #----------------------------------------------------------------
        # 与输入进行残差链接,并进行层归一化
        output = self.layer_norm(output + residual) 
        #------------------------- 维度信息 -------------------------------- 
        # output [batch_size, len_q, embedding_dim]
        #----------------------------------------------------------------
        return output # 返回加入残差连接后层归一化的结果

Pre layer normalization:将 Layer Normalization 放置于 Skip Connections 的范围内。
(常用方式)

class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.linear_2(x)
        x = self.dropout(x)
        return x




class  EncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.feed_forward = FeedForward(config)

    def forward(self, x, mask=None):
        # Apply layer normalization and then copy input into query, key, value
        hidden_state = self.layer_norm_1(x)
        # Apply attention with a skip connection
        x = x + self.attention(hidden_state, hidden_state, hidden_state, mask=mask)
        # Apply feed-forward layer with a skip connection
        x = x + self.feed_forward(self.layer_norm_2(x))
        return x

初始化
定义两个LayerNorm层用于归一化输入,
一个MultiHeadAttention模块负责自注意力机制,以及一个FeedForward模块。

在前向传播过程中:

1 对输入先做LayerNorm得到标准化后的hidden_state。
2 使用自注意力机制对hidden_state进行处理,并与原始输入相加,实现残差连接。
3 再次对上一步的结果进行LayerNorm,并通过FeedForward层进行处理。
4将FeedForward层输出与经过自注意力机制后的结果相加,再次使用残差连接。
5 返回最终处理后的编码器层输出x。


http://www.niftyadmin.cn/n/5414772.html

相关文章

Springboot启动时设置自定义的banner(横幅)在控制台中显示

1、创建文件 1.1 创建banner.txt文件 在项目的src/main/resources目录下创建一个名为banner.txt的文件。 1.2 简单banner.txt 示例: 我的项目名- Spring Boot Application Version: ${spring-boot.version} ${application.version} spring-boot.version是Spring…

【Java代码审计】JNDI+RMI绕过高版本JDK的限制

【Java代码审计】JNDIRMI绕过高版本JDK的限制 1.高版本JDK利用注入导致的问题2.绕过分析3.Tomcat8绕过4.工具绕过 1.高版本JDK利用注入导致的问题 JDK 6u132、7u122、8u113 开始 com.sun.jndi.rmi.object.trustURLCodebase 默认值为false,运行时需加入参数 -Dcom.s…

VBA中类的解读及应用第十讲:限制文本框的输入,使其只能输入数值(上)

《VBA中类的解读及应用》教程【10165646】是我推出的第五套教程,目前已经是第一版修订了。这套教程定位于最高级,是学完初级,中级后的教程。 类,是非常抽象的,更具研究的价值。随着我们学习、应用VBA的深入&#xff0…

【数据结构】二、线性表:4.循环链表的定义及其基本操作(循环单链表,循环双链表的初始化、判空、判断头结点、尾结点、插入、删除)

文章目录 4.循环链表4.1循环单链表4.1.1初始化4.1.2判断单链表是否为空4.1.3判断p结点是否为循环单链表的表尾结点 4.2循环双链表4.2.1初始化4.2.2判断循环链表是否为空4.2.3判断结点p是否为循环双链表的表尾结点4.2.4双链表的插入4.2.5双链表的删除 4.循环链表 4.1循环单链表…

使用k8s前配置环境

安装 kubectl:首先,确保你已经安装了 kubectl 命令行工具。你可以从 Kubernetes 官方的 GitHub 仓库中获取适合你操作系统的发行版,或者使用适合你操作系统的包管理器进行安装。 获取集群凭证:通常情况下,你需要从 Ku…

LeetCode 刷题 [C++] 第139题.单词拆分

题目描述 给你一个字符串 s 和一个字符串列表 wordDict 作为字典。如果可以利用字典中出现的一个或多个单词拼接出 s 则返回 true。 注意:不要求字典中出现的单词全部都使用,并且字典中的单词可以重复使用。 题目分析 背包问题特征: 是否…

golang 实现http请求的调用,访问并读取页面数据和内置的一些方法

下午就不能好好学习一下golang,业务一直找个不停,自己定的业务规则都能忘得一干二净,让你查半天,完全是浪费时间。 golang实现访问并读取页面数据 package mainimport ("fmt""net/http" )var urls []string{…

mysql和oracle数据库的区别与联系(值得收藏)

1、mysql和oracle都是关系型数据库。 mysql默认端口:3306 默认用户root oracle默认端口 1521 默认用户system mysql的安装配置和卸载简单,oracle比较麻烦,严重的可能要你重做系统。 oracle在命令行用命令登陆:sqlplus---然后录…