注意力计算

  • 注意力计算的三要素分别是:Query, Key,Value。而在自注意力计算中,三者则是等价的。
  • 结合如下图示例:一个序列有2个词元,每个词元有3个特征 ,即输入为(2, 3)
    • 每个Query词元会计算与其它词元Key的“相似度”(包括自己),再经过softmax(每行的和等于1)转换,得到 2 × 2 权重矩阵
    • 然后将其与Value矩阵进行乘法运算(2, 2) × (2, 3),得到新的(2, 3)输出结果
      • 形象理解:对于词元A的输出特征1,等于输入词元A, B的特征的加权和。
  • 多头注意力:本质上可以理解为将特征维度分成多个部分,每个部分称为一个“头”。每个头独立进行注意力计算,然后将所有头的输出合并在一起;以期学习不同的关系和模式。

image-20241027123614983

  • 注意力计算本身不涉及可学习参数。一般在input前,out后,各设置一层MLP线性变换。

参考:https://blog.csdn.net/God_WeiYang/article/details/131820781

0. 模拟数据

  • 通常情况下,注意力计算的输入数据拥有四个维度:(batch_size, num_heads, seq_length, head_dim)

    word embed = num_heads × head_dim

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
# 注意与上面顺序不太相同
def sample_input(bach_size, seq_length, n_head, n_dim):
    q = torch.randn((bach_size, seq_length, n_head, n_dim)).to("cuda:0", torch.float16)
    k = torch.rand_like(q)
    v = torch.rand_like(q)
    return q, k, v

q, k, v = sample_input(32, 100, 8, 64)
q.shape, k.shape, v.shape
# (torch.Size([32, 100, 8, 64]),
#  torch.Size([32, 100, 8, 64]),
#  torch.Size([32, 100, 8, 64]))

最终词元的特征长度为:num_heads × head_dim

1. 手动Attention计算

  • causal参数表示注意力掩码操作,常用于GPT生成模型中。表示计算序列中第n个词元时,只关注第1到n-1个(除了本身);
  • 具体通过torch.finfo(q.dtype).min设置为负无穷,则其softmax转换后的权重值为0。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
def custom_attention(q, k, v, causal=False):
    score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
    if causal:
        mask = torch.triu(torch.ones(score.shape[-2], score.shape[-1]), diagonal=1)
        mask = mask.masked_fill(mask==1, torch.finfo(q.dtype).min)
        mask = mask.to(q.device, q.dtype)
        score = score + mask
    attn = F.softmax(score, dim=-1)
    o = torch.matmul(attn, v)
    return o

q1 = q.transpose(1, 2) # torch.Size([32, 8, 100, 64])
k1 = k.transpose(1, 2)
v1 = v.transpose(1, 2)

o1 = custom_attention(q1, k1, v1)
o1.transpose(2, 3).shape
# torch.Size([32, 100, 8, 64])

2. Flash Attention 1/2

  • Flash Attention 一种高效的注意力计算方法,旨在优化 Transformer 模型中的注意力机制。它通过减少内存使用和提高计算速度来处理长序列输入。
  • https://github.com/Dao-AILab/flash-attention

2.1 安装

  • 之前学习的scGPT主要使用了FlashAttention-1,安装方法见之前的整理
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
# scgpt伪代码示例
from flash_attn.flash_attention import FlashAttention

# __init__
self.self_attn = FlashAttention(attention_dropout=attention_dropout)

# forward
pcpt_context, pcpt_attn_weights = self.self_attn(
    pcpt_qkv,
    key_padding_mask=pcpt_key_padding_mask,
    need_weights=need_weights,
    causal=self.causal, # If true, autoregressive modeling
)
  • 本次主要学习FlashAttention-2版本 (with Better Parallelism and Work Partitioning)。安装过程踩了很多坑,目前找到一种可信的方式。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
conda create -n flash python=3.10 mamba -y
# CUDA 11.8
mamba install cudatoolkit==11.8 -c nvidia
# torch 2.3.0
mamba install pytorch==2.3.0 torchvision==0.18.0 torchaudio==2.3.0 pytorch-cuda=11.8 -c pytorch -c nvidia
# nvcc
conda install nvidia/label/cuda-11.8.0::cuda-nvcc
nvcc -V
# 下载源文件,本地安装:https://github.com/Dao-AILab/flash-attention/releases
pip install 'flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl'

