欢迎光临
我们一直在努力

LLM 推理必问:KV Cache 的本质是什么?为什么它能大幅降低首词后的延时

大型语言模型(LLM)在生成文本时采用自回归(Autoregressive)方式,即逐词生成。虽然这种方式保证了生成内容的连贯性,但也带来了严重的性能挑战,尤其是在长序列生成时。核心问题在于Transformer模型中的自注意力(Self-Attention)机制。KV Cache正是为了解决这一瓶颈而诞生的。

1. Transformer自注意力机制的瓶颈

在标准的Transformer解码器中,要生成当前时间步 $t$ 的词元 $Y_t$,模型必须让当前的Query $Q_t$ 关注(Attend)到序列中所有过去的词元 $Y_{1…t-1}$。

自注意力机制的计算公式如下:
$$Attention(Q, K, V) = Softmax(\frac{QK^T}{\sqrt{d_k}})V$$

在自回归推理过程中,生成第 $t$ 个词元时,我们计算 $Q_t$,但必须使用整个历史序列 $K_{1…t}$ 和 $V_{1…t}$。如果没有缓存机制,每生成一个新的词元 $Y_{t+1}$,模型都需要重新计算从 $Y_1$ 到 $Y_t$ 整个序列的 $K$ 和 $V$ 投影矩阵,并进行矩阵乘法。随着序列长度 $L$ 的增加,每一步的计算复杂度大约是 $O(L^2)$,导致延时迅速攀升。

2. KV Cache的本质是什么?

KV Cache,即Key-Value Cache,是指在LLM的自回归推理过程中,将Transformer的自注意力模块在处理完历史词元后生成的Key(K)和Value(V)中间表示存储起来的机制。

本质定义:

  1. K(Key)和 V(Value): 它们是输入词嵌入经过线性投影(权重矩阵 $W_K$ 和 $W_V$)生成的向量,代表了词元的信息内容。
  2. 缓存内容: KV Cache 存储的正是这些 K 和 V 向量,按Transformer层和注意力头进行划分。对于一个多层模型,每层都有自己的 K/V 缓存。

当模型生成下一个词元 $Y_{t+1}$ 时:

  • 它只计算新的 $Q_{t+1}$,以及新的 $K_{t+1}$ 和 $V_{t+1}$。
  • 它将 $K_{t+1}$ 和 $V_{t+1}$ 追加(Append)到之前缓存的 $K_{1…t}$ 和 $V_{1…t}$ 后面。
  • 注意力计算时,使用 $Q_{t+1}$ 与完整的缓存 $[K_{1…t}, K_{t+1}]$ 进行乘积,避免了重新计算历史 $K_{1…t}$ 和 $V_{1…t}$。

3. KV Cache如何大幅降低首词后的延时

KV Cache主要带来了两个维度的性能提升:

  1. 计算量大幅减少: 在没有缓存时,生成长度为 $L$ 的序列,总计算量接近 $O(L^3)$。有了KV Cache后,每一步新词元的生成复杂度从 $O(L^2)$ 降低到 $O(L)$。这是因为矩阵乘法只涉及新的 Query (长度为 1) 与累积的 K (长度为 $L$),而不是整个历史序列的重新计算。
  2. 内存带宽优化: 虽然KV Cache需要额外的内存来存储K/V矩阵,但在实际推理中,计算效率往往受限于内存带宽(Memory Bandwidth)而不是计算单元(FLOPS)。通过避免重复读取和计算历史词元的Key和Value,KV Cache大大减轻了内存访问的负担,降低了首词后的延迟(Time to First Token 之后的部分)。

4. 实际操作示例:利用PyTorch模拟KV Cache

以下是一个概念性的PyTorch示例,展示了在自回归解码中,如何利用缓存来避免重新计算K和V。

import torch

# 假设词嵌入维度 d_k = 64
d_k = 64
# 假设序列中已经缓存了 L_cached = 10 个词元
L_cached = 10

# 1. 初始化模拟的KV Cache (Key/Value的投影结果)
# 形状:(batch_size, num_heads, sequence_length, head_dim)
cached_keys = torch.randn(1, 12, L_cached, d_k)
cached_values = torch.randn(1, 12, L_cached, d_k)

print(f"缓存初始化长度: {L_cached}")

# 2. 生成下一个 Token (t=11)
# 新的 Query, Key, Value 都是针对长度为 1 的序列
new_query = torch.randn(1, 12, 1, d_k)
new_key = torch.randn(1, 12, 1, d_k)
new_value = torch.randn(1, 12, 1, d_k)

# -------- KV Cache 核心步骤:更新缓存 --------
# 3. 拼接 K 和 V,形成完整的 Key/Value 序列
current_keys = torch.cat([cached_keys, new_key], dim=2)
current_values = torch.cat([cached_values, new_value], dim=2)

L_new = current_keys.shape[2]
print(f"新序列总长度 L_new: {L_new}")

# 4. 计算注意力分数
# Q (1x1) 只需要与完整的 K (1xL_new) 进行矩阵乘法
# scores 形状: (batch_size, num_heads, 1, L_new)
scores = torch.matmul(new_query, current_keys.transpose(-2, -1)) / (d_k ** 0.5)

# 5. 聚合 Value
attention_output = torch.matmul(scores.softmax(dim=-1), current_values)

print(f"注意力输出形状: {attention_output.shape}")
print("通过缓存,我们成功避免了重新计算 L_cached 个 Token 的 K 和 V 投影。")

在上述示例中,如果不对K/V进行缓存,模型需要重新计算包含所有 11 个词元的 $K$ 和 $V$ 矩阵,从而重复了大量的计算工作。KV Cache通过空间换时间的方式,有效保证了LLM推理的高效性和低延迟。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » LLM 推理必问:KV Cache 的本质是什么?为什么它能大幅降低首词后的延时
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址