欢迎光临
我们一直在努力

大模型断点续训(Checkpointing)优化:如何秒级保存与加载 TB 级的权重

在训练万亿参数(TB级权重)的大型语言模型(LLM)时,断点续训(Checkpointing)是至关重要的一环。然而,传统的PyTorch保存方式通常需要Rank 0节点聚合所有权重,这会导致严重的I/O瓶颈和内存溢出,使得保存一次权重可能需要数小时。要实现“秒级”的TB级权重保存,核心在于利用分布式训练框架提供的并行分片保存能力。

本实践将聚焦于PyTorch的Fully Sharded Data Parallel (FSDP),结合其分布式状态字典(Distributed State Dict)机制,实现高效的并行Checkpointing。

1. 核心挑战与解决方案:分片状态字典

挑战: 传统方法需要将所有权重(假设1TB)聚合到单个节点,然后写入磁盘。

解决方案: FSDP允许我们在保存时使用StateDictType.SHARDED_STATE_DICT。这意味着每个GPU(Rank)只负责保存其本地持有的模型权重分片、优化器状态分片和梯度分片。这样,100个GPU就可以并行地将1TB数据写入磁盘,将串行操作转化为并行操作,极大地提升了I/O速度。

2. 环境准备

确保你已经安装了支持分布式训练的PyTorch版本,并且配置了多GPU环境(例如使用torchrunmpirun)。

# 假设环境已配置好PyTorch和CUDA
# 示例运行命令(4个GPU)
# torchrun --nproc_per_node=4 your_script.py

3. 实现并行保存与加载(FSDP)

我们将定义一个简单的FSDP模型,并实现一个高度并行的保存和加载逻辑。

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import StateDictType, ShardedStateDictConfig
import os
import time

def setup_ddp():
    # 初始化DDP环境
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    return rank, local_rank

# 示例模型 (为了演示,使用一个小模型)
class DummyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(2**14, 2**14)

# 路径定义
CHECKPOINT_DIR = "fsdp_sharded_checkpoint"

def save_sharded_checkpoint(model: FSDP, rank: int, path: str):
    """并行保存分片状态字典"""
    print(f"Rank {rank}: 开始保存...")
    start_time = time.time()

    # 1. 配置为SHARDED_STATE_DICT模式
    # 必须指定offload_to_cpu=True以确保数据在保存前在CPU上进行准备
    save_config = ShardedStateDictConfig(offload_to_cpu=True)
    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, state_dict_config=save_config):
        # 2. 获取状态字典
        # 在此模式下,state_dict()返回的字典只包含当前Rank拥有的分片数据
        state_dict = model.state_dict()

        # 3. Rank 0 负责创建目录
        if rank == 0:
            os.makedirs(path, exist_ok=True)
        dist.barrier() # 确保目录创建完毕

        # 4. 每个Rank独立保存自己的分片
        save_file = os.path.join(path, f"model_shard_{rank}.pt")
        torch.save(state_dict, save_file)

    dist.barrier() # 等待所有Rank保存完毕
    if rank == 0:
        print(f"Rank 0: 全部保存完成,耗时 {time.time() - start_time:.2f} 秒")


def load_sharded_checkpoint(model: FSDP, rank: int, path: str):
    """并行加载分片状态字典"""
    print(f"Rank {rank}: 开始加载...")
    start_time = time.time()

    load_config = ShardedStateDictConfig(offload_to_cpu=True)
    with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT, state_dict_config=load_config):
        # 1. 读取当前Rank对应的分片文件
        load_file = os.path.join(path, f"model_shard_{rank}.pt")
        sharded_state_dict = torch.load(load_file, map_location=f'cuda:{model.device.index}')

        # 2. FSDP的load_state_dict会正确地将分片加载到对应的GPU内存中
        model.load_state_dict(sharded_state_dict)

    dist.barrier()
    if rank == 0:
        print(f"Rank 0: 全部加载完成,耗时 {time.time() - start_time:.2f} 秒")

# --- 主执行逻辑 ---
if __name__ == '__main__':
    # 1. 初始化环境
    rank, local_rank = setup_ddp()

    # 2. 创建并FSDP封装模型
    base_model = DummyModel().to(local_rank)
    # 注意:对于TB级模型,需要配合ParameterGroup等优化策略
    fsdp_model = FSDP(base_model)

    # 3. 示例:模型训练后进行保存
    save_sharded_checkpoint(fsdp_model, rank, CHECKPOINT_DIR)

    # 4. 示例:重新初始化模型并加载权重
    dist.barrier()
    new_base_model = DummyModel().to(local_rank)
    new_fsdp_model = FSDP(new_base_model)
    load_sharded_checkpoint(new_fsdp_model, rank, CHECKPOINT_DIR)

    dist.destroy_process_group()

4. 总结优化原理

这种方法的“秒级”优势并非来自压缩或神奇的算法,而是来自于:

  1. 并行性: N个GPU同时写入N个文件,理论I/O速度提升N倍。
  2. 避免聚合: 绕过了将所有TB级数据传输到Rank 0的CPU内存和磁盘带宽限制。
  3. 分布式文件系统: 搭配高性能并行文件系统(如Lustre, CephFS, GPFS等),可以充分发挥并行写入的性能。
【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 大模型断点续训(Checkpointing)优化:如何秒级保存与加载 TB 级的权重
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址