非本地安装方式为:

1
2
3
4
5
pip install ninja #加速build安装进程,实测发现只对下面的clone安装有作用

pip install flash-attn --no-build-isolation
# or clone github repo
python setup.py install

由于是在conda环境下安装的cuda,所以在上述过程中会出现类似cuda_runtime_api.h: No such file or directory的报错。查了很多教程,比较靠谱的方法是安装cudatoolkit-dev。这同时带来一个问题,目前其最高版本为11.7,需要安装与之对应的cuda环境,以及torch等

1
2
3
4
5
mamba install -c conda-forge cudatoolkit-dev
mamba install cudatoolkit==11.7 -c nvidia
mamba install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia

python setup.py install

2.2 使用

  • 输入维度要求一般是:(batch_size, seqlen, nheads, headdim)
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func, flash_attn_kvpacked_func

# flash_attn_func常规计算
o2 = flash_attn_func(q, k, v)
o2.shape
# torch.Size([32, 100, 8, 64])

# flash_attn_qkvpacked_func打包计算
qkv_pack = torch.concat([q.unsqueeze(2), k.unsqueeze(2), v.unsqueeze(2)], dim=2)
# (batch_size, seqlen, 3, nheads, headdim)
qkv_pack.shape
# torch.Size([32, 100, 3, 8, 64])

flash_attn_qkvpacked_func(qkv_pack).shape
# torch.Size([32, 100, 8, 64])

3. F.scaled_dot_product_attention

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
import torch.nn.functional as F

# 输入要求(batch_size, nheads, seqlen, headdim)
q3 = q.transpose(1, 2) # torch.Size([32, 8, 100, 64])
k3 = k.transpose(1, 2)
v3 = v.transpose(1, 2)

o3 = F.scaled_dot_product_attention(q3, k3, v3)
o3.transpose(1, 2).shape
# torch.Size([32, 100, 8, 64])

4. Benchmark

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import math
import time
from einops import rearrange
import torch
import torch.nn.functional as F
from flash_attn import flash_attn_func

def custom_attention(q, k, v, causal=False):
    score = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(q.size(-1))
    if causal:
        mask = torch.triu(torch.ones(score.shape[-2], score.shape[-1]), diagonal=1)
        mask = mask.masked_fill(mask==1, torch.finfo(q.dtype).min)
        mask = mask.to(q.device, q.dtype)
        score = score + mask
    attn = F.softmax(score, dim=-1)
    o = torch.matmul(attn, v)
    return o

def pytorch_func(q, k, v, causal=False):
    o = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
    # o = F.scaled_dot_product_attention(q, k, v, is_causal=causal)[0]
    return o

def flash_attention(q, k, v, causal=False):
    o = flash_attn_func(q, k, v, causal=causal)
    return o
  • 定义函数,测试注意力计算的时间以及显存消耗
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def test(func_name, q, k, v, *args, **kwargs):
    if func_name in ["custom_attention", "pytorch_func"]:
        q = rearrange(q, "a b c d -> a c b d")
        k = rearrange(k, "a b c d -> a c b d")
        v = rearrange(v, "a b c d -> a c b d")
 
    torch.cuda.reset_peak_memory_stats() # 重置 CUDA 内存统计信息
    torch.cuda.synchronize()             # 确保所有 CUDA 操作完成后再继续
    # globals():字典,包含了当前作用域内的所有全局变量和函数
    for _ in range(5):
        o = globals()[func_name](q, k, v, *args, **kwargs)
    torch.cuda.synchronize()
    st = time.time()
    o = globals()[func_name](q, k, v, *args, **kwargs)
    torch.cuda.synchronize()
    tt = time.time() - st
    max_memory = torch.cuda.max_memory_allocated() // 2**20  #单位MB
    torch.cuda.empty_cache() # 清除未使用的内存(释放那些已被删除但未释放的内存)
 
    if func_name in ["custom_attention", "pytorch_func"]:
        o = rearrange(o, "a c b d -> a b c d")
 
    return o, tt, max_memory
  • 测试不同序列长度,三种计算时间以及显存消耗情况
    • (1)序列的长度越长时,pytorch func与flash attention计算优势越明显
    • (2)pytorch func与flash attention的差距不太明显。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
