在处理 TB 级别或者需要实时生成的流式数据集时,传统的 PyTorch Dataset(Map-style Dataset,通过 __getitem__ 随机访问)机制会遇到致命的内存瓶颈。因为这类数据集要求在初始化时或者通过索引访问时将所有数据加载或映射到内存中。
解决办法是使用 ****torch.utils.data.IterableDataset****。与 Map-style Dataset 不同,IterableDataset 实现了 Python 的迭代器协议,它只要求实现 __iter__ 方法,用于返回一个数据流迭代器。这意味着数据是即时生成或即时从磁盘读取的,永远不会同时装载进内存,完美适用于大数据流。
本文将通过一个高度实操性的示例,展示如何构建和使用一个基于数据流生成器的 IterableDataset。
1. IterableDataset 的基本原理
IterableDataset 避免了随机访问,每次迭代都会产生下一批数据。它的核心在于 __iter__ 方法必须返回一个迭代器对象(例如一个 generator)。
2. 实操:构建流式数据加载器
我们模拟一个场景:数据并非来自一个固定的文件列表,而是通过一个函数实时生成,或者需要逐行读取一个巨大的日志文件。
import torch
from torch.utils.data import IterableDataset, DataLoader
import time
import os
# 1. 模拟一个大型数据生成器
# 这个函数模拟了从磁盘读取数据或实时生成数据的I/O操作
def large_data_generator(num_samples):
"""生成器函数,按需产生数据"""
print("\n--- 启动数据流生成 (Worker ID 见 DataLoader 输出) ---")
for i in range(num_samples):
# 模拟I/O延迟和数据处理
time.sleep(0.00001)
# 产生数据张量和标签
data = torch.tensor([i, i*2, i*3], dtype=torch.float32)
label = torch.tensor([i % 10], dtype=torch.long)
yield data, label
# 2. 实现 StreamingDataset (继承自 IterableDataset)
class StreamingDataset(IterableDataset):
def __init__(self, generator_func, num_samples):
super().__init__()
self.generator_func = generator_func
self.num_samples = num_samples
def __iter__(self):
# 核心:每次调用返回一个新的数据流迭代器
return self.generator_func(self.num_samples)
# 3. 关键:处理多进程(Multi-Worker)下的数据划分
# 对于 IterableDataset,如果使用 num_workers > 0,每个 worker 都会独立调用 __iter__()。
# 如果不进行处理,所有 worker 将读取相同的数据,导致数据重复。
# 我们需要使用 worker_init_fn 来确保每个 worker 只处理数据流的一部分。
def worker_init_fn(worker_id):
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
# 获取总 worker 数和当前 worker ID
num_workers = worker_info.num_workers
worker_id = worker_info.id
# 将 worker ID 存储到线程本地存储中,供 generator_func 使用
# 实际操作中,generator_func 应该根据 worker_id 来计算读取的起始和结束偏移量
# 这里的简单示例只是打印信息,实际需在 StreamingDataset.__iter__ 内部实现分片。
print(f"[Worker {worker_id}/{num_workers}] 初始化成功,准备读取数据片段")
# 4. 使用和测试
NUM_SAMPLES = 50000
BATCH_SIZE = 128
NUM_WORKERS = 4 # 启用多进程加速数据加载
# 初始化流式数据集
stream_dataset = StreamingDataset(
generator_func=large_data_generator,
num_samples=NUM_SAMPLES
)
# 使用 DataLoader
stream_loader = DataLoader(
stream_dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
worker_init_fn=worker_init_fn, # 确保数据流在多进程下正确划分
# 注意:IterableDataset 通常不支持内置的 shuffle,需要在 generator_func 内部实现打乱逻辑
)
# 5. 模拟训练循环
print(f"\n开始遍历数据流,总样本数: {NUM_SAMPLES}, 批次大小: {BATCH_SIZE}, Workers: {NUM_WORKERS}")
start_time = time.time()
total_batches = 0
for batch_idx, (data, labels) in enumerate(stream_loader):
# 模拟模型训练步骤
if batch_idx % 100 == 0:
print(f"Processed batch {batch_idx}: Data shape {data.shape}, Example label: {labels[0]}")
total_batches += 1
if total_batches > 400: # 避免运行时间过长,只演示部分训练
break
end_time = time.time()
print(f"\n成功处理了 {total_batches} 个批次 (约 {total_batches * BATCH_SIZE} 个样本)。")
print(f"总耗时: {end_time - start_time:.2f} 秒")
3. 关键点总结
- 内存效率:数据流实时产生,内存占用恒定且低,与数据集大小无关。
- ****iter****:这是 IterableDataset 唯一必需的方法,它必须返回一个迭代器(如 yield 语句构成的生成器函数)。
- 多进程处理:当设置 num_workers > 0 时,DataLoader 会为每个 worker 启动一个进程,并让每个进程独立调用 __iter__。因此,对于文件读取场景,必须配合 worker_init_fn 在 __iter__ 内部实现数据的分片逻辑(例如,worker 0 读取文件的前 1/N,worker 1 读取第 1/N 到 2/N,等等),否则所有 worker 将读取相同的数据。
汤不热吧