FlashAttention: Fast and Memory-Efficient Exact Attention
Dao et al.
用 tiling + recomputation 把 O(n²) 注意力从 HBM IO-bound 变成 SRAM compute-bound——**精确** attention,不近似。
arXiv:2205.14135核心贡献
- 01发现瓶颈:标准 Attention 在 GPU 上是 memory-bound(大量 HBM 读写 N×N 矩阵)
- 02Tiling:把 Q、K、V 分块装入 SRAM(L1 缓存),避免中间结果物化
- 03在线 Softmax 技巧:增量维护 max 和 denominator 统计量
- 04Backward 用 recomputation 替代存储中间激活——内存 O(N²) → O(N)
- 05在 A100 上比标准 attention 快 2-4×,且是**精确**计算(不是近似)
为什么标准 Attention 慢?
GPU 内存层级:SRAM(快、小) → HBM(慢、大)。A100 的 HBM 带宽 ~1.5TB/s,但 SRAM 带宽 19TB/s。
标准 Attention 的三步:
1S = Q @ K^T # [N, N] 存在 HBM2P = softmax(S) # 又一次读写 HBM3O = P @ V # 第三次读写 HBM
每一步都要把 N² 的中间矩阵读写 HBM——当 N=8192 时,这是 256MB 一次读写。瓶颈在内存,不在算力。
FlashAttention 的解法
Tiling(分块):把 Q 分成大小为 B_r 的行块,K 和 V 分成 B_c 的列块。每次从 HBM 加载一小块进 SRAM,在 SRAM 里算完 S → P → O 对应部分,累加到输出。
在线 Softmax:普通 softmax 需要知道整行的 max 和 sum;tiling 情况下每次只看一部分。论文用 online algorithm 维护 running max m_i 和 running sum ℓ_i,每加载新块就更新这些统计量,保持数学等价性。
Backward 用 recomputation:前向不存 S 和 P,反向时从 Q、K、V 重算。省大量激活内存。
结果
- 速度:
GPT-2训练 15%-25% 加速 - 内存:N=4K 时峰值显存降 20%;N=16K 时降 10×
- 是数学精确的——和标准 Attention 数值完全等价(不像 Linformer、Performer 等近似方案)
面试考点
"FlashAttention 为什么快?" 关键词:IO-aware。标准实现是 memory-bound,FlashAttention 通过 tiling 把读写 HBM 的次数从 O(N²) 降到 O(N²/M)(M 是 SRAM 大小)。算力没变,但内存 trip 少了。
"为什么不用 Linformer / Performer 这些 O(N) 方法?" 它们是近似 attention(投影到低维或用核函数);FlashAttention 是精确 attention。精确是生产级关键要求。
"Recomputation 为什么值得?" FLOPs vs memory trade-off。现代 GPU 算力增速 > 带宽增速,重算比读 HBM 便宜——这是 FlashAttention 的反直觉洞见。
FlashAttention-2(2023):进一步优化 warp 级调度和非 matmul 运算比例;FlashAttention-3(2024):针对 H100 的 TMA 和 WGMMA 指令专门优化。
工程扩展:
KV Cache+FlashAttention用于 decoding——FlashDecoding- Sparse variant:Block-sparse
FlashAttention - 应用到 vision:Swin、DiT 都改造了 attention kernel
常见陷阱:FlashAttention 不改变 Attention 的计算复杂度(仍是 O(N²d)),只是改了内存访问模式。面试被问时要区分 compute 和 memory。