for seqlen in [256, 512, 1024]:
    print(f"## Sequence length: {seqlen}")
    q, k, v = sample_input(32, seqlen, 8, 64)
    ## (1)
    o, t, m = test("custom_attention", q, k, v, causal=False)
    print(f"custom pytorch time: {t:.6f}, peak memory: {m} MB")
    ## (2)
    pf_o, pf_t, pf_m = test("pytorch_func", q, k, v, causal=False)
    print(f"pytorch func time: {pf_t:.6f}, speedup: {t/pf_t:.2f}; peak memory: {pf_m} MB, save: {int((m-pf_m)/m*100)}%")
    assert torch.allclose(o, pf_o, rtol=1e-2, atol=1e-2)
    ## (3)
    fa_o, fa_t, fa_m = test("flash_attention", q, k, v, causal=False)
    print(f"flash attention time: {fa_t:.6f}, speedup: {t/fa_t:.2f}; peak memory: {fa_m} MB, save: {int((m-fa_m)/m*100)}%")
    assert torch.allclose(o, fa_o, rtol=1e-2, atol=1e-2)
    
# ## Sequence length: 256
# custom pytorch time: 0.000259, peak memory: 216 MB
# pytorch func time: 0.000058, speedup: 4.44; peak memory: 120 MB, save: 44%
# flash attention time: 0.000073, speedup: 3.54; peak memory: 96 MB, save: 55%

# ## Sequence length: 512
# custom pytorch time: 0.001135, peak memory: 384 MB
# pytorch func time: 0.000142, speedup: 7.98; peak memory: 120 MB, save: 68%
# flash attention time: 0.000154, speedup: 7.38; peak memory: 128 MB, save: 66%

# ## Sequence length: 1024
# custom pytorch time: 0.004096, peak memory: 1272 MB
# pytorch func time: 0.000478, speedup: 8.58; peak memory: 233 MB, save: 81%
# flash attention time: 0.000493, speedup: 8.30; peak memory: 249 MB, save: 80%
  • 测试不同特征长度,三种计算时间以及显存消耗情况
    • 词元的特征维度越大时,pytorch func与flash attention计算优势会逐渐下降。
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
for seqlen in [256, 512, 1024]:
    print(f"## Sequence length: {seqlen}")
    q, k, v = sample_input(32, seqlen, 8, 64)
    ## (1)
    o, t, m = test("custom_attention", q, k, v, causal=False)
    print(f"custom pytorch time: {t:.6f}, peak memory: {m} MB")
    ## (2)
    pf_o, pf_t, pf_m = test("pytorch_func", q, k, v, causal=False)
    print(f"pytorch func time: {pf_t:.6f}, speedup: {t/pf_t:.2f}; peak memory: {pf_m} MB, save: {int((m-pf_m)/m*100)}%")
    assert torch.allclose(o, pf_o, rtol=1e-2, atol=1e-2)
    ## (3)
    fa_o, fa_t, fa_m = test("flash_attention", q, k, v, causal=False)
    print(f"flash attention time: {fa_t:.6f}, speedup: {t/fa_t:.2f}; peak memory: {fa_m} MB, save: {int((m-fa_m)/m*100)}%")
    assert torch.allclose(o, fa_o, rtol=1e-2, atol=1e-2)
    
# ## Sequence length: 32
# custom pytorch time: 0.003847, peak memory: 1128 MB
# pytorch func time: 0.000276, speedup: 13.92; peak memory: 105 MB, save: 90%
# flash attention time: 0.000286, speedup: 13.45; peak memory: 121 MB, save: 89%
# ## Sequence length: 64
# custom pytorch time: 0.004104, peak memory: 1272 MB
# pytorch func time: 0.000485, speedup: 8.45; peak memory: 233 MB, save: 81%
# flash attention time: 0.000500, speedup: 8.21; peak memory: 249 MB, save: 80%
# ## Sequence length: 128
# custom pytorch time: 0.004594, peak memory: 1512 MB
# pytorch func time: 0.000952, speedup: 4.82; peak memory: 457 MB, save: 69%
# flash attention time: 0.000956, speedup: 4.81; peak memory: 489 MB, save: 67%

