注意力计算

  • 注意力计算的三要素分别是:Query, Key,Value。而在自注意力计算中,三者则是等价的。
  • 结合如下图示例:一个序列有2个词元,每个词元有3个特征 ,即输入为(2, 3)
    • 每个Query词元会计算与其它词元Key的“相似度”(包括自己),再经过softmax(每行的和等于1)转换,得到 2 × 2 权重矩阵
    • 然后将其与Value矩阵进行乘法运算(2, 2) × (2, 3),得到新的(2, 3)输出结果
      • 形象理解:对于词元A的输出特征1,等于输入词元A, B的特征的加权和。
  • 多头注意力:本质上可以理解为将特征维度分成多个部分,每个部分称为一个“头”。每个头独立进行注意力计算,然后将所有头的输出合并在一起;以期学习不同的关系和模式。

image-20241027123614983

  • 注意力计算本身不涉及可学习参数。一般在input前,out后,各设置一层MLP线性变换。

参考:https://blog.csdn.net/God_WeiYang/article/details/131820781

0. 模拟数据

  • 通常情况下,注意力计算的输入数据拥有四个维度:(batch_size, num_heads, seq_length, head_dim)

    word embed = num_heads × head_dim

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# 注意与上面顺序不太相同
def sample_input(bach_size, seq_length, n_head, n_dim):
    q = torch.randn((bach_size, seq_length, n_head, n_dim)).to("cuda:0", torch.float16)
    k = torch.rand_like(q)
    v = torch.rand_like(q)
    return q, k, v

q, k, v = sample_input(32, 100, 8, 64)
q.shape, k.shape, v.shape
# (torch.Size([32, 100, 8, 64]),
#  torch.Size([32, 100, 8, 64]),
#  torch.Size([32, 100, 8, 64]))

最终词元的特征长度为:num_heads × head_dim

1. 手动计算

  • causal参数表示注意力掩码操作,常用于GPT生成模型中。表示计算序列中第n个词元时,只关注第1到n-1个(除了本身);
  • 具体通过torch.finfo(q.dtype).min设置为负无穷,则其softmax转换后的权重值为0。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
def custom_attention(q, k, v, causal=False):
    score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
    if causal:
        mask = torch.triu(torch.ones(score.shape[-2], score.shape[-1]), diagonal=1)
        mask = mask.masked_fill(mask==1, torch.finfo(q.dtype).min)
        mask = mask.to(q.device, q.dtype)
        score = score + mask
    attn = F.softmax(score, dim=-1)
    o = torch.matmul(attn, v)
    return o

q1 = q.transpose(1, 2) # torch.Size([32, 8, 100, 64])
k1 = k.transpose(1, 2)
v1 = v.transpose(1, 2)

o1 = custom_attention(q1, k1, v1)
o1.transpose(2, 3).shape
# torch.Size([32, 100, 8, 64])

2. Flash Attention 1/2

  • Flash Attention 一种高效的注意力计算方法,旨在优化 Transformer 模型中的注意力机制。它通过减少内存使用和提高计算速度来处理长序列输入。
  • https://github.com/Dao-AILab/flash-attention

2.1 安装

  • 之前学习的scGPT主要使用了FlashAttention-1,安装方法见之前的整理
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
# scgpt伪代码示例
from flash_attn.flash_attention import FlashAttention

# __init__
self.self_attn = FlashAttention(attention_dropout=attention_dropout)

# forward
pcpt_context, pcpt_attn_weights = self.self_attn(
    pcpt_qkv,
    key_padding_mask=pcpt_key_padding_mask,
    need_weights=need_weights,
    causal=self.causal, # If true, autoregressive modeling
)
  • 本次主要学习FlashAttention-2版本 (with Better Parallelism and Work Partitioning)。安装过程踩了很多坑,目前找到一种可信的方式。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
conda create -n flash python=3.10 mamba -y
# CUDA 11.8
mamba install cudatoolkit==11.8 -c nvidia
# torch 2.3.0
mamba install pytorch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 pytorch-cuda=11.8 -c pytorch -c nvidia
# nvcc
conda install nvidia/label/cuda-11.8.0::cuda-nvcc
nvcc -V
# 下载源文件,本地安装:https://github.com/Dao-AILab/flash-attention/releases
pip install 'flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl'

非本地安装方式为:

1
2
3
4
5
pip install ninja #加速build安装进程,实测发现只对下面的clone安装有作用

pip install flash-attn --no-build-isolation
# or clone github repo
python setup.py install

