欢迎光临
我们一直在努力

别只知道 FlashAttention:带你算算 Attention 算子在不同 QKV 维度下的显存读写比例。

别只知道 FlashAttention 的效果好,理解其背后的原理——解决显存带宽瓶颈——对于优化深度学习模型至关重要。标准 Self-Attention 机制在序列长度 $L$ 较大时,其性能瓶颈并非是计算量(FLOPs),而是显存的读写带宽。

本文将通过量化计算,分析 QKV 维度 ($d_k$ 和 $d_v$) 如何影响 Attention 算子的计算强度 (Arithmetic Intensity, AI),并明确 $L^2$ 内存访问的统治地位。

1. 标准 Attention 算子的定义与维度

标准的 Self-Attention 计算过程可以概括为两个主要步骤:

$$\text{Attention}(Q, K, V) = \text{Softmax}(\frac{QK^T}{\sqrt{d_k}})V$$

假设输入维度如下(为简化分析,我们忽略了批次大小和多头):

  • 查询矩阵 $Q$: $L \times d_k$
  • 键矩阵 $K$: $L \times d_k$
  • 值矩阵 $V$: $L \times d_v$

其中 $L$ 是序列长度, $d_k$ 是 Q/K 的维度, $d_v$ 是 V 的维度(通常 $d_k = d_v = d$)。

2. 计算量 (FLOPs) 分析

我们主要关注矩阵乘法的计算量,忽略 Softmax 和缩放因子 ($\sqrt{d_k}$) 的开销。

步骤 A: $S = QK^T$

  • $Q$ ($L \times d_k$) 乘以 $K^T$ ($d_k \times L$),得到 $S$ ($L \times L$)。
  • FLOPs (乘加操作): $2 \cdot L \cdot L \cdot d_k = 2L^2 d_k$

步骤 B: $O = S_{norm}V$

  • $S_{norm}$ ($L \times L$) 乘以 $V$ ($L \times d_v$),得到输出 $O$ ($L \times d_v$ )。
  • FLOPs (乘加操作): $2 \cdot L \cdot d_v \cdot L = 2L^2 d_v$

总 FLOPs: $\text{Total FLOPs} \approx 2L^2 (d_k + d_v)$

结论: 计算量与 $L^2$ 成正比,且与维度 $d_k$ 和 $d_v$ 成线性正比。

3. 显存读写 (Memory Access) 分析

计算瓶颈的根源在于需要频繁地读写中间结果,特别是 $S$ 矩阵。假设所有数据类型均为 4 字节 (Float32)。

输入输出数据 (I/O)

  1. 读入 Q, K, V: $4 \cdot L d_k$ (Q) $+ 4 \cdot L d_k$ (K) $+ 4 \cdot L d_v$ (V)
  2. 写出 O: $4 \cdot L d_v$ (O)

中间结果数据 (Memory Overhead)

  1. 读写 Attention 矩阵 $S$: $S$ 矩阵维度为 $L \times L$,需要读入并写回主存(HBM)以进行 Softmax 和后续乘法。
    • $S$ 的读写开销: $2 \cdot 4 \cdot L^2 = 8L^2$

总显存读写 $M$ (Bytes):

$$\text{Total Memory} (M) \approx 4L(2d_k + 2d_v) + 8L^2$$

4. 显存读写比例(计算强度)分析

计算强度 (AI) 定义为 FLOPs 与 Memory Access 的比值: $AI = \text{FLOPs} / M$。高 AI 意味着操作是计算密集型(Compute-Bound),低 AI 意味着操作是内存密集型(Memory-Bound)。

$$AI \approx \frac{2L^2 (d_k + d_v)}{8L^2 + 8L(d_k + d_v)}$$

我们分析两种极端情况,假设 $d_k = d_v = d$:

场景一:短序列 ($L \approx d$)

如果 $L$ 和 $d$ 具有相似的量级(例如 $L=128, d=128$),那么 $L^2$ 和 $Ld$ 的项相似。

$$AI \approx \frac{4L^2 d}{8L^2 + 16Ld} \approx \frac{4L^3}{24L^2} = \frac{L}{6}$$

此时 AI 相对合理,但随着 $L$ 增大,情况恶化。

场景二:长序列 ($L \gg d$)

当序列长度 $L$ 远大于维度 $d$ 时(例如 $L=4096, d=64$),$L^2$ 项将占据主导地位。

$$AI \approx \frac{2L^2 (d_k + d_v)}{8L^2} = \frac{d_k + d_v}{4}$$

核心发现: 当 $L$ 足够大时,计算强度 $AI$ 最终只取决于 $d_k$ 和 $d_v$,而与 $L$ 无关!

以常见的 $d_k=64, d_v=64$ 为例, $AI \approx (64+64) / 4 = 32$.

这意味着我们每进行 32 次浮点运算,就需要从 HBM 中读写 1 字节的数据。对于现代 GPU(如 NVIDIA A100/H100,其计算能力远超 HBM 带宽),这是一个极低的计算强度,标准 Attention 算子是典型的内存带宽受限 (Bandwidth-Bound) 操作

5. 结论:FlashAttention 的原理支撑

标准的 Attention 机制将 $O(L^2)$ 尺寸的中间矩阵 $S$ 读写到慢速的 HBM(High Bandwidth Memory)中,导致性能瓶颈。

而 FlashAttention 通过分块计算 (tiling) 和重排序,确保 $S$ 矩阵的各个块始终停留在 GPU 核心内部的快速 SRAM (On-chip memory) 中,从而消除了 $O(L^2)$ 的 HBM 读写开销,将瓶颈重新导向计算密集型或 $O(Ld)$ 级别的 I/O 读写,显著提高了长序列的处理速度。

操作示例:量化显存开销

假设 $L=2048$, $d_k=128$. (Float32 = 4 bytes)

数据项 维度 字节数 (Bytes) 百万字节 (MB)
$Q$ $2048 \times 128$ $2048 \cdot 128 \cdot 4 = 1,048,576$ 1.05
$V$ $2048 \times 128$ $1,048,576$ 1.05
S (中间结果) $2048 \times 2048$ $2048^2 \cdot 4 = 16,777,216$ 16.78

对于一次 Attention 计算,仅 $S$ 矩阵的读写开销($2 \times 16.78$ MB)就远高于输入 QKV 的总和(约 $3 \times 1.05$ MB)。这是显存带宽成为瓶颈的最直观证明。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 别只知道 FlashAttention:带你算算 Attention 算子在不同 QKV 维度下的显存读写比例。
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址