【AI】计算机视觉VIT文章(Transformer)源码解析

news/2024/7/19 10:33:02 标签: 人工智能, 计算机视觉, transformer

论文:Dosovitskiy A, Beyer L, Kolesnikov A, et al. An image is worth 16x16 words: Transformers for image recognition at scale[J]. arXiv preprint arXiv:2010.11929, 2020

源码的Pytorch版:https://github.com/lucidrains/vit-pytorch

0.前言

Transformer提出后在NLP领域中取得了极好的效果,其全Attention的结构,不仅增强了特征提取能力,还保持了并行计算的特点,可以又快又好的完成NLP领域内几乎所有任务,极大地推动自然语言处理的发展。
VIT这篇文章就是将Transformer模型应用在了CV领域,它将图像处理成Transformer模型可以应用的形式,沿用NLP领域中Transformer的方法,直接验证了其精度可以和ResNet不相上下,展示了在计算机视觉中使用纯Transformer结构的可能,为Transformer在CV领域的应用打开了大门。

直接读文章通常比较抽象,英文的原文更能劝退一大部分人,但对于程序员来说,代码是通行于世界的语言,理解起来就比较简单,结合源码理解论文中的结构,就比较事半功倍。

在这里插入图片描述

上图是VIT文章中的结构,我们看图提问题,从数据的流向来看:图像怎么切分重排的?Linear Projection of Flattened Patches对图像作了什么,怎么让图像变成Transformer能够输入的格式?
Position Embedding是怎么做图像位置编码的,为什么会多出来一个0的位置编码?Transformer Encoder中的各个结构分别代表什么,是怎么实现的?输出的类别是什么?

1.图像是如何适配Transformer输入的?

Transformer的输入是一个一维的向量,而我们的图像是二维的,需要把图像拉伸成一维的,最简单的方式就是沿着x轴展开,将所有的行拼接在一起,也就是Flatten的操作,但是这样处理会导致向量维度比较大,而且同一张图片也只能生成一个Embedding,不能适配Multi-head Embedding(这种解释有点牵强,有点拿结果去解释原因)。

1.1 图像切分重排

如结构图所示,VIT采取了图像切分重排的方式,将一个完整的图像按照行列的方向切分成小块,然后再进行后续处理。切分重排是怎么实现的,我们看代码:

self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), #图片切分重排
            nn.Linear(patch_dim, dim), # Linear Projection of Flattened Patches
        )

这里的Rearrange使用到了einops的库,相关的介绍可以查看文档
这里主要是对b c (h p1) (w p2) -> b (h w) (p1 p2 c)表达式的理解,b是batchsize,c是chanel,h和w是图像的高度和宽度,表达式前边的(h p1)表示将输入的h重新拆分成一个h*p1的向量,同理(w p2)是将输入的w拆分成w*p2的向量,这里的p1和p2是模型定义的patch_height和patch_width,可以理解为切分后小图的高和宽;
表达式右边表示输出向量的维度,(p1 p2 c)表示这3个数相乘,表示的是切分后一个小图的一维向量大小,(h w)则表示总共有多少个切分后的小图所生成的向量。这里的h和w的值跟输入的h和w值不同,表示的是原有h和w除以patch_height或patch_width的值,也就是在高和宽上各能切分出几个小图。
通过这样的处理,就将输入的c个channel的h*w的二维向量,转换成小图展开后的一维向量。

1.2 构建patch0

图像在输入Transformer之前,还连接了一个patch0,这一步的操作可以认为是延续了nlp的操作,在VIT这边的操作,可以认为是将分类类别拼接上去。

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim=1)

1.3 Positional Embedding

为了在模型中引入位置信息,VIT引入了位置编码的形式,也就是图中的0、1、2、3…

x += self.pos_embedding[:, :(n + 1)]

然后将position_embedding与图像的Embedding相加。

2.Transformer Encoder中的各个结构分别代表什么,是怎么实现的?

结构中的Norm可以认为是归一化处理,MLP是多个全连接层,理解起来都比较简单。其实最主要的结构就是Multi-head Attention。
在这里插入图片描述

2.1 Attention 定义

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        #得到qkv
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

