欢迎光临
我们一直在努力

详解 ZeRO 冗余消除技术:为什么 ZeRO-3 能让单卡跑起“塞不下”的模型

为什么 ZeRO-3 能让单卡跑起“塞不下”的模型?

随着大语言模型(LLM)的尺寸不断膨胀,GPU的显存往往成为训练过程中的最大瓶颈。一个1750亿参数的模型(如GPT-3),即使使用混合精度(FP16/BF16),仅参数、梯度和优化器状态就需要数百GB的存储空间,远超单张高端GPU(如A100 80GB)的容量。

DeepSpeed的Zero Redundancy Optimizer (ZeRO) 技术正是为解决这一问题而生。它是一种创新的数据并行(Data Parallelism)方法,通过对训练状态进行分片(Sharding),大幅度减少了每张GPU上的内存占用。

训练状态的内存消耗

在标准的PyTorch训练中,内存主要消耗在以下三个方面(以FP16/BF16和Adam优化器为例):

  1. 参数(P): 模型权重本身。FP16下,每10亿参数约占2GB。
  2. 梯度(G): 存储反向传播计算出的梯度。与参数大小相同。
  3. 优化器状态(P_os): Adam优化器需要存储两个状态(Momentum和Variance)。FP32下,每10亿参数约占8GB。

总计: 每10亿参数大约需要 2GB (P) + 2GB (G) + 8GB (P_os) = 12GB显存(不包含激活值)。

ZeRO 的三个阶段

ZeRO通过逐步分片这些训练状态,提供了三个优化级别:

级别 分片内容 (Sharding) 节省内存倍数 (相对于标准数据并行)
ZeRO-1 仅分片优化器状态 ($P_{os}$) 约 4x
ZeRO-2 分片优化器状态 ($P_{os}$) + 梯度 ($G$) 约 8x
ZeRO-3 分片优化器状态 ($P_{os}$) + 梯度 ($G$) + 参数 ($P$) $N_{gpu}$ 倍(与GPU数量成正比)

ZeRO-3:突破显存限制的关键

ZeRO-3(Full Sharded Data Parallelism) 是最激进的优化级别,它将模型参数本身也进行了分片。这意味着,如果一个模型有1000亿参数,使用8块GPU进行训练,那么在静态存储时,每块GPU仅需存储1/8的模型参数(加上1/8的梯度和优化器状态)。

运作机制:运行时动态重构

既然参数被分片了,GPU如何进行前向和后向计算呢?ZeRO-3采用了一种动态内存管理策略:

  1. 分片存储: 完整的模型参数 $P$ 被均匀地存储在 $N$ 块GPU上。
  2. 前向传播(Forward): 当计算一个特定的层时,ZeRO-3会使用 all-gather 通信操作,从所有其他GPU上收集该层所需的完整参数子集(完整的权重矩阵)。
  3. 计算与释放: GPU完成该层的计算后,立即释放(deallocate)这组完整的参数权重,仅保留计算所需的激活值。
  4. 后向传播(Backward): 过程类似,当计算梯度时,参数会被重新收集,计算完成后,梯度会被立即缩减(reduce-scatter)并分配回负责该参数分片的GPU进行更新。

通过这种“即时获取,用完即扔”的策略,ZeRO-3确保了任何时刻单块GPU的显存中,只保留当前计算所需的参数片段,从而使得总模型大小($P+G+P_{os}$)可以远远超过单张GPU的显存容量。

实操:如何启用 DeepSpeed ZeRO-3

启用ZeRO-3非常简单,只需通过DeepSpeed配置文件定义 zero_optimization 即可。

步骤 1: DeepSpeed 配置文件 (ds_config.json)

创建一个JSON配置文件,将 stage 设置为 3

{
  "train_batch_size": "auto",
  "gradient_accumulation_steps": 1,
  "optimizer": {
    "type": "Adam",
    "params": {
      "lr": 1e-5
    }
  },
  "fp16": {
    "enabled": true
  },
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer_states": true, 
    "offload_param": true, 
    "contiguous_gradients": true
  }
}
  • stage: 3:启用ZeRO-3。
  • offload_optimizer_states: true:可选,将优化器状态从GPU显存卸载到CPU/NVMe内存,进一步节省GPU内存。
  • offload_param: true:可选,将不活跃的参数也卸载到CPU/NVMe。

步骤 2: PyTorch 脚本集成

使用 deepspeed 命令行工具运行你的训练脚本。

假设你的训练脚本名为 train.py

# 假设使用 4 块 GPU (nproc_per_node=4)
deepspeed --num_gpus 4 train.py --deepspeed_config ds_config.json

步骤 3: 脚本示例 (train.py 关键部分)

在 PyTorch 脚本中,你只需要像使用标准 PyTorch 一样初始化模型和数据加载器,然后使用 deepspeed.initialize 包装它们。

import torch
import deepspeed

# 假设我们有一个非常大的模型
# 注意:在实际应用中,如果模型参数量巨大,需要使用模型并行或Triton等技术配合加载
# 这里仅演示如何初始化 DeepSpeed
class SimpleLargeModel(torch.nn.Module):
    def __init__(self, size=50000): # 示例大小
        super().__init__()
        self.linear1 = torch.nn.Linear(size, size)
        self.relu = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(size, size)

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

model = SimpleLargeModel()

# DeepSpeed 初始化
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    config_params=json.load(open('ds_config.json'))
)

# 训练循环与标准 PyTorch 基本相同
# model_engine.step() 会自动处理 all-gather, compute, reduce-scatter 等复杂的通信操作
# ... training loop ...

总结

ZeRO-3 的核心在于将模型的参数(P)从冗余的存储转变为分片存储。通过高效的 all-gather 和动态内存释放机制,它有效地将训练所需的静态内存需求按 $N_{gpu}$ 倍数进行划分,从而使得原本无法装入单张GPU的模型能够在分布式环境中成功训练。

【本站文章皆为原创,未经允许不得转载】:汤不热吧 » 详解 ZeRO 冗余消除技术:为什么 ZeRO-3 能让单卡跑起“塞不下”的模型
分享到: 更多 (0)

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址