由于是在conda环境下安装的cuda,所以在上述过程中会出现类似cuda_runtime_api.h: No such file or directory的报错。查了很多教程,比较靠谱的方法是安装cudatoolkit-dev。这同时带来一个问题,目前其最高版本为11.7,需要安装与之对应的cuda环境,以及torch等

1
2
3
4
5
mamba install -c conda-forge cudatoolkit-dev
mamba install cudatoolkit==11.7 -c nvidia
mamba install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia

python setup.py install

2.2 使用

  • 输入维度要求一般是:(batch_size, seqlen, nheads, headdim)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func, flash_attn_kvpacked_func

# flash_attn_func常规计算
o2 = flash_attn_func(q, k, v)
o2.shape
# torch.Size([32, 100, 8, 64])

# flash_attn_qkvpacked_func打包计算
qkv_pack = torch.concat([q.unsqueeze(2), k.unsqueeze(2), v.unsqueeze(2)], dim=2)
# (batch_size, seqlen, 3, nheads, headdim)
qkv_pack.shape
# torch.Size([32, 100, 3, 8, 64])

flash_attn_qkvpacked_func(qkv_pack).shape
# torch.Size([32, 100, 8, 64])

3. scaled_dot_product_attention

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
import torch.nn.functional as F

# 输入要求(batch_size, nheads, seqlen, headdim)
q3 = q.transpose(1, 2) # torch.Size([32, 8, 100, 64])
k3 = k.transpose(1, 2)
v3 = v.transpose(1, 2)

o3 = F.scaled_dot_product_attention(q3, k3, v3)
o3.transpose(1, 2).shape
# torch.Size([32, 100, 8, 64])

4. Benchmark

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import math
import time
from einops import rearrange
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_func

def custom_attention(q, k, v, causal=False):
    score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
    if causal:
        mask = torch.triu(torch.ones(score.shape[-2], score.shape[-1]), diagonal=1)
        mask = mask.masked_fill(mask==1, torch.finfo(q.dtype).min)
        mask = mask.to(q.device, q.dtype)
        score = score + mask
    attn = F.softmax(score, dim=-1)
    o = torch.matmul(attn, v)
    return o

def pytorch_func(q, k, v, causal=False):
    o = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
    # o = F.scaled_dot_product_attention(q, k, v, is_causal=causal)[0]
    return o

def flash_attention(q, k, v, causal=False):
    o = flash_attn_func(q, k, v, causal=causal)
    return o
  • 定义函数,测试注意力计算的时间以及显存消耗
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def test(func_name, q, k, v, *args, **kwargs):
    if func_name in ["custom_attention", "pytorch_func"]:
        q = rearrange(q, "a b c d -> a c b d")
        k = rearrange(k, "a b c d -> a c b d")
        v = rearrange(v, "a b c d -> a c b d")
 
    torch.cuda.reset_peak_memory_stats() # 重置 CUDA 内存统计信息
    torch.cuda.synchronize()             # 确保所有 CUDA 操作完成后再继续
    # globals():字典,包含了当前作用域内的所有全局变量和函数
    for _ in range(5):
        o = globals()[func_name](q, k, v, *args, **kwargs)
    torch.cuda.synchronize()
    st = time.time()
    o = globals()[func_name](q, k, v, *args, **kwargs)
    torch.cuda.synchronize()
    tt = time.time() - st
    max_memory = torch.cuda.max_memory_allocated() // 2**20  #单位MB
    torch.cuda.empty_cache() # 清除未使用的内存(释放那些已被删除但未释放的内存)
 
    if func_name in ["custom_attention", "pytorch_func"]:
        o = rearrange(o, "a c b d -> a b c d")
 
    return o, tt, max_memory
  • 测试不同序列长度,三种计算时间以及显存消耗情况
    • (1)序列的长度越长时,pytorch func与flash attention计算优势越明显
    • (2)pytorch func与flash attention的差距不太明显。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
for seqlen in [256, 512, 1024]:
    print(f"## Sequence length: {seqlen}")
    q, k, v = sample_input(32, seqlen, 8, 64)
    ## (1)
    o, t, m = test("custom_attention", q, k, v, causal=False)
    print(f"custom pytorch time: {t:.6f}, peak memory: {m} MB")
    ## (2)
    pf_o, pf_t, pf_m = test("pytorch_func", q, k, v, causal=False)
    print(f"pytorch func time: {pf_t:.6f}, speedup: {t/pf_t:.2f}; peak memory: {pf_m} MB, save: {int((m-pf_m)/m*100)}%")
    assert torch.allclose(o, pf_o, rtol=1e-2, atol=1e-2)
    ## (3)
    fa_o, fa_t, fa_m = test("flash_attention", q, k, v, causal=False)
    print(f"flash attention time: {fa_t:.6f}, speedup: {t/fa_t:.2f}; peak memory: {fa_m} MB, save: {int((m-fa_m)/m*100)}%")
    assert torch.allclose(o, fa_o, rtol=1e-2, atol=1e-2)
    
