效果概览:
好处:使用 google PALM 架构的小模型做 生成任务,改为 xformers 实现后,加速比为 2倍,显存消耗为原来的 1/3 ,非常给力。
缺点:相比pytorch的原生实现,误差略大。。。
xformers 官方github仓库:https://github.com/facebookresearch/xformers
xformers 官方文档:https://facebookresearch.github.io/xformers/
https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops
前两周 xformers 官方提供了 pypi 和 whl 轮包
windows 和 linux 均可用,最低版本要求 pytorch 1.13.1 版本
pip 安装 xformers
pip install -U xformers
如果需要用于编码器或需要位置偏置,则需要安装 0.17 以上版本
当前(2023/2/26) v0.17 为预发行版,需要使用 --pre 来安装
pip install --pre -U xformers
使用方法
import torch
from xformers.ops import memory_efficient_attention, LowerTriangularMask
device='cuda'
batch = 4
n_head = 8
head_dim = 16
seq_len = 128
q = torch.rand(batch, seq_len, n_head, head_dim).to(device)
k = torch.rand(batch, seq_len, n_head, head_dim).to(device)
v = torch.rand(batch, seq_len, n_head, head_dim).to(device)
# 使用 causal 掩码
o = memory_efficient_attention(q, k, v, LowerTriangularMask())
# 不使用编码
o = memory_efficient_attention(q, k, v)
# 使用自定义的 attn_bias,要求 xformers 版本 大于等于 0.17
## 这里的 from_len,to_len 分别代表Decoder的序列长度,Encoder的序列长度
from_len = seq_len
to_len = seq_len
attn_bias = torch.rand(batch, n_head, from_len, to_len).to(device)
o = memory_efficient_attention(q, k, v, attn_bias)
memory_efficient_attention 的大概的 等效pytorch实现
来自 https://facebookresearch.github.io/xformers/components/ops.html#module-xformers.ops
def memory_efficient_attention_pytorch(query, key, value, attn_bias=None, p=0., scale=None):
# q [batch, seq_len, n_head, head_dim]
# k [batch, seq_len, n_head, head_dim]
# v [batch, seq_len, n_head, head_dim]
# attn_bias [batch, n_head, seq_len, seq_len]
if scale is None:
scale = 1 / query.shape[-1] ** 0.5
query = query * scale
attn = query @ key.transpose(-2, -1)
if attn_bias is not None:
attn = attn + attn_bias
attn = attn.softmax(-1)
attn = F.dropout(attn, p)
return attn @ value