如何在推理服务中为 LLM 生成结果集成不确定性度量?
在生产环境中部署大语言模型(LLM)时,模型生成的“幻觉”(Hallucination)是影响业务落地的核心挑战。为了提升系统的可靠性,在 AI Infra 层面集成不确定性(Uncertainty)指标至关重要。本文将介绍如何通过提取推理过程中的 Logits 来计算序列置信度。
1. 为什么需要不确定性度量?
在 LLM 推理管线中加入不确定性度量可以实现以下功能:
– 质量门控:自动识别低置信度回答并触发人工审核或回退机制。
– 级联推理:当低成本小模型表现出高不确定性时,自动切换到高成本的大模型(如 GPT-4)。
– 幻觉检测:研究表明,词元熵(Token Entropy)与模型的事实准确性具有强相关性。
2. 技术原理:基于 Logits 的序列熵
最常用的不确定性指标是 Normalized Sequence Entropy (归一化序列熵)。其核心思想是计算生成序列中每个词元在概率分布上的混乱程度。对于生成的每个词元 $t$,其熵定义为:
$$H(t) = -\sum_{i \in V} p_i \log p_i$$
其中 $V$ 是词表,$p_i$ 是第 $i$ 个词元的概率。序列的总得分则是所有词元熵的平均值。
3. 实战:基于 Hugging Face Transformers 的实现
在现代推理框架中,我们需要开启 output_scores 才能获取概率分布。以下代码演示了如何计算生成结果的平均负对数似然(NLL),这是衡量不确定性的一个常用替代指标。
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
def generate_with_confidence(model, tokenizer, prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# 关键:设置 output_scores=True 和 return_dict_in_generate=True
outputs = model.generate(
**inputs,
max_new_tokens=32,
output_scores=True,
return_dict_in_generate=True,
temperature=1.0
)
# 获取生成的序列(排除 input 部分)
gen_sequences = outputs.sequences[:, inputs.input_ids.shape[-1]:]
# 使用 compute_transition_scores 获取每个 token 的 log_probs
probs = model.compute_transition_scores(
outputs.sequences, outputs.scores, normalize_logits=True
)
# 计算平均负对数似然 (Average NLL)
# NLL 越低,置信度越高;NLL 越高,不确定性越大
uncertainty_score = -torch.mean(probs).item()
decoded_text = tokenizer.decode(gen_sequences[0], skip_special_tokens=True)
return {
"text": decoded_text,
"uncertainty": uncertainty_score,
"confidence": torch.exp(torch.mean(probs)).item()
}
# 初始化模型 (以轻量级模型为例)
model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
res = generate_with_confidence(model, tokenizer, "The capital of France is")
print(f"Generated: {res['text']}
Uncertainty Score: {res['uncertainty']:.4f}")
4. 工程化建议:在 vLLM 等服务中集成
在生产级 AI Infra 中,通常不会直接调用 Transformers。若使用 vLLM 或 TGI:
1. API 配置:在调用 /v1/completions 时设置 logprobs 参数(例如 logprobs=1)。
2. 后端计算:在 Inference Gateway 层根据返回的 logprobs 数组计算平均值或几何平均数。
3. 阈值触发:设定一个动态阈值(如 NLL > 0.8),当超过该值时,API 返回中自动包含 is_reliable: false 字段,指导前端展示。
这种不依赖于额外采样(Sampling)的方法性能损耗极低,是当前 LLM 部署中增强可靠性的首选方案。
汤不热吧