pytorch 使用 xformers 库 加速多头注意力计算 和 大幅节省显存

news/2024/7/19 11:58:35 标签: pytorch, 深度学习, xformers, transformer

效果概览:
好处:使用 google PALM 架构的小模型做 生成任务,改为 xformers 实现后,加速比为 2倍,显存消耗为原来的 1/3 ,非常给力。
缺点:相比pytorch的原生实现,误差略大。。。

xformers 官方github仓库:https://github.com/facebookresearch/xformers
xformers 官方文档:https://facebookresearch.github.io/xformers/
https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops

前两周 xformers 官方提供了 pypi 和 whl 轮包
windows 和 linux 均可用,最低版本要求 pytorch 1.13.1 版本

pip 安装 xformers

pip install -U xformers

如果需要用于编码器或需要位置偏置,则需要安装 0.17 以上版本
当前(2023/2/26) v0.17 为预发行版,需要使用 --pre 来安装

pip install --pre -U xformers

使用方法

import torch
from xformers.ops import memory_efficient_attention, LowerTriangularMask

device='cuda'
batch = 4
n_head = 8
head_dim = 16
seq_len = 128

q = torch.rand(batch, seq_len, n_head, head_dim).to(device)
k = torch.rand(batch, seq_len, n_head, head_dim).to(device)
v = torch.rand(batch, seq_len, n_head, head_dim).to(device)

# 使用 causal 掩码
o = memory_efficient_attention(q, k, v, LowerTriangularMask())

# 不使用编码
o = memory_efficient_attention(q, k, v)

# 使用自定义的 attn_bias,要求 xformers 版本 大于等于 0.17
## 这里的 from_len,to_len 分别代表Decoder的序列长度,Encoder的序列长度
from_len = seq_len
to_len = seq_len
attn_bias = torch.rand(batch, n_head, from_len, to_len).to(device)
o = memory_efficient_attention(q, k, v, attn_bias)

memory_efficient_attention 的大概的 等效pytorch实现
来自 https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops


def memory_efficient_attention_pytorch(query, key, value, attn_bias=None, p=0., scale=None):
    # q [batch, seq_len, n_head, head_dim]
    # k [batch, seq_len, n_head, head_dim]
    # v [batch, seq_len, n_head, head_dim]
    # attn_bias [batch, n_head, seq_len, seq_len]

    if scale is None:
        scale = 1 / query.shape[-1] ** 0.5

    query = query * scale
    attn = query @ key.transpose(-2, -1)
    if attn_bias is not None:
        attn = attn + attn_bias
    attn = attn.softmax(-1)
    attn = F.dropout(attn, p)
    return attn @ value

http://www.niftyadmin.cn/n/102010.html

相关文章

在Angular项目中引入NG-ZORRO

在Angular项目中引入NG-ZORRO1.前置2.安装NG-ZORRO并进行初始化配置3.引入样式4.引入组件1.前置 首先创建一个angular项目:angular创建一个新项目的步骤 这是我项目的结构: 2.安装NG-ZORRO并进行初始化配置 安装NG-ZORRO:cd 到当前项目位…

《SQL基础》12. SQL优化

SQL优化SQL优化数据插入insert优化大批量插入数据主键优化order by优化group by优化limit优化count优化count用法update优化SQL优化 数据插入 insert优化 如果我们需要一次性往数据库表中插入多条记录,可以从以下三个方面进行优化。 批量插入手动控制事务主键顺…

信号的FFT变换与加窗

1. fft 傅里叶变换 1.1 傅里叶变换的本质 数学上有一种公式叫做 泰勒展开: 泰勒公式: 其表达的思想,是任意一函数可以有多个指数函数构成 当指数函数的个数趋近于无穷多个,那么组合出来的函数将会逼近原函数; …

【Python从入门到进阶】9、流程控制语句-条件语句(if-else)

接上篇《8、Python的输入输出》 上一篇我们学习了Python的输入和输出相关内容。本篇我们来学习Python的控制流语句。 一、流程控制语句的含义 之前我们分别学习过“变量及数据类型”、“运算符”,其中“变量及数据类型”相当于我们学习自然语言中的“字”&#xf…

mysql如何给字符串字段加索引

现在许多系统支持邮箱登录,给这样的字段设置索引第一种:给String字段创建完整索引alter table User add index index(email);这种方式创建的索引,只需要回到主键索引上取一次值,但是比较占用空间第二种:给String字段创…

【离线数仓-9-数据仓库开发DWS层设计要点-1d/nd/td表设计】

离线数仓-9-数据仓库开发DWS层设计要点-1d/nd/td表设计离线数仓-9-数据仓库开发DWS层设计要点-1d/nd/td表设计一、DWS层设计要点二、DWS层设计分析 - 1d/nd1.DWS层设计一:不考虑用户维度2.DWS层设计二:考虑用户维度2.DWS层设计三 :考虑用户商…

尚硅谷nginx基础

nginx1. nginx安装1.1版本区别1.2安装步骤1.3 启动nginx1.4关于防火墙1.5 安装成系统服务1.6 配置nginx环境变量2. nginx基本使用2.1 基本运行原理2.2 nginx配置文件2.2.1 最小配置2.2.1.1 基本配置说明2.3 虚拟主机2.3.1域名、dns、ip地址的关系2.3.2IP地址和DNS地址的区别2.3…

Java EE|TCP/IP协议栈之应用层协议DNS详解

文章目录一、对DNS的感性认识简介特点一些常见疑问二、DNSDNS域名结构域名的分级三、域名服务器四、域名解析过程参考一、对DNS的感性认识 简介 DNS,即Domain Name System,是域名系统的简称。它是Internet上解决网上机器命名的一种系统。 TCP/IP中的IP地址是由四…