LLM2022· NeurIPS 2022· CLASSIC

FlashAttention: Fast and Memory-Efficient Exact Attention

Dao et al.

用 tiling + recomputation 把 O(n²) 注意力从 HBM IO-bound 变成 SRAM compute-bound——**精确** attention,不近似。

arXiv:2205.14135
#attention#inference#efficiency#systems

核心贡献

  • 01发现瓶颈:标准 Attention 在 GPU 上是 memory-bound(大量 HBM 读写 N×N 矩阵)
  • 02Tiling:把 Q、K、V 分块装入 SRAM(L1 缓存),避免中间结果物化
  • 03在线 Softmax 技巧:增量维护 max 和 denominator 统计量
  • 04Backward 用 recomputation 替代存储中间激活——内存 O(N²) → O(N)
  • 05在 A100 上比标准 attention 快 2-4×,且是**精确**计算(不是近似)

为什么标准 Attention 慢?

GPU 内存层级:SRAM(快、小)HBM(慢、大)。A100 的 HBM 带宽 ~1.5TB/s,但 SRAM 带宽 19TB/s。

标准 Attention 的三步:

text
1S = Q @ K^T # [N, N] 存在 HBM
2P = softmax(S) # 又一次读写 HBM
3O = P @ V # 第三次读写 HBM

每一步都要把 N² 的中间矩阵读写 HBM——当 N=8192 时,这是 256MB 一次读写。瓶颈在内存,不在算力

FlashAttention 的解法

Tiling(分块):把 Q 分成大小为 B_r 的行块,K 和 V 分成 B_c 的列块。每次从 HBM 加载一小块进 SRAM,在 SRAM 里算完 S → P → O 对应部分,累加到输出。

在线 Softmax:普通 softmax 需要知道整行的 max 和 sum;tiling 情况下每次只看一部分。论文用 online algorithm 维护 running max m_i 和 running sum ℓ_i,每加载新块就更新这些统计量,保持数学等价性。

mnew=max(mold,mblock)new=emoldmnewold+emblockmnewblockm_{new} = \max(m_{old}, m_{block}) \\ \ell_{new} = e^{m_{old} - m_{new}} \ell_{old} + e^{m_{block} - m_{new}} \ell_{block}

Backward 用 recomputation:前向不存 S 和 P,反向时从 Q、K、V 重算。省大量激活内存。

结果

  • 速度:GPT-2 训练 15%-25% 加速
  • 内存:N=4K 时峰值显存降 20%;N=16K 时降 10×
  • 数学精确的——和标准 Attention 数值完全等价(不像 Linformer、Performer 等近似方案)
面试视角

面试考点

"FlashAttention 为什么快?" 关键词:IO-aware。标准实现是 memory-bound,FlashAttention 通过 tiling 把读写 HBM 的次数从 O(N²) 降到 O(N²/M)(M 是 SRAM 大小)。算力没变,但内存 trip 少了。

"为什么不用 Linformer / Performer 这些 O(N) 方法?" 它们是近似 attention(投影到低维或用核函数);FlashAttention精确 attention。精确是生产级关键要求。

"Recomputation 为什么值得?" FLOPs vs memory trade-off。现代 GPU 算力增速 > 带宽增速,重算比读 HBM 便宜——这是 FlashAttention 的反直觉洞见。

FlashAttention-2(2023):进一步优化 warp 级调度和非 matmul 运算比例;FlashAttention-3(2024):针对 H100 的 TMA 和 WGMMA 指令专门优化。

工程扩展

  • KV Cache + FlashAttention 用于 decoding——FlashDecoding
  • Sparse variant:Block-sparse FlashAttention
  • 应用到 vision:Swin、DiT 都改造了 attention kernel

常见陷阱FlashAttention 改变 Attention 的计算复杂度(仍是 O(N²d)),只是改了内存访问模式。面试被问时要区分 compute 和 memory。

相关论文