别只知道 FlashAttention 的效果好,理解其背后的原理——解决显存带宽瓶颈——对于优化深度学习模型至关重要。标准 Self-Attention 机制在序列长度 $L$ 较大时,其性能瓶颈并非是计算量(FLOPs),而是显存的读写带宽。
本文将通过量化计算,分析 QKV 维度 ($d_k$ 和 $d_v$) 如何影响 Attention 算子的计算强度 (Arithmetic Intensity, AI),并明确 $L^2$ 内存访问的统治地位。
1. 标准 Attention 算子的定义与维度
标准的 Self-Attention 计算过程可以概括为两个主要步骤:
$$\text{Attention}(Q, K, V) = \text{Softmax}(\frac{QK^T}{\sqrt{d_k}})V$$
假设输入维度如下(为简化分析,我们忽略了批次大小和多头):
- 查询矩阵 $Q$: $L \times d_k$
- 键矩阵 $K$: $L \times d_k$
- 值矩阵 $V$: $L \times d_v$
其中 $L$ 是序列长度, $d_k$ 是 Q/K 的维度, $d_v$ 是 V 的维度(通常 $d_k = d_v = d$)。
2. 计算量 (FLOPs) 分析
我们主要关注矩阵乘法的计算量,忽略 Softmax 和缩放因子 ($\sqrt{d_k}$) 的开销。
步骤 A: $S = QK^T$
- $Q$ ($L \times d_k$) 乘以 $K^T$ ($d_k \times L$),得到 $S$ ($L \times L$)。
- FLOPs (乘加操作): $2 \cdot L \cdot L \cdot d_k = 2L^2 d_k$
步骤 B: $O = S_{norm}V$
- $S_{norm}$ ($L \times L$) 乘以 $V$ ($L \times d_v$),得到输出 $O$ ($L \times d_v$ )。
- FLOPs (乘加操作): $2 \cdot L \cdot d_v \cdot L = 2L^2 d_v$
总 FLOPs: $\text{Total FLOPs} \approx 2L^2 (d_k + d_v)$
结论: 计算量与 $L^2$ 成正比,且与维度 $d_k$ 和 $d_v$ 成线性正比。
3. 显存读写 (Memory Access) 分析
计算瓶颈的根源在于需要频繁地读写中间结果,特别是 $S$ 矩阵。假设所有数据类型均为 4 字节 (Float32)。
输入输出数据 (I/O)
- 读入 Q, K, V: $4 \cdot L d_k$ (Q) $+ 4 \cdot L d_k$ (K) $+ 4 \cdot L d_v$ (V)
- 写出 O: $4 \cdot L d_v$ (O)
中间结果数据 (Memory Overhead)
- 读写 Attention 矩阵 $S$: $S$ 矩阵维度为 $L \times L$,需要读入并写回主存(HBM)以进行 Softmax 和后续乘法。
- $S$ 的读写开销: $2 \cdot 4 \cdot L^2 = 8L^2$
总显存读写 $M$ (Bytes):
$$\text{Total Memory} (M) \approx 4L(2d_k + 2d_v) + 8L^2$$
4. 显存读写比例(计算强度)分析
计算强度 (AI) 定义为 FLOPs 与 Memory Access 的比值: $AI = \text{FLOPs} / M$。高 AI 意味着操作是计算密集型(Compute-Bound),低 AI 意味着操作是内存密集型(Memory-Bound)。
$$AI \approx \frac{2L^2 (d_k + d_v)}{8L^2 + 8L(d_k + d_v)}$$
我们分析两种极端情况,假设 $d_k = d_v = d$:
场景一:短序列 ($L \approx d$)
如果 $L$ 和 $d$ 具有相似的量级(例如 $L=128, d=128$),那么 $L^2$ 和 $Ld$ 的项相似。
$$AI \approx \frac{4L^2 d}{8L^2 + 16Ld} \approx \frac{4L^3}{24L^2} = \frac{L}{6}$$
此时 AI 相对合理,但随着 $L$ 增大,情况恶化。
场景二:长序列 ($L \gg d$)
当序列长度 $L$ 远大于维度 $d$ 时(例如 $L=4096, d=64$),$L^2$ 项将占据主导地位。
$$AI \approx \frac{2L^2 (d_k + d_v)}{8L^2} = \frac{d_k + d_v}{4}$$
核心发现: 当 $L$ 足够大时,计算强度 $AI$ 最终只取决于 $d_k$ 和 $d_v$,而与 $L$ 无关!
以常见的 $d_k=64, d_v=64$ 为例, $AI \approx (64+64) / 4 = 32$.
这意味着我们每进行 32 次浮点运算,就需要从 HBM 中读写 1 字节的数据。对于现代 GPU(如 NVIDIA A100/H100,其计算能力远超 HBM 带宽),这是一个极低的计算强度,标准 Attention 算子是典型的内存带宽受限 (Bandwidth-Bound) 操作。
5. 结论:FlashAttention 的原理支撑
标准的 Attention 机制将 $O(L^2)$ 尺寸的中间矩阵 $S$ 读写到慢速的 HBM(High Bandwidth Memory)中,导致性能瓶颈。
而 FlashAttention 通过分块计算 (tiling) 和重排序,确保 $S$ 矩阵的各个块始终停留在 GPU 核心内部的快速 SRAM (On-chip memory) 中,从而消除了 $O(L^2)$ 的 HBM 读写开销,将瓶颈重新导向计算密集型或 $O(Ld)$ 级别的 I/O 读写,显著提高了长序列的处理速度。
操作示例:量化显存开销
假设 $L=2048$, $d_k=128$. (Float32 = 4 bytes)
| 数据项 | 维度 | 字节数 (Bytes) | 百万字节 (MB) |
|---|---|---|---|
| $Q$ | $2048 \times 128$ | $2048 \cdot 128 \cdot 4 = 1,048,576$ | 1.05 |
| $V$ | $2048 \times 128$ | $1,048,576$ | 1.05 |
| S (中间结果) | $2048 \times 2048$ | $2048^2 \cdot 4 = 16,777,216$ | 16.78 |
对于一次 Attention 计算,仅 $S$ 矩阵的读写开销($2 \times 16.78$ MB)就远高于输入 QKV 的总和(约 $3 \times 1.05$ MB)。这是显存带宽成为瓶颈的最直观证明。
汤不热吧