5. nn.MultiheadAttention及mask操作

torch.nn.MultiheadAttention

https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html

  • torch 实现的标准多头注意力层类,包含完整输入与输出的权重参数
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn

query = key = value = torch.randn(4, 10, 512)
# 实例化
mha = nn.MultiheadAttention(embed_dim=512, num_heads=8, 
                            batch_first=True, dropout=0.0)
# forward前向计算
out_put, attn_weight = mha(query, key, value, need_weights = True)

out_put.shape  # torch.Size([4, 10, 512])
attn_weight.shape  # torch.Size([4, 10, 10])


# 查看注意力层参数
for name, param in mha.named_parameters():
    print(f"Name: {name}")
    print(f"Shape: {param.shape}\n")
# Name: in_proj_weight
# Shape: torch.Size([1536, 512])

# Name: in_proj_bias
# Shape: torch.Size([1536])

# Name: out_proj.weight
# Shape: torch.Size([512, 512])

# Name: out_proj.bias
# Shape: torch.Size([512])

默认need_weights = True,即返回多头注意力矩阵计算结果。如果设置为False,use the optimized scaled_dot_product_attention and achieve the best performance for MHA.

前向计算的mask相关参数

key_padding_mask:用于标记一个序列中的pad填充字符,使得query不会关注与其的注意力。

  • 其shape通常是 (Batch_size, Seq_len)。
  • True表示是pad填充字符
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
key_padding_mask = torch.tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
                                 [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
                                 [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
                                 [0, 0, 0, 0, 1, 1, 1, 1, 1, 1]]).bool() #(4, 10)
out_put, _ = mha(query, key, value, need_weights = False,
                 key_padding_mask=key_padding_mask)
out_put[0,:3,:3]
# tensor([[ 0.3662, -0.3266, -0.2634],
#         [ 0.0777, -0.1854, -0.0766],
#         [ 0.2575, -0.0160, -0.1000]], grad_fn=<SliceBackward0>)

attn_mask:注意力掩码矩阵,直接对特定的query-key组合进行掩码操作

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
attn_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)
attn_mask = attn_mask.expand(4, 8, 10, 10)
attn_mask = attn_mask.reshape(4 * 8, 10, 10)
# torch.Size([32, 10, 10])
out_put, _ = mha(query, key, value, need_weights = False,
              attn_mask=attn_mask)
out_put[0,:3,:3]
# tensor([[ 0.3662, -0.3266, -0.2634],
#         [ 0.0777, -0.1854, -0.0766],
#         [ 0.2575, -0.0160, -0.1000]], grad_fn=<SliceBackward0>)

关于上述key_padding_mask的转换有如下值得注意的细节。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
mask = torch.randn(3, 4)   # (Batch_size, Seq_len), num_head = 2
## 直接repeat不是所期望的结果
mask.unsqueeze(1).repeat(2,1,1)  # repeat all batch as whole and cannot match the multi-head
# tensor([[[ 0.5375,  1.3711, -0.3028,  1.0184]],
#         [[-0.1414,  1.3067, -0.0363,  0.8482]],
#         [[ 0.3573, -0.9005, -0.3998, -0.8608]],
#         [[ 0.5375,  1.3711, -0.3028,  1.0184]],
#         [[-0.1414,  1.3067, -0.0363,  0.8482]],
#         [[ 0.3573, -0.9005, -0.3998, -0.8608]]])

