注意力计算
- 注意力计算的三要素分别是:Query, Key,Value。而在自注意力计算中,三者则是等价的。
- 结合如下图示例:一个序列有2个词元,每个词元有3个特征 ,即输入为(2, 3)
- 每个Query词元会计算与其它词元Key的“相似度”(包括自己),再经过softmax(每行的和等于1)转换,得到 2 × 2 权重矩阵
- 然后将其与Value矩阵进行乘法运算(2, 2) × (2, 3),得到新的(2, 3)输出结果
- 形象理解:对于词元A的输出特征1,等于输入词元A, B的特征的加权和。
- 多头注意力:本质上可以理解为将特征维度分成多个部分,每个部分称为一个“头”。每个头独立进行注意力计算,然后将所有头的输出合并在一起;以期学习不同的关系和模式。
- 注意力计算本身不涉及可学习参数。一般在input前,out后,各设置一层MLP线性变换。
参考:https://blog.csdn.net/God_WeiYang/article/details/131820781
0. 模拟数据
-
通常情况下,注意力计算的输入数据拥有四个维度:(batch_size, num_heads, seq_length, head_dim)
word embed = num_heads × head_dim
|
|
最终词元的特征长度为:num_heads × head_dim
1. 手动计算
- causal参数表示注意力掩码操作,常用于GPT生成模型中。表示计算序列中第n个词元时,只关注第1到n-1个(除了本身);
- 具体通过
torch.finfo(q.dtype).min
设置为负无穷,则其softmax转换后的权重值为0。
|
|
2. Flash Attention 1/2
- Flash Attention 一种高效的注意力计算方法,旨在优化 Transformer 模型中的注意力机制。它通过减少内存使用和提高计算速度来处理长序列输入。
- https://github.com/Dao-AILab/flash-attention
2.1 安装
- 之前学习的scGPT主要使用了FlashAttention-1,安装方法见之前的整理
|
|
- 本次主要学习FlashAttention-2版本 (with Better Parallelism and Work Partitioning)。安装过程踩了很多坑,目前找到一种可信的方式。
|
|
非本地安装方式为:
1 2 3 4 5
pip install ninja #加速build安装进程,实测发现只对下面的clone安装有作用 pip install flash-attn --no-build-isolation # or clone github repo python setup.py install
由于是在conda环境下安装的cuda,所以在上述过程中会出现类似
cuda_runtime_api.h: No such file or directory
的报错。查了很多教程,比较靠谱的方法是安装cudatoolkit-dev
。这同时带来一个问题,目前其最高版本为11.7,需要安装与之对应的cuda环境,以及torch等
1 2 3 4 5
mamba install -c conda-forge cudatoolkit-dev mamba install cudatoolkit==11.7 -c nvidia mamba install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.7 -c pytorch -c nvidia python setup.py install
2.2 使用
- 输入维度要求一般是:(batch_size, seqlen, nheads, headdim)
|
|
3. scaled_dot_product_attention
torch.nn.functional.scaled_dot_product_attention
是torch2.0版本更新后,新增的注意力加速计算方法- 其采用了包括Flash-attention2在内的三种加速算法
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- Memory-Efficient Attention
- A PyTorch implementation defined in C++ matching the above formulation
- All implementations are enabled by default. Scaled dot product attention attempts to automatically select the most optimal implementation based on the inputs
|
|
4. Benchmark
|
|
- 定义函数,测试注意力计算的时间以及显存消耗
|
|
- 测试不同序列长度,三种计算时间以及显存消耗情况
- (1)序列的长度越长时,pytorch func与flash attention计算优势越明显
- (2)pytorch func与flash attention的差距不太明显。
|
|
- 测试不同特征长度,三种计算时间以及显存消耗情况
- 词元的特征维度越大时,pytorch func与flash attention计算优势会逐渐下降。
|
|
补充~
5. Performer
- https://arxiv.org/pdf/2009.14794v4 2020 Google Research
- https://github.com/lucidrains/performer-pytorch Pytorch版本实现
Performer注意力:将复杂度降低到线性,使得它在处理长序列时更加高效;并从理论角度证明是可行的。
对于长度为L的输入序列,嵌入向量长度是d
计算常规Transfomer自注意力时 (下图左栏),其复杂度为O(L2) → Quadratic;
- (1) 注意力矩阵:Q * (K)T = (L, d) * (d, L) = (L, L) → O(L2 * d) → O(L2)
- (2) Softmax计算归一化权重:(L, L) = (L, L) → O(L2)
- (3) 加权和表示:Q * (K)T * V = (L, L) * (L, d) = (L, d) → O(L2 * d) → O(L2)
Performer注意力计算的注意力计算 (下图右栏) 复杂度为O(L) → Linear
- (1) 首先对Q与K进行随机特征映射(Random Feature Mapping)
- Q:(L, d) → Q’ (L, r), K:(L, d) → K’ (L, r)
- (2) 然后计算 (K’)T * V = (r, L) * (L, d) = (r, d) → O(Lrd) → O(L)
- (3) 最后计算 Q’ * (K’)T * V = (L, r) * (r, d) = (L, d) → O(Lrd) → O(L)
计算复杂度时,可忽略常数项、低阶项,以及系数。