Transformer实战-系列教程11:SwinTransformer 源码解读4(WindowAttention类)

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

6、WindowAttention类

6.1 构造函数

class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))

        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += self.window_size[0] - 1
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)
  1. dim:输入特征维度
  2. window_size:窗口大小
  3. num_heads:多头注意力头数
  4. head_dim:每头注意力的头数
  5. scale :缩放因子
  6. relative_position_bias_table:相对位置偏置表,它对每个头存储不同窗口位置之间的偏置,以模拟位置信息
  7. coords_h 、coords_w、coords:窗口内每个位置的坐标
  8. coords_flatten :将坐标展平,为计算相对位置做准备
  9. 第1个relative_coords:计算窗口内每个位置相对于其他位置的坐标差
  10. 第2个relative_coords:重排坐标差的维度以符合预期的格式
  11. relative_coords[:, :, 0]、relative_coords[:, :, 1]、relative_coords[:, :, 0]:调整坐标差,使其能够映射到相对位置偏置表中的索引
  12. relative_position_index :计算每对位置之间的相对位置索引
  13. register_buffer:将相对位置索引注册为模型的缓冲区,这样它就不会在训练过程中被更新
  14. qkv :创建一个线性层,用于生成QKV
  15. attn_drop、proj、proj_drop:初始化注意力dropout、输出投影层及其dropout
  16. trunc_normal_:使用截断正态分布初始化相对位置偏置表
  17. softmax :初始化softmax层,用于计算注意力权重

6.2 前向传播

    def forward(self, x, mask=None):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)
        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x
  1. B_, N, C = x.shape原始输入: torch.Size([256, 49, 96]),B_, N, C即原始输入的维度
  2. qkv = self.qkv(x).reshape...qkv: torch.Size([3, 256, 3, 49, 32]),被重塑的一个五维张量,分别代表qkv三个维度、256个窗口、3个注意力头数但是不会一直是3越往后会越多、49是一个窗口有7*7=49元素、每个头的特征维度。在之前的Transformer以及Vision Transformer中,都是用x接上各自的全连接后分别生成QKV,这这里直接一起生成了。
  3. q: torch.Size([256, 3, 49, 32]),k: torch.Size([256, 3, 49, 32]),v: torch.Size([256, 3, 49, 32]),从qkv中分解出q、k、v,而且已经包含了多头注意力机制
  4. attn: torch.Size([256, 3, 49, 49]),attn是q和k的点积
  5. relative_position_bias: torch.Size([49, 49, 3]),从相对位置偏置表中索引出每对位置之间的偏置,并重塑以匹配注意力分数的形状
  6. relative_position_bias: torch.Size([3, 49, 49]),重新排列,位置编码在Transformer中一直当成偏置加进去的,而这个位置编码是对一个窗口的,所以每一个窗口的都对应了相同的位置编码
  7. attn: torch.Size([256, 3, 49, 49]),将位置编码加到注意力分数上,到这里就算完了全部的注意力机制了
  8. attn: torch.Size([256, 3, 49, 49]),掩码加到注意力分数上,使用softmax函数归一化注意力分数,得到注意力权重,应用注意力dropout
  9. x: torch.Size([256, 49, 96]),使用注意力权重对v向量进行重构,然后对结果进行转置和重塑
  10. x: torch.Size([256, 49, 96]),将加权的注意力输出通过一个线性投影层,应用输出dropout,这就是最后WindowAttention的输出,一共256个窗口,每个窗口有49个特征,每个特征对应96维的向量

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)


http://www.niftyadmin.cn/n/5369275.html

相关文章

Spring Boot 笔记 002 整合mybatis做数据库读取

概念 MyBatis 是一款优秀的持久层框架,它支持自定义 SQL、存储过程以及高级映射。MyBatis 免除了几乎所有的 JDBC 代码以及设置参数和获取结果集的工作。MyBatis 可以通过简单的 XML 或注解来配置和映射原始类型、接口和 Java POJO(Plain Old Java Objec…

SpringCloud-Eureka原理分析

Eureka是Netflix开源的一款用于实现服务注册与发现的工具。在微服务架构中,服务的动态注册和发现是必不可少的组成部分,而Eureka正是为了解决这一问题而诞生的。 一、为何需要Eureka 在微服务架构中,服务之间的协同合作和高效通信是至关重要…

计算机网络相关题目及答案(第五章)

第五章 复习题: R2. 基于逻辑上集中控制的控制平面意味着什么?在这种有情况下,数据平面和控制平面是在相同的设备或在分离的设备中实现的吗?请解释。 答:基于逻辑上集中控制的控制平面意味着控制平面的具体实现不在每个路由器中, 而是在某个集中的地方(服务器). 这种情…

详解各种LLM系列|LLaMA 1 模型架构、预训练、部署优化特点总结

作者 | Sunnyyyyy 整理 | NewBeeNLP https://zhuanlan.zhihu.com/p/668698204 后台留言『交流』,加入 NewBee讨论组 LLaMA 是Meta在2023年2月发布的一系列从 7B到 65B 参数的基础语言模型。LLaMA作为第一个向学术界开源的模型,在大模型爆发的时代具有标…

Unity SRP 管线【第十讲:SRP/URP 图形API】

Unity 封装的图形API 文章目录 Unity 封装的图形API一、 CommandBuffer 要执行的图形命令列表1. CommandBuffer 属性2. CommandBuffer 常用图形API(方法)(1)设置(2)获取临时纹理 GetTemporaryRT以及释放(3)设置纹理为渲染目标 SetRenderTarget(4)Command…

【flink状态管理(2)各状态初始化入口】状态初始化流程详解与源码剖析

文章目录 1. 状态初始化总流程梳理2.创建StreamOperatorStateContext3. StateInitializationContext的接口设计。4. 状态初始化举例:UDF状态初始化 在TaskManager中启动Task线程后,会调用StreamTask.invoke()方法触发当前Task中算子的执行,在…

泽攸科技ZEM系列台扫助力环境科研创新:可见光催化抗生素降解的探索

环境污染和能源短缺是当今人类社会面临的最严重威胁之一。为了克服这些问题,特别是在污水处理过程中,寻找新的技术来实现清洁、高效、经济的发展显得尤为重要。在各种工业废水中,抗生素的过量排放引起了广泛关注。抗生素的残留会污染土壤、水…

Android:IntentActivity,Service,BroadcastReceiver

3.14 Android三大组件 1、Intent页面跳转 Intent(意图):将要做某一件事。Android的3大组件:Activity、Service、BroadcastReceiver,通过Intent启动,并且Intent可以携带数据。 Intent类方法setComponent()设置组件; setClass(packageContext,cls)设置类、 setActi…