在现代深度学习中,模型和数据集的规模爆炸式增长,使得分布式训练成为常态。PyTorch 的 torch.distributed 包提供了一系列高效的通信原语(Collective Operations),这些原语是实现数据并行(DDP)和模型并行(FSDP)的关键。理解它们之间的差异和适用场景,对于优化训练速度至关重要。
本文将深入对比 PyTorch 中最常用的三个通信原语:AllReduce、AllGather 和 ReduceScatter,并通过实操代码展示其用法。
准备工作:初始化分布式环境
所有分布式操作都依赖于进程组(Process Group)的初始化。我们假设环境已配置好 rank 和 world_size。
import torch
import torch.distributed as dist
import os
def setup(rank, world_size):
# 使用NCCL后端(GPU环境推荐)或Gloo后端
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup():
dist.destroy_process_group()
1. AllReduce:梯度同步的基石
AllReduce 是最常见的原语,它将所有进程上的输入张量进行归约(如求和、求平均),然后将最终结果分发给所有进程。
适用场景: 分布式数据并行(DDP)中的梯度同步。
工作原理
- 输入: $N$ 个进程各自拥有 $T$。
- 输出: $N$ 个进程各自获得 $ ext{Reduce}(T_1, T_2, …, T_N)$。
代码示例
假设我们有两个进程,rank 0 有张量 10,rank 1 有张量 20。我们使用求和操作。
# 假设已在setup函数中初始化环境
# rank = current_process_rank
# world_size = 2
def run_all_reduce(rank, world_size):
setup(rank, world_size)
input_data = torch.tensor([10.0 * (rank + 1)], device=rank)
print(f"Rank {rank} - Initial data: {input_data}")
# 执行 AllReduce (SUM)
dist.all_reduce(input_data, op=dist.ReduceOp.SUM)
# 归约结果:10.0 + 20.0 = 30.0
print(f"Rank {rank} - After AllReduce: {input_data}")
cleanup()
# 运行方式 (通常通过 torchrun 或 multiproc)
# Process 0 Output: Initial data: [10.0], After AllReduce: [30.0]
# Process 1 Output: Initial data: [20.0], After AllReduce: [30.0]
2. AllGather:收集全局信息
AllGather 将所有进程的输入张量收集起来,并将其拼接成一个更大的张量,然后把这个完整的结果分发给所有进程。
适用场景: 需要计算全局统计信息(如 LARS 优化器),或者在 BatchNorm 层中收集所有 GPU 上的统计量。
工作原理
- 输入: $N$ 个进程各自拥有 $T$。
- 输出: $N$ 个进程各自获得 $ ext{Concat}(T_1, T_2, …, T_N)$。
代码示例
每个进程拥有一个 2×2 的张量,world_size = 2。
def run_all_gather(rank, world_size):
setup(rank, world_size)
# 每个进程的输入张量
input_tensor = torch.ones(2, 2, device=rank) * (rank + 1)
# 准备输出列表,大小等于 world_size,用于存放收集到的张量
output_list = [torch.zeros_like(input_tensor) for _ in range(world_size)]
dist.all_gather(output_list, input_tensor)
# 结果是一个包含所有进程数据的列表
# Rank 0 input: [[1., 1.], [1., 1.]]
# Rank 1 input: [[2., 2.], [2., 2.]]
# All processes receive: [tensor(1s), tensor(2s)]
print(f"Rank {rank} - Collected data shape: {[o.shape for o in output_list]}")
print(f"Rank {rank} - First collected item (from rank 0):\n{output_list[0]}")
cleanup()
3. ReduceScatter:高效的模型并行基石
ReduceScatter 结合了 Reduce(归约)和 Scatter(分散)的功能。它首先对所有进程的输入数据列表进行归约,然后将归约后的结果分散回每个进程,每个进程只接收结果的一部分。
适用场景: 大规模模型(如 FSDP)中的高效梯度同步。它允许在归约过程中同时减少通信带宽,因为每个进程只接收其所需的部分。
工作原理
- 输入: $N$ 个进程各自拥有一组数据 $L = [T_1, T_2, …, T_N]$ (通常是已经拼接好的大张量,但在 PyTorch API 中通常输入的是列表)。
- 操作: 计算 $ ext{Reduce}(L_1, L_2, …, L_N)$。
- 输出: 进程 $i$ 获得归约结果的第 $i$ 部分。
代码示例
我们将使用 reduce_scatter_tensor API,它更符合实际操作中对大张量的处理。
def run_reduce_scatter(rank, world_size):
setup(rank, world_size)
# 假设每个进程的输入是一个大张量,包含所有进程所需的数据部分
# 输入大小必须是世界大小的倍数 (此处 4x2 = 8)
# 目标是:对所有进程的输入求和,然后将结果分散给每个进程 (2x2)
input_tensor_full = torch.ones(world_size * 2, 2, device=rank) * (rank + 1)
# 准备输出张量,大小为输入大小 / world_size
output_tensor = torch.zeros(2, 2, device=rank)
# 使用 dist.reduce_scatter_tensor
dist.reduce_scatter_tensor(output_tensor, input_tensor_full)
# 归约求和: 假设 Rank 0 和 Rank 1 都有 [[1,1], [1,1], [1,1], [1,1]] 和 [[2,2], [2,2], [2,2], [2,2]]
# 归约结果 (SUM): [[3,3], [3,3], [3,3], [3,3]]
# Scatter: Rank 0 收到前一半 [[3,3], [3,3]]
# Scatter: Rank 1 收到后一半 [[3,3], [3,3]]
print(f"Rank {rank} - Output (received part of the reduced result):\n{output_tensor}")
cleanup()
总结对比
| 原语 | 功能描述 | 数据流向 | 常见用途 |
|---|---|---|---|
| AllReduce | 归约所有数据,所有进程获得完整结果。 | $N$ 到 1 (Reduce) -> 1 到 $N$ (Broadcast) | DDP 梯度同步 |
| AllGather | 收集所有数据并拼接,所有进程获得完整拼接结果。 | $N$ 到 $N$ (收集/广播) | 收集 Batch Norm 统计量、全局特征 |
| ReduceScatter | 归约所有数据,并将结果分散到每个进程的不同部分。 | $N$ 到 1 (Reduce) -> 1 到 $N$ (Scatter) | FSDP 梯度高效同步 |
在选择通信原语时,核心原则是:如果只需要一个全局统计值(如平均梯度),使用 AllReduce;如果需要所有进程的数据,使用 AllGather;如果需要高效地归约大张量并分散结果的不同部分,使用 ReduceScatter。
汤不热吧