为什么 ZeRO-3 能让单卡跑起“塞不下”的模型?
随着大语言模型(LLM)的尺寸不断膨胀,GPU的显存往往成为训练过程中的最大瓶颈。一个1750亿参数的模型(如GPT-3),即使使用混合精度(FP16/BF16),仅参数、梯度和优化器状态就需要数百GB的存储空间,远超单张高端GPU(如A100 80GB)的容量。
DeepSpeed的Zero Redundancy Optimizer (ZeRO) 技术正是为解决这一问题而生。它是一种创新的数据并行(Data Parallelism)方法,通过对训练状态进行分片(Sharding),大幅度减少了每张GPU上的内存占用。
训练状态的内存消耗
在标准的PyTorch训练中,内存主要消耗在以下三个方面(以FP16/BF16和Adam优化器为例):
- 参数(P): 模型权重本身。FP16下,每10亿参数约占2GB。
- 梯度(G): 存储反向传播计算出的梯度。与参数大小相同。
- 优化器状态(P_os): Adam优化器需要存储两个状态(Momentum和Variance)。FP32下,每10亿参数约占8GB。
总计: 每10亿参数大约需要 2GB (P) + 2GB (G) + 8GB (P_os) = 12GB显存(不包含激活值)。
ZeRO 的三个阶段
ZeRO通过逐步分片这些训练状态,提供了三个优化级别:
| 级别 | 分片内容 (Sharding) | 节省内存倍数 (相对于标准数据并行) |
|---|---|---|
| ZeRO-1 | 仅分片优化器状态 ($P_{os}$) | 约 4x |
| ZeRO-2 | 分片优化器状态 ($P_{os}$) + 梯度 ($G$) | 约 8x |
| ZeRO-3 | 分片优化器状态 ($P_{os}$) + 梯度 ($G$) + 参数 ($P$) | $N_{gpu}$ 倍(与GPU数量成正比) |
ZeRO-3:突破显存限制的关键
ZeRO-3(Full Sharded Data Parallelism) 是最激进的优化级别,它将模型参数本身也进行了分片。这意味着,如果一个模型有1000亿参数,使用8块GPU进行训练,那么在静态存储时,每块GPU仅需存储1/8的模型参数(加上1/8的梯度和优化器状态)。
运作机制:运行时动态重构
既然参数被分片了,GPU如何进行前向和后向计算呢?ZeRO-3采用了一种动态内存管理策略:
- 分片存储: 完整的模型参数 $P$ 被均匀地存储在 $N$ 块GPU上。
- 前向传播(Forward): 当计算一个特定的层时,ZeRO-3会使用 all-gather 通信操作,从所有其他GPU上收集该层所需的完整参数子集(完整的权重矩阵)。
- 计算与释放: GPU完成该层的计算后,立即释放(deallocate)这组完整的参数权重,仅保留计算所需的激活值。
- 后向传播(Backward): 过程类似,当计算梯度时,参数会被重新收集,计算完成后,梯度会被立即缩减(reduce-scatter)并分配回负责该参数分片的GPU进行更新。
通过这种“即时获取,用完即扔”的策略,ZeRO-3确保了任何时刻单块GPU的显存中,只保留当前计算所需的参数片段,从而使得总模型大小($P+G+P_{os}$)可以远远超过单张GPU的显存容量。
实操:如何启用 DeepSpeed ZeRO-3
启用ZeRO-3非常简单,只需通过DeepSpeed配置文件定义 zero_optimization 即可。
步骤 1: DeepSpeed 配置文件 (ds_config.json)
创建一个JSON配置文件,将 stage 设置为 3。
{
"train_batch_size": "auto",
"gradient_accumulation_steps": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-5
}
},
"fp16": {
"enabled": true
},
"zero_optimization": {
"stage": 3,
"offload_optimizer_states": true,
"offload_param": true,
"contiguous_gradients": true
}
}
- stage: 3:启用ZeRO-3。
- offload_optimizer_states: true:可选,将优化器状态从GPU显存卸载到CPU/NVMe内存,进一步节省GPU内存。
- offload_param: true:可选,将不活跃的参数也卸载到CPU/NVMe。
步骤 2: PyTorch 脚本集成
使用 deepspeed 命令行工具运行你的训练脚本。
假设你的训练脚本名为 train.py:
# 假设使用 4 块 GPU (nproc_per_node=4)
deepspeed --num_gpus 4 train.py --deepspeed_config ds_config.json
步骤 3: 脚本示例 (train.py 关键部分)
在 PyTorch 脚本中,你只需要像使用标准 PyTorch 一样初始化模型和数据加载器,然后使用 deepspeed.initialize 包装它们。
import torch
import deepspeed
# 假设我们有一个非常大的模型
# 注意:在实际应用中,如果模型参数量巨大,需要使用模型并行或Triton等技术配合加载
# 这里仅演示如何初始化 DeepSpeed
class SimpleLargeModel(torch.nn.Module):
def __init__(self, size=50000): # 示例大小
super().__init__()
self.linear1 = torch.nn.Linear(size, size)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(size, size)
def forward(self, x):
return self.linear2(self.relu(self.linear1(x)))
model = SimpleLargeModel()
# DeepSpeed 初始化
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
config_params=json.load(open('ds_config.json'))
)
# 训练循环与标准 PyTorch 基本相同
# model_engine.step() 会自动处理 all-gather, compute, reduce-scatter 等复杂的通信操作
# ... training loop ...
总结
ZeRO-3 的核心在于将模型的参数(P)从冗余的存储转变为分片存储。通过高效的 all-gather 和动态内存释放机制,它有效地将训练所需的静态内存需求按 $N_{gpu}$ 倍数进行划分,从而使得原本无法装入单张GPU的模型能够在分布式环境中成功训练。
汤不热吧