## 如下两种方式均是符合预期的转换
mask.unsqueeze(1).unsqueeze(2).expand(3, 2, 1, 4).reshape(6, 1, 4) # repeat each seq two multi-head (expected)
# tensor([[[ 0.5375,  1.3711, -0.3028,  1.0184]],
#         [[ 0.5375,  1.3711, -0.3028,  1.0184]],
#         [[-0.1414,  1.3067, -0.0363,  0.8482]],
#         [[-0.1414,  1.3067, -0.0363,  0.8482]],
#         [[ 0.3573, -0.9005, -0.3998, -0.8608]],
#         [[ 0.3573, -0.9005, -0.3998, -0.8608]]])

mask.unsqueeze(1).repeat_interleave(repeats=2, dim=0)  # repeat each seq two multi-head (expected)

6. Flash-Attn V2 mask操作

  • 在 flash_attn v1中的, FlashAttention注意力计算是支持key_padding_mask参数的;同时也提供了封装好的 FlashMHA注意力层(包括权重可学习参数)。
    • False代表填充字符,这与上面的MultiheadAttention相反
  • 在 flash_attn v2中,flash_attn_func本身是仅用于计算注意力过程,没有mask操作。
1
2
3
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func, flash_attn_kvpacked_func

help(flash_attn_func)
  • 在torch v2引入的F.scaled_dot_product_attention引入了Flash-Attn v2,具体参看上面第三点的介绍。值得注意是,它是支持attn_mask参数的。

attn_mask (optional Tensor) – Attention mask; shape must be broadcastable to the shape of attention weights, which is (N,…,L,S)(N,…,L,S). Two types of masks are supported. A boolean mask where a value of True indicates that the element should take part in attention. A float mask of the same type as query, key, value that is added to the attention score.

如上参数说明有两个注意点

(1)关于shape,不强制要求是(Batch_size*Num_head, Seq_len, Seq_len)。broadcastable也可以;

(2)与Flash-Attn一样,False表示填充/被掩码

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# 方式1
attn_mask = key_padding_mask.unsqueeze(1)
attn_mask.shape # torch.Size([4, 1, 10])
output = F.scaled_dot_product_attention(
        query, key, value, attn_mask=attn_mask
    )
output[0,:3,:3]
# tensor([[ 0.7855, -0.2025, -0.5377],
#         [ 0.1681, -2.8067,  0.0794],
#         [-0.1817,  0.0185,  0.1878]])
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
# 方式2
attn_mask = key_padding_mask.unsqueeze(1).repeat(1,10,1)
attn_mask.shape # torch.Size([4, 10, 10])
output = F.scaled_dot_product_attention(
        query, key, value, attn_mask=attn_mask
    )
output[0,:3,:3]
# tensor([[ 0.7855, -0.2025, -0.5377],
#         [ 0.1681, -2.8067,  0.0794],
#         [-0.1817,  0.0185,  0.1878]])

注意点:

(1)上述仅为单头注意力计算;

(2)F.scaled_dot_product_attention仅提供注意力计算,不能作为完整的注意力层(缺少权重参数)

核心:attn_mask与key_padding_mask参数并没有本质的区别。

补充~


SP1. Performer

Performer注意力:将复杂度降低到线性,使得它在处理长序列时更加高效;并从理论角度证明是可行的。

对于长度为L的输入序列,嵌入向量长度是d

计算常规Transfomer自注意力时 (下图左栏),其复杂度为O(L2) → Quadratic;

  • (1) 注意力矩阵:Q * (K)T = (L, d) * (d, L) = (L, L) → O(L2 * d) → O(L2)
  • (2) Softmax计算归一化权重:(L, L) = (L, L) → O(L2)
  • (3) 加权和表示:Q * (K)T * V = (L, L) * (L, d) = (L, d) → O(L2 * d) → O(L2)

Performer注意力计算的注意力计算 (下图右栏) 复杂度为O(L) → Linear

  • (1) 首先对Q与K进行随机特征映射(Random Feature Mapping)
    • Q:(L, d) → Q’ (L, r), K:(L, d) → K’ (L, r)
  • (2) 然后计算 (K’)T * V = (r, L) * (L, d) = (r, d) → O(Lrd) → O(L)
  • (3) 最后计算 Q’ * (K’)T * V = (L, r) * (r, d) = (L, d) → O(Lrd) → O(L)

计算复杂度时,可忽略常数项、低阶项,以及系数。

image-20241108151812052