在工业级AI项目中,数据I/O效率往往是训练速度的瓶颈。标准的CSV或Parquet文件在处理大规模、异构数据(如包含大量稀疏特征、图像或高维向量)时,性能往往不佳。TensorFlow的官方数据格式TFRecord,结合其核心协议tf.train.Example,提供了高效的数据序列化和反序列化机制,是构建高性能特征库的最佳选择。
TFRecord本质上是一个二进制文件,它将复杂的结构化数据存储为一系列的序列化字符串(Protocol Buffers),实现了零拷贝读取和快速解析,极大地优化了数据加载效率。
步骤一:理解 tf.train.Example 结构
tf.train.Example是用于表示数据样本的标准协议。每个样本由一个或多个特征(Feature)组成,特征必须是以下三种类型之一:
- BytesList (字符串或原始字节数据,常用于存储图片、音频、文本或序列化的高维向量)
- Int64List (整数类型)
- FloatList (浮点数类型)
我们首先编写辅助函数,将Python基本数据类型封装成tf.train.Feature对象。
import tensorflow as tf
import numpy as np
import os
# 1. 辅助函数:将数据封装为 Feature 对象
def _bytes_feature(value):
"""将字符串/字节列表转换为 BytesList Feature."""
if isinstance(value, type(tf.constant(0))): # 处理张量
value = value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64_feature(value):
"""将整数或列表转换为 Int64List Feature."""
# 确保 value 是一个列表,即使只有一个元素
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def _float_feature(value):
"""将浮点数或列表转换为 FloatList Feature."""
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
步骤二:构建特征序列化函数
假设我们要存储一个推荐系统中的用户特征,包括用户ID(int)、年龄(int)和用户Embedding向量(高维浮点数)。高维向量必须先转换为字节(tobytes())才能存储。
def serialize_example(user_id, age, click_history, embedding):
"""将单个样本的特征字典序列化为 tf.train.Example 字符串."""
# 关键步骤:高维向量需要转换为原始字节数据
embedding_bytes = embedding.astype(np.float32).tobytes()
feature = {
'user_id': _int64_feature(user_id),
'age': _int64_feature(age),
# 存储变长特征,例如点击历史列表
'click_history': _int64_feature(click_history),
# 存储高维向量,使用 BytesList
'user_embedding_raw': _bytes_feature(embedding_bytes),
'embedding_shape': _int64_feature(embedding.shape[0]) # 记录原始维度,便于反序列化
}
example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
return example_proto.SerializeToString()
步骤三:写入 TFRecord 文件
使用tf.io.TFRecordWriter将序列化的样本逐一写入磁盘。
# 示例数据生成
NUM_SAMPLES = 10
EMBEDDING_DIM = 64
TFRECORD_FILE = 'industrial_user_features.tfrecord'
print(f"开始写入 {NUM_SAMPLES} 个样本到 {TFRECORD_FILE}...")
with tf.io.TFRecordWriter(TFRECORD_FILE) as writer:
for i in range(NUM_SAMPLES):
user_id = 1000 + i
age = 20 + np.random.randint(1, 10)
# 模拟变长点击历史
click_history = np.random.randint(100, 500, size=np.random.randint(5, 15)).tolist()
embedding = np.random.rand(EMBEDDING_DIM)
example_string = serialize_example(user_id, age, click_history, embedding)
writer.write(example_string)
print("写入完成。")
步骤四:高效读取与解析
读取TFRecord文件是使用tf.data.TFRecordDataset的核心步骤。我们必须提供一个解析函数,定义数据的结构和类型,尤其是要将存储为字节的向量反序列化回来。
def parse_tfrecord(example_proto):
# 1. 定义特征描述(必须与写入时的 key 对应)
feature_description = {
'user_id': tf.io.FixedLenFeature([], tf.int64),
'age': tf.io.FixedLenFeature([], tf.int64),
# 变长特征使用 VarLenFeature,或者在 FixedLenFeature 中指定 shape=[] 和 default_value=[]
'click_history': tf.io.VarLenFeature(tf.int64),
'user_embedding_raw': tf.io.FixedLenFeature([], tf.string),
'embedding_shape': tf.io.FixedLenFeature([], tf.int64),
}
# 2. 解析单个 Example
parsed_features = tf.io.parse_single_example(example_proto, feature_description)
# 3. 反序列化高维向量 (Bytes -> Tensor)
embedding_shape = parsed_features.pop('embedding_shape')
# tf.io.decode_raw 将字节数据解码为指定的浮点类型
embedding = tf.io.decode_raw(parsed_features['user_embedding_raw'], tf.float32)
# 重新指定维度。注意:这里我们用写入时的固定维度 64 来确定 shape
embedding = tf.reshape(embedding, [EMBEDDING_DIM])
parsed_features['user_embedding'] = embedding
# 删除原始字节数据
parsed_features.pop('user_embedding_raw')
# 对于 VarLenFeature,需要转换为稠密张量
parsed_features['click_history'] = tf.sparse.to_dense(parsed_features['click_history'])
return parsed_features
# 5. 构建数据管道
dataset = tf.data.TFRecordDataset(TFRECORD_FILE)
# 使用 map 高效并行地解析数据
dataset = dataset.map(parse_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
# 批处理和预取
dataset = dataset.batch(2).prefetch(tf.data.AUTOTUNE)
print("\n--- 验证读取结果 ---")
for batch in dataset.take(1):
print("用户ID:", batch['user_id'].numpy())
print("年龄:", batch['age'].numpy())
print("点击历史样本(变长):", batch['click_history'].numpy())
print("Embedding 向量 Shape:", batch['user_embedding'].shape)
print("Embedding 向量 Dtype:", batch['user_embedding'].dtype)
# 清理文件
os.remove(TFRECORD_FILE)
总结
利用 TFRecord 和 tf.train.Example,我们能够将不同类型和维度的特征(如标量、列表、高维向量)统一存储在一个高效的二进制文件中。这种方法消除了训练过程中的数据解析开销,并天然支持 TensorFlow 的 tf.data 管道的并行化和预取机制,是构建工业级高性能AI系统特征库的关键技术。
汤不热吧