如何使用 PyTorch FSDP 解决超大模型单卡显存不足问题
随着大语言模型(LLM)的参数量突破百亿甚至万亿级别,传统的分布式训练方案(如DDP,数据并行)已经无法满足需求,因为DDP要求每张GPU都复制完整的模型权重、梯度和优化器状态,这迅速触及了单卡显存的极限。全分片指数级并行(Fully Sharded Data Parallel, FSDP)是 PyTorch 社区为解决这一“显存墙”问题而推出的终极方案,它通过在多卡之间高效地切分模型的各个组件,极大地降低了单卡的内存占用。
一、FSDP 原理速览:全分片的魔力
FSDP 的核心思想是将模型的三个关键内存占用部分在所有 GPU 之间进行分片(Sharding):
- P (Parameters): 模型权重参数。
- G (Gradients): 梯度。
- O (Optimizer States): 优化器状态(如 Adam 的两个动量)。
不同于 DDP 只分片数据,FSDP 将 P、G、O 全部切分。在 FSDP 的全分片(Full Sharding,对应 ZeRO Stage 3)模式下,每张 GPU 在任何给定时刻,都只存储模型总参数的 $1/N$(N为GPU数量)部分。这使得训练万亿参数的模型在有限的 GPU 集群上成为可能。
二、FSDP 的运行机制
FSDP 遵循即时通信(Just-In-Time Communication)策略,确保内存效率和计算效率的平衡:
- 前向传播 (Forward Pass): 当某一层的计算需要权重时,FSDP 会通过 All-Gather 操作,从所有其他 GPU 上收集该层完整的权重。一旦该层的计算完成,完整的权重就会被立即释放(或重新分片),从而最小化峰值显存占用。
- 反向传播 (Backward Pass): 梯度是在局部计算的。FSDP 随后使用 Reduce-Scatter 操作,将计算得到的完整梯度分散回其对应参数的持有者 GPU 上,并进行同步求和。
三、实操指南:使用 PyTorch FSDP 训练大模型
我们使用 PyTorch 内置的 torch.distributed.fsdp 模块来演示如何快速部署 FSDP。
步骤 1: 环境准备
确保安装 PyTorch 1.12 或更高版本,并准备多张 GPU 环境。
步骤 2: 定义模型和 FSDP 策略
为了演示 FSDP 的效果,我们定义一个结构简单但参数量较大的模型(例如,一个具有多层 Transformer 块的模型)。
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import ShardingStrategy
from torch.distributed.fsdp.wrap import enable_wrap, wrap, size_based_auto_wrap_policy
import os
# 模拟一个大型Transformer块,通常FSDP在这些大块上进行分片
class TransformerBlock(nn.Module):
def __init__(self, hidden_size):
super().__init__()
# 故意使用大参数量
self.norm = nn.LayerNorm(hidden_size)
self.attn = nn.Linear(hidden_size, hidden_size * 4)
self.mlp = nn.Linear(hidden_size * 4, hidden_size)
def forward(self, x):
res = x
x = self.norm(x)
x = self.attn(x)
x = self.mlp(x)
return x + res
class LargeModel(nn.Module):
def __init__(self, hidden_size=4096, num_layers=24):
super().__init__()
self.layers = nn.ModuleList([
TransformerBlock(hidden_size) for _ in range(num_layers)
])
self.final_proj = nn.Linear(hidden_size, 1000)
def forward(self, x):
for layer in self.layers:
x = layer(x)
return self.final_proj(x)
# 定义自动分片策略:基于模块大小自动分片,避免对小模块进行分片
# 当子模块的参数量超过1亿时,将其视为一个FSDP单元
my_auto_wrap_policy = size_based_auto_wrap_policy(
min_num_params=100000000 # 1亿参数
)
步骤 3: 初始化和 FSDP 封装
我们需要在每个进程(GPU)上运行初始化代码,并使用 FSDP 封装模型。
def setup(rank, world_size):
# 使用nccl后端进行GPU通信
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def train_fsdp(rank, world_size):
setup(rank, world_size)
torch.cuda.set_device(rank)
# 1. 实例化模型并移至对应设备
model = LargeModel().to(rank)
# 2. FSDP 封装
# sharding_strategy=ShardingStrategy.FULL_SHARD 是ZeRO-3的核心
fsdp_model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=torch.cuda.current_device(),
auto_wrap_policy=my_auto_wrap_policy,
# 注意:使用 Transformer 结构的 FSDP 建议启用 Backward_Prefetch
)
# 3. 定义优化器 (注意: FSDP优化器会自动管理分片的参数)
optimizer = optim.AdamW(fsdp_model.parameters(), lr=1e-5)
# 4. 模拟训练步骤
print(f"Rank {rank}: 模型参数已分片完成,仅占用总参数的 1/{world_size} 部分显存")
dummy_input = torch.randn(2, 512, 4096).to(rank) # Batch_size 2, Seq_len 512, Hidden 4096
loss_fn = nn.MSELoss()
# 前向传播:FSDP 自动 All-Gather
output = fsdp_model(dummy_input)
target = torch.randn_like(output)
loss = loss_fn(output, target)
# 反向传播:FSDP 自动 Reduce-Scatter
loss.backward()
# 优化器更新:FSDP 自动处理分片的优化器状态
optimizer.step()
print(f"Rank {rank}: 训练步完成,Loss: {loss.item():.4f}")
dist.destroy_process_group()
# 5. 启动脚本 (通常使用 torchrun)
# 示例启动命令 (假设你有4张GPU):
# torchrun --nproc_per_node=4 your_script_name.py
if __name__ == '__main__':
# 在实际环境中,使用 torchrun 或 torch.distributed.launch 启动主函数
# 以下代码仅为说明 train_fsdp 函数的调用结构
# world_size = torch.cuda.device_count()
# rank = int(os.environ['LOCAL_RANK'])
# train_fsdp(rank, world_size)
print("请使用 'torchrun --nproc_per_node=<GPU数量> your_script.py' 运行此代码")
四、总结
FSDP 提供了对超大规模模型训练的全面支持,通过将参数 (P)、梯度 (G) 和优化器状态 (O) 完全分片到集群的所有 GPU 上,将单卡显存的占用从 $O(N)$ 降至 $O(N/W)$ (W为GPU数量),从而彻底解决了单卡显存不足的问题。对于致力于训练或微调大型 Transformer 模型的 AI 工程师而言,掌握 FSDP 是进入大模型时代的关键技能。
汤不热吧