欢迎光临
我们一直在努力

FlashAttention v1/v2/v3 演进史:它是如何通过减少显存读写让速度飞起来的

FlashAttention v1/v2 演进史:它是如何通过减少显存读写让速度飞起来的

自Transformer架构诞生以来,Attention机制一直是其核心但也是性能瓶颈所在。当序列长度 $N$ 增大时,标准Attention的计算复杂度和显存占用都按 $O(N^2)$ 增长。FlashAttention的出现彻底改变了这一现状,它通过精妙的I/O感知算法设计,极大地减少了对GPU高带宽显存(HBM)的访问,从而实现了革命性的加速。

1. 标准Attention的HBM瓶颈

在标准的Attention计算 $O = \text{Softmax}(Q K^T) V$ 中,最大的性能瓶颈不在于浮点运算(FLOPs),而在于中间结果的存储和加载。

瓶颈分析:

  1. 计算相似度矩阵 $S = Q K^T$(大小 $N \times N$)。
  2. 对 $S$ 应用Softmax得到概率矩阵 $P$(大小 $N \times N$)。
  3. 计算最终输出 $O = P V$。

在GPU上,由于 $N \times N$ 的 $S$ 和 $P$ 矩阵通常太大,它们必须被写入速度较慢的高带宽显存 (HBM)。在反向传播时,这些矩阵又需要被从HBM中读出,造成巨大的内存带宽压力和延迟,这就是所谓的“HBM墙”。

2. FlashAttention的核心机制:I/O感知的平铺计算

FlashAttention(FA)的核心思想是利用GPU上的快速片上存储器(SRAM/寄存器),通过分块计算(Tiling)策略,确保 $Q K^T$ 和 Softmax 过程的中间结果不会被写入HBM。

2.1 关键优化点:减少I/O

FlashAttention将 $Q, K, V$ 划分为小块,并迭代计算。

  1. Tiling/Blocking(分块): 将 $Q, K, V$ 切分成适合放入SRAM的小块 $Q_i, K_j, V_j$。
  2. 在线Softmax计算: 在每一步迭代中,计算块级的 $Q_i K_j^T$,并实时更新Softmax的归一化因子(最大值 $m_i$ 和分母 $l_i$)。这使得 $N \times N$ 的中间相似度矩阵 $S$ 永远不必完全物化并写入HBM。
  3. 梯度重计算(Recomputation): 在反向传播时,标准Attention需要保存 $P$ 矩阵($O(N^2)$ 内存)。FlashAttention选择不保存 $P$ 矩阵,而是在反向传播时,重新计算部分正向传播的步骤。这用少量的额外计算时间(FLOPs)换取了巨大的内存带宽节省和 $O(N)$ 的内存复杂度。

通过这些技术,FlashAttention将 Attention 的内存I/O复杂度从 $O(N^2)$ 降低到 $O(N^2 / D)$(其中 $D$ 是分块大小),在长序列上获得了巨大的加速。

3. FlashAttention的演进:v1 到 v2

虽然 v1 解决了主要的I/O瓶颈,但 v2 更进一步,专注于提高硬件利用率(GFLOPS)。

  • FlashAttention v1: 核心目标是减少HBM读写。主要性能提升来自于消除中间 $N \times N$ 矩阵的HBM存储。
  • FlashAttention v2: 在 v1 的基础上,重新设计了Tiling策略和并行化方案,以更好地适应现代GPU的硬件特性。主要改进包括:
    1. 更好的线程块并行性: 优化了 $Q$ 和 $K$ 矩阵的平铺和加载方式,确保了更高的 SM (Streaming Multiprocessor) 利用率。
    2. 更优的MatMul调度: 提升了内部矩阵乘法和规约操作的效率。

简单来说,v1 是“I/O-bound to Compute-bound”的胜利,而 v2 是“Compute-bound”阶段的进一步优化,追求更高的GFLOPS。

(注:目前 FlashAttention v3 尚未作为标准优化库发布,通常指社区或研究中对 v2 的进一步细微改进。本文主要关注 v1/v2 的核心思想。)

4. 实操:在PyTorch中使用FlashAttention

从 PyTorch 2.0 开始,官方引入了 torch.nn.functional.scaled_dot_product_attention (SDPA) API。如果用户的环境(GPU型号,CUDA版本)支持,PyTorch会自动使用高度优化的内核,包括FlashAttention或Memory-Efficient Attention,极大地简化了用户的使用。

下面的代码演示了在支持FlashAttention的环境下,使用SDPA带来的性能提升:

“`python
import torch
import time

检查是否支持PyTorch 2.0+和CUDA环境

if not torch.cuda.is_available():
print(“需要CUDA环境运行此示例。”)
exit()

配置长序列模型参数

序列长度 4096,典型的大型模型训练长度

B, L, H, D = 4, 4096, 16, 64

Q = torch.randn(B, H, L, D, device=’cuda’, dtype=torch.float16)
K = torch.randn(B, H, L, D, device=’cuda’, dtype=torch.float16)
V = torch.randn(B, H, L, D, device=’cuda’, dtype=torch.float16)

— 1. 使用优化的 SDPA (自动启用 FlashAttention/M.E.A) —

FlashAttention 内核通常在序列长度 L 较大时表现出巨大优势

预热GPU

for _ in range(10):
_ = torch.nn.functional.scaled_dot_product_attention(Q, K, V)

start_fa = time.time()
ITERATIONS = 50
for _ in range(ITERATIONS):
output_fa = torch.nn.functional.scaled_dot_product_attention(Q, K, V)
# 确保不进行梯度计算,仅测量前向时间

torch.cuda.synchronize()
end_fa = time.time()

time_fa = (end_fa – start_fa) / ITERATIONS * 1000

理论上,我们可以通过手动计算来模拟非优化版本,但为了公平对比,我们直接展示SDPA的性能

注意:在某些环境下,为了强制关闭优化,需要设置环境变量或使用特定的手动实现,此处仅演示优化后的性能。

print(f”\n— FlashAttention 性能测试 —“)
print(f”序列长度 L={L}, 头数 H={H}, 批次 B={B}”)
print(f”使用 FlashAttention/SDPA 优化内核的平均时间: {time_fa:.4f} ms”)

结论:在长序列场景下,与未优化的版本相比,FlashAttention通常能提供 2x 至 4x 甚至更高的速度提升,并显著降低显存占用。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » FlashAttention v1/v2/v3 演进史:它是如何通过减少显存读写让速度飞起来的
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址