在大型语言模型(LLM)的推理过程中,通常分为两个关键阶段:Prefill(预填充/处理Prompt)阶段和Decode(解码/自回归生成)阶段。这两个阶段对硬件资源的需求截然不同,理解它们的瓶颈对于优化推理性能至关重要。
1. 概念定义:Prefill 与 Decode
Prefill 阶段 (Compute-Bound)
Prefill 阶段是指模型并行处理用户输入的长序列Prompt(例如,一个长达几百个词的提问)。在这个阶段,模型需要一次性计算所有输入Token的表示,并初始化Key-Value (KV) Cache。
Decode 阶段 (Memory-Bound)
Decode 阶段是模型自回归地一个接一个生成新Token的过程。每生成一个新Token,都需要将其与先前所有已生成的Token(存储在KV Cache中)进行注意力计算。
2. Prefill 阶段:算力受限(Compute-Bound)
Prefill 阶段的特点是高并行度。假设输入序列长度为 $L_P$,隐藏层维度为 $H$。模型需要执行大规模的矩阵乘法(MatMul)来计算注意力权重和前馈网络。
算力分析
在这个阶段,计算量主要集中在Self-Attention机制上。虽然实际操作中会使用优化后的矩阵乘法内核,但理论上,注意力得分计算 $Q \times K^T$ 的计算复杂度约为 $O(L_P^2 \times H)$,而随后的多层感知机(MLP)计算复杂度约为 $O(L_P \times H^2)$。
由于 $L_P$ 通常较大(几百到几千),$H$ 也很大(几千到上万),Prefill 阶段产生了巨量的浮点运算(FLOPs)。
根据屋顶模型(Roofline Model),当算术强度(Arithmetic Intensity = FLOPs / Bytes)非常高时,性能的瓶颈就落在了计算单元的处理速度上(即GPU的TFLOPS)。
3. Decode 阶段:访存受限(Memory-Bound)
Decode 阶段的特点是低并行度,高访存需求。
访存分析
在生成第 $k$ 个Token时,我们只输入一个Query Token $Q_{new}$。此时,计算量(FLOPs)只与 $Q_{new}$ 和 KV Cache 中的 $k-1$ 个旧Token相关,计算复杂度约为 $O(k \times H)$。
虽然计算量相对较小,但是为了计算这个新的注意力,GPU必须从高带宽内存(HBM)中读取整个庞大且不断增长的KV Cache。
假设KV Cache存储了 $k$ 个Token,总大小约为 $k \times H \times 2 \times 2$ 字节(Key和Value,每个元素float16/bfloat16)。随着 $k$ 的增长,每次解码都需要读取巨大的数据块。
此时,算术强度极低(FLOPs很小,Bytes很大)。性能瓶颈转移到了内存带宽上(即GPU的内存读写速度)。我们称之为访存受限。
4. 示例代码:矩阵形状对比
以下使用Python/PyTorch概念代码演示 Prefill 和 Decode 阶段的矩阵尺寸差异,直观展示计算和访存的侧重不同(假设 $H=4096$, $L_P=512$):
import torch
H = 4096 # 隐藏维度
L_P = 512 # Prefill序列长度
B = 1 # 批量大小
L_current = 1000 # 当前KV Cache长度
# --- 1. Prefill 阶段 (Compute-Bound) ---
# 输入矩阵: (Batch, Seq_Len, Hidden_Dim)
Input_Prefill = torch.randn(B, L_P, H)
# 大规模矩阵乘法 (简化示例,模拟 Attention QK^T)
# 计算量大致与 L_P^2 成正比
print(f"Prefill 输入形状: {Input_Prefill.shape}")
# 核心计算是处理 L_P x L_P 的注意力矩阵,FLOPs 密集
# --- 2. Decode 阶段 (Memory-Bound) ---
# 新的 Query token: (Batch, 1, Hidden_Dim)
Q_new = torch.randn(B, 1, H)
# 已积累的 KV Cache: (Batch, L_current, Hidden_Dim)
K_cache = torch.randn(B, L_current, H)
V_cache = torch.randn(B, L_current, H)
# Attention 计算: Q_new 乘以 K_cache (1 x H) @ (H x L_current) -> (1 x L_current)
# FLOPs: 1 * H * L_current (相对 Prefill 小得多)
# **访存**: 必须读取 K_cache 和 V_cache (2 * L_current * H 字节)
print(f"Decode Q 形状: {Q_new.shape}")
print(f"KV Cache 形状: {K_cache.shape}, {V_cache.shape}")
print(f"解码阶段瓶颈在于读取 {L_current} 个 Token 对应的大块缓存数据")
5. 优化策略总结
由于瓶颈不同,LLM推理的优化也需要分阶段进行:
- Prefill 优化(针对算力): 侧重于使用高效的MatMul内核(如FlashAttention,它减少了HBM读写,但本质上是优化了计算效率),以及最大化GPU利用率。
- Decode 优化(针对访存): 侧重于减少KV Cache的大小和访存需求,例如:KV Cache量化(减少读取的字节数)、Paged Attention(提高内存碎片利用率)、或使用高带宽内存(如HBM3)。
汤不热吧