欢迎光临
我们一直在努力

怎样利用 TFRecord 存储格式与 tf.train.Example 构建高效的工业级特征库

在工业级AI项目中,数据I/O效率往往是训练速度的瓶颈。标准的CSV或Parquet文件在处理大规模、异构数据(如包含大量稀疏特征、图像或高维向量)时,性能往往不佳。TensorFlow的官方数据格式TFRecord,结合其核心协议tf.train.Example,提供了高效的数据序列化和反序列化机制,是构建高性能特征库的最佳选择。

TFRecord本质上是一个二进制文件,它将复杂的结构化数据存储为一系列的序列化字符串(Protocol Buffers),实现了零拷贝读取和快速解析,极大地优化了数据加载效率。

步骤一:理解 tf.train.Example 结构

tf.train.Example是用于表示数据样本的标准协议。每个样本由一个或多个特征(Feature)组成,特征必须是以下三种类型之一:

  1. BytesList (字符串或原始字节数据,常用于存储图片、音频、文本或序列化的高维向量)
  2. Int64List (整数类型)
  3. FloatList (浮点数类型)

我们首先编写辅助函数,将Python基本数据类型封装成tf.train.Feature对象。

import tensorflow as tf
import numpy as np
import os

# 1. 辅助函数:将数据封装为 Feature 对象
def _bytes_feature(value):
    """将字符串/字节列表转换为 BytesList Feature."""
    if isinstance(value, type(tf.constant(0))): # 处理张量
        value = value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _int64_feature(value):
    """将整数或列表转换为 Int64List Feature."""
    # 确保 value 是一个列表,即使只有一个元素
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def _float_feature(value):
    """将浮点数或列表转换为 FloatList Feature."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

步骤二:构建特征序列化函数

假设我们要存储一个推荐系统中的用户特征,包括用户ID(int)、年龄(int)和用户Embedding向量(高维浮点数)。高维向量必须先转换为字节(tobytes())才能存储。

def serialize_example(user_id, age, click_history, embedding):
    """将单个样本的特征字典序列化为 tf.train.Example 字符串."""
    # 关键步骤:高维向量需要转换为原始字节数据
    embedding_bytes = embedding.astype(np.float32).tobytes()

    feature = {
        'user_id': _int64_feature(user_id),
        'age': _int64_feature(age),
        # 存储变长特征,例如点击历史列表
        'click_history': _int64_feature(click_history),
        # 存储高维向量,使用 BytesList
        'user_embedding_raw': _bytes_feature(embedding_bytes),
        'embedding_shape': _int64_feature(embedding.shape[0]) # 记录原始维度,便于反序列化
    }

    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

步骤三:写入 TFRecord 文件

使用tf.io.TFRecordWriter将序列化的样本逐一写入磁盘。

# 示例数据生成
NUM_SAMPLES = 10
EMBEDDING_DIM = 64
TFRECORD_FILE = 'industrial_user_features.tfrecord'

print(f"开始写入 {NUM_SAMPLES} 个样本到 {TFRECORD_FILE}...")

with tf.io.TFRecordWriter(TFRECORD_FILE) as writer:
    for i in range(NUM_SAMPLES):
        user_id = 1000 + i
        age = 20 + np.random.randint(1, 10)
        # 模拟变长点击历史
        click_history = np.random.randint(100, 500, size=np.random.randint(5, 15)).tolist()
        embedding = np.random.rand(EMBEDDING_DIM)

        example_string = serialize_example(user_id, age, click_history, embedding)
        writer.write(example_string)

print("写入完成。")

步骤四:高效读取与解析

读取TFRecord文件是使用tf.data.TFRecordDataset的核心步骤。我们必须提供一个解析函数,定义数据的结构和类型,尤其是要将存储为字节的向量反序列化回来。

def parse_tfrecord(example_proto):
    # 1. 定义特征描述(必须与写入时的 key 对应)
    feature_description = {
        'user_id': tf.io.FixedLenFeature([], tf.int64),
        'age': tf.io.FixedLenFeature([], tf.int64),
        # 变长特征使用 VarLenFeature,或者在 FixedLenFeature 中指定 shape=[] 和 default_value=[]
        'click_history': tf.io.VarLenFeature(tf.int64),
        'user_embedding_raw': tf.io.FixedLenFeature([], tf.string),
        'embedding_shape': tf.io.FixedLenFeature([], tf.int64),
    }

    # 2. 解析单个 Example
    parsed_features = tf.io.parse_single_example(example_proto, feature_description)

    # 3. 反序列化高维向量 (Bytes -> Tensor)
    embedding_shape = parsed_features.pop('embedding_shape')

    # tf.io.decode_raw 将字节数据解码为指定的浮点类型
    embedding = tf.io.decode_raw(parsed_features['user_embedding_raw'], tf.float32)
    # 重新指定维度。注意:这里我们用写入时的固定维度 64 来确定 shape
    embedding = tf.reshape(embedding, [EMBEDDING_DIM]) 

    parsed_features['user_embedding'] = embedding
    # 删除原始字节数据
    parsed_features.pop('user_embedding_raw')

    # 对于 VarLenFeature,需要转换为稠密张量
    parsed_features['click_history'] = tf.sparse.to_dense(parsed_features['click_history'])

    return parsed_features

# 5. 构建数据管道
dataset = tf.data.TFRecordDataset(TFRECORD_FILE)
# 使用 map 高效并行地解析数据
dataset = dataset.map(parse_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
# 批处理和预取
dataset = dataset.batch(2).prefetch(tf.data.AUTOTUNE)

print("\n--- 验证读取结果 ---")
for batch in dataset.take(1):
    print("用户ID:", batch['user_id'].numpy())
    print("年龄:", batch['age'].numpy())
    print("点击历史样本(变长):", batch['click_history'].numpy())
    print("Embedding 向量 Shape:", batch['user_embedding'].shape)
    print("Embedding 向量 Dtype:", batch['user_embedding'].dtype)

# 清理文件
os.remove(TFRECORD_FILE)

总结

利用 TFRecord 和 tf.train.Example,我们能够将不同类型和维度的特征(如标量、列表、高维向量)统一存储在一个高效的二进制文件中。这种方法消除了训练过程中的数据解析开销,并天然支持 TensorFlow 的 tf.data 管道的并行化和预取机制,是构建工业级高性能AI系统特征库的关键技术。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 怎样利用 TFRecord 存储格式与 tf.train.Example 构建高效的工业级特征库
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址