在训练万亿参数(TB级权重)的大型语言模型(LLM)时,断点续训(Checkpointing)是至关重要的一环。然而,传统的PyTorch保存方式通常需要Rank 0节点聚合所有权重,这会导致严重的I/O瓶颈和内存溢出,使得保存一次权重可能需要数小时。要实现“秒级”的TB级权重保存,核心在于利用分布式训练框架提供的并行分片保存能力。
本实践将聚焦于PyTorch的Fully Sharded Data Parallel (FSDP),结合其分布式状态字典(Distributed State Dict)机制,实现高效的并行Checkpointing。
1. 核心挑战与解决方案:分片状态字典
挑战: 传统方法需要将所有权重(假设1TB)聚合到单个节点,然后写入磁盘。
解决方案: FSDP允许我们在保存时使用StateDictType.SHARDED_STATE_DICT。这意味着每个GPU(Rank)只负责保存其本地持有的模型权重分片、优化器状态分片和梯度分片。这样,100个GPU就可以并行地将1TB数据写入磁盘,将串行操作转化为并行操作,极大地提升了I/O速度。
2. 环境准备
确保你已经安装了支持分布式训练的PyTorch版本,并且配置了多GPU环境(例如使用torchrun或mpirun)。
# 假设环境已配置好PyTorch和CUDA
# 示例运行命令(4个GPU)
# torchrun --nproc_per_node=4 your_script.py
3. 实现并行保存与加载(FSDP)
我们将定义一个简单的FSDP模型,并实现一个高度并行的保存和加载逻辑。
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import StateDictType, ShardedStateDictConfig
import os
import time
def setup_ddp():
# 初始化DDP环境
dist.init_process_group("nccl")
rank = dist.get_rank()
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
return rank, local_rank
# 示例模型 (为了演示,使用一个小模型)
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(2**14, 2**14)
# 路径定义
CHECKPOINT_DIR = "fsdp_sharded_checkpoint"
def save_sharded_checkpoint(model: FSDP, rank: int, path: str):
"""并行保存分片状态字典"""
print(f"Rank {rank}: 开始保存...")
start_time = time.time()
# 1. 配置为SHARDED_STATE_DICT模式
# 必须指定offload_to_cpu=True以确保数据在保存前在CPU上进行准备
save_config = ShardedStateDictConfig(offload_to_cpu=True)
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, state_dict_config=save_config):
# 2. 获取状态字典
# 在此模式下,state_dict()返回的字典只包含当前Rank拥有的分片数据
state_dict = model.state_dict()
# 3. Rank 0 负责创建目录
if rank == 0:
os.makedirs(path, exist_ok=True)
dist.barrier() # 确保目录创建完毕
# 4. 每个Rank独立保存自己的分片
save_file = os.path.join(path, f"model_shard_{rank}.pt")
torch.save(state_dict, save_file)
dist.barrier() # 等待所有Rank保存完毕
if rank == 0:
print(f"Rank 0: 全部保存完成,耗时 {time.time() - start_time:.2f} 秒")
def load_sharded_checkpoint(model: FSDP, rank: int, path: str):
"""并行加载分片状态字典"""
print(f"Rank {rank}: 开始加载...")
start_time = time.time()
load_config = ShardedStateDictConfig(offload_to_cpu=True)
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, state_dict_config=load_config):
# 1. 读取当前Rank对应的分片文件
load_file = os.path.join(path, f"model_shard_{rank}.pt")
sharded_state_dict = torch.load(load_file, map_location=f'cuda:{model.device.index}')
# 2. FSDP的load_state_dict会正确地将分片加载到对应的GPU内存中
model.load_state_dict(sharded_state_dict)
dist.barrier()
if rank == 0:
print(f"Rank 0: 全部加载完成,耗时 {time.time() - start_time:.2f} 秒")
# --- 主执行逻辑 ---
if __name__ == '__main__':
# 1. 初始化环境
rank, local_rank = setup_ddp()
# 2. 创建并FSDP封装模型
base_model = DummyModel().to(local_rank)
# 注意:对于TB级模型,需要配合ParameterGroup等优化策略
fsdp_model = FSDP(base_model)
# 3. 示例:模型训练后进行保存
save_sharded_checkpoint(fsdp_model, rank, CHECKPOINT_DIR)
# 4. 示例:重新初始化模型并加载权重
dist.barrier()
new_base_model = DummyModel().to(local_rank)
new_fsdp_model = FSDP(new_base_model)
load_sharded_checkpoint(new_fsdp_model, rank, CHECKPOINT_DIR)
dist.destroy_process_group()
4. 总结优化原理
这种方法的“秒级”优势并非来自压缩或神奇的算法,而是来自于:
- 并行性: N个GPU同时写入N个文件,理论I/O速度提升N倍。
- 避免聚合: 绕过了将所有TB级数据传输到Rank 0的CPU内存和磁盘带宽限制。
- 分布式文件系统: 搭配高性能并行文件系统(如Lustre, CephFS, GPFS等),可以充分发挥并行写入的性能。
汤不热吧