# ## Sequence length: 256
# custom pytorch time: 0.000259, peak memory: 216 MB
# pytorch func time: 0.000058, speedup: 4.44; peak memory: 120 MB, save: 44%
# flash attention time: 0.000073, speedup: 3.54; peak memory: 96 MB, save: 55%

# ## Sequence length: 512
# custom pytorch time: 0.001135, peak memory: 384 MB
# pytorch func time: 0.000142, speedup: 7.98; peak memory: 120 MB, save: 68%
# flash attention time: 0.000154, speedup: 7.38; peak memory: 128 MB, save: 66%

# ## Sequence length: 1024
# custom pytorch time: 0.004096, peak memory: 1272 MB
# pytorch func time: 0.000478, speedup: 8.58; peak memory: 233 MB, save: 81%
# flash attention time: 0.000493, speedup: 8.30; peak memory: 249 MB, save: 80%
  • 测试不同特征长度,三种计算时间以及显存消耗情况
    • 词元的特征维度越大时,pytorch func与flash attention计算优势会逐渐下降。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
for seqlen in [256, 512, 1024]:
    print(f"## Sequence length: {seqlen}")
    q, k, v = sample_input(32, seqlen, 8, 64)
    ## (1)
    o, t, m = test("custom_attention", q, k, v, causal=False)
    print(f"custom pytorch time: {t:.6f}, peak memory: {m} MB")
    ## (2)
    pf_o, pf_t, pf_m = test("pytorch_func", q, k, v, causal=False)
    print(f"pytorch func time: {pf_t:.6f}, speedup: {t/pf_t:.2f}; peak memory: {pf_m} MB, save: {int((m-pf_m)/m*100)}%")
    assert torch.allclose(o, pf_o, rtol=1e-2, atol=1e-2)
    ## (3)
    fa_o, fa_t, fa_m = test("flash_attention", q, k, v, causal=False)
    print(f"flash attention time: {fa_t:.6f}, speedup: {t/fa_t:.2f}; peak memory: {fa_m} MB, save: {int((m-fa_m)/m*100)}%")
    assert torch.allclose(o, fa_o, rtol=1e-2, atol=1e-2)
    
# ## Sequence length: 32
# custom pytorch time: 0.003847, peak memory: 1128 MB
# pytorch func time: 0.000276, speedup: 13.92; peak memory: 105 MB, save: 90%
# flash attention time: 0.000286, speedup: 13.45; peak memory: 121 MB, save: 89%
# ## Sequence length: 64
# custom pytorch time: 0.004104, peak memory: 1272 MB
# pytorch func time: 0.000485, speedup: 8.45; peak memory: 233 MB, save: 81%
# flash attention time: 0.000500, speedup: 8.21; peak memory: 249 MB, save: 80%
# ## Sequence length: 128
# custom pytorch time: 0.004594, peak memory: 1512 MB
# pytorch func time: 0.000952, speedup: 4.82; peak memory: 457 MB, save: 69%
# flash attention time: 0.000956, speedup: 4.81; peak memory: 489 MB, save: 67%

补充~


5. Performer

Performer注意力:将复杂度降低到线性,使得它在处理长序列时更加高效;并从理论角度证明是可行的。

对于长度为L的输入序列,嵌入向量长度是d

计算常规Transfomer自注意力时 (下图左栏),其复杂度为O(L2) → Quadratic;

  • (1) 注意力矩阵:Q * (K)T = (L, d) * (d, L) = (L, L) → O(L2 * d) → O(L2)
  • (2) Softmax计算归一化权重:(L, L) = (L, L) → O(L2)
  • (3) 加权和表示:Q * (K)T * V = (L, L) * (L, d) = (L, d) → O(L2 * d) → O(L2)

Performer注意力计算的注意力计算 (下图右栏) 复杂度为O(L) → Linear

  • (1) 首先对Q与K进行随机特征映射(Random Feature Mapping)
    • Q:(L, d) → Q’ (L, r), K:(L, d) → K’ (L, r)
  • (2) 然后计算 (K’)T * V = (r, L) * (L, d) = (r, d) → O(Lrd) → O(L)
  • (3) 最后计算 Q’ * (K’)T * V = (L, r) * (r, d) = (L, d) → O(Lrd) → O(L)

计算复杂度时,可忽略常数项、低阶项,以及系数。

image-20241108151812052