大型语言模型(LLM)在生成文本时采用自回归(Autoregressive)方式,即逐词生成。虽然这种方式保证了生成内容的连贯性,但也带来了严重的性能挑战,尤其是在长序列生成时。核心问题在于Transformer模型中的自注意力(Self-Attention)机制。KV Cache正是为了解决这一瓶颈而诞生的。
1. Transformer自注意力机制的瓶颈
在标准的Transformer解码器中,要生成当前时间步 $t$ 的词元 $Y_t$,模型必须让当前的Query $Q_t$ 关注(Attend)到序列中所有过去的词元 $Y_{1…t-1}$。
自注意力机制的计算公式如下:
$$Attention(Q, K, V) = Softmax(\frac{QK^T}{\sqrt{d_k}})V$$
在自回归推理过程中,生成第 $t$ 个词元时,我们计算 $Q_t$,但必须使用整个历史序列 $K_{1…t}$ 和 $V_{1…t}$。如果没有缓存机制,每生成一个新的词元 $Y_{t+1}$,模型都需要重新计算从 $Y_1$ 到 $Y_t$ 整个序列的 $K$ 和 $V$ 投影矩阵,并进行矩阵乘法。随着序列长度 $L$ 的增加,每一步的计算复杂度大约是 $O(L^2)$,导致延时迅速攀升。
2. KV Cache的本质是什么?
KV Cache,即Key-Value Cache,是指在LLM的自回归推理过程中,将Transformer的自注意力模块在处理完历史词元后生成的Key(K)和Value(V)中间表示存储起来的机制。
本质定义:
- K(Key)和 V(Value): 它们是输入词嵌入经过线性投影(权重矩阵 $W_K$ 和 $W_V$)生成的向量,代表了词元的信息内容。
- 缓存内容: KV Cache 存储的正是这些 K 和 V 向量,按Transformer层和注意力头进行划分。对于一个多层模型,每层都有自己的 K/V 缓存。
当模型生成下一个词元 $Y_{t+1}$ 时:
- 它只计算新的 $Q_{t+1}$,以及新的 $K_{t+1}$ 和 $V_{t+1}$。
- 它将 $K_{t+1}$ 和 $V_{t+1}$ 追加(Append)到之前缓存的 $K_{1…t}$ 和 $V_{1…t}$ 后面。
- 注意力计算时,使用 $Q_{t+1}$ 与完整的缓存 $[K_{1…t}, K_{t+1}]$ 进行乘积,避免了重新计算历史 $K_{1…t}$ 和 $V_{1…t}$。
3. KV Cache如何大幅降低首词后的延时
KV Cache主要带来了两个维度的性能提升:
- 计算量大幅减少: 在没有缓存时,生成长度为 $L$ 的序列,总计算量接近 $O(L^3)$。有了KV Cache后,每一步新词元的生成复杂度从 $O(L^2)$ 降低到 $O(L)$。这是因为矩阵乘法只涉及新的 Query (长度为 1) 与累积的 K (长度为 $L$),而不是整个历史序列的重新计算。
- 内存带宽优化: 虽然KV Cache需要额外的内存来存储K/V矩阵,但在实际推理中,计算效率往往受限于内存带宽(Memory Bandwidth)而不是计算单元(FLOPS)。通过避免重复读取和计算历史词元的Key和Value,KV Cache大大减轻了内存访问的负担,降低了首词后的延迟(Time to First Token 之后的部分)。
4. 实际操作示例:利用PyTorch模拟KV Cache
以下是一个概念性的PyTorch示例,展示了在自回归解码中,如何利用缓存来避免重新计算K和V。
import torch
# 假设词嵌入维度 d_k = 64
d_k = 64
# 假设序列中已经缓存了 L_cached = 10 个词元
L_cached = 10
# 1. 初始化模拟的KV Cache (Key/Value的投影结果)
# 形状:(batch_size, num_heads, sequence_length, head_dim)
cached_keys = torch.randn(1, 12, L_cached, d_k)
cached_values = torch.randn(1, 12, L_cached, d_k)
print(f"缓存初始化长度: {L_cached}")
# 2. 生成下一个 Token (t=11)
# 新的 Query, Key, Value 都是针对长度为 1 的序列
new_query = torch.randn(1, 12, 1, d_k)
new_key = torch.randn(1, 12, 1, d_k)
new_value = torch.randn(1, 12, 1, d_k)
# -------- KV Cache 核心步骤:更新缓存 --------
# 3. 拼接 K 和 V,形成完整的 Key/Value 序列
current_keys = torch.cat([cached_keys, new_key], dim=2)
current_values = torch.cat([cached_values, new_value], dim=2)
L_new = current_keys.shape[2]
print(f"新序列总长度 L_new: {L_new}")
# 4. 计算注意力分数
# Q (1x1) 只需要与完整的 K (1xL_new) 进行矩阵乘法
# scores 形状: (batch_size, num_heads, 1, L_new)
scores = torch.matmul(new_query, current_keys.transpose(-2, -1)) / (d_k ** 0.5)
# 5. 聚合 Value
attention_output = torch.matmul(scores.softmax(dim=-1), current_values)
print(f"注意力输出形状: {attention_output.shape}")
print("通过缓存,我们成功避免了重新计算 L_cached 个 Token 的 K 和 V 投影。")
在上述示例中,如果不对K/V进行缓存,模型需要重新计算包含所有 11 个词元的 $K$ 和 $V$ 矩阵,从而重复了大量的计算工作。KV Cache通过空间换时间的方式,有效保证了LLM推理的高效性和低延迟。
汤不热吧