在万卡(10000+ GPU)规模的AI集群中进行大模型训练时,CheckPoint(检查点)的可靠性和速度是决定训练效率的关键。一个TB级甚至PB级的检查点,如果采用传统同步方式存储,可能导致训练停顿数小时,极大增加了恢复时间目标(RTO)。本文将深入探讨在大规模分布式训练环境下,如何通过技术手段实现快速、异步且高可靠的检查点机制。
Contents
1. 万卡集群检查点面临的核心挑战
- I/O 瓶颈: 数万个进程同时向分布式文件系统(如Lustre或NFS)写入大量小文件,导致文件系统拥塞。带宽饱和是常态。
- 恢复时长 (RTO): 加载一个PB级的检查点通常需要大量时间,如果训练中频繁发生故障,加载时间会抵消训练进度。
- 单点故障: 任何一个节点的I/O失败都可能导致整个CheckPoint过程失败。
2. 核心策略:异步与分层存储
为了解决同步I/O带来的停顿问题,必须将CheckPoint的写入过程从主训练循环中解耦。我们采用异步检查点(Asynchronous Checkpointing)结合分层存储(Hierarchical Storage)的策略。
2.1 异步检查点实现
异步检查点的核心思想是,在训练迭代 $N$ 结束时,训练进程只将模型状态和优化器状态的Tensor引用交给一个独立的后台工作进程/线程池。主训练循环立即进入 $N+1$ 迭代。后台进程负责实际的磁盘I/O。
以下是一个简化的PyTorch/Python异步保存逻辑示例(借鉴了DeepSpeed的Zero Stage 3或Megatron-LM的思路):
******python
import torch
import threading
import os
class CheckpointManager:
def init(self, save_dir):
self.save_dir = save_dir
self.queue = []
self.lock = threading.Lock()
self.stop_flag = threading.Event()
self.worker = threading.Thread(target=self._worker_loop)
self.worker.start()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24 def _worker_loop(self):
while not self.stop_flag.is_set():
if self.queue:
# 保证操作是原子的
with self.lock:
state_dict, step = self.queue.pop(0)
# 实际的I/O操作
path = os.path.join(self.save_dir, f'step_{step}_rank_{torch.distributed.get_rank()}.pt')
print(f"[Rank {torch.distributed.get_rank()}] Saving checkpoint {step} asynchronously to {path}")
torch.save(state_dict, path)
else:
self.stop_flag.wait(timeout=0.5)
def save_async(self, model_state, step):
# 注意:这里我们只保存了当前rank的模型/优化器分片
with self.lock:
# Deep copy the state dict to ensure the tensors aren't modified by training loop
copied_state = {k: v.clone() for k, v in model_state.items()}
self.queue.append((copied_state, step))
def stop(self):
self.stop_flag.set()
self.worker.join()
示例使用
if torch.distributed.get_rank() == 0:
manager = CheckpointManager(‘./async_ckpt’)
# 假设这是训练循环中的状态
current_step = 100
model_state = {‘weights’: torch.randn(100)}
manager.save_async(model_state, current_step)
# 训练继续…
# manager.stop() 在训练结束时调用
2.2 Fused Checkpointing 与节点聚合
在万卡集群中,避免生成数万个小文件至关重要。采用融合检查点(Fused Checkpointing),即在一个节点内部,将该节点上所有GPU(例如8个)的模型分片聚合成一个或少数几个大文件,再写入分布式存储。这能极大地减少文件句柄开销和元数据操作。
理想的写入路径:
- GPU Memory -> Node Local NVMe/SSD: Ranks将自己的模型分片写入本地高速存储。这是最快的写入,几乎不阻塞网络。
- Node Local NVMe -> Distributed File System (Lustre/BeeGFS): 节点上的一个聚合进程(或使用rsync/异步I/O工具)将聚合后的大文件移动到分布式文件系统。
- Distributed File System -> Object Storage (S3/OSS): 最终由专门的归档服务将稳定状态的检查点归档到高耐久性的对象存储。
这种分层结构确保了:即使分布式文件系统临时拥塞,训练也能快速完成本地写入并继续。恢复时,我们优先从分布式文件系统加载,如果它不可用,再从对象存储加载。
3. 快速恢复的关键:元数据和状态字典优化
快速恢复(快速加载)与快速保存同样重要。在万卡环境下,模型加载是顺序访问,加载速度被网络带宽严格限制。
3.1 仅加载所需的状态
对于超大模型,优化器状态(如AdamW的m和v)可能占据 CheckPoint 80%以上的空间。许多框架(如DeepSpeed Zero Stage 3)支持在恢复时选择性地重建优化器状态,或者只在必要时才加载优化器状态。
3.2 优化状态字典结构
传统的torch.save(model.state_dict(), …)会生成一个包含所有张量的字典。为了提高加载速度,我们应该优化状态字典,使其更容易被并行读取:
- 分离模型参数与优化器状态: 将模型参数(较稳定)和优化器状态(巨大且易变)保存为独立的文件集合。
- 使用索引文件: 创建一个轻量级的 JSON 或 YAML 元数据文件,记录每个 CheckPoint 分片的文件名、大小和哈希值。在恢复时,主节点只需读取这个元数据文件,即可并行调度所有工作节点同时从分布式文件系统拉取各自的分片,无需等待主节点解析整个状态字典。
******json
{
“step”: 1000,
“model_shards”: [
{“rank”: 0, “file”: “model_r0.pt”, “size_bytes”: 1073741824},
{“rank”: 1, “file”: “model_r1.pt”, “size_bytes”: 1073741824},
// … 更多分片
],
“optimizer_shards”: [
{“rank”: 0, “file”: “optim_r0.pt”},
// …
]
}
这种机制将恢复过程从串行元数据解析转变为并行数据加载,显著降低了RTO,使得万卡集群能够在几分钟内而非几小时内恢复训练。
汤不热吧