2.2 Attention 的理解

VIT不同于之前RNN的地方就是引入了Attention机制,其本质可以认为是相似度计算,是计算每一个输入值与其他输入值的相似度,然后带入到计算中,如图所示,q是输入的查询值,k是关键词,v是计算值,计算每一个q与其他k的相似度,然后再带入计算中去。
在这里插入图片描述

其对应的计算公式为:
在这里插入图片描述
相似度计算是用的点积的形式,除以根号下dk是为了抑制极端值,保证softmax之后数值不至于丧失梯度。对应的代码如下:

qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

然后过一个softmax

attn = self.attend(dots) # self.attend = nn.Softmax(dim = -1)
attn = self.dropout(attn)

然后跟value值相乘

out = torch.matmul(attn, v)

2.3 Multihead Attention

多头注意力机制可以认为是定义多个Attention(每个attention关注的重点不同)分别来对数据进行处理,这里从源码中的循环结构可以体现出来:

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

3.使用示例

这里可以参考官方的readme文档

$ pip install vit-pytorch

使用示例

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

http://www.niftyadmin.cn/n/5290423.html

相关文章

【linux】如何查看服务器磁盘IO性能

查看服务器磁盘IO性能 在服务器运维过程中,了解服务器的磁盘IO性能是非常重要的。磁盘IO性能直接影响到服务器的响应速度和处理能力。本文将介绍如何使用dd命令来查看服务器磁盘IO性能。 1. 什么是dd命令? dd命令是Linux系统中的一个非常强大的工具&a…

手动创建idea SpringBoot 项目

步骤一: 步骤二: 选择Spring initializer -> Project SDK 选择自己的JDK版本 ->Next 步骤三: Maven POM ->Next 步骤四: 根据JDK版本选择Spring Boot版本 11版本及以上JDK建议选用3.2版本,JDK为11版本…

LeetCode 75| 位运算

目录 338 比特位计数 136 只出现一次的数字 1318 或运算的最小翻转次数 338 比特位计数 class Solution { public:vector<int> countBits(int n) {vector<int>res(n 1);for(int i 0;i < n;i)res[i] cal(i);return res;}int cal(int num){int res 0;for(i…

uniapp 输入手机号并且正则校验

1.<input input“onInput” :value“phoneNum” type“number” maxlength“11”/> 3. method里面写 onInput(e){ this.phoneNum e.detail.value }, 4.调用接口时候校验正则 if (!/^1[3456789]\d{9}$/.test(this.phoneNum)) {uni.showToast({title: 请输入正确的手机号…

【Python百宝箱】个性化推荐算法探幽:从协同过滤到深度学习,推荐系统库选择指南

个性化推荐&#xff1a;Surprise、LightFM、Implicit、Cornac、TuriCreate、Spotlight全方位解析 前言 在当今数字时代&#xff0c;推荐系统在引导用户发现个性化内容、产品和体验方面发挥着至关重要的作用。本文深入探讨了推荐系统领域的多个先进工具和库&#xff0c;为读者…

Vue 自定义ip地址输入组件

实现效果&#xff1a; 组件代码 <template><div class"ip-input flex flex-space-between flex-center-cz"><input type"text" v-model"value1" maxlength"3" ref"ip1" :placeholder"placeholder"…

计算机网络(第八版)期末复习(第二章物理层)

重要已用加粗表示&#xff0c;这些是复习内容所以并没有包括许多细节&#xff0c;仅包括重要知识点方便快速过。 物理层考虑的是怎样才能在连接各种计算机的传输媒体上传输数据比特流&#xff0c;而不是指具体的传输媒体确定与传输媒体的接口的一些特性: 机械特性电气特性功能…

爬虫工作量由小到大的思维转变---<第三十章 Scrapy Redis 第一步(配置同步redis)>

前言: 要迈向scrapy-redis进行编写了;首要的一步是,如何让他们互通?也就是让多台电脑连一个任务(这后面会讲); 现在来做一个准备工作,配置好redis的同步!! 针对的是windows版本的redis同步,实现主服务和从服务共享一个redis库; 正文: 正常的redis for windows 的安装这里就…