在处理推荐系统或大规模广告系统时,我们经常遇到具有数百万甚至数十亿唯一值的类别特征(如用户ID、商品ID)。如果直接将这些ID作为输入并依赖传统的 Keras Embedding 层,模型在内存和初始化速度上都会面临巨大挑战。
解决这个问题的关键在于:将类别到整数索引的映射过程剥离出模型变量,并使用高效的 TensorFlow 查找表 (tf.lookup) 来实现快速、内存友好的查询。
本文将聚焦于如何使用 tf.lookup.StaticHashTable 结合外部词汇表文件,实现超大规模类别特征的快速嵌入映射。
1. 技术背景:为什么需要 tf.lookup?
如果一个词汇表有1000万个唯一项,我们不能直接在 tf.keras.layers.Embedding 中使用字符串作为输入。通常的做法是先将字符串映射为整数ID,然后将ID输入Embedding层。
当词汇表巨大时,如果将词汇映射逻辑作为模型的一部分(例如,使用 tf.keras.layers.StringLookup 并将其权重存储在模型检查点中),会导致检查点文件过于庞大,加载速度慢,且内存占用高。
StaticHashTable 允许我们从外部文件(如 CSV 或文本文件)加载映射关系,该查找表在图初始化时构建,并且查找速度极快,同时不占用模型可训练变量的内存。
2. 实操步骤:实现字符串到索引的映射
我们将演示如何创建一个包含10万个类别的词汇表,并将其高效地映射到索引。
步骤 2.1: 准备词汇表文件
我们首先生成一个模拟的超大规模词汇表文件 large_vocab.txt,每一行是一个唯一的类别字符串。
import tensorflow as tf
import numpy as np
import os
# 定义文件路径和词汇量大小
vocab_filename = 'large_vocab.txt'
vocab_size = 100000 # 模拟10万个类别
embedding_dim = 64
# 检查并生成词汇表文件
if not os.path.exists(vocab_filename):
print(f"生成 {vocab_size} 个词汇项...")
with open(vocab_filename, 'w') as f:
for i in range(vocab_size):
f.write(f"category_{i}\n")
print("词汇表生成完毕。")
# OOV (Out of Vocabulary) 索引:我们将使用最后一个索引位 (vocab_size) 作为 OOV 的默认值。
OOV_INDEX = vocab_size
步骤 2.2: 初始化静态哈希查找表
使用 tf.lookup.TextFileInitializer 从文件中读取键值对。由于我们的文件只有键(类别名),我们将使用行号作为值(索引)。
# 1. 定义初始化器:从文件中读取键值对
keys_init = tf.lookup.TextFileInitializer(
filename=vocab_filename,
key_dtype=tf.string,
value_dtype=tf.int64,
# 指定读取配置:使用整行作为Key
key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
# 使用行号(从0开始)作为Value (索引)
value_index=tf.lookup.TextFileIndex.LINE_NUMBER
)
# 2. 创建静态哈希表
# default_value设置为 OOV_INDEX,用于处理不在词汇表中的类别。
table = tf.lookup.StaticHashTable(
keys_init,
default_value=tf.constant(OOV_INDEX, dtype=tf.int64)
)
print("查找表初始化成功。")
步骤 2.3: 查找与嵌入集成
现在,我们可以在模型的数据预处理或输入层中直接使用这个查找表,将输入的字符串映射为Embedding层所需的整数索引。
# 模拟输入数据
input_categories = tf.constant([
"category_100",
"category_99999",
"category_OOV_test", # OOV项
"category_50"
], dtype=tf.string)
# 执行查找:字符串 -> 索引
indices = table.lookup(input_categories)
print("原始类别:\n", input_categories.numpy())
print("映射索引:\n", indices.numpy())
# 3. 结合 Keras Embedding 层
# Embedding层大小需要是 (vocab_size + 1) 来容纳 OOV 索引
embedding_layer = tf.keras.layers.Embedding(
input_dim=vocab_size + 1,
output_dim=embedding_dim,
name="large_scale_embedding"
)
# 获取嵌入向量
embeddings = embedding_layer(indices)
print(f"Embedding层输入维度: {embedding_layer.input_dim}")
print(f"输出Embedding形状: {embeddings.shape}")
# 验证 OOV 映射
# OOV_INDEX 应该是 100000
print(f"预期的 OOV 索引: {OOV_INDEX}")
print(f"OOV 项 (category_OOV_test) 实际映射的索引: {indices.numpy()[2]}")
# 清理生成的临时文件
os.remove(vocab_filename)
3. 总结与优势
通过使用 tf.lookup.StaticHashTable,我们实现了对超大规模类别特征的快速、离线映射:
- 内存效率高: 查找表数据存储在图的初始化阶段,不作为模型的可训练变量,极大地减小了模型检查点的大小。
- 查找速度快: 哈希表查询的时间复杂度接近 O(1)。
- 动态管理: 如果词汇表需要更新,只需要替换外部文件并重新初始化模型,无需修改模型结构或重新训练Embedding权重(除非需要重新训练整个模型)。
- 支持 OOV: 通过设置 default_value,可以优雅地处理训练或推理时出现的新类别。
汤不热吧